Spaces:
Runtime error
Runtime error
""" | |
Not used as part of the streamlit app, but run offline to train the architecture for the rag architecture. | |
""" | |
import argparse | |
import chromadb | |
import os | |
import shutil | |
import time | |
from chromadb import Settings | |
from src.common import data_dir | |
from src.datatypes import * | |
# Set up and parse expected arguments to the script | |
parser = argparse.ArgumentParser(prog="train_rag", | |
description="Train a Chroma DB document store from the product dataset provided for use in a RAG LLM architecture") | |
parser.add_argument('-in_vectors', help="Optional existing Chroma DB to load and extend") | |
parser.add_argument('-out_vectors', required=True, help="The name for the output Chroma DB") | |
parser.add_argument('-products_db', required=True, help="The products sqlite to vectorise") | |
parser.add_argument('-overwrite', action='store_true', help="Overwrite the in_vectors store from blank (defaults to append and skip existing)") | |
parser.add_argument('-delete', action='store_true', help="Delete product set from vector store") | |
args = parser.parse_args() | |
def need_to_copy_vector_store(args) -> bool: | |
if args.in_vectors is None: | |
return False | |
if args.in_vectors == args.out_vectors: | |
return False | |
return True | |
def copy_vector_store(args) -> None: | |
src = os.path.join(data_dir, 'vector_stores', f"{args.in_vectors}_chroma") | |
dest = os.path.join(data_dir, 'vector_stores', f"{args.out_vectors}_chroma") | |
shutil.copytree(src, dest) | |
def empty_or_create_out_vector_store(args) -> None: | |
chroma_client = get_out_vector_store_client(args, True) | |
chroma_client.reset() | |
def get_product_collection_from_client(client:chromadb.Client) -> chromadb.Collection: | |
return client.get_or_create_collection(name='products', metadata={'hnsw:space': 'cosine'}) | |
def connect_to_product_db(args) -> None: | |
""" | |
Connect to the requested product DB which will load the products. | |
On failure this will raise an exception which will propagate to the command | |
line, which is fine as this is a script, not part of the running app | |
""" | |
if DataLoader.active_db != args.products_db: | |
DataLoader.set_db_name(args.products_db) | |
else: | |
DataLoader.load_data() | |
def get_out_vector_store_client(args, allow_reset: bool = False) -> chromadb.Client: | |
out_dir = os.path.join(data_dir, 'vector_stores', f"{args.out_vectors}_chroma") | |
chroma_settings = Settings() | |
chroma_settings.allow_reset = allow_reset | |
return chromadb.PersistentClient(path=out_dir, settings=chroma_settings) | |
def prepare_to_vectorise(args) -> chromadb.Client: | |
connect_to_product_db(args) # Do this first as non-destructive | |
# Now do possibly destructive setup | |
if args.overwrite: | |
empty_or_create_out_vector_store(args) | |
elif need_to_copy_vector_store(args): | |
copy_vector_store(args) | |
return get_out_vector_store_client(args) | |
def document_for_product(product: Product) -> str: | |
""" | |
Builds a string document for vectorisation from a product | |
""" | |
category = product.category.singular_name | |
category_sentence = f"The {product.name} is a {category}." | |
price_rating_sentence = f"It costs ${product.price} and is rated {product.average_rating} stars." | |
feature_sentence = f"The {product.name} features {join_items_comma_and(product.features)}." | |
return f"{category_sentence} {price_rating_sentence} {feature_sentence} {product.description}" | |
def vectorise(vector_client: chromadb.Client) -> None: | |
""" | |
Add documents representing the products from the products database into the vector store | |
Document is a built string from the features of the product | |
IDs are loaded as "prod_{id from db}" | |
Metadata is loaded with the category | |
""" | |
collection = get_product_collection_from_client(vector_client) | |
products = Product.all_as_list() | |
ids = [f"prod_{p.id}" for p in products] | |
documents = [document_for_product(p) for p in products] | |
metadata = [{'category': p.category.singular_name} for p in products] | |
print(f"Vectorising {len(products)} products") | |
collection.upsert(ids=ids, documents=documents, metadatas=metadata) | |
def prepare_to_delete_vectors(args) -> chromadb.Client: | |
connect_to_product_db(args) # Do this first as non-destructive | |
# Now do possibly destructive setup | |
if need_to_copy_vector_store(args): | |
copy_vector_store(args) | |
return get_out_vector_store_client(args) | |
def delete_vectors(vector_client: chromadb.Client) -> None: | |
collection = get_product_collection_from_client(vector_client) | |
products = Product.all_as_list() | |
ids = [f"prod_{p.id}" for p in products] | |
collection.delete(ids=ids) | |
def train(args): | |
if args.delete: | |
vector_store = prepare_to_delete_vectors(args) | |
delete_vectors(vector_store) | |
else: | |
vector_store = prepare_to_vectorise(args) | |
vectorise(vector_store) | |
if __name__ == "__main__": | |
start = time() | |
train(args) | |
end = time() | |
print(f"Training took {end-start:.2f} seconds") | |