From 2ee900eed41aebeb0f6f791f96bacb21779e6ac0 Mon Sep 17 00:00:00 2001 From: Joris Date: Fri, 7 Jun 2024 12:34:47 +0200 Subject: Add job to remove expired files --- src/db.rs | 84 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-- src/jobs.rs | 77 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ src/main.rs | 3 +++ 3 files changed, 162 insertions(+), 2 deletions(-) create mode 100644 src/jobs.rs (limited to 'src') diff --git a/src/db.rs b/src/db.rs index e1bb7e3..ab699c6 100644 --- a/src/db.rs +++ b/src/db.rs @@ -1,3 +1,4 @@ +use chrono::{DateTime, Local}; use tokio_rusqlite::{params, Connection, Result}; use crate::model::{decode_datetime, encode_datetime, File}; @@ -64,6 +65,49 @@ pub async fn get_file(conn: &Connection, file_id: String) -> Result .await } +pub async fn list_expire_after(conn: &Connection, time: DateTime) -> Result> { + conn.call(move |conn| { + let mut stmt = conn.prepare( + r#" + SELECT + id + FROM + files + WHERE + expires_at > ? + "#, + )?; + + let iter = stmt.query_map([encode_datetime(time)], |row| row.get(0))?; + + let mut res = vec![]; + for id in iter { + res.push(id?) + } + Ok(res) + }) + .await +} + +pub async fn remove_expire_before(conn: &Connection, time: DateTime) -> Result<()> { + conn.call(move |conn| { + conn.execute( + &format!( + r#" + DELETE FROM + files + WHERE + expires_at <= ? + "# + ), + [encode_datetime(time)] + )?; + + Ok(()) + }) + .await +} + fn rusqlite_other_error(msg: &str) -> tokio_rusqlite::Error { tokio_rusqlite::Error::Other(msg.into()) } @@ -71,8 +115,9 @@ fn rusqlite_other_error(msg: &str) -> tokio_rusqlite::Error { #[cfg(test)] mod tests { use super::*; - use crate::model::local_time; + use crate::model::{generate_file_id, local_time}; use chrono::Duration; + use std::collections::HashSet; use std::ops::Add; #[tokio::test] @@ -105,9 +150,44 @@ mod tests { assert!(file_res.unwrap().is_none()); } + #[tokio::test] + async fn test_list_non_expirable() { + let conn = get_connection().await; + let file_expire = dummy_file(Duration::zero()); + let file_no_expire_1 = dummy_file(Duration::minutes(1)); + let file_no_expire_2 = dummy_file(Duration::minutes(1)); + assert!(insert_file(&conn, file_expire.clone()).await.is_ok()); + assert!(insert_file(&conn, file_no_expire_1.clone()).await.is_ok()); + assert!(insert_file(&conn, file_no_expire_2.clone()).await.is_ok()); + let list = list_expire_after(&conn, Local::now()).await; + assert!(list.is_ok()); + assert_eq!( + HashSet::from_iter(list.unwrap().iter()), + HashSet::from([&file_no_expire_1.id, &file_no_expire_2.id]) + ) + } + + #[tokio::test] + async fn test_remove_expire_before() { + let conn = get_connection().await; + let file_1 = dummy_file(Duration::zero()); + let file_2 = dummy_file(Duration::zero()); + let file_3 = dummy_file(Duration::minutes(1)); + let file_4 = dummy_file(Duration::minutes(1)); + assert!(insert_file(&conn, file_1.clone()).await.is_ok()); + assert!(insert_file(&conn, file_2.clone()).await.is_ok()); + assert!(insert_file(&conn, file_3.clone()).await.is_ok()); + assert!(insert_file(&conn, file_4.clone()).await.is_ok()); + assert!(remove_expire_before(&conn, Local::now()).await.is_ok()); + assert!(get_file(&conn, file_1.id).await.unwrap().is_none()); + assert!(get_file(&conn, file_2.id).await.unwrap().is_none()); + assert!(get_file(&conn, file_3.id).await.unwrap().is_some()); + assert!(get_file(&conn, file_4.id).await.unwrap().is_some()); + } + fn dummy_file(td: Duration) -> File { File { - id: "1234".to_string(), + id: generate_file_id(), name: "foo".to_string(), expires_at: local_time().add(td), content_length: 100, diff --git a/src/jobs.rs b/src/jobs.rs new file mode 100644 index 0000000..a01a70d --- /dev/null +++ b/src/jobs.rs @@ -0,0 +1,77 @@ +use std::collections::HashSet; +use std::path::Path; + +use chrono::Local; +use tokio::fs; +use tokio::time::{sleep, Duration}; +use tokio_rusqlite::Connection; + +use crate::db; + +pub async fn start(db_conn: Connection, files_dir: String) { + loop { + log::info!("Starting removing expired files"); + cleanup_expired(&db_conn, &files_dir).await; + + // Sleeping 1 day + sleep(Duration::from_secs(24 * 60 * 60)).await; + } +} + +async fn cleanup_expired(db_conn: &Connection, files_dir: &String) { + let time = Local::now(); + + match read_dir(files_dir).await { + Err(msg) => log::error!("Listing files: {msg}"), + Ok(files) => match db::list_expire_after(db_conn, time).await { + Err(msg) => log::error!("Getting non expirable files: {msg}"), + Ok(non_expirable) => { + let non_expirable = HashSet::::from_iter(non_expirable.iter().cloned()); + let expired_ids = files.difference(&non_expirable); + let count = remove_files(files_dir, expired_ids.cloned()).await; + log::info!("Removed {} files", count); + if let Err(msg) = db::remove_expire_before(db_conn, time).await { + log::error!("Removing files: {msg}") + } + } + }, + } +} + +async fn read_dir(files_dir: &String) -> Result, String> { + match fs::read_dir(files_dir).await { + Err(msg) => Err(msg.to_string()), + Ok(mut read_dir) => { + let mut files = HashSet::::new(); + loop { + let entry = read_dir.next_entry().await; + match entry { + Ok(Some(entry)) => match entry.file_name().into_string() { + Ok(filename) => { + files.insert(filename.clone()); + } + Err(_) => log::error!("Decoding filename"), + }, + Ok(None) => break, + Err(msg) => log::error!("File entry: {msg}"), + } + } + Ok(files) + } + } +} + +async fn remove_files(files_dir: &String, ids: I) -> i32 +where + I: Iterator, +{ + let mut count = 0; + for id in ids { + let path = Path::new(&files_dir).join(id.clone()); + match fs::remove_file(path).await { + Err(msg) => log::error!("Removing file: {msg}"), + Ok(_) => count += 1 + } + } + count +} diff --git a/src/main.rs b/src/main.rs index 27da278..b2af6de 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,6 +8,7 @@ use tokio::net::TcpListener; use tokio_rusqlite::Connection; mod db; +mod jobs; mod model; mod routes; mod templates; @@ -27,6 +28,8 @@ async fn main() -> std::result::Result<(), Box> { .await .expect("Error while openning DB conection"); + tokio::spawn(jobs::start(db_conn.clone(), files_dir.clone())); + let addr: SocketAddr = format!("{host}:{port}") .parse() .unwrap_or_else(|_| panic!("Invalid address: {host}:{port}")); -- cgit v1.2.3