import sqlite3 from typing import List from src.common import * class DataLoader: active_db = "02_baseline_products_dataset" db_dir = os.path.join(data_dir, 'sqlite') db_file = os.path.join(db_dir, f"{active_db}.db") loaded = False @classmethod def set_db_name(cls, name: str): if name != cls.active_db: new_file = os.path.join(data_dir, 'sqlite', f"{name}.db") cls.db_file = os.path.join(DataLoader.db_dir, f"{name}.db") DataLoader.load_data(reload=True) cls.active_db = name @staticmethod def current_db() -> str: return DataLoader.active_db[:-3] @staticmethod def available_dbs() -> List[str]: return [f[:-3] for f in os.listdir(DataLoader.db_dir) if ('products' in f) and f.endswith('.db')] @staticmethod def load_data(reload=False): if DataLoader.loaded and not reload: return # Wipe out any prior data Review.all = {} Feature.all = {} Product.all = {} Category.all = {} con = sqlite3.connect(DataLoader.db_file) cur = con.cursor() categories = cur.execute('SELECT * FROM categories').fetchall() for c in categories: Category.all[c[0]] = Category(c[0], c[1]) features = cur.execute('SELECT * FROM features').fetchall() for f in features: feat = Feature(f[0], f[1], Category.all[f[2]]) Feature.all[f[0]] = feat Category.all[f[2]].features.append(feat) products = cur.execute('SELECT * FROM products').fetchall() for p in products: prod = Product(p[0], p[1], p[2], p[3], Category.all[p[4]]) Product.all[p[0]] = prod Category.all[p[4]].products.append(prod) prod_feats = cur.execute('SELECT * FROM product_features').fetchall() for pf in prod_feats: Product.all[pf[1]].features.append(Feature.all[pf[2]]) Feature.all[pf[2]].products.append(Product.all[pf[1]]) reviews = cur.execute('SELECT * FROM reviews').fetchall() for r in reviews: rev = Review(r[0], r[2], r[3], Product.all[r[1]]) Review.all[r[0]] = rev Product.all[r[1]].reviews.append(rev) DataLoader.loaded = True class Category: all = {} @staticmethod def all_sorted(): all_cats = list(Category.all.values()) all_cats.sort(key=lambda x: x.name) return all_cats @staticmethod def by_name(name: str): all_cats = list(Category.all.values()) for c in all_cats: if c.name == name: return c def __init__(self, id, name): self.id = id self.name = name self.features = [] self.products = [] @property def feature_count(self): return len(self.features) @property def product_count(self): return len(self.products) @property def singular_name(self): if self.name[-1] == "s": return self.name[:-1] # Clip the s return self.name @property def lower_singular_name(self): if self.name[-1] == "s": return self.name[:-1].lower() # Clip the s return self.name.lower() class Feature: all = {} def __init__(self, id, name, category): self.id = id self.name = name self.category = category self.products = [] @property def product_count(self): return len(self.products) def __repr__(self): return self.name class Product: all = {} def __init__(self, id, name, description, price, category): self.id = id self.name = name self.description = description self.price = round(price, 2) self.category = category self.features = [] self.reviews = [] @property def feature_count(self): return len(self.features) @property def review_count(self): return len(self.reviews) @property def average_rating(self, decimals=2): if self.review_count == 0: return 0.0 return float(round(sum([r.rating for r in self.reviews]) / self.review_count, decimals)) @staticmethod def for_ids(ids: List[str]): return[Product.all[i] for i in ids] @staticmethod def all_as_list(): return list(Product.all.values()) class Review: all = {} def __init__(self, id, rating, review_text, product): self.id = id self.rating = rating self.review_text = review_text self.product = product