import os import pickle import streamlit as st import pandas as pd import numpy as np import faiss import torch from sentence_transformers import SentenceTransformer from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # --------------------------- # Configuration & Constants # --------------------------- CSV_PATH = r'complicated 21010.csv' FAISS_INDEX_PATH = 'faiss_index.bin' MODEL_DIR = r'model fine_tuned 1' # Ensure the path is correct DEVICE = "cuda" if torch.cuda.is_available() else "cpu" TOP_K = 5 # --------------------------- # Load Model & FAISS Index # --------------------------- st.title("Health Insurance RAG System") df = pd.read_csv(CSV_PATH) texts = df["Plan Name"].tolist() # Simplified text retrieval st.write("Loading SentenceTransformer...") embedder = SentenceTransformer("all-MiniLM-L6-v2").to(DEVICE) st.write("Building FAISS index...") index = faiss.IndexFlatL2(384) embeddings = embedder.encode(texts, convert_to_numpy=True) index.add(embeddings) # --------------------------- # Load Tokenizer & Model with Error Handling # --------------------------- st.write("Loading Tokenizer and Model...") try: if os.path.exists(MODEL_DIR): tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR) model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_DIR).to(DEVICE) else: st.warning("Local model not found. Using a pre-trained Hugging Face model.") MODEL_NAME = "facebook/bart-large" # Change as needed tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(DEVICE) except Exception as e: st.error(f"Error loading model: {e}") st.stop() # --------------------------- # Query Processing # --------------------------- user_query = st.text_input("Enter your query:") if user_query: query_embedding = embedder.encode([user_query], convert_to_numpy=True) _, indices = index.search(query_embedding, TOP_K) candidate_texts = [texts[i] for i in indices[0]] selected_context = st.selectbox("Select a retrieved context:", candidate_texts) if st.button("Generate Answer", key="generate_button"): input_text = f"question: {user_query} context: {selected_context}" inputs = tokenizer.encode(input_text, return_tensors="pt", truncation=True).to(DEVICE) with torch.no_grad(): outputs = model.generate(inputs, max_length=150) answer = tokenizer.decode(outputs[0], skip_special_tokens=True) st.write("### Generated Answer:") st.write(answer)