harsh99 commited on
Commit
70511fe
·
1 Parent(s): b993f12

CatVTON changes on top of stable diffusion

Browse files
Files changed (5) hide show
  1. agnostic_mask.png +0 -0
  2. diffusion.py +16 -16
  3. garment.jpg +0 -0
  4. person.jpg +0 -0
  5. pipeline.py +327 -114
agnostic_mask.png ADDED
diffusion.py CHANGED
@@ -50,7 +50,7 @@ class UNET_ResidualBlock(nn.Module):
50
  return merged + self.residual_layer(residue)
51
 
52
  class UNET_AttentionBlock(nn.Module):
53
- def __init__(self, n_head, n_embed, d_context=768):
54
  super().__init__()
55
 
56
  channels=n_head*n_embed
@@ -62,7 +62,7 @@ class UNET_AttentionBlock(nn.Module):
62
  self.attention_1=SelfAttention(n_head, channels, in_proj_bias=False)
63
 
64
  self.layernorm_2=nn.LayerNorm(channels)
65
- self.attention_2=CrossAttention(n_head, channels, d_context, in_proj_bias=False)
66
 
67
  self.layernorm_3=nn.LayerNorm(channels)
68
 
@@ -71,7 +71,7 @@ class UNET_AttentionBlock(nn.Module):
71
 
72
  self.conv_output=nn.Conv2d(channels, channels, kernel_size=1, padding=0)
73
 
74
- def forward(self, x, context):
75
  residue_long=x
76
 
77
  x=self.grpnorm(x)
@@ -92,7 +92,7 @@ class UNET_AttentionBlock(nn.Module):
92
  residue_short=x
93
 
94
  x=self.layernorm_2(x)
95
- x=self.attention_2(x, context)
96
 
97
  x+=residue_short
98
 
@@ -123,10 +123,10 @@ class Upsample(nn.Module):
123
 
124
  # passing arguments to the parent class nn.Sequential, not to your SwitchSequential class directly — because you did not override the __init__ method in SwitchSequential
125
  class SwitchSequential(nn.Sequential):
126
- def forward(self, x, context, time):
127
  for layer in self:
128
  if isinstance(layer, UNET_AttentionBlock):
129
- x=layer(x, context)
130
  elif isinstance(layer, UNET_ResidualBlock):
131
  x=layer(x, time)
132
  else:
@@ -210,22 +210,22 @@ class UNET(nn.Module):
210
  SwitchSequential(UNET_ResidualBlock(640, 320), UNET_AttentionBlock(8, 40)),
211
  ])
212
 
213
- def forward(self, x, context, time):
214
  # x: (Batch_Size, 4, Height / 8, Width / 8)
215
  # context: (Batch_Size, Seq_Len, Dim)
216
  # time: (1, 1280)
217
 
218
  skip_connections = []
219
  for layers in self.encoders:
220
- x = layers(x, context, time)
221
  skip_connections.append(x)
222
 
223
- x = self.bottleneck(x, context, time)
224
 
225
  for layers in self.decoders:
226
  # Since we always concat with the skip connection of the encoder, the number of features increases before being sent to the decoder's layer
227
  x = torch.cat((x, skip_connections.pop()), dim=1)
228
- x = layers(x, context, time)
229
 
230
  return x
231
 
@@ -251,10 +251,10 @@ class Diffusion(nn.Module):
251
  self.unet=UNET()
252
  self.final=UNET_OutputLayer(320, 4)
253
 
254
- def forward(self, latent, context, time):
255
  time=self.time_embedding(time)
256
 
257
- output=self.unet(latent, context, time)
258
 
259
  output=self.final(output)
260
 
@@ -266,7 +266,7 @@ if __name__ == "__main__":
266
  height = 64
267
  width = 64
268
  in_channels = 4
269
- context_dim = 768
270
  seq_len = 77
271
 
272
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -285,13 +285,13 @@ if __name__ == "__main__":
285
  print('Time Embedding shape to UNET: ',t.shape)
286
 
287
  # Context for cross attention (e.g., text embedding from CLIP or transformer)
288
- context = torch.randn(batch_size, seq_len, context_dim).to(device)
289
 
290
- print('context shape to UNET: ', context.shape)
291
 
292
  # Forward pass
293
  with torch.no_grad():
294
- output = model(x, context, t)
295
  print(output)
296
 
