VatsalPatel18's picture
Fix Phi-3 generate fallback on ZeroGPU
1d1ab79
"""
MedDiscover-HF: Hugging Face Spaces-ready Gradio app.
- ZeroGPU-compatible (uses @spaces.GPU for heavy ops)
- MedCPT embeddings + FAISS retrieval over uploaded PDFs
- OSS generator model dropdown (gpt-oss-20b, gemma-3-12b-it, deepseek-vl2-small, granite vision/docling)
"""
import os
import json
import csv
import uuid
from datetime import datetime
from pathlib import Path
from threading import Thread
from typing import List, Dict, Tuple
import faiss
import gradio as gr
import numpy as np
import spaces
import torch
from PyPDF2 import PdfReader
from transformers import (
AutoModel,
AutoTokenizer,
AutoProcessor,
TextIteratorStreamer,
pipeline,
)
# ----------------------------
# Paths and env configuration
# ----------------------------
BASE_DIR = Path(__file__).parent
DATA_DIR = Path(os.getenv("DATA_DIR") or (BASE_DIR / "data"))
DATA_DIR.mkdir(parents=True, exist_ok=True)
INDEX_PATH = DATA_DIR / "faiss_index.bin"
META_PATH = DATA_DIR / "doc_metadata.json"
LOGS_PATH = DATA_DIR / "logs.jsonl"
HF_TOKEN = os.getenv("HF_TOKEN") # set in Space secrets if needed
HF_HOME = os.getenv("HF_HOME", str(DATA_DIR / ".cache"))
os.environ["HF_HOME"] = HF_HOME
# Force CPU on stateless ZeroGPU environments to avoid CUDA init errors
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
# ----------------------------
# Chunking / PDF utils
# ----------------------------
CHUNK_SIZE = 500
OVERLAP = 50
def chunk_text(text: str, chunk_size: int = CHUNK_SIZE, overlap: int = OVERLAP) -> List[str]:
words = text.split()
chunks = []
start = 0
while start < len(words):
end = start + chunk_size
chunk = words[start:end]
if not chunk:
break
chunks.append(" ".join(chunk))
start += (chunk_size - overlap)
return chunks
def extract_text_from_pdf(path: str) -> str:
buff = []
try:
reader = PdfReader(path)
for page in reader.pages:
text = page.extract_text()
if text:
buff.append(text)
except Exception as e: # pragma: no cover
return f"Error reading {path}: {e}"
return "\n".join(buff)
# ----------------------------
# Embeddings: MedCPT encoders
# ----------------------------
MEDCPT_ARTICLE = "ncbi/MedCPT-Article-Encoder"
MEDCPT_QUERY = "ncbi/MedCPT-Query-Encoder"
MAX_ART_LEN = 512
MAX_QUERY_LEN = 64
EMBED_DIM = 768
_article_tok = None
_article_model = None
_query_tok = None
_query_model = None
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def load_medcpt():
global _article_tok, _article_model, _query_tok, _query_model
if _article_model and _query_model:
return
_article_tok = AutoTokenizer.from_pretrained(MEDCPT_ARTICLE, use_auth_token=HF_TOKEN)
_article_model = AutoModel.from_pretrained(MEDCPT_ARTICLE, use_auth_token=HF_TOKEN)
_article_model.to(DEVICE).eval()
_query_tok = AutoTokenizer.from_pretrained(MEDCPT_QUERY, use_auth_token=HF_TOKEN)
_query_model = AutoModel.from_pretrained(MEDCPT_QUERY, use_auth_token=HF_TOKEN)
_query_model.to(DEVICE).eval()
@spaces.GPU()
def embed_chunks(chunks: List[str]) -> np.ndarray:
load_medcpt()
all_vecs = []
with torch.no_grad():
for i in range(0, len(chunks), 8):
batch = chunks[i : i + 8]
enc = _article_tok(
batch,
truncation=True,
padding=True,
return_tensors="pt",
max_length=MAX_ART_LEN,
).to(DEVICE)
out = _article_model(**enc)
vec = out.last_hidden_state[:, 0, :].cpu().numpy()
all_vecs.append(vec)
if not all_vecs:
return np.array([])
return np.vstack(all_vecs)
@spaces.GPU()
def embed_query(query: str) -> np.ndarray:
load_medcpt()
with torch.no_grad():
enc = _query_tok(
query,
truncation=True,
padding=True,
return_tensors="pt",
max_length=MAX_QUERY_LEN,
).to(DEVICE)
out = _query_model(**enc)
vec = out.last_hidden_state[:, 0, :].cpu().numpy()
return vec
# ----------------------------
# FAISS index helpers
# ----------------------------
def build_index(embeddings: np.ndarray) -> faiss.IndexFlatIP:
if embeddings.dtype != np.float32:
embeddings = embeddings.astype(np.float32)
index = faiss.IndexFlatIP(embeddings.shape[1])
index.add(embeddings)
return index
def save_index(index: faiss.IndexFlatIP, meta: List[Dict]):
faiss.write_index(index, str(INDEX_PATH))
META_PATH.write_text(json.dumps(meta, indent=2), encoding="utf-8")
def load_index() -> Tuple[faiss.IndexFlatIP, List[Dict]]:
if not INDEX_PATH.exists() or not META_PATH.exists():
return None, None
idx = faiss.read_index(str(INDEX_PATH))
meta = json.loads(META_PATH.read_text(encoding="utf-8"))
return idx, meta
def search(index: faiss.IndexFlatIP, meta: List[Dict], query_vec: np.ndarray, k: int) -> List[Dict]:
if query_vec.dtype != np.float32:
query_vec = query_vec.astype(np.float32)
scores, inds = index.search(query_vec, k)
candidates = []
for score, ind in zip(scores[0], inds[0]):
if ind < 0 or ind >= len(meta):
continue
item = dict(meta[ind])
item["retrieval_score"] = float(score)
candidates.append(item)
return candidates
# ----------------------------
# Model registry for generators
# ----------------------------
class GeneratorWrapper:
def __init__(self, name: str, load_fn, fallback=None, fallback_msg: str | None = None):
self.name = name
self._load_fn = load_fn
self._pipe = None
self._fallback = fallback
self._fallback_msg = fallback_msg
self._note = None
def ensure(self):
if self._pipe is None:
try:
self._pipe = self._load_fn()
self._note = None
except Exception as exc:
print(f"[Generator:{self.name}] load failed: {exc}")
if self._fallback:
print(f"[Generator:{self.name}] falling back to {self._fallback.name}")
self._pipe = self._fallback.ensure()
self._note = self._fallback_msg or f"Falling back to {self._fallback.name}."
else:
raise
return self._pipe
def generate_stream(self, prompt: str, max_new_tokens: int, temperature: float, top_p: float):
pipe = self.ensure()
streamer = TextIteratorStreamer(pipe.tokenizer, skip_special_tokens=True, skip_prompt=True)
inputs = pipe.tokenizer(prompt, return_tensors="pt")
device = getattr(pipe.model, "device", torch.device("cpu"))
inputs = {k: v.to(device) for k, v in inputs.items()}
gen_kwargs = {
"max_new_tokens": max_new_tokens,
"do_sample": True,
"temperature": temperature,
"top_p": top_p,
"streamer": streamer,
"return_dict_in_generate": True,
"output_scores": False,
"use_cache": False, # avoid DynamicCache issues on Phi-3 CPU
}
def _run():
try:
pipe.model.generate(**inputs, **gen_kwargs)
except Exception as exc:
if self._fallback:
print(f"[Generator:{self.name}] generate failed: {exc}; falling back to {self._fallback.name}")
self._pipe = self._fallback.ensure()
note = self._fallback_msg or f"Falling back to {self._fallback.name}."
if note:
streamer.put(note + " ")
fb_stream = self._fallback.generate_stream(prompt, max_new_tokens, temperature, top_p)
for tok in fb_stream:
streamer.put(tok)
else:
print(f"[Generator:{self.name}] generate failed: {exc}")
streamer.end()
Thread(target=_run, daemon=True).start()
if self._note:
yield self._note + " "
self._note = None
for token in streamer:
yield token
def load_gpt_oss():
raise RuntimeError("gpt-oss-20b is disabled on ZeroGPU (too large)")
def load_tinyllama():
# CPU-friendly small chat model to keep ZeroGPU happy.
return pipeline(
"text-generation",
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
device_map="cpu",
torch_dtype=torch.float32,
)
def load_phi3_mini():
pipe = pipeline(
"text-generation",
model="microsoft/Phi-3-mini-4k-instruct",
device_map="cpu",
torch_dtype=torch.float32,
trust_remote_code=True,
model_kwargs={
"use_cache": False,
"attn_implementation": "eager",
},
)
# Disable cache to avoid DynamicCache.seen_tokens errors on ZeroGPU/CPU.
try:
pipe.model.config.use_cache = False
pipe.model.generation_config.use_cache = False
pipe.model.generation_config.cache_implementation = "static"
except Exception:
pass
return pipe
_tiny_wrapper = GeneratorWrapper("tinyllama-1.1b-chat", load_tinyllama)
_phi_wrapper = GeneratorWrapper(
"phi-3-mini-4k",
load_phi3_mini,
fallback=_tiny_wrapper,
fallback_msg="Phi-3-mini-4k unavailable on this Space (CUDA blocked); falling back to TinyLlama CPU.",
)
GENERATORS = {
"tinyllama-1.1b-chat": _tiny_wrapper,
"phi-3-mini-4k": _phi_wrapper,
}
# ----------------------------
# Prompt formatting
# ----------------------------
def format_prompt(query: str, contexts: List[Dict]) -> str:
context_blocks = []
for i, c in enumerate(contexts):
context_blocks.append(
f"--- Context {i+1} (file={c.get('filename','N/A')} chunk={c.get('chunk_id','?')}) ---\n{c.get('text','')}"
)
joined = "\n\n".join(context_blocks) if context_blocks else "None."
prompt = (
"You are MedDiscover, a biomedical assistant. Use ONLY the provided context to answer concisely.\n"
"If the context does not contain the answer, reply: 'Not found in provided documents.'\n\n"
f"{joined}\n\nQuestion: {query}\nAnswer:"
)
return prompt
# ----------------------------
# Gradio callbacks
# ----------------------------
def ensure_session_state(session_state):
if not isinstance(session_state, dict):
session_state = {}
if not session_state.get("id"):
session_state["id"] = str(uuid.uuid4())
if "records" not in session_state or not isinstance(session_state["records"], list):
session_state["records"] = []
return session_state
def append_log_record(record: Dict):
LOGS_PATH.parent.mkdir(parents=True, exist_ok=True)
with LOGS_PATH.open("a", encoding="utf-8") as f:
f.write(json.dumps(record) + "\n")
@spaces.GPU()
def process_pdfs(files, progress=gr.Progress()):
if not files:
return "Upload PDFs first."
texts = []
meta = []
doc_id = 0
for idx, f in enumerate(files):
progress(((idx) / max(len(files), 1)), desc=f"Reading {Path(f.name).name}")
text = extract_text_from_pdf(f.name)
if not text or text.startswith("Error reading"):
continue
chunks = chunk_text(text)
for cid, chunk in enumerate(chunks):
meta.append({"doc_id": doc_id, "filename": Path(f.name).name, "chunk_id": cid, "text": chunk})
texts.append(chunk)
doc_id += 1
if not texts:
return "No text extracted."
progress(0.7, desc=f"Embedding {len(texts)} chunks")
embeds = embed_chunks(texts)
if embeds.size == 0:
return "Embedding failed."
progress(0.85, desc="Building index")
idx = build_index(embeds)
save_index(idx, meta)
progress(1.0, desc="Ready")
return f"Processed {doc_id} PDFs. Index size={idx.ntotal}, dim={idx.d}. Saved to {DATA_DIR}."
def handle_query(query, model_key, k, max_new_tokens, temperature, top_p, session_state):
session_state = ensure_session_state(session_state)
if not query or query.strip() == "":
return "Enter a query", "No context", session_state
idx, meta = load_index()
if idx is None or meta is None:
return "Index not ready. Process PDFs first.", "No context", session_state
qvec = embed_query(query)
cands = search(idx, meta, qvec, int(k))
prompt = format_prompt(query, cands)
wrapper = GENERATORS[model_key]
stream = wrapper.generate_stream(prompt, int(max_new_tokens), float(temperature), float(top_p))
answer_accum = ""
for chunk in stream:
answer_accum += chunk
yield answer_accum, prompt, session_state
record = {
"timestamp": datetime.utcnow().isoformat() + "Z",
"session_id": session_state["id"],
"question": query,
"answer": answer_accum,
"context_chunks": cands,
"model": model_key,
"k": int(k),
"temperature": float(temperature),
"top_p": float(top_p),
"max_new_tokens": int(max_new_tokens),
}
session_state["records"].append(record)
append_log_record(record)
yield answer_accum, prompt, session_state
def export_session_json(session_state):
session_state = ensure_session_state(session_state)
records = session_state.get("records", [])
if not records:
return None, "No session records to export."
out_path = DATA_DIR / f"session-{session_state['id']}.json"
with out_path.open("w", encoding="utf-8") as f:
json.dump(records, f, ensure_ascii=False, indent=2)
return str(out_path), f"Exported {len(records)} records to JSON."
def export_session_csv(session_state):
session_state = ensure_session_state(session_state)
records = session_state.get("records", [])
if not records:
return None, "No session records to export."
out_path = DATA_DIR / f"session-{session_state['id']}.csv"
fields = [
"timestamp",
"session_id",
"question",
"answer",
"model",
"k",
"temperature",
"top_p",
"max_new_tokens",
"context",
]
with out_path.open("w", encoding="utf-8", newline="") as f:
writer = csv.DictWriter(f, fieldnames=fields)
writer.writeheader()
for rec in records:
ctx = " ||| ".join([c.get("text", "") for c in rec.get("context_chunks", [])])
writer.writerow(
{
"timestamp": rec.get("timestamp", ""),
"session_id": rec.get("session_id", ""),
"question": rec.get("question", ""),
"answer": rec.get("answer", ""),
"model": rec.get("model", ""),
"k": rec.get("k", ""),
"temperature": rec.get("temperature", ""),
"top_p": rec.get("top_p", ""),
"max_new_tokens": rec.get("max_new_tokens", ""),
"context": ctx,
}
)
return str(out_path), f"Exported {len(records)} records to CSV."
def clear_session(session_state):
session_state = ensure_session_state(session_state)
session_state["records"] = []
session_state["id"] = str(uuid.uuid4())
return session_state, "Session cleared."
# ----------------------------
# Gradio UI
# ----------------------------
with gr.Blocks(title="MedDiscover") as demo:
gr.Markdown("# 🩺 MedDiscover\nRetrieval Augmented Generation with Large Language Models for Biomedical Discovery Presented by,\nVatsal Patel, Elena Jolkver, Anne Schwerk\nIU International University of Applied Science, Germany")
with gr.Row():
with gr.Column(scale=1):
api_info = gr.Markdown("")
pdfs = gr.File(label="Upload PDFs", file_types=[".pdf"], file_count="multiple")
process_btn = gr.Button("Process PDFs (chunk/embed/index)", variant="primary")
status = gr.Textbox(label="Status", interactive=False)
model_dd = gr.Dropdown(
label="Generator Model",
choices=list(GENERATORS.keys()),
value="tinyllama-1.1b-chat",
interactive=True,
)
k_slider = gr.Slider(1, 10, value=3, step=1, label="Top-k chunks")
max_tokens = gr.Slider(20, 512, value=128, step=8, label="Max new tokens")
temp = gr.Slider(0.1, 1.5, value=0.4, step=0.1, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
with gr.Group():
gr.Markdown("Session Logs")
export_json_btn = gr.Button("Export session (JSON)")
export_csv_btn = gr.Button("Export session (CSV)")
clear_session_btn = gr.Button("Clear session")
session_status = gr.Textbox(label="Session status", interactive=False)
download_json = gr.File(label="JSON export", interactive=False)
download_csv = gr.File(label="CSV export", interactive=False)
with gr.Column(scale=2):
query = gr.Textbox(label="Query", lines=3, placeholder="Ask about your documents...")
answer = gr.Textbox(label="Answer (streaming)", lines=6)
context_box = gr.Textbox(label="Context used in prompt", lines=14)
go_btn = gr.Button("Ask", variant="primary")
session_state = gr.State({"id": None, "records": []})
process_btn.click(fn=process_pdfs, inputs=pdfs, outputs=status, show_progress="full")
go_btn.click(
fn=handle_query,
inputs=[query, model_dd, k_slider, max_tokens, temp, top_p, session_state],
outputs=[answer, context_box, session_state],
concurrency_limit=1,
)
query.submit(
fn=handle_query,
inputs=[query, model_dd, k_slider, max_tokens, temp, top_p, session_state],
outputs=[answer, context_box, session_state],
concurrency_limit=1,
)
export_json_btn.click(fn=export_session_json, inputs=session_state, outputs=[download_json, session_status])
export_csv_btn.click(fn=export_session_csv, inputs=session_state, outputs=[download_csv, session_status])
clear_session_btn.click(fn=clear_session, inputs=session_state, outputs=[session_state, session_status])
if __name__ == "__main__":
demo.queue().launch()