Spaces:
Sleeping
Sleeping
import os | |
import pickle | |
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
import faiss | |
import torch | |
from sentence_transformers import SentenceTransformer | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
# --------------------------- | |
# Configuration & Constants | |
# --------------------------- | |
CSV_PATH = r'complicated 21010.csv' | |
FAISS_INDEX_PATH = 'faiss_index.bin' | |
MODEL_DIR = r'model fine_tuned 1' # Ensure the path is correct | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
TOP_K = 5 | |
# --------------------------- | |
# Load Model & FAISS Index | |
# --------------------------- | |
st.title("Health Insurance RAG System") | |
df = pd.read_csv(CSV_PATH) | |
texts = df["Plan Name"].tolist() # Simplified text retrieval | |
st.write("Loading SentenceTransformer...") | |
embedder = SentenceTransformer("all-MiniLM-L6-v2").to(DEVICE) | |
st.write("Building FAISS index...") | |
index = faiss.IndexFlatL2(384) | |
embeddings = embedder.encode(texts, convert_to_numpy=True) | |
index.add(embeddings) | |
# --------------------------- | |
# Load Tokenizer & Model with Error Handling | |
# --------------------------- | |
st.write("Loading Tokenizer and Model...") | |
try: | |
if os.path.exists(MODEL_DIR): | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR) | |
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_DIR).to(DEVICE) | |
else: | |
st.warning("Local model not found. Using a pre-trained Hugging Face model.") | |
MODEL_NAME = "facebook/bart-large" # Change as needed | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(DEVICE) | |
except Exception as e: | |
st.error(f"Error loading model: {e}") | |
st.stop() | |
# --------------------------- | |
# Query Processing | |
# --------------------------- | |
user_query = st.text_input("Enter your query:") | |
if user_query: | |
query_embedding = embedder.encode([user_query], convert_to_numpy=True) | |
_, indices = index.search(query_embedding, TOP_K) | |
candidate_texts = [texts[i] for i in indices[0]] | |
selected_context = st.selectbox("Select a retrieved context:", candidate_texts) | |
if st.button("Generate Answer", key="generate_button"): | |
input_text = f"question: {user_query} context: {selected_context}" | |
inputs = tokenizer.encode(input_text, return_tensors="pt", truncation=True).to(DEVICE) | |
with torch.no_grad(): | |
outputs = model.generate(inputs, max_length=150) | |
answer = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
st.write("### Generated Answer:") | |
st.write(answer) | |