aboutsummaryrefslogtreecommitdiff
path: root/src/server.py
diff options
context:
space:
mode:
authorJoris2024-05-20 10:09:52 +0200
committerJoris2024-05-20 10:09:52 +0200
commitd2f33d8e854b7fda6bae389d983025e6c8609842 (patch)
treef76997cf83eed847c6803535bed03331c84b66b4 /src/server.py
parent0167ad139146892c444fcfb2b4fe8d91a7871293 (diff)
Factor download page and file
Diffstat (limited to 'src/server.py')
-rw-r--r--src/server.py35
1 files changed, 17 insertions, 18 deletions
diff --git a/src/server.py b/src/server.py
index 2cd6741..5927052 100644
--- a/src/server.py
+++ b/src/server.py
@@ -10,7 +10,7 @@ import utils
logger = logging.getLogger(__name__)
conn = sqlite3.connect('db.sqlite3')
-files_directory = 'download'
+files_directory = 'files'
authorized_key = os.environ['KEY']
class MyServer(http.server.BaseHTTPRequestHandler):
@@ -23,28 +23,27 @@ class MyServer(http.server.BaseHTTPRequestHandler):
case '/main.css':
self._serve_file('public/main.css', 'text/css')
case path:
- prefix = f'/{files_directory}/'
- if path.startswith(prefix):
- file_id = path[len(prefix):]
- res = db.get_file(conn, file_id)
- if res is None:
- self._serve_str(templates.not_found, 404, 'text/html')
- else:
- filename, _, content_length = res
- path = os.path.join(files_directory, file_id)
+ if path.endswith('?download'):
+ download = True
+ path = path[:-len('?download')]
+ else:
+ download = False
+
+ file_id = path[1:]
+ res = db.get_file(conn, file_id)
+ if res is None:
+ self._serve_str(templates.not_found, 404, 'text/html')
+ else:
+ filename, expires, content_length = res
+ disk_path = os.path.join(files_directory, file_id)
+ if download:
headers = [
('Content-Disposition', f'attachment; filename={filename}'),
('Content-Length', content_length)
]
- self._serve_file(path, 'application/octet-stream', headers)
- else:
- file_id = path[1:]
- res = db.get_file(conn, file_id)
- if res is None:
- self._serve_str(templates.not_found, 404, 'text/html')
+ self._serve_file(disk_path, 'application/octet-stream', headers)
else:
- filename, expires, _ = res
- href = os.path.join(files_directory, file_id)
+ href = f'{file_id}?download'
self._serve_str(templates.download(href, filename, expires), 200, 'text/html')
def do_POST(self):