fireedman commited on
Commit
f6b477c
1 Parent(s): 246f5ef

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +152 -0
app.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import gradio as gr
4
+ import numpy as np
5
+ import supervision as sv
6
+ import torch
7
+
8
+ from PIL import Image
9
+ from transformers import pipeline, CLIPProcessor, CLIPModel
10
+
11
+
12
+ #************
13
+ #Variables globales
14
+ MARKDOWN = """
15
+ #SAM
16
+ """
17
+ EXAMPLES = [
18
+ ["https://media.roboflow.com/notebooks/examples/dog.jpeg", "dog", 0.5],
19
+ ["https://media.roboflow.com/notebooks/examples/dog.jpeg", "building", 0.5],
20
+ ["https://media.roboflow.com/notebooks/examples/dog-3.jpeg", "jacket", 0.5],
21
+ ["https://media.roboflow.com/notebooks/examples/dog-3.jpeg", "coffee", 0.6],
22
+ ]
23
+
24
+ MIN_AREA_THRESHOLD = 0.01
25
+
26
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
27
+
28
+ SAM_GENERATOR = pipeline(
29
+ task = "mask-generation",
30
+ model = "facebook/sam-vit-large",
31
+ device = DEVICE
32
+ )
33
+
34
+ SEMITRANSPARENT_MASK_ANNOTATOR = sv.MaskAnnotator(
35
+ color = sv.Color.red(),
36
+ color_lookup = sv.ColorLookup.INDEX
37
+ )
38
+
39
+ SOLID_MASK_ANNOTATOR = sv.MaskAnnotator(
40
+ color = sv.Color.white(),
41
+ color_lookup = sv.ColorLookup.INDEX,
42
+ opacity = 1
43
+ )
44
+
45
+
46
+ #************
47
+ #funciones de trabajo
48
+
49
+ def run_sam(image_rgb_pil : Image.Image ) -> sv.Detections:
50
+ outputs = SAM_GENERATOR(image_rgb_pil, points_per_batch = 32)
51
+ mask = np.array(outputs['masks'])
52
+ return sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask)
53
+
54
+
55
+ def reverse_mask_image(image: np.ndarray, mask: np.ndarray, gray_value=128):
56
+ gray_color = np.array([
57
+ gray_value,
58
+ gray_value,
59
+ gray_value
60
+ ], dtype=np.uint8)
61
+ return np.where(mask[..., None], image, gray_color)
62
+
63
+
64
+ def filter_detections(image_rgb_pil: Image.Image, detections: sv.Detections) -> sv.Detections:
65
+ img_rgb_numpy = np.array(image_rgb_pil)
66
+ filtering_mask = []
67
+ for xyxy, mask in zip(detections.xyxy, detections.mask):
68
+ crop = sv.crop_image(
69
+ image = img_rgb_numpy,
70
+ xyxy =xyxy
71
+ )
72
+ mask_crop = sv.crop_image(
73
+ image=mask,
74
+ xyxy=xyxy
75
+ )
76
+ masked_crop = reverse_mask_image(
77
+ image=crop,
78
+ mask=mask_crop
79
+ )
80
+
81
+ filtering_mask = np.array(
82
+ filtering_mask
83
+ )
84
+ return detections[filtering_mask]
85
+
86
+
87
+ def inference (image_rgb_pil: Image.Image) -> List[Image.Image]:
88
+ width, height = image_rgb_pil.size
89
+ area = width * height
90
+
91
+ detections = run_sam(
92
+ image_rgb_pil
93
+ )
94
+ detections = detections[ detections.area /area > MIN_AREA_THRESHOLD ]
95
+ detections = filter_detections(
96
+ image_rgb_pil=image_rgb_pil,
97
+ detections=detections,
98
+ )
99
+ blank_image = Image.new("RGB", (width, height), "black")
100
+ return [
101
+ annotate(
102
+ image_rgb_pil=image_rgb_pil,
103
+ detections=detections,
104
+ annotator=SEMITRANSPARENT_MASK_ANNOTATOR),
105
+ annotate(
106
+ image_rgb_pil=blank_image,
107
+ detections=detections,
108
+ annotator=SOLID_MASK_ANNOTATOR)
109
+ ]
110
+
111
+
112
+ #************
113
+ #GRADIO CONSTRUCTION
114
+ with gr.Blocks() as demo:
115
+ gr.Markdown(MARKDOWN)
116
+ with gr.Row():
117
+ with gr.Column():
118
+ input_image = gr,Image(
119
+ image_mode = 'RGB',
120
+ type = 'pil',
121
+ height = 500
122
+ )
123
+ submit_button = gr.Button("Pruébalo!!!")
124
+ gallery = gr.Gallery(
125
+ label = "Result",
126
+ object_fit = "scale-down",
127
+ preview = True
128
+ )
129
+ with gr.Row():
130
+ gr.Examples(
131
+ examples = EXAMPLES,
132
+ fn = inference,
133
+ inputs = [
134
+ input_image,
135
+ prompt_text,
136
+ confidence_slider
137
+ ],
138
+ outputs = [gallery],
139
+ cache_examples = True,
140
+ run_on_click = True
141
+ )
142
+ submit_button.click(
143
+ inference,
144
+ inputs = [
145
+ input_image,
146
+ prompt_text,
147
+ confidence_slider
148
+ ],
149
+ outputs = gallery
150
+ )
151
+
152
+ demo.launch( debug = True, show_error = True )