aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/db.rs84
-rw-r--r--src/jobs.rs77
-rw-r--r--src/main.rs3
3 files changed, 162 insertions, 2 deletions
diff --git a/src/db.rs b/src/db.rs
index e1bb7e3..ab699c6 100644
--- a/src/db.rs
+++ b/src/db.rs
@@ -1,3 +1,4 @@
+use chrono::{DateTime, Local};
use tokio_rusqlite::{params, Connection, Result};
use crate::model::{decode_datetime, encode_datetime, File};
@@ -64,6 +65,49 @@ pub async fn get_file(conn: &Connection, file_id: String) -> Result<Option<File>
.await
}
+pub async fn list_expire_after(conn: &Connection, time: DateTime<Local>) -> Result<Vec<String>> {
+ conn.call(move |conn| {
+ let mut stmt = conn.prepare(
+ r#"
+ SELECT
+ id
+ FROM
+ files
+ WHERE
+ expires_at > ?
+ "#,
+ )?;
+
+ let iter = stmt.query_map([encode_datetime(time)], |row| row.get(0))?;
+
+ let mut res = vec![];
+ for id in iter {
+ res.push(id?)
+ }
+ Ok(res)
+ })
+ .await
+}
+
+pub async fn remove_expire_before(conn: &Connection, time: DateTime<Local>) -> Result<()> {
+ conn.call(move |conn| {
+ conn.execute(
+ &format!(
+ r#"
+ DELETE FROM
+ files
+ WHERE
+ expires_at <= ?
+ "#
+ ),
+ [encode_datetime(time)]
+ )?;
+
+ Ok(())
+ })
+ .await
+}
+
fn rusqlite_other_error(msg: &str) -> tokio_rusqlite::Error {
tokio_rusqlite::Error::Other(msg.into())
}
@@ -71,8 +115,9 @@ fn rusqlite_other_error(msg: &str) -> tokio_rusqlite::Error {
#[cfg(test)]
mod tests {
use super::*;
- use crate::model::local_time;
+ use crate::model::{generate_file_id, local_time};
use chrono::Duration;
+ use std::collections::HashSet;
use std::ops::Add;
#[tokio::test]
@@ -105,9 +150,44 @@ mod tests {
assert!(file_res.unwrap().is_none());
}
+ #[tokio::test]
+ async fn test_list_non_expirable() {
+ let conn = get_connection().await;
+ let file_expire = dummy_file(Duration::zero());
+ let file_no_expire_1 = dummy_file(Duration::minutes(1));
+ let file_no_expire_2 = dummy_file(Duration::minutes(1));
+ assert!(insert_file(&conn, file_expire.clone()).await.is_ok());
+ assert!(insert_file(&conn, file_no_expire_1.clone()).await.is_ok());
+ assert!(insert_file(&conn, file_no_expire_2.clone()).await.is_ok());
+ let list = list_expire_after(&conn, Local::now()).await;
+ assert!(list.is_ok());
+ assert_eq!(
+ HashSet::from_iter(list.unwrap().iter()),
+ HashSet::from([&file_no_expire_1.id, &file_no_expire_2.id])
+ )
+ }
+
+ #[tokio::test]
+ async fn test_remove_expire_before() {
+ let conn = get_connection().await;
+ let file_1 = dummy_file(Duration::zero());
+ let file_2 = dummy_file(Duration::zero());
+ let file_3 = dummy_file(Duration::minutes(1));
+ let file_4 = dummy_file(Duration::minutes(1));
+ assert!(insert_file(&conn, file_1.clone()).await.is_ok());
+ assert!(insert_file(&conn, file_2.clone()).await.is_ok());
+ assert!(insert_file(&conn, file_3.clone()).await.is_ok());
+ assert!(insert_file(&conn, file_4.clone()).await.is_ok());
+ assert!(remove_expire_before(&conn, Local::now()).await.is_ok());
+ assert!(get_file(&conn, file_1.id).await.unwrap().is_none());
+ assert!(get_file(&conn, file_2.id).await.unwrap().is_none());
+ assert!(get_file(&conn, file_3.id).await.unwrap().is_some());
+ assert!(get_file(&conn, file_4.id).await.unwrap().is_some());
+ }
+
fn dummy_file(td: Duration) -> File {
File {
- id: "1234".to_string(),
+ id: generate_file_id(),
name: "foo".to_string(),
expires_at: local_time().add(td),
content_length: 100,
diff --git a/src/jobs.rs b/src/jobs.rs
new file mode 100644
index 0000000..a01a70d
--- /dev/null
+++ b/src/jobs.rs
@@ -0,0 +1,77 @@
+use std::collections::HashSet;
+use std::path::Path;
+
+use chrono::Local;
+use tokio::fs;
+use tokio::time::{sleep, Duration};
+use tokio_rusqlite::Connection;
+
+use crate::db;
+
+pub async fn start(db_conn: Connection, files_dir: String) {
+ loop {
+ log::info!("Starting removing expired files");
+ cleanup_expired(&db_conn, &files_dir).await;
+
+ // Sleeping 1 day
+ sleep(Duration::from_secs(24 * 60 * 60)).await;
+ }
+}
+
+async fn cleanup_expired(db_conn: &Connection, files_dir: &String) {
+ let time = Local::now();
+
+ match read_dir(files_dir).await {
+ Err(msg) => log::error!("Listing files: {msg}"),
+ Ok(files) => match db::list_expire_after(db_conn, time).await {
+ Err(msg) => log::error!("Getting non expirable files: {msg}"),
+ Ok(non_expirable) => {
+ let non_expirable = HashSet::<String>::from_iter(non_expirable.iter().cloned());
+ let expired_ids = files.difference(&non_expirable);
+ let count = remove_files(files_dir, expired_ids.cloned()).await;
+ log::info!("Removed {} files", count);
+ if let Err(msg) = db::remove_expire_before(db_conn, time).await {
+ log::error!("Removing files: {msg}")
+ }
+ }
+ },
+ }
+}
+
+async fn read_dir(files_dir: &String) -> Result<HashSet<String>, String> {
+ match fs::read_dir(files_dir).await {
+ Err(msg) => Err(msg.to_string()),
+ Ok(mut read_dir) => {
+ let mut files = HashSet::<String>::new();
+ loop {
+ let entry = read_dir.next_entry().await;
+ match entry {
+ Ok(Some(entry)) => match entry.file_name().into_string() {
+ Ok(filename) => {
+ files.insert(filename.clone());
+ }
+ Err(_) => log::error!("Decoding filename"),
+ },
+ Ok(None) => break,
+ Err(msg) => log::error!("File entry: {msg}"),
+ }
+ }
+ Ok(files)
+ }
+ }
+}
+
+async fn remove_files<I>(files_dir: &String, ids: I) -> i32
+where
+ I: Iterator<Item = String>,
+{
+ let mut count = 0;
+ for id in ids {
+ let path = Path::new(&files_dir).join(id.clone());
+ match fs::remove_file(path).await {
+ Err(msg) => log::error!("Removing file: {msg}"),
+ Ok(_) => count += 1
+ }
+ }
+ count
+}
diff --git a/src/main.rs b/src/main.rs
index 27da278..b2af6de 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -8,6 +8,7 @@ use tokio::net::TcpListener;
use tokio_rusqlite::Connection;
mod db;
+mod jobs;
mod model;
mod routes;
mod templates;
@@ -27,6 +28,8 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
.await
.expect("Error while openning DB conection");
+ tokio::spawn(jobs::start(db_conn.clone(), files_dir.clone()));
+
let addr: SocketAddr = format!("{host}:{port}")
.parse()
.unwrap_or_else(|_| panic!("Invalid address: {host}:{port}"));