diff --git a/.gitignore b/.gitignore index 2e5f17d..c9fd033 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ .pyenv/ .vscode/ __pycache__/ +slinky.db diff --git a/config.yaml b/config.yaml new file mode 100644 index 0000000..710c173 --- /dev/null +++ b/config.yaml @@ -0,0 +1,2 @@ +--- +db: sqlite:///slinky.db diff --git a/requirements.txt b/requirements.txt index 84548f4..5a5c7fe 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ flask flask_bootstrap flask_wtf +pyyaml waitress Flask-SQLAlchemy diff --git a/slinky/__init__.py b/slinky/__init__.py index 0085bd3..cd019c8 100644 --- a/slinky/__init__.py +++ b/slinky/__init__.py @@ -4,12 +4,28 @@ Main code import random import string +from dataclasses import dataclass from datetime import datetime from typing import Optional +import sqlalchemy + from slinky import db +@dataclass +class Shortcode: + """ + Simple dataclass to allow for typing of Shortcodes + """ + + id: int # pylint: disable=invalid-name + shortcode: str + url: str + fixed_views: int + expiry: str + + def random_string(length: int = 4) -> str: """ Create a random, alphanumeric string of the length specified @@ -25,34 +41,80 @@ def random_string(length: int = 4) -> str: return ''.join(random.SystemRandom().choice(allowed_chars) for _ in range(length)) -def add_shortcode( - url: str, - length: int = 4, - fixed_views: int = 0, - expiry: datetime = datetime.max, -) -> str: +class Slinky: """ - Add a shortcode to the DB - - Args: - url (str): URL to redirect to - fixed_views (int, optional): number of views to serve before expiring. - Defaults to 0 (no limit). - expiry (int, optional): date of expiry. Defaults to 0 (no limit). - - Returns: - str: shortcode for the redirect + Class for Slinky """ - session = db.session() - shortcode = random_string(length=length) - dbentry = db.ShortURL( - shortcode=shortcode, - url=url, - fixed_views=fixed_views, - expiry=expiry, - ) - session.add(dbentry) - session.commit() - session.close() - return shortcode + def __init__(self, url: str) -> None: + self.db = db.ShortcodeDB(url) # pylint: disable=invalid-name + self.session = self.db.session() + + def add( + self, + url: str, + length: int = 4, + fixed_views: int = -1, + expiry: datetime = datetime.max, + ) -> str: + """ + Add a shortcode to the DB + + Args: + url (str): URL to redirect to + fixed_views (int, optional): number of views to serve before expiring. + Defaults to 0 (no limit). + expiry (int, optional): date of expiry. Defaults to 0 (no limit). + + Returns: + str: shortcode for the redirect + """ + shortcode = random_string(length=length) + + if self.get(shortcode).url: + raise ValueError(f'Shortcode {shortcode} already exists') + + dbentry = db.ShortURL( + shortcode=shortcode, + url=url, + fixed_views=fixed_views, + expiry=expiry, + ) + self.session.add(dbentry) + self.session.commit() + + return shortcode + + def get(self, shortcode: str) -> Shortcode: + """ + Return a Shortcode object for a given shortcode + + Args: + shortcode (str): the shortcode to look up + + Returns: + Shortcode: full Shortcode object for the given shortcode + """ + entry = self.session.query(db.ShortURL).filter_by(shortcode=shortcode).first() + + if entry: + return Shortcode( + entry.id, + entry.shortcode, + entry.url, + entry.fixed_views, + entry.expiry, + ) + return Shortcode(0, '', '', 0, '1970-01-01 00:00:00.000000') + + def remove_view(self, sc_id: int) -> None: + """ + Reduce the fixed views count by one + + Args: + id (int): ID of the DB entry to reduce the fixed_views count + """ + self.session.query(db.ShortURL).filter_by(id=sc_id).update( + {db.ShortURL.fixed_views: db.ShortURL.fixed_views - 1} + ) + self.session.commit() diff --git a/slinky/db.py b/slinky/db.py index 3a4c1de..cf68915 100644 --- a/slinky/db.py +++ b/slinky/db.py @@ -1,6 +1,8 @@ """ DB component """ +from dataclasses import dataclass + from sqlalchemy import Column, Integer, String, create_engine from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Session, sessionmaker @@ -8,7 +10,7 @@ from sqlalchemy.orm import Session, sessionmaker Base = declarative_base() -class ShortURL(Base): # pylint: disable=too-few-public-methods +class ShortURL(Base): # type: ignore[misc, valid-type] # pylint: disable=too-few-public-methods """ Class to describe the DB schema for ShortURLs """ @@ -22,14 +24,21 @@ class ShortURL(Base): # pylint: disable=too-few-public-methods expiry = Column(Integer, unique=False, nullable=False) -def session() -> Session: +@dataclass +class ShortcodeDB: """ - Create a DB session + Class to represent the database + """ + url: str - Returns: - Session: the DB session object - """ - engine = create_engine('sqlite:////tmp/test.db') - Base.metadata.create_all(engine) - new_session = sessionmaker(bind=engine) - return new_session() + def session(self) -> Session: + """ + Create a DB session + + Returns: + Session: the DB session object + """ + engine = create_engine(self.url) + Base.metadata.create_all(engine) + new_session = sessionmaker(bind=engine) + return new_session() diff --git a/slinky/templates/add.html b/slinky/templates/add.html index 7f30f7a..9a943da 100644 --- a/slinky/templates/add.html +++ b/slinky/templates/add.html @@ -6,7 +6,7 @@

