aboutsummaryrefslogtreecommitdiff
path: root/src/db
diff options
context:
space:
mode:
authorJoris2020-05-31 17:39:28 +0200
committerJoris2020-05-31 17:39:28 +0200
commit1bed85a9b107d1b03e71b848829cb7b1f33060f4 (patch)
tree7e9bc7dd0813b9426e293dde7913717746a479f8 /src/db
parenta585e507cbe2c05cc846013cafe433953e514295 (diff)
Prevent removing a tag being used
Diffstat (limited to 'src/db')
-rw-r--r--src/db/init.py8
-rw-r--r--src/db/task_tags.py9
2 files changed, 14 insertions, 3 deletions
diff --git a/src/db/init.py b/src/db/init.py
index 6b4cbea..77920cf 100644
--- a/src/db/init.py
+++ b/src/db/init.py
@@ -8,9 +8,9 @@ def init(path):
database = sqlite3.connect(path)
- if is_db_new:
+ cursor = database.cursor()
- cursor = database.cursor()
+ if is_db_new:
cursor.execute(
" CREATE TABLE IF NOT EXISTS tasks("
@@ -43,6 +43,8 @@ def init(path):
" PRIMARY KEY (task_id, tag_id)"
" )")
- database.commit()
+ cursor.execute("PRAGMA foreign_keys = ON")
+
+ database.commit()
return database
diff --git a/src/db/task_tags.py b/src/db/task_tags.py
index 34366e0..93dc627 100644
--- a/src/db/task_tags.py
+++ b/src/db/task_tags.py
@@ -4,6 +4,15 @@ from typing import List
from model.task_tag import TaskTag
+def one_is_used(cursor: Cursor, tag_ids: List[int]) -> bool:
+ if len(tag_ids) >= 1:
+ cursor.execute(
+ "SELECT task_id FROM task_tags WHERE tag_id IN (%s) LIMIT 1" % ",".join("?"*len(tag_ids)),
+ tag_ids)
+ return len(cursor.fetchall()) == 1
+ else:
+ return False
+
def get(cursor: Cursor) -> List[TaskTag]:
cursor.execute("SELECT task_id, tag_id FROM task_tags")
return [TaskTag(r[0], r[1]) for r in cursor.fetchall()]