aboutsummaryrefslogtreecommitdiff
path: root/src/db.rs
blob: e1bb7e3cdfa803dac0c8143f417e21e65e53b6a2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
use tokio_rusqlite::{params, Connection, Result};

use crate::model::{decode_datetime, encode_datetime, File};

pub async fn insert_file(conn: &Connection, file: File) -> Result<()> {
    conn.call(move |conn| {
        conn.execute(
            r#"
                INSERT INTO
                    files(id, created_at, expires_at, filename, content_length)
                VALUES
                    (?1, datetime(), ?2, ?3, ?4)
                "#,
            params![
                file.id,
                encode_datetime(file.expires_at),
                file.name,
                file.content_length
            ],
        )
        .map_err(tokio_rusqlite::Error::Rusqlite)
    })
    .await
    .map(|_| ())
}

pub async fn get_file(conn: &Connection, file_id: String) -> Result<Option<File>> {
    conn.call(move |conn| {
        let mut stmt = conn.prepare(
            r#"
            SELECT
                filename, expires_at, content_length
            FROM
                files
            WHERE
                id = ?
                AND expires_at > datetime()
            "#,
        )?;

        let mut iter = stmt.query_map([file_id.clone()], |row| {
            let res: (String, String, usize) = (row.get(0)?, row.get(1)?, row.get(2)?);
            Ok(res)
        })?;

        match iter.next() {
            Some(Ok((filename, expires_at, content_length))) => {
                match decode_datetime(&expires_at) {
                    Some(expires_at) => Ok(Some(File {
                        id: file_id.clone(),
                        name: filename,
                        expires_at,
                        content_length,
                    })),
                    _ => Err(rusqlite_other_error(&format!(
                        "Error decoding datetime: {expires_at}"
                    ))),
                }
            }
            Some(_) => Err(rusqlite_other_error("Error reading file in DB")),
            None => Ok(None),
        }
    })
    .await
}

fn rusqlite_other_error(msg: &str) -> tokio_rusqlite::Error {
    tokio_rusqlite::Error::Other(msg.into())
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::model::local_time;
    use chrono::Duration;
    use std::ops::Add;

    #[tokio::test]
    async fn test_insert_and_get_file() {
        let conn = get_connection().await;
        let file = dummy_file(Duration::minutes(1));
        assert!(insert_file(&conn, file.clone()).await.is_ok());
        let file_res = get_file(&conn, file.id.clone()).await;
        assert!(file_res.is_ok());
        assert_eq!(file_res.unwrap(), Some(file));
    }

    #[tokio::test]
    async fn test_expired_file_err() {
        let conn = get_connection().await;
        let file = dummy_file(Duration::zero());
        assert!(insert_file(&conn, file.clone()).await.is_ok());
        let file_res = get_file(&conn, file.id.clone()).await;
        assert!(file_res.is_ok());
        assert!(file_res.unwrap().is_none());
    }

    #[tokio::test]
    async fn test_wrong_file_err() {
        let conn = get_connection().await;
        let file = dummy_file(Duration::minutes(1));
        assert!(insert_file(&conn, file.clone()).await.is_ok());
        let file_res = get_file(&conn, "wrong-id".to_string()).await;
        assert!(file_res.is_ok());
        assert!(file_res.unwrap().is_none());
    }

    fn dummy_file(td: Duration) -> File {
        File {
            id: "1234".to_string(),
            name: "foo".to_string(),
            expires_at: local_time().add(td),
            content_length: 100,
        }
    }

    async fn get_connection() -> Connection {
        let conn = Connection::open_in_memory().await.unwrap();
        let init_db = tokio::fs::read_to_string("init-db.sql").await.unwrap();

        let res = conn.call(move |conn| Ok(conn.execute(&init_db, []))).await;
        assert!(res.is_ok());
        conn
    }
}