From 1bed85a9b107d1b03e71b848829cb7b1f33060f4 Mon Sep 17 00:00:00 2001 From: Joris Date: Sun, 31 May 2020 17:39:28 +0200 Subject: Prevent removing a tag being used --- src/db/init.py | 8 +++++--- src/db/task_tags.py | 9 +++++++++ src/gui/tags/panel/table/menu.py | 29 ++++++++++++++++++++--------- src/gui/tags/panel/table/model.py | 1 + src/gui/tags/panel/table/widget.py | 11 +++++++---- src/gui/tasks/table/widget.py | 25 +++++++++++++++---------- src/service/tasks.py | 8 ++++---- 7 files changed, 61 insertions(+), 30 deletions(-) diff --git a/src/db/init.py b/src/db/init.py index 6b4cbea..77920cf 100644 --- a/src/db/init.py +++ b/src/db/init.py @@ -8,9 +8,9 @@ def init(path): database = sqlite3.connect(path) - if is_db_new: + cursor = database.cursor() - cursor = database.cursor() + if is_db_new: cursor.execute( " CREATE TABLE IF NOT EXISTS tasks(" @@ -43,6 +43,8 @@ def init(path): " PRIMARY KEY (task_id, tag_id)" " )") - database.commit() + cursor.execute("PRAGMA foreign_keys = ON") + + database.commit() return database diff --git a/src/db/task_tags.py b/src/db/task_tags.py index 34366e0..93dc627 100644 --- a/src/db/task_tags.py +++ b/src/db/task_tags.py @@ -4,6 +4,15 @@ from typing import List from model.task_tag import TaskTag +def one_is_used(cursor: Cursor, tag_ids: List[int]) -> bool: + if len(tag_ids) >= 1: + cursor.execute( + "SELECT task_id FROM task_tags WHERE tag_id IN (%s) LIMIT 1" % ",".join("?"*len(tag_ids)), + tag_ids) + return len(cursor.fetchall()) == 1 + else: + return False + def get(cursor: Cursor) -> List[TaskTag]: cursor.execute("SELECT task_id, tag_id FROM task_tags") return [TaskTag(r[0], r[1]) for r in cursor.fetchall()] diff --git a/src/gui/tags/panel/table/menu.py b/src/gui/tags/panel/table/menu.py index 2238444..6bf812e 100644 --- a/src/gui/tags/panel/table/menu.py +++ b/src/gui/tags/panel/table/menu.py @@ -1,25 +1,36 @@ from PyQt5 import QtWidgets, QtCore +from model.tag import Tag, ValidTagForm +import database import db.tags +import db.task_tags import gui.tags.panel.dialog -from model.tag import Tag, ValidTagForm def open(table, update_tag_signal, position): rows = set([index.row() for index in table.selectedIndexes()]) menu = QtWidgets.QMenu(table) + actions = 0 + if len(rows) == 1: modify_action = menu.addAction(gui.icon.dialog_open(menu.style()), 'modify') + actions += 1 else: modify_action = QtWidgets.QAction(menu) - delete_action = menu.addAction(gui.icon.trash(menu.style()), 'delete') + tags = table.model().row_ids(rows) + if not db.task_tags.one_is_used(database.cursor(), tags): + delete_action = menu.addAction(gui.icon.trash(menu.style()), 'delete') + actions += 1 + else: + delete_action = QtWidgets.QAction(menu) - action = menu.exec_(table.mapToGlobal(position + QtCore.QPoint(15, 20))) - if action == modify_action and len(rows) == 1: - row = list(rows)[0] - tag = table.model().get_at(row) - gui.tags.panel.dialog.update(table, update_tag_signal, row, tag).exec_() - elif action == delete_action: - gui.tags.panel.dialog.show_delete(table, rows) + if actions > 0: + action = menu.exec_(table.mapToGlobal(position + QtCore.QPoint(15, 20))) + if action == modify_action and len(rows) == 1: + row = list(rows)[0] + tag = table.model().get_at(row) + gui.tags.panel.dialog.update(table, update_tag_signal, row, tag).exec_() + elif action == delete_action: + gui.tags.panel.dialog.show_delete(table, rows) diff --git a/src/gui/tags/panel/table/model.py b/src/gui/tags/panel/table/model.py index 7c66b5d..353f747 100644 --- a/src/gui/tags/panel/table/model.py +++ b/src/gui/tags/panel/table/model.py @@ -1,5 +1,6 @@ from PyQt5 import QtCore, QtWidgets, QtGui from PyQt5.QtCore import Qt +from typing import List from model.tag import Tag import time diff --git a/src/gui/tags/panel/table/widget.py b/src/gui/tags/panel/table/widget.py index f0bf82c..0ef67c2 100644 --- a/src/gui/tags/panel/table/widget.py +++ b/src/gui/tags/panel/table/widget.py @@ -1,13 +1,14 @@ from PyQt5 import QtWidgets from PyQt5.QtCore import Qt +from model.tag import Tag, ValidTagForm +import database import db.tags +import db.task_tags +import gui.tags.panel.dialog import gui.tags.panel.signal import gui.tags.panel.table.menu import gui.tags.panel.table.model -import gui.tags.panel.dialog -from model.tag import Tag, ValidTagForm -import database class Widget(QtWidgets.QTableView): @@ -60,7 +61,9 @@ class Widget(QtWidgets.QTableView): gui.tags.panel.dialog.update(self, self._update_tag_signal, row, tag).exec_() elif event.key() == Qt.Key_Delete: rows = self.get_selected_rows() - gui.tags.panel.dialog.show_delete(self, rows) + tags = self.model().row_ids(rows) + if not db.task_tags.one_is_used(database.cursor(), tags): + gui.tags.panel.dialog.show_delete(self, rows) def get_selected_rows(self): return list(set([index.row() for index in self.selectedIndexes()])) diff --git a/src/gui/tasks/table/widget.py b/src/gui/tasks/table/widget.py index 0a8d216..82c0456 100644 --- a/src/gui/tasks/table/widget.py +++ b/src/gui/tasks/table/widget.py @@ -99,16 +99,21 @@ class Widget(QtWidgets.QTableWidget): reverse = is_rev) def update_task(self, row, task: Task, tags: List[int]): - # TODO: just update if sort order is not impacted - # self._tasks[row] = task - # task_ids = [t.id for t in self._tasks] - # filtred_task_tags = [tt for tt in self._task_tags if tt.task_id in task_ids] - # new_task_tags = [TaskTag(task_id=task.id, tag_id=tag_id) for tag_id in tags] - # self._task_tags = filtred_task_tags + new_task_tags - # self.update_row(row) - self.delete_rows([row]) - row = self.insert(task, tags) - self.selectRow(row) + self._tasks[row] = task + filtred_task_tags = [tt for tt in self._task_tags if tt.task_id in [t.id for t in self._tasks if t.id != task.id]] + new_task_tags = [TaskTag(task_id=task.id, tag_id=tag_id) for tag_id in tags] + self._task_tags = filtred_task_tags + new_task_tags + + # Update task in table + self.sort() + row_after_sort = [i for i in range(len(self._tasks)) if self._tasks[i].id == task.id][0] + if row_after_sort == row: + self.update_row(row) + else: + self.removeRow(row) + self.insertRow(row_after_sort) + self.update_row(row_after_sort) + self.selectRow(row_after_sort) def update_view(self): for row in range(len(self._tasks)): diff --git a/src/service/tasks.py b/src/service/tasks.py index 6c3444b..870002a 100644 --- a/src/service/tasks.py +++ b/src/service/tasks.py @@ -10,18 +10,18 @@ def get(cursor) -> List[Task]: def create(cursor, task_form: ValidTaskForm) -> Task: task = db.tasks.insert(cursor, task_form) - new_task_tags = db.task_tags.insert_many(cursor, task.id, task_form.tags) + db.task_tags.insert_many(cursor, task.id, task_form.tags) database.commit() return task def update(cursor, task: Task, task_form: ValidTaskForm) -> Task: - updated_task = db.tasks.update(cursor, task, task_form) db.task_tags.delete(cursor, [task.id]) - new_task_tags = db.task_tags.insert_many(cursor, task.id, task_form.tags) + updated_task = db.tasks.update(cursor, task, task_form) + db.task_tags.insert_many(cursor, task.id, task_form.tags) database.commit() return updated_task def delete(cursor, task_ids: List[int]): - db.tasks.delete(cursor, task_ids) db.task_tags.delete(cursor, task_ids) + db.tasks.delete(cursor, task_ids) database.commit() -- cgit v1.2.3