llm-arch / src /datatypes.py
alfraser's picture
Added a couple of convenience methods to the classes to get the representations I needed, but put that in the classes
8531ccc
raw
history blame
4.64 kB
import sqlite3
from typing import List
from src.common import *
class DataLoader:
active_db = "01_all_products_dataset"
db_dir = os.path.join(data_dir, 'sqlite')
db_file = os.path.join(db_dir, f"{active_db}.db")
loaded = False
@classmethod
def set_db_name(cls, name: str):
if name != cls.active_db:
new_file = os.path.join(data_dir, 'sqlite', f"{name}.db")
print(f"Switching database file from {cls.db_file} to {new_file}")
cls.db_file = os.path.join(DataLoader.db_dir, f"{name}.db")
DataLoader.load_data(reload=True)
cls.active_db = name
@staticmethod
def current_db() -> str:
return DataLoader.active_db[:-3]
@staticmethod
def available_dbs() -> List[str]:
return [f[:-3] for f in os.listdir(DataLoader.db_dir) if ('products' in f) and f.endswith('.db')]
@staticmethod
def load_data(reload=False):
if DataLoader.loaded and not reload:
return
# Wipe out any prior data
Review.all = {}
Feature.all = {}
Product.all = {}
Category.all = {}
print(f"Loading {DataLoader.db_file}")
con = sqlite3.connect(DataLoader.db_file)
cur = con.cursor()
categories = cur.execute('SELECT * FROM categories').fetchall()
for c in categories:
Category.all[c[0]] = Category(c[0], c[1])
features = cur.execute('SELECT * FROM features').fetchall()
for f in features:
feat = Feature(f[0], f[1], Category.all[f[2]])
Feature.all[f[0]] = feat
Category.all[f[2]].features.append(feat)
products = cur.execute('SELECT * FROM products').fetchall()
for p in products:
prod = Product(p[0], p[1], p[2], p[3], Category.all[p[4]])
Product.all[p[0]] = prod
Category.all[p[4]].products.append(prod)
prod_feats = cur.execute('SELECT * FROM product_features').fetchall()
for pf in prod_feats:
Product.all[pf[1]].features.append(Feature.all[pf[2]])
Feature.all[pf[2]].products.append(Product.all[pf[1]])
reviews = cur.execute('SELECT * FROM reviews').fetchall()
for r in reviews:
rev = Review(r[0], r[2], r[3], Product.all[r[1]])
Review.all[r[0]] = rev
Product.all[r[1]].reviews.append(rev)
print("Data loaded")
DataLoader.loaded = True
class Category:
all = {}
@staticmethod
def all_sorted():
all_cats = list(Category.all.values())
all_cats.sort(key=lambda x: x.name)
return all_cats
@staticmethod
def by_name(name: str):
all_cats = list(Category.all.values())
for c in all_cats:
if c.name == name:
return c
def __init__(self, id, name):
self.id = id
self.name = name
self.features = []
self.products = []
@property
def feature_count(self):
return len(self.features)
@property
def product_count(self):
return len(self.products)
@property
def singular_name(self):
if self.name[-1] == "s":
return self.name[:-1] # Clip the s
return self.name
class Feature:
all = {}
def __init__(self, id, name, category):
self.id = id
self.name = name
self.category = category
self.products = []
@property
def product_count(self):
return len(self.products)
def __repr__(self):
return self.name
class Product:
all = {}
def __init__(self, id, name, description, price, category):
self.id = id
self.name = name
self.description = description
self.price = round(price, 2)
self.category = category
self.features = []
self.reviews = []
@property
def feature_count(self):
return len(self.features)
@property
def review_count(self):
return len(self.reviews)
@property
def average_rating(self, decimals=2):
if self.review_count == 0:
return 0.0
return float(round(sum([r.rating for r in self.reviews]) / self.review_count, decimals))
@staticmethod
def for_ids(ids: List[str]):
return[Product.all[i] for i in ids]
@staticmethod
def all_as_list():
return list(Product.all.values())
class Review:
all = {}
def __init__(self, id, rating, review_text, product):
self.id = id
self.rating = rating
self.review_text = review_text
self.product = product