| import os |
| import time |
| from typing import List, Dict, Any, Optional |
| from concurrent.futures import ThreadPoolExecutor |
|
|
| import gradio as gr |
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
| from Bio import Entrez |
| import traceback |
| import pandas as pd |
|
|
|
|
| |
| _CACHED_AGENT_KEY = None |
| _CACHED_AGENT = None |
|
|
| |
| _MODEL_CACHE: Dict[str, Dict[str, Any]] = {} |
|
|
|
|
| MODEL_NAME = "hkust-nlp/WebExplorer-8B" |
|
|
|
|
| def _get_hf_components(device_str: str) -> Dict[str, Any]: |
| """Load and cache tokenizer/model for the requested device string.""" |
| if device_str in _MODEL_CACHE: |
| return _MODEL_CACHE[device_str] |
| |
| print(f"Loading model for device: {device_str}") |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True) |
| |
| |
| if torch.cuda.is_available(): |
| try: |
| quantization_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_compute_dtype=torch.float16, |
| bnb_4bit_use_double_quant=True, |
| bnb_4bit_quant_type="nf4" |
| ) |
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_NAME, |
| quantization_config=quantization_config, |
| device_map="auto", |
| trust_remote_code=True, |
| low_cpu_mem_usage=True, |
| ) |
| except Exception as e: |
| print(f"4-bit load failed, falling back to standard half precision: {e}") |
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_NAME, |
| device_map="auto", |
| torch_dtype=torch.float16, |
| trust_remote_code=True, |
| low_cpu_mem_usage=True, |
| ) |
| else: |
| |
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_NAME, |
| device_map="auto", |
| torch_dtype=torch.float32, |
| low_cpu_mem_usage=True, |
| ) |
| |
| |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| model.config.pad_token_id = tokenizer.eos_token_id |
| |
| print(f"Model loaded successfully on {device_str}") |
| _MODEL_CACHE[device_str] = {"tokenizer": tokenizer, "model": model} |
| return _MODEL_CACHE[device_str] |
|
|
|
|
| class LocalWebExplorerAgent: |
| """Optimized medical research agent with PubMed integration.""" |
|
|
| def __init__(self, search_targets: List[str], use_cpu: bool): |
| self.search_targets = search_targets |
| self.device_str = "cpu" if use_cpu else ("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| Entrez.email = os.getenv("ENTREZ_EMAIL") |
| Entrez.api_key = os.getenv("ENTREZ_API_KEY") |
|
|
| comps = _get_hf_components(self.device_str) |
| self.tokenizer = comps["tokenizer"] |
| self.model = comps["model"] |
| |
| |
| self.search_cache: Dict[str, List[Dict[str, str]]] = {} |
|
|
| def _needs_search(self, query: str) -> bool: |
| """Determine if external search is needed.""" |
| lowered = query.lower() |
| trigger_terms = [ |
| "treatment", "survival", "trial", "latest", "guideline", |
| "therapy", "diagnosis", "prognosis", "rate", "statistic", |
| "study", "research", "clinical", "evidence" |
| ] |
| return any(term in lowered for term in trigger_terms) |
|
|
| def _extract_diagnosis(self, query: str) -> str: |
| """Extract medical condition from query.""" |
| query_lower = query.lower() |
| |
| |
| conditions = { |
| "lung": "lung cancer", |
| "pancreatic": "pancreatic cancer", |
| "breast": "breast cancer", |
| "colon": "colorectal cancer", |
| "prostate": "prostate cancer", |
| "melanoma": "melanoma", |
| "diabetes": "diabetes mellitus", |
| "heart failure": "heart failure", |
| "hypertension": "hypertension", |
| } |
| |
| for key, value in conditions.items(): |
| if key in query_lower: |
| return value |
| |
| return "general medical condition" |
|
|
| def _pubmed_search(self, diagnosis: str) -> List[Dict[str, str]]: |
| """Search PubMed with caching.""" |
| |
| if diagnosis in self.search_cache: |
| return self.search_cache[diagnosis] |
| |
| if not Entrez.email or Entrez.email == "user@example.com": |
| |
| return [] |
| |
| try: |
| query = f"{diagnosis} treatment guidelines[Title/Abstract] OR {diagnosis} clinical practice[Title/Abstract]" |
| handle = Entrez.esearch(db="pubmed", term=query, retmax=3, sort="relevance") |
| record = Entrez.read(handle) |
| handle.close() |
| |
| ids = record.get("IdList", []) |
| results: List[Dict[str, str]] = [] |
| |
| if ids: |
| |
| fetch = Entrez.esummary(db="pubmed", id=",".join(ids), retmode="xml") |
| summary_list = Entrez.read(fetch) |
| fetch.close() |
| |
| for summary in summary_list: |
| pmid = summary.get("Id", "") |
| title = summary.get("Title", "No title") |
| results.append({ |
| "pmid": str(pmid), |
| "title": title, |
| "url": f"https://pubmed.ncbi.nlm.nih.gov/{pmid}/" |
| }) |
| |
| |
| self.search_cache[diagnosis] = results |
| return results |
| |
| except Exception as e: |
| print(f"PubMed search error: {e}") |
| return [] |
|
|
| def _fetch_abstracts(self, pmids: List[str]) -> str: |
| """Fetch abstracts in parallel for speed.""" |
| if not Entrez.email or not pmids: |
| return "" |
| |
| def fetch_single(pmid: str) -> str: |
| try: |
| fetch = Entrez.efetch(db="pubmed", id=pmid, rettype="abstract", retmode="text") |
| content = fetch.read() |
| fetch.close() |
| |
| if isinstance(content, bytes): |
| content = content.decode('utf-8', errors='ignore') |
| return content |
| except Exception as e: |
| print(f"Error fetching abstract for PMID {pmid}: {e}") |
| return "" |
| |
| |
| with ThreadPoolExecutor(max_workers=3) as executor: |
| abstracts = list(executor.map(fetch_single, pmids)) |
| |
| return "\n\n".join([a for a in abstracts if a]) |
|
|
| def _generate(self, prompt: str, max_new_tokens: int = 200) -> str: |
| """Optimized generation with proper settings.""" |
| inputs = self.tokenizer( |
| prompt, |
| return_tensors="pt", |
| truncation=True, |
| max_length=1024 |
| ).to(self.model.device) |
| |
| with torch.inference_mode(): |
| outputs = self.model.generate( |
| **inputs, |
| max_new_tokens=max_new_tokens, |
| do_sample=False, |
| num_beams=1, |
| pad_token_id=self.tokenizer.pad_token_id, |
| eos_token_id=self.tokenizer.eos_token_id, |
| use_cache=True, |
| ) |
| |
| |
| generated_ids = outputs[0][inputs.input_ids.shape[1]:] |
| return self.tokenizer.decode(generated_ids, skip_special_tokens=True).strip() |
|
|
| def execute_query(self, query: str, max_turns: int = 3) -> Dict[str, Any]: |
| """Execute a single query with optimized flow.""" |
| turns: List[Dict[str, Any]] = [] |
| timestamp = int(time.time()) |
|
|
| |
| diagnosis = self._extract_diagnosis(query) |
|
|
| |
| needs_search = self._needs_search(query) |
| turns.append({ |
| "turn": 1, |
| "action_decision": "search" if needs_search else "reason", |
| "tool_calls": [], |
| }) |
|
|
| retrieved_docs: List[Dict[str, str]] = [] |
| abstracts = "" |
| |
| |
| if needs_search and len(turns) < max_turns: |
| retrieved_docs = self._pubmed_search(diagnosis) |
| turns.append({ |
| "turn": 2, |
| "action_decision": "search", |
| "tool_calls": [{ |
| "tool": "pubmed.search", |
| "args": {"diagnosis": diagnosis}, |
| "results": [f"PMID {d['pmid']}: {d['title']}" for d in retrieved_docs], |
| }], |
| }) |
| |
| |
| if retrieved_docs: |
| pmids = [d["pmid"] for d in retrieved_docs] |
| abstracts = self._fetch_abstracts(pmids) |
|
|
| |
| prompt = self._build_prompt(query, diagnosis, abstracts) |
| answer_text = self._generate(prompt, max_new_tokens=200) |
| |
| turns.append({ |
| "turn": len(turns) + 1, |
| "action_decision": "reason", |
| "tool_calls": [], |
| "response": answer_text[:100] + "..." |
| }) |
|
|
| |
| answer_text = self._format_answer(answer_text, query, retrieved_docs) |
|
|
| return { |
| "model_loaded": True, |
| "final_answer": answer_text, |
| "turns": turns, |
| "total_turns": len(turns), |
| "timestamp": timestamp, |
| } |
|
|
| def _build_prompt(self, query: str, diagnosis: str, abstracts: str) -> str: |
| """Build optimized prompt.""" |
| if abstracts: |
| return ( |
| f"Answer this medical question based on the research below.\n\n" |
| f"Question: {query}\n\n" |
| f"Research on {diagnosis}:\n{abstracts[:1500]}\n\n" |
| f"Provide a clear, concise summary of current treatments and outcomes." |
| ) |
| else: |
| return ( |
| f"Answer this medical question concisely and accurately.\n\n" |
| f"Question: {query}\n\n" |
| f"Provide evidence-based information in plain language." |
| ) |
|
|
| def _format_answer(self, answer: str, query: str, docs: List[Dict[str, str]]) -> str: |
| """Format answer with disclaimer and sources.""" |
| |
| medical_terms = ["cancer", "disease", "diabetes", "treatment", "diagnosis", "therapy"] |
| if any(term in query.lower() for term in medical_terms): |
| answer += "\n\n**Disclaimer:** This is educational information only. Always consult a healthcare professional for medical advice." |
| |
| |
| if docs: |
| answer += "\n\n**Sources:**\n" + "\n".join( |
| f"- [{d['title']}]({d['url']})" for d in docs |
| ) |
| |
| return answer |
|
|
| def execute_batch(self, queries: List[str], max_turns: int = 3, progress_callback=None) -> List[Dict[str, Any]]: |
| """Process multiple queries with progress tracking.""" |
| results = [] |
| total = len(queries) |
| |
| for idx, query in enumerate(queries): |
| if progress_callback: |
| progress_callback((idx + 1) / total, desc=f"Processing query {idx + 1}/{total}") |
| |
| try: |
| result = self.execute_query(query, max_turns=max_turns) |
| results.append(result) |
| except Exception as e: |
| print(f"Error processing query '{query}': {e}") |
| results.append({ |
| "model_loaded": False, |
| "final_answer": f"Error: {str(e)}", |
| "turns": [], |
| "total_turns": 0, |
| "timestamp": int(time.time()), |
| "error": str(e) |
| }) |
| |
| return results |
|
|
|
|
| DEFAULT_TARGETS = [ |
| 'nih.gov', 'cdc.gov', 'fda.gov', 'clinicaltrials.gov', 'medlineplus.gov', |
| 'who.int', 'cancerresearchuk.org', 'esmo.org', 'cancer.org', 'cancer.net', |
| 'mayoclinic.org', 'mdanderson.org', 'mskcc.org', 'dana-farber.org', |
| 'uptodate.com', 'ncbi.nlm.nih.gov', 'healthline.com', |
| ] |
|
|
|
|
| def get_agent(search_targets: List[str], use_cpu: bool) -> LocalWebExplorerAgent: |
| """Get or create cached agent.""" |
| global _CACHED_AGENT_KEY, _CACHED_AGENT |
| key = (tuple(sorted(search_targets)), use_cpu) |
| if _CACHED_AGENT is not None and _CACHED_AGENT_KEY == key: |
| return _CACHED_AGENT |
| _CACHED_AGENT = LocalWebExplorerAgent(search_targets=search_targets, use_cpu=use_cpu) |
| _CACHED_AGENT_KEY = key |
| return _CACHED_AGENT |
|
|
|
|
| def run_query(query: str, domain_scope: str, device_choice: str, max_turns: int, fast_mode: bool, progress=gr.Progress()): |
| """Run a single query with progress tracking.""" |
| if not query or not query.strip(): |
| return "Please enter a query.", {} |
|
|
| progress(0, desc="Loading model...") |
| use_cpu = device_choice == "CPU" |
| targets = DEFAULT_TARGETS if domain_scope == "Medical (Trusted sources only)" else [] |
| |
| try: |
| agent = get_agent(targets, use_cpu=use_cpu) |
| progress(0.2, desc="Processing query...") |
| |
| if fast_mode: |
| |
| agent._needs_search = lambda q: False |
| result = agent.execute_query(query.strip(), max_turns=1) |
| |
| if result.get('final_answer'): |
| result['final_answer'] = result['final_answer'][:1200] |
| else: |
| result = agent.execute_query(query.strip(), max_turns=max_turns) |
| |
| progress(1.0, desc="Complete!") |
| final_answer = result.get('final_answer', '') |
| |
| mini_trace = { |
| 'model_loaded': result.get('model_loaded'), |
| 'turns': result.get('turns', []), |
| 'total_turns': result.get('total_turns'), |
| 'timestamp': result.get('timestamp'), |
| 'fast_mode': fast_mode, |
| } |
| return final_answer, mini_trace |
| |
| except Exception as e: |
| tb = traceback.format_exc() |
| print("\n===== ERROR IN run_query =====\n", tb, "\n==============================\n") |
| return f"Error: {str(e)}", {"error": str(e), "traceback": tb} |
|
|
|
|
| def process_batch_file(file, domain_scope: str, device_choice: str, max_turns: int, progress=gr.Progress()): |
| """Process batch file with queries.""" |
| if file is None: |
| return "Please upload a file.", None |
| |
| progress(0, desc="Reading file...") |
| |
| try: |
| |
| if file.name.endswith('.csv'): |
| df = pd.read_csv(file.name) |
| if 'query' in df.columns: |
| queries = df['query'].tolist() |
| elif 'question' in df.columns: |
| queries = df['question'].tolist() |
| else: |
| queries = df.iloc[:, 0].tolist() |
| elif file.name.endswith('.txt'): |
| with open(file.name, 'r', encoding='utf-8') as f: |
| queries = [line.strip() for line in f if line.strip()] |
| else: |
| return "Please upload a CSV or TXT file.", None |
| |
| if not queries: |
| return "No queries found in file.", None |
| |
| progress(0.1, desc=f"Found {len(queries)} queries. Loading model...") |
| |
| use_cpu = device_choice == "CPU" |
| targets = DEFAULT_TARGETS if domain_scope == "Medical (Trusted sources only)" else [] |
| agent = get_agent(targets, use_cpu=use_cpu) |
| |
| |
| results = agent.execute_batch( |
| queries, |
| max_turns=max_turns, |
| progress_callback=lambda p, desc: progress(0.1 + p * 0.9, desc=desc) |
| ) |
| |
| |
| results_data = [] |
| for query, result in zip(queries, results): |
| results_data.append({ |
| 'Query': query, |
| 'Answer': result.get('final_answer', 'Error'), |
| 'Total Turns': result.get('total_turns', 0), |
| 'Success': result.get('model_loaded', False), |
| }) |
| |
| results_df = pd.DataFrame(results_data) |
| |
| |
| output_path = f"batch_results_{int(time.time())}.csv" |
| results_df.to_csv(output_path, index=False) |
| |
| progress(1.0, desc="Complete!") |
| |
| success_count = sum(r.get('model_loaded', False) for r in results) |
| summary = ( |
| f"β
Processed {len(queries)} queries\n\n" |
| f"π Success rate: {success_count}/{len(results)}\n\n" |
| f"πΎ Results saved to: `{output_path}`" |
| ) |
| |
| return summary, results_df |
| |
| except Exception as e: |
| tb = traceback.format_exc() |
| print("\n===== ERROR IN process_batch_file =====\n", tb, "\n==============================\n") |
| return f"Error processing file: {e}", None |
|
|
|
|
| |
| with gr.Blocks(title="WebExplorer-8B Medical Research") as demo: |
| gr.Markdown(""" |
| # π¬ WebExplorer-8B Medical Research Assistant |
| Ask medical questions or process multiple queries in batch. Powered by AI and PubMed research. |
| """) |
| |
| with gr.Tabs(): |
| with gr.Tab("π¬ Single Query"): |
| with gr.Row(): |
| query = gr.Textbox( |
| label="Medical Question", |
| lines=3, |
| placeholder="e.g., What are the treatment options for Type 2 diabetes?", |
| scale=4 |
| ) |
| |
| with gr.Row(): |
| domain_scope = gr.Radio( |
| choices=["Medical (Trusted sources only)", "All sources"], |
| value="Medical (Trusted sources only)", |
| label="Source Scope", |
| scale=2 |
| ) |
| device = gr.Radio( |
| choices=["GPU", "CPU"], |
| value="GPU", |
| label="Device", |
| scale=1 |
| ) |
| max_turns = gr.Slider( |
| minimum=1, maximum=5, value=2, step=1, |
| label="Max Research Depth", |
| scale=1 |
| ) |
| fast_mode = gr.Checkbox(value=True, label="Fast mode (skip PubMed, shorter answer)") |
|
|
| submit = gr.Button("π Research", variant="primary", size="lg") |
|
|
| answer = gr.Markdown(label="Answer", height=300) |
| trace = gr.Json(label="Execution Trace", visible=False) |
|
|
| gr.Markdown("### π Example Questions") |
| gr.Examples( |
| examples=[ |
| ["What are the survival rates for stage IV pancreatic cancer?"], |
| ["How is Type 2 diabetes diagnosed and treated?"], |
| ["What are the latest immunotherapy options for melanoma?"], |
| ["What are the risk factors for colorectal cancer?"], |
| ], |
| inputs=[query], |
| ) |
|
|
| submit.click( |
| run_query, |
| inputs=[query, domain_scope, device, max_turns, fast_mode], |
| outputs=[answer, trace] |
| ) |
| |
| with gr.Tab("π Batch Processing"): |
| gr.Markdown(""" |
| ### Process Multiple Queries |
| Upload a **CSV** (with 'query' column) or **TXT** file (one query per line). |
| """) |
| |
| batch_file = gr.File( |
| label="Upload File", |
| file_types=['.csv', '.txt'], |
| scale=2 |
| ) |
| |
| with gr.Row(): |
| batch_domain = gr.Radio( |
| choices=["Medical (Trusted sources only)", "All sources"], |
| value="Medical (Trusted sources only)", |
| label="Source Scope" |
| ) |
| batch_device = gr.Radio( |
| choices=["GPU", "CPU"], |
| value="GPU", |
| label="Device" |
| ) |
| batch_turns = gr.Slider( |
| minimum=1, maximum=5, value=2, step=1, |
| label="Max Research Depth" |
| ) |
| |
| batch_submit = gr.Button("π Process Batch", variant="primary", size="lg") |
| |
| batch_status = gr.Markdown(label="Status") |
| batch_results = gr.Dataframe(label="Results Preview", max_height=400) |
| |
| batch_submit.click( |
| process_batch_file, |
| inputs=[batch_file, batch_domain, batch_device, batch_turns], |
| outputs=[batch_status, batch_results] |
| ) |
|
|
| gr.Markdown(""" |
| --- |
| **Note:** Configure `ENTREZ_EMAIL` environment variable for PubMed access. |
| GPU recommended for faster processing (2-5s vs 30-60s on CPU). |
| """) |
|
|
|
|
| if __name__ == "__main__": |
| port = int(os.environ.get("PORT", "7860")) |
| demo.launch( |
| server_name="0.0.0.0", |
| server_port=port, |
| theme=gr.themes.Soft() |
| ) |