import os import json import numpy as np import faiss import gradio as gr from datasets import load_dataset from sentence_transformers import SentenceTransformer from groq import Groq import nltk import re from nltk.corpus import stopwords from nltk.tokenize import word_tokenize, sent_tokenize from nltk.stem import WordNetLemmatizer from multiprocessing import Pool, cpu_count nltk.download("all") # Load stopwords and lemmatizer stop_words = set(stopwords.words("english")) lemmatizer = WordNetLemmatizer() # Load dataset def load_and_preprocess_dataset(): """Load and preprocess the dataset.""" dataset = load_dataset("MedRAG/textbooks") print("Dataset loaded successfully.") return dataset # Preprocessing function def preprocess_text(text): """Preprocess text by lowercasing, removing special characters, and lemmatizing.""" text = text.lower() # Convert to lowercase text = re.sub(r"[^\w\s]", "", text) # Remove special characters words = word_tokenize(text) # Tokenization words = [lemmatizer.lemmatize(w) for w in words if w not in stop_words] # Lemmatization & stopword removal return " ".join(words) # Chunking function def chunk_text(text, chunk_size=3): """Split text into chunks of sentences.""" sentences = sent_tokenize(text) # Split text into sentences return [" ".join(sentences[i:i + chunk_size]) for i in range(0, len(sentences), chunk_size)] # Generate embeddings in parallel def generate_embeddings_parallel(chunks): """Generate embeddings for chunks in parallel.""" embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") with Pool(cpu_count()) as pool: embeddings = pool.map(embed_model.encode, chunks) return embeddings # Generate embeddings for the dataset def generate_embeddings(dataset): """Generate embeddings for the dataset.""" print("Preprocessing dataset...") dataset = dataset.map(lambda row: {"cleaned_content": preprocess_text(row["content"])}) dataset = dataset.map(lambda row: {"chunks": chunk_text(row["cleaned_content"])}) print("Generating embeddings...") all_chunks = [chunk for row in dataset["train"]["chunks"] for chunk in row] embeddings = generate_embeddings_parallel(all_chunks) # Add embeddings to the dataset dataset = dataset.map(lambda row, idx: {"embedding": embeddings[idx]}, with_indices=True) return dataset # Create FAISS index def create_faiss_index(dataset): """Create and save a FAISS index for the embeddings.""" embeddings_np = np.array([np.array(row["embedding"]).flatten().tolist() for row in dataset["train"]], dtype=np.float32) index = faiss.IndexFlatL2(embeddings_np.shape[1]) index.add(embeddings_np) faiss.write_index(index, "faiss_medical.index") print("FAISS index created and saved.") # Load FAISS index def load_faiss_index(): """Load the FAISS index.""" index = faiss.read_index("faiss_medical.index") print("FAISS index loaded.") return index # Retrieve medical summary def retrieve_medical_summary(query, index, id_to_text, k=3): """Retrieve the most relevant medical literature from FAISS.""" embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") query_embedding = embed_model.encode([query]) D, I = index.search(np.array(query_embedding).astype("float32"), k) retrieved_docs = [id_to_text.get(int(idx), "No relevant data found.") for idx in I[0]] retrieved_docs = [doc if isinstance(doc, str) else " ".join(doc) for doc in retrieved_docs] return "\n\n---\n\n".join(retrieved_docs) if retrieved_docs else "No relevant data found." # Generate medical answer using Groq def generate_medical_answer_groq(query, index, id_to_text): """Generate a medical response using Groq's API.""" retrieved_summary = retrieve_medical_summary(query, index, id_to_text) if not retrieved_summary or retrieved_summary == "No relevant data found.": return "No relevant medical data found. Please consult a healthcare professional." client = Groq(api_key=os.getenv("GROQ_API_KEY")) try: response = client.chat.completions.create( model="llama-3.3-70b-versatile", messages=[ {"role": "system", "content": "You are an expert AI specializing in medical knowledge."}, {"role": "user", "content": f"Summarize the following medical literature and provide a structured medical answer:\n\n### Medical Literature ###\n{retrieved_summary}\n\n### Patient Question ###\n{query}\n\n### Medical Advice ###"} ], max_tokens=500, temperature=0.3 ) return response.choices[0].message.content.strip() except Exception as e: return f"Error generating response: {str(e)}" # Gradio interface def ask_medical_question(question): """Gradio interface for asking medical questions.""" return generate_medical_answer_groq(question, index, id_to_text) # Main function def main(): """Main function to set up the system.""" global index, id_to_text # Load and preprocess dataset dataset = load_and_preprocess_dataset() dataset = generate_embeddings(dataset) # Create FAISS index create_faiss_index(dataset) # Load FAISS index index = load_faiss_index() # Create ID to text mapping medical_texts = dataset["train"]["chunks"] id_to_text = {idx: text for idx, text in enumerate(medical_texts)} with open("id_to_text.json", "w") as f: json.dump(id_to_text, f) # Launch Gradio app iface = gr.Interface( fn=ask_medical_question, inputs=gr.Textbox(lines=2, placeholder="Enter your medical question here..."), outputs=gr.Textbox(lines=10, placeholder="AI-generated medical advice will appear here..."), title="Medical Question Answering System", description="Ask any medical question, and the AI will provide an answer based on medical literature." ) iface.launch() # Run the main function if __name__ == "__main__": main()