ayaresa_AI / app.py
Kennethdotse
updated imports
91fe1e0
import os
import sys
import gradio as gr
import torch
from transformers import pipeline, BitsAndBytesConfig
from datasets import load_dataset
import pandas as pd
from PIL import Image
from typing import Optional
from pathlib import Path
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_community.document_loaders.csv_loader import CSVLoader
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.document_loaders.dataframe import DataFrameLoader
from langchain_text_splitters import CharacterTextSplitter
# ---------- Configuration ----------
MODEL_VARIANT = os.environ.get("MODEL_VARIANT", "4b-it")
MODEL_ID = f"google/medgemma-{MODEL_VARIANT}"
USE_QUANTIZATION = True
LOCAL_DOCS_PATH = Path("./medical/hb_db")
CHROMA_PERSIST_DIR = "./chroma_db"
_pipe = None
_rag_vectorstore = None
_embeddings = None
# ---------- Lazy initialization helpers ----------
def _init_pipeline():
global _pipe
if _pipe is not None:
return _pipe
# Model kwargs
model_kwargs = dict(
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
device_map="auto",
)
if USE_QUANTIZATION:
try:
model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True)
except Exception:
# bitsandbytes may not be available on CPU-only setups; ignore and fall back
pass
# Choose pipeline task type depending on variant
task = "image-text-to-text" if "image" in MODEL_VARIANT or "it" in MODEL_VARIANT else "text-generation"
print(f"Initializing pipeline: {MODEL_ID} task={task}")
_pipe = pipeline(
task,
model=MODEL_ID,
device_map=model_kwargs.get("device_map"),
torch_dtype=model_kwargs.get("torch_dtype"),
**({} if "quantization_config" not in model_kwargs else {"quantization_config": model_kwargs["quantization_config"]}),
)
try:
_pipe.model.generation_config.do_sample = False
except Exception:
pass
return _pipe
def _init_rag():
"""Builds or loads a Chroma vectorstore from local files. This runs lazily on first request."""
global _rag_vectorstore, _embeddings
if _rag_vectorstore is not None:
return _rag_vectorstore
docs = []
# 1) Load a Hugging Face dataset (if available) — convert to a DataFrame
try:
ds = load_dataset("knowrohit07/know_medical_dialogue_v2")
df = pd.DataFrame(ds["train"])
if "instruction" in df.columns and "output" in df.columns:
df["full_dialogue"] = df["instruction"].astype(str) + " \n\n" + df["output"].astype(str)
loader = DataFrameLoader(df, page_content_column="full_dialogue")
docs += loader.load()
except Exception as e:
print("Warning: could not load HF dataset:", e)
# 2) Load local CSV if present
csv_path = LOCAL_DOCS_PATH / "Final_Dataset.csv"
if csv_path.exists():
try:
csv_loader = CSVLoader(str(csv_path))
docs += csv_loader.load()
except Exception as e:
print("Warning loading CSV:", e)
# 3) Load PDFs found in the directory
if LOCAL_DOCS_PATH.exists() and LOCAL_DOCS_PATH.is_dir():
for pdf_file in LOCAL_DOCS_PATH.glob("*.pdf"):
try:
pdf_loader = PyPDFLoader(str(pdf_file))
docs += pdf_loader.load()
except Exception as e:
print(f"Warning loading PDF {pdf_file}: {e}")
# 4) If still no docs, create a placeholder document
if len(docs) == 0:
from langchain.schema import Document
docs = [Document(page_content="No local documents found. Upload PDFs/CSV into ./medical/hb_db or commit them to the Space repo.")]
# 5) Split into chunks
splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
chunks = splitter.split_documents(docs)
# 6) Embeddings and Chroma vectorstore
try:
_embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
_rag_vectorstore = Chroma.from_documents(chunks, _embeddings, persist_directory=CHROMA_PERSIST_DIR)
try:
_rag_vectorstore.persist()
except Exception:
pass
except Exception as e:
print("Error initializing vectorstore:", e)
_rag_vectorstore = None
return _rag_vectorstore
# ---------- Main RAG + generation function ----------
def generate_medgemma_rag_response(query: str, image: Optional[Image.Image] = None) -> str:
"""Generate an answer using RAG + MedGemma model. This function will lazily initialize heavy resources."""
# Ensure rag is initialized
vs = _init_rag()
# Retrieve relevant docs if vectorstore exists
context = ""
if vs is not None:
try:
retrieved = vs.similarity_search(query, k=4)
context = "\n\n".join([d.page_content for d in retrieved])
except Exception as e:
print("Warning during similarity search:", e)
# Construct prompt
rag_prompt = f"You are a respectful, medical AI assistant. Use the provided context and your knowledge to answer and be clear when uncertain.\n\nContext:\n{context}\n\nUser Question: {query}\n\nAnswer:\n"
# Initialize pipeline lazily
pipe = _init_pipeline()
# Build input for the pipeline. The exact expected format can vary by pipeline task.
if image is not None:
# Provide an image + text prompt; pipeline expects inputs in a tuple/list depending on model
input_for_pipe = {"image": image, "text": rag_prompt}
try:
out = pipe(input_for_pipe, max_new_tokens=512)
except Exception:
# fallback to plain text prompt if image pipeline fails
out = pipe(rag_prompt, max_new_tokens=512)
else:
out = pipe(rag_prompt, max_new_tokens=512)
# Normalize output — many pipelines return a list of dicts
try:
if isinstance(out, list) and len(out) > 0:
# Prefer a sensible key if present
if isinstance(out[0], dict):
text = out[0].get("generated_text") or out[0].get("text") or str(out[0])
else:
text = str(out[0])
else:
text = str(out)
except Exception:
text = str(out)
return text
# ...existing code...
with gr.Blocks() as iface:
chatbot = gr.Chatbot(label="Ayaresa chat")
with gr.Row():
with gr.Column(scale=3):
txt = gr.Textbox(label="Enter a prompt", placeholder="Type your question here...", lines=2)
with gr.Column(scale=1):
img = gr.Image(type="pil", label="Image (optional)")
with gr.Row():
send = gr.Button("Send")
clear = gr.Button("Clear")
# keep conversation state explicitly
state = gr.State([])
def submit_fn(message, image, history):
history = history or []
if (not message or message.strip() == "") and image is None:
return history, "", history
resp = generate_medgemma_rag_response(message or "", image)
history.append((message or "", resp))
return history, "", history
send.click(submit_fn, inputs=[txt, img, state], outputs=[chatbot, txt, state])
txt.submit(submit_fn, inputs=[txt, img, state], outputs=[chatbot, txt, state])
clear.click(lambda: ([], "", []), inputs=None, outputs=[chatbot, txt, state])
if __name__ == "__main__":
iface.launch()