use std::ops::Add; use std::path::{Path, PathBuf}; use chrono::Duration; use futures_util::TryStreamExt; use http_body_util::{combinators::BoxBody, BodyExt, Full, StreamBody}; use hyper::body::{Bytes, Frame, Incoming}; use hyper::header::{HeaderName, HeaderValue, CONTENT_DISPOSITION, CONTENT_LENGTH, CONTENT_TYPE}; use hyper::{Method, Request, Response, Result, StatusCode}; use tokio::io::AsyncWriteExt; use tokio::{fs, fs::File}; use tokio_rusqlite::Connection; use tokio_util::io::ReaderStream; use crate::db; use crate::model; use crate::templates; use crate::util; pub async fn routes( request: Request, db_conn: Connection, authorized_key: String, files_dir: String, ) -> Result>> { let path = &request.uri().path().split('/').collect::>()[1..]; let files_dir = Path::new(&files_dir); match (request.method(), path) { (&Method::GET, [""]) => { Ok(with_headers( response(StatusCode::OK, templates::INDEX.to_string()), vec![(CONTENT_TYPE, "text/html")], )) }, (&Method::GET, ["static", "main.js"]) => Ok(static_file( include_str!("static/main.js").to_string(), "application/javascript", )), (&Method::GET, ["static", "main.css"]) => Ok(static_file( include_str!("static/main.css").to_string(), "text/css", )), (&Method::POST, ["upload"]) => upload_file(request, db_conn, authorized_key, files_dir).await, (&Method::GET, ["share", file_id]) => get(db_conn, file_id, GetFile::ShowPage, files_dir).await, (&Method::GET, ["share", file_id, "download"]) => { get(db_conn, file_id, GetFile::Download, files_dir).await } _ => Ok(not_found()), } } async fn upload_file( request: Request, db_conn: Connection, authorized_key: String, files_dir: &Path, ) -> Result>> { let key = get_header(&request, "X-Key"); if key != Some(authorized_key) { log::info!("Unauthorized file upload"); Ok(response( StatusCode::UNAUTHORIZED, "Unauthorized".to_string(), )) } else { let file_id = model::generate_file_id(); let filename = get_header(&request, "X-Filename").map(|s| util::sanitize_filename(&s)); let expiration_days: Option = get_header(&request, "X-Expiration").and_then(|s| s.parse().ok()); let content_length: Option = get_header(&request, "Content-Length").and_then(|s| s.parse().ok()); match (filename, expiration_days, content_length) { (Some(filename), Some(expiration_days), Some(content_length)) => { let _ = fs::create_dir(files_dir).await; let path = files_dir.join(&file_id); let mut file = File::create(&path).await.unwrap(); let mut incoming = request.into_body(); while let Some(frame) = incoming.frame().await { if let Ok(data) = frame { let _ = file.write_all(&data.into_data().unwrap()).await; let _ = file.flush().await; } } let file = model::File { id: file_id.clone(), name: filename, expires_at: model::local_time().add(Duration::days(expiration_days)), content_length, }; match db::insert_file(&db_conn, file.clone()).await { Ok(_) => Ok(response(StatusCode::OK, file_id)), Err(msg) => { log::error!("Insert file: {msg}"); if let Err(msg) = fs::remove_file(path).await { log::error!("Remove file: {msg}"); }; Ok(internal_server_error()) } } } _ => Ok(bad_request()), } } } fn get_header(request: &Request, header: &str) -> Option { request .headers() .get(header)? .to_str() .ok() .map(|str| str.to_string()) } enum GetFile { ShowPage, Download, } async fn get( db_conn: Connection, file_id: &str, get_file: GetFile, files_dir: &Path, ) -> Result>> { let file = db::get_file(&db_conn, file_id.to_string()).await; match (get_file, file) { (GetFile::ShowPage, Ok(Some(file))) => { Ok(with_headers( response(StatusCode::OK, templates::file_page(file)), vec![(CONTENT_TYPE, "text/html")], )) } (GetFile::Download, Ok(Some(file))) => { let path = files_dir.join(file_id); Ok(stream_file(path, file).await) } (_, Err(msg)) => { log::error!("Getting file: {msg}"); Ok(internal_server_error()) } (_, Ok(None)) => Ok(not_found()), } } fn static_file(text: String, content_type: &str) -> Response> { let response = Response::builder() .body(Full::new(text.into()).map_err(|e| match e {}).boxed()) .unwrap(); with_headers(response, vec![(CONTENT_TYPE, content_type)]) } fn response(status_code: StatusCode, text: String) -> Response> { Response::builder() .status(status_code) .body(Full::new(text.into()).map_err(|e| match e {}).boxed()) .unwrap() } async fn stream_file(path: PathBuf, file: model::File) -> Response> { match File::open(path).await { Err(e) => { log::error!("Unable to open file: {e}"); not_found() } Ok(disk_file) => { let reader_stream = ReaderStream::new(disk_file); let stream_body = StreamBody::new(reader_stream.map_ok(Frame::data)); let boxed_body = stream_body.boxed(); let response = Response::builder().body(boxed_body).unwrap(); with_headers( response, vec![ ( CONTENT_DISPOSITION, &format!("attachment; filename={}", file.name), ), (CONTENT_LENGTH, &file.content_length.to_string()), ], ) } } } fn not_found() -> Response> { response(StatusCode::NOT_FOUND, templates::NOT_FOUND.to_string()) } fn bad_request() -> Response> { response(StatusCode::BAD_REQUEST, templates::BAD_REQUEST.to_string()) } fn internal_server_error() -> Response> { response( StatusCode::INTERNAL_SERVER_ERROR, templates::INTERNAL_SERVER_ERROR.to_string(), ) } pub fn with_headers( response: Response>, headers: Vec<(HeaderName, &str)>, ) -> Response> { let mut response = response; let response_headers = response.headers_mut(); for (name, value) in headers { response_headers.insert(name, HeaderValue::from_str(value).unwrap()); } response }