llm-arch / src /datatypes.py
alfraser's picture
Added runner for pricing fact checks to assess the level of fact embedding in the latest model
c319c31
raw
history blame
No virus
4.66 kB
import sqlite3
from typing import List
from src.common import *
class DataLoader:
active_db = "02_baseline_products_dataset"
db_dir = os.path.join(data_dir, 'sqlite')
db_file = os.path.join(db_dir, f"{active_db}.db")
loaded = False
@classmethod
def set_db_name(cls, name: str):
if name != cls.active_db:
new_file = os.path.join(data_dir, 'sqlite', f"{name}.db")
cls.db_file = os.path.join(DataLoader.db_dir, f"{name}.db")
DataLoader.load_data(reload=True)
cls.active_db = name
@staticmethod
def current_db() -> str:
return DataLoader.active_db[:-3]
@staticmethod
def available_dbs() -> List[str]:
return [f[:-3] for f in os.listdir(DataLoader.db_dir) if ('products' in f) and f.endswith('.db')]
@staticmethod
def load_data(reload=False):
if DataLoader.loaded and not reload:
return
# Wipe out any prior data
Review.all = {}
Feature.all = {}
Product.all = {}
Category.all = {}
con = sqlite3.connect(DataLoader.db_file)
cur = con.cursor()
categories = cur.execute('SELECT * FROM categories').fetchall()
for c in categories:
Category.all[c[0]] = Category(c[0], c[1])
features = cur.execute('SELECT * FROM features').fetchall()
for f in features:
feat = Feature(f[0], f[1], Category.all[f[2]])
Feature.all[f[0]] = feat
Category.all[f[2]].features.append(feat)
products = cur.execute('SELECT * FROM products').fetchall()
for p in products:
prod = Product(p[0], p[1], p[2], p[3], Category.all[p[4]])
Product.all[p[0]] = prod
Category.all[p[4]].products.append(prod)
prod_feats = cur.execute('SELECT * FROM product_features').fetchall()
for pf in prod_feats:
Product.all[pf[1]].features.append(Feature.all[pf[2]])
Feature.all[pf[2]].products.append(Product.all[pf[1]])
reviews = cur.execute('SELECT * FROM reviews').fetchall()
for r in reviews:
rev = Review(r[0], r[2], r[3], Product.all[r[1]])
Review.all[r[0]] = rev
Product.all[r[1]].reviews.append(rev)
DataLoader.loaded = True
class Category:
all = {}
@staticmethod
def all_sorted():
all_cats = list(Category.all.values())
all_cats.sort(key=lambda x: x.name)
return all_cats
@staticmethod
def by_name(name: str):
all_cats = list(Category.all.values())
for c in all_cats:
if c.name == name:
return c
def __init__(self, id, name):
self.id = id
self.name = name
self.features = []
self.products = []
@property
def feature_count(self):
return len(self.features)
@property
def product_count(self):
return len(self.products)
@property
def singular_name(self):
if self.name[-1] == "s":
return self.name[:-1] # Clip the s
return self.name
@property
def lower_singular_name(self):
if self.name[-1] == "s":
return self.name[:-1].lower() # Clip the s
return self.name.lower()
class Feature:
all = {}
def __init__(self, id, name, category):
self.id = id
self.name = name
self.category = category
self.products = []
@property
def product_count(self):
return len(self.products)
def __repr__(self):
return self.name
class Product:
all = {}
def __init__(self, id, name, description, price, category):
self.id = id
self.name = name
self.description = description
self.price = round(price, 2)
self.category = category
self.features = []
self.reviews = []
@property
def feature_count(self):
return len(self.features)
@property
def review_count(self):
return len(self.reviews)
@property
def average_rating(self, decimals=2):
if self.review_count == 0:
return 0.0
return float(round(sum([r.rating for r in self.reviews]) / self.review_count, decimals))
@staticmethod
def for_ids(ids: List[str]):
return[Product.all[i] for i in ids]
@staticmethod
def all_as_list():
return list(Product.all.values())
class Review:
all = {}
def __init__(self, id, rating, review_text, product):
self.id = id
self.rating = rating
self.review_text = review_text
self.product = product