BertChristiaens commited on
Commit
dd0ab9f
1 Parent(s): e803877
Files changed (5) hide show
  1. app.py +2 -1
  2. helpers.py +46 -0
  3. models.py +5 -187
  4. pipelines.py +126 -0
  5. segmentation.py +55 -0
app.py CHANGED
@@ -7,7 +7,8 @@ import numpy as np
7
  import os
8
  import time
9
 
10
- from models import make_image_controlnet, make_inpainting, segment_image
 
11
  from config import HEIGHT, WIDTH, POS_PROMPT, NEG_PROMPT, COLOR_MAPPING, map_colors, map_colors_rgb
12
  from palette import COLOR_MAPPING_CATEGORY
13
  from preprocessing import preprocess_seg_mask, get_image, get_mask
 
7
  import os
8
  import time
9
 
10
+ from models import make_image_controlnet, make_inpainting
11
+ from segmentation import segment_image
12
  from config import HEIGHT, WIDTH, POS_PROMPT, NEG_PROMPT, COLOR_MAPPING, map_colors, map_colors_rgb
13
  from palette import COLOR_MAPPING_CATEGORY
14
  from preprocessing import preprocess_seg_mask, get_image, get_mask
helpers.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import torch
3
+ from scipy.signal import fftconvolve
4
+ from PIL import Image
5
+
6
+ def flush():
7
+ gc.collect()
8
+ torch.cuda.empty_cache()
9
+
10
+
11
+
12
+ def convolution(mask: Image.Image, size=9) -> Image:
13
+ """Method to blur the mask
14
+ Args:
15
+ mask (Image): masking image
16
+ size (int, optional): size of the blur. Defaults to 9.
17
+ Returns:
18
+ Image: blurred mask
19
+ """
20
+ mask = np.array(mask.convert("L"))
21
+ conv = np.ones((size, size)) / size**2
22
+ mask_blended = fftconvolve(mask, conv, 'same')
23
+ mask_blended = mask_blended.astype(np.uint8).copy()
24
+
25
+ border = size
26
+
27
+ # replace borders with original values
28
+ mask_blended[:border, :] = mask[:border, :]
29
+ mask_blended[-border:, :] = mask[-border:, :]
30
+ mask_blended[:, :border] = mask[:, :border]
31
+ mask_blended[:, -border:] = mask[:, -border:]
32
+
33
+ return Image.fromarray(mask_blended).convert("L")
34
+
35
+
36
+ def postprocess_image_masking(inpainted: Image, image: Image, mask: Image) -> Image:
37
+ """Method to postprocess the inpainted image
38
+ Args:
39
+ inpainted (Image): inpainted image
40
+ image (Image): original image
41
+ mask (Image): mask
42
+ Returns:
43
+ Image: inpainted image
44
+ """
45
+ final_inpainted = Image.composite(inpainted.convert("RGBA"), image.convert("RGBA"), mask)
46
+ return final_inpainted.convert("RGB")
models.py CHANGED
@@ -8,176 +8,18 @@ import gc
8
  import time
9
  import numpy as np
10
  from PIL import Image
11
- from time import perf_counter
12
- from contextlib import contextmanager
13
- from scipy.signal import fftconvolve
14
  from PIL import ImageFilter
15
 
16
- from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
17
  from diffusers import ControlNetModel, UniPCMultistepScheduler
18
- from diffusers import StableDiffusionInpaintPipeline
19
 
20
  from config import WIDTH, HEIGHT
21
  from palette import ade_palette
22
  from stable_diffusion_controlnet_inpaint_img2img import StableDiffusionControlNetInpaintImg2ImgPipeline
 
 
23
 
24
  LOGGING = logging.getLogger(__name__)
25
 
