Alexander McKinney commited on
Commit
b4542eb
1 Parent(s): 04bf3ab

interface example

Browse files

need to change to blocks, so we can compute segmentation once, diffusion
once. Only repeated components are on CPU.
unsure how to resolve onclick canvas, need to check what canvas can do.

Files changed (1) hide show
  1. app.py +107 -48
app.py CHANGED
@@ -12,6 +12,18 @@ from transformers.models.detr.feature_extraction_detr import rgb_to_id
12
 
13
  from diffusers import StableDiffusionInpaintPipeline
14
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def load_segmentation_models(model_name: str = 'facebook/detr-resnet-50-panoptic'):
16
  feature_extractor = DetrFeatureExtractor.from_pretrained(model_name)
17
  model = DetrForSegmentation.from_pretrained(model_name)
@@ -29,9 +41,6 @@ def load_diffusion_pipeline(model_name: str = 'runwayml/stable-diffusion-inpaint
29
  def get_device(try_cuda=True):
30
  return torch.device('cuda' if try_cuda and torch.cuda.is_available() else 'cpu')
31
 
32
- def greet(name):
33
- return "Hello " + name + "!"
34
-
35
  def min_pool(x: torch.Tensor, kernel_size: int):
36
  pad_size = (kernel_size - 1) // 2
37
  return -torch.nn.functional.max_pool2d(-x, kernel_size, (1, 1), padding=pad_size)
@@ -47,55 +56,105 @@ def clean_mask(mask, min_kernel: int = 5, max_kernel: int = 23):
47
  mask = mask.bool().squeeze().numpy()
48
  return mask
49
 
50
- # iface = gr.Interface(fn=greet, inputs="text", outputs="text")
51
- # iface.launch()
52
  device = get_device()
53
 
54
  feature_extractor, segmentation_model, segmentation_cfg = load_segmentation_models()
55
- model = segmentation_model.to(device)
56
-
57
- url = "http://images.cocodataset.org/val2017/000000039769.jpg"
58
- image = Image.open(requests.get(url, stream=True).raw)
59
-
60
- # prepare image for the model
61
- inputs = feature_extractor(images=image, return_tensors="pt").to(device)
62
-
63
- # forward pass
64
- outputs = segmentation_model(**inputs)
65
-
66
- processed_sizes = torch.as_tensor(inputs["pixel_values"].shape[-2:]).unsqueeze(0)
67
- result = feature_extractor.post_process_panoptic(outputs, processed_sizes)[0]
68
-
69
- panoptic_seg = Image.open(io.BytesIO(result["png_string"])).resize((image.width, image.height))
70
- panoptic_seg = np.array(panoptic_seg, dtype=np.uint8)
71
-
72
- panoptic_seg_id = rgb_to_id(panoptic_seg)
73
-
74
- print(result['segments_info'])
75
-
76
- # cat_mask = (panoptic_seg_id == 1) | (panoptic_seg_id == 5)
77
- cat_mask = (panoptic_seg_id == 5)
78
- cat_mask = clean_mask(cat_mask)
79
-
80
- masked_image = np.array(image).copy()
81
- masked_image[cat_mask] = 0
82
-
83
- masked_image = Image.fromarray(masked_image)
84
- masked_image.save('masked_cat.png')
85
 
86
  pipe = load_diffusion_pipeline()
87
  pipe = pipe.to(device)
88
 
89
- print(cat_mask)
90
-
91
- resize_ratio = 512 / 480
92
- new_width = int(640 * resize_ratio)
93
- new_width += 8 - (new_width % 8)
94
-
95
- print(new_width)
96
- cat_mask = Image.fromarray(cat_mask.astype(np.uint8) * 255).convert("RGB").resize((new_width, 512))
97
- masked_image = masked_image.resize((new_width, 512))
98
-
99
- prompt = "Two cats on the sofa together."
100
- inpainted_image = pipe(height=512, width=new_width, prompt=prompt, image=masked_image, mask_image=cat_mask).images[0]
101
- inpainted_image.save('inpaint_cat.png')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  from diffusers import StableDiffusionInpaintPipeline
14
 
15
+ # TODO: maybe need to port to `Blocks` system
16
+ # allegedly provides:
17
+ # Have multi-step interfaces, in which the output of one model becomes the
18
+ # input to the next model, or have more flexible data flows in general.
19
+
20
+ # and:
21
+ # Change a component’s properties (for example, the choices in a dropdown) or its visibility based on user input
22
+ # https://huggingface.co/course/chapter9/7?fw=pt
23
+
24
+ torch.inference_mode()
25
+ torch.no_grad()
26
+
27
  def load_segmentation_models(model_name: str = 'facebook/detr-resnet-50-panoptic'):
28
  feature_extractor = DetrFeatureExtractor.from_pretrained(model_name)
29
  model = DetrForSegmentation.from_pretrained(model_name)
 
41
  def get_device(try_cuda=True):
42
  return torch.device('cuda' if try_cuda and torch.cuda.is_available() else 'cpu')
43
 
 
 
 
44
  def min_pool(x: torch.Tensor, kernel_size: int):
45
  pad_size = (kernel_size - 1) // 2
46
  return -torch.nn.functional.max_pool2d(-x, kernel_size, (1, 1), padding=pad_size)
 
56
  mask = mask.bool().squeeze().numpy()
57
  return mask
58
 
 
 
59
  device = get_device()
60
 
61
  feature_extractor, segmentation_model, segmentation_cfg = load_segmentation_models()
62
+ # segmentation_model = segmentation_model.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  pipe = load_diffusion_pipeline()
65
  pipe = pipe.to(device)
66
 
67
+ # TODO: potentially use `gr.Gallery` to display different masks
68
+ def fn_segmentation_diffusion(prompt, mask_indices, image, max_kernel, min_kernel, num_diffusion_steps):
69
+ mask_indices = [int(i) for i in mask_indices.split(',')]
70
+ inputs = feature_extractor(images=image, return_tensors="pt")
71
+ outputs = segmentation_model(**inputs)
72
+
73
+ processed_sizes = torch.as_tensor(inputs["pixel_values"].shape[-2:]).unsqueeze(0)
74
+ result = feature_extractor.post_process_panoptic(outputs, processed_sizes)[0]
75
+
76
+ panoptic_seg = Image.open(io.BytesIO(result["png_string"])).resize((image.width, image.height))
77
+ panoptic_seg = np.array(panoptic_seg, dtype=np.uint8)
78
+
79
+ class_str = '\n'.join(segmentation_cfg.id2label[s['category_id']] for s in result['segments_info'])
80
+
81
+ panoptic_seg_id = rgb_to_id(panoptic_seg)
82
+
83
+ if len(mask_indices) > 0:
84
+ mask = (panoptic_seg_id == mask_indices[0])
85
+ for idx in mask_indices[1:]:
86
+ mask = mask | (panoptic_seg_id == idx)
87
+ mask = clean_mask(mask, min_kernel=min_kernel, max_kernel=max_kernel)
88
+
89
+ masked_image = np.array(image).copy()
90
+ masked_image[mask] = 0
91
+
92
+ masked_image = Image.fromarray(masked_image).resize(image.size)
93
+ mask = Image.fromarray(mask.astype(np.uint8) * 255).resize(image.size)
94
+
95
+ if num_diffusion_steps == 0:
96
+ return masked_image, masked_image, class_str
97
+
98
+ STABLE_DIFFUSION_SMALL_EDGE = 512
99
+
100
+ assert masked_image.size == mask.size
101
+ w, h = masked_image.size
102
+ is_width_larger = w > h
103
+ resize_ratio = STABLE_DIFFUSION_SMALL_EDGE / (h if is_width_larger else w)
104
+
105
+ new_width = int(w * resize_ratio) if is_width_larger else STABLE_DIFFUSION_SMALL_EDGE
106
+ new_height = STABLE_DIFFUSION_SMALL_EDGE if is_width_larger else int(h * resize_ratio)
107
+
108
+ new_width += 8 - (new_width % 8) if is_width_larger else 0
109
+ new_height += 0 if is_width_larger else 8 - (new_height % 8)
110
+
111
+ mask = mask.convert("RGB").resize((new_width, new_height))
112
+ masked_image = masked_image.convert("RGB").resize((new_width, new_height))
113
+
114
+ inpainted_image = pipe(
115
+ height=new_height,
116
+ width=new_width,
117
+ prompt=prompt,
118
+ image=masked_image,
119
+ mask_image=mask,
120
+ num_inference_steps=num_diffusion_steps
121
+ ).images[0]
122
+
123
+ return masked_image, inpainted_image, class_str
124
+
125
+
126
+ # iface_segmentation = gr.Interface(
127
+ # fn=fn_segmentation,
128
+ # inputs=[
129
+ # "text",
130
+ # "text",
131
+ # gr.Image(value="http://images.cocodataset.org/val2017/000000039769.jpg"),
132
+ # gr.Slider(minimum=1, maximum=99, value=23, step=2),
133
+ # gr.Slider(minimum=1, maximum=99, value=5, step=2),
134
+ # gr.Slider(minimum=0, maximum=100, value=50, step=1),
135
+ # ],
136
+ # outputs=["text", gr.Image(type="pil"), gr.Image(type="pil"), "number", "text"]
137
+ # )
138
+
139
+ # iface_diffusion = gr.Interface(
140
+ # fn=fn_diffusion,
141
+ # inputs=["text", gr.Image(type='pil'), gr.Image(type='pil'), "number", "text"],
142
+ # outputs=[gr.Image(), gr.Image(), gr.Textbox()]
143
+ # )
144
+
145
+ # iface = gr.Series(
146
+ # iface_segmentation, iface_diffusion,
147
+ iface = gr.Interface(
148
+ fn=fn_segmentation_diffusion,
149
+ inputs=[
150
+ "text",
151
+ "text",
152
+ gr.Image(value="http://images.cocodataset.org/val2017/000000039769.jpg", type='pil'),
153
+ gr.Slider(minimum=1, maximum=99, value=23, step=2),
154
+ gr.Slider(minimum=1, maximum=99, value=5, step=2),
155
+ gr.Slider(minimum=0, maximum=100, value=50, step=1),
156
+ ],
157
+ outputs=[gr.Image(), gr.Image(), gr.Textbox(interactive=False)]
158
+ )
159
+
160
+ iface.launch()