""" Not used as part of the streamlit app, but run offline to train the architecture for the rag architecture. """ import argparse import chromadb import os import shutil import time from chromadb import Settings from src.common import data_dir from src.datatypes import * # Set up and parse expected arguments to the script parser = argparse.ArgumentParser(prog="train_rag", description="Train a Chroma DB document store from the product dataset provided for use in a RAG LLM architecture") parser.add_argument('-in_vectors', help="Optional existing Chroma DB to load and extend") parser.add_argument('-out_vectors', required=True, help="The name for the output Chroma DB") parser.add_argument('-products_db', required=True, help="The products sqlite to vectorise") parser.add_argument('-overwrite', action='store_true', help="Overwrite the in_vectors store from blank (defaults to append and skip existing)") parser.add_argument('-delete', action='store_true', help="Delete product set from vector store") args = parser.parse_args() def need_to_copy_vector_store(args) -> bool: if args.in_vectors is None: return False if args.in_vectors == args.out_vectors: return False return True def copy_vector_store(args) -> None: src = os.path.join(data_dir, 'vector_stores', f"{args.in_vectors}_chroma") dest = os.path.join(data_dir, 'vector_stores', f"{args.out_vectors}_chroma") shutil.copytree(src, dest) def empty_or_create_out_vector_store(args) -> None: chroma_client = get_out_vector_store_client(args, True) chroma_client.reset() def get_product_collection_from_client(client:chromadb.Client) -> chromadb.Collection: return client.get_or_create_collection(name='products', metadata={'hnsw:space': 'cosine'}) def connect_to_product_db(args) -> None: """ Connect to the requested product DB which will load the products. On failure this will raise an exception which will propagate to the command line, which is fine as this is a script, not part of the running app """ if DataLoader.active_db != args.products_db: DataLoader.set_db_name(args.products_db) else: DataLoader.load_data() def get_out_vector_store_client(args, allow_reset: bool = False) -> chromadb.Client: out_dir = os.path.join(data_dir, 'vector_stores', f"{args.out_vectors}_chroma") chroma_settings = Settings() chroma_settings.allow_reset = allow_reset return chromadb.PersistentClient(path=out_dir, settings=chroma_settings) def prepare_to_vectorise(args) -> chromadb.Client: connect_to_product_db(args) # Do this first as non-destructive # Now do possibly destructive setup if args.overwrite: empty_or_create_out_vector_store(args) elif need_to_copy_vector_store(args): copy_vector_store(args) return get_out_vector_store_client(args) def document_for_product(product: Product) -> str: """ Builds a string document for vectorisation from a product """ category = product.category.singular_name category_sentence = f"The {product.name} is a {category}." price_rating_sentence = f"It costs ${product.price} and is rated {product.average_rating} stars." feature_sentence = f"The {product.name} features {join_items_comma_and(product.features)}." return f"{category_sentence} {price_rating_sentence} {feature_sentence} {product.description}" def vectorise(vector_client: chromadb.Client) -> None: """ Add documents representing the products from the products database into the vector store Document is a built string from the features of the product IDs are loaded as "prod_{id from db}" Metadata is loaded with the category """ collection = get_product_collection_from_client(vector_client) products = Product.all_as_list() ids = [f"prod_{p.id}" for p in products] documents = [document_for_product(p) for p in products] metadata = [{'category': p.category.singular_name} for p in products] print(f"Vectorising {len(products)} products") collection.upsert(ids=ids, documents=documents, metadatas=metadata) def prepare_to_delete_vectors(args) -> chromadb.Client: connect_to_product_db(args) # Do this first as non-destructive # Now do possibly destructive setup if need_to_copy_vector_store(args): copy_vector_store(args) return get_out_vector_store_client(args) def delete_vectors(vector_client: chromadb.Client) -> None: collection = get_product_collection_from_client(vector_client) products = Product.all_as_list() ids = [f"prod_{p.id}" for p in products] collection.delete(ids=ids) def train(args): if args.delete: vector_store = prepare_to_delete_vectors(args) delete_vectors(vector_store) else: vector_store = prepare_to_vectorise(args) vectorise(vector_store) if __name__ == "__main__": start = time() train(args) end = time() print(f"Training took {end-start:.2f} seconds")