26
- def flush():
27
- gc.collect()
28
- torch.cuda.empty_cache()
29
-
30
- class ControlNetPipeline:
31
- def __init__(self):
32
- self.in_use = False
33
- self.controlnet = ControlNetModel.from_pretrained(
34
- "BertChristiaens/controlnet-seg-room", torch_dtype=torch.float16)
35
-
36
- self.pipe = StableDiffusionControlNetInpaintImg2ImgPipeline.from_pretrained(
37
- "runwayml/stable-diffusion-inpainting",
38
- controlnet=self.controlnet,
39
- safety_checker=None,
40
- torch_dtype=torch.float16
41
- )
42
-
43
- self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
44
- self.pipe.enable_xformers_memory_efficient_attention()
45
- self.pipe = self.pipe.to("cuda")
46
-
47
- self.waiting_queue = []
48
- self.count = 0
49
-
50
- @property
51
- def queue_size(self):
52
- return len(self.waiting_queue)
53
-
54
- def __call__(self, **kwargs):
55
- self.count += 1
56
- number = self.count
57
-
58
- self.waiting_queue.append(number)
59
-
60
- # wait until the next number in the queue is the current number
61
- while self.waiting_queue[0] != number:
62
- print(f"Wait for your turn {number} in queue {self.waiting_queue}")
63
- time.sleep(0.5)
64
- pass
65
-
66
- # it's your turn, so remove the number from the queue
67
- # and call the function
68
- print("It's the turn of", self.count)
69
- results = self.pipe(**kwargs)
70
- self.waiting_queue.pop(0)
71
- flush()
72
- return results
73
-
74
- class SDPipeline:
75
- def __init__(self):
76
- self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
77
- "stabilityai/stable-diffusion-2-inpainting",
78
- torch_dtype=torch.float16,
79
- safety_checker=None,
80
- )
81
-
82
- self.pipe.enable_xformers_memory_efficient_attention()
83
- self.pipe = self.pipe.to("cuda")
84
-
85
- self.waiting_queue = []
86
- self.count = 0
87
-
88
- @property
89
- def queue_size(self):
90
- return len(self.waiting_queue)
91
-
92
- def __call__(self, **kwargs):
93
- self.count += 1
94
- number = self.count
95
-
96
- self.waiting_queue.append(number)
97
-
98
- # wait until the next number in the queue is the current number
99
- while self.waiting_queue[0] != number:
100
- print(f"Wait for your turn {number} in queue {self.waiting_queue}")
101
- time.sleep(0.5)
102
- pass
103
-
104
- # it's your turn, so remove the number from the queue
105
- # and call the function
106
- print("It's the turn of", self.count)
107
- results = self.pipe(**kwargs)
108
- self.waiting_queue.pop(0)
109
- flush()
110
- return results
111
-
112
-
113
- def convolution(mask: Image.Image, size=9) -> Image:
114
- """Method to blur the mask
115
- Args:
116
- mask (Image): masking image
117
- size (int, optional): size of the blur. Defaults to 9.
118
- Returns:
119
- Image: blurred mask
120
- """
121
- mask = np.array(mask.convert("L"))
122
- conv = np.ones((size, size)) / size**2
123
- mask_blended = fftconvolve(mask, conv, 'same')
124
- mask_blended = mask_blended.astype(np.uint8).copy()
125
-
126
- border = size
127
-
128
- # replace borders with original values
129
- mask_blended[:border, :] = mask[:border, :]
130
- mask_blended[-border:, :] = mask[-border:, :]
131
- mask_blended[:, :border] = mask[:, :border]
132
- mask_blended[:, -border:] = mask[:, -border:]
133
-
134
- return Image.fromarray(mask_blended).convert("L")
135
-
136
-
137
- def postprocess_image_masking(inpainted: Image, image: Image, mask: Image) -> Image:
138
- """Method to postprocess the inpainted image
139
- Args:
140
- inpainted (Image): inpainted image
141
- image (Image): original image
142
- mask (Image): mask
143
- Returns:
144
- Image: inpainted image
145
- """
146
- final_inpainted = Image.composite(inpainted.convert("RGBA"), image.convert("RGBA"), mask)
147
- return final_inpainted.convert("RGB")
148
-
149
-
150
- @st.experimental_singleton(max_entries=5)
151
- def get_controlnet() -> ControlNetModel:
152
- """Method to load the controlnet model
153
- Returns:
154
- ControlNetModel: controlnet model
155
- """
156
- pipe = ControlNetPipeline()
157
- return pipe
158
-
159
-
160
- @st.experimental_singleton(max_entries=5)
161
- def get_segmentation_pipeline() -> Tuple[AutoImageProcessor, UperNetForSemanticSegmentation]:
162
- """Method to load the segmentation pipeline
163
- Returns:
164
- Tuple[AutoImageProcessor, UperNetForSemanticSegmentation]: segmentation pipeline
165
- """
166
- image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small")
167
- image_segmentor = UperNetForSemanticSegmentation.from_pretrained(
168
- "openmmlab/upernet-convnext-small")
169
- return image_processor, image_segmentor
170
-
171
-
172
- @st.experimental_singleton(max_entries=5)
173
- def get_inpainting_pipeline() -> StableDiffusionInpaintPipeline:
174
- """Method to load the inpainting pipeline
175
- Returns:
176
- StableDiffusionInpaintPipeline: inpainting pipeline
177
- """
178
- pipe = SDPipeline()
179
- return pipe
180
-
181
 
182
  @torch.inference_mode()
