fireedman commited on
Commit
605391e
1 Parent(s): 3b39624

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -103
app.py CHANGED
@@ -1,18 +1,14 @@
1
  from typing import List
2
-
3
  import gradio as gr
4
  import numpy as np
5
  import supervision as sv
6
  import torch
7
-
8
  from PIL import Image
9
- from transformers import pipeline, CLIPProcessor, CLIPModel
10
-
11
 
12
- #************
13
- #Variables globales
14
  MARKDOWN = """
15
- #SAM
16
  """
17
  EXAMPLES = [
18
  ["https://media.roboflow.com/notebooks/examples/dog.jpeg", "dog", 0.5],
@@ -22,130 +18,79 @@ EXAMPLES = [
22
  ]
23
 
24
  MIN_AREA_THRESHOLD = 0.01
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
27
-
28
- SAM_GENERATOR = pipeline(
29
- task = "mask-generation",
30
- model = "facebook/sam-vit-large",
31
- device = DEVICE
32
- )
33
-
34
  SEMITRANSPARENT_MASK_ANNOTATOR = sv.MaskAnnotator(
35
- color = sv.Color.red(),
36
- color_lookup = sv.ColorLookup.INDEX
37
  )
38
 
39
  SOLID_MASK_ANNOTATOR = sv.MaskAnnotator(
40
- color = sv.Color.white(),
41
- color_lookup = sv.ColorLookup.INDEX,
42
- opacity = 1
43
  )
44
 
45
-
46
- #************
47
- #funciones de trabajo
48
-
49
- def run_sam(image_rgb_pil : Image.Image ) -> sv.Detections:
50
- outputs = SAM_GENERATOR(image_rgb_pil, points_per_batch = 32)
51
- mask = np.array(outputs['masks'])
52
- return sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask)
53
-
54
 
55
  def reverse_mask_image(image: np.ndarray, mask: np.ndarray, gray_value=128):
56
- gray_color = np.array([
57
- gray_value,
58
- gray_value,
59
- gray_value
60
- ], dtype=np.uint8)
61
  return np.where(mask[..., None], image, gray_color)
62
 
63
-
64
- """
65
- def filter_detections(image_rgb_pil: Image.Image, detections: sv.Detections) -> sv.Detections:
66
- img_rgb_numpy = np.array(image_rgb_pil)
67
- filtering_mask = []
68
- for xyxy, mask in zip(detections.xyxy, detections.mask):
69
- crop = sv.crop_image(
70
- image = img_rgb_numpy,
71
- xyxy =xyxy
72
- )
73
- mask_crop = sv.crop_image(
74
- image=mask,
75
- xyxy=xyxy
76
- )
77
- masked_crop = reverse_mask_image(
78
- image=crop,
79
- mask=mask_crop
80
- )
81
-
82
- filtering_mask = np.array(
83
- filtering_mask
84
- )
85
- return detections[filtering_mask]
86
- """
87
-
88
- def inference (image_rgb_pil: Image.Image) -> List[Image.Image]:
89
  width, height = image_rgb_pil.size
90
  area = width * height
91
 
92
- detections = run_sam(
93
- image_rgb_pil
94
- )
95
- detections = detections[ detections.area /area > MIN_AREA_THRESHOLD ]
96
-
97
- #detections = filter_detections(
98
- # image_rgb_pil=image_rgb_pil,
99
- # detections=detections,
100
- #)
101
-
102
  blank_image = Image.new("RGB", (width, height), "black")
103
  return [
104
- annotate(
105
- image_rgb_pil=image_rgb_pil,
106
- detections=detections,
107
- annotator=SEMITRANSPARENT_MASK_ANNOTATOR),
108
- annotate(
109
- image_rgb_pil=blank_image,
110
- detections=detections,
111
- annotator=SOLID_MASK_ANNOTATOR)
112
  ]
113
-
114
-
115
  #************
116
  #GRADIO CONSTRUCTION
117
  with gr.Blocks() as demo:
118
  gr.Markdown(MARKDOWN)
119
  with gr.Row():
120
  with gr.Column():
121
- input_image = gr.Image(
122
- image_mode = 'RGB',
123
- type = 'pil',
124
- height = 500
125
- )
126
  submit_button = gr.Button("Pruébalo!!!")
127
- gallery = gr.Gallery(
128
- label = "Result",
129
- object_fit = "scale-down",
130
- preview = True
131
- )
132
  with gr.Row():
