From 1bed85a9b107d1b03e71b848829cb7b1f33060f4 Mon Sep 17 00:00:00 2001 From: Joris Date: Sun, 31 May 2020 17:39:28 +0200 Subject: Prevent removing a tag being used --- src/db/init.py | 8 +++++--- src/db/task_tags.py | 9 +++++++++ 2 files changed, 14 insertions(+), 3 deletions(-) (limited to 'src/db') diff --git a/src/db/init.py b/src/db/init.py index 6b4cbea..77920cf 100644 --- a/src/db/init.py +++ b/src/db/init.py @@ -8,9 +8,9 @@ def init(path): database = sqlite3.connect(path) - if is_db_new: + cursor = database.cursor() - cursor = database.cursor() + if is_db_new: cursor.execute( " CREATE TABLE IF NOT EXISTS tasks(" @@ -43,6 +43,8 @@ def init(path): " PRIMARY KEY (task_id, tag_id)" " )") - database.commit() + cursor.execute("PRAGMA foreign_keys = ON") + + database.commit() return database diff --git a/src/db/task_tags.py b/src/db/task_tags.py index 34366e0..93dc627 100644 --- a/src/db/task_tags.py +++ b/src/db/task_tags.py @@ -4,6 +4,15 @@ from typing import List from model.task_tag import TaskTag +def one_is_used(cursor: Cursor, tag_ids: List[int]) -> bool: + if len(tag_ids) >= 1: + cursor.execute( + "SELECT task_id FROM task_tags WHERE tag_id IN (%s) LIMIT 1" % ",".join("?"*len(tag_ids)), + tag_ids) + return len(cursor.fetchall()) == 1 + else: + return False + def get(cursor: Cursor) -> List[TaskTag]: cursor.execute("SELECT task_id, tag_id FROM task_tags") return [TaskTag(r[0], r[1]) for r in cursor.fetchall()] -- cgit v1.2.3