From 1019ea1ed341e3a7769c046aa0be5764789360b6 Mon Sep 17 00:00:00 2001 From: Joris Date: Sun, 2 Jun 2024 14:38:13 +0200 Subject: Migrate to Rust and Hyper With sanic, downloading a file locally is around ten times slower than with Rust and hyper. Maybe `pypy` could have helped, but I didn’t succeed to set it up quickly with the dependencies. --- src/routes.rs | 209 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 209 insertions(+) create mode 100644 src/routes.rs (limited to 'src/routes.rs') diff --git a/src/routes.rs b/src/routes.rs new file mode 100644 index 0000000..b54e565 --- /dev/null +++ b/src/routes.rs @@ -0,0 +1,209 @@ +use std::ops::Add; +use std::path::{Path, PathBuf}; + +use chrono::Duration; +use futures_util::TryStreamExt; +use http_body_util::{combinators::BoxBody, BodyExt, Full, StreamBody}; +use hyper::body::{Bytes, Frame, Incoming}; +use hyper::header::{HeaderName, HeaderValue, CONTENT_DISPOSITION, CONTENT_LENGTH, CONTENT_TYPE}; +use hyper::{Method, Request, Response, Result, StatusCode}; +use tokio::io::AsyncWriteExt; +use tokio::{fs, fs::File}; +use tokio_rusqlite::Connection; +use tokio_util::io::ReaderStream; + +use crate::db; +use crate::model; +use crate::templates; +use crate::util; + +pub async fn routes( + request: Request, + db_conn: Connection, + authorized_key: String, + files_dir: String, +) -> Result>> { + let path = &request.uri().path().split('/').collect::>()[1..]; + let files_dir = Path::new(&files_dir); + + match (request.method(), path) { + (&Method::GET, [""]) => Ok(response(StatusCode::OK, templates::INDEX.to_string())), + (&Method::GET, ["static", "main.js"]) => Ok(static_file( + include_str!("static/main.js").to_string(), + "application/javascript", + )), + (&Method::GET, ["static", "main.css"]) => Ok(static_file( + include_str!("static/main.css").to_string(), + "text/css", + )), + (&Method::POST, [""]) => upload_file(request, db_conn, authorized_key, files_dir).await, + (&Method::GET, [file_id]) => get(db_conn, file_id, GetFile::ShowPage, files_dir).await, + (&Method::GET, [file_id, "download"]) => { + get(db_conn, file_id, GetFile::Download, files_dir).await + } + _ => Ok(not_found()), + } +} + +async fn upload_file( + request: Request, + db_conn: Connection, + authorized_key: String, + files_dir: &Path, +) -> Result>> { + let key = get_header(&request, "X-Key"); + if key != Some(authorized_key) { + log::info!("Unauthorized file upload"); + Ok(response( + StatusCode::UNAUTHORIZED, + "Unauthorized".to_string(), + )) + } else { + let file_id = model::generate_file_id(); + let filename = get_header(&request, "X-Filename").map(|s| util::sanitize_filename(&s)); + let expiration_days: Option = + get_header(&request, "X-Expiration").and_then(|s| s.parse().ok()); + let content_length: Option = + get_header(&request, "Content-Length").and_then(|s| s.parse().ok()); + + match (filename, expiration_days, content_length) { + (Some(filename), Some(expiration_days), Some(content_length)) => { + let _ = fs::create_dir(files_dir).await; + let path = files_dir.join(&file_id); + let mut file = File::create(&path).await.unwrap(); + + let mut incoming = request.into_body(); + while let Some(frame) = incoming.frame().await { + if let Ok(data) = frame { + let _ = file.write_all(&data.into_data().unwrap()).await; + let _ = file.flush().await; + } + } + + let file = model::File { + id: file_id.clone(), + name: filename, + expires_at: model::local_time().add(Duration::days(expiration_days)), + content_length, + }; + + match db::insert_file(&db_conn, file.clone()).await { + Ok(_) => Ok(response(StatusCode::OK, file_id)), + Err(msg) => { + log::error!("Insert file: {msg}"); + if let Err(msg) = fs::remove_file(path).await { + log::error!("Remove file: {msg}"); + }; + Ok(internal_server_error()) + } + } + } + _ => Ok(bad_request()), + } + } +} + +fn get_header(request: &Request, header: &str) -> Option { + request + .headers() + .get(header)? + .to_str() + .ok() + .map(|str| str.to_string()) +} + +enum GetFile { + ShowPage, + Download, +} + +async fn get( + db_conn: Connection, + file_id: &str, + get_file: GetFile, + files_dir: &Path, +) -> Result>> { + let file = db::get_file(&db_conn, file_id.to_string()).await; + match (get_file, file) { + (GetFile::ShowPage, Ok(Some(file))) => { + Ok(response(StatusCode::OK, templates::file_page(file))) + } + (GetFile::Download, Ok(Some(file))) => { + let path = files_dir.join(file_id); + Ok(stream_file(path, file).await) + } + (_, Err(msg)) => { + log::error!("Getting file: {msg}"); + Ok(internal_server_error()) + } + (_, Ok(None)) => Ok(not_found()), + } +} + +fn static_file(text: String, content_type: &str) -> Response> { + let response = Response::builder() + .body(Full::new(text.into()).map_err(|e| match e {}).boxed()) + .unwrap(); + with_headers(response, vec![(CONTENT_TYPE, content_type)]) +} + +fn response(status_code: StatusCode, text: String) -> Response> { + Response::builder() + .status(status_code) + .body(Full::new(text.into()).map_err(|e| match e {}).boxed()) + .unwrap() +} + +async fn stream_file(path: PathBuf, file: model::File) -> Response> { + match File::open(path).await { + Err(e) => { + log::error!("Unable to open file: {e}"); + not_found() + } + Ok(disk_file) => { + let reader_stream = ReaderStream::new(disk_file); + let stream_body = StreamBody::new(reader_stream.map_ok(Frame::data)); + let boxed_body = stream_body.boxed(); + + let response = Response::builder().body(boxed_body).unwrap(); + + with_headers( + response, + vec![ + ( + CONTENT_DISPOSITION, + &format!("attachment; filename={}", file.name), + ), + (CONTENT_LENGTH, &file.content_length.to_string()), + ], + ) + } + } +} + +fn not_found() -> Response> { + response(StatusCode::NOT_FOUND, templates::NOT_FOUND.to_string()) +} + +fn bad_request() -> Response> { + response(StatusCode::BAD_REQUEST, templates::BAD_REQUEST.to_string()) +} + +fn internal_server_error() -> Response> { + response( + StatusCode::INTERNAL_SERVER_ERROR, + templates::INTERNAL_SERVER_ERROR.to_string(), + ) +} + +pub fn with_headers( + response: Response>, + headers: Vec<(HeaderName, &str)>, +) -> Response> { + let mut response = response; + let response_headers = response.headers_mut(); + for (name, value) in headers { + response_headers.insert(name, HeaderValue::from_str(value).unwrap()); + } + response +} -- cgit v1.2.3