NikhilJoson commited on
Commit
5be4a85
1 Parent(s): e6654e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +233 -22
app.py CHANGED
@@ -2,14 +2,14 @@
2
  import os
3
  import random
4
  import numpy as np
5
-
6
- import torch
7
  import spaces
8
  import gradio as gr
9
 
10
- from diffusers import FluxTransformer2DModel, FluxInpaintPipeline
11
  import google.generativeai as genai
12
-
 
13
 
14
 
15
  MARKDOWN = """
@@ -18,6 +18,7 @@ Thanks to [Black Forest Labs](https://huggingface.co/black-forest-labs) team for
18
  and a big thanks to [Gothos](https://github.com/Gothos) for taking it to the next level by enabling inpainting with the FLUX.
19
  """
20
 
 
21
  #Gemini Setup
22
  genai.configure(api_key = os.environ['Gemini_API'])
23
  gemini_flash = genai.GenerativeModel(model_name='gemini-1.5-flash-002')
@@ -43,17 +44,196 @@ def gemini_predict(prompt):
43
  Query : {prompt}
44
  """
45
  response = gemini_flash.generate_content(system_message)
46
- return(str(response.text)[:-2])
 
47
 
48
 
49
  MAX_SEED = np.iinfo(np.int32).max
50
  DEVICE = "cuda" #if torch.cuda.is_available() else "cpu"
51
 
52
- #Setting up Flux (Schnell) Inpainting
53
- #inpaint_pipe = FluxInpaintPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(DEVICE)
54
 
55
- transformer_SchnellReal = FluxTransformer2DModel.from_single_file("https://huggingface.co/SG161222/RealFlux_1.0b_Schnell/blob/main/4%20-%20Schnell%20Transformer%20Version/RealFlux_1.0b_Schnell_Transformer.safetensors", torch_dtype=torch.bfloat16)
56
- inpaint_pipe = FluxInpaintPipeline.from_pretrained(bfl_repo, transformer=transformer_SchnellReal, torch_dtype=dtype).to(DEVICE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  #Uncomment the following 4 lines, if you want LoRA Realism weights added to the pipeline
59
  # inpaint_pipe.load_lora_weights('hugovntr/flux-schnell-realism', weight_name='schnell-realism_v2.3.safetensors', adapter_name="better")
@@ -64,9 +244,10 @@ inpaint_pipe = FluxInpaintPipeline.from_pretrained(bfl_repo, transformer=transfo
64
  #torch.cuda.empty_cache()
65
 
66
  @spaces.GPU()
67
- def process(input_image_editor, mask_image, input_text, strength, seed, randomize_seed, num_inference_steps, guidance_scale=3.5, progress=gr.Progress(track_tqdm=True)):
68
  if not input_text:
69
  raise gr.Error("Please enter a text prompt.")
 
70
  item = gemini_predict(input_text)
71
  #print(item)
72
 
@@ -77,15 +258,23 @@ def process(input_image_editor, mask_image, input_text, strength, seed, randomiz
77
 
78
  if randomize_seed:
79
  seed = random.randint(0, MAX_SEED)
80
-
81
- generator = torch.Generator(device=DEVICE).manual_seed(seed)
82
 
83
- result = inpaint_pipe(prompt=input_text, image=image, mask_image=mask_image, width=width, height=height,
 
 
 
 
 
 
 
 
 
 
84
  strength=strength, num_inference_steps=num_inference_steps, generator=generator,
85
  guidance_scale=guidance_scale).images[0]
86
 
87
 
88
- return result, mask_image, seed, item
89
 
90
  with gr.Blocks(theme=gr.themes.Ocean()) as demo:
91
  gr.Markdown(MARKDOWN)
@@ -109,14 +298,14 @@ with gr.Blocks(theme=gr.themes.Ocean()) as demo:
109
  strength_slider = gr.Slider(
110
  minimum=0.0,
111
  maximum=1.0,
112
- value=0.7,
113
  step=0.01,
114
  label="Strength"
115
  )
116
  num_inference_steps = gr.Slider(
117
  minimum=1,
118
  maximum=100,
119
- value=30,
120
  step=1,
121
  label="Number of inference steps"
122
  )
@@ -125,16 +314,38 @@ with gr.Blocks(theme=gr.themes.Ocean()) as demo:
125
  minimum=1,
126
  maximum=15,
127
  step=0.1,
128
- value=3.5,
129
  )
130
  seed_number = gr.Number(
131
  label="Seed",
132
- value=42,
133
  precision=0
134
  )
135
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
136
- with gr.Accordion("Upload a mask", open=False):
137
- uploaded_mask_component = gr.Image(label="Already made mask (black pixels will be preserved, white pixels will be redrawn)", sources=["upload"], type="pil")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  submit_button_component = gr.Button(value='Inpaint', variant='primary')
139
  with gr.Column(scale=1):
140
  output_image_component = gr.Image(type='pil', image_mode='RGB', label='Generated Image')
@@ -145,7 +356,7 @@ with gr.Blocks(theme=gr.themes.Ocean()) as demo:
145
 
146
  submit_button_component.click(
147
  fn=process,
148
- inputs=[input_image_component, uploaded_mask_component, input_text_component, strength_slider, seed_number, randomize_seed, num_inference_steps, guidance_scale],
149
  outputs=[output_image_component, output_mask_component, output_seed, identified_item]
150
  )
151
 
 
2
  import os
3
  import random
4
  import numpy as np
5
+ import cv2
 
6
  import spaces
7
  import gradio as gr
8
 
9
+ import torch
10
  import google.generativeai as genai
11
+ from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline
12
+ from diffusers import FluxTransformer2DModel, FluxInpaintPipeline
13
 
14
 
15
  MARKDOWN = """
 
