aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJoris2024-05-20 20:09:01 +0200
committerJoris2024-06-02 14:20:50 +0200
commite8da9790dc6d55cd2e8883322cdf9a7bf5b4f5b7 (patch)
treec960e1036e3d93c1f91b58dfe9e3c8e6038ed363
parentd2f33d8e854b7fda6bae389d983025e6c8609842 (diff)
Migrate to sanic
-rw-r--r--flake.nix3
-rw-r--r--src/controller.py58
-rw-r--r--src/main-old.py33
-rw-r--r--src/main.py40
-rw-r--r--src/templates.py14
-rw-r--r--src/utils.py11
-rw-r--r--static/main.css (renamed from public/main.css)0
-rw-r--r--static/main.js (renamed from public/main.js)0
8 files changed, 126 insertions, 33 deletions
diff --git a/flake.nix b/flake.nix
index 92951f0..170fe8d 100644
--- a/flake.nix
+++ b/flake.nix
@@ -13,9 +13,12 @@
(pkgs.python3.withPackages (pythonPackages: with pythonPackages; [
sqlite
watchexec
+ sanic
+ setuptools
]))
];
shellHook = ''
+ export DEBUG="TRUE"
export HOST="127.0.0.1"
export PORT="8080"
export KEY="1234"
diff --git a/src/controller.py b/src/controller.py
new file mode 100644
index 0000000..351d0bc
--- /dev/null
+++ b/src/controller.py
@@ -0,0 +1,58 @@
+import io
+import logging
+import os
+import sanic
+import sqlite3
+import tempfile
+
+import db
+import templates
+import utils
+
+conn = sqlite3.connect('db.sqlite3')
+files_directory = 'files'
+authorized_key = os.environ['KEY']
+
+def index():
+ return sanic.html(templates.index)
+
+async def upload(request):
+ key = request.headers.get('X-Key')
+ if not key == authorized_key:
+ sanic.log.logging.info('Unauthorized to upload file: wrong key')
+ return sanic.text('Unauthorized', status = 401)
+ else:
+ sanic.log.logging.info('Uploading file')
+ content_length = int(request.headers.get('content-length'))
+ filename = utils.sanitize_filename(request.headers.get('X-FileName'))
+ expiration = request.headers.get('X-Expiration')
+
+ with tempfile.NamedTemporaryFile(delete = False) as tmp:
+ while data := await request.stream.read():
+ tmp.write(data)
+
+ sanic.log.logging.info('File uploaded')
+ file_id = db.insert_file(conn, filename, expiration, content_length)
+ os.makedirs(files_directory, exist_ok=True)
+ os.rename(tmp.name, os.path.join(files_directory, file_id))
+
+ return sanic.text(file_id)
+
+async def file(file_id: str, download: bool):
+ res = db.get_file(conn, file_id)
+ if res is None:
+ self._serve_str(templates.not_found, 404, 'text/html')
+ else:
+ filename, expires, content_length = res
+ disk_path = os.path.join(files_directory, file_id)
+ if download:
+ return await sanic.response.file_stream(
+ disk_path,
+ chunk_size = io.DEFAULT_BUFFER_SIZE,
+ headers = {
+ 'Content-Disposition': f'attachment; filename={filename}',
+ 'Content-Length': content_length
+ }
+ )
+ else:
+ return sanic.html(templates.file_page(file_id, filename, expires))
diff --git a/src/main-old.py b/src/main-old.py
new file mode 100644
index 0000000..42d7c8c
--- /dev/null
+++ b/src/main-old.py
@@ -0,0 +1,33 @@
+# import http.server
+# import logging
+# import os
+# import sys
+
+# import server
+
+# logger = logging.getLogger(__name__)
+# hostName = os.environ['HOST']
+# serverPort = int(os.environ['PORT'])
+
+# if __name__ == '__main__':
+# logging.basicConfig(stream=sys.stdout, level=logging.INFO)
+# webServer = http.server.HTTPServer((hostName, serverPort), server.MyServer)
+# logger.info('Server started at http://%s:%s.' % (hostName, serverPort))
+
+# try:
+# webServer.serve_forever()
+# except KeyboardInterrupt:
+# pass
+
+# webServer.server_close()
+# conn.close()
+# logger.info('Server stopped.')
+
+from sanic import Sanic
+from sanic.response import text
+
+app = Sanic("MyHelloWorldApp")
+
+@app.get("/")
+async def hello_world(request):
+ return text("Hello, world.")
diff --git a/src/main.py b/src/main.py
index 56c8e9e..b678aae 100644
--- a/src/main.py
+++ b/src/main.py
@@ -1,24 +1,28 @@
-import http.server
-import logging
+import sanic
import os
-import sys
-import server
+import controller
-logger = logging.getLogger(__name__)
-hostName = os.environ['HOST']
-serverPort = int(os.environ['PORT'])
+app = sanic.Sanic("Files")
-if __name__ == '__main__':
- logging.basicConfig(stream=sys.stdout, level=logging.INFO)
- webServer = http.server.HTTPServer((hostName, serverPort), server.MyServer)
- logger.info('Server started at http://%s:%s.' % (hostName, serverPort))
+@app.get("/")
+async def index(request):
+ return controller.index()
- try:
- webServer.serve_forever()
- except KeyboardInterrupt:
- pass
+@app.post("/", stream = True)
+async def upload(request):
+ return await controller.upload(request)
- webServer.server_close()
- conn.close()
- logger.info('Server stopped.')
+@app.get("/<file_id:str>")
+async def file_page(request, file_id):
+ return await controller.file(file_id, download = False)
+
+@app.get("/<file_id:str>/download")
+async def file_download(request, file_id):
+ return await controller.file(file_id, download = True)
+
+app.static("/static/", "static/")
+
+if __name__ == "__main__":
+ debug = 'DEBUG' in os.environ and os.environ['DEBUG'] == 'TRUE'
+ app.run(debug=debug, access_log=True)
diff --git a/src/templates.py b/src/templates.py
index c605e57..8125f69 100644
--- a/src/templates.py
+++ b/src/templates.py
@@ -8,15 +8,15 @@ page: str = '''
<meta name="viewport" content="width=device-width">
<title>Files</title>
- <link rel="stylesheet" href="/main.css">
- <script src="/main.js"></script>
+ <link rel="stylesheet" href="/static/main.css">
+ <script src="/static/main.js"></script>
<a href="/">
<h1>Files</h1>
</a>
'''
-index: str = f'''
+pub index: str = f'''
{page}
<form>
@@ -79,8 +79,14 @@ index: str = f'''
</form>
'''
-def download(href: str, filename: str, expires: str) -> str:
+def file_page(file_id: str, filename: str, expires: str) -> str:
+ href = f'{file_id}/download'
expires_in = datetime.datetime.strptime(expires, '%Y-%m-%d %H:%M:%S') - datetime.datetime.now()
+
+ print()
+ print(href)
+ print()
+
return f'''
{page}
diff --git a/src/utils.py b/src/utils.py
index ccf92c0..151217f 100644
--- a/src/utils.py
+++ b/src/utils.py
@@ -1,16 +1,5 @@
import io
-def transfer(reader, writer, content_length = None, buffer_size = io.DEFAULT_BUFFER_SIZE):
- if content_length is None:
- while (data := reader.read(buffer_size)):
- writer.write(data)
- else:
- remaining = content_length
- while remaining > 0:
- size = min(buffer_size, remaining)
- writer.write(reader.read(size))
- remaining -= size
-
def sanitize_filename(s: str) -> str:
return '.'.join([sanitize_filename_part(p) for p in s.split('.')])
diff --git a/public/main.css b/static/main.css
index db9a678..db9a678 100644
--- a/public/main.css
+++ b/static/main.css
diff --git a/public/main.js b/static/main.js
index 1729d38..1729d38 100644
--- a/public/main.js
+++ b/static/main.js