use tokio_rusqlite::{params, Connection, Result}; use crate::model::{decode_datetime, encode_datetime, File}; pub async fn insert_file(conn: &Connection, file: File) -> Result<()> { conn.call(move |conn| { conn.execute( r#" INSERT INTO files(id, created_at, expires_at, filename, content_length) VALUES (?1, datetime(), ?2, ?3, ?4) "#, params![ file.id, encode_datetime(file.expires_at), file.name, file.content_length ], ) .map_err(tokio_rusqlite::Error::Rusqlite) }) .await .map(|_| ()) } pub async fn get_file(conn: &Connection, file_id: String) -> Result> { conn.call(move |conn| { let mut stmt = conn.prepare( r#" SELECT filename, expires_at, content_length FROM files WHERE id = ? AND expires_at > datetime() "#, )?; let mut iter = stmt.query_map([file_id.clone()], |row| { let res: (String, String, usize) = (row.get(0)?, row.get(1)?, row.get(2)?); Ok(res) })?; match iter.next() { Some(Ok((filename, expires_at, content_length))) => { match decode_datetime(&expires_at) { Some(expires_at) => Ok(Some(File { id: file_id.clone(), name: filename, expires_at, content_length, })), _ => Err(rusqlite_other_error(&format!( "Error decoding datetime: {expires_at}" ))), } } Some(_) => Err(rusqlite_other_error("Error reading file in DB")), None => Ok(None), } }) .await } fn rusqlite_other_error(msg: &str) -> tokio_rusqlite::Error { tokio_rusqlite::Error::Other(msg.into()) } #[cfg(test)] mod tests { use super::*; use crate::model::local_time; use chrono::Duration; use std::ops::Add; #[tokio::test] async fn test_insert_and_get_file() { let conn = get_connection().await; let file = dummy_file(Duration::minutes(1)); assert!(insert_file(&conn, file.clone()).await.is_ok()); let file_res = get_file(&conn, file.id.clone()).await; assert!(file_res.is_ok()); assert_eq!(file_res.unwrap(), Some(file)); } #[tokio::test] async fn test_expired_file_err() { let conn = get_connection().await; let file = dummy_file(Duration::zero()); assert!(insert_file(&conn, file.clone()).await.is_ok()); let file_res = get_file(&conn, file.id.clone()).await; assert!(file_res.is_ok()); assert!(file_res.unwrap().is_none()); } #[tokio::test] async fn test_wrong_file_err() { let conn = get_connection().await; let file = dummy_file(Duration::minutes(1)); assert!(insert_file(&conn, file.clone()).await.is_ok()); let file_res = get_file(&conn, "wrong-id".to_string()).await; assert!(file_res.is_ok()); assert!(file_res.unwrap().is_none()); } fn dummy_file(td: Duration) -> File { File { id: "1234".to_string(), name: "foo".to_string(), expires_at: local_time().add(td), content_length: 100, } } async fn get_connection() -> Connection { let conn = Connection::open_in_memory().await.unwrap(); let init_db = tokio::fs::read_to_string("init-db.sql").await.unwrap(); let res = conn.call(move |conn| Ok(conn.execute(&init_db, []))).await; assert!(res.is_ok()); conn } }