Spaces:
Sleeping
Sleeping
# Install required packages | |
import subprocess | |
import sys | |
try: | |
import google.generativeai as genai | |
except ImportError: | |
subprocess.check_call([sys.executable, "-m", "pip", "install", "google-generativeai"]) | |
import google.generativeai as genai | |
import os | |
import json | |
from datasets import Dataset | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
import numpy as np | |
import requests | |
from transformers import set_seed | |
import torch | |
import streamlit as st | |
# Configure Gemini API | |
genai.configure(api_key="AIzaSyBHpcmFHo6rCI4DA6YtAwC7x2JIZy1oeDU") | |
# Initialize Gemini model | |
model = genai.GenerativeModel(model_name="gemini-1.5-flash") | |
# Set seed for reproducibility | |
set_seed(42) | |
# ----------------------------- | |
# 1. Data Loading and Preprocessing | |
# ----------------------------- | |
def load_json_files(uploaded_files): | |
""" | |
Load and preprocess JSON files from the uploaded files. | |
Combines relevant fields into a single text blob for each document. | |
""" | |
data = [] | |
for uploaded_file in uploaded_files: | |
try: | |
json_data = json.load(uploaded_file) | |
# Combine all relevant fields into a single text blob | |
text = "\n".join([ | |
f"DESCRIPTION: {json_data.get('DESCRIPTION', '')}", | |
f"CLINICAL PHARMACOLOGY: {json_data.get('CLINICAL PHARMACOLOGY', '')}", | |
f"INDICATIONS AND USAGE: {json_data.get('INDICATIONS AND USAGE', '')}", | |
f"CONTRAINDICATIONS: {json_data.get('CONTRAINDICATIONS', '')}", | |
f"WARNINGS: {json_data.get('WARNINGS', '')}", | |
f"PRECAUTIONS: {json_data.get('PRECAUTIONS', '')}", | |
f"ADVERSE REACTIONS: {json_data.get('ADVERSE REACTIONS', '')}", | |
f"OVERDOSAGE: {json_data.get('OVERDOSAGE', '')}", | |
f"DOSAGE AND ADMINISTRATION: {json_data.get('DOSAGE AND ADMINISTRATION', '')}", | |
f"HOW SUPPLIED: {json_data.get('HOW SUPPLIED', '')}", | |
f"PACKAGE LABEL.PRINCIPAL DISPLAY PANEL: {json_data.get('PACKAGE LABEL.PRINCIPAL DISPLAY PANEL', '')}", | |
f"INGREDIENTS AND APPEARANCE: {json_data.get('INGREDIENTS AND APPEARANCE', '')}", | |
f"PRODUCT NAME: {json_data.get('product_name', '')}" | |
]) | |
data.append({ | |
"text": text, | |
"product_name": json_data.get("product_name", "") | |
}) | |
except json.JSONDecodeError as e: | |
st.error(f"Error decoding JSON from file {uploaded_file.name}: {e}") | |
except Exception as e: | |
st.error(f"Unexpected error processing file {uploaded_file.name}: {e}") | |
return Dataset.from_list(data) | |
# ----------------------------- | |
# 2. Generating Embeddings | |
# ----------------------------- | |
def load_embedding_model(): | |
embedding_model_name = 'all-MiniLM-L6-v2' # You can choose a different model if needed | |
embedding_model = SentenceTransformer(embedding_model_name) | |
return embedding_model | |
def generate_embeddings(dataset, embedding_model): | |
# Function to generate embeddings | |
def encode_batch(batch): | |
embeddings = embedding_model.encode(batch['text'], convert_to_tensor=True) | |
return {'embeddings': embeddings.cpu().numpy()} | |
# Apply the embedding function to the dataset | |
dataset = dataset.map(encode_batch, batched=True, batch_size=16) | |
return dataset | |
# ----------------------------- | |
# 3. Building the FAISS Vector Store | |
# ----------------------------- | |
def build_faiss_index(dataset): | |
embeddings = np.vstack(dataset['embeddings']).astype('float32').copy() | |
dimension = embeddings.shape[1] | |
index = faiss.IndexFlatL2(dimension) # Using L2 distance | |
index.add(embeddings) | |
return index | |
# ----------------------------- | |
# 4. Defining the RAG QA Function | |
# ----------------------------- | |
def rag_qa(query, embedding_model, index, dataset, top_k=5): | |
""" | |
Perform question answering using Retrieval-Augmented Generation (RAG). | |
Args: | |
query (str): The user's question. | |
embedding_model: The SentenceTransformer embedding model. | |
index: The FAISS index. | |
dataset: The dataset containing documents. | |
top_k (int): Number of top documents to retrieve. | |
Returns: | |
str: The generated answer. | |
""" | |
try: | |
# Generate embedding for the query | |
query_embedding = embedding_model.encode([query], convert_to_tensor=True).cpu().numpy().astype('float32') | |
# Search in FAISS index | |
distances, indices = index.search(query_embedding, top_k) | |
# Retrieve the top_k documents | |
retrieved_docs = [dataset[int(i)]['text'] for i in indices[0]] | |
# Prepare the prompt for the Gemini API | |
prompt = ( | |
"You are a knowledgeable assistant specialized in pharmaceuticals. " | |
"Use the following information to answer the question accurately.\n\n" | |
"Context:\n" | |
+ "\n\n".join(retrieved_docs) + | |
"\n\nQuestion: " + query + | |
"\nAnswer:" | |
) | |
# Call Gemini API | |
response = model.generate_content(prompt) | |
# Extract the answer | |
answer = response.text.split("Answer:")[-1].strip() | |
return answer | |
except Exception as e: | |
st.error(f"Error during RAG QA: {e}") | |
return "I'm sorry, I couldn't process your request at the moment." | |
# ----------------------------- | |
# 5. Streamlit Interface | |
# ----------------------------- | |
def main(): | |
st.title("Retrieval-Augmented Generation (RAG) QA Agent with Gemini API") | |
st.sidebar.header("Upload JSON Files") | |
uploaded_files = st.sidebar.file_uploader("Upload JSON files", type=["json"], accept_multiple_files=True) | |
if uploaded |