aboutsummaryrefslogtreecommitdiff
path: root/src/routes.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/routes.rs')
-rw-r--r--src/routes.rs209
1 files changed, 209 insertions, 0 deletions
diff --git a/src/routes.rs b/src/routes.rs
new file mode 100644
index 0000000..b54e565
--- /dev/null
+++ b/src/routes.rs
@@ -0,0 +1,209 @@
+use std::ops::Add;
+use std::path::{Path, PathBuf};
+
+use chrono::Duration;
+use futures_util::TryStreamExt;
+use http_body_util::{combinators::BoxBody, BodyExt, Full, StreamBody};
+use hyper::body::{Bytes, Frame, Incoming};
+use hyper::header::{HeaderName, HeaderValue, CONTENT_DISPOSITION, CONTENT_LENGTH, CONTENT_TYPE};
+use hyper::{Method, Request, Response, Result, StatusCode};
+use tokio::io::AsyncWriteExt;
+use tokio::{fs, fs::File};
+use tokio_rusqlite::Connection;
+use tokio_util::io::ReaderStream;
+
+use crate::db;
+use crate::model;
+use crate::templates;
+use crate::util;
+
+pub async fn routes(
+ request: Request<Incoming>,
+ db_conn: Connection,
+ authorized_key: String,
+ files_dir: String,
+) -> Result<Response<BoxBody<Bytes, std::io::Error>>> {
+ let path = &request.uri().path().split('/').collect::<Vec<&str>>()[1..];
+ let files_dir = Path::new(&files_dir);
+
+ match (request.method(), path) {
+ (&Method::GET, [""]) => Ok(response(StatusCode::OK, templates::INDEX.to_string())),
+ (&Method::GET, ["static", "main.js"]) => Ok(static_file(
+ include_str!("static/main.js").to_string(),
+ "application/javascript",
+ )),
+ (&Method::GET, ["static", "main.css"]) => Ok(static_file(
+ include_str!("static/main.css").to_string(),
+ "text/css",
+ )),
+ (&Method::POST, [""]) => upload_file(request, db_conn, authorized_key, files_dir).await,
+ (&Method::GET, [file_id]) => get(db_conn, file_id, GetFile::ShowPage, files_dir).await,
+ (&Method::GET, [file_id, "download"]) => {
+ get(db_conn, file_id, GetFile::Download, files_dir).await
+ }
+ _ => Ok(not_found()),
+ }
+}
+
+async fn upload_file(
+ request: Request<Incoming>,
+ db_conn: Connection,
+ authorized_key: String,
+ files_dir: &Path,
+) -> Result<Response<BoxBody<Bytes, std::io::Error>>> {
+ let key = get_header(&request, "X-Key");
+ if key != Some(authorized_key) {
+ log::info!("Unauthorized file upload");
+ Ok(response(
+ StatusCode::UNAUTHORIZED,
+ "Unauthorized".to_string(),
+ ))
+ } else {
+ let file_id = model::generate_file_id();
+ let filename = get_header(&request, "X-Filename").map(|s| util::sanitize_filename(&s));
+ let expiration_days: Option<i64> =
+ get_header(&request, "X-Expiration").and_then(|s| s.parse().ok());
+ let content_length: Option<usize> =
+ get_header(&request, "Content-Length").and_then(|s| s.parse().ok());
+
+ match (filename, expiration_days, content_length) {
+ (Some(filename), Some(expiration_days), Some(content_length)) => {
+ let _ = fs::create_dir(files_dir).await;
+ let path = files_dir.join(&file_id);
+ let mut file = File::create(&path).await.unwrap();
+
+ let mut incoming = request.into_body();
+ while let Some(frame) = incoming.frame().await {
+ if let Ok(data) = frame {
+ let _ = file.write_all(&data.into_data().unwrap()).await;
+ let _ = file.flush().await;
+ }
+ }
+
+ let file = model::File {
+ id: file_id.clone(),
+ name: filename,
+ expires_at: model::local_time().add(Duration::days(expiration_days)),
+ content_length,
+ };
+
+ match db::insert_file(&db_conn, file.clone()).await {
+ Ok(_) => Ok(response(StatusCode::OK, file_id)),
+ Err(msg) => {
+ log::error!("Insert file: {msg}");
+ if let Err(msg) = fs::remove_file(path).await {
+ log::error!("Remove file: {msg}");
+ };
+ Ok(internal_server_error())
+ }
+ }
+ }
+ _ => Ok(bad_request()),
+ }
+ }
+}
+
+fn get_header(request: &Request<Incoming>, header: &str) -> Option<String> {
+ request
+ .headers()
+ .get(header)?
+ .to_str()
+ .ok()
+ .map(|str| str.to_string())
+}
+
+enum GetFile {
+ ShowPage,
+ Download,
+}
+
+async fn get(
+ db_conn: Connection,
+ file_id: &str,
+ get_file: GetFile,
+ files_dir: &Path,
+) -> Result<Response<BoxBody<Bytes, std::io::Error>>> {
+ let file = db::get_file(&db_conn, file_id.to_string()).await;
+ match (get_file, file) {
+ (GetFile::ShowPage, Ok(Some(file))) => {
+ Ok(response(StatusCode::OK, templates::file_page(file)))
+ }
+ (GetFile::Download, Ok(Some(file))) => {
+ let path = files_dir.join(file_id);
+ Ok(stream_file(path, file).await)
+ }
+ (_, Err(msg)) => {
+ log::error!("Getting file: {msg}");
+ Ok(internal_server_error())
+ }
+ (_, Ok(None)) => Ok(not_found()),
+ }
+}
+
+fn static_file(text: String, content_type: &str) -> Response<BoxBody<Bytes, std::io::Error>> {
+ let response = Response::builder()
+ .body(Full::new(text.into()).map_err(|e| match e {}).boxed())
+ .unwrap();
+ with_headers(response, vec![(CONTENT_TYPE, content_type)])
+}
+
+fn response(status_code: StatusCode, text: String) -> Response<BoxBody<Bytes, std::io::Error>> {
+ Response::builder()
+ .status(status_code)
+ .body(Full::new(text.into()).map_err(|e| match e {}).boxed())
+ .unwrap()
+}
+
+async fn stream_file(path: PathBuf, file: model::File) -> Response<BoxBody<Bytes, std::io::Error>> {
+ match File::open(path).await {
+ Err(e) => {
+ log::error!("Unable to open file: {e}");
+ not_found()
+ }
+ Ok(disk_file) => {
+ let reader_stream = ReaderStream::new(disk_file);
+ let stream_body = StreamBody::new(reader_stream.map_ok(Frame::data));
+ let boxed_body = stream_body.boxed();
+
+ let response = Response::builder().body(boxed_body).unwrap();
+
+ with_headers(
+ response,
+ vec![
+ (
+ CONTENT_DISPOSITION,
+ &format!("attachment; filename={}", file.name),
+ ),
+ (CONTENT_LENGTH, &file.content_length.to_string()),
+ ],
+ )
+ }
+ }
+}
+
+fn not_found() -> Response<BoxBody<Bytes, std::io::Error>> {
+ response(StatusCode::NOT_FOUND, templates::NOT_FOUND.to_string())
+}
+
+fn bad_request() -> Response<BoxBody<Bytes, std::io::Error>> {
+ response(StatusCode::BAD_REQUEST, templates::BAD_REQUEST.to_string())
+}
+
+fn internal_server_error() -> Response<BoxBody<Bytes, std::io::Error>> {
+ response(
+ StatusCode::INTERNAL_SERVER_ERROR,
+ templates::INTERNAL_SERVER_ERROR.to_string(),
+ )
+}
+
+pub fn with_headers(
+ response: Response<BoxBody<Bytes, std::io::Error>>,
+ headers: Vec<(HeaderName, &str)>,
+) -> Response<BoxBody<Bytes, std::io::Error>> {
+ let mut response = response;
+ let response_headers = response.headers_mut();
+ for (name, value) in headers {
+ response_headers.insert(name, HeaderValue::from_str(value).unwrap());
+ }
+ response
+}