297
  print("Output shape of UNET:", output.shape)
 
50
  return merged + self.residual_layer(residue)
51
 
52
  class UNET_AttentionBlock(nn.Module):
53
+ def __init__(self, n_head, n_embed):
54
  super().__init__()
55
 
56
  channels=n_head*n_embed
 
62
  self.attention_1=SelfAttention(n_head, channels, in_proj_bias=False)
63
 
64
  self.layernorm_2=nn.LayerNorm(channels)
65
+ # self.attention_2=CrossAttention(n_head, channels, d_context, in_proj_bias=False)
66
 
67
  self.layernorm_3=nn.LayerNorm(channels)
68
 
 
71
 
72
  self.conv_output=nn.Conv2d(channels, channels, kernel_size=1, padding=0)
73
 
74
+ def forward(self, x):
75
  residue_long=x
76
 
77
  x=self.grpnorm(x)
 
92
  residue_short=x
93
 
94
  x=self.layernorm_2(x)
95
+ # x=self.attention_2(x, context)
96
 
97
  x+=residue_short
98
 
 
123
 
124
  # passing arguments to the parent class nn.Sequential, not to your SwitchSequential class directly — because you did not override the __init__ method in SwitchSequential
125
  class SwitchSequential(nn.Sequential):
126
+ def forward(self, x, time):
127
  for layer in self:
128
  if isinstance(layer, UNET_AttentionBlock):
129
+ x=layer(x)
130
  elif isinstance(layer, UNET_ResidualBlock):
131
  x=layer(x, time)
132
  else:
 
210
  SwitchSequential(UNET_ResidualBlock(640, 320), UNET_AttentionBlock(8, 40)),
211
  ])
212
 
213
+ def forward(self, x, time):
214
  # x: (Batch_Size, 4, Height / 8, Width / 8)
215
  # context: (Batch_Size, Seq_Len, Dim)
216
  # time: (1, 1280)
217
 
218
  skip_connections = []
219
  for layers in self.encoders:
220
+ x = layers(x, time)
221
  skip_connections.append(x)
222
 
223
+ x = self.bottleneck(x, time)
224
 
225
  for layers in self.decoders:
226
  # Since we always concat with the skip connection of the encoder, the number of features increases before being sent to the decoder's layer
227
  x = torch.cat((x, skip_connections.pop()), dim=1)
228
+ x = layers(x, time)
229
 
230
  return x
231
 
 
251
  self.unet=UNET()
252
  self.final=UNET_OutputLayer(320, 4)
253
 
254
+ def forward(self, latent, time):
255
  time=self.time_embedding(time)
256
 
257
+ output=self.unet(latent, time)
258
 
259
  output=self.final(output)
260
 
 
266
  height = 64
267
  width = 64
268
  in_channels = 4
269
+ # context_dim = 768
270
  seq_len = 77
271
 
272
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
285
  print('Time Embedding shape to UNET: ',t.shape)
286
 
287
  # Context for cross attention (e.g., text embedding from CLIP or transformer)
288
+ # context = torch.randn(batch_size, seq_len, context_dim).to(device)
289
 
290
+ # print('context shape to UNET: ', context.shape)
291
 
292
  # Forward pass
293
  with torch.no_grad():
294
+ output = model(x, t)
295
  print(output)
296
 
297
  print("Output shape of UNET:", output.shape)
garment.jpg ADDED
person.jpg ADDED
pipeline.py CHANGED
@@ -1,32 +1,218 @@
 
 
 
1
  import torch
2
  import numpy as np
3
  from tqdm import tqdm
4
  from ddpm import DDPMSampler
 
5
 
6
  WIDTH = 512
7
  HEIGHT = 512
8
  LATENTS_WIDTH = WIDTH // 8
9
  LATENTS_HEIGHT = HEIGHT // 8
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def generate(
12
- prompt,
13
- uncond_prompt=None,
14
- input_image=None,
15
- strength=0.8,
16
- do_cfg=True,
17
- cfg_scale=7.5,
18
- sampler_name="ddpm",
19
- n_inference_steps=50,
20
  models={},
 
21
  seed=None,
22
  device=None,
23
  idle_device=None,
24
- tokenizer=None,
25
  ):
26
  with torch.no_grad():
27
- if not 0 < strength <= 1:
28
- raise ValueError("strength must be between 0 and 1")
29
-
30
  if idle_device:
