diff options
author | Joris | 2024-06-02 14:38:13 +0200 |
---|---|---|
committer | Joris | 2024-06-02 14:38:22 +0200 |
commit | 1019ea1ed341e3a7769c046aa0be5764789360b6 (patch) | |
tree | 1a0d8a4f00cff252d661c42fc23ed4c19795da6f /src | |
parent | e8da9790dc6d55cd2e8883322cdf9a7bf5b4f5b7 (diff) |
Migrate to Rust and Hyper
With sanic, downloading a file locally is around ten times slower than
with Rust and hyper.
Maybe `pypy` could have helped, but I didn’t succeed to set it up
quickly with the dependencies.
Diffstat (limited to 'src')
-rw-r--r-- | src/controller.py | 58 | ||||
-rw-r--r-- | src/db.py | 24 | ||||
-rw-r--r-- | src/db.rs | 125 | ||||
-rw-r--r-- | src/main-old.py | 33 | ||||
-rw-r--r-- | src/main.py | 28 | ||||
-rw-r--r-- | src/main.rs | 68 | ||||
-rw-r--r-- | src/model.rs | 52 | ||||
-rw-r--r-- | src/routes.rs | 209 | ||||
-rw-r--r-- | src/server.py | 84 | ||||
-rw-r--r-- | src/static/main.css | 94 | ||||
-rw-r--r-- | src/static/main.js | 49 | ||||
-rw-r--r-- | src/templates.py | 105 | ||||
-rw-r--r-- | src/templates.rs | 134 | ||||
-rw-r--r-- | src/util.rs | 125 | ||||
-rw-r--r-- | src/utils.py | 8 |
15 files changed, 856 insertions, 340 deletions
diff --git a/src/controller.py b/src/controller.py deleted file mode 100644 index 351d0bc..0000000 --- a/src/controller.py +++ /dev/null @@ -1,58 +0,0 @@ -import io -import logging -import os -import sanic -import sqlite3 -import tempfile - -import db -import templates -import utils - -conn = sqlite3.connect('db.sqlite3') -files_directory = 'files' -authorized_key = os.environ['KEY'] - -def index(): - return sanic.html(templates.index) - -async def upload(request): - key = request.headers.get('X-Key') - if not key == authorized_key: - sanic.log.logging.info('Unauthorized to upload file: wrong key') - return sanic.text('Unauthorized', status = 401) - else: - sanic.log.logging.info('Uploading file') - content_length = int(request.headers.get('content-length')) - filename = utils.sanitize_filename(request.headers.get('X-FileName')) - expiration = request.headers.get('X-Expiration') - - with tempfile.NamedTemporaryFile(delete = False) as tmp: - while data := await request.stream.read(): - tmp.write(data) - - sanic.log.logging.info('File uploaded') - file_id = db.insert_file(conn, filename, expiration, content_length) - os.makedirs(files_directory, exist_ok=True) - os.rename(tmp.name, os.path.join(files_directory, file_id)) - - return sanic.text(file_id) - -async def file(file_id: str, download: bool): - res = db.get_file(conn, file_id) - if res is None: - self._serve_str(templates.not_found, 404, 'text/html') - else: - filename, expires, content_length = res - disk_path = os.path.join(files_directory, file_id) - if download: - return await sanic.response.file_stream( - disk_path, - chunk_size = io.DEFAULT_BUFFER_SIZE, - headers = { - 'Content-Disposition': f'attachment; filename={filename}', - 'Content-Length': content_length - } - ) - else: - return sanic.html(templates.file_page(file_id, filename, expires)) diff --git a/src/db.py b/src/db.py deleted file mode 100644 index a6e29fd..0000000 --- a/src/db.py +++ /dev/null @@ -1,24 +0,0 @@ -import secrets - -def insert_file(conn, filename: str, expiration_days: int, content_length: int): - cur = conn.cursor() - file_id = secrets.token_urlsafe() - cur.execute( - 'INSERT INTO files(id, filename, created, expires, content_length) VALUES(?, ?, datetime(), datetime(datetime(), ?), ?)', - (file_id, filename, f'+{expiration_days} days', content_length) - ) - conn.commit() - return file_id - -def get_file(conn, file_id: str): - cur = conn.cursor() - res = cur.execute( - ''' - SELECT filename, expires, content_length - FROM files - WHERE id = ? AND expires > datetime() - ''', - (file_id,) - ) - return res.fetchone() - diff --git a/src/db.rs b/src/db.rs new file mode 100644 index 0000000..e1bb7e3 --- /dev/null +++ b/src/db.rs @@ -0,0 +1,125 @@ +use tokio_rusqlite::{params, Connection, Result}; + +use crate::model::{decode_datetime, encode_datetime, File}; + +pub async fn insert_file(conn: &Connection, file: File) -> Result<()> { + conn.call(move |conn| { + conn.execute( + r#" + INSERT INTO + files(id, created_at, expires_at, filename, content_length) + VALUES + (?1, datetime(), ?2, ?3, ?4) + "#, + params![ + file.id, + encode_datetime(file.expires_at), + file.name, + file.content_length + ], + ) + .map_err(tokio_rusqlite::Error::Rusqlite) + }) + .await + .map(|_| ()) +} + +pub async fn get_file(conn: &Connection, file_id: String) -> Result<Option<File>> { + conn.call(move |conn| { + let mut stmt = conn.prepare( + r#" + SELECT + filename, expires_at, content_length + FROM + files + WHERE + id = ? + AND expires_at > datetime() + "#, + )?; + + let mut iter = stmt.query_map([file_id.clone()], |row| { + let res: (String, String, usize) = (row.get(0)?, row.get(1)?, row.get(2)?); + Ok(res) + })?; + + match iter.next() { + Some(Ok((filename, expires_at, content_length))) => { + match decode_datetime(&expires_at) { + Some(expires_at) => Ok(Some(File { + id: file_id.clone(), + name: filename, + expires_at, + content_length, + })), + _ => Err(rusqlite_other_error(&format!( + "Error decoding datetime: {expires_at}" + ))), + } + } + Some(_) => Err(rusqlite_other_error("Error reading file in DB")), + None => Ok(None), + } + }) + .await +} + +fn rusqlite_other_error(msg: &str) -> tokio_rusqlite::Error { + tokio_rusqlite::Error::Other(msg.into()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::model::local_time; + use chrono::Duration; + use std::ops::Add; + + #[tokio::test] + async fn test_insert_and_get_file() { + let conn = get_connection().await; + let file = dummy_file(Duration::minutes(1)); + assert!(insert_file(&conn, file.clone()).await.is_ok()); + let file_res = get_file(&conn, file.id.clone()).await; + assert!(file_res.is_ok()); + assert_eq!(file_res.unwrap(), Some(file)); + } + + #[tokio::test] + async fn test_expired_file_err() { + let conn = get_connection().await; + let file = dummy_file(Duration::zero()); + assert!(insert_file(&conn, file.clone()).await.is_ok()); + let file_res = get_file(&conn, file.id.clone()).await; + assert!(file_res.is_ok()); + assert!(file_res.unwrap().is_none()); + } + + #[tokio::test] + async fn test_wrong_file_err() { + let conn = get_connection().await; + let file = dummy_file(Duration::minutes(1)); + assert!(insert_file(&conn, file.clone()).await.is_ok()); + let file_res = get_file(&conn, "wrong-id".to_string()).await; + assert!(file_res.is_ok()); + assert!(file_res.unwrap().is_none()); + } + + fn dummy_file(td: Duration) -> File { + File { + id: "1234".to_string(), + name: "foo".to_string(), + expires_at: local_time().add(td), + content_length: 100, + } + } + + async fn get_connection() -> Connection { + let conn = Connection::open_in_memory().await.unwrap(); + let init_db = tokio::fs::read_to_string("init-db.sql").await.unwrap(); + + let res = conn.call(move |conn| Ok(conn.execute(&init_db, []))).await; + assert!(res.is_ok()); + conn + } +} diff --git a/src/main-old.py b/src/main-old.py deleted file mode 100644 index 42d7c8c..0000000 --- a/src/main-old.py +++ /dev/null @@ -1,33 +0,0 @@ -# import http.server -# import logging -# import os -# import sys - -# import server - -# logger = logging.getLogger(__name__) -# hostName = os.environ['HOST'] -# serverPort = int(os.environ['PORT']) - -# if __name__ == '__main__': -# logging.basicConfig(stream=sys.stdout, level=logging.INFO) -# webServer = http.server.HTTPServer((hostName, serverPort), server.MyServer) -# logger.info('Server started at http://%s:%s.' % (hostName, serverPort)) - -# try: -# webServer.serve_forever() -# except KeyboardInterrupt: -# pass - -# webServer.server_close() -# conn.close() -# logger.info('Server stopped.') - -from sanic import Sanic -from sanic.response import text - -app = Sanic("MyHelloWorldApp") - -@app.get("/") -async def hello_world(request): - return text("Hello, world.") diff --git a/src/main.py b/src/main.py deleted file mode 100644 index b678aae..0000000 --- a/src/main.py +++ /dev/null @@ -1,28 +0,0 @@ -import sanic -import os - -import controller - -app = sanic.Sanic("Files") - -@app.get("/") -async def index(request): - return controller.index() - -@app.post("/", stream = True) -async def upload(request): - return await controller.upload(request) - -@app.get("/<file_id:str>") -async def file_page(request, file_id): - return await controller.file(file_id, download = False) - -@app.get("/<file_id:str>/download") -async def file_download(request, file_id): - return await controller.file(file_id, download = True) - -app.static("/static/", "static/") - -if __name__ == "__main__": - debug = 'DEBUG' in os.environ and os.environ['DEBUG'] == 'TRUE' - app.run(debug=debug, access_log=True) diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..27da278 --- /dev/null +++ b/src/main.rs @@ -0,0 +1,68 @@ +use std::env; +use std::net::SocketAddr; + +use hyper::server::conn::http1; +use hyper::service::service_fn; +use hyper_util::rt::TokioIo; +use tokio::net::TcpListener; +use tokio_rusqlite::Connection; + +mod db; +mod model; +mod routes; +mod templates; +mod util; + +#[tokio::main] +async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> { + env_logger::init(); + + let host = get_env("HOST"); + let port = get_env("PORT"); + let db_path = get_env("DB"); + let authorized_key = get_env("KEY"); + let files_dir = get_env("FILES_DIR"); + + let db_conn = Connection::open(db_path) + .await + .expect("Error while openning DB conection"); + + let addr: SocketAddr = format!("{host}:{port}") + .parse() + .unwrap_or_else(|_| panic!("Invalid address: {host}:{port}")); + + let listener = TcpListener::bind(addr).await?; + log::info!("Listening on http://{}", addr); + + loop { + let (stream, _) = listener.accept().await?; + let io = TokioIo::new(stream); + + let db_conn = db_conn.clone(); + let authorized_key = authorized_key.clone(); + let files_dir = files_dir.clone(); + + tokio::task::spawn(async move { + if let Err(err) = http1::Builder::new() + .serve_connection( + io, + service_fn(move |req| { + routes::routes( + req, + db_conn.clone(), + authorized_key.clone(), + files_dir.clone(), + ) + }), + ) + .await + { + log::error!("Failed to serve connection: {:?}", err); + } + }); + } +} + +fn get_env(key: &str) -> String { + env::var(key).unwrap_or_else(|_| panic!("Missing environment variable {key}")) +} diff --git a/src/model.rs b/src/model.rs new file mode 100644 index 0000000..ed4fbf8 --- /dev/null +++ b/src/model.rs @@ -0,0 +1,52 @@ +use base64::{engine::general_purpose::URL_SAFE, Engine as _}; +use chrono::{DateTime, Local, NaiveDateTime, TimeZone}; +use rand_core::{OsRng, RngCore}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct File { + pub id: String, + pub name: String, + pub expires_at: DateTime<Local>, + pub content_length: usize, +} + +pub fn local_time() -> DateTime<Local> { + let dt = Local::now(); + match decode_datetime(&encode_datetime(dt)) { + Some(res) => res, + None => dt, + } +} + +// Using 20 bytes (160 bits) to file identifiers +// https://owasp.org/www-community/vulnerabilities/Insufficient_Session-ID_Length +// https://www.rfc-editor.org/rfc/rfc6749.html#section-10.10 +const FILE_ID_BYTES: usize = 20; + +pub fn generate_file_id() -> String { + let mut token = [0u8; FILE_ID_BYTES]; + OsRng.fill_bytes(&mut token); + URL_SAFE.encode(token) +} + +const FORMAT: &str = "%Y-%m-%d %H:%M:%S"; + +pub fn encode_datetime(dt: DateTime<Local>) -> String { + dt.naive_utc().format(FORMAT).to_string() +} + +pub fn decode_datetime(str: &str) -> Option<DateTime<Local>> { + let naive_time = NaiveDateTime::parse_from_str(str, FORMAT).ok()?; + Some(Local.from_utc_datetime(&naive_time)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_datetime_serialization() { + let dt = local_time(); + assert_eq!(decode_datetime(&encode_datetime(dt)), Some(dt)) + } +} diff --git a/src/routes.rs b/src/routes.rs new file mode 100644 index 0000000..b54e565 --- /dev/null +++ b/src/routes.rs @@ -0,0 +1,209 @@ +use std::ops::Add; +use std::path::{Path, PathBuf}; + +use chrono::Duration; +use futures_util::TryStreamExt; +use http_body_util::{combinators::BoxBody, BodyExt, Full, StreamBody}; +use hyper::body::{Bytes, Frame, Incoming}; +use hyper::header::{HeaderName, HeaderValue, CONTENT_DISPOSITION, CONTENT_LENGTH, CONTENT_TYPE}; +use hyper::{Method, Request, Response, Result, StatusCode}; +use tokio::io::AsyncWriteExt; +use tokio::{fs, fs::File}; +use tokio_rusqlite::Connection; +use tokio_util::io::ReaderStream; + +use crate::db; +use crate::model; +use crate::templates; +use crate::util; + +pub async fn routes( + request: Request<Incoming>, + db_conn: Connection, + authorized_key: String, + files_dir: String, +) -> Result<Response<BoxBody<Bytes, std::io::Error>>> { + let path = &request.uri().path().split('/').collect::<Vec<&str>>()[1..]; + let files_dir = Path::new(&files_dir); + + match (request.method(), path) { + (&Method::GET, [""]) => Ok(response(StatusCode::OK, templates::INDEX.to_string())), + (&Method::GET, ["static", "main.js"]) => Ok(static_file( + include_str!("static/main.js").to_string(), + "application/javascript", + )), + (&Method::GET, ["static", "main.css"]) => Ok(static_file( + include_str!("static/main.css").to_string(), + "text/css", + )), + (&Method::POST, [""]) => upload_file(request, db_conn, authorized_key, files_dir).await, + (&Method::GET, [file_id]) => get(db_conn, file_id, GetFile::ShowPage, files_dir).await, + (&Method::GET, [file_id, "download"]) => { + get(db_conn, file_id, GetFile::Download, files_dir).await + } + _ => Ok(not_found()), + } +} + +async fn upload_file( + request: Request<Incoming>, + db_conn: Connection, + authorized_key: String, + files_dir: &Path, +) -> Result<Response<BoxBody<Bytes, std::io::Error>>> { + let key = get_header(&request, "X-Key"); + if key != Some(authorized_key) { + log::info!("Unauthorized file upload"); + Ok(response( + StatusCode::UNAUTHORIZED, + "Unauthorized".to_string(), + )) + } else { + let file_id = model::generate_file_id(); + let filename = get_header(&request, "X-Filename").map(|s| util::sanitize_filename(&s)); + let expiration_days: Option<i64> = + get_header(&request, "X-Expiration").and_then(|s| s.parse().ok()); + let content_length: Option<usize> = + get_header(&request, "Content-Length").and_then(|s| s.parse().ok()); + + match (filename, expiration_days, content_length) { + (Some(filename), Some(expiration_days), Some(content_length)) => { + let _ = fs::create_dir(files_dir).await; + let path = files_dir.join(&file_id); + let mut file = File::create(&path).await.unwrap(); + + let mut incoming = request.into_body(); + while let Some(frame) = incoming.frame().await { + if let Ok(data) = frame { + let _ = file.write_all(&data.into_data().unwrap()).await; + let _ = file.flush().await; + } + } + + let file = model::File { + id: file_id.clone(), + name: filename, + expires_at: model::local_time().add(Duration::days(expiration_days)), + content_length, + }; + + match db::insert_file(&db_conn, file.clone()).await { + Ok(_) => Ok(response(StatusCode::OK, file_id)), + Err(msg) => { + log::error!("Insert file: {msg}"); + if let Err(msg) = fs::remove_file(path).await { + log::error!("Remove file: {msg}"); + }; + Ok(internal_server_error()) + } + } + } + _ => Ok(bad_request()), + } + } +} + +fn get_header(request: &Request<Incoming>, header: &str) -> Option<String> { + request + .headers() + .get(header)? + .to_str() + .ok() + .map(|str| str.to_string()) +} + +enum GetFile { + ShowPage, + Download, +} + +async fn get( + db_conn: Connection, + file_id: &str, + get_file: GetFile, + files_dir: &Path, +) -> Result<Response<BoxBody<Bytes, std::io::Error>>> { + let file = db::get_file(&db_conn, file_id.to_string()).await; + match (get_file, file) { + (GetFile::ShowPage, Ok(Some(file))) => { + Ok(response(StatusCode::OK, templates::file_page(file))) + } + (GetFile::Download, Ok(Some(file))) => { + let path = files_dir.join(file_id); + Ok(stream_file(path, file).await) + } + (_, Err(msg)) => { + log::error!("Getting file: {msg}"); + Ok(internal_server_error()) + } + (_, Ok(None)) => Ok(not_found()), + } +} + +fn static_file(text: String, content_type: &str) -> Response<BoxBody<Bytes, std::io::Error>> { + let response = Response::builder() + .body(Full::new(text.into()).map_err(|e| match e {}).boxed()) + .unwrap(); + with_headers(response, vec![(CONTENT_TYPE, content_type)]) +} + +fn response(status_code: StatusCode, text: String) -> Response<BoxBody<Bytes, std::io::Error>> { + Response::builder() + .status(status_code) + .body(Full::new(text.into()).map_err(|e| match e {}).boxed()) + .unwrap() +} + +async fn stream_file(path: PathBuf, file: model::File) -> Response<BoxBody<Bytes, std::io::Error>> { + match File::open(path).await { + Err(e) => { + log::error!("Unable to open file: {e}"); + not_found() + } + Ok(disk_file) => { + let reader_stream = ReaderStream::new(disk_file); + let stream_body = StreamBody::new(reader_stream.map_ok(Frame::data)); + let boxed_body = stream_body.boxed(); + + let response = Response::builder().body(boxed_body).unwrap(); + + with_headers( + response, + vec![ + ( + CONTENT_DISPOSITION, + &format!("attachment; filename={}", file.name), + ), + (CONTENT_LENGTH, &file.content_length.to_string()), + ], + ) + } + } +} + +fn not_found() -> Response<BoxBody<Bytes, std::io::Error>> { + response(StatusCode::NOT_FOUND, templates::NOT_FOUND.to_string()) +} + +fn bad_request() -> Response<BoxBody<Bytes, std::io::Error>> { + response(StatusCode::BAD_REQUEST, templates::BAD_REQUEST.to_string()) +} + +fn internal_server_error() -> Response<BoxBody<Bytes, std::io::Error>> { + response( + StatusCode::INTERNAL_SERVER_ERROR, + templates::INTERNAL_SERVER_ERROR.to_string(), + ) +} + +pub fn with_headers( + response: Response<BoxBody<Bytes, std::io::Error>>, + headers: Vec<(HeaderName, &str)>, +) -> Response<BoxBody<Bytes, std::io::Error>> { + let mut response = response; + let response_headers = response.headers_mut(); + for (name, value) in headers { + response_headers.insert(name, HeaderValue::from_str(value).unwrap()); + } + response +} diff --git a/src/server.py b/src/server.py deleted file mode 100644 index 5927052..0000000 --- a/src/server.py +++ /dev/null @@ -1,84 +0,0 @@ -import http.server -import logging -import os -import sqlite3 -import tempfile - -import db -import templates -import utils - -logger = logging.getLogger(__name__) -conn = sqlite3.connect('db.sqlite3') -files_directory = 'files' -authorized_key = os.environ['KEY'] - -class MyServer(http.server.BaseHTTPRequestHandler): - def do_GET(self): - match self.path: - case '/': - self._serve_str(templates.index, 200, 'text/html') - case '/main.js': - self._serve_file('public/main.js', 'application/javascript') - case '/main.css': - self._serve_file('public/main.css', 'text/css') - case path: - if path.endswith('?download'): - download = True - path = path[:-len('?download')] - else: - download = False - - file_id = path[1:] - res = db.get_file(conn, file_id) - if res is None: - self._serve_str(templates.not_found, 404, 'text/html') - else: - filename, expires, content_length = res - disk_path = os.path.join(files_directory, file_id) - if download: - headers = [ - ('Content-Disposition', f'attachment; filename={filename}'), - ('Content-Length', content_length) - ] - self._serve_file(disk_path, 'application/octet-stream', headers) - else: - href = f'{file_id}?download' - self._serve_str(templates.download(href, filename, expires), 200, 'text/html') - - def do_POST(self): - key = self.headers['X-Key'] - if not key == authorized_key: - logging.info('Unauthorized to upload file: wrong key') - self._serve_str('Unauthorized', 401) - - else: - logging.info('Uploading file') - content_length = int(self.headers['content-length']) - filename = utils.sanitize_filename(self.headers['X-FileName']) - expiration = self.headers['X-Expiration'] - - with tempfile.NamedTemporaryFile(delete = False) as tmp: - utils.transfer(self.rfile, tmp, content_length = content_length) - - logging.info('File uploaded') - file_id = db.insert_file(conn, filename, expiration, content_length) - os.makedirs(files_directory, exist_ok=True) - os.rename(tmp.name, os.path.join(files_directory, file_id)) - - self._serve_str(file_id, 200) - - def _serve_str(self, s, code, content_type='text/plain'): - self.send_response(code) - self.send_header('Content-type', content_type) - self.end_headers() - self.wfile.write(bytes(s, 'utf-8')) - - def _serve_file(self, filename, content_type, headers = []): - self.send_response(200) - self.send_header('Content-type', content_type) - for header_name, header_value in headers: - self.send_header(header_name, header_value) - self.end_headers() - with open(filename, 'rb') as f: - utils.transfer(f, self.wfile) diff --git a/src/static/main.css b/src/static/main.css new file mode 100644 index 0000000..af0ee54 --- /dev/null +++ b/src/static/main.css @@ -0,0 +1,94 @@ +html { + margin: 0 1rem; + font-size: 16px; + line-height: 1.4rem; + font-family: sans-serif; + box-sizing: border-box; +} + +*, *:before, *:after { + box-sizing: inherit; +} + +body { + max-width: 30rem; + margin: 0 auto; +} + +a { + text-decoration: none; + color: #06C; +} + +h1 { + text-align: center; + font-variant: small-caps; + font-size: 40px; + letter-spacing: 0.2rem; + margin-bottom: 4rem; +} + +.g-Link { + text-decoration: underline; +} + +label { + display: flex; + gap: 0.5rem; + flex-direction: column; + margin-bottom: 2rem; +} + +input, select { + font-size: inherit; + border: 1px solid black; + height: 2rem; + background: white; +} + +input[type=file] { + align-content: center; + padding-left: 2px; +} + +input[type=submit] { + width: 100%; + background: #06C; + cursor: pointer; + border: none; + color: white; +} + +.g-Loading { + display: none; + align-items: center; + justify-content: center; + gap: 1rem; + margin-bottom: 2rem; +} + +.g-Error { + text-align: center; + margin-bottom: 2rem; + color: #C00; +} + +.g-Spinner { + width: 25px; + height: 25px; + border: 4px solid #06C; + border-bottom-color: transparent; + border-radius: 50%; + display: inline-block; + box-sizing: border-box; + animation: rotation 1s linear infinite; +} + +@keyframes rotation { + 0% { + transform: rotate(0deg); + } + 100% { + transform: rotate(360deg); + } +} diff --git a/src/static/main.js b/src/static/main.js new file mode 100644 index 0000000..40e62d6 --- /dev/null +++ b/src/static/main.js @@ -0,0 +1,49 @@ +window.onload = function() { + const form = document.querySelector('form') + + if (form !== null) { + const submit = document.querySelector('input[type="submit"]') + const loading = document.querySelector('.g-Loading') + const error = document.querySelector('.g-Error') + + function showError(msg) { + loading.style.display = 'none' + submit.disabled = false + error.innerText = msg + error.style.display = 'block' + } + + form.onsubmit = function(event) { + event.preventDefault() + + loading.style.display = 'flex' + submit.disabled = true + error.style.display = 'none' + + const key = document.querySelector('input[name="key"]').value + const expiration = document.querySelector('select[name="expiration"]').value + const file = document.querySelector('input[name="file"]').files[0] + const filename = file.name.replace(/[^0-9a-zA-Z\.]/gi, '-') + + // Wait a bit to prevent showing the loader too briefly + setTimeout(function() { + const xhr = new XMLHttpRequest() + xhr.open('POST', '/', true) + xhr.onload = function () { + if (xhr.status === 200) { + window.location = `/${xhr.responseText}` + } else { + showError(`Error uploading: ${xhr.status}`) + } + } + xhr.onerror = function () { + showError('Upload error') + } + xhr.setRequestHeader('X-FileName', filename) + xhr.setRequestHeader('X-Expiration', expiration) + xhr.setRequestHeader('X-Key', key) + xhr.send(file) + }, 500) + } + } +} diff --git a/src/templates.py b/src/templates.py deleted file mode 100644 index 8125f69..0000000 --- a/src/templates.py +++ /dev/null @@ -1,105 +0,0 @@ -import html -import datetime - -page: str = ''' - <!doctype html> - <html lang="fr"> - <meta charset="utf-8"> - <meta name="viewport" content="width=device-width"> - - <title>Files</title> - <link rel="stylesheet" href="/static/main.css"> - <script src="/static/main.js"></script> - - <a href="/"> - <h1>Files</h1> - </a> -''' - -pub index: str = f''' - {page} - - <form> - <label> - File - <input type="file" name="file" required> - </label> - - <label> - Expiration - <select name="expiration"> - <option value="1">1 day</option> - <option value="2">2 days</option> - <option value="3">3 days</option> - <option value="4">4 days</option> - <option value="5">5 days</option> - <option value="6">6 days</option> - <option value="7" selected>7 days</option> - <option value="8">8 days</option> - <option value="9">9 days</option> - <option value="10">10 days</option> - <option value="11">11 days</option> - <option value="12">12 days</option> - <option value="13">13 days</option> - <option value="14">14 days</option> - <option value="15">15 days</option> - <option value="16">16 days</option> - <option value="17">17 days</option> - <option value="18">18 days</option> - <option value="19">19 days</option> - <option value="20">20 days</option> - <option value="21">21 days</option> - <option value="22">22 days</option> - <option value="23">23 days</option> - <option value="24">24 days</option> - <option value="25">25 days</option> - <option value="26">26 days</option> - <option value="27">27 days</option> - <option value="28">28 days</option> - <option value="29">29 days</option> - <option value="30">30 days</option> - <option value="31">31 days</option> - </select> - </label> - - <label> - Key - <input type="password" name="key" required> - </label> - - <div class="g-Loading"> - <div class="g-Spinner"></div> - Uploading… - </div> - - <div class="g-Error"> - </div> - - <input type="submit" value="Upload"> - </form> -''' - -def file_page(file_id: str, filename: str, expires: str) -> str: - href = f'{file_id}/download' - expires_in = datetime.datetime.strptime(expires, '%Y-%m-%d %H:%M:%S') - datetime.datetime.now() - - print() - print(href) - print() - - return f''' - {page} - - <div> - <a class="g-Link" href="{html.escape(href)}">{html.escape(filename)}</a> - <div> - Expires in {expires_in} - </div> - </div> - ''' - -not_found: str = f''' - {page} - - Oops, not found! -''' diff --git a/src/templates.rs b/src/templates.rs new file mode 100644 index 0000000..b551bf6 --- /dev/null +++ b/src/templates.rs @@ -0,0 +1,134 @@ +use chrono::Local; + +use crate::model::File; +use crate::util; + +const PAGE: &str = r#" +<!doctype html> +<html lang="fr"> +<meta charset="utf-8"> +<meta name="viewport" content="width=device-width"> + +<title>Files</title> +<link rel="stylesheet" href="/static/main.css"> +<script src="/static/main.js"></script> + +<a href="/"> + <h1>Files</h1> +</a> +"#; + +pub const INDEX: &str = const_format::concatcp!( + PAGE, + r#" +<form> + <label> + File + <input type="file" name="file" required> + </label> + + <label> + Expiration + <select name="expiration"> + <option value="1">1 day</option> + <option value="2">2 days</option> + <option value="3">3 days</option> + <option value="4">4 days</option> + <option value="5">5 days</option> + <option value="6">6 days</option> + <option value="7" selected>7 days</option> + <option value="8">8 days</option> + <option value="9">9 days</option> + <option value="10">10 days</option> + <option value="11">11 days</option> + <option value="12">12 days</option> + <option value="13">13 days</option> + <option value="14">14 days</option> + <option value="15">15 days</option> + <option value="16">16 days</option> + <option value="17">17 days</option> + <option value="18">18 days</option> + <option value="19">19 days</option> + <option value="20">20 days</option> + <option value="21">21 days</option> + <option value="22">22 days</option> + <option value="23">23 days</option> + <option value="24">24 days</option> + <option value="25">25 days</option> + <option value="26">26 days</option> + <option value="27">27 days</option> + <option value="28">28 days</option> + <option value="29">29 days</option> + <option value="30">30 days</option> + <option value="31">31 days</option> + </select> + </label> + + <label> + Key + <input type="password" name="key" required> + </label> + + <div class="g-Loading"> + <div class="g-Spinner"></div> + Uploading… + </div> + + <div class="g-Error"> + </div> + + <input type="submit" value="Upload"> +</form>"# +); + +pub const NOT_FOUND: &str = const_format::concatcp!( + PAGE, + r#" + <div> + Oops, not found. + </div> + "# +); + +pub const BAD_REQUEST: &str = const_format::concatcp!( + PAGE, + r#" + <div> + Oops, bad request. + </div> + "# +); + +pub const INTERNAL_SERVER_ERROR: &str = const_format::concatcp!( + PAGE, + r#" + <div> + Oops, internal server error. + </div> + "# +); + +pub fn file_page(file: File) -> String { + let href = format!("{}/download", file.id); + let expiration = file.expires_at.signed_duration_since(Local::now()); + + format!( + r#" + {page} + + <div> + <div> + <a class="g-Link" href="{href}">{filename}</a> – {size} + </div> + <div> + Expires in <b>{expiration}</b>. + </div> + </div> + "#, + page = PAGE, + href = html_escape::encode_text(&href), + filename = html_escape::encode_text(&file.name), + expiration = html_escape::encode_text(&util::pretty_print_duration(expiration)), + size = html_escape::encode_text(&util::pretty_print_bytes(file.content_length)) + ) +} diff --git a/src/util.rs b/src/util.rs new file mode 100644 index 0000000..9bc7cb9 --- /dev/null +++ b/src/util.rs @@ -0,0 +1,125 @@ +use chrono::Duration; + +pub fn sanitize_filename(s: &str) -> String { + s.split('.') + .map(sanitize_filename_part) + .collect::<Vec<String>>() + .join(".") +} + +pub fn sanitize_filename_part(s: &str) -> String { + s.chars() + .map(|c| { + if c.is_ascii_alphanumeric() { + c.to_lowercase().collect::<String>() + } else { + " ".to_string() + } + }) + .collect::<String>() + .split_whitespace() + .collect::<Vec<&str>>() + .join("-") +} + +pub fn pretty_print_duration(d: Duration) -> String { + if d.num_days() > 0 { + let plural = if d.num_days() > 1 { "s" } else { "" }; + format!("{} day{}", d.num_days(), plural) + } else if d.num_hours() > 0 { + format!("{} h", d.num_hours()) + } else if d.num_minutes() > 0 { + format!("{} min", d.num_minutes()) + } else { + format!("{} s", d.num_seconds()) + } +} + +pub fn pretty_print_bytes(bytes: usize) -> String { + let ko = bytes / 1024; + let mo = ko / 1024; + let go = mo / 1024; + if go > 0 { + format!("{} GB", go) + } else if mo > 0 { + format!("{} MB", mo) + } else if ko > 0 { + format!("{} KB", ko) + } else { + format!("{} B", bytes) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_sanitize_filename() { + assert_eq!(sanitize_filename(""), ""); + assert_eq!(sanitize_filename("foo bar 123"), "foo-bar-123"); + assert_eq!(sanitize_filename("foo bar.123"), "foo-bar.123"); + assert_eq!(sanitize_filename("foo ( test+2 ).xml"), "foo-test-2.xml"); + } + + #[test] + fn test_sanitize_filename_part() { + assert_eq!(sanitize_filename_part(""), ""); + assert_eq!(sanitize_filename_part("foo123BAZ"), "foo123baz"); + assert_eq!(sanitize_filename_part("foo-123-BAZ"), "foo-123-baz"); + assert_eq!(sanitize_filename_part("[()] */+-!;?<'> ?:"), ""); + assert_eq!(sanitize_filename_part("foo [bar] -- BAZ3"), "foo-bar-baz3"); + } + + #[test] + fn test_pretty_print_duration() { + assert_eq!( + pretty_print_duration(Duration::days(2)), + "2 days".to_string() + ); + assert_eq!( + pretty_print_duration(Duration::hours(30)), + "1 day".to_string() + ); + assert_eq!( + pretty_print_duration(Duration::days(1)), + "1 day".to_string() + ); + assert_eq!( + pretty_print_duration(Duration::hours(15)), + "15 h".to_string() + ); + assert_eq!( + pretty_print_duration(Duration::minutes(70)), + "1 h".to_string() + ); + assert_eq!( + pretty_print_duration(Duration::minutes(44)), + "44 min".to_string() + ); + assert_eq!( + pretty_print_duration(Duration::seconds(100)), + "1 min".to_string() + ); + assert_eq!( + pretty_print_duration(Duration::seconds(7)), + "7 s".to_string() + ); + assert_eq!(pretty_print_duration(Duration::zero()), "0 s".to_string()); + } + + #[test] + fn test_pretty_print_bytes() { + assert_eq!(pretty_print_bytes(0), "0 B"); + assert_eq!(pretty_print_bytes(10), "10 B"); + assert_eq!(pretty_print_bytes(1024), "1 KB"); + assert_eq!(pretty_print_bytes(1100), "1 KB"); + assert_eq!(pretty_print_bytes(54 * 1024), "54 KB"); + assert_eq!(pretty_print_bytes(1024 * 1024), "1 MB"); + assert_eq!(pretty_print_bytes(1300 * 1024), "1 MB"); + assert_eq!(pretty_print_bytes(79 * 1024 * 1024), "79 MB"); + assert_eq!(pretty_print_bytes(1024 * 1024 * 1024), "1 GB"); + assert_eq!(pretty_print_bytes(1300 * 1024 * 1024), "1 GB"); + assert_eq!(pretty_print_bytes(245 * 1024 * 1024 * 1024), "245 GB"); + } +} diff --git a/src/utils.py b/src/utils.py deleted file mode 100644 index 151217f..0000000 --- a/src/utils.py +++ /dev/null @@ -1,8 +0,0 @@ -import io - -def sanitize_filename(s: str) -> str: - return '.'.join([sanitize_filename_part(p) for p in s.split('.')]) - -def sanitize_filename_part(s: str) -> str: - alnum_or_space = ''.join([c if c.isalnum() else ' ' for c in s]) - return '-'.join(alnum_or_space.split()) |