ArchCoder commited on
Commit
370bf7d
·
verified ·
1 Parent(s): 94a6d6a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -112
app.py CHANGED
@@ -10,10 +10,19 @@ from torchvision import transforms
10
  import torchvision.transforms.functional as TF
11
  import urllib.request
12
  import os
 
 
13
 
14
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
  model = None
16
 
 
 
 
 
 
 
 
17
  # Define your Attention U-Net architecture (from your training code)
18
  class DoubleConv(nn.Module):
19
  def __init__(self, in_channels, out_channels):
@@ -65,12 +74,10 @@ class AttentionUNET(nn.Module):
65
  self.downs = nn.ModuleList()
66
  self.attentions = nn.ModuleList()
67
  self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
68
-
69
  # Down part of UNET
70
  for feature in features:
71
  self.downs.append(DoubleConv(in_channels, feature))
72
  in_channels = feature
73
-
74
  # Bottleneck
75
  self.bottleneck = DoubleConv(features[-1], features[-1]*2)
76
 
@@ -79,19 +86,15 @@ class AttentionUNET(nn.Module):
79
  self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2))
80
  self.attentions.append(AttentionBlock(F_g=feature, F_l=feature, F_int=feature // 2))
81
  self.ups.append(DoubleConv(feature*2, feature))
82
-
83
  self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
84
-
85
  def forward(self, x):
86
  skip_connections = []
87
  for down in self.downs:
88
  x = down(x)
89
  skip_connections.append(x)
90
  x = self.pool(x)
91
-
92
  x = self.bottleneck(x)
93
  skip_connections = skip_connections[::-1] #reverse list
94
-
95
  for idx in range(0, len(self.ups), 2): #do up and double_conv
96
  x = self.ups[idx](x)
97
  skip_connection = skip_connections[idx//2]
@@ -100,7 +103,6 @@ class AttentionUNET(nn.Module):
100
  skip_connection = self.attentions[idx // 2](skip_connection, x)
101
  concat_skip = torch.cat((skip_connection, x), dim=1)
102
  x = self.ups[idx+1](concat_skip)
103
-
104
  return self.final_conv(x)
105
 
106
  def download_model():
@@ -109,15 +111,15 @@ def download_model():
109
  model_path = "best_attention_model.pth.tar"
110
 
111
  if not os.path.exists(model_path):
112
- print("📥 Downloading your trained model...")
113
  try:
114
  urllib.request.urlretrieve(model_url, model_path)
115
- print("Model downloaded successfully!")
116
  except Exception as e:
117
- print(f"Failed to download model: {e}")
118
  return None
119
  else:
120
- print("Model already exists!")
121
 
122
  return model_path
123
 
@@ -126,7 +128,7 @@ def load_your_attention_model():
126
  global model
127
  if model is None:
128
  try:
129
- print("🔄 Loading your trained Attention U-Net model...")
130
 
131
  # Download model if needed
132
  model_path = download_model()
@@ -141,9 +143,9 @@ def load_your_attention_model():
141
  model.load_state_dict(checkpoint["state_dict"])
142
  model.eval()
143
 
144
- print("Your Attention U-Net model loaded successfully!")
145
  except Exception as e:
146
- print(f"Error loading your model: {e}")
147
  model = None
148
  return model
149
 
@@ -161,75 +163,63 @@ def preprocess_for_your_model(image):
161
 
162
  return val_test_transform(image).unsqueeze(0) # Add batch dimension
163
 
164
- def create_heatmap_visualization(pred_mask_continuous, original_image):
165
- """Create heatmap visualization from continuous prediction values"""
166
- # pred_mask_continuous should be the raw sigmoid output (0-1 values)
167
- heatmap_np = pred_mask_continuous.cpu().squeeze().numpy()
168
-
169
- # Normalize to 0-255 for better visualization
170
- heatmap_normalized = (heatmap_np * 255).astype(np.uint8)
171
-
172
- # Apply colormap (using 'hot' colormap like in medical imaging)
173
- heatmap_colored = cv2.applyColorMap(heatmap_normalized, cv2.COLORMAP_HOT)
174
- heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)
175
-
176
- # Convert original image to RGB for overlay
177
- if len(original_image.shape) == 2: # Grayscale
178
- original_rgb = cv2.cvtColor(original_image.astype(np.uint8), cv2.COLOR_GRAY2RGB)
179
- else:
180
- original_rgb = original_image.astype(np.uint8)
181
-
182
- # Create overlay (blend original image with heatmap)
183
- alpha = 0.6 # Transparency factor
184
- overlay = cv2.addWeighted(original_rgb, 1-alpha, heatmap_colored, alpha, 0)
185
-
186
- return overlay
187
-
188
- def predict_tumor(image):
189
  current_model = load_your_attention_model()
190
 
191
  if current_model is None:
192
- return None, "Failed to load your trained model."
193
  if image is None:
194
- return None, "⚠️ Please upload an image first."
195
 
196
  try:
197
- print("🧠 Processing with YOUR trained Attention U-Net...")
198
 
199
  # Use the exact preprocessing from your Colab code
200
  input_tensor = preprocess_for_your_model(image).to(device)
201
 
202
  # Predict using your model (exactly like your Colab code)
203
  with torch.no_grad():
204
- pred_mask_continuous = torch.sigmoid(current_model(input_tensor)) # Keep continuous values for heatmap
205
- pred_mask_binary = (pred_mask_continuous > 0.5).float() # Binary mask for original visualizations
206
 
207
- # Convert to numpy (like your Colab code) - KEEPING ORIGINAL LOGIC
208
  pred_mask_np = pred_mask_binary.cpu().squeeze().numpy()
 
209
  original_np = np.array(image.convert('L').resize((256, 256)))
210
 
211
- # Create inverted mask for visualization (like your Colab code) - UNCHANGED
212
  inv_pred_mask_np = np.where(pred_mask_np == 1, 0, 255)
213
 
214
- # Create tumor-only image (like your Colab code) - UNCHANGED
215
  tumor_only = np.where(pred_mask_np == 1, original_np, 255)
216
 
217
- # NEW: Create heatmap visualization
218
- heatmap_overlay = create_heatmap_visualization(pred_mask_continuous, original_np)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
- # Create visualization with 5 panels (original 4 + heatmap)
221
  fig, axes = plt.subplots(1, 5, figsize=(25, 5))
222
- fig.suptitle('🧠 Your Attention U-Net Results with Heatmap', fontsize=16, fontweight='bold')
223
 
224
- titles = ["Original Image", "Predicted Mask", "Inverted Mask", "Tumor Only", "Prediction Heatmap"]
225
- images = [original_np, pred_mask_np * 255, inv_pred_mask_np, tumor_only, heatmap_overlay]
226
- cmaps = ['gray', 'hot', 'gray', 'gray', None] # None for RGB heatmap
227
 
228
  for i, ax in enumerate(axes):
229
- if cmaps[i] is not None:
230
- ax.imshow(images[i], cmap=cmaps[i])
231
- else:
232
- ax.imshow(images[i]) # RGB image
233
  ax.set_title(titles[i], fontsize=12, fontweight='bold')
234
  ax.axis('off')
235
 
@@ -243,68 +233,73 @@ def predict_tumor(image):
243
 
244
  result_image = Image.open(buf)
245
 
246
- # Calculate statistics (like your Colab code) - UNCHANGED
247
  tumor_pixels = np.sum(pred_mask_np)
248
  total_pixels = pred_mask_np.size
249
  tumor_percentage = (tumor_pixels / total_pixels) * 100
250
 
251
  # Calculate confidence metrics
252
- max_confidence = torch.max(pred_mask_continuous).item()
253
- mean_confidence = torch.mean(pred_mask_continuous).item()
254
 
255
  analysis_text = f"""
256
- ## 🧠 Your Attention U-Net Analysis Results
257
- ### 📊 Detection Summary:
258
- - **Status**: {'🔴 TUMOR DETECTED' if tumor_pixels > 50 else '🟢 NO SIGNIFICANT TUMOR'}
259
  - **Tumor Area**: {tumor_percentage:.2f}% of brain region
260
  - **Tumor Pixels**: {tumor_pixels:,} pixels
261
  - **Max Confidence**: {max_confidence:.4f}
262
  - **Mean Confidence**: {mean_confidence:.4f}
263
-
264
- ### 🔥 New Heatmap Features:
265
- - **Continuous Predictions**: Shows confidence levels (0-1)
266
- - **Color Coding**: Red/Yellow = High confidence, Blue/Black = Low confidence
267
- - **Overlay Visualization**: Heatmap overlaid on original image
268
- - **Enhanced Analysis**: Better understanding of model uncertainty
269
-
270
- ### 🔬 Your Model Information:
271
  - **Architecture**: YOUR trained Attention U-Net
272
  - **Training Performance**: Dice: 0.8420, IoU: 0.7297
273
  - **Input**: Grayscale (single channel)
274
- - **Output**: Binary segmentation mask + Continuous heatmap
275
  - **Device**: {device.type.upper()}
276
-
277
- ### 🎯 Model Performance:
278
  - **Training Accuracy**: 98.90%
279
  - **Best Dice Score**: 0.8420
280
  - **Best IoU Score**: 0.7297
281
  - **Training Dataset**: Brain tumor segmentation dataset
282
-
283
- ### 📈 Processing Details:
284
  - **Preprocessing**: Resize(256×256) + ToTensor (your exact method)
285
  - **Threshold**: 0.5 (sigmoid > 0.5)
286
  - **Architecture**: Attention gates + Skip connections
287
  - **Features**: [32, 64, 128, 256] channels
288
- - **Heatmap**: Continuous sigmoid output with hot colormap
289
-
290
- ### ⚠️ Medical Disclaimer:
291
  This is YOUR trained AI model for **research and educational purposes only**.
292
  Results should be validated by medical professionals. Not for clinical diagnosis.
293
-
294
- ### 🏆 Model Quality:
295
- ✅ This is your own trained model with proven {tumor_percentage:.2f}% detection capability!
296
  """
297
 
298
- print(f"Your model analysis completed! Tumor area: {tumor_percentage:.2f}%")
299
  return result_image, analysis_text
300
 
301
  except Exception as e:
302
- error_msg = f"Error with your model: {str(e)}"
303
  print(error_msg)
304
  return None, error_msg
305
 
 
 
 
 
 
 
 
 
 
 
306
  def clear_all():
307
- return None, None, "Upload a brain MRI image to test YOUR trained Attention U-Net model with heatmap visualization"
308
 
309
  # Enhanced CSS for your model
310
  css = """
@@ -324,13 +319,13 @@ css = """
324
  """
325
 
326
  # Create Gradio interface for your model
327
- with gr.Blocks(css=css, title="🧠 Your Attention U-Net Model with Heatmap", theme=gr.themes.Soft()) as app:
328
 
329
  gr.HTML("""
330
  <div id="title">
331
- <h1>🧠 YOUR Attention U-Net Model with Heatmap</h1>
332
  <p style="font-size: 18px; margin-top: 15px;">
333
- Using Your Own Trained Model • Dice: 0.8420 • IoU: 0.7297 • Now with Heatmap Visualization
334
  </p>
335
  <p style="font-size: 14px; margin-top: 10px; opacity: 0.9;">
336
  Loaded from: ArchCoder/the-op-segmenter HuggingFace Space
@@ -338,9 +333,11 @@ with gr.Blocks(css=css, title="🧠 Your Attention U-Net Model with Heatmap", th
338
  </div>
339
  """)
340
 
 
 
341
  with gr.Row():
342
  with gr.Column(scale=1):
343
- gr.Markdown("### 📤 Upload Brain MRI")
344
 
345
  image_input = gr.Image(
346
  label="Brain MRI Scan",
@@ -350,51 +347,49 @@ with gr.Blocks(css=css, title="🧠 Your Attention U-Net Model with Heatmap", th
350
  )
351
 
352
  with gr.Row():
353
- analyze_btn = gr.Button("🔍 Analyze with YOUR Model", variant="primary", scale=2, size="lg")
354
- clear_btn = gr.Button("🗑️ Clear", variant="secondary", scale=1)
 
355
 
356
  gr.HTML("""
357
  <div style="margin-top: 20px; padding: 20px; background: linear-gradient(135deg, #F3E8FF 0%, #EDE9FE 100%); border-radius: 10px; border-left: 4px solid #8B5CF6;">
358
- <h4 style="color: #8B5CF6; margin-bottom: 15px;">🏆 Your Model Features:</h4>
359
  <ul style="margin: 10px 0; padding-left: 20px; line-height: 1.6;">
360
  <li><strong>Personal Model:</strong> Your own trained Attention U-Net</li>
361
  <li><strong>Proven Performance:</strong> 84.2% Dice Score, 72.97% IoU</li>
362
  <li><strong>Attention Gates:</strong> Advanced feature selection</li>
363
  <li><strong>Clean Output:</strong> Binary segmentation masks</li>
364
- <li><strong>NEW: Heatmap:</strong> Continuous confidence visualization</li>
365
- <li><strong>5-Panel View:</strong> Complete analysis with heatmap</li>
366
  </ul>
367
  </div>
368
  """)
369
 
370
  with gr.Column(scale=2):
371
- gr.Markdown("### 📊 Your Model Results with Heatmap")
372
 
373
  output_image = gr.Image(
374
- label="Your Attention U-Net Analysis with Heatmap",
375
  type="pil",
376
  height=500
377
  )
378
 
379
  analysis_output = gr.Markdown(
380
- value="Upload a brain MRI image to test YOUR trained Attention U-Net model with heatmap visualization.",
381
  elem_id="analysis"
382
  )
383
-
384
- # Footer highlighting your model with heatmap features
385
  gr.HTML("""
386
  <div style="margin-top: 30px; padding: 25px; background-color: #F8FAFC; border-radius: 15px; border: 2px solid #8B5CF6;">
387
  <div style="display: grid; grid-template-columns: 1fr 1fr; gap: 30px;">
388
  <div>
389
- <h4 style="color: #8B5CF6; margin-bottom: 15px;">🏆 Your Personal AI Model</h4>
390
  <p><strong>Architecture:</strong> Attention U-Net with skip connections</p>
391
  <p><strong>Performance:</strong> Dice: 0.8420, IoU: 0.7297, Accuracy: 98.90%</p>
392
  <p><strong>Training:</strong> Your own dataset-specific training</p>
393
  <p><strong>Features:</strong> [32, 64, 128, 256] channel progression</p>
394
- <p><strong>NEW:</strong> Continuous heatmap visualization for confidence</p>
395
  </div>
396
  <div>
397
- <h4 style="color: #DC2626; margin-bottom: 15px;">⚠️ Your Model Disclaimer</h4>
398
  <p style="color: #DC2626; font-weight: 600; line-height: 1.4;">
399
  This is YOUR personally trained AI model for <strong>research purposes only</strong>.<br>
400
  Results reflect your model's training performance.<br>
@@ -404,7 +399,7 @@ with gr.Blocks(css=css, title="🧠 Your Attention U-Net Model with Heatmap", th
404
  </div>
405
  <hr style="margin: 20px 0; border: none; border-top: 2px solid #E5E7EB;">
406
  <p style="text-align: center; color: #6B7280; margin: 10px 0; font-weight: 600;">
407
- 🚀 Your Personal Attention U-Net • Downloaded from HuggingFace • Research-Grade Performance • Now with Heatmap! 🔥
408
  </p>
409
  </div>
410
  """)
@@ -412,27 +407,32 @@ with gr.Blocks(css=css, title="🧠 Your Attention U-Net Model with Heatmap", th
412
  # Event handlers
413
  analyze_btn.click(
414
  fn=predict_tumor,
415
- inputs=[image_input],
416
  outputs=[output_image, analysis_output],
417
  show_progress=True
418
  )
419
 
 
 
 
 
 
 
420
  clear_btn.click(
421
  fn=clear_all,
422
  inputs=[],
423
- outputs=[image_input, output_image, analysis_output]
424
  )
425
 
426
  if __name__ == "__main__":
427
- print("🚀 Starting YOUR Attention U-Net Model System with Heatmap...")
428
- print("🏆 Using your personally trained model")
429
- print("📥 Auto-downloading from HuggingFace...")
430
- print("🎯 Expected performance: Dice 0.8420, IoU 0.7297")
431
- print("🔥 NEW: Heatmap visualization added!")
432
 
433
  app.launch(
434
  server_name="0.0.0.0",
435
  server_port=7860,
436
  show_error=True,
437
  share=False
438
- )
 
10
  import torchvision.transforms.functional as TF
11
  import urllib.request
12
  import os
13
+ import random
14
+ import kagglehub
15
 
16
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
  model = None
18
 
19
+ # Download dataset
20
+ dataset_path = kagglehub.dataset_download('nikhilroxtomar/brain-tumor-segmentation')
21
+ image_path = os.path.join(dataset_path, 'images')
22
+ mask_path = os.path.join(dataset_path, 'masks')
23
+ test_imgs = sorted([f for f in os.listdir(image_path) if f.endswith('.jpg') or f.endswith('.png')])
24
+ test_masks = sorted([f for f in os.listdir(mask_path) if f.endswith('.jpg') or f.endswith('.png')])
25
+
26
  # Define your Attention U-Net architecture (from your training code)
27
  class DoubleConv(nn.Module):
28
  def __init__(self, in_channels, out_channels):
 
74
  self.downs = nn.ModuleList()
75
  self.attentions = nn.ModuleList()
76
  self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
 
77
  # Down part of UNET
78
  for feature in features:
79
  self.downs.append(DoubleConv(in_channels, feature))
80
  in_channels = feature
 
81
  # Bottleneck
82
  self.bottleneck = DoubleConv(features[-1], features[-1]*2)
83
 
 
86
  self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2))
87
  self.attentions.append(AttentionBlock(F_g=feature, F_l=feature, F_int=feature // 2))
88
  self.ups.append(DoubleConv(feature*2, feature))
 
89
  self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
 
90
  def forward(self, x):
91
  skip_connections = []
92
  for down in self.downs:
93
  x = down(x)
94
  skip_connections.append(x)
95
  x = self.pool(x)
 
96
  x = self.bottleneck(x)
97
  skip_connections = skip_connections[::-1] #reverse list
 
98
  for idx in range(0, len(self.ups), 2): #do up and double_conv
99
  x = self.ups[idx](x)
100
  skip_connection = skip_connections[idx//2]
 
103
  skip_connection = self.attentions[idx // 2](skip_connection, x)
104
  concat_skip = torch.cat((skip_connection, x), dim=1)
105
  x = self.ups[idx+1](concat_skip)
 
106
  return self.final_conv(x)
107
 
108
  def download_model():
 
111
  model_path = "best_attention_model.pth.tar"
112
 
113
  if not os.path.exists(model_path):
114
+ print("Downloading your trained model...")
115
  try:
116
  urllib.request.urlretrieve(model_url, model_path)
117
+ print("Model downloaded successfully!")
118
  except Exception as e:
119
+ print(f"Failed to download model: {e}")
120
  return None
121
  else:
122
+ print("Model already exists!")
123
 
124
  return model_path
125
 
 
128
  global model
129
  if model is None:
130
  try:
131
+ print("Loading your trained Attention U-Net model...")
132
 
133
  # Download model if needed
134
  model_path = download_model()
 
143
  model.load_state_dict(checkpoint["state_dict"])
144
  model.eval()
145
 
146
+ print("Your Attention U-Net model loaded successfully!")
147
  except Exception as e:
148
+ print(f"Error loading your model: {e}")
149
  model = None
150
  return model
151
 
 
163
 
164
  return val_test_transform(image).unsqueeze(0) # Add batch dimension
165
 
166
+ def predict_tumor(image, mask=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  current_model = load_your_attention_model()
168
 
169
  if current_model is None:
170
+ return None, "Failed to load your trained model."
171
  if image is None:
172
+ return None, "Please upload an image first."
173
 
174
  try:
175
+ print("Processing with YOUR trained Attention U-Net...")
176
 
177
  # Use the exact preprocessing from your Colab code
178
  input_tensor = preprocess_for_your_model(image).to(device)
179
 
180
  # Predict using your model (exactly like your Colab code)
181
  with torch.no_grad():
182
+ pred_mask = torch.sigmoid(current_model(input_tensor))
183
+ pred_mask_binary = (pred_mask > 0.5).float()
184
 
185
+ # Convert to numpy (like your Colab code)
186
  pred_mask_np = pred_mask_binary.cpu().squeeze().numpy()
187
+ prob_mask_np = pred_mask.cpu().squeeze().numpy() # Probability for heatmap
188
  original_np = np.array(image.convert('L').resize((256, 256)))
189
 
190
+ # Create inverted mask for visualization (like your Colab code)
191
  inv_pred_mask_np = np.where(pred_mask_np == 1, 0, 255)
192
 
193
+ # Create tumor-only image (like your Colab code)
194
  tumor_only = np.where(pred_mask_np == 1, original_np, 255)
195
 
196
+ # Handle ground truth if provided
197
+ mask_np = None
198
+ dice_score = None
199
+ iou_score = None
200
+ if mask is not None:
201
+ mask_transform = transforms.Compose([
202
+ transforms.Resize((256,256)),
203
+ transforms.ToTensor()
204
+ ])
205
+ mask_tensor = mask_transform(mask).squeeze().numpy()
206
+ mask_np = (mask_tensor > 0.5).astype(float)
207
+
208
+ intersection = np.logical_and(pred_mask_np, mask_np).sum()
209
+ union = np.logical_or(pred_mask_np, mask_np).sum()
210
+ iou_score = intersection / (union + 1e-7)
211
+ dice_score = (2 * intersection) / (pred_mask_np.sum() + mask_np.sum() + 1e-7)
212
 
213
+ # Create visualization (5-panel layout)
214
  fig, axes = plt.subplots(1, 5, figsize=(25, 5))
215
+ fig.suptitle('Your Attention U-Net Results', fontsize=16, fontweight='bold')
216
 
217
+ titles = ["Original Image", "Ground Truth", "Predicted Mask", "Tumor Only", "Heatmap"]
218
+ images = [original_np, mask_np if mask_np is not None else np.zeros_like(original_np), inv_pred_mask_np, tumor_only, prob_mask_np]
219
+ cmaps = ['gray', 'gray', 'gray', 'gray', 'hot']
220
 
221
  for i, ax in enumerate(axes):
222
+ ax.imshow(images[i], cmap=cmaps[i])
 
 
 
223
  ax.set_title(titles[i], fontsize=12, fontweight='bold')
224
  ax.axis('off')
225
 
 
233
 
234
  result_image = Image.open(buf)
235
 
236
+ # Calculate statistics (like your Colab code)
237
  tumor_pixels = np.sum(pred_mask_np)
238
  total_pixels = pred_mask_np.size
239
  tumor_percentage = (tumor_pixels / total_pixels) * 100
240
 
241
  # Calculate confidence metrics
242
+ max_confidence = torch.max(pred_mask).item()
243
+ mean_confidence = torch.mean(pred_mask).item()
244
 
245
  analysis_text = f"""
246
+ ## Your Attention U-Net Analysis Results
247
+ ### Detection Summary:
248
+ - **Status**: {'TUMOR DETECTED' if tumor_pixels > 50 else 'NO SIGNIFICANT TUMOR'}
249
  - **Tumor Area**: {tumor_percentage:.2f}% of brain region
250
  - **Tumor Pixels**: {tumor_pixels:,} pixels
251
  - **Max Confidence**: {max_confidence:.4f}
252
  - **Mean Confidence**: {mean_confidence:.4f}
253
+ """
254
+ if dice_score is not None and iou_score is not None:
255
+ analysis_text += f"""
256
+ - **Dice Score**: {dice_score:.4f}
257
+ - **IoU Score**: {iou_score:.4f}
258
+ """
259
+ analysis_text += """
260
+ ### Model Information:
261
  - **Architecture**: YOUR trained Attention U-Net
262
  - **Training Performance**: Dice: 0.8420, IoU: 0.7297
263
  - **Input**: Grayscale (single channel)
264
+ - **Output**: Binary segmentation mask
265
  - **Device**: {device.type.upper()}
266
+ ### Model Performance:
 
267
  - **Training Accuracy**: 98.90%
268
  - **Best Dice Score**: 0.8420
269
  - **Best IoU Score**: 0.7297
270
  - **Training Dataset**: Brain tumor segmentation dataset
271
+ ### Processing Details:
 
272
  - **Preprocessing**: Resize(256×256) + ToTensor (your exact method)
273
  - **Threshold**: 0.5 (sigmoid > 0.5)
274
  - **Architecture**: Attention gates + Skip connections
275
  - **Features**: [32, 64, 128, 256] channels
276
+ ### Medical Disclaimer:
 
 
277
  This is YOUR trained AI model for **research and educational purposes only**.
278
  Results should be validated by medical professionals. Not for clinical diagnosis.
279
+ ### Model Quality:
280
+ This is your own trained model with {tumor_percentage:.2f}% detection capability!
 
281
  """
282
 
283
+ print(f"Your model analysis completed! Tumor area: {tumor_percentage:.2f}%")
284
  return result_image, analysis_text
285
 
286
  except Exception as e:
287
+ error_msg = f"Error with your model: {str(e)}"
288
  print(error_msg)
289
  return None, error_msg
290
 
291
+ def load_random_sample():
292
+ if not test_imgs:
293
+ return None, None, "Dataset not available."
294
+ rand_idx = random.randint(0, len(test_imgs) - 1)
295
+ img_path = os.path.join(image_path, test_imgs[rand_idx])
296
+ msk_path = os.path.join(mask_path, test_masks[rand_idx]) # Assuming paired by index
297
+ image = Image.open(img_path).convert('L')
298
+ mask = Image.open(msk_path).convert('L')
299
+ return image, mask, "Loaded random sample from dataset."
300
+
301
  def clear_all():
302
+ return None, None, "Upload a brain MRI image to test YOUR trained Attention U-Net model", None
303
 
304
  # Enhanced CSS for your model
305
  css = """
 
319
  """
320
 
321
  # Create Gradio interface for your model
322
+ with gr.Blocks(css=css, title="Your Attention U-Net Model", theme=gr.themes.Soft()) as app:
323
 
324
  gr.HTML("""
325
  <div id="title">
326
+ <h1>Your Attention U-Net Model</h1>
327
  <p style="font-size: 18px; margin-top: 15px;">
328
+ Using Your Own Trained Model • Dice: 0.8420 • IoU: 0.7297
329
  </p>
330
  <p style="font-size: 14px; margin-top: 10px; opacity: 0.9;">
331
  Loaded from: ArchCoder/the-op-segmenter HuggingFace Space
 
333
  </div>
334
  """)
335
 
336
+ mask_state = gr.State(None)
337
+
338
  with gr.Row():
339
  with gr.Column(scale=1):
340
+ gr.Markdown("### Upload Brain MRI")
341
 
342
  image_input = gr.Image(
343
  label="Brain MRI Scan",
 
347
  )
348
 
349
  with gr.Row():
350
+ analyze_btn = gr.Button("Analyze with YOUR Model", variant="primary", scale=1, size="lg")
351
+ random_btn = gr.Button("Load Random Sample", variant="secondary", scale=1, size="lg")
352
+ clear_btn = gr.Button("Clear", variant="secondary", scale=1)
353
 
354
  gr.HTML("""
355
  <div style="margin-top: 20px; padding: 20px; background: linear-gradient(135deg, #F3E8FF 0%, #EDE9FE 100%); border-radius: 10px; border-left: 4px solid #8B5CF6;">
356
+ <h4 style="color: #8B5CF6; margin-bottom: 15px;">Your Model Features:</h4>
357
  <ul style="margin: 10px 0; padding-left: 20px; line-height: 1.6;">
358
  <li><strong>Personal Model:</strong> Your own trained Attention U-Net</li>
359
  <li><strong>Proven Performance:</strong> 84.2% Dice Score, 72.97% IoU</li>
360
  <li><strong>Attention Gates:</strong> Advanced feature selection</li>
361
  <li><strong>Clean Output:</strong> Binary segmentation masks</li>
362
+ <li><strong>5-Panel View:</strong> Complete analysis like your Colab</li>
 
363
  </ul>
364
  </div>
365
  """)
366
 
367
  with gr.Column(scale=2):
368
+ gr.Markdown("### Your Model Results")
369
 
370
  output_image = gr.Image(
371
+ label="Your Attention U-Net Analysis",
372
  type="pil",
373
  height=500
374
  )
375
 
376
  analysis_output = gr.Markdown(
377
+ value="Upload a brain MRI image to test YOUR trained Attention U-Net model.",
378
  elem_id="analysis"
379
  )
380
+ # Footer highlighting your model
 
381
  gr.HTML("""
382
  <div style="margin-top: 30px; padding: 25px; background-color: #F8FAFC; border-radius: 15px; border: 2px solid #8B5CF6;">
383
  <div style="display: grid; grid-template-columns: 1fr 1fr; gap: 30px;">
384
  <div>
385
+ <h4 style="color: #8B5CF6; margin-bottom: 15px;">Your Personal AI Model</h4>
386
  <p><strong>Architecture:</strong> Attention U-Net with skip connections</p>
387
  <p><strong>Performance:</strong> Dice: 0.8420, IoU: 0.7297, Accuracy: 98.90%</p>
388
  <p><strong>Training:</strong> Your own dataset-specific training</p>
389
  <p><strong>Features:</strong> [32, 64, 128, 256] channel progression</p>
 
390
  </div>
391
  <div>
392
+ <h4 style="color: #DC2626; margin-bottom: 15px;">Your Model Disclaimer</h4>
393
  <p style="color: #DC2626; font-weight: 600; line-height: 1.4;">
394
  This is YOUR personally trained AI model for <strong>research purposes only</strong>.<br>
395
  Results reflect your model's training performance.<br>
 
399
  </div>
400
  <hr style="margin: 20px 0; border: none; border-top: 2px solid #E5E7EB;">
401
  <p style="text-align: center; color: #6B7280; margin: 10px 0; font-weight: 600;">
402
+ Your Personal Attention U-Net • Downloaded from HuggingFace • Research-Grade Performance
403
  </p>
404
  </div>
405
  """)
 
407
  # Event handlers
408
  analyze_btn.click(
409
  fn=predict_tumor,
410
+ inputs=[image_input, mask_state],
411
  outputs=[output_image, analysis_output],
412
  show_progress=True
413
  )
414
 
415
+ random_btn.click(
416
+ fn=load_random_sample,
417
+ inputs=[],
418
+ outputs=[image_input, mask_state, analysis_output]
419
+ )
420
+
421
  clear_btn.click(
422
  fn=clear_all,
423
  inputs=[],
424
+ outputs=[image_input, output_image, analysis_output, mask_state]
425
  )
426
 
427
  if __name__ == "__main__":
428
+ print("Starting YOUR Attention U-Net Model System...")
429
+ print("Using your personally trained model")
430
+ print("Auto-downloading from HuggingFace...")
431
+ print("Expected performance: Dice 0.8420, IoU 0.7297")
 
432
 
433
  app.launch(
434
  server_name="0.0.0.0",
435
  server_port=7860,
436
  show_error=True,
437
  share=False
438
+ )