31
  to_idle = lambda x: x.to(idle_device)
32
  else:
@@ -39,121 +225,125 @@ def generate(
39
  else:
40
  generator.manual_seed(seed)
41
 
42
- clip = models["clip"]
43
- clip.to(device)
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- if do_cfg:
46
- # Convert into a list of length Seq_Len=77
47
- cond_tokens = tokenizer.batch_encode_plus(
48
- [prompt], padding="max_length", max_length=77
49
- ).input_ids
50
- # (Batch_Size, Seq_Len)
51
- cond_tokens = torch.tensor(cond_tokens, dtype=torch.long, device=device)
52
- # (Batch_Size, Seq_Len) -> (Batch_Size, Seq_Len, Dim)
53
- cond_context = clip(cond_tokens)
54
- # Convert into a list of length Seq_Len=77
55
- uncond_tokens = tokenizer.batch_encode_plus(
56
- [uncond_prompt], padding="max_length", max_length=77
57
- ).input_ids
58
- # (Batch_Size, Seq_Len)
59
- uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device)
60
- # (Batch_Size, Seq_Len) -> (Batch_Size, Seq_Len, Dim)
61
- uncond_context = clip(uncond_tokens)
62
- # (Batch_Size, Seq_Len, Dim) + (Batch_Size, Seq_Len, Dim) -> (2 * Batch_Size, Seq_Len, Dim)
63
- context = torch.cat([cond_context, uncond_context])
64
- else:
65
- # Convert into a list of length Seq_Len=77
66
- tokens = tokenizer.batch_encode_plus(
67
- [prompt], padding="max_length", max_length=77
68
- ).input_ids
69
- # (Batch_Size, Seq_Len)
70
- tokens = torch.tensor(tokens, dtype=torch.long, device=device)
71
- # (Batch_Size, Seq_Len) -> (Batch_Size, Seq_Len, Dim)
72
- context = clip(tokens)
73
- to_idle(clip)
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  if sampler_name == "ddpm":
76
  sampler = DDPMSampler(generator)
77
- sampler.set_inference_timesteps(n_inference_steps)
78
  else:
79
- raise ValueError("Unknown sampler value %s. ")
80
-
81
- latents_shape = (1, 4, LATENTS_HEIGHT, LATENTS_WIDTH)
82
-
83
- if input_image:
84
- encoder = models["encoder"]
85
- encoder.to(device)
86
 
87
- input_image_tensor = input_image.resize((WIDTH, HEIGHT))
 
88
 
89
- # (Height, Width, Channel)
90
- input_image_tensor = np.array(input_image_tensor)
91
-
92
- # (Height, Width, Channel) -> (Height, Width, Channel)
93
- input_image_tensor = torch.tensor(input_image_tensor, dtype=torch.float32, device=device)
94
- input_image_tensor = rescale(input_image_tensor, (0, 255), (-1, 1))
95
-
96
- # (Height, Width, Channel) -> (Batch_Size, Height, Width, Channel)
97
- input_image_tensor = input_image_tensor.unsqueeze(0)
98
-
99
- # (Batch_Size, Height, Width, Channel) -> (Batch_Size, Channel, Height, Width)
100
- input_image_tensor = input_image_tensor.permute(0, 3, 1, 2)
101
-
102
- # (Batch_Size, 4, Latents_Height, Latents_Width)
103
- encoder_noise = torch.randn(latents_shape, generator=generator, device=device)
104
- latents = encoder(input_image_tensor, encoder_noise)
105
-
106
- # Add noise to the latents (the encoded input image)
107
- # (Batch_Size, 4, Latents_Height, Latents_Width)
108
- sampler.set_strength(strength=strength)
109
- latents = sampler.add_noise(latents, sampler.timesteps[0])
110
- to_idle(encoder)
111
- else:
112
- # (Batch_Size, 4, Latents_Height, Latents_Width)
113
- latents = torch.randn(latents_shape, generator=generator, device=device)
114
-
115
- diffusion = models["diffusion"]
116
- diffusion.to(device)
117
-
118
- timesteps = tqdm(sampler.timesteps)
119
- for i, timestep in enumerate(timesteps):
120
- # (1, 320)
121
- time_embedding = get_time_embedding(timestep).to(device)
122
 