183
  def make_image_controlnet(image: np.ndarray,
@@ -238,12 +80,13 @@ def make_inpainting(positive_prompt: str,
238
  List[Image.Image]: list of generated images
239
  """
240
  pipe = get_inpainting_pipeline()
 
241
  mask_image_postproc = convolution(mask_image)
242
 
243
  flush()
244
  st.success(f"{pipe.queue_size} images in the queue, can take up to {(pipe.queue_size+1) * 10} seconds")
245
  generated_image = pipe(image=image,
246
- mask_image=Image.fromarray((mask_image * 255).astype(np.uint8)),
247
  prompt=positive_prompt,
248
  negative_prompt=negative_prompt,
249
  num_inference_steps=20,
@@ -252,29 +95,4 @@ def make_inpainting(positive_prompt: str,
252
  ).images[0]
253
  generated_image = postprocess_image_masking(generated_image, image, mask_image_postproc)
254
 
255
- return image_
256
-
257
-
258
- @torch.inference_mode()
259
- @torch.autocast('cuda')
260
- def segment_image(image: Image) -> Image:
261
- """Method to segment image
262
- Args:
263
- image (Image): input image
264
- Returns:
265
- Image: segmented image
266
- """
267
- image_processor, image_segmentor = get_segmentation_pipeline()
268
- pixel_values = image_processor(image, return_tensors="pt").pixel_values
269
- with torch.no_grad():
270
- outputs = image_segmentor(pixel_values)
271
-
272
- seg = image_processor.post_process_semantic_segmentation(
273
- outputs, target_sizes=[image.size[::-1]])[0]
274
- color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
275
- palette = np.array(ade_palette())
276
- for label, color in enumerate(palette):
277
- color_seg[seg == label, :] = color
278
- color_seg = color_seg.astype(np.uint8)
279
- seg_image = Image.fromarray(color_seg).convert('RGB')
280
- return seg_image
 
8
  import time
9
  import numpy as np
10
  from PIL import Image
 
 
 
11
  from PIL import ImageFilter
12
 
 
13
  from diffusers import ControlNetModel, UniPCMultistepScheduler
 
14
 
15
  from config import WIDTH, HEIGHT
16
  from palette import ade_palette
17
  from stable_diffusion_controlnet_inpaint_img2img import StableDiffusionControlNetInpaintImg2ImgPipeline
18
+ from helpers import flush, postprocess_image_masking, convolution
19
+ from pipelines import ControlNetPipeline, SDPipeline, get_inpainting_pipeline, get_controlnet
20
 
21
  LOGGING = logging.getLogger(__name__)
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  @torch.inference_mode()
25
  def make_image_controlnet(image: np.ndarray,
 
80
  List[Image.Image]: list of generated images
81
  """
82
  pipe = get_inpainting_pipeline()
83
+ mask_image = Image.fromarray((mask_image * 255).astype(np.uint8))
84
  mask_image_postproc = convolution(mask_image)
85
 
86
  flush()
87
  st.success(f"{pipe.queue_size} images in the queue, can take up to {(pipe.queue_size+1) * 10} seconds")
88
  generated_image = pipe(image=image,
89
+ mask_image=mask_image,
90
  prompt=positive_prompt,
91
  negative_prompt=negative_prompt,
92
  num_inference_steps=20,
 
95
  ).images[0]
96
  generated_image = postprocess_image_masking(generated_image, image, mask_image_postproc)
97
 
98
+ return generated_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pipelines.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import List, Tuple, Dict
3
+
4
+ import streamlit as st
5
+ import torch
6
+ import gc
7
+ import time
8
+ import numpy as np
9
+ from PIL import Image
10
+ from time import perf_counter
11
+ from contextlib import contextmanager
12
+ from scipy.signal import fftconvolve
13
+ from PIL import ImageFilter
14
+
15
+ from diffusers import ControlNetModel, UniPCMultistepScheduler
16
+ from diffusers import StableDiffusionInpaintPipeline
17
+
18
+ from config import WIDTH, HEIGHT
19
+ from stable_diffusion_controlnet_inpaint_img2img import StableDiffusionControlNetInpaintImg2ImgPipeline
20
+ from helpers import flush
21
+
22
+ LOGGING = logging.getLogger(__name__)
23
+
24
+ class ControlNetPipeline:
25
+ def __init__(self):
26
+ self.in_use = False
27
+ self.controlnet = ControlNetModel.from_pretrained(
28
+ "BertChristiaens/controlnet-seg-room", torch_dtype=torch.float16)
29
+
30
+ self.pipe = StableDiffusionControlNetInpaintImg2ImgPipeline.from_pretrained(
31
+ "runwayml/stable-diffusion-inpainting",
32
+ controlnet=self.controlnet,
33
+ safety_checker=None,
34
+ torch_dtype=torch.float16
35
+ )
36
+
37
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
38
+ self.pipe.enable_xformers_memory_efficient_attention()
39
+ self.pipe = self.pipe.to("cuda")
40
+
41
+ self.waiting_queue = []
42
+ self.count = 0
43
+
44
+ @property
45
+ def queue_size(self):
46
+ return len(self.waiting_queue)
47
+
48
+ def __call__(self, **kwargs):
49
+ self.count += 1
50
+ number = self.count
51
+
52
+ self.waiting_queue.append(number)
53
+
54
+ # wait until the next number in the queue is the current number
55
+ while self.waiting_queue[0] != number:
56
+ print(f"Wait for your turn {number} in queue {self.waiting_queue}")
57
+ time.sleep(0.5)
58
+ pass
59
+
60
+ # it's your turn, so remove the number from the queue
61
+ # and call the function
62
+ print("It's the turn of", self.count)
63
+ results = self.pipe(**kwargs)
64
+ self.waiting_queue.pop(0)
65
+ flush()
66
+ return results
67
+
68
+ class SDPipeline:
69
+ def __init__(self):
70
+ self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
71
+ "stabilityai/stable-diffusion-2-inpainting",
72
+ torch_dtype=torch.float16,
73
+ safety_checker=None,
74
+ )
75
+
76
+ self.pipe.enable_xformers_memory_efficient_attention()
77
+ self.pipe = self.pipe.to("cuda")
78
+
79
+ self.waiting_queue = []
80
+ self.count = 0
81
+
82
+ @property
83
+ def queue_size(self):
84
+ return len(self.waiting_queue)
85
+
86
+ def __call__(self, **kwargs):
87
+ self.count += 1
88
+ number = self.count
89
+
90
+ self.waiting_queue.append(number)
91
+
92
+ # wait until the next number in the queue is the current number
93
+ while self.waiting_queue[0] != number:
94
+ print(f"Wait for your turn {number} in queue {self.waiting_queue}")
95
+ time.sleep(0.5)
96
+ pass
97
+
98
+ # it's your turn, so remove the number from the queue
99
+ # and call the function
100
+ print("It's the turn of", self.count)
101
+ results = self.pipe(**kwargs)
102
+ self.waiting_queue.pop(0)
103
+ flush()
104
+ return results
105
+
106
+
107
+
108
+ @st.experimental_singleton(max_entries=5)
109
+ def get_controlnet():
110
+ """Method to load the controlnet model
111
+ Returns:
112
+ ControlNetModel: controlnet model
113
+ """
114
+ pipe = ControlNetPipeline()
115
+ return pipe
116
+
117
+
118
+
119
+ @st.experimental_singleton(max_entries=5)
120
+ def get_inpainting_pipeline():
121
+ """Method to load the inpainting pipeline
122
+ Returns:
123
+ StableDiffusionInpaintPipeline: inpainting pipeline
124
+ """
125
+ pipe = SDPipeline()
126
+ return pipe
segmentation.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import List, Tuple, Dict
3
+
4
+ import streamlit as st
5
+ import torch
6
+ import gc
7
+ import numpy as np
8
+ from PIL import Image
9
+
10
+ from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
11
+
12
+ from palette import ade_palette
13
+
14
+ LOGGING = logging.getLogger(__name__)
15
+
16
+
17
+ def flush():
18
+ gc.collect()
19
+ torch.cuda.empty_cache()
20
+
21
+ @st.experimental_singleton(max_entries=5)
22
+ def get_segmentation_pipeline() -> Tuple[AutoImageProcessor, UperNetForSemanticSegmentation]:
23
+ """Method to load the segmentation pipeline
24
+ Returns:
25
+ Tuple[AutoImageProcessor, UperNetForSemanticSegmentation]: segmentation pipeline
26
+ """
27
+ image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small")
28
+ image_segmentor = UperNetForSemanticSegmentation.from_pretrained(
29
+ "openmmlab/upernet-convnext-small")
30
+ return image_processor, image_segmentor
31
+
32
+
33
+ @torch.inference_mode()
34
+ @torch.autocast('cuda')
35
+ def segment_image(image: Image) -> Image:
36
+ """Method to segment image
37
+ Args:
38
+ image (Image): input image
39
+ Returns:
40
+ Image: segmented image
41
+ """
42
+ image_processor, image_segmentor = get_segmentation_pipeline()
43
+ pixel_values = image_processor(image, return_tensors="pt").pixel_values
44
+ with torch.no_grad():
45
+ outputs = image_segmentor(pixel_values)
46
+
47
+ seg = image_processor.post_process_semantic_segmentation(
48
+ outputs, target_sizes=[image.size[::-1]])[0]
49
+ color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
50
+ palette = np.array(ade_palette())
51
+ for label, color in enumerate(palette):
52
+ color_seg[seg == label, :] = color
53
+ color_seg = color_seg.astype(np.uint8)
54
+ seg_image = Image.fromarray(color_seg).convert('RGB')
55
+ return seg_image