diff options
-rw-r--r-- | flake.nix | 3 | ||||
-rw-r--r-- | src/controller.py | 58 | ||||
-rw-r--r-- | src/main-old.py | 33 | ||||
-rw-r--r-- | src/main.py | 40 | ||||
-rw-r--r-- | src/templates.py | 14 | ||||
-rw-r--r-- | src/utils.py | 11 | ||||
-rw-r--r-- | static/main.css (renamed from public/main.css) | 0 | ||||
-rw-r--r-- | static/main.js (renamed from public/main.js) | 0 |
8 files changed, 126 insertions, 33 deletions
@@ -13,9 +13,12 @@ (pkgs.python3.withPackages (pythonPackages: with pythonPackages; [ sqlite watchexec + sanic + setuptools ])) ]; shellHook = '' + export DEBUG="TRUE" export HOST="127.0.0.1" export PORT="8080" export KEY="1234" diff --git a/src/controller.py b/src/controller.py new file mode 100644 index 0000000..351d0bc --- /dev/null +++ b/src/controller.py @@ -0,0 +1,58 @@ +import io +import logging +import os +import sanic +import sqlite3 +import tempfile + +import db +import templates +import utils + +conn = sqlite3.connect('db.sqlite3') +files_directory = 'files' +authorized_key = os.environ['KEY'] + +def index(): + return sanic.html(templates.index) + +async def upload(request): + key = request.headers.get('X-Key') + if not key == authorized_key: + sanic.log.logging.info('Unauthorized to upload file: wrong key') + return sanic.text('Unauthorized', status = 401) + else: + sanic.log.logging.info('Uploading file') + content_length = int(request.headers.get('content-length')) + filename = utils.sanitize_filename(request.headers.get('X-FileName')) + expiration = request.headers.get('X-Expiration') + + with tempfile.NamedTemporaryFile(delete = False) as tmp: + while data := await request.stream.read(): + tmp.write(data) + + sanic.log.logging.info('File uploaded') + file_id = db.insert_file(conn, filename, expiration, content_length) + os.makedirs(files_directory, exist_ok=True) + os.rename(tmp.name, os.path.join(files_directory, file_id)) + + return sanic.text(file_id) + +async def file(file_id: str, download: bool): + res = db.get_file(conn, file_id) + if res is None: + self._serve_str(templates.not_found, 404, 'text/html') + else: + filename, expires, content_length = res + disk_path = os.path.join(files_directory, file_id) + if download: + return await sanic.response.file_stream( + disk_path, + chunk_size = io.DEFAULT_BUFFER_SIZE, + headers = { + 'Content-Disposition': f'attachment; filename={filename}', + 'Content-Length': content_length + } + ) + else: + return sanic.html(templates.file_page(file_id, filename, expires)) diff --git a/src/main-old.py b/src/main-old.py new file mode 100644 index 0000000..42d7c8c --- /dev/null +++ b/src/main-old.py @@ -0,0 +1,33 @@ +# import http.server +# import logging +# import os +# import sys + +# import server + +# logger = logging.getLogger(__name__) +# hostName = os.environ['HOST'] +# serverPort = int(os.environ['PORT']) + +# if __name__ == '__main__': +# logging.basicConfig(stream=sys.stdout, level=logging.INFO) +# webServer = http.server.HTTPServer((hostName, serverPort), server.MyServer) +# logger.info('Server started at http://%s:%s.' % (hostName, serverPort)) + +# try: +# webServer.serve_forever() +# except KeyboardInterrupt: +# pass + +# webServer.server_close() +# conn.close() +# logger.info('Server stopped.') + +from sanic import Sanic +from sanic.response import text + +app = Sanic("MyHelloWorldApp") + +@app.get("/") +async def hello_world(request): + return text("Hello, world.") diff --git a/src/main.py b/src/main.py index 56c8e9e..b678aae 100644 --- a/src/main.py +++ b/src/main.py @@ -1,24 +1,28 @@ -import http.server -import logging +import sanic import os -import sys -import server +import controller -logger = logging.getLogger(__name__) -hostName = os.environ['HOST'] -serverPort = int(os.environ['PORT']) +app = sanic.Sanic("Files") -if __name__ == '__main__': - logging.basicConfig(stream=sys.stdout, level=logging.INFO) - webServer = http.server.HTTPServer((hostName, serverPort), server.MyServer) - logger.info('Server started at http://%s:%s.' % (hostName, serverPort)) +@app.get("/") +async def index(request): + return controller.index() - try: - webServer.serve_forever() - except KeyboardInterrupt: - pass +@app.post("/", stream = True) +async def upload(request): + return await controller.upload(request) - webServer.server_close() - conn.close() - logger.info('Server stopped.') +@app.get("/<file_id:str>") +async def file_page(request, file_id): + return await controller.file(file_id, download = False) + +@app.get("/<file_id:str>/download") +async def file_download(request, file_id): + return await controller.file(file_id, download = True) + +app.static("/static/", "static/") + +if __name__ == "__main__": + debug = 'DEBUG' in os.environ and os.environ['DEBUG'] == 'TRUE' + app.run(debug=debug, access_log=True) diff --git a/src/templates.py b/src/templates.py index c605e57..8125f69 100644 --- a/src/templates.py +++ b/src/templates.py @@ -8,15 +8,15 @@ page: str = ''' <meta name="viewport" content="width=device-width"> <title>Files</title> - <link rel="stylesheet" href="/main.css"> - <script src="/main.js"></script> + <link rel="stylesheet" href="/static/main.css"> + <script src="/static/main.js"></script> <a href="/"> <h1>Files</h1> </a> ''' -index: str = f''' +pub index: str = f''' {page} <form> @@ -79,8 +79,14 @@ index: str = f''' </form> ''' -def download(href: str, filename: str, expires: str) -> str: +def file_page(file_id: str, filename: str, expires: str) -> str: + href = f'{file_id}/download' expires_in = datetime.datetime.strptime(expires, '%Y-%m-%d %H:%M:%S') - datetime.datetime.now() + + print() + print(href) + print() + return f''' {page} diff --git a/src/utils.py b/src/utils.py index ccf92c0..151217f 100644 --- a/src/utils.py +++ b/src/utils.py @@ -1,16 +1,5 @@ import io -def transfer(reader, writer, content_length = None, buffer_size = io.DEFAULT_BUFFER_SIZE): - if content_length is None: - while (data := reader.read(buffer_size)): - writer.write(data) - else: - remaining = content_length - while remaining > 0: - size = min(buffer_size, remaining) - writer.write(reader.read(size)) - remaining -= size - def sanitize_filename(s: str) -> str: return '.'.join([sanitize_filename_part(p) for p in s.split('.')]) diff --git a/public/main.css b/static/main.css index db9a678..db9a678 100644 --- a/public/main.css +++ b/static/main.css diff --git a/public/main.js b/static/main.js index 1729d38..1729d38 100644 --- a/public/main.js +++ b/static/main.js |