123
- # (Batch_Size, 4, Latents_Height, Latents_Width)
124
- model_input = latents
125
-
126
- if do_cfg:
127
- # (Batch_Size, 4, Latents_Height, Latents_Width) -> (2 * Batch_Size, 4, Latents_Height, Latents_Width)
128
- model_input = model_input.repeat(2, 1, 1, 1)
129
-
130
- # model_output is the predicted noise
131
- # (Batch_Size, 4, Latents_Height, Latents_Width) -> (Batch_Size, 4, Latents_Height, Latents_Width)
132
-
133
- model_output = diffusion(model_input, context, time_embedding)
134
-
135
- if do_cfg:
136
- output_cond, output_uncond = model_output.chunk(2)
137
- model_output = cfg_scale * (output_cond - output_uncond) + output_uncond
138
-
139
- # (Batch_Size, 4, Latents_Height, Latents_Width) -> (Batch_Size, 4, Latents_Height, Latents_Width)
140
- latents = sampler.step(timestep, latents, model_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
- to_idle(diffusion)
 
143
 
144
- decoder = models["decoder"]
 
 
 
145
  decoder.to(device)
146
- # (Batch_Size, 4, Latents_Height, Latents_Width) -> (Batch_Size, 3, Height, Width)
147
- images = decoder(latents)
148
 
 
 
 
 
 
 
149
  to_idle(decoder)
150
-
151
- images = rescale(images, (-1, 1), (0, 255), clamp=True)
152
- # (Batch_Size, Channel, Height, Width) -> (Batch_Size, Height, Width, Channel)
153
- images = images.permute(0, 2, 3, 1)
154
- images = images.to("cpu", torch.uint8).numpy()
155
- return images[0]
156
 
 
157
  def rescale(x, old_range, new_range, clamp=False):
158
  old_min, old_max = old_range
159
  new_min, new_max = new_range
@@ -169,6 +359,29 @@ def get_time_embedding(timestep):
169
  freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160)
170
  # Shape: (1, 160)
171
  x = torch.tensor([timestep], dtype=torch.float32)[:, None] * freqs[None]
172
- # Shape: (1, 160 * 2)
173
  return torch.cat([torch.cos(x), torch.sin(x)], dim=-1)
174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Union
3
+ import PIL
4
  import torch
5
  import numpy as np
6
  from tqdm import tqdm
7
  from ddpm import DDPMSampler
8
+ from PIL import Image
9
 
10
  WIDTH = 512
11
  HEIGHT = 512
12
  LATENTS_WIDTH = WIDTH // 8
13
  LATENTS_HEIGHT = HEIGHT // 8
14
 
