IPMODEL / app.py
Madhanitsofcl's picture
Update app.py
9ebe3e2 verified
import os
import pickle
import streamlit as st
import pandas as pd
import numpy as np
import faiss
import torch
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# ---------------------------
# Configuration & Constants
# ---------------------------
CSV_PATH = r'complicated 21010.csv'
FAISS_INDEX_PATH = 'faiss_index.bin'
MODEL_DIR = r'model fine_tuned 1' # Ensure the path is correct
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TOP_K = 5
# ---------------------------
# Load Model & FAISS Index
# ---------------------------
st.title("Health Insurance RAG System")
df = pd.read_csv(CSV_PATH)
texts = df["Plan Name"].tolist() # Simplified text retrieval
st.write("Loading SentenceTransformer...")
embedder = SentenceTransformer("all-MiniLM-L6-v2").to(DEVICE)
st.write("Building FAISS index...")
index = faiss.IndexFlatL2(384)
embeddings = embedder.encode(texts, convert_to_numpy=True)
index.add(embeddings)
# ---------------------------
# Load Tokenizer & Model with Error Handling
# ---------------------------
st.write("Loading Tokenizer and Model...")
try:
if os.path.exists(MODEL_DIR):
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_DIR).to(DEVICE)
else:
st.warning("Local model not found. Using a pre-trained Hugging Face model.")
MODEL_NAME = "facebook/bart-large" # Change as needed
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(DEVICE)
except Exception as e:
st.error(f"Error loading model: {e}")
st.stop()
# ---------------------------
# Query Processing
# ---------------------------
user_query = st.text_input("Enter your query:")
if user_query:
query_embedding = embedder.encode([user_query], convert_to_numpy=True)
_, indices = index.search(query_embedding, TOP_K)
candidate_texts = [texts[i] for i in indices[0]]
selected_context = st.selectbox("Select a retrieved context:", candidate_texts)
if st.button("Generate Answer", key="generate_button"):
input_text = f"question: {user_query} context: {selected_context}"
inputs = tokenizer.encode(input_text, return_tensors="pt", truncation=True).to(DEVICE)
with torch.no_grad():
outputs = model.generate(inputs, max_length=150)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
st.write("### Generated Answer:")
st.write(answer)