llm-arch / src /training /train_rag.py
alfraser's picture
Fixed the time.time bug here. Also a call to reset the Chroma DB
1cb115b
raw
history blame
5.04 kB
"""
Not used as part of the streamlit app, but run offline to train the architecture for the rag architecture.
"""
import argparse
import chromadb
import os
import shutil
import time
from chromadb import Settings
from src.common import data_dir
from src.datatypes import *
# Set up and parse expected arguments to the script
parser = argparse.ArgumentParser(prog="train_rag",
description="Train a Chroma DB document store from the product dataset provided for use in a RAG LLM architecture")
parser.add_argument('-in_vectors', help="Optional existing Chroma DB to load and extend")
parser.add_argument('-out_vectors', required=True, help="The name for the output Chroma DB")
parser.add_argument('-products_db', required=True, help="The products sqlite to vectorise")
parser.add_argument('-overwrite', action='store_true', help="Overwrite the in_vectors store from blank (defaults to append and skip existing)")
parser.add_argument('-delete', action='store_true', help="Delete product set from vector store")
args = parser.parse_args()
def need_to_copy_vector_store(args) -> bool:
if args.in_vectors is None:
return False
if args.in_vectors == args.out_vectors:
return False
return True
def copy_vector_store(args) -> None:
src = os.path.join(data_dir, 'vector_stores', f"{args.in_vectors}_chroma")
dest = os.path.join(data_dir, 'vector_stores', f"{args.out_vectors}_chroma")
shutil.copytree(src, dest)
def empty_or_create_out_vector_store(args) -> None:
chroma_client = get_out_vector_store_client(args, True)
chroma_client.reset()
def get_product_collection_from_client(client:chromadb.Client) -> chromadb.Collection:
return client.get_or_create_collection(name='products', metadata={'hnsw:space': 'cosine'})
def connect_to_product_db(args) -> None:
"""
Connect to the requested product DB which will load the products.
On failure this will raise an exception which will propagate to the command
line, which is fine as this is a script, not part of the running app
"""
if DataLoader.active_db != args.products_db:
DataLoader.set_db_name(args.products_db)
else:
DataLoader.load_data()
def get_out_vector_store_client(args, allow_reset: bool = False) -> chromadb.Client:
out_dir = os.path.join(data_dir, 'vector_stores', f"{args.out_vectors}_chroma")
chroma_settings = Settings()
chroma_settings.allow_reset = allow_reset
return chromadb.PersistentClient(path=out_dir, settings=chroma_settings)
def prepare_to_vectorise(args) -> chromadb.Client:
connect_to_product_db(args) # Do this first as non-destructive
# Now do possibly destructive setup
if args.overwrite:
empty_or_create_out_vector_store(args)
elif need_to_copy_vector_store(args):
copy_vector_store(args)
return get_out_vector_store_client(args)
def document_for_product(product: Product) -> str:
"""
Builds a string document for vectorisation from a product
"""
category = product.category.singular_name
category_sentence = f"The {product.name} is a {category}."
price_rating_sentence = f"It costs ${product.price} and is rated {product.average_rating} stars."
feature_sentence = f"The {product.name} features {join_items_comma_and(product.features)}."
return f"{category_sentence} {price_rating_sentence} {feature_sentence} {product.description}"
def vectorise(vector_client: chromadb.Client) -> None:
"""
Add documents representing the products from the products database into the vector store
Document is a built string from the features of the product
IDs are loaded as "prod_{id from db}"
Metadata is loaded with the category
"""
collection = get_product_collection_from_client(vector_client)
products = Product.all_as_list()
ids = [f"prod_{p.id}" for p in products]
documents = [document_for_product(p) for p in products]
metadata = [{'category': p.category.singular_name} for p in products]
print(f"Vectorising {len(products)} products")
collection.upsert(ids=ids, documents=documents, metadatas=metadata)
def prepare_to_delete_vectors(args) -> chromadb.Client:
connect_to_product_db(args) # Do this first as non-destructive
# Now do possibly destructive setup
if need_to_copy_vector_store(args):
copy_vector_store(args)
return get_out_vector_store_client(args)
def delete_vectors(vector_client: chromadb.Client) -> None:
collection = get_product_collection_from_client(vector_client)
products = Product.all_as_list()
ids = [f"prod_{p.id}" for p in products]
collection.delete(ids=ids)
def train(args):
if args.delete:
vector_store = prepare_to_delete_vectors(args)
delete_vectors(vector_store)
else:
vector_store = prepare_to_vectorise(args)
vectorise(vector_store)
if __name__ == "__main__":
start = time()
train(args)
end = time()
print(f"Training took {end-start:.2f} seconds")