15
+ def repaint_result(result, person_image, mask_image):
16
+ result, person, mask = np.array(result), np.array(person_image), np.array(mask_image)
17
+ # expand the mask to 3 channels & to 0~1
18
+ mask = np.expand_dims(mask, axis=2)
19
+ mask = mask / 255.0
20
+ # mask for result, ~mask for person
21
+ result_ = result * mask + person * (1 - mask)
22
+ return Image.fromarray(result_.astype(np.uint8))
23
+
24
+
25
+ def prepare_image(image):
26
+ if isinstance(image, torch.Tensor):
27
+ # Batch single image
28
+ if image.ndim == 3:
29
+ image = image.unsqueeze(0)
30
+ image = image.to(dtype=torch.float32)
31
+ else:
32
+ # preprocess image
33
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
34
+ image = [image]
35
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
36
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
37
+ image = np.concatenate(image, axis=0)
38
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
39
+ image = np.concatenate([i[None, :] for i in image], axis=0)
40
+ image = image.transpose(0, 3, 1, 2)
41
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
42
+ return image
43
+
44
+
45
+ def prepare_mask_image(mask_image):
46
+ if isinstance(mask_image, torch.Tensor):
47
+ if mask_image.ndim == 2:
48
+ # Batch and add channel dim for single mask
49
+ mask_image = mask_image.unsqueeze(0).unsqueeze(0)
50
+ elif mask_image.ndim == 3 and mask_image.shape[0] == 1:
51
+ # Single mask, the 0'th dimension is considered to be
52
+ # the existing batch size of 1
53
+ mask_image = mask_image.unsqueeze(0)
54
+ elif mask_image.ndim == 3 and mask_image.shape[0] != 1:
55
+ # Batch of mask, the 0'th dimension is considered to be
56
+ # the batching dimension
57
+ mask_image = mask_image.unsqueeze(1)
58
+
59
+ # Binarize mask
60
+ mask_image[mask_image < 0.5] = 0
61
+ mask_image[mask_image >= 0.5] = 1
62
+ else:
63
+ # preprocess mask
64
+ if isinstance(mask_image, (PIL.Image.Image, np.ndarray)):
65
+ mask_image = [mask_image]
66
+
67
+ if isinstance(mask_image, list) and isinstance(mask_image[0], PIL.Image.Image):
68
+ mask_image = np.concatenate(
69
+ [np.array(m.convert("L"))[None, None, :] for m in mask_image], axis=0
70
+ )
71
+ mask_image = mask_image.astype(np.float32) / 255.0
72
+ elif isinstance(mask_image, list) and isinstance(mask_image[0], np.ndarray):
73
+ mask_image = np.concatenate([m[None, None, :] for m in mask_image], axis=0)
74
+
75
+ mask_image[mask_image < 0.5] = 0
76
+ mask_image[mask_image >= 0.5] = 1
77
+ mask_image = torch.from_numpy(mask_image)
78
+
79
+ return mask_image
80
+
81
+
82
+ def numpy_to_pil(images):
83
+ """
84
+ Convert a numpy image or a batch of images to a PIL image.
85
+ """
86
+ if images.ndim == 3:
87
+ images = images[None, ...]
88
+ images = (images * 255).round().astype("uint8")
89
+ if images.shape[-1] == 1:
90
+ # special case for grayscale (single channel) images
91
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
92
+ else:
93
+ pil_images = [Image.fromarray(image) for image in images]
94
+
95
+ return pil_images
96
+
97
+
98
+ def tensor_to_image(tensor: torch.Tensor):
99
+ """
100
+ Converts a torch tensor to PIL Image.
101
+ """
102
+ assert tensor.dim() == 3, "Input tensor should be 3-dimensional."
103
+ assert tensor.dtype == torch.float32, "Input tensor should be float32."
104
+ assert (
105
+ tensor.min() >= 0 and tensor.max() <= 1
106
+ ), "Input tensor should be in range [0, 1]."
107
+ tensor = tensor.cpu()
108
+ tensor = tensor * 255
109
+ tensor = tensor.permute(1, 2, 0)
110
+ tensor = tensor.numpy().astype(np.uint8)
111
+ image = Image.fromarray(tensor)
112
+ return image
113
+
114
+
115
+ def concat_images(images: List[Image.Image], divider: int = 4, cols: int = 4):
116
+ """
117
+ Concatenates images horizontally and with
118
+ """
119
+ widths = [image.size[0] for image in images]
120
+ heights = [image.size[1] for image in images]
121
+ total_width = cols * max(widths)
122
+ total_width += divider * (cols - 1)
123
+ # `col` images each row
124
+ rows = math.ceil(len(images) / cols)
125
+ total_height = max(heights) * rows
126
+ # add divider between rows
127
+ total_height += divider * (len(heights) // cols - 1)
128
+
129
+ # all black image
130
+ concat_image = Image.new("RGB", (total_width, total_height), (0, 0, 0))
131
+
132
+ x_offset = 0
133
+ y_offset = 0
134
+ for i, image in enumerate(images):
135
+ concat_image.paste(image, (x_offset, y_offset))
136
+ x_offset += image.size[0] + divider
137
+ if (i + 1) % cols == 0:
138
+ x_offset = 0
139
+ y_offset += image.size[1] + divider
140
+
141
+ return concat_image
142
+
143
+ def resize_and_crop(image, size):
144
+ # Crop to size ratio
145
+ w, h = image.size
146
+ target_w, target_h = size
147
+ if w / h < target_w / target_h:
148
+ new_w = w
149
+ new_h = w * target_h // target_w
150
+ else:
151
+ new_h = h
152
+ new_w = h * target_w // target_h
153
+ image = image.crop(
154
+ ((w - new_w) // 2, (h - new_h) // 2, (w + new_w) // 2, (h + new_h) // 2)
155
+ )
156
+ # resize
157
+ image = image.resize(size, Image.LANCZOS)
158
+ return image
159
+
160
+
161
+ def resize_and_padding(image, size):
162
+ # Padding to size ratio
163
+ w, h = image.size
164
+ target_w, target_h = size
165
+ if w / h < target_w / target_h:
166
+ new_h = target_h
167
+ new_w = w * target_h // h
168
+ else:
169
+ new_w = target_w
170
+ new_h = h * target_w // w
171
+ image = image.resize((new_w, new_h), Image.LANCZOS)
172
+ # padding
173
+ padding = Image.new("RGB", size, (255, 255, 255))
174
+ padding.paste(image, ((target_w - new_w) // 2, (target_h - new_h) // 2))
175
+ return padding
176
+
177
+ def check_inputs(image, condition_image, mask, width, height):
178
+ if isinstance(image, torch.Tensor) and isinstance(condition_image, torch.Tensor) and isinstance(mask, torch.Tensor):
179
+ return image, condition_image, mask
180
+ assert image.size == mask.size, "Image and mask must have the same size"
181
+ image = resize_and_crop(image, (width, height))
182
+ mask = resize_and_crop(mask, (width, height))
183
+ condition_image = resize_and_padding(condition_image, (width, height))
184
+ return image, condition_image, mask
185
+
186
+
187
+ def compute_vae_encodings(image_tensor, encoder, device):
188
+ """Encode image using VAE encoder"""
189
+ # Generate random noise for encoding
190
+ encoder_noise = torch.randn(
191
+ (image_tensor.shape[0], 4, image_tensor.shape[2] // 8, image_tensor.shape[3] // 8),
192
+ device=device,
193
+ )
194
+
195
+ # Encode using your custom encoder
196
+ latent = encoder(image_tensor, encoder_noise)
197
+ return latent
198
+
199
+
200
  def generate(
201
+ image: Union[PIL.Image.Image, torch.Tensor],
202
+ condition_image: Union[PIL.Image.Image, torch.Tensor],
203
+ mask: Union[PIL.Image.Image, torch.Tensor],
204
+ num_inference_steps: int = 50,
205
+ guidance_scale: float = 2.5,
206
+ height: int = 1024,
207
+ width: int = 768,
 
208
  models={},
209
+ sampler_name="ddpm",
210
  seed=None,
211
  device=None,
212
  idle_device=None,
213
+ **kwargs
214
  ):
215
  with torch.no_grad():
 
 
 
216
  if idle_device:
217
  to_idle = lambda x: x.to(idle_device)
218
  else:
 
225
  else:
226
  generator.manual_seed(seed)
227
 
228
+ concat_dim = -2 # FIXME: y axis concat
229
+ # Prepare inputs to Tensor
230
+ image, condition_image, mask = check_inputs(image, condition_image, mask, width, height)
231
+ image = prepare_image(image).to(device)
232
+ condition_image = prepare_image(condition_image).to(device)
233
+ mask = prepare_mask_image(mask).to(device)
234
+ # Mask image
235
+ masked_image = image * (mask < 0.5)
236
+
237
+ # VAE encoding
238
+ encoder = models.get('encoder', None)
239
+ if encoder is None:
240
+ raise ValueError("Encoder model not found in models dictionary")
241
 
242
+ encoder.to(device)
243
+ masked_latent = compute_vae_encodings(masked_image, encoder, device)
244
+ condition_latent = compute_vae_encodings(condition_image, encoder, device)
245
+ to_idle(encoder)
246
+
247
+ # Concatenate latents
248
+ masked_latent_concat = torch.cat([masked_latent, condition_latent], dim=concat_dim)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
 
250
+ mask_latent = torch.nn.functional.interpolate(mask, size=masked_latent.shape[-2:], mode="nearest")
251
+ del image, mask, condition_image
252
+ mask_latent_concat = torch.cat([mask_latent, torch.zeros_like(mask_latent)], dim=concat_dim)
253
+
254
+ # Initialize latents
255
+ latents = torch.randn(
256
+ masked_latent_concat.shape,
257
+ generator=generator,
258
+ device=masked_latent_concat.device,
259
+ dtype=masked_latent_concat.dtype
260
+ )
261
+
262
+ # Prepare timesteps
263
  if sampler_name == "ddpm":
264
  sampler = DDPMSampler(generator)
265
+ sampler.set_inference_timesteps(num_inference_steps)
266
  else:
267
+ raise ValueError("Unknown sampler value %s. " % sampler_name)
 
 
 
 
 
 
268
 
269
+ timesteps = sampler.timesteps
270
+ latents = sampler.add_noise(latents, timesteps[0])
271
 
272
+ # Classifier-Free Guidance
273
+ do_classifier_free_guidance = guidance_scale > 1.0
274
+ if do_classifier_free_guidance:
275
+ masked_latent_concat = torch.cat(
276
+ [
277
+ torch.cat([masked_latent, torch.zeros_like(condition_latent)], dim=concat_dim),
278
+ masked_latent_concat,
279
+ ]
280
+ )
281
+ mask_latent_concat = torch.cat([mask_latent_concat] * 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
 
283
+ # Denoising loop - Fixed: removed self references and incorrect scheduler calls
284
+ num_warmup_steps = 0 # For simple DDPM, no warmup needed
285
+
286
+ with tqdm(total=num_inference_steps) as progress_bar:
287
+ for i, t in enumerate(timesteps):
288
+ # expand the latents if we are doing classifier free guidance
289
+ non_inpainting_latent_model_input = (torch.cat([latents] * 2) if do_classifier_free_guidance else latents)
290
+
291
+ # prepare the input for the inpainting model
292
+ inpainting_latent_model_input = torch.cat([non_inpainting_latent_model_input, mask_latent_concat, masked_latent_concat], dim=1)
293
+
294
+ # predict the noise residual
295
+ diffusion = models.get('diffusion', None)
296
+ if diffusion is None:
297
+ raise ValueError("Diffusion model not found in models dictionary")
298
+
299
+ diffusion.to(device)
300
+
301
+ # Create time embedding for the current timestep
302
+ time_embedding = get_time_embedding(t.item()).unsqueeze(0).to(device)
303
+ if do_classifier_free_guidance:
304
+ time_embedding = torch.cat([time_embedding] * 2)
305
+
306
+ noise_pred = diffusion(
307
+ inpainting_latent_model_input,
308
+ time_embedding
309
+ )
310
+
311
+ to_idle(diffusion)
312
+
313
+ # perform guidance
314
+ if do_classifier_free_guidance:
315
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
316
+ noise_pred = noise_pred_uncond + guidance_scale * (
317
+ noise_pred_text - noise_pred_uncond
318
+ )
319
+
320
+ # compute the previous noisy sample x_t -> x_t-1
321
+ latents = sampler.step(t, latents, noise_pred)
322
+
323
+ # Update progress bar
324
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps):
325
+ progress_bar.update()
326
 
327
+ # Decode the final latents
328
+ latents = latents.split(latents.shape[concat_dim] // 2, dim=concat_dim)[0]
329
 
330
+ decoder = models.get('decoder', None)
331
+ if decoder is None:
332
+ raise ValueError("Decoder model not found in models dictionary")
333
+
334
  decoder.to(device)
 
 
335
 
336
+ image = decoder(latents.to(device))
337
+ image = (image / 2 + 0.5).clamp(0, 1)
338
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
339
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
340
+ image = numpy_to_pil(image)
341
+
342
  to_idle(decoder)
343
+
344
+ return image
 
 
 
 
345
 
346
+
347
  def rescale(x, old_range, new_range, clamp=False):
348
  old_min, old_max = old_range
349
  new_min, new_max = new_range
 
359
  freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160)
360
  # Shape: (1, 160)
361
  x = torch.tensor([timestep], dtype=torch.float32)[:, None] * freqs[None]
362
+ # Shape: (1, 160 * 2) -> (1, 320)
363
  return torch.cat([torch.cos(x), torch.sin(x)], dim=-1)
364
 
365
+ if __name__ == "__main__":
366
+ # Example usage
367
+ image = Image.open("example_image.jpg").convert("RGB")
368
+ condition_image = Image.open("example_condition_image.jpg").convert("RGB")
369
+ mask = Image.open("example_mask.png").convert("L")
370
+
371
+ # Resize images to the desired dimensions
372
+ image, condition_image, mask = check_inputs(image, condition_image, mask, WIDTH, HEIGHT)
373
+
374
+ # Generate image
375
+ generated_image = generate(
376
+ image=image,
377
+ condition_image=condition_image,
378
+ mask=mask,
379
+ num_inference_steps=50,
380
+ guidance_scale=2.5,
381
+ width=WIDTH,
382
+ height=HEIGHT,
383
+ device="cuda" # or "cpu"
384
+ )
385
+
386
+ generated_image[0].save("generated_image.png")
387
+