summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--jimbrella/umbrellas.py46
1 files changed, 24 insertions, 22 deletions
diff --git a/jimbrella/umbrellas.py b/jimbrella/umbrellas.py
index 6da7184..ff88277 100644
--- a/jimbrella/umbrellas.py
+++ b/jimbrella/umbrellas.py
@@ -1,5 +1,6 @@
import sqlite3
from datetime import datetime, timedelta
+from typing import Union
from .utils import human_datetime, human_timedelta
from .config import DUE_HOURS, ADMIN_LOG_PATH
from .exceptions import *
@@ -42,12 +43,21 @@ class Umbrellas:
"""
self.path = path
- def read(self) -> list:
+ def read(self, umbid=None) -> Union[dict, list]:
+ """Read umbrella data from database.
+
+ If umbid is an integer, returns dict pertaining to umbrella #<umbid>.
+ If umbid is None, returns list of dicts for all umbrellas.
+ """
db = sqlite3.connect(self.path)
- db.row_factory = sqlite.Row
- umbrellas = db.execute("SELECT * FROM Umbrellas").fetchall()
+ db.row_factory = sqlite3.Row
+ if umbid is None:
+ data = db.execute("SELECT * FROM Umbrellas").fetchall()
+ else:
+ data = db.execute("SELECT * FROM Umbrellas WHERE id = ?", (umbid,)).fetchone()
+
db.close()
- return umbrellas
+ return data
def update(self, umb) -> dict:
"""Update Umbrella table with new data given in `umb`.
@@ -76,7 +86,7 @@ class Umbrellas:
# check if umbrella #<id> exists in database
umbid = umb["id"]
- umb_in_db = db.execute("SELECT * FROM Umbrellas WHERE id = ?", umbid).fetchone()
+ umb_in_db = db.execute("SELECT * FROM Umbrellas WHERE id = ?", (umbid,)).fetchone()
if umb_in_db is None:
raise UmbrellaNotFoundError(umbid)
@@ -89,7 +99,7 @@ class Umbrellas:
if umb_in_db["status"] != umb["status"]:
diff["status"] = (umb_in_db["status"], umb["status"])
- db.execute("UPDATE Umbrellas SET status = ? WHERE id = ?", status, umbid)
+ db.execute("UPDATE Umbrellas SET status = ? WHERE id = ?", (status, umbid))
else:
raise UmbrellaValueError("status")
@@ -106,9 +116,9 @@ class Umbrellas:
db.execute(
"UPDATE Umbrellas SET ? = ? WHERE id = ?",
- col,
+ (col,
umb[col],
- umbid,
+ umbid,)
)
if "lent_at" in umb:
@@ -130,8 +140,8 @@ class Umbrellas:
db.execute(
"UPDATE Umbrellas SET lent_at = ? WHERE id = ?",
- lent_at.isoformat(timespec="milliseconds"),
- umbid,
+ ( lent_at.isoformat(timespec="milliseconds"),
+ umbid,)
)
else:
# discard unneeded fields
@@ -142,7 +152,7 @@ class Umbrellas:
"tenant_email",
"lent_at",
):
- db.execute("UPDATE Umbrellas SET ? = NULL WHERE id = ?", col, umbid)
+ db.execute("UPDATE Umbrellas SET ? = NULL WHERE id = ?", (col, umbid))
# now that new data are validated, commit the SQL transaction
db.commit()
@@ -153,10 +163,7 @@ class Umbrellas:
self, umbid, date, tenant_name, tenant_id, tenant_phone="", tenant_email=""
) -> None:
"""When a user has borrowed an umbrella."""
- db = sqlite3.connect(self.path)
- db.row_factory = sqlite3.Row
- umb = db.execute("SELECT * FROM Umbrellas WHERE id = ?", umbid)
- db.close()
+ umb = self.read(umbid)
if umb is None:
raise UmbrellaNotFoundError(umbid)
@@ -181,10 +188,7 @@ class Umbrellas:
`tenant_name` and `tenant_id` are used to verify if the umbrella is returned by the same
person who borrowed it.
"""
- db = sqlite3.connect(self.path)
- db.row_factory = sqlite3.Row
- umb = db.execute("SELECT * FROM Umbrellas WHERE id = ?", umbid)
- db.close()
+ umb = self.read(umbid)
if umb is None:
raise UmbrellaNotFoundError(umbid)
@@ -202,9 +206,7 @@ class Umbrellas:
def mark_overdue(self, umbid) -> None:
"""When an umbrella is overdue, change its status to "overdue"."""
- db = sqlite3.connect(self.path)
- db.row_factory = sqlite3.Row
- umb = db.execute("SELECT * FROM Umbrellas WHERE id = ?", umbid)
+ umb = self.read(umbid)
if umb is None:
raise UmbrellaNotFoundError(umbid)