aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJoris2024-06-02 14:38:13 +0200
committerJoris2024-06-02 14:38:22 +0200
commit1019ea1ed341e3a7769c046aa0be5764789360b6 (patch)
tree1a0d8a4f00cff252d661c42fc23ed4c19795da6f /src
parente8da9790dc6d55cd2e8883322cdf9a7bf5b4f5b7 (diff)
Migrate to Rust and Hyper
With sanic, downloading a file locally is around ten times slower than with Rust and hyper. Maybe `pypy` could have helped, but I didn’t succeed to set it up quickly with the dependencies.
Diffstat (limited to 'src')
-rw-r--r--src/controller.py58
-rw-r--r--src/db.py24
-rw-r--r--src/db.rs125
-rw-r--r--src/main-old.py33
-rw-r--r--src/main.py28
-rw-r--r--src/main.rs68
-rw-r--r--src/model.rs52
-rw-r--r--src/routes.rs209
-rw-r--r--src/server.py84
-rw-r--r--src/static/main.css94
-rw-r--r--src/static/main.js49
-rw-r--r--src/templates.py105
-rw-r--r--src/templates.rs134
-rw-r--r--src/util.rs125
-rw-r--r--src/utils.py8
15 files changed, 856 insertions, 340 deletions
diff --git a/src/controller.py b/src/controller.py
deleted file mode 100644
index 351d0bc..0000000
--- a/src/controller.py
+++ /dev/null
@@ -1,58 +0,0 @@
-import io
-import logging
-import os
-import sanic
-import sqlite3
-import tempfile
-
-import db
-import templates
-import utils
-
-conn = sqlite3.connect('db.sqlite3')
-files_directory = 'files'
-authorized_key = os.environ['KEY']
-
-def index():
- return sanic.html(templates.index)
-
-async def upload(request):
- key = request.headers.get('X-Key')
- if not key == authorized_key:
- sanic.log.logging.info('Unauthorized to upload file: wrong key')
- return sanic.text('Unauthorized', status = 401)
- else:
- sanic.log.logging.info('Uploading file')
- content_length = int(request.headers.get('content-length'))
- filename = utils.sanitize_filename(request.headers.get('X-FileName'))
- expiration = request.headers.get('X-Expiration')
-
- with tempfile.NamedTemporaryFile(delete = False) as tmp:
- while data := await request.stream.read():
- tmp.write(data)
-
- sanic.log.logging.info('File uploaded')
- file_id = db.insert_file(conn, filename, expiration, content_length)
- os.makedirs(files_directory, exist_ok=True)
- os.rename(tmp.name, os.path.join(files_directory, file_id))
-
- return sanic.text(file_id)
-
-async def file(file_id: str, download: bool):
- res = db.get_file(conn, file_id)
- if res is None:
- self._serve_str(templates.not_found, 404, 'text/html')
- else:
- filename, expires, content_length = res
- disk_path = os.path.join(files_directory, file_id)
- if download:
- return await sanic.response.file_stream(
- disk_path,
- chunk_size = io.DEFAULT_BUFFER_SIZE,
- headers = {
- 'Content-Disposition': f'attachment; filename={filename}',
- 'Content-Length': content_length
- }
- )
- else:
- return sanic.html(templates.file_page(file_id, filename, expires))
diff --git a/src/db.py b/src/db.py
deleted file mode 100644
index a6e29fd..0000000
--- a/src/db.py
+++ /dev/null
@@ -1,24 +0,0 @@
-import secrets
-
-def insert_file(conn, filename: str, expiration_days: int, content_length: int):
- cur = conn.cursor()
- file_id = secrets.token_urlsafe()
- cur.execute(
- 'INSERT INTO files(id, filename, created, expires, content_length) VALUES(?, ?, datetime(), datetime(datetime(), ?), ?)',
- (file_id, filename, f'+{expiration_days} days', content_length)
- )
- conn.commit()
- return file_id
-
-def get_file(conn, file_id: str):
- cur = conn.cursor()
- res = cur.execute(
- '''
- SELECT filename, expires, content_length
- FROM files
- WHERE id = ? AND expires > datetime()
- ''',
- (file_id,)
- )
- return res.fetchone()
-
diff --git a/src/db.rs b/src/db.rs
new file mode 100644
index 0000000..e1bb7e3
--- /dev/null
+++ b/src/db.rs
@@ -0,0 +1,125 @@
+use tokio_rusqlite::{params, Connection, Result};
+
+use crate::model::{decode_datetime, encode_datetime, File};
+
+pub async fn insert_file(conn: &Connection, file: File) -> Result<()> {
+ conn.call(move |conn| {
+ conn.execute(
+ r#"
+ INSERT INTO
+ files(id, created_at, expires_at, filename, content_length)
+ VALUES
+ (?1, datetime(), ?2, ?3, ?4)
+ "#,
+ params![
+ file.id,
+ encode_datetime(file.expires_at),
+ file.name,
+ file.content_length
+ ],
+ )
+ .map_err(tokio_rusqlite::Error::Rusqlite)
+ })
+ .await
+ .map(|_| ())
+}
+
+pub async fn get_file(conn: &Connection, file_id: String) -> Result<Option<File>> {
+ conn.call(move |conn| {
+ let mut stmt = conn.prepare(
+ r#"
+ SELECT
+ filename, expires_at, content_length
+ FROM
+ files
+ WHERE
+ id = ?
+ AND expires_at > datetime()
+ "#,
+ )?;
+
+ let mut iter = stmt.query_map([file_id.clone()], |row| {
+ let res: (String, String, usize) = (row.get(0)?, row.get(1)?, row.get(2)?);
+ Ok(res)
+ })?;
+
+ match iter.next() {
+ Some(Ok((filename, expires_at, content_length))) => {
+ match decode_datetime(&expires_at) {
+ Some(expires_at) => Ok(Some(File {
+ id: file_id.clone(),
+ name: filename,
+ expires_at,
+ content_length,
+ })),
+ _ => Err(rusqlite_other_error(&format!(
+ "Error decoding datetime: {expires_at}"
+ ))),
+ }
+ }
+ Some(_) => Err(rusqlite_other_error("Error reading file in DB")),
+ None => Ok(None),
+ }
+ })
+ .await
+}
+
+fn rusqlite_other_error(msg: &str) -> tokio_rusqlite::Error {
+ tokio_rusqlite::Error::Other(msg.into())
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::model::local_time;
+ use chrono::Duration;
+ use std::ops::Add;
+
+ #[tokio::test]
+ async fn test_insert_and_get_file() {
+ let conn = get_connection().await;
+ let file = dummy_file(Duration::minutes(1));
+ assert!(insert_file(&conn, file.clone()).await.is_ok());
+ let file_res = get_file(&conn, file.id.clone()).await;
+ assert!(file_res.is_ok());
+ assert_eq!(file_res.unwrap(), Some(file));
+ }
+
+ #[tokio::test]
+ async fn test_expired_file_err() {
+ let conn = get_connection().await;
+ let file = dummy_file(Duration::zero());
+ assert!(insert_file(&conn, file.clone()).await.is_ok());
+ let file_res = get_file(&conn, file.id.clone()).await;
+ assert!(file_res.is_ok());
+ assert!(file_res.unwrap().is_none());
+ }
+
+ #[tokio::test]
+ async fn test_wrong_file_err() {
+ let conn = get_connection().await;
+ let file = dummy_file(Duration::minutes(1));
+ assert!(insert_file(&conn, file.clone()).await.is_ok());
+ let file_res = get_file(&conn, "wrong-id".to_string()).await;
+ assert!(file_res.is_ok());
+ assert!(file_res.unwrap().is_none());
+ }
+
+ fn dummy_file(td: Duration) -> File {
+ File {
+ id: "1234".to_string(),
+ name: "foo".to_string(),
+ expires_at: local_time().add(td),
+ content_length: 100,
+ }
+ }
+
+ async fn get_connection() -> Connection {
+ let conn = Connection::open_in_memory().await.unwrap();
+ let init_db = tokio::fs::read_to_string("init-db.sql").await.unwrap();
+
+ let res = conn.call(move |conn| Ok(conn.execute(&init_db, []))).await;
+ assert!(res.is_ok());
+ conn
+ }
+}
diff --git a/src/main-old.py b/src/main-old.py
deleted file mode 100644
index 42d7c8c..0000000
--- a/src/main-old.py
+++ /dev/null
@@ -1,33 +0,0 @@
-# import http.server
-# import logging
-# import os
-# import sys
-
-# import server
-
-# logger = logging.getLogger(__name__)
-# hostName = os.environ['HOST']
-# serverPort = int(os.environ['PORT'])
-
-# if __name__ == '__main__':
-# logging.basicConfig(stream=sys.stdout, level=logging.INFO)
-# webServer = http.server.HTTPServer((hostName, serverPort), server.MyServer)
-# logger.info('Server started at http://%s:%s.' % (hostName, serverPort))
-
-# try:
-# webServer.serve_forever()
-# except KeyboardInterrupt:
-# pass
-
-# webServer.server_close()
-# conn.close()
-# logger.info('Server stopped.')
-
-from sanic import Sanic
-from sanic.response import text
-
-app = Sanic("MyHelloWorldApp")
-
-@app.get("/")
-async def hello_world(request):
- return text("Hello, world.")
diff --git a/src/main.py b/src/main.py
deleted file mode 100644
index b678aae..0000000
--- a/src/main.py
+++ /dev/null
@@ -1,28 +0,0 @@
-import sanic
-import os
-
-import controller
-
-app = sanic.Sanic("Files")
-
-@app.get("/")
-async def index(request):
- return controller.index()
-
-@app.post("/", stream = True)
-async def upload(request):
- return await controller.upload(request)
-
-@app.get("/<file_id:str>")
-async def file_page(request, file_id):
- return await controller.file(file_id, download = False)
-
-@app.get("/<file_id:str>/download")
-async def file_download(request, file_id):
- return await controller.file(file_id, download = True)
-
-app.static("/static/", "static/")
-
-if __name__ == "__main__":
- debug = 'DEBUG' in os.environ and os.environ['DEBUG'] == 'TRUE'
- app.run(debug=debug, access_log=True)
diff --git a/src/main.rs b/src/main.rs
new file mode 100644
index 0000000..27da278
--- /dev/null
+++ b/src/main.rs
@@ -0,0 +1,68 @@
+use std::env;
+use std::net::SocketAddr;
+
+use hyper::server::conn::http1;
+use hyper::service::service_fn;
+use hyper_util::rt::TokioIo;
+use tokio::net::TcpListener;
+use tokio_rusqlite::Connection;
+
+mod db;
+mod model;
+mod routes;
+mod templates;
+mod util;
+
+#[tokio::main]
+async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
+ env_logger::init();
+
+ let host = get_env("HOST");
+ let port = get_env("PORT");
+ let db_path = get_env("DB");
+ let authorized_key = get_env("KEY");
+ let files_dir = get_env("FILES_DIR");
+
+ let db_conn = Connection::open(db_path)
+ .await
+ .expect("Error while openning DB conection");
+
+ let addr: SocketAddr = format!("{host}:{port}")
+ .parse()
+ .unwrap_or_else(|_| panic!("Invalid address: {host}:{port}"));
+
+ let listener = TcpListener::bind(addr).await?;
+ log::info!("Listening on http://{}", addr);
+
+ loop {
+ let (stream, _) = listener.accept().await?;
+ let io = TokioIo::new(stream);
+
+ let db_conn = db_conn.clone();
+ let authorized_key = authorized_key.clone();
+ let files_dir = files_dir.clone();
+
+ tokio::task::spawn(async move {
+ if let Err(err) = http1::Builder::new()
+ .serve_connection(
+ io,
+ service_fn(move |req| {
+ routes::routes(
+ req,
+ db_conn.clone(),
+ authorized_key.clone(),
+ files_dir.clone(),
+ )
+ }),
+ )
+ .await
+ {
+ log::error!("Failed to serve connection: {:?}", err);
+ }
+ });
+ }
+}
+
+fn get_env(key: &str) -> String {
+ env::var(key).unwrap_or_else(|_| panic!("Missing environment variable {key}"))
+}
diff --git a/src/model.rs b/src/model.rs
new file mode 100644
index 0000000..ed4fbf8
--- /dev/null
+++ b/src/model.rs
@@ -0,0 +1,52 @@
+use base64::{engine::general_purpose::URL_SAFE, Engine as _};
+use chrono::{DateTime, Local, NaiveDateTime, TimeZone};
+use rand_core::{OsRng, RngCore};
+
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub struct File {
+ pub id: String,
+ pub name: String,
+ pub expires_at: DateTime<Local>,
+ pub content_length: usize,
+}
+
+pub fn local_time() -> DateTime<Local> {
+ let dt = Local::now();
+ match decode_datetime(&encode_datetime(dt)) {
+ Some(res) => res,
+ None => dt,
+ }
+}
+
+// Using 20 bytes (160 bits) to file identifiers
+// https://owasp.org/www-community/vulnerabilities/Insufficient_Session-ID_Length
+// https://www.rfc-editor.org/rfc/rfc6749.html#section-10.10
+const FILE_ID_BYTES: usize = 20;
+
+pub fn generate_file_id() -> String {
+ let mut token = [0u8; FILE_ID_BYTES];
+ OsRng.fill_bytes(&mut token);
+ URL_SAFE.encode(token)
+}
+
+const FORMAT: &str = "%Y-%m-%d %H:%M:%S";
+
+pub fn encode_datetime(dt: DateTime<Local>) -> String {
+ dt.naive_utc().format(FORMAT).to_string()
+}
+
+pub fn decode_datetime(str: &str) -> Option<DateTime<Local>> {
+ let naive_time = NaiveDateTime::parse_from_str(str, FORMAT).ok()?;
+ Some(Local.from_utc_datetime(&naive_time))
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_datetime_serialization() {
+ let dt = local_time();
+ assert_eq!(decode_datetime(&encode_datetime(dt)), Some(dt))
+ }
+}
diff --git a/src/routes.rs b/src/routes.rs
new file mode 100644
index 0000000..b54e565
--- /dev/null
+++ b/src/routes.rs
@@ -0,0 +1,209 @@
+use std::ops::Add;
+use std::path::{Path, PathBuf};
+
+use chrono::Duration;
+use futures_util::TryStreamExt;
+use http_body_util::{combinators::BoxBody, BodyExt, Full, StreamBody};
+use hyper::body::{Bytes, Frame, Incoming};
+use hyper::header::{HeaderName, HeaderValue, CONTENT_DISPOSITION, CONTENT_LENGTH, CONTENT_TYPE};
+use hyper::{Method, Request, Response, Result, StatusCode};
+use tokio::io::AsyncWriteExt;
+use tokio::{fs, fs::File};
+use tokio_rusqlite::Connection;
+use tokio_util::io::ReaderStream;
+
+use crate::db;
+use crate::model;
+use crate::templates;
+use crate::util;
+
+pub async fn routes(
+ request: Request<Incoming>,
+ db_conn: Connection,
+ authorized_key: String,
+ files_dir: String,
+) -> Result<Response<BoxBody<Bytes, std::io::Error>>> {
+ let path = &request.uri().path().split('/').collect::<Vec<&str>>()[1..];
+ let files_dir = Path::new(&files_dir);
+
+ match (request.method(), path) {
+ (&Method::GET, [""]) => Ok(response(StatusCode::OK, templates::INDEX.to_string())),
+ (&Method::GET, ["static", "main.js"]) => Ok(static_file(
+ include_str!("static/main.js").to_string(),
+ "application/javascript",
+ )),
+ (&Method::GET, ["static", "main.css"]) => Ok(static_file(
+ include_str!("static/main.css").to_string(),
+ "text/css",
+ )),
+ (&Method::POST, [""]) => upload_file(request, db_conn, authorized_key, files_dir).await,
+ (&Method::GET, [file_id]) => get(db_conn, file_id, GetFile::ShowPage, files_dir).await,
+ (&Method::GET, [file_id, "download"]) => {
+ get(db_conn, file_id, GetFile::Download, files_dir).await
+ }
+ _ => Ok(not_found()),
+ }
+}
+
+async fn upload_file(
+ request: Request<Incoming>,
+ db_conn: Connection,
+ authorized_key: String,
+ files_dir: &Path,
+) -> Result<Response<BoxBody<Bytes, std::io::Error>>> {
+ let key = get_header(&request, "X-Key");
+ if key != Some(authorized_key) {
+ log::info!("Unauthorized file upload");
+ Ok(response(
+ StatusCode::UNAUTHORIZED,
+ "Unauthorized".to_string(),
+ ))
+ } else {
+ let file_id = model::generate_file_id();
+ let filename = get_header(&request, "X-Filename").map(|s| util::sanitize_filename(&s));
+ let expiration_days: Option<i64> =
+ get_header(&request, "X-Expiration").and_then(|s| s.parse().ok());
+ let content_length: Option<usize> =
+ get_header(&request, "Content-Length").and_then(|s| s.parse().ok());
+
+ match (filename, expiration_days, content_length) {
+ (Some(filename), Some(expiration_days), Some(content_length)) => {
+ let _ = fs::create_dir(files_dir).await;
+ let path = files_dir.join(&file_id);
+ let mut file = File::create(&path).await.unwrap();
+
+ let mut incoming = request.into_body();
+ while let Some(frame) = incoming.frame().await {
+ if let Ok(data) = frame {
+ let _ = file.write_all(&data.into_data().unwrap()).await;
+ let _ = file.flush().await;
+ }
+ }
+
+ let file = model::File {
+ id: file_id.clone(),
+ name: filename,
+ expires_at: model::local_time().add(Duration::days(expiration_days)),
+ content_length,
+ };
+
+ match db::insert_file(&db_conn, file.clone()).await {
+ Ok(_) => Ok(response(StatusCode::OK, file_id)),
+ Err(msg) => {
+ log::error!("Insert file: {msg}");
+ if let Err(msg) = fs::remove_file(path).await {
+ log::error!("Remove file: {msg}");
+ };
+ Ok(internal_server_error())
+ }
+ }
+ }
+ _ => Ok(bad_request()),
+ }
+ }
+}
+
+fn get_header(request: &Request<Incoming>, header: &str) -> Option<String> {
+ request
+ .headers()
+ .get(header)?
+ .to_str()
+ .ok()
+ .map(|str| str.to_string())
+}
+
+enum GetFile {
+ ShowPage,
+ Download,
+}
+
+async fn get(
+ db_conn: Connection,
+ file_id: &str,
+ get_file: GetFile,
+ files_dir: &Path,
+) -> Result<Response<BoxBody<Bytes, std::io::Error>>> {
+ let file = db::get_file(&db_conn, file_id.to_string()).await;
+ match (get_file, file) {
+ (GetFile::ShowPage, Ok(Some(file))) => {
+ Ok(response(StatusCode::OK, templates::file_page(file)))
+ }
+ (GetFile::Download, Ok(Some(file))) => {
+ let path = files_dir.join(file_id);
+ Ok(stream_file(path, file).await)
+ }
+ (_, Err(msg)) => {
+ log::error!("Getting file: {msg}");
+ Ok(internal_server_error())
+ }
+ (_, Ok(None)) => Ok(not_found()),
+ }
+}
+
+fn static_file(text: String, content_type: &str) -> Response<BoxBody<Bytes, std::io::Error>> {
+ let response = Response::builder()
+ .body(Full::new(text.into()).map_err(|e| match e {}).boxed())
+ .unwrap();
+ with_headers(response, vec![(CONTENT_TYPE, content_type)])
+}
+
+fn response(status_code: StatusCode, text: String) -> Response<BoxBody<Bytes, std::io::Error>> {
+ Response::builder()
+ .status(status_code)
+ .body(Full::new(text.into()).map_err(|e| match e {}).boxed())
+ .unwrap()
+}
+
+async fn stream_file(path: PathBuf, file: model::File) -> Response<BoxBody<Bytes, std::io::Error>> {
+ match File::open(path).await {
+ Err(e) => {
+ log::error!("Unable to open file: {e}");
+ not_found()
+ }
+ Ok(disk_file) => {
+ let reader_stream = ReaderStream::new(disk_file);
+ let stream_body = StreamBody::new(reader_stream.map_ok(Frame::data));
+ let boxed_body = stream_body.boxed();
+
+ let response = Response::builder().body(boxed_body).unwrap();
+
+ with_headers(
+ response,
+ vec![
+ (
+ CONTENT_DISPOSITION,
+ &format!("attachment; filename={}", file.name),
+ ),
+ (CONTENT_LENGTH, &file.content_length.to_string()),
+ ],
+ )
+ }
+ }
+}
+
+fn not_found() -> Response<BoxBody<Bytes, std::io::Error>> {
+ response(StatusCode::NOT_FOUND, templates::NOT_FOUND.to_string())
+}
+
+fn bad_request() -> Response<BoxBody<Bytes, std::io::Error>> {
+ response(StatusCode::BAD_REQUEST, templates::BAD_REQUEST.to_string())
+}
+
+fn internal_server_error() -> Response<BoxBody<Bytes, std::io::Error>> {
+ response(
+ StatusCode::INTERNAL_SERVER_ERROR,
+ templates::INTERNAL_SERVER_ERROR.to_string(),
+ )
+}
+
+pub fn with_headers(
+ response: Response<BoxBody<Bytes, std::io::Error>>,
+ headers: Vec<(HeaderName, &str)>,
+) -> Response<BoxBody<Bytes, std::io::Error>> {
+ let mut response = response;
+ let response_headers = response.headers_mut();
+ for (name, value) in headers {
+ response_headers.insert(name, HeaderValue::from_str(value).unwrap());
+ }
+ response
+}
diff --git a/src/server.py b/src/server.py
deleted file mode 100644
index 5927052..0000000
--- a/src/server.py
+++ /dev/null
@@ -1,84 +0,0 @@
-import http.server
-import logging
-import os
-import sqlite3
-import tempfile
-
-import db
-import templates
-import utils
-
-logger = logging.getLogger(__name__)
-conn = sqlite3.connect('db.sqlite3')
-files_directory = 'files'
-authorized_key = os.environ['KEY']
-
-class MyServer(http.server.BaseHTTPRequestHandler):
- def do_GET(self):
- match self.path:
- case '/':
- self._serve_str(templates.index, 200, 'text/html')
- case '/main.js':
- self._serve_file('public/main.js', 'application/javascript')
- case '/main.css':
- self._serve_file('public/main.css', 'text/css')
- case path:
- if path.endswith('?download'):
- download = True
- path = path[:-len('?download')]
- else:
- download = False
-
- file_id = path[1:]
- res = db.get_file(conn, file_id)
- if res is None:
- self._serve_str(templates.not_found, 404, 'text/html')
- else:
- filename, expires, content_length = res
- disk_path = os.path.join(files_directory, file_id)
- if download:
- headers = [
- ('Content-Disposition', f'attachment; filename={filename}'),
- ('Content-Length', content_length)
- ]
- self._serve_file(disk_path, 'application/octet-stream', headers)
- else:
- href = f'{file_id}?download'
- self._serve_str(templates.download(href, filename, expires), 200, 'text/html')
-
- def do_POST(self):
- key = self.headers['X-Key']
- if not key == authorized_key:
- logging.info('Unauthorized to upload file: wrong key')
- self._serve_str('Unauthorized', 401)
-
- else:
- logging.info('Uploading file')
- content_length = int(self.headers['content-length'])
- filename = utils.sanitize_filename(self.headers['X-FileName'])
- expiration = self.headers['X-Expiration']
-
- with tempfile.NamedTemporaryFile(delete = False) as tmp:
- utils.transfer(self.rfile, tmp, content_length = content_length)
-
- logging.info('File uploaded')
- file_id = db.insert_file(conn, filename, expiration, content_length)
- os.makedirs(files_directory, exist_ok=True)
- os.rename(tmp.name, os.path.join(files_directory, file_id))
-
- self._serve_str(file_id, 200)
-
- def _serve_str(self, s, code, content_type='text/plain'):
- self.send_response(code)
- self.send_header('Content-type', content_type)
- self.end_headers()
- self.wfile.write(bytes(s, 'utf-8'))
-
- def _serve_file(self, filename, content_type, headers = []):
- self.send_response(200)
- self.send_header('Content-type', content_type)
- for header_name, header_value in headers:
- self.send_header(header_name, header_value)
- self.end_headers()
- with open(filename, 'rb') as f:
- utils.transfer(f, self.wfile)
diff --git a/src/static/main.css b/src/static/main.css
new file mode 100644
index 0000000..af0ee54
--- /dev/null
+++ b/src/static/main.css
@@ -0,0 +1,94 @@
+html {
+ margin: 0 1rem;
+ font-size: 16px;
+ line-height: 1.4rem;
+ font-family: sans-serif;
+ box-sizing: border-box;
+}
+
+*, *:before, *:after {
+ box-sizing: inherit;
+}
+
+body {
+ max-width: 30rem;
+ margin: 0 auto;
+}
+
+a {
+ text-decoration: none;
+ color: #06C;
+}
+
+h1 {
+ text-align: center;
+ font-variant: small-caps;
+ font-size: 40px;
+ letter-spacing: 0.2rem;
+ margin-bottom: 4rem;
+}
+
+.g-Link {
+ text-decoration: underline;
+}
+
+label {
+ display: flex;
+ gap: 0.5rem;
+ flex-direction: column;
+ margin-bottom: 2rem;
+}
+
+input, select {
+ font-size: inherit;
+ border: 1px solid black;
+ height: 2rem;
+ background: white;
+}
+
+input[type=file] {
+ align-content: center;
+ padding-left: 2px;
+}
+
+input[type=submit] {
+ width: 100%;
+ background: #06C;
+ cursor: pointer;
+ border: none;
+ color: white;
+}
+
+.g-Loading {
+ display: none;
+ align-items: center;
+ justify-content: center;
+ gap: 1rem;
+ margin-bottom: 2rem;
+}
+
+.g-Error {
+ text-align: center;
+ margin-bottom: 2rem;
+ color: #C00;
+}
+
+.g-Spinner {
+ width: 25px;
+ height: 25px;
+ border: 4px solid #06C;
+ border-bottom-color: transparent;
+ border-radius: 50%;
+ display: inline-block;
+ box-sizing: border-box;
+ animation: rotation 1s linear infinite;
+}
+
+@keyframes rotation {
+ 0% {
+ transform: rotate(0deg);
+ }
+ 100% {
+ transform: rotate(360deg);
+ }
+}
diff --git a/src/static/main.js b/src/static/main.js
new file mode 100644
index 0000000..40e62d6
--- /dev/null
+++ b/src/static/main.js
@@ -0,0 +1,49 @@
+window.onload = function() {
+ const form = document.querySelector('form')
+
+ if (form !== null) {
+ const submit = document.querySelector('input[type="submit"]')
+ const loading = document.querySelector('.g-Loading')
+ const error = document.querySelector('.g-Error')
+
+ function showError(msg) {
+ loading.style.display = 'none'
+ submit.disabled = false
+ error.innerText = msg
+ error.style.display = 'block'
+ }
+
+ form.onsubmit = function(event) {
+ event.preventDefault()
+
+ loading.style.display = 'flex'
+ submit.disabled = true
+ error.style.display = 'none'
+
+ const key = document.querySelector('input[name="key"]').value
+ const expiration = document.querySelector('select[name="expiration"]').value
+ const file = document.querySelector('input[name="file"]').files[0]
+ const filename = file.name.replace(/[^0-9a-zA-Z\.]/gi, '-')
+
+ // Wait a bit to prevent showing the loader too briefly
+ setTimeout(function() {
+ const xhr = new XMLHttpRequest()
+ xhr.open('POST', '/', true)
+ xhr.onload = function () {
+ if (xhr.status === 200) {
+ window.location = `/${xhr.responseText}`
+ } else {
+ showError(`Error uploading: ${xhr.status}`)
+ }
+ }
+ xhr.onerror = function () {
+ showError('Upload error')
+ }
+ xhr.setRequestHeader('X-FileName', filename)
+ xhr.setRequestHeader('X-Expiration', expiration)
+ xhr.setRequestHeader('X-Key', key)
+ xhr.send(file)
+ }, 500)
+ }
+ }
+}
diff --git a/src/templates.py b/src/templates.py
deleted file mode 100644
index 8125f69..0000000
--- a/src/templates.py
+++ /dev/null
@@ -1,105 +0,0 @@
-import html
-import datetime
-
-page: str = '''
- <!doctype html>
- <html lang="fr">
- <meta charset="utf-8">
- <meta name="viewport" content="width=device-width">
-
- <title>Files</title>
- <link rel="stylesheet" href="/static/main.css">
- <script src="/static/main.js"></script>
-
- <a href="/">
- <h1>Files</h1>
- </a>
-'''
-
-pub index: str = f'''
- {page}
-
- <form>
- <label>
- File
- <input type="file" name="file" required>
- </label>
-
- <label>
- Expiration
- <select name="expiration">
- <option value="1">1 day</option>
- <option value="2">2 days</option>
- <option value="3">3 days</option>
- <option value="4">4 days</option>
- <option value="5">5 days</option>
- <option value="6">6 days</option>
- <option value="7" selected>7 days</option>
- <option value="8">8 days</option>
- <option value="9">9 days</option>
- <option value="10">10 days</option>
- <option value="11">11 days</option>
- <option value="12">12 days</option>
- <option value="13">13 days</option>
- <option value="14">14 days</option>
- <option value="15">15 days</option>
- <option value="16">16 days</option>
- <option value="17">17 days</option>
- <option value="18">18 days</option>
- <option value="19">19 days</option>
- <option value="20">20 days</option>
- <option value="21">21 days</option>
- <option value="22">22 days</option>
- <option value="23">23 days</option>
- <option value="24">24 days</option>
- <option value="25">25 days</option>
- <option value="26">26 days</option>
- <option value="27">27 days</option>
- <option value="28">28 days</option>
- <option value="29">29 days</option>
- <option value="30">30 days</option>
- <option value="31">31 days</option>
- </select>
- </label>
-
- <label>
- Key
- <input type="password" name="key" required>
- </label>
-
- <div class="g-Loading">
- <div class="g-Spinner"></div>
- Uploading…
- </div>
-
- <div class="g-Error">
- </div>
-
- <input type="submit" value="Upload">
- </form>
-'''
-
-def file_page(file_id: str, filename: str, expires: str) -> str:
- href = f'{file_id}/download'
- expires_in = datetime.datetime.strptime(expires, '%Y-%m-%d %H:%M:%S') - datetime.datetime.now()
-
- print()
- print(href)
- print()
-
- return f'''
- {page}
-
- <div>
- <a class="g-Link" href="{html.escape(href)}">{html.escape(filename)}</a>
- <div>
- Expires in {expires_in}
- </div>
- </div>
- '''
-
-not_found: str = f'''
- {page}
-
- Oops, not found!
-'''
diff --git a/src/templates.rs b/src/templates.rs
new file mode 100644
index 0000000..b551bf6
--- /dev/null
+++ b/src/templates.rs
@@ -0,0 +1,134 @@
+use chrono::Local;
+
+use crate::model::File;
+use crate::util;
+
+const PAGE: &str = r#"
+<!doctype html>
+<html lang="fr">
+<meta charset="utf-8">
+<meta name="viewport" content="width=device-width">
+
+<title>Files</title>
+<link rel="stylesheet" href="/static/main.css">
+<script src="/static/main.js"></script>
+
+<a href="/">
+ <h1>Files</h1>
+</a>
+"#;
+
+pub const INDEX: &str = const_format::concatcp!(
+ PAGE,
+ r#"
+<form>
+ <label>
+ File
+ <input type="file" name="file" required>
+ </label>
+
+ <label>
+ Expiration
+ <select name="expiration">
+ <option value="1">1 day</option>
+ <option value="2">2 days</option>
+ <option value="3">3 days</option>
+ <option value="4">4 days</option>
+ <option value="5">5 days</option>
+ <option value="6">6 days</option>
+ <option value="7" selected>7 days</option>
+ <option value="8">8 days</option>
+ <option value="9">9 days</option>
+ <option value="10">10 days</option>
+ <option value="11">11 days</option>
+ <option value="12">12 days</option>
+ <option value="13">13 days</option>
+ <option value="14">14 days</option>
+ <option value="15">15 days</option>
+ <option value="16">16 days</option>
+ <option value="17">17 days</option>
+ <option value="18">18 days</option>
+ <option value="19">19 days</option>
+ <option value="20">20 days</option>
+ <option value="21">21 days</option>
+ <option value="22">22 days</option>
+ <option value="23">23 days</option>
+ <option value="24">24 days</option>
+ <option value="25">25 days</option>
+ <option value="26">26 days</option>
+ <option value="27">27 days</option>
+ <option value="28">28 days</option>
+ <option value="29">29 days</option>
+ <option value="30">30 days</option>
+ <option value="31">31 days</option>
+ </select>
+ </label>
+
+ <label>
+ Key
+ <input type="password" name="key" required>
+ </label>
+
+ <div class="g-Loading">
+ <div class="g-Spinner"></div>
+ Uploading…
+ </div>
+
+ <div class="g-Error">
+ </div>
+
+ <input type="submit" value="Upload">
+</form>"#
+);
+
+pub const NOT_FOUND: &str = const_format::concatcp!(
+ PAGE,
+ r#"
+ <div>
+ Oops, not found.
+ </div>
+ "#
+);
+
+pub const BAD_REQUEST: &str = const_format::concatcp!(
+ PAGE,
+ r#"
+ <div>
+ Oops, bad request.
+ </div>
+ "#
+);
+
+pub const INTERNAL_SERVER_ERROR: &str = const_format::concatcp!(
+ PAGE,
+ r#"
+ <div>
+ Oops, internal server error.
+ </div>
+ "#
+);
+
+pub fn file_page(file: File) -> String {
+ let href = format!("{}/download", file.id);
+ let expiration = file.expires_at.signed_duration_since(Local::now());
+
+ format!(
+ r#"
+ {page}
+
+ <div>
+ <div>
+ <a class="g-Link" href="{href}">{filename}</a> – {size}
+ </div>
+ <div>
+ Expires in <b>{expiration}</b>.
+ </div>
+ </div>
+ "#,
+ page = PAGE,
+ href = html_escape::encode_text(&href),
+ filename = html_escape::encode_text(&file.name),
+ expiration = html_escape::encode_text(&util::pretty_print_duration(expiration)),
+ size = html_escape::encode_text(&util::pretty_print_bytes(file.content_length))
+ )
+}
diff --git a/src/util.rs b/src/util.rs
new file mode 100644
index 0000000..9bc7cb9
--- /dev/null
+++ b/src/util.rs
@@ -0,0 +1,125 @@
+use chrono::Duration;
+
+pub fn sanitize_filename(s: &str) -> String {
+ s.split('.')
+ .map(sanitize_filename_part)
+ .collect::<Vec<String>>()
+ .join(".")
+}
+
+pub fn sanitize_filename_part(s: &str) -> String {
+ s.chars()
+ .map(|c| {
+ if c.is_ascii_alphanumeric() {
+ c.to_lowercase().collect::<String>()
+ } else {
+ " ".to_string()
+ }
+ })
+ .collect::<String>()
+ .split_whitespace()
+ .collect::<Vec<&str>>()
+ .join("-")
+}
+
+pub fn pretty_print_duration(d: Duration) -> String {
+ if d.num_days() > 0 {
+ let plural = if d.num_days() > 1 { "s" } else { "" };
+ format!("{} day{}", d.num_days(), plural)
+ } else if d.num_hours() > 0 {
+ format!("{} h", d.num_hours())
+ } else if d.num_minutes() > 0 {
+ format!("{} min", d.num_minutes())
+ } else {
+ format!("{} s", d.num_seconds())
+ }
+}
+
+pub fn pretty_print_bytes(bytes: usize) -> String {
+ let ko = bytes / 1024;
+ let mo = ko / 1024;
+ let go = mo / 1024;
+ if go > 0 {
+ format!("{} GB", go)
+ } else if mo > 0 {
+ format!("{} MB", mo)
+ } else if ko > 0 {
+ format!("{} KB", ko)
+ } else {
+ format!("{} B", bytes)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_sanitize_filename() {
+ assert_eq!(sanitize_filename(""), "");
+ assert_eq!(sanitize_filename("foo bar 123"), "foo-bar-123");
+ assert_eq!(sanitize_filename("foo bar.123"), "foo-bar.123");
+ assert_eq!(sanitize_filename("foo ( test+2 ).xml"), "foo-test-2.xml");
+ }
+
+ #[test]
+ fn test_sanitize_filename_part() {
+ assert_eq!(sanitize_filename_part(""), "");
+ assert_eq!(sanitize_filename_part("foo123BAZ"), "foo123baz");
+ assert_eq!(sanitize_filename_part("foo-123-BAZ"), "foo-123-baz");
+ assert_eq!(sanitize_filename_part("[()] */+-!;?<'> ?:"), "");
+ assert_eq!(sanitize_filename_part("foo [bar] -- BAZ3"), "foo-bar-baz3");
+ }
+
+ #[test]
+ fn test_pretty_print_duration() {
+ assert_eq!(
+ pretty_print_duration(Duration::days(2)),
+ "2 days".to_string()
+ );
+ assert_eq!(
+ pretty_print_duration(Duration::hours(30)),
+ "1 day".to_string()
+ );
+ assert_eq!(
+ pretty_print_duration(Duration::days(1)),
+ "1 day".to_string()
+ );
+ assert_eq!(
+ pretty_print_duration(Duration::hours(15)),
+ "15 h".to_string()
+ );
+ assert_eq!(
+ pretty_print_duration(Duration::minutes(70)),
+ "1 h".to_string()
+ );
+ assert_eq!(
+ pretty_print_duration(Duration::minutes(44)),
+ "44 min".to_string()
+ );
+ assert_eq!(
+ pretty_print_duration(Duration::seconds(100)),
+ "1 min".to_string()
+ );
+ assert_eq!(
+ pretty_print_duration(Duration::seconds(7)),
+ "7 s".to_string()
+ );
+ assert_eq!(pretty_print_duration(Duration::zero()), "0 s".to_string());
+ }
+
+ #[test]
+ fn test_pretty_print_bytes() {
+ assert_eq!(pretty_print_bytes(0), "0 B");
+ assert_eq!(pretty_print_bytes(10), "10 B");
+ assert_eq!(pretty_print_bytes(1024), "1 KB");
+ assert_eq!(pretty_print_bytes(1100), "1 KB");
+ assert_eq!(pretty_print_bytes(54 * 1024), "54 KB");
+ assert_eq!(pretty_print_bytes(1024 * 1024), "1 MB");
+ assert_eq!(pretty_print_bytes(1300 * 1024), "1 MB");
+ assert_eq!(pretty_print_bytes(79 * 1024 * 1024), "79 MB");
+ assert_eq!(pretty_print_bytes(1024 * 1024 * 1024), "1 GB");
+ assert_eq!(pretty_print_bytes(1300 * 1024 * 1024), "1 GB");
+ assert_eq!(pretty_print_bytes(245 * 1024 * 1024 * 1024), "245 GB");
+ }
+}
diff --git a/src/utils.py b/src/utils.py
deleted file mode 100644
index 151217f..0000000
--- a/src/utils.py
+++ /dev/null
@@ -1,8 +0,0 @@
-import io
-
-def sanitize_filename(s: str) -> str:
- return '.'.join([sanitize_filename_part(p) for p in s.split('.')])
-
-def sanitize_filename_part(s: str) -> str:
- alnum_or_space = ''.join([c if c.isalnum() else ' ' for c in s])
- return '-'.join(alnum_or_space.split())