aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/db.py20
-rw-r--r--src/main.py24
-rw-r--r--src/server.py85
-rw-r--r--src/templates.py97
-rw-r--r--src/utils.py19
5 files changed, 245 insertions, 0 deletions
diff --git a/src/db.py b/src/db.py
new file mode 100644
index 0000000..8aa20f8
--- /dev/null
+++ b/src/db.py
@@ -0,0 +1,20 @@
+import secrets
+
+def insert_file(conn, filename: str, expiration_days: int, content_length: int):
+ cur = conn.cursor()
+ file_id = secrets.token_urlsafe()
+ cur.execute(
+ 'INSERT INTO files(id, filename, created, expires, content_length) VALUES(?, ?, datetime(), datetime(datetime(), ?), ?)',
+ (file_id, filename, f'+{expiration_days} days', content_length)
+ )
+ conn.commit()
+ return file_id
+
+def get_file(conn, file_id: str):
+ cur = conn.cursor()
+ res = cur.execute(
+ 'SELECT filename, expires, content_length FROM files WHERE id = ?',
+ (file_id,)
+ )
+ return res.fetchone()
+
diff --git a/src/main.py b/src/main.py
new file mode 100644
index 0000000..56c8e9e
--- /dev/null
+++ b/src/main.py
@@ -0,0 +1,24 @@
+import http.server
+import logging
+import os
+import sys
+
+import server
+
+logger = logging.getLogger(__name__)
+hostName = os.environ['HOST']
+serverPort = int(os.environ['PORT'])
+
+if __name__ == '__main__':
+ logging.basicConfig(stream=sys.stdout, level=logging.INFO)
+ webServer = http.server.HTTPServer((hostName, serverPort), server.MyServer)
+ logger.info('Server started at http://%s:%s.' % (hostName, serverPort))
+
+ try:
+ webServer.serve_forever()
+ except KeyboardInterrupt:
+ pass
+
+ webServer.server_close()
+ conn.close()
+ logger.info('Server stopped.')
diff --git a/src/server.py b/src/server.py
new file mode 100644
index 0000000..2cd6741
--- /dev/null
+++ b/src/server.py
@@ -0,0 +1,85 @@
+import http.server
+import logging
+import os
+import sqlite3
+import tempfile
+
+import db
+import templates
+import utils
+
+logger = logging.getLogger(__name__)
+conn = sqlite3.connect('db.sqlite3')
+files_directory = 'download'
+authorized_key = os.environ['KEY']
+
+class MyServer(http.server.BaseHTTPRequestHandler):
+ def do_GET(self):
+ match self.path:
+ case '/':
+ self._serve_str(templates.index, 200, 'text/html')
+ case '/main.js':
+ self._serve_file('public/main.js', 'application/javascript')
+ case '/main.css':
+ self._serve_file('public/main.css', 'text/css')
+ case path:
+ prefix = f'/{files_directory}/'
+ if path.startswith(prefix):
+ file_id = path[len(prefix):]
+ res = db.get_file(conn, file_id)
+ if res is None:
+ self._serve_str(templates.not_found, 404, 'text/html')
+ else:
+ filename, _, content_length = res
+ path = os.path.join(files_directory, file_id)
+ headers = [
+ ('Content-Disposition', f'attachment; filename={filename}'),
+ ('Content-Length', content_length)
+ ]
+ self._serve_file(path, 'application/octet-stream', headers)
+ else:
+ file_id = path[1:]
+ res = db.get_file(conn, file_id)
+ if res is None:
+ self._serve_str(templates.not_found, 404, 'text/html')
+ else:
+ filename, expires, _ = res
+ href = os.path.join(files_directory, file_id)
+ self._serve_str(templates.download(href, filename, expires), 200, 'text/html')
+
+ def do_POST(self):
+ key = self.headers['X-Key']
+ if not key == authorized_key:
+ logging.info('Unauthorized to upload file: wrong key')
+ self._serve_str('Unauthorized', 401)
+
+ else:
+ logging.info('Uploading file')
+ content_length = int(self.headers['content-length'])
+ filename = utils.sanitize_filename(self.headers['X-FileName'])
+ expiration = self.headers['X-Expiration']
+
+ with tempfile.NamedTemporaryFile(delete = False) as tmp:
+ utils.transfer(self.rfile, tmp, content_length = content_length)
+
+ logging.info('File uploaded')
+ file_id = db.insert_file(conn, filename, expiration, content_length)
+ os.makedirs(files_directory, exist_ok=True)
+ os.rename(tmp.name, os.path.join(files_directory, file_id))
+
+ self._serve_str(file_id, 200)
+
+ def _serve_str(self, s, code, content_type='text/plain'):
+ self.send_response(code)
+ self.send_header('Content-type', content_type)
+ self.end_headers()
+ self.wfile.write(bytes(s, 'utf-8'))
+
+ def _serve_file(self, filename, content_type, headers = []):
+ self.send_response(200)
+ self.send_header('Content-type', content_type)
+ for header_name, header_value in headers:
+ self.send_header(header_name, header_value)
+ self.end_headers()
+ with open(filename, 'rb') as f:
+ utils.transfer(f, self.wfile)
diff --git a/src/templates.py b/src/templates.py
new file mode 100644
index 0000000..1308fc0
--- /dev/null
+++ b/src/templates.py
@@ -0,0 +1,97 @@
+import html
+
+page: str = '''
+ <!doctype html>
+ <html lang="fr">
+ <meta charset="utf-8">
+ <meta name="viewport" content="width=device-width">
+
+ <title>Files</title>
+ <link rel="stylesheet" href="/main.css">
+ <script src="/main.js"></script>
+
+ <a href="/">
+ <h1>Files</h1>
+ </a>
+'''
+
+index: str = f'''
+ {page}
+
+ <form>
+ <label>
+ File
+ <input type="file" name="file" required>
+ </label>
+
+ <label>
+ Expiration
+ <select name="expiration">
+ <option value="1">1 day</option>
+ <option value="2">2 days</option>
+ <option value="3">3 days</option>
+ <option value="4">4 days</option>
+ <option value="5">5 days</option>
+ <option value="6">6 days</option>
+ <option value="7" selected>7 days</option>
+ <option value="8">8 days</option>
+ <option value="9">9 days</option>
+ <option value="10">10 days</option>
+ <option value="11">11 days</option>
+ <option value="12">12 days</option>
+ <option value="13">13 days</option>
+ <option value="14">14 days</option>
+ <option value="15">15 days</option>
+ <option value="16">16 days</option>
+ <option value="17">17 days</option>
+ <option value="18">18 days</option>
+ <option value="19">19 days</option>
+ <option value="20">20 days</option>
+ <option value="21">21 days</option>
+ <option value="22">22 days</option>
+ <option value="23">23 days</option>
+ <option value="24">24 days</option>
+ <option value="25">25 days</option>
+ <option value="26">26 days</option>
+ <option value="27">27 days</option>
+ <option value="28">28 days</option>
+ <option value="29">29 days</option>
+ <option value="30">30 days</option>
+ <option value="31">31 days</option>
+ </select>
+ </label>
+
+ <label>
+ Key
+ <input type="password" name="key" required>
+ </label>
+
+ <div class="g-Loading">
+ <div class="g-Spinner"></div>
+ Uploading…
+ </div>
+
+ <div class="g-Error">
+ </div>
+
+ <input type="submit" value="Upload">
+ </form>
+'''
+
+def download(href: str, filename: str, expires: str) -> str:
+ return f'''
+ {page}
+
+ <div>
+ <a class="g-Link" href="{html.escape(href)}">{html.escape(filename)}</a>
+ <div>
+ Expires: {html.escape(expires)}
+ </div>
+ </div>
+ '''
+
+not_found: str = f'''
+ {page}
+
+ Sorry, the file you are looking for can not be found. It may have already expired.
+'''
diff --git a/src/utils.py b/src/utils.py
new file mode 100644
index 0000000..ccf92c0
--- /dev/null
+++ b/src/utils.py
@@ -0,0 +1,19 @@
+import io
+
+def transfer(reader, writer, content_length = None, buffer_size = io.DEFAULT_BUFFER_SIZE):
+ if content_length is None:
+ while (data := reader.read(buffer_size)):
+ writer.write(data)
+ else:
+ remaining = content_length
+ while remaining > 0:
+ size = min(buffer_size, remaining)
+ writer.write(reader.read(size))
+ remaining -= size
+
+def sanitize_filename(s: str) -> str:
+ return '.'.join([sanitize_filename_part(p) for p in s.split('.')])
+
+def sanitize_filename_part(s: str) -> str:
+ alnum_or_space = ''.join([c if c.isalnum() else ' ' for c in s])
+ return '-'.join(alnum_or_space.split())