Alexander McKinney commited on
Commit
04bf3ab
1 Parent(s): 3e7b7cc

experimenting with segmentation mask and inpainting pipeline

Browse files
Files changed (1) hide show
  1. app.py +96 -2
app.py CHANGED
@@ -1,7 +1,101 @@
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  def greet(name):
4
  return "Hello " + name + "!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import requests
3
+ import numpy as np
4
+ import torch
5
+ from PIL import Image
6
+ from skimage.measure import block_reduce
7
+
8
  import gradio as gr
9
 
10
+ from transformers import DetrFeatureExtractor, DetrForSegmentation, DetrConfig
11
+ from transformers.models.detr.feature_extraction_detr import rgb_to_id
12
+
13
+ from diffusers import StableDiffusionInpaintPipeline
14
+
15
+ def load_segmentation_models(model_name: str = 'facebook/detr-resnet-50-panoptic'):
16
+ feature_extractor = DetrFeatureExtractor.from_pretrained(model_name)
17
+ model = DetrForSegmentation.from_pretrained(model_name)
18
+ cfg = DetrConfig.from_pretrained(model_name)
19
+
20
+ return feature_extractor, model, cfg
21
+
22
+ def load_diffusion_pipeline(model_name: str = 'runwayml/stable-diffusion-inpainting'):
23
+ return StableDiffusionInpaintPipeline.from_pretrained(
24
+ model_name,
25
+ revision='fp16',
26
+ torch_dtype=torch.float16
27
+ )
28
+
29
+ def get_device(try_cuda=True):
30
+ return torch.device('cuda' if try_cuda and torch.cuda.is_available() else 'cpu')
31
+
32
  def greet(name):
33
  return "Hello " + name + "!"
34
 
35
+ def min_pool(x: torch.Tensor, kernel_size: int):
36
+ pad_size = (kernel_size - 1) // 2
37
+ return -torch.nn.functional.max_pool2d(-x, kernel_size, (1, 1), padding=pad_size)
38
+
39
+ def max_pool(x: torch.Tensor, kernel_size: int):
40
+ pad_size = (kernel_size - 1) // 2
41
+ return torch.nn.functional.max_pool2d(x, kernel_size, (1, 1), padding=pad_size)
42
+
43
+ def clean_mask(mask, min_kernel: int = 5, max_kernel: int = 23):
44
+ mask = torch.Tensor(mask[None, None]).float()
45
+ mask = min_pool(mask, min_kernel)
46
+ mask = max_pool(mask, max_kernel)
47
+ mask = mask.bool().squeeze().numpy()
48
+ return mask
49
+
50
+ # iface = gr.Interface(fn=greet, inputs="text", outputs="text")
51
+ # iface.launch()
52
+ device = get_device()
53
+
54
+ feature_extractor, segmentation_model, segmentation_cfg = load_segmentation_models()
55
+ model = segmentation_model.to(device)
56
+
57
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
58
+ image = Image.open(requests.get(url, stream=True).raw)
59
+
60
+ # prepare image for the model
61
+ inputs = feature_extractor(images=image, return_tensors="pt").to(device)
62
+
63
+ # forward pass
64
+ outputs = segmentation_model(**inputs)
65
+
66
+ processed_sizes = torch.as_tensor(inputs["pixel_values"].shape[-2:]).unsqueeze(0)
67
+ result = feature_extractor.post_process_panoptic(outputs, processed_sizes)[0]
68
+
69
+ panoptic_seg = Image.open(io.BytesIO(result["png_string"])).resize((image.width, image.height))
70
+ panoptic_seg = np.array(panoptic_seg, dtype=np.uint8)
71
+
72
+ panoptic_seg_id = rgb_to_id(panoptic_seg)
73
+
74
+ print(result['segments_info'])
75
+
76
+ # cat_mask = (panoptic_seg_id == 1) | (panoptic_seg_id == 5)
77
+ cat_mask = (panoptic_seg_id == 5)
78
+ cat_mask = clean_mask(cat_mask)
79
+
80
+ masked_image = np.array(image).copy()
81
+ masked_image[cat_mask] = 0
82
+
83
+ masked_image = Image.fromarray(masked_image)
84
+ masked_image.save('masked_cat.png')
85
+
86
+ pipe = load_diffusion_pipeline()
87
+ pipe = pipe.to(device)
88
+
89
+ print(cat_mask)
90
+
91
+ resize_ratio = 512 / 480
92
+ new_width = int(640 * resize_ratio)
93
+ new_width += 8 - (new_width % 8)
94
+
95
+ print(new_width)
96
+ cat_mask = Image.fromarray(cat_mask.astype(np.uint8) * 255).convert("RGB").resize((new_width, 512))
97
+ masked_image = masked_image.resize((new_width, 512))
98
+
99
+ prompt = "Two cats on the sofa together."
100
+ inpainted_image = pipe(height=512, width=new_width, prompt=prompt, image=masked_image, mask_image=cat_mask).images[0]
101
+ inpainted_image.save('inpaint_cat.png')