Add a shortcode


-
+ {{ form.url.label }} {{ form.url }}
{{ form.length.label }} {{ form.length }}
{{ form.fixed_views.label }} {{ form.fixed_views }} (0 = unlimited)
diff --git a/slinky/web.py b/slinky/web.py index 2ac9886..b87d61b 100644 --- a/slinky/web.py +++ b/slinky/web.py @@ -2,17 +2,22 @@ Web component """ +import logging from datetime import datetime -from flask import Blueprint, render_template +import yaml +from flask import Blueprint, Response, redirect, render_template from flask_wtf import FlaskForm from wtforms import DateTimeLocalField, IntegerField, StringField from wtforms.validators import DataRequired, Length -from slinky import add_shortcode +from slinky import Slinky slinky_webapp = Blueprint('webapp', __name__, template_folder='templates') +with open('config.yaml', encoding='utf-8-sig') as conffile: + cfg = yaml.safe_load(conffile) + class ShortURLForm(FlaskForm): # type: ignore[misc] """ @@ -25,7 +30,7 @@ class ShortURLForm(FlaskForm): # type: ignore[misc] render_kw={ 'size': 64, 'maxlength': 2048, - 'placeholder': 'e.g. www.example.com', + 'placeholder': 'https://www.example.com', }, ) fixed_views = IntegerField( @@ -33,7 +38,7 @@ class ShortURLForm(FlaskForm): # type: ignore[misc] validators=[DataRequired()], render_kw={ 'size': 3, - 'value': 0, + 'value': -1, }, ) length = IntegerField( @@ -54,7 +59,34 @@ class ShortURLForm(FlaskForm): # type: ignore[misc] ) -@slinky_webapp.route('/add', methods=['GET', 'POST']) +@slinky_webapp.route('/') +def try_path_as_shortcode(path: str) -> Response: + """ + Try the initial path as a shortcode, redirect if found + + Returns: + Optional[Response]: redirect if found, otherwise continue on + """ + should_redirect = True + slinky = Slinky(cfg['db']) + shortcode = slinky.get(path) + if shortcode.url: + if shortcode.fixed_views == 0: + logging.warning('Shortcode out of views') + should_redirect = False + elif shortcode.fixed_views > 0: + slinky.remove_view(shortcode.id) + if datetime.fromisoformat(shortcode.expiry) < datetime.now(): + logging.warning('Shortcode expired') + should_redirect = False + + if should_redirect: + return redirect(shortcode.url, 302) + + return Response('Not found', 404) + + +@slinky_webapp.route('/_/add', methods=['GET', 'POST']) def add() -> str: """ Create and add a new shorturl @@ -74,6 +106,12 @@ def add() -> str: expiry = form.expiry.data or datetime.max if url: - shortcode = add_shortcode(url, length, fixed_views, expiry) + slinky = Slinky(cfg['db']) + while True: + try: + shortcode = slinky.add(url, length, fixed_views, expiry) + break + except ValueError: + logging.warning('Shortcode already exists. Retrying.') return render_template('add.html', form=form, shortcode=shortcode) diff --git a/templates/_head.html b/templates/_head.html index 2da7b90..e15f33f 100644 --- a/templates/_head.html +++ b/templates/_head.html @@ -87,7 +87,7 @@ diff --git a/tests/test.db b/tests/test.db new file mode 100644 index 0000000..14a658c Binary files /dev/null and b/tests/test.db differ diff --git a/tests/test_slinky.py b/tests/test_slinky.py index 2383468..1e40b24 100644 --- a/tests/test_slinky.py +++ b/tests/test_slinky.py @@ -1,16 +1,58 @@ -from unittest import TestCase +""" +Test Slinky +""" +import random +from typing import Any +from unittest import TestCase, mock -import slinky +from slinky import Slinky, random_string class TestSlinky(TestCase): + """ + Class to test Slinky code + """ + + test_db = 'sqlite:///tests/test.db' + def test_random_string(self) -> None: """ Ensure the random string generates correctly """ - self.assertEqual(4, len(slinky.random_string())) - self.assertEqual(8, len(slinky.random_string(8))) - self.assertEqual(16, len(slinky.random_string(16))) - self.assertEqual(64, len(slinky.random_string(64))) - self.assertTrue(slinky.random_string(128).isalnum(), True) + rnd_len = random.randint(8, 128) + + self.assertEqual(4, len(random_string())) + self.assertEqual(rnd_len, len(random_string(rnd_len))) + + self.assertTrue(random_string(128).isalnum(), True) + + @mock.patch('sqlalchemy.orm.session.Session.add', return_value=None) + @mock.patch('slinky.random_string', return_value='abcd') + def test_add(self, *_: Any) -> None: + """ + Ensure we can add a shortcode to the DB + """ + self.assertEqual( + Slinky(self.test_db).add('https://www.example.com'), + 'abcd', + ) + + def test_get(self) -> None: + """ + Ensure we can fetch a URL for a known shortcode + """ + + self.assertEqual('https://example.com', Slinky(self.test_db).get('egie').url) + + @mock.patch('sqlalchemy.orm.session.Session.add', return_value=None) + @mock.patch('slinky.random_string', return_value='egie') + def test_duplicate_shortcode(self, *_: Any) -> None: + """ + Ensure duplicate shortcodes raise a ValueError exception + """ + self.assertRaises( + ValueError, + Slinky(self.test_db).add, + 'https://www.example.com', + ) diff --git a/tests/test_web.py b/tests/test_web.py new file mode 100644 index 0000000..4484d82 --- /dev/null +++ b/tests/test_web.py @@ -0,0 +1,35 @@ +""" +Test Slinky +""" +from typing import Any +from unittest import TestCase, mock + +from slinky import web + +@mock.patch.dict('slinky.web.cfg', {'db': 'sqlite:///tests/test.db'}) +class TestWeb(TestCase): + """ + Class to test Slinky code + """ + + def test_simple_redirect(self, *_: Any) -> None: + """ + Ensure simple redirect works + """ + response = web.try_path_as_shortcode('egie') + self.assertEqual(response.status_code, 302) + self.assertEqual(response.location, 'https://example.com') + + def test_fixed_views(self, *_: Any) -> None: + """ + Ensure simple redirect works + """ + response = web.try_path_as_shortcode('egig') + self.assertEqual(response.status_code, 404) + + def test_expiry(self, *_: Any) -> None: + """ + Ensure simple redirect works + """ + response = web.try_path_as_shortcode('egif') + self.assertEqual(response.status_code, 404)