Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| import gradio as gr | |
| import torch | |
| from transformers import pipeline, BitsAndBytesConfig | |
| from datasets import load_dataset | |
| import pandas as pd | |
| from PIL import Image | |
| from typing import Optional | |
| from pathlib import Path | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import Chroma | |
| from langchain_community.document_loaders.csv_loader import CSVLoader | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from langchain_community.document_loaders.dataframe import DataFrameLoader | |
| from langchain_text_splitters import CharacterTextSplitter | |
| # ---------- Configuration ---------- | |
| MODEL_VARIANT = os.environ.get("MODEL_VARIANT", "4b-it") | |
| MODEL_ID = f"google/medgemma-{MODEL_VARIANT}" | |
| USE_QUANTIZATION = True | |
| LOCAL_DOCS_PATH = Path("./medical/hb_db") | |
| CHROMA_PERSIST_DIR = "./chroma_db" | |
| _pipe = None | |
| _rag_vectorstore = None | |
| _embeddings = None | |
| # ---------- Lazy initialization helpers ---------- | |
| def _init_pipeline(): | |
| global _pipe | |
| if _pipe is not None: | |
| return _pipe | |
| # Model kwargs | |
| model_kwargs = dict( | |
| torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto", | |
| ) | |
| if USE_QUANTIZATION: | |
| try: | |
| model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True) | |
| except Exception: | |
| # bitsandbytes may not be available on CPU-only setups; ignore and fall back | |
| pass | |
| # Choose pipeline task type depending on variant | |
| task = "image-text-to-text" if "image" in MODEL_VARIANT or "it" in MODEL_VARIANT else "text-generation" | |
| print(f"Initializing pipeline: {MODEL_ID} task={task}") | |
| _pipe = pipeline( | |
| task, | |
| model=MODEL_ID, | |
| device_map=model_kwargs.get("device_map"), | |
| torch_dtype=model_kwargs.get("torch_dtype"), | |
| **({} if "quantization_config" not in model_kwargs else {"quantization_config": model_kwargs["quantization_config"]}), | |
| ) | |
| try: | |
| _pipe.model.generation_config.do_sample = False | |
| except Exception: | |
| pass | |
| return _pipe | |
| def _init_rag(): | |
| """Builds or loads a Chroma vectorstore from local files. This runs lazily on first request.""" | |
| global _rag_vectorstore, _embeddings | |
| if _rag_vectorstore is not None: | |
| return _rag_vectorstore | |
| docs = [] | |
| # 1) Load a Hugging Face dataset (if available) — convert to a DataFrame | |
| try: | |
| ds = load_dataset("knowrohit07/know_medical_dialogue_v2") | |
| df = pd.DataFrame(ds["train"]) | |
| if "instruction" in df.columns and "output" in df.columns: | |
| df["full_dialogue"] = df["instruction"].astype(str) + " \n\n" + df["output"].astype(str) | |
| loader = DataFrameLoader(df, page_content_column="full_dialogue") | |
| docs += loader.load() | |
| except Exception as e: | |
| print("Warning: could not load HF dataset:", e) | |
| # 2) Load local CSV if present | |
| csv_path = LOCAL_DOCS_PATH / "Final_Dataset.csv" | |
| if csv_path.exists(): | |
| try: | |
| csv_loader = CSVLoader(str(csv_path)) | |
| docs += csv_loader.load() | |
| except Exception as e: | |
| print("Warning loading CSV:", e) | |
| # 3) Load PDFs found in the directory | |
| if LOCAL_DOCS_PATH.exists() and LOCAL_DOCS_PATH.is_dir(): | |
| for pdf_file in LOCAL_DOCS_PATH.glob("*.pdf"): | |
| try: | |
| pdf_loader = PyPDFLoader(str(pdf_file)) | |
| docs += pdf_loader.load() | |
| except Exception as e: | |
| print(f"Warning loading PDF {pdf_file}: {e}") | |
| # 4) If still no docs, create a placeholder document | |
| if len(docs) == 0: | |
| from langchain.schema import Document | |
| docs = [Document(page_content="No local documents found. Upload PDFs/CSV into ./medical/hb_db or commit them to the Space repo.")] | |
| # 5) Split into chunks | |
| splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
| chunks = splitter.split_documents(docs) | |
| # 6) Embeddings and Chroma vectorstore | |
| try: | |
| _embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
| _rag_vectorstore = Chroma.from_documents(chunks, _embeddings, persist_directory=CHROMA_PERSIST_DIR) | |
| try: | |
| _rag_vectorstore.persist() | |
| except Exception: | |
| pass | |
| except Exception as e: | |
| print("Error initializing vectorstore:", e) | |
| _rag_vectorstore = None | |
| return _rag_vectorstore | |
| # ---------- Main RAG + generation function ---------- | |
| def generate_medgemma_rag_response(query: str, image: Optional[Image.Image] = None) -> str: | |
| """Generate an answer using RAG + MedGemma model. This function will lazily initialize heavy resources.""" | |
| # Ensure rag is initialized | |
| vs = _init_rag() | |
| # Retrieve relevant docs if vectorstore exists | |
| context = "" | |
| if vs is not None: | |
| try: | |
| retrieved = vs.similarity_search(query, k=4) | |
| context = "\n\n".join([d.page_content for d in retrieved]) | |
| except Exception as e: | |
| print("Warning during similarity search:", e) | |
| # Construct prompt | |
| rag_prompt = f"You are a respectful, medical AI assistant. Use the provided context and your knowledge to answer and be clear when uncertain.\n\nContext:\n{context}\n\nUser Question: {query}\n\nAnswer:\n" | |
| # Initialize pipeline lazily | |
| pipe = _init_pipeline() | |
| # Build input for the pipeline. The exact expected format can vary by pipeline task. | |
| if image is not None: | |
| # Provide an image + text prompt; pipeline expects inputs in a tuple/list depending on model | |
| input_for_pipe = {"image": image, "text": rag_prompt} | |
| try: | |
| out = pipe(input_for_pipe, max_new_tokens=512) | |
| except Exception: | |
| # fallback to plain text prompt if image pipeline fails | |
| out = pipe(rag_prompt, max_new_tokens=512) | |
| else: | |
| out = pipe(rag_prompt, max_new_tokens=512) | |
| # Normalize output — many pipelines return a list of dicts | |
| try: | |
| if isinstance(out, list) and len(out) > 0: | |
| # Prefer a sensible key if present | |
| if isinstance(out[0], dict): | |
| text = out[0].get("generated_text") or out[0].get("text") or str(out[0]) | |
| else: | |
| text = str(out[0]) | |
| else: | |
| text = str(out) | |
| except Exception: | |
| text = str(out) | |
| return text | |
| # ...existing code... | |
| with gr.Blocks() as iface: | |
| chatbot = gr.Chatbot(label="Ayaresa chat") | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| txt = gr.Textbox(label="Enter a prompt", placeholder="Type your question here...", lines=2) | |
| with gr.Column(scale=1): | |
| img = gr.Image(type="pil", label="Image (optional)") | |
| with gr.Row(): | |
| send = gr.Button("Send") | |
| clear = gr.Button("Clear") | |
| # keep conversation state explicitly | |
| state = gr.State([]) | |
| def submit_fn(message, image, history): | |
| history = history or [] | |
| if (not message or message.strip() == "") and image is None: | |
| return history, "", history | |
| resp = generate_medgemma_rag_response(message or "", image) | |
| history.append((message or "", resp)) | |
| return history, "", history | |
| send.click(submit_fn, inputs=[txt, img, state], outputs=[chatbot, txt, state]) | |
| txt.submit(submit_fn, inputs=[txt, img, state], outputs=[chatbot, txt, state]) | |
| clear.click(lambda: ([], "", []), inputs=None, outputs=[chatbot, txt, state]) | |
| if __name__ == "__main__": | |
| iface.launch() |