Bring code up-to-date

This commit is contained in:
Scott Wallace 2024-11-27 09:50:37 +00:00
parent b939ad415d
commit 97512598f5
Signed by: scott
SSH key fingerprint: SHA256:+LJug6Dj01Jdg86CILGng9r0lJseUrpI0xfRqdW9Uws
8 changed files with 186 additions and 187 deletions

10
main.py
View file

@ -1,20 +1,18 @@
""" """
Main Flask-based app for Slinky Main Flask-based app for Slinky
""" """
from flask import Flask, Response, render_template from flask import Flask, Response, render_template
from flask_bootstrap import Bootstrap # type: ignore[import]
from werkzeug.middleware.proxy_fix import ProxyFix from werkzeug.middleware.proxy_fix import ProxyFix
from slinky.web import protect, slinky_webapp from slinky.web import protect, slinky_webapp
app = Flask(__name__) app = Flask(__name__)
app.wsgi_app = ProxyFix(app.wsgi_app, x_proto=1) # type: ignore[assignment] app.wsgi_app = ProxyFix(app.wsgi_app, x_proto=1)
app.register_blueprint(slinky_webapp) app.register_blueprint(slinky_webapp)
Bootstrap(app)
@app.route("/")
@app.route('/')
@protect @protect
def index() -> Response: def index() -> Response:
""" """
@ -23,4 +21,4 @@ def index() -> Response:
Returns: Returns:
str: string of page content str: string of page content
""" """
return Response(render_template('index.html'), 200) return Response(render_template("index.html"), 200)

View file

@ -1,5 +1,4 @@
flask flask
flask_bootstrap
flask_wtf flask_wtf
psycopg2-binary psycopg2-binary
pyyaml pyyaml

View file

