aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJoris2020-05-31 17:39:28 +0200
committerJoris2020-05-31 17:39:28 +0200
commit1bed85a9b107d1b03e71b848829cb7b1f33060f4 (patch)
tree7e9bc7dd0813b9426e293dde7913717746a479f8 /src
parenta585e507cbe2c05cc846013cafe433953e514295 (diff)
Prevent removing a tag being used
Diffstat (limited to 'src')
-rw-r--r--src/db/init.py8
-rw-r--r--src/db/task_tags.py9
-rw-r--r--src/gui/tags/panel/table/menu.py29
-rw-r--r--src/gui/tags/panel/table/model.py1
-rw-r--r--src/gui/tags/panel/table/widget.py11
-rw-r--r--src/gui/tasks/table/widget.py25
-rw-r--r--src/service/tasks.py8
7 files changed, 61 insertions, 30 deletions
diff --git a/src/db/init.py b/src/db/init.py
index 6b4cbea..77920cf 100644
--- a/src/db/init.py
+++ b/src/db/init.py
@@ -8,9 +8,9 @@ def init(path):
database = sqlite3.connect(path)
- if is_db_new:
+ cursor = database.cursor()
- cursor = database.cursor()
+ if is_db_new:
cursor.execute(
" CREATE TABLE IF NOT EXISTS tasks("
@@ -43,6 +43,8 @@ def init(path):
" PRIMARY KEY (task_id, tag_id)"
" )")
- database.commit()
+ cursor.execute("PRAGMA foreign_keys = ON")
+
+ database.commit()
return database
diff --git a/src/db/task_tags.py b/src/db/task_tags.py
index 34366e0..93dc627 100644
--- a/src/db/task_tags.py
+++ b/src/db/task_tags.py
@@ -4,6 +4,15 @@ from typing import List
from model.task_tag import TaskTag
+def one_is_used(cursor: Cursor, tag_ids: List[int]) -> bool:
+ if len(tag_ids) >= 1:
+ cursor.execute(
+ "SELECT task_id FROM task_tags WHERE tag_id IN (%s) LIMIT 1" % ",".join("?"*len(tag_ids)),
+ tag_ids)
+ return len(cursor.fetchall()) == 1
+ else:
+ return False
+
def get(cursor: Cursor) -> List[TaskTag]:
cursor.execute("SELECT task_id, tag_id FROM task_tags")
return [TaskTag(r[0], r[1]) for r in cursor.fetchall()]
diff --git a/src/gui/tags/panel/table/menu.py b/src/gui/tags/panel/table/menu.py
index 2238444..6bf812e 100644
--- a/src/gui/tags/panel/table/menu.py
+++ b/src/gui/tags/panel/table/menu.py
@@ -1,25 +1,36 @@
from PyQt5 import QtWidgets, QtCore
+from model.tag import Tag, ValidTagForm
+import database
import db.tags
+import db.task_tags
import gui.tags.panel.dialog
-from model.tag import Tag, ValidTagForm
def open(table, update_tag_signal, position):
rows = set([index.row() for index in table.selectedIndexes()])
menu = QtWidgets.QMenu(table)
+ actions = 0
+
if len(rows) == 1:
modify_action = menu.addAction(gui.icon.dialog_open(menu.style()), 'modify')
+ actions += 1
else:
modify_action = QtWidgets.QAction(menu)
- delete_action = menu.addAction(gui.icon.trash(menu.style()), 'delete')
+ tags = table.model().row_ids(rows)
+ if not db.task_tags.one_is_used(database.cursor(), tags):
+ delete_action = menu.addAction(gui.icon.trash(menu.style()), 'delete')
+ actions += 1
+ else:
+ delete_action = QtWidgets.QAction(menu)
- action = menu.exec_(table.mapToGlobal(position + QtCore.QPoint(15, 20)))
- if action == modify_action and len(rows) == 1:
- row = list(rows)[0]
- tag = table.model().get_at(row)
- gui.tags.panel.dialog.update(table, update_tag_signal, row, tag).exec_()
- elif action == delete_action:
- gui.tags.panel.dialog.show_delete(table, rows)
+ if actions > 0:
+ action = menu.exec_(table.mapToGlobal(position + QtCore.QPoint(15, 20)))
+ if action == modify_action and len(rows) == 1:
+ row = list(rows)[0]
+ tag = table.model().get_at(row)
+ gui.tags.panel.dialog.update(table, update_tag_signal, row, tag).exec_()
+ elif action == delete_action:
+ gui.tags.panel.dialog.show_delete(table, rows)
diff --git a/src/gui/tags/panel/table/model.py b/src/gui/tags/panel/table/model.py
index 7c66b5d..353f747 100644
--- a/src/gui/tags/panel/table/model.py
+++ b/src/gui/tags/panel/table/model.py
@@ -1,5 +1,6 @@
from PyQt5 import QtCore, QtWidgets, QtGui
from PyQt5.QtCore import Qt
+from typing import List
from model.tag import Tag
import time
diff --git a/src/gui/tags/panel/table/widget.py b/src/gui/tags/panel/table/widget.py
index f0bf82c..0ef67c2 100644
--- a/src/gui/tags/panel/table/widget.py
+++ b/src/gui/tags/panel/table/widget.py
@@ -1,13 +1,14 @@
from PyQt5 import QtWidgets
from PyQt5.QtCore import Qt
+from model.tag import Tag, ValidTagForm
+import database
import db.tags
+import db.task_tags
+import gui.tags.panel.dialog
import gui.tags.panel.signal
import gui.tags.panel.table.menu
import gui.tags.panel.table.model
-import gui.tags.panel.dialog
-from model.tag import Tag, ValidTagForm
-import database
class Widget(QtWidgets.QTableView):
@@ -60,7 +61,9 @@ class Widget(QtWidgets.QTableView):
gui.tags.panel.dialog.update(self, self._update_tag_signal, row, tag).exec_()
elif event.key() == Qt.Key_Delete:
rows = self.get_selected_rows()
- gui.tags.panel.dialog.show_delete(self, rows)
+ tags = self.model().row_ids(rows)
+ if not db.task_tags.one_is_used(database.cursor(), tags):
+ gui.tags.panel.dialog.show_delete(self, rows)
def get_selected_rows(self):
return list(set([index.row() for index in self.selectedIndexes()]))
diff --git a/src/gui/tasks/table/widget.py b/src/gui/tasks/table/widget.py
index 0a8d216..82c0456 100644
--- a/src/gui/tasks/table/widget.py
+++ b/src/gui/tasks/table/widget.py
@@ -99,16 +99,21 @@ class Widget(QtWidgets.QTableWidget):
reverse = is_rev)
def update_task(self, row, task: Task, tags: List[int]):
- # TODO: just update if sort order is not impacted
- # self._tasks[row] = task
- # task_ids = [t.id for t in self._tasks]
- # filtred_task_tags = [tt for tt in self._task_tags if tt.task_id in task_ids]
- # new_task_tags = [TaskTag(task_id=task.id, tag_id=tag_id) for tag_id in tags]
- # self._task_tags = filtred_task_tags + new_task_tags
- # self.update_row(row)
- self.delete_rows([row])
- row = self.insert(task, tags)
- self.selectRow(row)
+ self._tasks[row] = task
+ filtred_task_tags = [tt for tt in self._task_tags if tt.task_id in [t.id for t in self._tasks if t.id != task.id]]
+ new_task_tags = [TaskTag(task_id=task.id, tag_id=tag_id) for tag_id in tags]
+ self._task_tags = filtred_task_tags + new_task_tags
+
+ # Update task in table
+ self.sort()
+ row_after_sort = [i for i in range(len(self._tasks)) if self._tasks[i].id == task.id][0]
+ if row_after_sort == row:
+ self.update_row(row)
+ else:
+ self.removeRow(row)
+ self.insertRow(row_after_sort)
+ self.update_row(row_after_sort)
+ self.selectRow(row_after_sort)
def update_view(self):
for row in range(len(self._tasks)):
diff --git a/src/service/tasks.py b/src/service/tasks.py
index 6c3444b..870002a 100644
--- a/src/service/tasks.py
+++ b/src/service/tasks.py
@@ -10,18 +10,18 @@ def get(cursor) -> List[Task]:
def create(cursor, task_form: ValidTaskForm) -> Task:
task = db.tasks.insert(cursor, task_form)
- new_task_tags = db.task_tags.insert_many(cursor, task.id, task_form.tags)
+ db.task_tags.insert_many(cursor, task.id, task_form.tags)
database.commit()
return task
def update(cursor, task: Task, task_form: ValidTaskForm) -> Task:
- updated_task = db.tasks.update(cursor, task, task_form)
db.task_tags.delete(cursor, [task.id])
- new_task_tags = db.task_tags.insert_many(cursor, task.id, task_form.tags)
+ updated_task = db.tasks.update(cursor, task, task_form)
+ db.task_tags.insert_many(cursor, task.id, task_form.tags)
database.commit()
return updated_task
def delete(cursor, task_ids: List[int]):
- db.tasks.delete(cursor, task_ids)
db.task_tags.delete(cursor, task_ids)
+ db.tasks.delete(cursor, task_ids)
database.commit()