LLM_Hackathon / app.py
manojshipra's picture
Update app.py
55e80a1 verified
# Install required packages
import subprocess
import sys
try:
import google.generativeai as genai
except ImportError:
subprocess.check_call([sys.executable, "-m", "pip", "install", "google-generativeai"])
import google.generativeai as genai
import os
import json
from datasets import Dataset
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import requests
from transformers import set_seed
import torch
import streamlit as st
# Configure Gemini API
genai.configure(api_key="AIzaSyBHpcmFHo6rCI4DA6YtAwC7x2JIZy1oeDU")
# Initialize Gemini model
model = genai.GenerativeModel(model_name="gemini-1.5-flash")
# Set seed for reproducibility
set_seed(42)
# -----------------------------
# 1. Data Loading and Preprocessing
# -----------------------------
@st.cache_data
def load_json_files(uploaded_files):
"""
Load and preprocess JSON files from the uploaded files.
Combines relevant fields into a single text blob for each document.
"""
data = []
for uploaded_file in uploaded_files:
try:
json_data = json.load(uploaded_file)
# Combine all relevant fields into a single text blob
text = "\n".join([
f"DESCRIPTION: {json_data.get('DESCRIPTION', '')}",
f"CLINICAL PHARMACOLOGY: {json_data.get('CLINICAL PHARMACOLOGY', '')}",
f"INDICATIONS AND USAGE: {json_data.get('INDICATIONS AND USAGE', '')}",
f"CONTRAINDICATIONS: {json_data.get('CONTRAINDICATIONS', '')}",
f"WARNINGS: {json_data.get('WARNINGS', '')}",
f"PRECAUTIONS: {json_data.get('PRECAUTIONS', '')}",
f"ADVERSE REACTIONS: {json_data.get('ADVERSE REACTIONS', '')}",
f"OVERDOSAGE: {json_data.get('OVERDOSAGE', '')}",
f"DOSAGE AND ADMINISTRATION: {json_data.get('DOSAGE AND ADMINISTRATION', '')}",
f"HOW SUPPLIED: {json_data.get('HOW SUPPLIED', '')}",
f"PACKAGE LABEL.PRINCIPAL DISPLAY PANEL: {json_data.get('PACKAGE LABEL.PRINCIPAL DISPLAY PANEL', '')}",
f"INGREDIENTS AND APPEARANCE: {json_data.get('INGREDIENTS AND APPEARANCE', '')}",
f"PRODUCT NAME: {json_data.get('product_name', '')}"
])
data.append({
"text": text,
"product_name": json_data.get("product_name", "")
})
except json.JSONDecodeError as e:
st.error(f"Error decoding JSON from file {uploaded_file.name}: {e}")
except Exception as e:
st.error(f"Unexpected error processing file {uploaded_file.name}: {e}")
return Dataset.from_list(data)
# -----------------------------
# 2. Generating Embeddings
# -----------------------------
@st.cache_resource
def load_embedding_model():
embedding_model_name = 'all-MiniLM-L6-v2' # You can choose a different model if needed
embedding_model = SentenceTransformer(embedding_model_name)
return embedding_model
@st.cache_data
def generate_embeddings(dataset, embedding_model):
# Function to generate embeddings
def encode_batch(batch):
embeddings = embedding_model.encode(batch['text'], convert_to_tensor=True)
return {'embeddings': embeddings.cpu().numpy()}
# Apply the embedding function to the dataset
dataset = dataset.map(encode_batch, batched=True, batch_size=16)
return dataset
# -----------------------------
# 3. Building the FAISS Vector Store
# -----------------------------
@st.cache_resource
def build_faiss_index(dataset):
embeddings = np.vstack(dataset['embeddings']).astype('float32').copy()
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension) # Using L2 distance
index.add(embeddings)
return index
# -----------------------------
# 4. Defining the RAG QA Function
# -----------------------------
def rag_qa(query, embedding_model, index, dataset, top_k=5):
"""
Perform question answering using Retrieval-Augmented Generation (RAG).
Args:
query (str): The user's question.
embedding_model: The SentenceTransformer embedding model.
index: The FAISS index.
dataset: The dataset containing documents.
top_k (int): Number of top documents to retrieve.
Returns:
str: The generated answer.
"""
try:
# Generate embedding for the query
query_embedding = embedding_model.encode([query], convert_to_tensor=True).cpu().numpy().astype('float32')
# Search in FAISS index
distances, indices = index.search(query_embedding, top_k)
# Retrieve the top_k documents
retrieved_docs = [dataset[int(i)]['text'] for i in indices[0]]
# Prepare the prompt for the Gemini API
prompt = (
"You are a knowledgeable assistant specialized in pharmaceuticals. "
"Use the following information to answer the question accurately.\n\n"
"Context:\n"
+ "\n\n".join(retrieved_docs) +
"\n\nQuestion: " + query +
"\nAnswer:"
)
# Call Gemini API
response = model.generate_content(prompt)
# Extract the answer
answer = response.text.split("Answer:")[-1].strip()
return answer
except Exception as e:
st.error(f"Error during RAG QA: {e}")
return "I'm sorry, I couldn't process your request at the moment."
# -----------------------------
# 5. Streamlit Interface
# -----------------------------
def main():
st.title("Retrieval-Augmented Generation (RAG) QA Agent with Gemini API")
st.sidebar.header("Upload JSON Files")
uploaded_files = st.sidebar.file_uploader("Upload JSON files", type=["json"], accept_multiple_files=True)
if uploaded