@ -6,9 +6,6 @@ import random
import string import string
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from typing import Optional
import sqlalchemy # type: ignore[import]
from slinky import db from slinky import db
@ -38,7 +35,7 @@ def random_string(length: int = 4) -> str:
""" """
allowed_chars: str = string.ascii_letters + string.digits allowed_chars: str = string.ascii_letters + string.digits
return ''.join(random.SystemRandom().choice(allowed_chars) for _ in range(length)) return "".join(random.SystemRandom().choice(allowed_chars) for _ in range(length))
class Slinky: class Slinky:
@ -52,8 +49,8 @@ class Slinky:
def add( # pylint: disable=too-many-arguments def add( # pylint: disable=too-many-arguments
self, self,
shortcode: str = '', shortcode: str = "",
url: str = '', url: str = "",
length: int = 4, length: int = 4,
fixed_views: int = -1, fixed_views: int = -1,
expiry: datetime = datetime.max, expiry: datetime = datetime.max,
@ -78,7 +75,7 @@ class Slinky:
shortcode = random_string(length=length) shortcode = random_string(length=length)
if self.get_by_shortcode(shortcode).url: if self.get_by_shortcode(shortcode).url:
raise ValueError(f'Shortcode {shortcode} already exists') raise ValueError(f"Shortcode {shortcode} already exists")
dbentry = db.ShortURL( dbentry = db.ShortURL(
shortcode=shortcode, shortcode=shortcode,
@ -114,7 +111,7 @@ class Slinky:
) )
self.session.close() self.session.close()
return ret_sc return ret_sc
return Shortcode(0, '', '', 0, '1970-01-01 00:00:00.000000') return Shortcode(0, "", "", 0, "1970-01-01 00:00:00.000000")
def remove_view(self, sc_id: int) -> None: def remove_view(self, sc_id: int) -> None:
""" """

View file

@ -1,21 +1,22 @@
""" """
DB component DB component
""" """
from dataclasses import dataclass from dataclasses import dataclass
from sqlalchemy import Column, Integer, String, create_engine # type: ignore[import] from sqlalchemy import Column, Integer, String, create_engine
from sqlalchemy.ext.declarative import declarative_base # type: ignore[import] from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, sessionmaker # type: ignore[import] from sqlalchemy.orm import Session, sessionmaker
Base = declarative_base() Base = declarative_base()
class ShortURL(Base): # type: ignore[misc, valid-type] # pylint: disable=too-few-public-methods class ShortURL(Base): # pylint: disable=too-few-public-methods
""" """
Class to describe the DB schema for ShortURLs Class to describe the DB schema for ShortURLs
""" """
__tablename__ = 'shorturl' __tablename__ = "shorturl"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
shortcode = Column(String(128), unique=True, nullable=False) shortcode = Column(String(128), unique=True, nullable=False)

View file

@ -10,71 +10,70 @@ from typing import Any, Callable
import yaml import yaml
from flask import Blueprint, Response, render_template, request from flask import Blueprint, Response, render_template, request
from flask_wtf import FlaskForm # type: ignore[import] from flask_wtf import FlaskForm
from wtforms import HiddenField # type: ignore[import] from wtforms import DateTimeLocalField, HiddenField, IntegerField, StringField
from wtforms import DateTimeLocalField, IntegerField, StringField from wtforms.validators import DataRequired, Length
from wtforms.validators import DataRequired, Length # type: ignore[import]
from slinky import Slinky, random_string from slinky import Slinky, random_string
slinky_webapp = Blueprint('webapp', __name__, template_folder='templates') slinky_webapp = Blueprint("webapp", __name__, template_folder="templates")
with open('config.yaml', encoding='utf-8-sig') as conffile: with open("config.yaml", encoding="utf-8-sig") as conffile:
cfg = yaml.safe_load(conffile) cfg = yaml.safe_load(conffile)
class DelForm(FlaskForm): # type: ignore[misc] class DelForm(FlaskForm):
""" """
Delete form definition Delete form definition
""" """
delete = HiddenField('delete') delete = HiddenField("delete")
class AddForm(FlaskForm): # type: ignore[misc] class AddForm(FlaskForm):
""" """
Add form definition Add form definition
""" """
shortcode = StringField( shortcode = StringField(
'Shortcode', "Shortcode",
validators=[DataRequired(), Length(1, 2048)], validators=[DataRequired(), Length(1, 2048)],
render_kw={ render_kw={
'size': 64, "size": 64,
'maxlength': 2048, "maxlength": 2048,
}, },
) )
url = StringField( url = StringField(
'URL', "URL",
validators=[DataRequired(), Length(1, 2048)], validators=[DataRequired(), Length(1, 2048)],
render_kw={ render_kw={
'size': 64, "size": 64,
'maxlength': 2048, "maxlength": 2048,
'placeholder': 'https://www.example.com', "placeholder": "https://www.example.com",
}, },
) )
fixed_views = IntegerField( fixed_views = IntegerField(
'Fixed number of views', "Fixed number of views",
validators=[DataRequired()], validators=[DataRequired()],
render_kw={ render_kw={
'size': 3, "size": 3,
'value': -1, "value": -1,
}, },
) )
length = IntegerField( length = IntegerField(
'Shortcode length', "Shortcode length",
validators=[DataRequired()], validators=[DataRequired()],
render_kw={ render_kw={
'size': 3, "size": 3,
'value': 4, "value": 4,
}, },
) )
expiry = DateTimeLocalField( expiry = DateTimeLocalField(
'Expiry', "Expiry",
format='%Y-%m-%dT%H:%M', format="%Y-%m-%dT%H:%M",
render_kw={ render_kw={
'size': 8, "size": 8,
'maxlength': 10, "maxlength": 10,
}, },
) )
@ -92,18 +91,26 @@ def protect(func: Callable[..., Response]) -> Callable[..., Response]:
@wraps(func) @wraps(func)
def check_ip(*args: Any, **kwargs: Any) -> Response: def check_ip(*args: Any, **kwargs: Any) -> Response:
remote_addr = request.remote_addr
if "x-forwarded-for" in request.headers:
remote_addr = request.headers["x-forwarded-for"]
if "x-real-ip" in request.headers:
remote_addr = request.headers["x-real-ip"]
if ( if (
os.environ.get('FLASK_ENV', '') != 'development' os.environ.get("FLASK_ENV", "") != "development"
and request.headers['X-Forwarded-For'] not in cfg['allowed_ips'] and remote_addr not in cfg["allowed_ips"]
): ):
logging.warning('Protected URL access attempt from %s', request.remote_addr) logging.warning("Protected URL access attempt from %s", remote_addr)
return Response('Not found', 404) return Response("Not found", 404)
return func(*args, **kwargs) return func(*args, **kwargs)
return check_ip return check_ip
@slinky_webapp.route('/<path:path>', strict_slashes=False) @slinky_webapp.route("/<path:path>", strict_slashes=False)
def try_path_as_shortcode(path: str) -> Response: def try_path_as_shortcode(path: str) -> Response:
""" """
Try the initial path as a shortcode, redirect if found Try the initial path as a shortcode, redirect if found
@ -112,27 +119,27 @@ def try_path_as_shortcode(path: str) -> Response:
Response: redirect if found, otherwise 404 Response: redirect if found, otherwise 404
""" """
should_redirect = True should_redirect = True
slinky = Slinky(cfg['db']) slinky = Slinky(cfg["db"])
shortcode = slinky.get_by_shortcode(path) shortcode = slinky.get_by_shortcode(path)
if shortcode.url: if shortcode.url:
if shortcode.fixed_views == 0: if shortcode.fixed_views == 0:
logging.warning('Shortcode out of views') logging.warning("Shortcode out of views")
should_redirect = False should_redirect = False
elif shortcode.fixed_views > 0: elif shortcode.fixed_views > 0:
slinky.remove_view(shortcode.id) slinky.remove_view(shortcode.id)
if datetime.fromisoformat(shortcode.expiry) < datetime.now(): if datetime.fromisoformat(shortcode.expiry) < datetime.now():
logging.warning('Shortcode expired') logging.warning("Shortcode expired")
should_redirect = False should_redirect = False
if should_redirect: if should_redirect:
return Response( return Response(
'Redirecting...', status=302, headers={'location': shortcode.url} "Redirecting...", status=302, headers={"location": shortcode.url}
) )
return Response('Not found', 404) return Response("Not found", 404)
@slinky_webapp.route('/_/add', methods=['GET', 'POST']) @slinky_webapp.route("/_/add", methods=["GET", "POST"])
@protect @protect
def add() -> Response: def add() -> Response:
""" """
@ -141,27 +148,27 @@ def add() -> Response:
Returns: Returns:
Response: HTTP response Response: HTTP response
""" """
slinky = Slinky(cfg['db']) slinky = Slinky(cfg["db"])
for attempts in range(50): for attempts in range(50):
shortcode = random_string() shortcode = random_string()
if slinky.get_by_shortcode(shortcode).url: if slinky.get_by_shortcode(shortcode).url:
logging.warning( logging.warning(
'Shortcode already exists. Retrying (%s/50).', "Shortcode already exists. Retrying (%s/50).",
attempts, attempts,
) )
else: else:
break break
else: else:
return Response( return Response(
render_template('error.html', msg='Could not create a unique shortcode'), render_template("error.html", msg="Could not create a unique shortcode"),
500, 500,
) )
url = '' url = ""
final_url = '' final_url = ""
form = AddForm(meta={'csrf': False}) form = AddForm(meta={"csrf": False})
if form.is_submitted(): if form.is_submitted():
shortcode = form.shortcode.data.strip() shortcode = form.shortcode.data.strip()
@ -181,14 +188,14 @@ def add() -> Response:
) )
except ValueError as error: except ValueError as error:
logging.warning(error) logging.warning(error)
return Response(render_template('error.html', msg=error), 400) return Response(render_template("error.html", msg=error), 400)
if form.is_submitted(): if form.is_submitted():
final_url = f'{request.host_url}/{shortcode}' final_url = f"{request.host_url}/{shortcode}"
return Response( return Response(
render_template( render_template(
'add.html', "add.html",
form=form, form=form,
shortcode=shortcode, shortcode=shortcode,
final_url=final_url, final_url=final_url,
@ -197,7 +204,7 @@ def add() -> Response:
) )
@slinky_webapp.route('/_/list', methods=['GET', 'POST']) @slinky_webapp.route("/_/list", methods=["GET", "POST"])
@protect @protect
def lister() -> Response: def lister() -> Response:
""" """
@ -206,19 +213,19 @@ def lister() -> Response:
Returns: Returns:
Response: HTTP response Response: HTTP response
""" """
form = DelForm(meta={'csrf': False}) form = DelForm(meta={"csrf": False})
slinky = Slinky(cfg['db']) slinky = Slinky(cfg["db"])
if form.is_submitted(): if form.is_submitted():
slinky.delete_by_shortcode(form.delete.data.strip()) slinky.delete_by_shortcode(form.delete.data.strip())
return Response( return Response(
render_template('list.html', form=form, shortcodes=slinky.get_all()), render_template("list.html", form=form, shortcodes=slinky.get_all()),
200, 200,
) )
@slinky_webapp.route('/_/edit/<int:id>', methods=['GET', 'POST']) @slinky_webapp.route("/_/edit/<int:id>", methods=["GET", "POST"])
@protect @protect
def edit(id: int) -> Response: # pylint: disable=invalid-name,redefined-builtin def edit(id: int) -> Response: # pylint: disable=invalid-name,redefined-builtin
""" """
@ -227,15 +234,15 @@ def edit(id: int) -> Response: # pylint: disable=invalid-name,redefined-builtin
Returns: Returns:
Response: HTTP response Response: HTTP response
""" """
form = DelForm(meta={'csrf': False}) form = DelForm(meta={"csrf": False})
slinky = Slinky(cfg['db']) slinky = Slinky(cfg["db"])
logging.debug('Editing: %d', id) logging.debug("Editing: %d", id)
if form.is_submitted(): if form.is_submitted():
slinky.delete_by_shortcode(form.delete.data.strip()) slinky.delete_by_shortcode(form.delete.data.strip())
return Response( return Response(
render_template('edit.html', form=form, shortcodes=slinky.get_all()), render_template("edit.html", form=form, shortcodes=slinky.get_all()),
200, 200,
) )

View file

@ -11,9 +11,8 @@
<link rel="canonical" href="https://getbootstrap.com/docs/5.0/examples/navbar-static/"> <link rel="canonical" href="https://getbootstrap.com/docs/5.0/examples/navbar-static/">
<!-- Bootstrap CSS --> <!-- Bootstrap CSS -->
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.0.0-beta3/dist/css/bootstrap.min.css" rel="stylesheet" <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.3/dist/css/bootstrap.min.css" rel="stylesheet"
integrity="sha384-eOJMYsd53ii+scO/bJGFsiCZc+5NDVN2yr8+0RDqr0Ql0h+rP48ckxlpbzKgwra6" crossorigin="anonymous"> integrity="sha384-QWTKZyjpPEjISv5WaRU9OFeRpok6YctnYmDr5pNlyT2bRjXh0JMhjY6hW+ALEwIH" crossorigin="anonymous">
<link rel="stylesheet" href="{{url_for('bootstrap.static', filename='datepicker.css')}}">
<meta name="theme-color" content="#7952b3"> <meta name="theme-color" content="#7952b3">
<style> <style>

View file

@ -1,6 +1,6 @@
<script src="https://cdn.jsdelivr.net/npm/bootstrap@5.0.0-beta3/dist/js/bootstrap.bundle.min.js" <script src="https://cdn.jsdelivr.net/npm/bootstrap@5.3.3/dist/js/bootstrap.bundle.min.js"
integrity="sha384-JEW9xMcG8R+pH31jmWH6WWP0WintQrMb4s7ZOdauHnUtxwoG2vI5DkLtS3qm9Ekf" crossorigin="anonymous"> integrity="sha384-YvpcrYf0tY3lHB60NNkmXc5s9fDVZLESaAA55NDzOxhy9GkcIdslK1eN7N6jIeHz"
</script> crossorigin="anonymous"></script>
<script> <script>
var tooltipTriggerList = [].slice.call(document.querySelectorAll('[data-bs-toggle="tooltip"]')); var tooltipTriggerList = [].slice.call(document.querySelectorAll('[data-bs-toggle="tooltip"]'));
var tooltipList = tooltipTriggerList.map(function (tooltipTriggerEl) { var tooltipList = tooltipTriggerList.map(function (tooltipTriggerEl) {

View file

@ -5,7 +5,7 @@ Test Slinky web interface
from unittest import TestCase, mock from unittest import TestCase, mock
from flask import Flask from flask import Flask
from flask_bootstrap import Bootstrap # type: ignore[import]
from slinky.web import slinky_webapp from slinky.web import slinky_webapp
@ -15,45 +15,43 @@ class TestWeb(TestCase):
""" """
def setUp(self) -> None: def setUp(self) -> None:
self.app = Flask(__name__, template_folder='../templates') self.app = Flask(__name__, template_folder="../templates")
self.app.register_blueprint(slinky_webapp) self.app.register_blueprint(slinky_webapp)
self.app_context = self.app.app_context() self.app_context = self.app.app_context()
self.app_context.push() self.app_context.push()
self.client = self.app.test_client() self.client = self.app.test_client()
Bootstrap(self.app) mock.patch.dict("slinky.web.cfg", {"db": "sqlite:///tests/test.db"}).start()
mock.patch.dict('slinky.web.cfg', {'db': 'sqlite:///tests/test.db'}).start()
def test_simple_redirect(self) -> None: def test_simple_redirect(self) -> None:
""" """
Ensure simple redirect works Ensure simple redirect works
""" """
response = self.client.get('/egie') response = self.client.get("/egie")
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
self.assertEqual(response.location, 'https://example.com') self.assertEqual(response.location, "https://example.com")
def test_fixed_views(self) -> None: def test_fixed_views(self) -> None:
""" """
Ensure depleted fixed views returns a 404 Ensure depleted fixed views returns a 404
""" """
response = self.client.get('/egig') response = self.client.get("/egig")
self.assertEqual(response.status_code, 404) self.assertEqual(response.status_code, 404)
def test_expiry(self) -> None: def test_expiry(self) -> None:
""" """
Ensure expired redirect returns a 404 Ensure expired redirect returns a 404
""" """
response = self.client.get('/egif') response = self.client.get("/egif")
self.assertEqual(response.status_code, 404) self.assertEqual(response.status_code, 404)
def test_no_unique_shortcode(self) -> None: def test_no_unique_shortcode(self) -> None:
""" """
Ensure non-unique shortcode generation returns a 500 error Ensure non-unique shortcode generation returns a 500 error
""" """
with mock.patch('slinky.web.random_string', return_value='egie'): with mock.patch("slinky.web.random_string", return_value="egie"):
response = self.client.get( response = self.client.get(
'/_/add', headers={'x-forwarded-for': '127.0.0.1'} "/_/add", headers={"x-forwarded-for": "127.0.0.1"}
) )
self.assertEqual(response.status_code, 500) self.assertEqual(response.status_code, 500)
@ -61,9 +59,9 @@ class TestWeb(TestCase):
""" """
Test the condition where the random_string() returns an existing shortcode Test the condition where the random_string() returns an existing shortcode
""" """
with mock.patch('slinky.web.random_string', side_effect=['egie', 'egiz']): with mock.patch("slinky.web.random_string", side_effect=["egie", "egiz"]):
response = self.client.get( response = self.client.get(
'/_/add', "/_/add",
headers={'x-forwarded-for': '127.0.0.1'}, headers={"x-forwarded-for": "127.0.0.1"},
) )
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)