from typing import Dict, List, Optional, Tuple, Type, Any from pathlib import Path from pydantic import BaseModel, Field import torch import transformers from transformers import AutoModelForCausalLM, AutoTokenizer from langchain_core.callbacks import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) from langchain_core.tools import BaseTool class XRayVQAToolInput(BaseModel): """Input schema for the CheXagent Tool.""" image_paths: List[str] = Field( ..., description="List of paths to chest X-ray images to analyze" ) prompt: str = Field(..., description="Question or instruction about the chest X-ray images") max_new_tokens: int = Field( 512, description="Maximum number of tokens to generate in the response" ) class XRayVQATool(BaseTool): """Tool that leverages CheXagent for comprehensive chest X-ray analysis.""" name: str = "chest_xray_expert" description: str = ( "A versatile tool for analyzing chest X-rays. " "Can perform multiple tasks including: visual question answering, report generation, " "abnormality detection, comparative analysis, anatomical description, " "and clinical interpretation. Input should be paths to X-ray images " "and a natural language prompt describing the analysis needed." ) args_schema: Type[BaseModel] = XRayVQAToolInput return_direct: bool = True cache_dir: Optional[str] = None device: Optional[str] = None dtype: torch.dtype = torch.bfloat16 tokenizer: Optional[AutoTokenizer] = None model: Optional[AutoModelForCausalLM] = None def __init__( self, model_name: str = "StanfordAIMI/CheXagent-2-3b", device: Optional[str] = "cuda", dtype: torch.dtype = torch.bfloat16, cache_dir: Optional[str] = None, **kwargs: Any, ) -> None: """Initialize the XRayVQATool. Args: model_name: Name of the CheXagent model to use device: Device to run model on (cuda/cpu) dtype: Data type for model weights cache_dir: Directory to cache downloaded models **kwargs: Additional arguments """ super().__init__(**kwargs) # Dangerous code, but works for now import transformers original_transformers_version = transformers.__version__ transformers.__version__ = "4.40.0" self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.dtype = dtype self.cache_dir = cache_dir # Load tokenizer and model self.tokenizer = AutoTokenizer.from_pretrained( model_name, trust_remote_code=True, cache_dir=cache_dir, ) self.model = AutoModelForCausalLM.from_pretrained( model_name, device_map=self.device, trust_remote_code=True, cache_dir=cache_dir, ) self.model = self.model.to(dtype=self.dtype) self.model.eval() transformers.__version__ = original_transformers_version def _generate_response(self, image_paths: List[str], prompt: str, max_new_tokens: int) -> str: """Generate response using CheXagent model. Args: image_paths: List of paths to chest X-ray images prompt: Question or instruction about the images max_new_tokens: Maximum number of tokens to generate Returns: str: Model's response """ query = self.tokenizer.from_list_format( [*[{"image": path} for path in image_paths], {"text": prompt}] ) conv = [ {"from": "system", "value": "You are a helpful assistant."}, {"from": "human", "value": query}, ] input_ids = self.tokenizer.apply_chat_template( conv, add_generation_prompt=True, return_tensors="pt" ).to(device=self.device) # Run inference with torch.inference_mode(): output = self.model.generate( input_ids, do_sample=False, num_beams=1, temperature=1.0, top_p=1.0, use_cache=True, max_new_tokens=max_new_tokens, )[0] response = self.tokenizer.decode(output[input_ids.size(1) : -1]) return response def _run( self, image_paths: List[str], prompt: str, max_new_tokens: int = 512, run_manager: Optional[CallbackManagerForToolRun] = None, ) -> Tuple[Dict[str, Any], Dict]: """Execute the chest X-ray analysis. Args: image_paths: List of paths to chest X-ray images prompt: Question or instruction about the images max_new_tokens: Maximum number of tokens to generate run_manager: Optional callback manager Returns: Tuple[Dict[str, Any], Dict]: Output dictionary and metadata dictionary """ try: # Verify image paths for path in image_paths: if not Path(path).is_file(): raise FileNotFoundError(f"Image file not found: {path}") response = self._generate_response(image_paths, prompt, max_new_tokens) output = { "response": response, } metadata = { "image_paths": image_paths, "prompt": prompt, "max_new_tokens": max_new_tokens, "analysis_status": "completed", } return output, metadata except Exception as e: output = {"error": str(e)} metadata = { "image_paths": image_paths, "prompt": prompt, "max_new_tokens": max_new_tokens, "analysis_status": "failed", "error_details": str(e), } return output, metadata async def _arun( self, image_paths: List[str], prompt: str, max_new_tokens: int = 512, run_manager: Optional[AsyncCallbackManagerForToolRun] = None, ) -> Tuple[Dict[str, Any], Dict]: """Async version of _run.""" return self._run(image_paths, prompt, max_new_tokens)