]> rtime.felk.cvut.cz Git - hubacji1/coffee-getter.git/commitdiff
Rework db
authorJiri Vlasak <jiri.vlasak.2@cvut.cz>
Fri, 4 Nov 2022 13:51:19 +0000 (14:51 +0100)
committerJiri Vlasak <jiri.vlasak.2@cvut.cz>
Fri, 4 Nov 2022 15:02:06 +0000 (16:02 +0100)
coffee_getter/db.py

index 896ddc9f8bd553a73ffa12bc6aba0241fddfcc1e..d2dfdeb5d3d55a8c91dd99ae8260c0073a4f9d07 100644 (file)
@@ -1,74 +1,82 @@
-# -*- coding: utf-8 -*-
-"""Database access."""
 from sqlite3 import connect
 
-class FileNotSetError(ValueError):
-    pass
 
-class ArgCountError(ValueError):
-    pass
+def Q(q, b="", t=["now", "-7 days"]):
+    """Return db query identified by `q`.
+
+    :param q: Identifier of query.
+    :param b: Optionally, specify the beverage(s).
+    :param t: Optionally, specify the time range.
+    """
+    assert isinstance(t, tuple) or isinstance(t, list)
+    assert len(t) == 2
+    if t[0] == "now":
+        dtf = f"datetime('now', 'localtime', '{t[1]}')"
+        dtt = "datetime('now', 'localtime')"
+    elif t[1] == "now":
+        dtf = f"datetime('now', 'localtime', '{t[0]}')"
+        dtt = "datetime('now', 'localtime')"
+    else:
+        dtf = f"datetime('{t[0]}', 'localtime')"
+        dtt = f"datetime('{t[1]}', 'localtime')"
+    if q == "get_drinks":
+        return f"""
 
-class Db:
-    def __init__(self, dbpath=False):
-        if dbpath:
-            self.con = connect(dbpath)
+        SELECT count(*), flavor
+        FROM coffees
+        WHERE time BETWEEN
+            {dtf}
+            AND {dtt}
+        GROUP BY flavor
+
+        """
+    elif q == "get_drinkers_of":
+        if isinstance(b, tuple) or isinstance(b, list):
+            assert len(b) > 0
+            f = f"WHERE flavor = '{b[0]}'"
+            for i in b[1:]:
+                f += f" OR flavor = '{i}'"
         else:
-            self.con = None
-            raise FileNotSetError("Database file must be set")
+            assert b != ""
+            f = f"WHERE flavor = '{b}'"
+        return f"""
+
+        SELECT count(*), users.name FROM coffees
+        LEFT JOIN identifiers on coffees.id = identifiers.userid
+        LEFT JOIN users on identifiers.userid = users.id
+        {f}
+        AND coffees.time BETWEEN
+            {dtf}
+            AND {dtt}
+        GROUP BY identifiers.userid
+
+        """
+
+
+class Db:
+    def __init__(self, db_path):
+        self.con = connect(db_path)
         self.cur = self.con.cursor()
-        return None
 
     def __del__(self):
         if self.con:
             self.con.close()
 
     def get_top_drinks(self):
-        """Return list of pairs of drink name and count."""
-        q = """
-
-        SELECT count(*), flavor
-        FROM coffees
-        WHERE time BETWEEN
-            datetime('now', 'localtime', '-7 days')
-            AND datetime('now', 'localtime')
-        GROUP BY flavor
-
-        """
+        q = Q("get_drinks")
         top = []
         for (cnt, dn) in self.cur.execute(q):
             top.append((dn, cnt))
         top.sort(key=lambda x: (x[1], x[0]), reverse=True)
-        return top
+        return tuple(top)
 
     def getTopMateDrinkers(self):
         """Return list of pairs of name, count for Mate drinkers."""
         users = {}
-        que = """
-
-        SELECT count(*), users.name FROM coffees
-        LEFT JOIN identifiers on coffees.id = identifiers.id
-        LEFT JOIN users on identifiers.userid = users.id
-        WHERE flavor = 'Club-Mate 0,5 l'
-        AND coffees.time BETWEEN
-            datetime('now', 'localtime', '-7 days') AND
-            datetime('now', 'localtime')
-        GROUP BY identifiers.userid
-
-        """
+        que = Q("get_drinkers_of", "Club-Mate 0,5 l")
         for (cnt, un) in self.cur.execute(que):
             users[un] = cnt * 0.5
