Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification, AutoModelForCausalLM, AutoTokenizer | |
| from PIL import Image | |
| import torch | |
| from typing import Tuple, Optional, Dict, Any | |
| from dataclasses import dataclass | |
| import random | |
| import tempfile | |
| import webbrowser | |
| import os | |
| from datetime import datetime | |
| class PatientMetadata: | |
| age: int | |
| smoking_status: str | |
| family_history: bool | |
| menopause_status: str | |
| previous_mammogram: bool | |
| breast_density: str | |
| hormone_therapy: bool | |
| class AnalysisResult: | |
| has_tumor: bool | |
| tumor_size: str | |
| metadata: PatientMetadata | |
| class BreastSinogramAnalyzer: | |
| def __init__(self): | |
| """Initialize the analyzer with required models.""" | |
| print("Initializing system...") | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {self.device}") | |
| self._init_vision_models() | |
| self._init_llm() | |
| print("Initialization complete!") | |
| def _init_vision_models(self) -> None: | |
| """Initialize vision models for abnormality detection and size measurement.""" | |
| print("Loading detection models...") | |
| self.tumor_detector = AutoModelForImageClassification.from_pretrained( | |
| "SIATCN/vit_tumor_classifier" | |
| ).to(self.device).eval() | |
| self.tumor_processor = AutoImageProcessor.from_pretrained("SIATCN/vit_tumor_classifier") | |
| self.size_detector = AutoModelForImageClassification.from_pretrained( | |
| "SIATCN/vit_tumor_radius_detection_finetuned" | |
| ).to(self.device).eval() | |
| self.size_processor = AutoImageProcessor.from_pretrained( | |
| "SIATCN/vit_tumor_radius_detection_finetuned" | |
| ) | |
| def _init_llm(self) -> None: | |
| """Initialize the Qwen language model for report generation.""" | |
| print("Loading Qwen language model...") | |
| self.model_name = "Qwen/QwQ-32B-Preview" | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.model_name, | |
| torch_dtype="auto", | |
| device_map="auto" | |
| ) | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
| def _generate_synthetic_metadata(self) -> PatientMetadata: | |
| """Generate realistic patient metadata for breast cancer screening.""" | |
| age = random.randint(40, 75) | |
| smoking_status = random.choice(["Never Smoker", "Former Smoker", "Current Smoker"]) | |
| family_history = random.choice([True, False]) | |
| menopause_status = "Post-menopausal" if age > 50 else "Pre-menopausal" | |
| previous_mammogram = random.choice([True, False]) | |
| breast_density = random.choice(["A: Almost entirely fatty", | |
| "B: Scattered fibroglandular", | |
| "C: Heterogeneously dense", | |
| "D: Extremely dense"]) | |
| hormone_therapy = random.choice([True, False]) | |
| return PatientMetadata( | |
| age=age, | |
| smoking_status=smoking_status, | |
| family_history=family_history, | |
| menopause_status=menopause_status, | |
| previous_mammogram=previous_mammogram, | |
| breast_density=breast_density, | |
| hormone_therapy=hormone_therapy | |
| ) | |
| def _process_image(self, image: Image.Image) -> Image.Image: | |
| """Process input image for model consumption.""" | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| return image.resize((224, 224)) | |
| def _analyze_image(self, image: Image.Image) -> AnalysisResult: | |
| """Perform abnormality detection and size measurement.""" | |
| metadata = self._generate_synthetic_metadata() | |
| # Detect abnormality | |
| tumor_inputs = self.tumor_processor(image, return_tensors="pt").to(self.device) | |
| tumor_outputs = self.tumor_detector(**tumor_inputs) | |
| tumor_probs = tumor_outputs.logits.softmax(dim=-1)[0].cpu() | |
| has_tumor = tumor_probs[1] > tumor_probs[0] | |
| # Measure size if tumor detected | |
| size_inputs = self.size_processor(image, return_tensors="pt").to(self.device) | |
| size_outputs = self.size_detector(**size_inputs) | |
| size_pred = size_outputs.logits.softmax(dim=-1)[0].cpu() | |
| sizes = ["no-tumor", "0.5", "1.0", "1.5"] | |
| tumor_size = sizes[size_pred.argmax().item()] | |
| return AnalysisResult(has_tumor, tumor_size, metadata) | |
| def _generate_medical_report(self, analysis: AnalysisResult) -> str: | |
| """Generate a clear medical report using Qwen.""" | |
| try: | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": "You are a radiologist providing clear and straightforward medical reports. Focus on clarity and actionable recommendations." | |
| }, | |
| { | |
| "role": "user", | |
| "content": f"""Generate a clear medical report for this breast imaging scan: | |
| Scan Results: | |
| - Finding: {'Abnormal area detected' if analysis.has_tumor else 'No abnormalities detected'} | |
| {f'- Size of abnormal area: {analysis.tumor_size} cm' if analysis.has_tumor else ''} | |
| Patient Information: | |
| - Age: {analysis.metadata.age} years | |
| - Risk factors: {', '.join([ | |
| 'family history of breast cancer' if analysis.metadata.family_history else '', | |
| f'{analysis.metadata.smoking_status.lower()}', | |
| 'currently on hormone therapy' if analysis.metadata.hormone_therapy else '' | |
| ]).strip(', ')} | |
| Please provide: | |
| 1. A clear interpretation of the findings | |
| 2. A specific recommendation for next steps""" | |
| } | |
| ] | |
| text = self.tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device) | |
| generated_ids = self.model.generate( | |
| **model_inputs, | |
| max_new_tokens=128, | |
| temperature=0.3, | |
| top_p=0.9, | |
| repetition_penalty=1.1, | |
| do_sample=True | |
| ) | |
| generated_ids = [ | |
| output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) | |
| ] | |
| response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| if len(response.split()) >= 10: | |
| return f"""FINDINGS AND RECOMMENDATIONS: | |
| {response}""" | |
| return self._generate_fallback_report(analysis) | |
| except Exception as e: | |
| print(f"Error in report generation: {str(e)}") | |
| return self._generate_fallback_report(analysis) | |
| def _generate_fallback_report(self, analysis: AnalysisResult) -> str: | |
| """Generate a clear fallback report.""" | |
| if analysis.has_tumor: | |
| return f"""FINDINGS AND RECOMMENDATIONS: | |
| Finding: An abnormal area measuring {analysis.tumor_size} cm was detected during the scan. | |
| Recommendation: {'An immediate follow-up with conventional mammogram and ultrasound is required.' if analysis.tumor_size in ['1.0', '1.5'] else 'A follow-up scan is recommended in 6 months.'}""" | |
| else: | |
| return """FINDINGS AND RECOMMENDATIONS: | |
| Finding: No abnormal areas were detected during this scan. | |
| Recommendation: Continue with routine screening as per standard guidelines.""" | |
| def _generate_print_preview(self, analysis_text: str, image: Image.Image) -> str: | |
| """Generate an HTML print preview.""" | |
| temp_dir = tempfile.gettempdir() | |
| temp_image_path = os.path.join(temp_dir, 'scan_image.png') | |
| image.save(temp_image_path) | |
| current_date = datetime.now().strftime("%B %d, %Y") | |
| html_content = f""" | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <title>Medical Imaging Report</title> | |
| <style> | |
| @media print {{ | |
| body {{ | |
| font-family: Arial, sans-serif; | |
| line-height: 1.6; | |
| padding: 20px; | |
| max-width: 800px; | |
| margin: 0 auto; | |
| }} | |
| .header {{ | |
| text-align: center; | |
| margin-bottom: 30px; | |
| border-bottom: 2px solid #000; | |
| padding-bottom: 10px; | |
| }} | |
| .date {{ | |
| text-align: right; | |
| margin-bottom: 20px; | |
| }} | |
| .content {{ | |
| margin-bottom: 30px; | |
| }} | |
| .scan-image {{ | |
| text-align: center; | |
| margin: 20px 0; | |
| }} | |
| .scan-image img {{ | |
| max-width: 500px; | |
| height: auto; | |
| }} | |
| .footer {{ | |
| margin-top: 50px; | |
| border-top: 1px solid #000; | |
| padding-top: 20px; | |
| }} | |
| @page {{ | |
| size: A4; | |
| margin: 2cm; | |
| }} | |
| .no-print {{ | |
| display: none; | |
| }} | |
| }} | |
| /* Screen-only styles */ | |
| body {{ | |
| font-family: Arial, sans-serif; | |
| line-height: 1.6; | |
| padding: 20px; | |
| max-width: 800px; | |
| margin: 0 auto; | |
| }} | |
| .print-button {{ | |
| background-color: #007bff; | |
| color: white; | |
| padding: 10px 20px; | |
| border: none; | |
| border-radius: 5px; | |
| cursor: pointer; | |
| margin-bottom: 20px; | |
| }} | |
| .print-button:hover {{ | |
| background-color: #0056b3; | |
| }} | |
| </style> | |
| </head> | |
| <body> | |
| <button onclick="window.print()" class="print-button no-print">Print Report</button> | |
| <div class="header"> | |
| <h1>Medical Imaging Report</h1> | |
| </div> | |
| <div class="date"> | |
| Report Date: {current_date} | |
| </div> | |
| <div class="scan-image"> | |
| <img src="file://{temp_image_path}" alt="Scan Image"> | |
| </div> | |
| <div class="content"> | |
| <pre style="white-space: pre-wrap; font-family: Arial, sans-serif;">{analysis_text}</pre> | |
| </div> | |
| <div class="footer"> | |
| <p>This report is generated by an automated analysis system and should be reviewed by a qualified healthcare professional.</p> | |
| </div> | |
| </body> | |
| </html> | |
| """ | |
| temp_html_path = os.path.join(temp_dir, 'report.html') | |
| with open(temp_html_path, 'w', encoding='utf-8') as f: | |
| f.write(html_content) | |
| return temp_html_path | |
| def analyze(self, image: Image.Image) -> Tuple[str, str]: | |
| """Main analysis pipeline.""" | |
| try: | |
| processed_image = self._process_image(image) | |
| analysis = self._analyze_image(processed_image) | |
| report = self._generate_medical_report(analysis) | |
| analysis_text = f"""SCAN RESULTS: | |
| {'⚠️ Abnormal area detected' if analysis.has_tumor else '✓ No abnormalities detected'} | |
| {f'Size of abnormal area: {analysis.tumor_size} cm' if analysis.has_tumor else ''} | |
| PATIENT INFORMATION: | |
| • Age: {analysis.metadata.age} years | |
| • Risk Factors: {', '.join([ | |
| 'family history of breast cancer' if analysis.metadata.family_history else '', | |
| analysis.metadata.smoking_status.lower(), | |
| 'currently on hormone therapy' if analysis.metadata.hormone_therapy else '', | |
| ]).strip(', ')} | |
| {report}""" | |
| preview_path = self._generate_print_preview(analysis_text, image) | |
| return analysis_text, preview_path | |
| except Exception as e: | |
| return f"Error during analysis: {str(e)}", "" | |
| def open_print_preview(preview_path: str) -> None: | |
| """Open the print preview in the default browser.""" | |
| if preview_path: | |
| webbrowser.open(f'file://{preview_path}') | |
| return None | |
| def create_interface() -> gr.Blocks: | |
| """Create the Gradio interface.""" | |
| analyzer = BreastSinogramAnalyzer() | |
| with gr.Blocks() as interface: | |
| gr.Markdown("# Breast Imaging Analysis System") | |
| gr.Markdown("Upload a breast image for analysis and medical assessment.") | |
| with gr.Row(): | |
| input_image = gr.Image(type="pil", label="Upload Breast Image for Analysis") | |
| with gr.Row(): | |
| analyze_btn = gr.Button("Analyze Image", variant="primary") | |
| print_btn = gr.Button("Open Print Preview") | |
| output_text = gr.Textbox(label="Analysis Results", lines=20) | |
| preview_path = gr.Textbox(visible=False) | |
| analyze_btn.click( | |
| fn=analyzer.analyze, | |
| inputs=[input_image], | |
| outputs=[output_text, preview_path] | |
| ) | |
| print_btn.click( | |
| fn=open_print_preview, | |
| inputs=[preview_path], | |
| outputs=None | |
| ) | |
| return interface | |
| if __name__ == "__main__": | |
| print("Starting application...") | |
| interface = create_interface() | |
| interface.launch(debug=True, share=True) |