medrax.org / medrax /tools /xray_vqa.py
oldcai's picture
Upload folder using huggingface_hub
d7a7846 verified
from typing import Dict, List, Optional, Tuple, Type, Any
from pathlib import Path
from pydantic import BaseModel, Field
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.tools import BaseTool
class XRayVQAToolInput(BaseModel):
"""Input schema for the CheXagent Tool."""
image_paths: List[str] = Field(
..., description="List of paths to chest X-ray images to analyze"
)
prompt: str = Field(..., description="Question or instruction about the chest X-ray images")
max_new_tokens: int = Field(
512, description="Maximum number of tokens to generate in the response"
)
class XRayVQATool(BaseTool):
"""Tool that leverages CheXagent for comprehensive chest X-ray analysis."""
name: str = "chest_xray_expert"
description: str = (
"A versatile tool for analyzing chest X-rays. "
"Can perform multiple tasks including: visual question answering, report generation, "
"abnormality detection, comparative analysis, anatomical description, "
"and clinical interpretation. Input should be paths to X-ray images "
"and a natural language prompt describing the analysis needed."
)
args_schema: Type[BaseModel] = XRayVQAToolInput
return_direct: bool = True
cache_dir: Optional[str] = None
device: Optional[str] = None
dtype: torch.dtype = torch.bfloat16
tokenizer: Optional[AutoTokenizer] = None
model: Optional[AutoModelForCausalLM] = None
def __init__(
self,
model_name: str = "StanfordAIMI/CheXagent-2-3b",
device: Optional[str] = "cuda",
dtype: torch.dtype = torch.bfloat16,
cache_dir: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Initialize the XRayVQATool.
Args:
model_name: Name of the CheXagent model to use
device: Device to run model on (cuda/cpu)
dtype: Data type for model weights
cache_dir: Directory to cache downloaded models
**kwargs: Additional arguments
"""
super().__init__(**kwargs)
# Dangerous code, but works for now
import transformers
original_transformers_version = transformers.__version__
transformers.__version__ = "4.40.0"
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.dtype = dtype
self.cache_dir = cache_dir
# Load tokenizer and model
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
cache_dir=cache_dir,
)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map=self.device,
trust_remote_code=True,
cache_dir=cache_dir,
)
self.model = self.model.to(dtype=self.dtype)
self.model.eval()
transformers.__version__ = original_transformers_version
def _generate_response(self, image_paths: List[str], prompt: str, max_new_tokens: int) -> str:
"""Generate response using CheXagent model.
Args:
image_paths: List of paths to chest X-ray images
prompt: Question or instruction about the images
max_new_tokens: Maximum number of tokens to generate
Returns:
str: Model's response
"""
query = self.tokenizer.from_list_format(
[*[{"image": path} for path in image_paths], {"text": prompt}]
)
conv = [
{"from": "system", "value": "You are a helpful assistant."},
{"from": "human", "value": query},
]
input_ids = self.tokenizer.apply_chat_template(
conv, add_generation_prompt=True, return_tensors="pt"
).to(device=self.device)
# Run inference
with torch.inference_mode():
output = self.model.generate(
input_ids,
do_sample=False,
num_beams=1,
temperature=1.0,
top_p=1.0,
use_cache=True,
max_new_tokens=max_new_tokens,
)[0]
response = self.tokenizer.decode(output[input_ids.size(1) : -1])
return response
def _run(
self,
image_paths: List[str],
prompt: str,
max_new_tokens: int = 512,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> Tuple[Dict[str, Any], Dict]:
"""Execute the chest X-ray analysis.
Args:
image_paths: List of paths to chest X-ray images
prompt: Question or instruction about the images
max_new_tokens: Maximum number of tokens to generate
run_manager: Optional callback manager
Returns:
Tuple[Dict[str, Any], Dict]: Output dictionary and metadata dictionary
"""
try:
# Verify image paths
for path in image_paths:
if not Path(path).is_file():
raise FileNotFoundError(f"Image file not found: {path}")
response = self._generate_response(image_paths, prompt, max_new_tokens)
output = {
"response": response,
}
metadata = {
"image_paths": image_paths,
"prompt": prompt,
"max_new_tokens": max_new_tokens,
"analysis_status": "completed",
}
return output, metadata
except Exception as e:
output = {"error": str(e)}
metadata = {
"image_paths": image_paths,
"prompt": prompt,
"max_new_tokens": max_new_tokens,
"analysis_status": "failed",
"error_details": str(e),
}
return output, metadata
async def _arun(
self,
image_paths: List[str],
prompt: str,
max_new_tokens: int = 512,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> Tuple[Dict[str, Any], Dict]:
"""Async version of _run."""
return self._run(image_paths, prompt, max_new_tokens)