From d2f33d8e854b7fda6bae389d983025e6c8609842 Mon Sep 17 00:00:00 2001 From: Joris Date: Mon, 20 May 2024 10:09:52 +0200 Subject: Factor download page and file --- src/server.py | 35 +++++++++++++++++------------------ src/templates.py | 6 ++++-- 2 files changed, 21 insertions(+), 20 deletions(-) (limited to 'src') diff --git a/src/server.py b/src/server.py index 2cd6741..5927052 100644 --- a/src/server.py +++ b/src/server.py @@ -10,7 +10,7 @@ import utils logger = logging.getLogger(__name__) conn = sqlite3.connect('db.sqlite3') -files_directory = 'download' +files_directory = 'files' authorized_key = os.environ['KEY'] class MyServer(http.server.BaseHTTPRequestHandler): @@ -23,28 +23,27 @@ class MyServer(http.server.BaseHTTPRequestHandler): case '/main.css': self._serve_file('public/main.css', 'text/css') case path: - prefix = f'/{files_directory}/' - if path.startswith(prefix): - file_id = path[len(prefix):] - res = db.get_file(conn, file_id) - if res is None: - self._serve_str(templates.not_found, 404, 'text/html') - else: - filename, _, content_length = res - path = os.path.join(files_directory, file_id) + if path.endswith('?download'): + download = True + path = path[:-len('?download')] + else: + download = False + + file_id = path[1:] + res = db.get_file(conn, file_id) + if res is None: + self._serve_str(templates.not_found, 404, 'text/html') + else: + filename, expires, content_length = res + disk_path = os.path.join(files_directory, file_id) + if download: headers = [ ('Content-Disposition', f'attachment; filename={filename}'), ('Content-Length', content_length) ] - self._serve_file(path, 'application/octet-stream', headers) - else: - file_id = path[1:] - res = db.get_file(conn, file_id) - if res is None: - self._serve_str(templates.not_found, 404, 'text/html') + self._serve_file(disk_path, 'application/octet-stream', headers) else: - filename, expires, _ = res - href = os.path.join(files_directory, file_id) + href = f'{file_id}?download' self._serve_str(templates.download(href, filename, expires), 200, 'text/html') def do_POST(self): diff --git a/src/templates.py b/src/templates.py index 1308fc0..c605e57 100644 --- a/src/templates.py +++ b/src/templates.py @@ -1,4 +1,5 @@ import html +import datetime page: str = ''' @@ -79,13 +80,14 @@ index: str = f''' ''' def download(href: str, filename: str, expires: str) -> str: + expires_in = datetime.datetime.strptime(expires, '%Y-%m-%d %H:%M:%S') - datetime.datetime.now() return f''' {page}