imurra's picture
updated to give entire exemplar question and answer choices
f33d22c verified
raw
history blame
11 kB
import os
os.environ['ANONYMIZED_TELEMETRY'] = 'False'
import zipfile
import chromadb
from sentence_transformers import SentenceTransformer
import gradio as gr
from fastapi import FastAPI
from pydantic import BaseModel
import re
import anthropic # You'll need: pip install anthropic
# OR if using OpenAI: import openai
# Extract and load database
DB_PATH = "./medqa_db"
if not os.path.exists(DB_PATH) and os.path.exists("./medqa_db.zip"):
print("πŸ“¦ Extracting database...")
with zipfile.ZipFile("./medqa_db.zip", 'r') as z:
z.extractall(".")
print("βœ… Database extracted")
print("πŸ”Œ Loading ChromaDB...")
client = chromadb.PersistentClient(path=DB_PATH)
collection = client.get_collection("medqa")
print(f"βœ… Loaded {collection.count()} questions")
print("🧠 Loading MedCPT model...")
model = SentenceTransformer('ncbi/MedCPT-Query-Encoder')
print("βœ… Model ready")
# Initialize AI client (choose one)
# Option 1: Claude
claude_client = anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY"))
# Option 2: OpenAI (uncomment if using)
# openai.api_key = os.environ.get("OPENAI_API_KEY")
# ============================================================================
# Deduplication function (same as before)
# ============================================================================
def deduplicate_results(results, target_count):
if not results['documents'][0]:
return results
documents = results['documents'][0]
metadatas = results['metadatas'][0]
distances = results['distances'][0]
selected_indices = []
for i in range(len(documents)):
is_duplicate = False
current_answer = metadatas[i].get('answer', '')
for j in selected_indices:
selected_answer = metadatas[j].get('answer', '')
dist_diff = abs(distances[i] - distances[j])
if dist_diff < 0.08:
is_duplicate = True
break
if current_answer == selected_answer and dist_diff < 0.15:
is_duplicate = True
break
if not is_duplicate:
selected_indices.append(i)
if len(selected_indices) >= target_count:
break
return {
'documents': [[documents[i] for i in selected_indices]],
'metadatas': [[metadatas[i] for i in selected_indices]],
'distances': [[distances[i] for i in selected_indices]],
'ids': [[results['ids'][0][i] for i in selected_indices]] if 'ids' in results else None
}
# ============================================================================
# Search function (same as before)
# ============================================================================
def search(query, num_results=3, source_filter=None):
emb = model.encode(query).tolist()
where_clause = None
if source_filter and source_filter != "all":
where_clause = {"source": source_filter}
fetch_count = min(num_results * 4, 50)
results = collection.query(
query_embeddings=[emb],
n_results=fetch_count,
where=where_clause
)
return deduplicate_results(results, num_results)
# ============================================================================
# NEW: Parser to extract question structure
# ============================================================================
def parse_question_document(doc_text, metadata):
"""Extract question and choices from document text."""
lines = doc_text.split('\n')
question_lines = []
options_started = False
options = {}
for line in lines:
line = line.strip()
if not line:
continue
option_match = re.match(r'^([A-E])[\.\)]\s*(.+)$', line)
if option_match:
options_started = True
letter = option_match.group(1)
text = option_match.group(2).strip()
options[letter] = text
elif not options_started:
question_lines.append(line)
question_text = ' '.join(question_lines).strip()
answer_idx = metadata.get('answer_idx', 'N/A')
return {
'question': question_text,
'choices': options,
'correct_answer': answer_idx
}
# ============================================================================
# NEW: AI generation functions
# ============================================================================
def generate_choice_explanations(question, choices, correct_answer):
"""Generate explanations for why each choice is correct/wrong."""
choices_text = '\n'.join([f"{k}. {v}" for k, v in choices.items()])
prompt = f"""You are a medical educator. For this USMLE-style question, explain why EACH answer choice is correct or incorrect.
QUESTION:
{question}
ANSWER CHOICES:
{choices_text}
CORRECT ANSWER: {correct_answer}
Provide a 1-2 sentence explanation for EACH choice (A through E) explaining why it is correct or incorrect. Format as:
A. [Choice text] - [Explanation]
B. [Choice text] - [Explanation]
C. [Choice text] - [Explanation]
D. [Choice text] - [Explanation]
E. [Choice text] - [Explanation]"""
# Using Claude
message = claude_client.messages.create(
model="claude-sonnet-4-20250514",
max_tokens=1000,
messages=[{"role": "user", "content": prompt}]
)
return message.content[0].text
# OR using OpenAI (uncomment if using):
# response = openai.ChatCompletion.create(
# model="gpt-4",
# messages=[{"role": "user", "content": prompt}],
# max_tokens=1000
# )
# return response.choices[0].message.content
def generate_similar_question(original_question, choices, correct_answer):
"""Generate a new question based on the exemplar."""
choices_text = '\n'.join([f"{k}. {v}" for k, v in choices.items()])
prompt = f"""You are a medical educator. Based on this USMLE-style question, create a NEW similar question that tests the SAME medical concept but with a different clinical scenario.
ORIGINAL QUESTION:
{question}
ANSWER CHOICES:
{choices_text}
CORRECT ANSWER: {correct_answer}
Create a NEW question that:
1. Tests the same medical concept
2. Uses a different patient scenario
3. Has 5 answer choices (A-E)
4. Includes explanations for why each choice is correct/incorrect
Format your response EXACTLY as:
NEW QUESTION:
[Your new question text]
ANSWER CHOICES:
A. [Choice A]
B. [Choice B]
C. [Choice C]
D. [Choice D]
E. [Choice E]
CORRECT ANSWER: [Letter]
EXPLANATIONS:
A. [Choice A text] - [Explanation]
B. [Choice B text] - [Explanation]
C. [Choice C text] - [Explanation]
D. [Choice D text] - [Explanation]
E. [Choice E text] - [Explanation]"""
# Using Claude
message = claude_client.messages.create(
model="claude-sonnet-4-20250514",
max_tokens=2000,
messages=[{"role": "user", "content": prompt}]
)
return message.content[0].text
# OR using OpenAI:
# response = openai.ChatCompletion.create(
# model="gpt-4",
# messages=[{"role": "user", "content": prompt}],
# max_tokens=2000
# )
# return response.choices[0].message.content
# ============================================================================
# NEW: Format complete output
# ============================================================================
def format_complete_output(exemplar_num, parsed, original_explanation, choice_explanations, new_question_text):
"""Format everything into readable plain text."""
choices_text = '\n'.join([f"{k}. {v}" for k, v in parsed['choices'].items()])
output = f"""{'='*80}
EXEMPLAR {exemplar_num}
{'='*80}
ORIGINAL QUESTION:
{parsed['question']}
ANSWER CHOICES:
{choices_text}
CORRECT ANSWER: {parsed['correct_answer']}
EXPLANATION FOR EACH CHOICE:
{choice_explanations}
"""
if original_explanation:
output += f"\nORIGINAL EXPLANATION FROM DATABASE:\n{original_explanation}\n"
output += f"""
{'-'*80}
AI-GENERATED SIMILAR QUESTION:
{'-'*80}
{new_question_text}
{'='*80}
"""
return output
# ============================================================================
# MODIFIED: API endpoint with full generation
# ============================================================================
app = FastAPI()
class SearchRequest(BaseModel):
query: str
num_results: int = 3
source_filter: str = None
generate_ai: bool = True # Option to skip AI generation for faster response
@app.post("/search_medqa")
def api_search(req: SearchRequest):
"""Search and return complete formatted exemplars with AI-generated content."""
print(f"πŸ” Searching for: {req.query}")
r = search(req.query, req.num_results, req.source_filter)
if not r['documents'][0]:
return {"output": "No results found."}
complete_output = f"SEARCH QUERY: {req.query}\n"
complete_output += f"FOUND {len(r['documents'][0])} EXEMPLARS\n\n"
for i in range(len(r['documents'][0])):
print(f"Processing exemplar {i+1}...")
doc_text = r['documents'][0][i]
metadata = r['metadatas'][0][i]
# Parse the exemplar
parsed = parse_question_document(doc_text, metadata)
original_explanation = metadata.get('explanation', '')
if req.generate_ai:
# Generate AI content
print(f" Generating choice explanations...")
choice_explanations = generate_choice_explanations(
parsed['question'],
parsed['choices'],
parsed['correct_answer']
)
print(f" Generating similar question...")
new_question = generate_similar_question(
parsed['question'],
parsed['choices'],
parsed['correct_answer']
)
else:
choice_explanations = "(AI generation skipped)"
new_question = "(AI generation skipped)"
# Format complete output
formatted = format_complete_output(
i + 1,
parsed,
original_explanation,
choice_explanations,
new_question
)
complete_output += formatted
return {
"output": complete_output,
"content_type": "text/plain"
}
# Gradio UI (simplified - just shows we have it)
with gr.Blocks(theme=gr.themes.Soft(), title="MedQA Search") as demo:
gr.Markdown("# πŸ₯ MedQA Search with AI Generation")
query_input = gr.Textbox(label="Query")
output = gr.Textbox(label="Results", lines=50)
app = gr.mount_gradio_app(app, demo, path="/")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)