133
  gr.Examples(
134
- examples = EXAMPLES,
135
- fn = inference,
136
- inputs = [
137
- input_image
138
- ],
139
- outputs = [gallery],
140
- cache_examples = False,
141
- run_on_click = True
142
  )
143
  submit_button.click(
144
  inference,
145
- inputs = [
146
- input_image
147
- ],
148
- outputs = gallery
149
  )
150
 
151
- demo.launch( debug = True, show_error = True )
 
1
  from typing import List
 
2
  import gradio as gr
3
  import numpy as np
4
  import supervision as sv
5
  import torch
 
6
  from PIL import Image
7
+ from transformers import pipeline
 
8
 
9
+ # Global Variables
 
10
  MARKDOWN = """
11
+ # SAM - Softly Activated Masks
12
  """
13
  EXAMPLES = [
14
  ["https://media.roboflow.com/notebooks/examples/dog.jpeg", "dog", 0.5],
 
18
  ]
19
 
20
  MIN_AREA_THRESHOLD = 0.01
21
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
22
+
23
+ # Initialize SAM Generator with exception handling
24
+ try:
25
+ SAM_GENERATOR = pipeline(
26
+ task="mask-generation",
27
+ model="facebook/sam-vit-large",
28
+ device=DEVICE
29
+ )
30
+ except Exception as e:
31
+ print(f"Error initializing SAM generator: {e}")
32
 
33
+ # Mask Annotators
 
 
 
 
 
 
 
34
  SEMITRANSPARENT_MASK_ANNOTATOR = sv.MaskAnnotator(
35
+ color=sv.Color.red(),
36
+ color_lookup=sv.ColorLookup.INDEX
37
  )
38
 
39
  SOLID_MASK_ANNOTATOR = sv.MaskAnnotator(
40
+ color=sv.Color.white(),
41
+ color_lookup=sv.ColorLookup.INDEX,
42
+ opacity=1
43
  )
44
 
45
+ # Functions
46
+ def run_sam(image_rgb_pil: Image.Image) -> sv.Detections:
47
+ try:
48
+ outputs = SAM_GENERATOR(image_rgb_pil, points_per_batch=32)
49
+ mask = np.array(outputs['masks'], dtype=np.uint8)
50
+ return sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask)
51
+ except Exception as e:
52
+ print(f"Error running SAM model: {e}")
53
+ return sv.Detections(xyxy=[], mask=[])
54
 
55
  def reverse_mask_image(image: np.ndarray, mask: np.ndarray, gray_value=128):
56
+ gray_color = np.array([gray_value, gray_value, gray_value], dtype=np.uint8)
 
 
 
 
57
  return np.where(mask[..., None], image, gray_color)
58
 
59
+ def inference(image_rgb_pil: Image.Image) -> List[Image.Image]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  width, height = image_rgb_pil.size
61
  area = width * height
62
 
63
+ detections = run_sam(image_rgb_pil)
64
+ detections = detections[detections.area / area > MIN_AREA_THRESHOLD]
65
+
 
 
 
 
 
 
 
66
  blank_image = Image.new("RGB", (width, height), "black")
67
  return [
68
+ SEMITRANSPARENT_MASK_ANNOTATOR.annotate(image_rgb_pil, detections),
69
+ SOLID_MASK_ANNOTATOR.annotate(blank_image, detections)
 
 
 
 
 
 
70
  ]
 
 
71
  #************
72
  #GRADIO CONSTRUCTION
73
  with gr.Blocks() as demo:
74
  gr.Markdown(MARKDOWN)
75
  with gr.Row():
76
  with gr.Column():
77
+ input_image = gr.Image(image_mode='RGB', type='pil', height=500)
 
 
 
 
78
  submit_button = gr.Button("Pruébalo!!!")
79
+ gallery = gr.Gallery(label="Result", object_fit="scale-down", preview=True)
80
+
 
 
 
81
  with gr.Row():
82
  gr.Examples(
83
+ examples=EXAMPLES,
84
+ fn=inference,
85
+ inputs=[input_image],
86
+ outputs=[gallery],
87
+ cache_examples=False,
88
+ run_on_click=True
 
 
89
  )
90
  submit_button.click(
91
  inference,
92
+ inputs=[input_image],
93
+ outputs=gallery
 
 
94
  )
95
 
96
+ demo.launch(debug=False, show_error=True)