18
  and a big thanks to [Gothos](https://github.com/Gothos) for taking it to the next level by enabling inpainting with the FLUX.
19
  """
20
 
21
+
22
  #Gemini Setup
23
  genai.configure(api_key = os.environ['Gemini_API'])
24
  gemini_flash = genai.GenerativeModel(model_name='gemini-1.5-flash-002')
 
44
  Query : {prompt}
45
  """
46
  response = gemini_flash.generate_content(system_message)
47
+ return(str(response.text)[:-1])
48
+
49
 
50
 
51
  MAX_SEED = np.iinfo(np.int32).max
52
  DEVICE = "cuda" #if torch.cuda.is_available() else "cpu"
53
 
 
 
54
 
55
+ ###GroundingDINO & SAM Setup
56
+
57
+ #To store DINO results
58
+ @dataclass
59
+ class BoundingBox:
60
+ xmin: int
61
+ ymin: int
62
+ xmax: int
63
+ ymax: int
64
+
65
+ @property
66
+ def xyxy(self) -> List[float]:
67
+ return [self.xmin, self.ymin, self.xmax, self.ymax]
68
+
69
+ @dataclass
70
+ class DetectionResult:
71
+ score: float
72
+ label: str
73
+ box: BoundingBox
74
+ mask: Optional[np.array] = None
75
+
76
+ @classmethod
77
+ def from_dict(cls, detection_dict: Dict) -> 'DetectionResult':
78
+ return cls(score=detection_dict['score'],
79
+ label=detection_dict['label'],
80
+ box=BoundingBox(xmin=detection_dict['box']['xmin'],
81
+ ymin=detection_dict['box']['ymin'],
82
+ xmax=detection_dict['box']['xmax'],
83
+ ymax=detection_dict['box']['ymax']))
84
+
85
+ #Utility Functions for Mask Generation
86
+ def mask_to_polygon(mask: np.ndarray) -> List[List[int]]:
87
+ # Find contours in the binary mask
88
+ contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
89
+
90
+ # Find the contour with the largest area
91
+ largest_contour = max(contours, key=cv2.contourArea)
92
+
93
+ # Extract the vertices of the contour
94
+ polygon = largest_contour.reshape(-1, 2).tolist()
95
+
96
+ return polygon
97
+
98
+ def polygon_to_mask(polygon: List[Tuple[int, int]], image_shape: Tuple[int, int]) -> np.ndarray:
99
+ """
100
+ Convert a polygon to a segmentation mask.
101
+
102
+ Args:
103
+ - polygon (list): List of (x, y) coordinates representing the vertices of the polygon.
104
+ - image_shape (tuple): Shape of the image (height, width) for the mask.
105
+
106
+ Returns:
107
+ - np.ndarray: Segmentation mask with the polygon filled.
108
+ """
109
+ # Create an empty mask
110
+ mask = np.zeros(image_shape, dtype=np.uint8)
111
+
112
+ # Convert polygon to an array of points
113
+ pts = np.array(polygon, dtype=np.int32)
114
+
115
+ # Fill the polygon with white color (255)
116
+ cv2.fillPoly(mask, [pts], color=(255,))
117
+
118
+ return mask
119
+
120
+ def get_boxes(results: DetectionResult) -> List[List[List[float]]]:
121
+ boxes = []
122
+ for result in results:
123
+ xyxy = result.box.xyxy
124
+ boxes.append(xyxy)
125
+
126
+ return [boxes]
127
+
128
+ def refine_masks(masks: torch.BoolTensor, polygon_refinement: bool = False) -> List[np.ndarray]:
129
+ masks = masks.cpu().float()
130
+ masks = masks.permute(0, 2, 3, 1)
131
+ masks = masks.mean(axis=-1)
132
+ masks = (masks > 0).int()
133
+ masks = masks.numpy().astype(np.uint8)
134
+ masks = list(masks)
135
+
136
+ #print(masks)
137
+
138
+ if polygon_refinement:
139
+ for idx, mask in enumerate(masks):
140
+ shape = mask.shape
141
+ polygon = mask_to_polygon(mask)
142
+ mask = polygon_to_mask(polygon, shape)
143
+ masks[idx] = mask
144
+
145
+ return masks
146
+
147
+ def get_alphacomp_mask(mask, image, random_color=True):
148
+ annotated_frame_pil = Image.fromarray(image).convert("RGBA")
149
+ #mask_image_pil = Image.fromarray((mask_image.cpu().numpy() * 255).astype(np.uint8)).convert("RGBA")
150
+ mask_image_pil = Image.fromarray(mask).convert("RGBA")
151
+
152
+ return np.array(Image.alpha_composite(annotated_frame_pil, mask_image_pil))
153
+
154
+
155
+ # Use Grounding DINO to detect a set of labels in an image in a zero-shot fashion.
156
+ detector_id = "IDEA-Research/grounding-dino-tiny"
157
+ object_detector = pipeline(model=detector_id, task="zero-shot-object-detection", device=SAM_device)
158
+
159
+ #Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes.
160
+ segmenter_id = "facebook/sam-vit-base"
161
+ processor = AutoProcessor.from_pretrained(segmenter_id)
162
+ segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to(SAM_device)
163
+
164
+ def detect(image: Image.Image, labels: List[str], threshold: float = 0.3) -> List[Dict[str, Any]]:
165
+ labels = [label if label.endswith(".") else label+"." for label in labels]
166
+
167
+ with torch.no_grad():
168
+ results = object_detector(image, candidate_labels=labels, threshold=threshold)
169
+ torch.cuda.empty_cache()
170
+
171
+ results = [DetectionResult.from_dict(result) for result in results]
172
+ #print("DINO results:", results)
173
+ return results
174
+
175
+ def segment_SAM(image: Image.Image, detection_results: List[Dict[str, Any]], polygon_refinement: bool = False) -> List[DetectionResult]:
176
+ boxes = get_boxes(detection_results)
177
+ inputs = processor(images=image, input_boxes=boxes, return_tensors="pt").to(SAM_device)
178
+
179
+ with torch.no_grad():
180
+ outputs = segmentator(**inputs)
181
+ torch.cuda.empty_cache()
182
+
183
+ masks = processor.post_process_masks(masks=outputs.pred_masks, original_sizes=inputs.original_sizes,
184
+ reshaped_input_sizes=inputs.reshaped_input_sizes)[0]
185
+
186
+ #print("Masks:", masks)
187
+ masks = refine_masks(masks, polygon_refinement)
188
+
189
+ for detection_result, mask in zip(detection_results, masks):
190
+ detection_result.mask = mask
191
+
192
+ return detection_results
193
+
194
+ def grounded_segmentation(image: Union[Image.Image, str], labels: List[str], threshold: float = 0.3,
195
+ polygon_refinement: bool = False) -> Tuple[np.ndarray, List[DetectionResult]]:
196
+
197
+ if isinstance(image, str):
198
+ image = load_image(image)
199
+
200
+ detections = detect(image, labels, threshold)
201
+ segmented = segment_SAM(image, detections, polygon_refinement)
202
+
203
+ return np.array(image), segmented
204
+
205
+ def get_finalmask(image_array, detections):
206
+ for i,d in enumerate(detections):
207
+ mask_ = d.__getattribute__('mask')
208
+ if i==0:
209
+ image_with_mask = get_alphacomp_mask(mask_, image_array)
210
+ else:
211
+ image_with_mask += get_alphacomp_mask(mask_, image_array)
212
+
213
+ return image_with_mask
214
+
215
+ #Preprocessing Mask
216
+ kernel = np.ones((3, 3), np.uint8) # Taking a matrix of size 3 as the kernel
217
+ def preprocess_mask(pipe, inp_mask, expan_lvl, blur_lvl):
218
+ if expan_lvl>0:
219
+ inp_mask = Image.fromarray(cv2.dilate(np.array(inp_mask), kernel, iterations=expan_lvl))
220
+
221
+ if blur_lvl>0:
222
+ inp_mask = pipe.mask_processor.blur(inp_mask, blur_factor=blur)
223
+
224
+ # inp_mask = Image.fromarray(np.array(inp_mask))
225
+ return inp_mask
226
+
227
+
228
+ def generate_mask(inp_image, label, threshold):
229
+ image_array, segments = grounded_segmentation(image=inp_image, labels=label, threshold=threshold, polygon_refinement=True,)
230
+ inp_mask = get_finalmask(image_array, segments)
231
+ # print(type(inp_mask))
232
+ return inp_mask
233
+
234
+
235
+ #Setting up Flux (Schnell) Inpainting
236
+ inpaint_pipe = FluxInpaintPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(DEVICE)
237
 
238
  #Uncomment the following 4 lines, if you want LoRA Realism weights added to the pipeline
239
  # inpaint_pipe.load_lora_weights('hugovntr/flux-schnell-realism', weight_name='schnell-realism_v2.3.safetensors', adapter_name="better")
 
244
  #torch.cuda.empty_cache()
245
 
246
  @spaces.GPU()
247
+ def process(input_image_editor, input_text, strength, seed, randomize_seed, num_inference_steps, guidance_scale, threshold, expan_lvl, blur_lvl, progress=gr.Progress(track_tqdm=True)):
248
  if not input_text:
249
  raise gr.Error("Please enter a text prompt.")
250
+ #Object identification
251
  item = gemini_predict(input_text)
252
  #print(item)
253
 
 
258
 
259
  if randomize_seed:
260
  seed = random.randint(0, MAX_SEED)
 
 
261
 
262
+
263
+ #Generating Mask
264
+ label = [item]
265
+ gen_mask = generate_mask(image, label, threshold)
266
+ #Pre-processing Mask, optional
267
+ if expan_lvl>0 or blur_lvl>0:
268
+ gen_mask = preprocess_mask(inpaint_pipe, gen_mask, expan_lvl, blur_lvl)
269
+
270
+ #Inpainting
271
+ generator = torch.Generator(device=DEVICE).manual_seed(seed)
272
+ result = inpaint_pipe(prompt=input_text, image=image, mask_image=gen_mask, width=width, height=height,
273
  strength=strength, num_inference_steps=num_inference_steps, generator=generator,
274
  guidance_scale=guidance_scale).images[0]
275
 
276
 
277
+ return result, gen_mask, seed, item
278
 
279
  with gr.Blocks(theme=gr.themes.Ocean()) as demo:
280
  gr.Markdown(MARKDOWN)
 
298
  strength_slider = gr.Slider(
299
  minimum=0.0,
300
  maximum=1.0,
301
+ value=0.8,
302
  step=0.01,
303
  label="Strength"
304
  )
305
  num_inference_steps = gr.Slider(
306
  minimum=1,
307
  maximum=100,
308
+ value=32,
309
  step=1,
310
  label="Number of inference steps"
311
  )
 
314
  minimum=1,
315
  maximum=15,
316
  step=0.1,
317
+ value=5,
318
  )
