aboutsummaryrefslogtreecommitdiff
path: root/src/db.rs
diff options
context:
space:
mode:
authorJoris2024-06-02 14:38:13 +0200
committerJoris2024-06-02 14:38:22 +0200
commit1019ea1ed341e3a7769c046aa0be5764789360b6 (patch)
tree1a0d8a4f00cff252d661c42fc23ed4c19795da6f /src/db.rs
parente8da9790dc6d55cd2e8883322cdf9a7bf5b4f5b7 (diff)
Migrate to Rust and Hyper
With sanic, downloading a file locally is around ten times slower than with Rust and hyper. Maybe `pypy` could have helped, but I didn’t succeed to set it up quickly with the dependencies.
Diffstat (limited to 'src/db.rs')
-rw-r--r--src/db.rs125
1 files changed, 125 insertions, 0 deletions
diff --git a/src/db.rs b/src/db.rs
new file mode 100644
index 0000000..e1bb7e3
--- /dev/null
+++ b/src/db.rs
@@ -0,0 +1,125 @@
+use tokio_rusqlite::{params, Connection, Result};
+
+use crate::model::{decode_datetime, encode_datetime, File};
+
+pub async fn insert_file(conn: &Connection, file: File) -> Result<()> {
+ conn.call(move |conn| {
+ conn.execute(
+ r#"
+ INSERT INTO
+ files(id, created_at, expires_at, filename, content_length)
+ VALUES
+ (?1, datetime(), ?2, ?3, ?4)
+ "#,
+ params![
+ file.id,
+ encode_datetime(file.expires_at),
+ file.name,
+ file.content_length
+ ],
+ )
+ .map_err(tokio_rusqlite::Error::Rusqlite)
+ })
+ .await
+ .map(|_| ())
+}
+
+pub async fn get_file(conn: &Connection, file_id: String) -> Result<Option<File>> {
+ conn.call(move |conn| {
+ let mut stmt = conn.prepare(
+ r#"
+ SELECT
+ filename, expires_at, content_length
+ FROM
+ files
+ WHERE
+ id = ?
+ AND expires_at > datetime()
+ "#,
+ )?;
+
+ let mut iter = stmt.query_map([file_id.clone()], |row| {
+ let res: (String, String, usize) = (row.get(0)?, row.get(1)?, row.get(2)?);
+ Ok(res)
+ })?;
+
+ match iter.next() {
+ Some(Ok((filename, expires_at, content_length))) => {
+ match decode_datetime(&expires_at) {
+ Some(expires_at) => Ok(Some(File {
+ id: file_id.clone(),
+ name: filename,
+ expires_at,
+ content_length,
+ })),
+ _ => Err(rusqlite_other_error(&format!(
+ "Error decoding datetime: {expires_at}"
+ ))),
+ }
+ }
+ Some(_) => Err(rusqlite_other_error("Error reading file in DB")),
+ None => Ok(None),
+ }
+ })
+ .await
+}
+
+fn rusqlite_other_error(msg: &str) -> tokio_rusqlite::Error {
+ tokio_rusqlite::Error::Other(msg.into())
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::model::local_time;
+ use chrono::Duration;
+ use std::ops::Add;
+
+ #[tokio::test]
+ async fn test_insert_and_get_file() {
+ let conn = get_connection().await;
+ let file = dummy_file(Duration::minutes(1));
+ assert!(insert_file(&conn, file.clone()).await.is_ok());
+ let file_res = get_file(&conn, file.id.clone()).await;
+ assert!(file_res.is_ok());
+ assert_eq!(file_res.unwrap(), Some(file));
+ }
+
+ #[tokio::test]
+ async fn test_expired_file_err() {
+ let conn = get_connection().await;
+ let file = dummy_file(Duration::zero());
+ assert!(insert_file(&conn, file.clone()).await.is_ok());
+ let file_res = get_file(&conn, file.id.clone()).await;
+ assert!(file_res.is_ok());
+ assert!(file_res.unwrap().is_none());
+ }
+
+ #[tokio::test]
+ async fn test_wrong_file_err() {
+ let conn = get_connection().await;
+ let file = dummy_file(Duration::minutes(1));
+ assert!(insert_file(&conn, file.clone()).await.is_ok());
+ let file_res = get_file(&conn, "wrong-id".to_string()).await;
+ assert!(file_res.is_ok());
+ assert!(file_res.unwrap().is_none());
+ }
+
+ fn dummy_file(td: Duration) -> File {
+ File {
+ id: "1234".to_string(),
+ name: "foo".to_string(),
+ expires_at: local_time().add(td),
+ content_length: 100,
+ }
+ }
+
+ async fn get_connection() -> Connection {
+ let conn = Connection::open_in_memory().await.unwrap();
+ let init_db = tokio::fs::read_to_string("init-db.sql").await.unwrap();
+
+ let res = conn.call(move |conn| Ok(conn.execute(&init_db, []))).await;
+ assert!(res.is_ok());
+ conn
+ }
+}