diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/db.py | 20 | ||||
-rw-r--r-- | src/main.py | 24 | ||||
-rw-r--r-- | src/server.py | 85 | ||||
-rw-r--r-- | src/templates.py | 97 | ||||
-rw-r--r-- | src/utils.py | 19 |
5 files changed, 245 insertions, 0 deletions
diff --git a/src/db.py b/src/db.py new file mode 100644 index 0000000..8aa20f8 --- /dev/null +++ b/src/db.py @@ -0,0 +1,20 @@ +import secrets + +def insert_file(conn, filename: str, expiration_days: int, content_length: int): + cur = conn.cursor() + file_id = secrets.token_urlsafe() + cur.execute( + 'INSERT INTO files(id, filename, created, expires, content_length) VALUES(?, ?, datetime(), datetime(datetime(), ?), ?)', + (file_id, filename, f'+{expiration_days} days', content_length) + ) + conn.commit() + return file_id + +def get_file(conn, file_id: str): + cur = conn.cursor() + res = cur.execute( + 'SELECT filename, expires, content_length FROM files WHERE id = ?', + (file_id,) + ) + return res.fetchone() + diff --git a/src/main.py b/src/main.py new file mode 100644 index 0000000..56c8e9e --- /dev/null +++ b/src/main.py @@ -0,0 +1,24 @@ +import http.server +import logging +import os +import sys + +import server + +logger = logging.getLogger(__name__) +hostName = os.environ['HOST'] +serverPort = int(os.environ['PORT']) + +if __name__ == '__main__': + logging.basicConfig(stream=sys.stdout, level=logging.INFO) + webServer = http.server.HTTPServer((hostName, serverPort), server.MyServer) + logger.info('Server started at http://%s:%s.' % (hostName, serverPort)) + + try: + webServer.serve_forever() + except KeyboardInterrupt: + pass + + webServer.server_close() + conn.close() + logger.info('Server stopped.') diff --git a/src/server.py b/src/server.py new file mode 100644 index 0000000..2cd6741 --- /dev/null +++ b/src/server.py @@ -0,0 +1,85 @@ +import http.server +import logging +import os +import sqlite3 +import tempfile + +import db +import templates +import utils + +logger = logging.getLogger(__name__) +conn = sqlite3.connect('db.sqlite3') +files_directory = 'download' +authorized_key = os.environ['KEY'] + +class MyServer(http.server.BaseHTTPRequestHandler): + def do_GET(self): + match self.path: + case '/': + self._serve_str(templates.index, 200, 'text/html') + case '/main.js': + self._serve_file('public/main.js', 'application/javascript') + case '/main.css': + self._serve_file('public/main.css', 'text/css') + case path: + prefix = f'/{files_directory}/' + if path.startswith(prefix): + file_id = path[len(prefix):] + res = db.get_file(conn, file_id) + if res is None: + self._serve_str(templates.not_found, 404, 'text/html') + else: + filename, _, content_length = res + path = os.path.join(files_directory, file_id) + headers = [ + ('Content-Disposition', f'attachment; filename={filename}'), + ('Content-Length', content_length) + ] + self._serve_file(path, 'application/octet-stream', headers) + else: + file_id = path[1:] + res = db.get_file(conn, file_id) + if res is None: + self._serve_str(templates.not_found, 404, 'text/html') + else: + filename, expires, _ = res + href = os.path.join(files_directory, file_id) + self._serve_str(templates.download(href, filename, expires), 200, 'text/html') + + def do_POST(self): + key = self.headers['X-Key'] + if not key == authorized_key: + logging.info('Unauthorized to upload file: wrong key') + self._serve_str('Unauthorized', 401) + + else: + logging.info('Uploading file') + content_length = int(self.headers['content-length']) + filename = utils.sanitize_filename(self.headers['X-FileName']) + expiration = self.headers['X-Expiration'] + + with tempfile.NamedTemporaryFile(delete = False) as tmp: + utils.transfer(self.rfile, tmp, content_length = content_length) + + logging.info('File uploaded') + file_id = db.insert_file(conn, filename, expiration, content_length) + os.makedirs(files_directory, exist_ok=True) + os.rename(tmp.name, os.path.join(files_directory, file_id)) + + self._serve_str(file_id, 200) + + def _serve_str(self, s, code, content_type='text/plain'): + self.send_response(code) + self.send_header('Content-type', content_type) + self.end_headers() + self.wfile.write(bytes(s, 'utf-8')) + + def _serve_file(self, filename, content_type, headers = []): + self.send_response(200) + self.send_header('Content-type', content_type) + for header_name, header_value in headers: + self.send_header(header_name, header_value) + self.end_headers() + with open(filename, 'rb') as f: + utils.transfer(f, self.wfile) diff --git a/src/templates.py b/src/templates.py new file mode 100644 index 0000000..1308fc0 --- /dev/null +++ b/src/templates.py @@ -0,0 +1,97 @@ +import html + +page: str = ''' + <!doctype html> + <html lang="fr"> + <meta charset="utf-8"> + <meta name="viewport" content="width=device-width"> + + <title>Files</title> + <link rel="stylesheet" href="/main.css"> + <script src="/main.js"></script> + + <a href="/"> + <h1>Files</h1> + </a> +''' + +index: str = f''' + {page} + + <form> + <label> + File + <input type="file" name="file" required> + </label> + + <label> + Expiration + <select name="expiration"> + <option value="1">1 day</option> + <option value="2">2 days</option> + <option value="3">3 days</option> + <option value="4">4 days</option> + <option value="5">5 days</option> + <option value="6">6 days</option> + <option value="7" selected>7 days</option> + <option value="8">8 days</option> + <option value="9">9 days</option> + <option value="10">10 days</option> + <option value="11">11 days</option> + <option value="12">12 days</option> + <option value="13">13 days</option> + <option value="14">14 days</option> + <option value="15">15 days</option> + <option value="16">16 days</option> + <option value="17">17 days</option> + <option value="18">18 days</option> + <option value="19">19 days</option> + <option value="20">20 days</option> + <option value="21">21 days</option> + <option value="22">22 days</option> + <option value="23">23 days</option> + <option value="24">24 days</option> + <option value="25">25 days</option> + <option value="26">26 days</option> + <option value="27">27 days</option> + <option value="28">28 days</option> + <option value="29">29 days</option> + <option value="30">30 days</option> + <option value="31">31 days</option> + </select> + </label> + + <label> + Key + <input type="password" name="key" required> + </label> + + <div class="g-Loading"> + <div class="g-Spinner"></div> + Uploading… + </div> + + <div class="g-Error"> + </div> + + <input type="submit" value="Upload"> + </form> +''' + +def download(href: str, filename: str, expires: str) -> str: + return f''' + {page} + + <div> + <a class="g-Link" href="{html.escape(href)}">{html.escape(filename)}</a> + <div> + Expires: {html.escape(expires)} + </div> + </div> + ''' + +not_found: str = f''' + {page} + + Sorry, the file you are looking for can not be found. It may have already expired. +''' diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000..ccf92c0 --- /dev/null +++ b/src/utils.py @@ -0,0 +1,19 @@ +import io + +def transfer(reader, writer, content_length = None, buffer_size = io.DEFAULT_BUFFER_SIZE): + if content_length is None: + while (data := reader.read(buffer_size)): + writer.write(data) + else: + remaining = content_length + while remaining > 0: + size = min(buffer_size, remaining) + writer.write(reader.read(size)) + remaining -= size + +def sanitize_filename(s: str) -> str: + return '.'.join([sanitize_filename_part(p) for p in s.split('.')]) + +def sanitize_filename_part(s: str) -> str: + alnum_or_space = ''.join([c if c.isalnum() else ' ' for c in s]) + return '-'.join(alnum_or_space.split()) |