| | |
| | |
| | |
| | |
| | import os, json |
| | from dataclasses import dataclass, field |
| | from typing import Any, Callable, Dict, List, Optional |
| | from dotenv import load_dotenv |
| | from pydantic import BaseModel, Field |
| | import chainlit as cl |
| | from openai import AsyncOpenAI as _SDKAsyncOpenAI |
| |
|
| | |
| | |
| | |
| | def set_tracing_disabled(disabled: bool = True): |
| | return disabled |
| |
|
| | def function_tool(func: Callable): |
| | func._is_tool = True |
| | return func |
| |
|
| | def handoff(*args, **kwargs): |
| | return None |
| |
|
| | class InputGuardrail: |
| | def __init__(self, guardrail_function: Callable): |
| | self.guardrail_function = guardrail_function |
| |
|
| | @dataclass |
| | class GuardrailFunctionOutput: |
| | output_info: Any |
| | tripwire_triggered: bool = False |
| | tripwire_message: str = "" |
| |
|
| | class InputGuardrailTripwireTriggered(Exception): |
| | pass |
| |
|
| | class AsyncOpenAI: |
| | def __init__(self, api_key: str, base_url: Optional[str] = None): |
| | kwargs = {"api_key": api_key} |
| | if base_url: |
| | kwargs["base_url"] = base_url |
| | self._client = _SDKAsyncOpenAI(**kwargs) |
| |
|
| | @property |
| | def client(self): |
| | return self._client |
| |
|
| | class OpenAIChatCompletionsModel: |
| | def __init__(self, model: str, openai_client: AsyncOpenAI): |
| | self.model = model |
| | self.client = openai_client.client |
| |
|
| | @dataclass |
| | class Agent: |
| | name: str |
| | instructions: str |
| | model: OpenAIChatCompletionsModel |
| | tools: Optional[List[Callable]] = field(default_factory=list) |
| | handoff_description: Optional[str] = None |
| | output_type: Optional[type] = None |
| | input_guardrails: Optional[List[InputGuardrail]] = field(default_factory=list) |
| |
|
| | def tool_specs(self) -> List[Dict[str, Any]]: |
| | specs = [] |
| | for t in (self.tools or []): |
| | if getattr(t, "_is_tool", False): |
| | specs.append({ |
| | "type": "function", |
| | "function": { |
| | "name": t.__name__, |
| | "description": (t.__doc__ or "")[:512], |
| | "parameters": { |
| | "type": "object", |
| | "properties": { |
| | p: {"type": "string"} |
| | for p in t.__code__.co_varnames[:t.__code__.co_argcount] |
| | }, |
| | "required": list(t.__code__.co_varnames[:t.__code__.co_argcount]), |
| | }, |
| | }, |
| | }) |
| | return specs |
| |
|
| | class Runner: |
| | @staticmethod |
| | async def run(agent: Agent, user_input: str, context: Optional[Dict[str, Any]] = None): |
| | msgs = [ |
| | {"role": "system", "content": agent.instructions}, |
| | {"role": "user", "content": user_input}, |
| | ] |
| | tools = agent.tool_specs() |
| | tool_map = {t.__name__: t for t in (agent.tools or []) if getattr(t, "_is_tool", False)} |
| |
|
| | |
| | for _ in range(4): |
| | resp = await agent.model.client.chat.completions.create( |
| | model=agent.model.model, |
| | messages=msgs, |
| | tools=tools if tools else None, |
| | tool_choice="auto" if tools else None, |
| | ) |
| |
|
| | choice = resp.choices[0] |
| | msg = choice.message |
| | msgs.append({"role": "assistant", "content": msg.content or "", "tool_calls": msg.tool_calls}) |
| |
|
| | if msg.tool_calls: |
| | for call in msg.tool_calls: |
| | fn_name = call.function.name |
| | args = json.loads(call.function.arguments or "{}") |
| | if fn_name in tool_map: |
| | try: |
| | result = tool_map[fn_name](**args) |
| | except Exception as e: |
| | result = {"error": str(e)} |
| | else: |
| | result = {"error": f"Unknown tool: {fn_name}"} |
| | msgs.append({ |
| | "role": "tool", |
| | "tool_call_id": call.id, |
| | "name": fn_name, |
| | "content": json.dumps(result), |
| | }) |
| | continue |
| |
|
| | |
| | final_text = msg.content or "" |
| | final_obj = type("Result", (), {})() |
| | final_obj.final_output = final_text |
| | final_obj.context = context or {} |
| | if agent.output_type and issubclass(agent.output_type, BaseModel): |
| | try: |
| | data = agent.output_type.model_validate_json(final_text) |
| | final_obj.final_output = data.model_dump_json() |
| | final_obj.final_output_as = lambda t: data |
| | except Exception: |
| | final_obj.final_output_as = lambda t: final_text |
| | else: |
| | final_obj.final_output_as = lambda t: final_text |
| | return final_obj |
| |
|
| | final_obj = type("Result", (), {})() |
| | final_obj.final_output = "Sorry, I couldn't complete the request." |
| | final_obj.context = context or {} |
| | final_obj.final_output_as = lambda t: final_obj.final_output |
| | return final_obj |
| |
|
| | |
| | |
| | |
| | load_dotenv() |
| | API_KEY = os.environ.get("GEMINI_API_KEY") or os.environ.get("OPENAI_API_KEY") |
| | if not API_KEY: |
| | raise RuntimeError( |
| | "Missing GEMINI_API_KEY (or OPENAI_API_KEY). " |
| | "Add it in the Space secrets or a .env file." |
| | ) |
| |
|
| | set_tracing_disabled(True) |
| |
|
| | external_client: AsyncOpenAI = AsyncOpenAI( |
| | api_key=API_KEY, |
| | base_url="https://generativelanguage.googleapis.com/v1beta/openai/", |
| | ) |
| | llm_model: OpenAIChatCompletionsModel = OpenAIChatCompletionsModel( |
| | model="gemini-2.5-flash", |
| | openai_client=external_client, |
| | ) |
| |
|
| | |
| | |
| | |
| | class Section(BaseModel): |
| | title: str |
| | bullets: List[str] |
| |
|
| | class TutorResponse(BaseModel): |
| | modality: str |
| | acquisition_overview: Section |
| | common_artifacts: Section |
| | preprocessing_methods: Section |
| | study_tips: Section |
| | caution: str |
| |
|
| | |
| | |
| | |
| | @function_tool |
| | def infer_modality_from_filename(filename: str) -> dict: |
| | """ |
| | Guess modality (MRI/X-ray/CT/Ultrasound) from filename keywords. |
| | Returns: {"modality": "<guess or unknown>"} |
| | """ |
| | f = (filename or "").lower() |
| | guess = "unknown" |
| | mapping = { |
| | "xray": "X-ray", "x_ray": "X-ray", "xr": "X-ray", "cxr": "X-ray", |
| | "mri": "MRI", "t1": "MRI", "t2": "MRI", "flair": "MRI", "dwi": "MRI", "adc": "MRI", |
| | "ct": "CT", "cta": "CT", |
| | "ultrasound": "Ultrasound", "usg": "Ultrasound", "echo": "Ultrasound", |
| | } |
| | for key, mod in mapping.items(): |
| | if key in f: |
| | guess = mod |
| | break |
| | return {"modality": guess} |
| |
|
| | @function_tool |
| | def imaging_reference_guide(modality: str) -> dict: |
| | """ |
| | Educational points for acquisition, artifacts, preprocessing, and study tips by modality. |
| | Education only (no diagnosis). |
| | """ |
| | mod = (modality or "").strip().lower() |
| | if mod in ["xray", "x-ray", "x_ray"]: |
| | return { |
| | "acquisition": [ |
| | "Projection radiography using ionizing radiation.", |
| | "Common views: AP, PA, lateral; exposure (kVp/mAs) and positioning matter.", |
| | "Grids/collimation reduce scatter and improve contrast." |
| | ], |
| | "artifacts": [ |
| | "Motion blur; under/overexposure affecting contrast.", |
| | "Grid cut-off; foreign objects (buttons, jewelry).", |
| | "Magnification/distortion from object–detector distance." |
| | ], |
| | "preprocessing": [ |
| | "Denoising (median/NLM), histogram equalization.", |
| | "Window/level selection (bone vs soft tissue) for teaching.", |
| | "Edge enhancement (unsharp mask) with caution (halo artifacts)." |
| | ], |
| | "study_tips": [ |
| | "Use a systematic approach (e.g., ABCDE for chest X-ray).", |
| | "Compare sides; verify devices, labels, positioning.", |
| | "Correlate with clinical scenario; keep a checklist." |
| | ], |
| | } |
| | if mod in ["mri", "mr"]: |
| | return { |
| | "acquisition": [ |
| | "MR uses RF pulses in a strong magnetic field; sequences set contrast.", |
| | "Key sequences: T1, T2, FLAIR, DWI/ADC, GRE/SWI.", |
| | "TR/TE/flip angle shape SNR, contrast, time." |
| | ], |
| | "artifacts": [ |
| | "Motion/ghosting (movement, pulsation).", |
| | "Susceptibility (metal, air-bone interfaces).", |
| | "Chemical shift, Gibbs ringing.", |
| | "B0/B1 inhomogeneity causing intensity bias." |
| | ], |
| | "preprocessing": [ |
| | "Bias-field correction (N4).", |
| | "Denoising (non-local means), registration/normalization.", |
| | "Skull stripping (brain), intensity standardization." |
| | ], |
| | "study_tips": [ |
| | "Know sequence intent (T1 anatomy, T2 fluid, FLAIR edema).", |
| | "Check diffusion for acute ischemia (with ADC).", |
| | "Use consistent windowing for longitudinal comparison." |
| | ], |
| | } |
| | if mod == "ct": |
| | return { |
| | "acquisition": [ |
| | "Helical CT reconstructs attenuation in Hounsfield Units.", |
| | "Kernels (bone vs soft) change sharpness/noise.", |
| | "Contrast phases (arterial/venous) match the task." |
| | ], |
| | "artifacts": [ |
| | "Beam hardening (streaks), partial volume.", |
| | "Motion (breathing/cardiac).", |
| | "Metal artifacts; consider MAR algorithms." |
| | ], |
| | "preprocessing": [ |
| | "Denoising (bilateral/NLM) while preserving edges.", |
| | "Appropriate window/level (lung, mediastinum, bone).", |
| | "Iterative reconstruction / metal artifact reduction." |
| | ], |
| | "study_tips": [ |
| | "Use standard planes; scroll systematically.", |
| | "Compare windows; document sizes/HU as needed.", |
| | "Correlate phase with the clinical question." |
| | ], |
| | } |
| | return { |
| | "acquisition": [ |
| | "Acquisition parameters define contrast, resolution, and noise.", |
| | "Positioning and motion control are crucial for quality." |
| | ], |
| | "artifacts": [ |
| | "Motion blur/ghosting; foreign objects and hardware.", |
| | "Parameter misconfiguration harms interpretability." |
| | ], |
| | "preprocessing": [ |
| | "Denoising and contrast normalization for clarity.", |
| | "Registration to standard planes for comparison." |
| | ], |
| | "study_tips": [ |
| | "Adopt a checklist; compare across time or sides.", |
| | "Learn modality-specific knobs (window/level, sequences)." |
| | ], |
| | } |
| |
|
| | @function_tool |
| | def file_facts(filename: str, size_bytes: str) -> dict: |
| | """ |
| | Returns lightweight file facts: filename and byte size (as string). |
| | """ |
| | try: |
| | size = int(size_bytes) |
| | except Exception: |
| | size = -1 |
| | return {"filename": filename, "size_bytes": size} |
| |
|
| | |
| | |
| | |
| | tutor_instructions = ( |
| | "You are a Biomedical Imaging Education Tutor. TEACH, do not diagnose.\n" |
| | "Given an uploaded MRI or X-ray, provide:\n" |
| | "1) Acquisition overview\n" |
| | "2) Common artifacts\n" |
| | "3) Preprocessing methods\n" |
| | "4) Study tips\n" |
| | "5) A caution line: education only, no diagnosis\n" |
| | "Use tools to infer modality from filename and to fetch a modality reference guide.\n" |
| | "If unclear, provide a generic overview and ask for clarification.\n" |
| | "Always respond as concise, well-structured bullet points.\n" |
| | "Absolutely avoid clinical diagnosis, disease identification, or treatment advice." |
| | ) |
| |
|
| | tutor_agent = Agent( |
| | name="Biomedical Imaging Tutor", |
| | instructions=tutor_instructions, |
| | model=llm_model, |
| | tools=[infer_modality_from_filename, imaging_reference_guide, file_facts], |
| | ) |
| |
|
| | class SafetyCheck(BaseModel): |
| | unsafe_medical_advice: bool |
| | requests_diagnosis: bool |
| | pii_included: bool |
| | reasoning: str |
| |
|
| | guardrail_agent = Agent( |
| | name="Safety Classifier", |
| | instructions=( |
| | "Classify if the user's message requests medical diagnosis or unsafe medical advice, " |
| | "and if it includes personal identifiers. Respond as JSON with fields: " |
| | "{unsafe_medical_advice: bool, requests_diagnosis: bool, pii_included: bool, reasoning: string}." |
| | ), |
| | model=llm_model, |
| | ) |
| |
|
| | |
| | |
| | |
| | WELCOME = ( |
| | "🎓 **Multimodal Biomedical Imaging Tutor**\n\n" |
| | "Upload an **MRI** or **X-ray** image (PNG/JPG). I’ll explain:\n" |
| | "• Acquisition (how it’s made)\n" |
| | "• Common artifacts (what to watch for)\n" |
| | "• Preprocessing for study/teaching\n\n" |
| | "⚠️ *Education only — I do not provide diagnosis. For clinical concerns, consult a professional.*" |
| | ) |
| |
|
| | @cl.on_chat_start |
| | async def on_chat_start(): |
| | await cl.Message(content=WELCOME).send() |
| | files = await cl.AskFileMessage( |
| | content="Please upload an **MRI or X-ray** image (PNG/JPG).", |
| | accept=["image/png", "image/jpeg"], |
| | max_size_mb=15, |
| | max_files=1, |
| | timeout=180, |
| | ).send() |
| |
|
| | if not files: |
| | await cl.Message(content="No file uploaded. You can still ask general imaging questions.").send() |
| | return |
| |
|
| | f = files[0] |
| | cl.user_session.set("last_file_path", f.path) |
| | cl.user_session.set("last_file_name", f.name) |
| | cl.user_session.set("last_file_size", f.size) |
| |
|
| | await cl.Message( |
| | content=f"Received **{f.name}** ({f.size} bytes). " |
| | "Ask: *“Explain acquisition & artifacts for this image.”*" |
| | ).send() |
| |
|
| | @cl.on_message |
| | async def on_message(message: cl.Message): |
| | |
| | try: |
| | safety = await Runner.run(guardrail_agent, message.content) |
| | |
| | parsed = safety.final_output |
| | try: |
| | data = json.loads(parsed) if isinstance(parsed, str) else parsed |
| | except Exception: |
| | data = {} |
| | if isinstance(data, dict): |
| | if data.get("unsafe_medical_advice") or data.get("requests_diagnosis"): |
| | await cl.Message( |
| | content=( |
| | "🚫 I can’t provide medical diagnoses or treatment advice.\n" |
| | "I’m happy to explain **imaging concepts**, **artifacts**, and **preprocessing** for learning." |
| | ) |
| | ).send() |
| | return |
| | except Exception: |
| | pass |
| |
|
| | |
| | file_name = cl.user_session.get("last_file_name") |
| | file_size = cl.user_session.get("last_file_size") |
| |
|
| | context_note = "" |
| | if file_name: |
| | context_note += f"The user uploaded a file named '{file_name}'.\n" |
| | if file_size is not None: |
| | context_note += f"File size: {file_size} bytes.\n" |
| |
|
| | user_query = message.content |
| | if context_note: |
| | user_query = f"{user_query}\n\n[Context]\n{context_note}" |
| |
|
| | |
| | result = await Runner.run(tutor_agent, user_query) |
| |
|
| | |
| | facts_md = "" |
| | try: |
| | modality = infer_modality_from_filename(file_name or "").get("modality", "unknown") |
| | guide = imaging_reference_guide(modality) |
| | acq = "\n".join([f"- {b}" for b in guide.get("acquisition", [])]) |
| | art = "\n".join([f"- {b}" for b in guide.get("artifacts", [])]) |
| | prep = "\n".join([f"- {b}" for b in guide.get("preprocessing", [])]) |
| | tips = "\n".join([f"- {b}" for b in guide.get("study_tips", [])]) |
| |
|
| | facts_md = ( |
| | f"### 📁 File\n" |
| | f"- Name: `{file_name or 'unknown'}`\n" |
| | f"- Size: `{file_size if file_size is not None else 'unknown'} bytes`\n\n" |
| | f"### 🔎 Modality (guess)\n- {modality}\n\n" |
| | f"### 📚 Reference Guide (study)\n" |
| | f"**Acquisition**\n{acq or '- (general)'}\n\n" |
| | f"**Common Artifacts**\n{art or '- (general)'}\n\n" |
| | f"**Preprocessing Ideas**\n{prep or '- (general)'}\n\n" |
| | f"**Study Tips**\n{tips or '- (general)'}\n\n" |
| | f"> ⚠️ Education only — no diagnosis.\n" |
| | ) |
| | except Exception: |
| | pass |
| |
|
| | text = result.final_output or "I couldn’t generate an explanation." |
| | await cl.Message(content=f"{facts_md}\n---\n{text}").send() |
| |
|