Add ability to add new entry with custom shortcode
This commit is contained in:
parent
8a527e723e
commit
b939ad415d
|
@ -50,9 +50,10 @@ class Slinky:
|
|||
self.db = db.ShortcodeDB(db_url) # pylint: disable=invalid-name
|
||||
self.session = self.db.session()
|
||||
|
||||
def add(
|
||||
def add( # pylint: disable=too-many-arguments
|
||||
self,
|
||||
url: str,
|
||||
shortcode: str = '',
|
||||
url: str = '',
|
||||
length: int = 4,
|
||||
fixed_views: int = -1,
|
||||
expiry: datetime = datetime.max,
|
||||
|
@ -61,7 +62,11 @@ class Slinky:
|
|||
Add a shortcode to the DB
|
||||
|
||||
Args:
|
||||
shortcode (str): URL path to use for the shortcode. If not provided,
|
||||
one will be generated.
|
||||
url (str): URL to redirect to
|
||||
length (int): length of the desired shortcode. Only used when a shortcode
|
||||
is generated. Defaults to 4.
|
||||
fixed_views (int, optional): number of views to serve before expiring.
|
||||
Defaults to 0 (no limit).
|
||||
expiry (int, optional): date of expiry. Defaults to 0 (no limit).
|
||||
|
@ -69,6 +74,7 @@ class Slinky:
|
|||
Returns:
|
||||
str: shortcode for the redirect
|
||||
"""
|
||||
if not shortcode:
|
||||
shortcode = random_string(length=length)
|
||||
|
||||
if self.get_by_shortcode(shortcode).url:
|
||||
|
|
|
@ -7,10 +7,11 @@
|
|||
</div>
|
||||
<br />
|
||||
<form action="/_/add" method="post">
|
||||
{{ form.shortcode.label }} {{ form.shortcode(value=shortcode) }}<br />
|
||||
{{ form.url.label }} {{ form.url }}<br />
|
||||
{{ form.length.label }} {{ form.length }}<br />
|
||||
{{ form.fixed_views.label }} {{ form.fixed_views }} (-1 for unlimited)<br />
|
||||
{{ form.expiry.label}} {{ form.expiry(class='datepicker') }} (leave as default for unlimited)<br />
|
||||
{{ form.expiry.label}} {{ form.expiry(class="datepicker") }} (leave as default for unlimited)<br />
|
||||
|
||||
<button id="submit" class="btn btn-primary" type="submit" onclick="waiting();" style="margin: 1em 0;">
|
||||
Create shortcode
|
||||
|
@ -19,7 +20,7 @@
|
|||
</form>
|
||||
|
||||
<div id="content">
|
||||
{% if shortcode -%}
|
||||
{% if final_url -%}
|
||||
<table class="table table-striped table-sm">
|
||||
<thead>
|
||||
<tr>
|
||||
|
@ -28,7 +29,7 @@
|
|||
</thead>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td><a href="{{request.host_url}}{{ shortcode }}">{{request.host_url}}{{ shortcode }}</a></td>
|
||||
<td><a href="{{ final_url }}">{{ final_url }}</a></td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
|
|
@ -15,7 +15,7 @@ from wtforms import HiddenField # type: ignore[import]
|
|||
from wtforms import DateTimeLocalField, IntegerField, StringField
|
||||
from wtforms.validators import DataRequired, Length # type: ignore[import]
|
||||
|
||||
from slinky import Slinky
|
||||
from slinky import Slinky, random_string
|
||||
|
||||
slinky_webapp = Blueprint('webapp', __name__, template_folder='templates')
|
||||
|
||||
|
@ -36,6 +36,14 @@ class AddForm(FlaskForm): # type: ignore[misc]
|
|||
Add form definition
|
||||
"""
|
||||
|
||||
shortcode = StringField(
|
||||
'Shortcode',
|
||||
validators=[DataRequired(), Length(1, 2048)],
|
||||
render_kw={
|
||||
'size': 64,
|
||||
'maxlength': 2048,
|
||||
},
|
||||
)
|
||||
url = StringField(
|
||||
'URL',
|
||||
validators=[DataRequired(), Length(1, 2048)],
|
||||
|
@ -88,7 +96,7 @@ def protect(func: Callable[..., Response]) -> Callable[..., Response]:
|
|||
os.environ.get('FLASK_ENV', '') != 'development'
|
||||
and request.headers['X-Forwarded-For'] not in cfg['allowed_ips']
|
||||
):
|
||||
print(f'Protected URL access attempt from {request.remote_addr}')
|
||||
logging.warning('Protected URL access attempt from %s', request.remote_addr)
|
||||
return Response('Not found', 404)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
@ -133,31 +141,60 @@ def add() -> Response:
|
|||
Returns:
|
||||
Response: HTTP response
|
||||
"""
|
||||
shortcode = ''
|
||||
slinky = Slinky(cfg['db'])
|
||||
|
||||
for attempts in range(50):
|
||||
shortcode = random_string()
|
||||
if slinky.get_by_shortcode(shortcode).url:
|
||||
logging.warning(
|
||||
'Shortcode already exists. Retrying (%s/50).',
|
||||
attempts,
|
||||
)
|
||||
else:
|
||||
break
|
||||
else:
|
||||
return Response(
|
||||
render_template('error.html', msg='Could not create a unique shortcode'),
|
||||
500,
|
||||
)
|
||||
|
||||
url = ''
|
||||
final_url = ''
|
||||
|
||||
form = AddForm(meta={'csrf': False})
|
||||
|
||||
if form.is_submitted():
|
||||
shortcode = form.shortcode.data.strip()
|
||||
url = form.url.data.strip()
|
||||
length = form.length.data
|
||||
fixed_views = form.fixed_views.data
|
||||
expiry = form.expiry.data or datetime.max
|
||||
|
||||
if url:
|
||||
slinky = Slinky(cfg['db'])
|
||||
for attempts in range(50):
|
||||
try:
|
||||
shortcode = slinky.add(url, length, fixed_views, expiry)
|
||||
break
|
||||
except ValueError:
|
||||
logging.warning(
|
||||
'Shortcode already exists. Retrying (%d/50).', attempts
|
||||
shortcode = slinky.add(
|
||||
shortcode=shortcode,
|
||||
url=url,
|
||||
length=length,
|
||||
fixed_views=fixed_views,
|
||||
expiry=expiry,
|
||||
)
|
||||
else:
|
||||
return Response('Could not create a unique shortcode', 500)
|
||||
except ValueError as error:
|
||||
logging.warning(error)
|
||||
return Response(render_template('error.html', msg=error), 400)
|
||||
|
||||
return Response(render_template('add.html', form=form, shortcode=shortcode), 200)
|
||||
if form.is_submitted():
|
||||
final_url = f'{request.host_url}/{shortcode}'
|
||||
|
||||
return Response(
|
||||
render_template(
|
||||
'add.html',
|
||||
form=form,
|
||||
shortcode=shortcode,
|
||||
final_url=final_url,
|
||||
),
|
||||
200,
|
||||
)
|
||||
|
||||
|
||||
@slinky_webapp.route('/_/list', methods=['GET', 'POST'])
|
||||
|
|
7
templates/error.html
Normal file
7
templates/error.html
Normal file
|
@ -0,0 +1,7 @@
|
|||
{% include '_head.html' -%}
|
||||
<main class="container">
|
||||
<div class="container">
|
||||
<div class="alert alert-danger" role="alert">ERROR: {{ msg }}</div>
|
||||
</div>
|
||||
</main>
|
||||
{% include '_tail.html' -%}
|
|
@ -34,7 +34,7 @@ class TestSlinky(TestCase):
|
|||
Ensure we can add a shortcode to the DB
|
||||
"""
|
||||
self.assertEqual(
|
||||
Slinky(self.test_db).add('https://www.example.com'),
|
||||
Slinky(self.test_db).add(url='https://www.example.com'),
|
||||
'abcd',
|
||||
)
|
||||
|
||||
|
@ -56,7 +56,20 @@ class TestSlinky(TestCase):
|
|||
self.assertRaises(
|
||||
ValueError,
|
||||
Slinky(self.test_db).add,
|
||||
'https://www.example.com',
|
||||
url='https://www.example.com',
|
||||
)
|
||||
|
||||
@mock.patch('sqlalchemy.orm.session.Session.add', return_value=None)
|
||||
def test_supplied_shortcode(self, *_: Any) -> None:
|
||||
"""
|
||||
Ensure a shortcode can be supplied
|
||||
"""
|
||||
self.assertEqual(
|
||||
'__TEST__',
|
||||
Slinky(self.test_db).add(
|
||||
shortcode='__TEST__',
|
||||
url='https://www.example.com',
|
||||
),
|
||||
)
|
||||
|
||||
def test_get_all(self) -> None:
|
||||
|
|
|
@ -1,56 +1,69 @@
|
|||
"""
|
||||
Test Slinky
|
||||
Test Slinky web interface
|
||||
"""
|
||||
from typing import Any
|
||||
|
||||
from unittest import TestCase, mock
|
||||
|
||||
from slinky import web
|
||||
from flask import Flask
|
||||
from flask_bootstrap import Bootstrap # type: ignore[import]
|
||||
from slinky.web import slinky_webapp
|
||||
|
||||
|
||||
@mock.patch.dict('slinky.web.cfg', {'db': 'sqlite:///tests/test.db'})
|
||||
class TestWeb(TestCase):
|
||||
"""
|
||||
Class to test Slinky code
|
||||
"""
|
||||
|
||||
def test_simple_redirect(self, *_: Any) -> None:
|
||||
def setUp(self) -> None:
|
||||
self.app = Flask(__name__, template_folder='../templates')
|
||||
self.app.register_blueprint(slinky_webapp)
|
||||
self.app_context = self.app.app_context()
|
||||
self.app_context.push()
|
||||
self.client = self.app.test_client()
|
||||
|
||||
Bootstrap(self.app)
|
||||
|
||||
mock.patch.dict('slinky.web.cfg', {'db': 'sqlite:///tests/test.db'}).start()
|
||||
|
||||
def test_simple_redirect(self) -> None:
|
||||
"""
|
||||
Ensure simple redirect works
|
||||
"""
|
||||
response = web.try_path_as_shortcode('egie')
|
||||
response = self.client.get('/egie')
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertEqual(response.location, 'https://example.com')
|
||||
|
||||
def test_fixed_views(self, *_: Any) -> None:
|
||||
def test_fixed_views(self) -> None:
|
||||
"""
|
||||
Ensure simple redirect works
|
||||
Ensure depleted fixed views returns a 404
|
||||
"""
|
||||
response = web.try_path_as_shortcode('egig')
|
||||
response = self.client.get('/egig')
|
||||
self.assertEqual(response.status_code, 404)
|
||||
|
||||
def test_expiry(self, *_: Any) -> None:
|
||||
def test_expiry(self) -> None:
|
||||
"""
|
||||
Ensure simple redirect works
|
||||
Ensure expired redirect returns a 404
|
||||
"""
|
||||
response = web.try_path_as_shortcode('egif')
|
||||
response = self.client.get('/egif')
|
||||
self.assertEqual(response.status_code, 404)
|
||||
|
||||
@mock.patch(
|
||||
'slinky.web.AddForm',
|
||||
return_value=mock.Mock(
|
||||
shortcode=mock.Mock(data=''),
|
||||
url=mock.Mock(data='https://example.com'),
|
||||
fixed_views=mock.Mock(data=0),
|
||||
expiry=mock.Mock(data='1970-01-01 00:00:00.000000'),
|
||||
),
|
||||
)
|
||||
@mock.patch('slinky.random_string', return_value='egie')
|
||||
def test_no_unique_shortcode(self, *_: Any) -> None:
|
||||
def test_no_unique_shortcode(self) -> None:
|
||||
"""
|
||||
Ensure non-unique shortcode generation returns a 500 error
|
||||
"""
|
||||
request = mock.MagicMock()
|
||||
request.headers = {'X-Forwarded-For': '127.0.0.1'}
|
||||
with mock.patch("slinky.web.request", request):
|
||||
response = web.add()
|
||||
with mock.patch('slinky.web.random_string', return_value='egie'):
|
||||
response = self.client.get(
|
||||
'/_/add', headers={'x-forwarded-for': '127.0.0.1'}
|
||||
)
|
||||
self.assertEqual(response.status_code, 500)
|
||||
|
||||
def test_conflicting_random_string(self) -> None:
|
||||
"""
|
||||
Test the condition where the random_string() returns an existing shortcode
|
||||
"""
|
||||
with mock.patch('slinky.web.random_string', side_effect=['egie', 'egiz']):
|
||||
response = self.client.get(
|
||||
'/_/add',
|
||||
headers={'x-forwarded-for': '127.0.0.1'},
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
|
Loading…
Reference in a new issue