319
  seed_number = gr.Number(
320
  label="Seed",
321
+ value=26,
322
  precision=0
323
  )
324
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
325
+ with gr.Accordion("Mask Settings", open=False):
326
+ SAM_threshold = gr.Slider(
327
+ minimum=0.0,
328
+ maximum=1.0,
329
+ value=0.4,
330
+ step=0.01,
331
+ label="Threshold"
332
+ )
333
+ expansion_level = gr.Slider(
334
+ minimum=0,
335
+ maximum=5,
336
+ value=2,
337
+ step=1,
338
+ label="Mask Expansion level"
339
+ )
340
+ blur_level = gr.Slider(
341
+ minimum=0,
342
+ maximum=5,
343
+ step=1,
344
+ value=1,
345
+ label="Mask Blur level"
346
+ )
347
+ # with gr.Accordion("Upload a mask", open=False):
348
+ # uploaded_mask_component = gr.Image(label="Already made mask (black pixels will be preserved, white pixels will be redrawn)", sources=["upload"], type="pil")
349
  submit_button_component = gr.Button(value='Inpaint', variant='primary')
350
  with gr.Column(scale=1):
351
  output_image_component = gr.Image(type='pil', image_mode='RGB', label='Generated Image')
 
356
 
357
  submit_button_component.click(
358
  fn=process,
359
+ inputs=[input_image_component, input_text_component, strength_slider, seed_number, randomize_seed, num_inference_steps, guidance_scale, SAM_threshold, expansion_level, blur_level],
360
  outputs=[output_image_component, output_mask_component, output_seed, identified_item]
361
  )
362