diff --git a/slinky/__init__.py b/slinky/__init__.py
index a4c6764..05f18eb 100644
--- a/slinky/__init__.py
+++ b/slinky/__init__.py
@@ -50,9 +50,10 @@ class Slinky:
self.db = db.ShortcodeDB(db_url) # pylint: disable=invalid-name
self.session = self.db.session()
- def add(
+ def add( # pylint: disable=too-many-arguments
self,
- url: str,
+ shortcode: str = '',
+ url: str = '',
length: int = 4,
fixed_views: int = -1,
expiry: datetime = datetime.max,
@@ -61,7 +62,11 @@ class Slinky:
Add a shortcode to the DB
Args:
+ shortcode (str): URL path to use for the shortcode. If not provided,
+ one will be generated.
url (str): URL to redirect to
+ length (int): length of the desired shortcode. Only used when a shortcode
+ is generated. Defaults to 4.
fixed_views (int, optional): number of views to serve before expiring.
Defaults to 0 (no limit).
expiry (int, optional): date of expiry. Defaults to 0 (no limit).
@@ -69,7 +74,8 @@ class Slinky:
Returns:
str: shortcode for the redirect
"""
- shortcode = random_string(length=length)
+ if not shortcode:
+ shortcode = random_string(length=length)
if self.get_by_shortcode(shortcode).url:
raise ValueError(f'Shortcode {shortcode} already exists')
diff --git a/slinky/templates/add.html b/slinky/templates/add.html
index 4851bae..638a67d 100644
--- a/slinky/templates/add.html
+++ b/slinky/templates/add.html
@@ -7,10 +7,11 @@
- {% if shortcode -%}
+ {% if final_url -%}
diff --git a/slinky/web.py b/slinky/web.py
index 7a296f5..14b0cd3 100644
--- a/slinky/web.py
+++ b/slinky/web.py
@@ -15,7 +15,7 @@ from wtforms import HiddenField # type: ignore[import]
from wtforms import DateTimeLocalField, IntegerField, StringField
from wtforms.validators import DataRequired, Length # type: ignore[import]
-from slinky import Slinky
+from slinky import Slinky, random_string
slinky_webapp = Blueprint('webapp', __name__, template_folder='templates')
@@ -36,6 +36,14 @@ class AddForm(FlaskForm): # type: ignore[misc]
Add form definition
"""
+ shortcode = StringField(
+ 'Shortcode',
+ validators=[DataRequired(), Length(1, 2048)],
+ render_kw={
+ 'size': 64,
+ 'maxlength': 2048,
+ },
+ )
url = StringField(
'URL',
validators=[DataRequired(), Length(1, 2048)],
@@ -88,7 +96,7 @@ def protect(func: Callable[..., Response]) -> Callable[..., Response]:
os.environ.get('FLASK_ENV', '') != 'development'
and request.headers['X-Forwarded-For'] not in cfg['allowed_ips']
):
- print(f'Protected URL access attempt from {request.remote_addr}')
+ logging.warning('Protected URL access attempt from %s', request.remote_addr)
return Response('Not found', 404)
return func(*args, **kwargs)
@@ -133,31 +141,60 @@ def add() -> Response:
Returns:
Response: HTTP response
"""
- shortcode = ''
+ slinky = Slinky(cfg['db'])
+
+ for attempts in range(50):
+ shortcode = random_string()
+ if slinky.get_by_shortcode(shortcode).url:
+ logging.warning(
+ 'Shortcode already exists. Retrying (%s/50).',
+ attempts,
+ )
+ else:
+ break
+ else:
+ return Response(
+ render_template('error.html', msg='Could not create a unique shortcode'),
+ 500,
+ )
+
url = ''
+ final_url = ''
form = AddForm(meta={'csrf': False})
if form.is_submitted():
+ shortcode = form.shortcode.data.strip()
url = form.url.data.strip()
length = form.length.data
fixed_views = form.fixed_views.data
expiry = form.expiry.data or datetime.max
if url:
- slinky = Slinky(cfg['db'])
- for attempts in range(50):
- try:
- shortcode = slinky.add(url, length, fixed_views, expiry)
- break
- except ValueError:
- logging.warning(
- 'Shortcode already exists. Retrying (%d/50).', attempts
- )
- else:
- return Response('Could not create a unique shortcode', 500)
+ try:
+ shortcode = slinky.add(
+ shortcode=shortcode,
+ url=url,
+ length=length,
+ fixed_views=fixed_views,
+ expiry=expiry,
+ )
+ except ValueError as error:
+ logging.warning(error)
+ return Response(render_template('error.html', msg=error), 400)
- return Response(render_template('add.html', form=form, shortcode=shortcode), 200)
+ if form.is_submitted():
+ final_url = f'{request.host_url}/{shortcode}'
+
+ return Response(
+ render_template(
+ 'add.html',
+ form=form,
+ shortcode=shortcode,
+ final_url=final_url,
+ ),
+ 200,
+ )
@slinky_webapp.route('/_/list', methods=['GET', 'POST'])
diff --git a/templates/error.html b/templates/error.html
new file mode 100644
index 0000000..3093ba3
--- /dev/null
+++ b/templates/error.html
@@ -0,0 +1,7 @@
+{% include '_head.html' -%}
+
+
+
+{% include '_tail.html' -%}
diff --git a/tests/test_slinky.py b/tests/test_slinky.py
index 76af67b..143fcbf 100644
--- a/tests/test_slinky.py
+++ b/tests/test_slinky.py
@@ -34,7 +34,7 @@ class TestSlinky(TestCase):
Ensure we can add a shortcode to the DB
"""
self.assertEqual(
- Slinky(self.test_db).add('https://www.example.com'),
+ Slinky(self.test_db).add(url='https://www.example.com'),
'abcd',
)
@@ -56,7 +56,20 @@ class TestSlinky(TestCase):
self.assertRaises(
ValueError,
Slinky(self.test_db).add,
- 'https://www.example.com',
+ url='https://www.example.com',
+ )
+
+ @mock.patch('sqlalchemy.orm.session.Session.add', return_value=None)
+ def test_supplied_shortcode(self, *_: Any) -> None:
+ """
+ Ensure a shortcode can be supplied
+ """
+ self.assertEqual(
+ '__TEST__',
+ Slinky(self.test_db).add(
+ shortcode='__TEST__',
+ url='https://www.example.com',
+ ),
)
def test_get_all(self) -> None:
diff --git a/tests/test_web.py b/tests/test_web.py
index 2f0c6ef..099ac86 100644
--- a/tests/test_web.py
+++ b/tests/test_web.py
@@ -1,56 +1,69 @@
"""
-Test Slinky
+Test Slinky web interface
"""
-from typing import Any
+
from unittest import TestCase, mock
-from slinky import web
+from flask import Flask
+from flask_bootstrap import Bootstrap # type: ignore[import]
+from slinky.web import slinky_webapp
-@mock.patch.dict('slinky.web.cfg', {'db': 'sqlite:///tests/test.db'})
class TestWeb(TestCase):
"""
Class to test Slinky code
"""
- def test_simple_redirect(self, *_: Any) -> None:
+ def setUp(self) -> None:
+ self.app = Flask(__name__, template_folder='../templates')
+ self.app.register_blueprint(slinky_webapp)
+ self.app_context = self.app.app_context()
+ self.app_context.push()
+ self.client = self.app.test_client()
+
+ Bootstrap(self.app)
+
+ mock.patch.dict('slinky.web.cfg', {'db': 'sqlite:///tests/test.db'}).start()
+
+ def test_simple_redirect(self) -> None:
"""
Ensure simple redirect works
"""
- response = web.try_path_as_shortcode('egie')
+ response = self.client.get('/egie')
self.assertEqual(response.status_code, 302)
self.assertEqual(response.location, 'https://example.com')
- def test_fixed_views(self, *_: Any) -> None:
+ def test_fixed_views(self) -> None:
"""
- Ensure simple redirect works
+ Ensure depleted fixed views returns a 404
"""
- response = web.try_path_as_shortcode('egig')
+ response = self.client.get('/egig')
self.assertEqual(response.status_code, 404)
- def test_expiry(self, *_: Any) -> None:
+ def test_expiry(self) -> None:
"""
- Ensure simple redirect works
+ Ensure expired redirect returns a 404
"""
- response = web.try_path_as_shortcode('egif')
+ response = self.client.get('/egif')
self.assertEqual(response.status_code, 404)
- @mock.patch(
- 'slinky.web.AddForm',
- return_value=mock.Mock(
- shortcode=mock.Mock(data=''),
- url=mock.Mock(data='https://example.com'),
- fixed_views=mock.Mock(data=0),
- expiry=mock.Mock(data='1970-01-01 00:00:00.000000'),
- ),
- )
- @mock.patch('slinky.random_string', return_value='egie')
- def test_no_unique_shortcode(self, *_: Any) -> None:
+ def test_no_unique_shortcode(self) -> None:
"""
Ensure non-unique shortcode generation returns a 500 error
"""
- request = mock.MagicMock()
- request.headers = {'X-Forwarded-For': '127.0.0.1'}
- with mock.patch("slinky.web.request", request):
- response = web.add()
+ with mock.patch('slinky.web.random_string', return_value='egie'):
+ response = self.client.get(
+ '/_/add', headers={'x-forwarded-for': '127.0.0.1'}
+ )
self.assertEqual(response.status_code, 500)
+
+ def test_conflicting_random_string(self) -> None:
+ """
+ Test the condition where the random_string() returns an existing shortcode
+ """
+ with mock.patch('slinky.web.random_string', side_effect=['egie', 'egiz']):
+ response = self.client.get(
+ '/_/add',
+ headers={'x-forwarded-for': '127.0.0.1'},
+ )
+ self.assertEqual(response.status_code, 200)