From 436ddf6f23242eb709b591cd5e9cbf1553f8d390 Mon Sep 17 00:00:00 2001 From: Joris Date: Mon, 20 May 2024 09:40:11 +0200 Subject: Allow to upload file and download from given link --- src/db.py | 20 ++++++++++++ src/main.py | 24 ++++++++++++++ src/server.py | 85 +++++++++++++++++++++++++++++++++++++++++++++++++ src/templates.py | 97 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ src/utils.py | 19 +++++++++++ 5 files changed, 245 insertions(+) create mode 100644 src/db.py create mode 100644 src/main.py create mode 100644 src/server.py create mode 100644 src/templates.py create mode 100644 src/utils.py (limited to 'src') diff --git a/src/db.py b/src/db.py new file mode 100644 index 0000000..8aa20f8 --- /dev/null +++ b/src/db.py @@ -0,0 +1,20 @@ +import secrets + +def insert_file(conn, filename: str, expiration_days: int, content_length: int): + cur = conn.cursor() + file_id = secrets.token_urlsafe() + cur.execute( + 'INSERT INTO files(id, filename, created, expires, content_length) VALUES(?, ?, datetime(), datetime(datetime(), ?), ?)', + (file_id, filename, f'+{expiration_days} days', content_length) + ) + conn.commit() + return file_id + +def get_file(conn, file_id: str): + cur = conn.cursor() + res = cur.execute( + 'SELECT filename, expires, content_length FROM files WHERE id = ?', + (file_id,) + ) + return res.fetchone() + diff --git a/src/main.py b/src/main.py new file mode 100644 index 0000000..56c8e9e --- /dev/null +++ b/src/main.py @@ -0,0 +1,24 @@ +import http.server +import logging +import os +import sys + +import server + +logger = logging.getLogger(__name__) +hostName = os.environ['HOST'] +serverPort = int(os.environ['PORT']) + +if __name__ == '__main__': + logging.basicConfig(stream=sys.stdout, level=logging.INFO) + webServer = http.server.HTTPServer((hostName, serverPort), server.MyServer) + logger.info('Server started at http://%s:%s.' % (hostName, serverPort)) + + try: + webServer.serve_forever() + except KeyboardInterrupt: + pass + + webServer.server_close() + conn.close() + logger.info('Server stopped.') diff --git a/src/server.py b/src/server.py new file mode 100644 index 0000000..2cd6741 --- /dev/null +++ b/src/server.py @@ -0,0 +1,85 @@ +import http.server +import logging +import os +import sqlite3 +import tempfile + +import db +import templates +import utils + +logger = logging.getLogger(__name__) +conn = sqlite3.connect('db.sqlite3') +files_directory = 'download' +authorized_key = os.environ['KEY'] + +class MyServer(http.server.BaseHTTPRequestHandler): + def do_GET(self): + match self.path: + case '/': + self._serve_str(templates.index, 200, 'text/html') + case '/main.js': + self._serve_file('public/main.js', 'application/javascript') + case '/main.css': + self._serve_file('public/main.css', 'text/css') + case path: + prefix = f'/{files_directory}/' + if path.startswith(prefix): + file_id = path[len(prefix):] + res = db.get_file(conn, file_id) + if res is None: + self._serve_str(templates.not_found, 404, 'text/html') + else: + filename, _, content_length = res + path = os.path.join(files_directory, file_id) + headers = [ + ('Content-Disposition', f'attachment; filename={filename}'), + ('Content-Length', content_length) + ] + self._serve_file(path, 'application/octet-stream', headers) + else: + file_id = path[1:] + res = db.get_file(conn, file_id) + if res is None: + self._serve_str(templates.not_found, 404, 'text/html') + else: + filename, expires, _ = res + href = os.path.join(files_directory, file_id) + self._serve_str(templates.download(href, filename, expires), 200, 'text/html') + + def do_POST(self): + key = self.headers['X-Key'] + if not key == authorized_key: + logging.info('Unauthorized to upload file: wrong key') + self._serve_str('Unauthorized', 401) + + else: + logging.info('Uploading file') + content_length = int(self.headers['content-length']) + filename = utils.sanitize_filename(self.headers['X-FileName']) + expiration = self.headers['X-Expiration'] + + with tempfile.NamedTemporaryFile(delete = False) as tmp: + utils.transfer(self.rfile, tmp, content_length = content_length) + + logging.info('File uploaded') + file_id = db.insert_file(conn, filename, expiration, content_length) + os.makedirs(files_directory, exist_ok=True) + os.rename(tmp.name, os.path.join(files_directory, file_id)) + + self._serve_str(file_id, 200) + + def _serve_str(self, s, code, content_type='text/plain'): + self.send_response(code) + self.send_header('Content-type', content_type) + self.end_headers() + self.wfile.write(bytes(s, 'utf-8')) + + def _serve_file(self, filename, content_type, headers = []): + self.send_response(200) + self.send_header('Content-type', content_type) + for header_name, header_value in headers: + self.send_header(header_name, header_value) + self.end_headers() + with open(filename, 'rb') as f: + utils.transfer(f, self.wfile) diff --git a/src/templates.py b/src/templates.py new file mode 100644 index 0000000..1308fc0 --- /dev/null +++ b/src/templates.py @@ -0,0 +1,97 @@ +import html + +page: str = ''' + + + + + + Files + + + + +

Files

+
+''' + +index: str = f''' + {page} + +
+ + + + + + +
+
+ Uploading… +
+ +
+
+ + +
+''' + +def download(href: str, filename: str, expires: str) -> str: + return f''' + {page} + +
+ {html.escape(filename)} +
+ Expires: {html.escape(expires)} +
+
+ ''' + +not_found: str = f''' + {page} + + Sorry, the file you are looking for can not be found. It may have already expired. +''' diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000..ccf92c0 --- /dev/null +++ b/src/utils.py @@ -0,0 +1,19 @@ +import io + +def transfer(reader, writer, content_length = None, buffer_size = io.DEFAULT_BUFFER_SIZE): + if content_length is None: + while (data := reader.read(buffer_size)): + writer.write(data) + else: + remaining = content_length + while remaining > 0: + size = min(buffer_size, remaining) + writer.write(reader.read(size)) + remaining -= size + +def sanitize_filename(s: str) -> str: + return '.'.join([sanitize_filename_part(p) for p in s.split('.')]) + +def sanitize_filename_part(s: str) -> str: + alnum_or_space = ''.join([c if c.isalnum() else ' ' for c in s]) + return '-'.join(alnum_or_space.split()) -- cgit v1.2.3