taher30 commited on
Commit
9a86e9d
1 Parent(s): 975ae87

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -23
app.py CHANGED
@@ -1,40 +1,41 @@
1
- import gradio as gr
2
  import cv2
3
  import torch
4
  import numpy as np
5
  from PIL import Image
6
  from torchvision import transforms
7
  from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
 
8
  # import segmentation_models_pytorch as smp
9
 
10
- def load_model(model_type):
11
- # Model loading simplified for clarity
12
- model = sam_model_registry[model_type](checkpoint=f"sam_{model_type}_checkpoint.pth")
13
- model.to(device='cuda')
14
- return SamAutomaticMaskGenerator(model)
15
 
16
- def segment_and_classify(image, model_type):
17
- model = load_model(model_type)
18
- image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
19
-
20
- # Generate masks
21
- masks = model.generate(image_cv)
22
-
23
- # Prepare to store segments
24
- segments = []
 
 
 
25
 
26
- # Loop through masks and extract segments
27
- for mask_data in masks:
 
28
  mask = mask_data['segmentation']
29
- segment = image_cv * np.tile(mask[:, :, None], [1, 1, 3]) # Apply mask to the image
30
- segments.append(segment) # Store the segment for classification
31
 
32
- # Here you would call the classification model (e.g., CLIP)
33
- # For now, let's just return the first segment for visualization
34
- return Image.fromarray(segments[0])
 
 
35
 
36
  iface = gr.Interface(
37
- fn=segment_and_classify,
38
  inputs=[gr.inputs.Image(type="pil"), gr.inputs.Dropdown(['vit_h', 'vit_b', 'vit_l'], label="Model Type")],
39
  outputs=gr.outputs.Image(type="pil"),
40
  title="SAM Model Segmentation and Classification",
 
 
1
  import cv2
2
  import torch
3
  import numpy as np
4
  from PIL import Image
5
  from torchvision import transforms
6
  from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
7
+ import matplotlib.pyplot as plt
8
  # import segmentation_models_pytorch as smp
9
 
 
 
 
 
 
10
 
11
+
12
+ # image= cv2.imread('image_4.png', cv2.IMREAD_COLOR)
13
+ def get_masks(model_type, image):
14
+ if model_type == 'vit_h':
15
+ sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
16
+
17
+ masks_h = mask_generator_h.generate(image)
18
+ if model_type == 'vit_b':
19
+ sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
20
+
21
+ if model_type == 'vit_l':
22
+ sam = sam_model_registry["vit_l"](checkpoint="sam_vit_l_0b3195.pth")
23
 
24
+ mask_generator = SamAutomaticMaskGenerator(sam)
25
+ masks = mask_generator.generate(image)
26
+ for i, mask_data in enumerate(masks):
27
  mask = mask_data['segmentation']
28
+ color = colors[i]
29
+ composite_image[mask] = (color[:3] * 255).astype(np.uint8) # Apply color to mask
30
 
31
+ # Combine original image with the composite mask image
32
+ overlayed_image = (composite_image * 0.5 + image_cv.squeeze().permute(1, 2, 0).cpu().numpy() * 0.5).astype(np.uint8)
33
+ return overlayed_image
34
+
35
+
36
 
37
  iface = gr.Interface(
38
+ fn=get_masks,
39
  inputs=[gr.inputs.Image(type="pil"), gr.inputs.Dropdown(['vit_h', 'vit_b', 'vit_l'], label="Model Type")],
40
  outputs=gr.outputs.Image(type="pil"),
41
  title="SAM Model Segmentation and Classification",