-        que = """
-
-        SELECT count(*), users.name FROM coffees
-        LEFT JOIN identifiers on coffees.id = identifiers.id
-        LEFT JOIN users on identifiers.userid = users.id
-        WHERE flavor = 'Club-Mate 0,33 l'
-        AND coffees.time BETWEEN
-            datetime('now', 'localtime', '-7 days') AND
-            datetime('now', 'localtime')
-        GROUP BY identifiers.userid
-
-        """
+        que = Q("get_drinkers_of", "Club-Mate 0,33 l")
         for (cnt, un) in self.cur.execute(que):
             if un in users:
                 users[un] += cnt * 0.33
@@ -78,27 +86,16 @@ class Db:
         for (un, cnt) in users.items():
             top.append((un, cnt))
         top.sort(key=lambda x: (x[1], x[0]), reverse=True)
-        return top
+        return tuple(top)
 
     def get_top_tea_drinkers(self):
         """Return list of pairs of name, count for tea drinkers."""
-        q = """
-
-        SELECT count(*), users.name FROM coffees
-        LEFT JOIN identifiers on coffees.id = identifiers.id
-        LEFT JOIN users on identifiers.userid = users.id
-        WHERE flavor = 'tea'
-        AND coffees.time BETWEEN
-            datetime('now', 'localtime', '-7 days') AND
-            datetime('now', 'localtime')
-        GROUP BY identifiers.userid
-
-        """
+        q = Q("get_drinkers_of", "tea")
         top = []
         for (cnt, un) in self.cur.execute(q):
             top.append((un, cnt))
         top.sort(key=lambda x: (x[1], x[0]), reverse=True)
-        return top
+        return tuple(top)
 
     def getDrunkSum(self, *args, **kwargs):
         """Return list of drunken ``flavor`` from ``dtf`` to ``dtt``.
@@ -108,9 +105,8 @@ class Db:
         dtf -- Date and time *from*.
         dtt -- Date and time *to*.
         """
-        if not ((len(args) == 3 and len(kwargs) == 0) or
-                (len(args) == 0 and len(kwargs) == 3)):
-            raise ArgCountError("3 arguments needed: flavor, from, and to")
+        assert ((len(args) == 3 and len(kwargs) == 0)
+                or (len(args) == 0 and len(kwargs) == 3))
         if args:
             flavor = args[0]
             dtf = args[1]
@@ -119,26 +115,11 @@ class Db:
             flavor = kwargs["flavor"]
             dtf = kwargs["dtf"]
             dtt = kwargs["dtt"]
-        flavors = flavor.split(";")
-        que = """
-            SELECT count(*), users.name FROM coffees
-            INNER JOIN users ON coffees.id = users.id
-        """
-        for f in flavors:
-            if f is flavors[0]:
-                que += "WHERE flavor = '{}'".format(f)
-            else:
-                que += "OR flavor = '{}'".format(f)
-        que += """
-            AND coffees.time BETWEEN
-                datetime('{}', 'localtime') AND
-                datetime('{}', 'localtime')
-            GROUP BY coffees.id
-        """.format(dtf, dtt)
+        que = Q("get_drinkers_of", flavor.split(";"), (dtf, dtt))
         drunk = []
         for (cnt, un) in self.cur.execute(que):
             drunk.append((un, cnt))
-        return drunk
+        return tuple(drunk)
 
     def getDrunkList(self, *args, **kwargs):
         """Return dict of lists of drunken ``flavor`` from ``dtf`` to ``dtt``.
@@ -148,9 +129,8 @@ class Db:
         dtf -- Date and time *from*.
         dtt -- Date and time *to*.
         """
-        if not ((len(args) == 3 and len(kwargs) == 0) or
-                (len(args) == 0 and len(kwargs) == 3)):
-            raise ArgCountError("3 arguments needed: flavor, from, and to")
+        assert ((len(args) == 3 and len(kwargs) == 0)
+                or (len(args) == 0 and len(kwargs) == 3))
         if args:
             flavor = args[0]
             dtf = args[1]
@@ -163,20 +143,9 @@ class Db:
         drunk = {}
         i = 0
         for f in flavors:
-            que = """
-
-            SELECT count(*), users.name FROM coffees
-            LEFT JOIN identifiers on coffees.id = identifiers.id
-            LEFT JOIN users on identifiers.userid = users.id
-            WHERE flavor = '{}'
-            AND coffees.time BETWEEN
-                datetime('{}', 'localtime') AND
-                datetime('{}', 'localtime')
-            GROUP BY identifiers.userid
-
-            """.format(f, dtf, dtt)
+            que = Q("get_drinkers_of", f, (dtf, dtt))
             for (cnt, un) in self.cur.execute(que):
-                if not un in drunk:
+                if un not in drunk:
                     drunk[un] = [0 for j in range(i)]
                 drunk[un].append(cnt)
             i += 1