From 2ee900eed41aebeb0f6f791f96bacb21779e6ac0 Mon Sep 17 00:00:00 2001 From: Joris Date: Fri, 7 Jun 2024 12:34:47 +0200 Subject: Add job to remove expired files --- src/db.rs | 84 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 82 insertions(+), 2 deletions(-) (limited to 'src/db.rs') diff --git a/src/db.rs b/src/db.rs index e1bb7e3..ab699c6 100644 --- a/src/db.rs +++ b/src/db.rs @@ -1,3 +1,4 @@ +use chrono::{DateTime, Local}; use tokio_rusqlite::{params, Connection, Result}; use crate::model::{decode_datetime, encode_datetime, File}; @@ -64,6 +65,49 @@ pub async fn get_file(conn: &Connection, file_id: String) -> Result .await } +pub async fn list_expire_after(conn: &Connection, time: DateTime) -> Result> { + conn.call(move |conn| { + let mut stmt = conn.prepare( + r#" + SELECT + id + FROM + files + WHERE + expires_at > ? + "#, + )?; + + let iter = stmt.query_map([encode_datetime(time)], |row| row.get(0))?; + + let mut res = vec![]; + for id in iter { + res.push(id?) + } + Ok(res) + }) + .await +} + +pub async fn remove_expire_before(conn: &Connection, time: DateTime) -> Result<()> { + conn.call(move |conn| { + conn.execute( + &format!( + r#" + DELETE FROM + files + WHERE + expires_at <= ? + "# + ), + [encode_datetime(time)] + )?; + + Ok(()) + }) + .await +} + fn rusqlite_other_error(msg: &str) -> tokio_rusqlite::Error { tokio_rusqlite::Error::Other(msg.into()) } @@ -71,8 +115,9 @@ fn rusqlite_other_error(msg: &str) -> tokio_rusqlite::Error { #[cfg(test)] mod tests { use super::*; - use crate::model::local_time; + use crate::model::{generate_file_id, local_time}; use chrono::Duration; + use std::collections::HashSet; use std::ops::Add; #[tokio::test] @@ -105,9 +150,44 @@ mod tests { assert!(file_res.unwrap().is_none()); } + #[tokio::test] + async fn test_list_non_expirable() { + let conn = get_connection().await; + let file_expire = dummy_file(Duration::zero()); + let file_no_expire_1 = dummy_file(Duration::minutes(1)); + let file_no_expire_2 = dummy_file(Duration::minutes(1)); + assert!(insert_file(&conn, file_expire.clone()).await.is_ok()); + assert!(insert_file(&conn, file_no_expire_1.clone()).await.is_ok()); + assert!(insert_file(&conn, file_no_expire_2.clone()).await.is_ok()); + let list = list_expire_after(&conn, Local::now()).await; + assert!(list.is_ok()); + assert_eq!( + HashSet::from_iter(list.unwrap().iter()), + HashSet::from([&file_no_expire_1.id, &file_no_expire_2.id]) + ) + } + + #[tokio::test] + async fn test_remove_expire_before() { + let conn = get_connection().await; + let file_1 = dummy_file(Duration::zero()); + let file_2 = dummy_file(Duration::zero()); + let file_3 = dummy_file(Duration::minutes(1)); + let file_4 = dummy_file(Duration::minutes(1)); + assert!(insert_file(&conn, file_1.clone()).await.is_ok()); + assert!(insert_file(&conn, file_2.clone()).await.is_ok()); + assert!(insert_file(&conn, file_3.clone()).await.is_ok()); + assert!(insert_file(&conn, file_4.clone()).await.is_ok()); + assert!(remove_expire_before(&conn, Local::now()).await.is_ok()); + assert!(get_file(&conn, file_1.id).await.unwrap().is_none()); + assert!(get_file(&conn, file_2.id).await.unwrap().is_none()); + assert!(get_file(&conn, file_3.id).await.unwrap().is_some()); + assert!(get_file(&conn, file_4.id).await.unwrap().is_some()); + } + fn dummy_file(td: Duration) -> File { File { - id: "1234".to_string(), + id: generate_file_id(), name: "foo".to_string(), expires_at: local_time().add(td), content_length: 100, -- cgit v1.2.3