# src/rag_application.py import os from langchain.vectorstores import FAISS from sentence_transformers import SentenceTransformer from langchain.chains import RetrievalQA from data_loader import load_all_data from custom_llm import HuggingFaceLLMWrapper DATA_DIR = "../datasets/microlabs_usa/" PERSIST_DIRECTORY = "faiss_index" EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2" LLM_MODEL = "meta-llama/Llama-3.2-3B" # Replace with your actual model name def load_documents(): products = load_all_data() documents = [] for product in products: content = "\n".join([ f"Product Name: {product.get('Product_Name', '')}", f"Usage: {product.get('Usage', '')}", f"Composition: {product.get('Composition', '')}", f"Warnings: {product.get('Warnings', '')}", f"Dosage and Administration: {product.get('Dosage_and_Administration', '')}", f"Side Effects: {product.get('Side_Effects', '')}", f"Drug Interactions: {product.get('Drug_Interactions', '')}" ]) documents.append(content) return documents def create_embeddings(documents): embedding_model = SentenceTransformer(EMBEDDING_MODEL) embeddings = embedding_model.encode(documents, show_progress_bar=True, convert_to_tensor=True) # Create FAISS index db = FAISS.from_embeddings(embeddings, documents) db.save_local(PERSIST_DIRECTORY) return db def load_embeddings_db(): if not os.path.exists(PERSIST_DIRECTORY): raise FileNotFoundError("FAISS index not found. Please run `create_embeddings` first.") db = FAISS.load_local(PERSIST_DIRECTORY, SentenceTransformer(EMBEDDING_MODEL)) return db def setup_qa_chain(): db = load_embeddings_db() llm = HuggingFaceLLMWrapper(model_name=LLM_MODEL, device='cuda') # Change to 'cpu' if necessary qa = RetrievalQA.from_chain_type( llm=llm, chain_type="stuff", retriever=db.as_retriever() ) return qa if __name__ == "__main__": # Step 1: Load and prepare data docs = load_documents() # Step 2: Create embeddings and index (only if FAISS index doesn't exist) if not os.path.exists(PERSIST_DIRECTORY): db = create_embeddings(docs) else: db = load_embeddings_db() # Step 3: Set up QA chain qa_chain = setup_qa_chain() # Test query query = "What is the composition of Paracetamol?" answer = qa_chain.run(query) print(f"Q: {query}\nA: {answer}")