QinLei086 commited on
Commit
15acbf0
1 Parent(s): f8372bd

Upload 28 files

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ pretrain_weights/
README.md CHANGED
@@ -1,3 +1,5 @@
 
 
1
  ---
2
  title: LSDM
3
  emoji: 💻
 
1
+ -- LSDM for Crack Segmentation dataset expending
2
+
3
  ---
4
  title: LSDM
5
  emoji: 💻
app.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from evolution import random_walk
3
+ from generate import generate
4
+
5
+ def process_random_walk(img):
6
+ img1, _ = random_walk(img)
7
+
8
+ return img1
9
+
10
+ def process_first_generation(img1, model_path="pretrain_weights/b2m/unet_ema"):
11
+
12
+ generated_images = generate(img1, model_path)
13
+ return generated_images[0]
14
+
15
+ def process_second_generation(img1, model_path="pretrain_weights/m2i/unet_ema"):
16
+
17
+ generated_images = generate(img1, model_path)
18
+ return generated_images[0]
19
+
20
+ # 创建 Gradio 接口
21
+ with gr.Blocks() as app:
22
+ with gr.Row():
23
+ with gr.Column():
24
+ input_image = gr.Image(value="figs/4.png", image_mode='L', type='numpy', label="Upload Grayscale Image")
25
+
26
+ process_button_1 = gr.Button("1. Process Evolution")
27
+
28
+ with gr.Column():
29
+ output_image_1 = gr.Image(value="figs/4_1.png", image_mode='L', type="numpy", label="After Evolution Image",sources=[])
30
+ process_button_2 = gr.Button("2. Generate Masks")
31
+
32
+ with gr.Row():
33
+ with gr.Column():
34
+ output_image_3 = gr.Image(value="figs/4_1_mask.png", image_mode='L', type="numpy", label="Generated Mask Image",sources=[])
35
+ process_button_3 = gr.Button("3. Generate Images")
36
+ with gr.Column():
37
+ output_image_5 = gr.Image(value="figs/4_1.jpg", type="numpy", image_mode='RGB', label="Final Generated Image 1",sources=[])
38
+
39
+
40
+ process_button_1.click(
41
+ process_random_walk,
42
+ inputs=[input_image],
43
+ outputs=[output_image_1]
44
+ )
45
+
46
+ process_button_2.click(
47
+ process_first_generation,
48
+ inputs=[output_image_1],
49
+ outputs=[output_image_3]
50
+ )
51
+
52
+ process_button_3.click(
53
+ process_second_generation,
54
+ inputs=[output_image_3],
55
+ outputs=[output_image_5]
56
+ )
57
+
58
+ app.launch()
diffusion_module/__pycache__/nn.cpython-39.pyc ADDED
Binary file (6.26 kB). View file
 
diffusion_module/__pycache__/unet.cpython-39.pyc ADDED
Binary file (29.6 kB). View file
 
diffusion_module/__pycache__/unet_2d_blocks.cpython-39.pyc ADDED
Binary file (57.3 kB). View file
 
diffusion_module/__pycache__/unet_2d_sdm.cpython-39.pyc ADDED
Binary file (10.7 kB). View file
 
diffusion_module/nn.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Various utilities for neural networks.
3
+ """
4
+
5
+ import math
6
+
7
+ import torch as th
8
+ import torch.nn as nn
9
+
10
+
11
+ def convert_module_to_f16(l):
12
+ """
13
+ Convert primitive modules to float16.
14
+ """
15
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
16
+ l.weight.data = l.weight.data.half()
17
+ if l.bias is not None:
18
+ l.bias.data = l.bias.data.half()
19
+
20
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
21
+ class SiLU(nn.Module):
22
+ def forward(self, x):
23
+ return x * th.sigmoid(x)
24
+
25
+
26
+ class GroupNorm32(nn.GroupNorm):
27
+ def forward(self, x):
28
+ #print(x.float().dtype)
29
+ return super().forward(x).type(x.dtype)
30
+
31
+
32
+ def conv_nd(dims, *args, **kwargs):
33
+ """
34
+ Create a 1D, 2D, or 3D convolution module.
35
+ """
36
+ if dims == 1:
37
+ return nn.Conv1d(*args, **kwargs)
38
+ elif dims == 2:
39
+ return nn.Conv2d(*args, **kwargs)
40
+ elif dims == 3:
41
+ return nn.Conv3d(*args, **kwargs)
42
+ raise ValueError(f"unsupported dimensions: {dims}")
43
+
44
+
45
+ def linear(*args, **kwargs):
46
+ """
47
+ Create a linear module.
48
+ """
49
+ return nn.Linear(*args, **kwargs)
50
+
51
+
52
+ def avg_pool_nd(dims, *args, **kwargs):
53
+ """
54
+ Create a 1D, 2D, or 3D average pooling module.
55
+ """
56
+ if dims == 1:
57
+ return nn.AvgPool1d(*args, **kwargs)
58
+ elif dims == 2:
59
+ return nn.AvgPool2d(*args, **kwargs)
60
+ elif dims == 3:
61
+ return nn.AvgPool3d(*args, **kwargs)
62
+ raise ValueError(f"unsupported dimensions: {dims}")
63
+
64
+
65
+ def update_ema(target_params, source_params, rate=0.99):
66
+ """
67
+ Update target parameters to be closer to those of source parameters using
68
+ an exponential moving average.
69
+
70
+ :param target_params: the target parameter sequence.
71
+ :param source_params: the source parameter sequence.
72
+ :param rate: the EMA rate (closer to 1 means slower).
73
+ """
74
+ for targ, src in zip(target_params, source_params):
75
+ targ.detach().mul_(rate).add_(src, alpha=1 - rate)
76
+
77
+
78
+ def zero_module(module):
79
+ """
80
+ Zero out the parameters of a module and return it.
81
+ """
82
+ for p in module.parameters():
83
+ p.detach().zero_()
84
+ return module
85
+
86
+
87
+ def scale_module(module, scale):
88
+ """
89
+ Scale the parameters of a module and return it.
90
+ """
91
+ for p in module.parameters():
92
+ p.detach().mul_(scale)
93
+ return module
94
+
95
+
96
+ def mean_flat(tensor):
97
+ """
98
+ Take the mean over all non-batch dimensions.
99
+ """
100
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
101
+
102
+
103
+ def normalization(channels):
104
+ """
105
+ Make a standard normalization layer.
106
+
107
+ :param channels: number of input channels.
108
+ :return: an nn.Module for normalization.
109
+ """
110
+ return GroupNorm32(32, channels)
111
+
112
+
113
+ def timestep_embedding(timesteps, dim, max_period=10000):
114
+ """
115
+ Create sinusoidal timestep embeddings.
116
+
117
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
118
+ These may be fractional.
119
+ :param dim: the dimension of the output.
120
+ :param max_period: controls the minimum frequency of the embeddings.
121
+ :return: an [N x dim] Tensor of positional embeddings.
122
+ """
123
+ half = dim // 2
124
+ freqs = th.exp(
125
+ -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
126
+ ).to(device=timesteps.device)
127
+ args = timesteps[:, None].float() * freqs[None]
128
+ embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
129
+ if dim % 2:
130
+ embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
131
+ return embedding
132
+
133
+
134
+ def checkpoint(func, inputs, params, flag):
135
+ """
136
+ Evaluate a function without caching intermediate activations, allowing for
137
+ reduced memory at the expense of extra compute in the backward pass.
138
+
139
+ :param func: the function to evaluate.
140
+ :param inputs: the argument sequence to pass to `func`.
141
+ :param params: a sequence of parameters `func` depends on but does not
142
+ explicitly take as arguments.
143
+ :param flag: if False, disable gradient checkpointing.
144
+ """
145
+ if flag:
146
+ args = tuple(inputs) + tuple(params)
147
+ #return th.utils.checkpoint.checkpoint.apply(func, inputs)
148
+ return CheckpointFunction.apply(func, len(inputs), *args)
149
+ else:
150
+ return func(*inputs)
151
+
152
+
153
+ class CheckpointFunction(th.autograd.Function):
154
+ @staticmethod
155
+ def forward(ctx, run_function, length, *args):
156
+ ctx.run_function = run_function
157
+ ctx.input_tensors = list(args[:length])
158
+ ctx.input_params = list(args[length:])
159
+ breakpoint()
160
+ with th.no_grad():
161
+ output_tensors = ctx.run_function(*ctx.input_tensors)
162
+ return output_tensors
163
+
164
+ @staticmethod
165
+ def backward(ctx, *output_grads):
166
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
167
+ with th.enable_grad():
168
+ # Fixes a bug where the first op in run_function modifies the
169
+ # Tensor storage in place, which is not allowed for detach()'d
170
+ # Tensors.
171
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
172
+ breakpoint()
173
+ output_tensors = ctx(*shallow_copies)
174
+ input_grads = th.autograd.grad(
175
+ output_tensors,
176
+ ctx.input_tensors + ctx.input_params,
177
+ output_grads,
178
+ allow_unused=True,
179
+ )
180
+ del ctx.input_tensors
181
+ del ctx.input_params
182
+ del output_tensors
183
+ return (None, None) + input_grads
diffusion_module/unet.py ADDED
@@ -0,0 +1,1315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+
3
+ import math
4
+
5
+ import numpy as np
6
+ import torch as th
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from .nn import (
11
+ SiLU,
12
+ checkpoint,
13
+ conv_nd,
14
+ linear,
15
+ avg_pool_nd,
16
+ zero_module,
17
+ normalization,
18
+ timestep_embedding,
19
+ convert_module_to_f16
20
+ )
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.utils import BaseOutput
24
+ from diffusers.models.modeling_utils import ModelMixin
25
+ from dataclasses import dataclass
26
+
27
+ @dataclass
28
+ class UNet2DOutput(BaseOutput):
29
+ """
30
+ Args:
31
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
32
+ Hidden states output. Output of last layer of model.
33
+ """
34
+
35
+ sample: th.FloatTensor
36
+
37
+
38
+ class AttentionPool2d(nn.Module):
39
+ """
40
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ spacial_dim: int,
46
+ embed_dim: int,
47
+ num_heads_channels: int,
48
+ output_dim: int = None,
49
+ ):
50
+ super().__init__()
51
+ self.positional_embedding = nn.Parameter(
52
+ th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5
53
+ )
54
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
55
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
56
+ self.num_heads = embed_dim // num_heads_channels
57
+ self.attention = QKVAttention(self.num_heads)
58
+
59
+ def forward(self, x):
60
+ b, c, *_spatial = x.shape
61
+ x = x.reshape(b, c, -1) # NC(HW)
62
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
63
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
64
+ x = self.qkv_proj(x)
65
+ x = self.attention(x)
66
+ x = self.c_proj(x)
67
+ return x[:, :, 0]
68
+
69
+
70
+ class TimestepBlock(nn.Module):
71
+ """
72
+ Any module where forward() takes timestep embeddings as a second argument.
73
+ """
74
+
75
+ @abstractmethod
76
+ def forward(self, x, emb):
77
+ """
78
+ Apply the module to `x` given `emb` timestep embeddings.
79
+ """
80
+
81
+ class CondTimestepBlock(nn.Module):
82
+ """
83
+ Any module where forward() takes timestep embeddings as a second argument.
84
+ """
85
+
86
+ @abstractmethod
87
+ def forward(self, x, cond, emb):
88
+ """
89
+ Apply the module to `x` given `emb` timestep embeddings.
90
+ """
91
+ """
92
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock, CondTimestepBlock):
93
+
94
+ def forward(self, x, cond, emb):
95
+ for layer in self:
96
+ if isinstance(layer, CondTimestepBlock):
97
+ x = layer(x, cond, emb)
98
+ elif isinstance(layer, TimestepBlock):
99
+ x = layer(x, emb)
100
+ else:
101
+ x = layer(x)
102
+ return x
103
+ """
104
+
105
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock, CondTimestepBlock):
106
+ def forward(self, x, cond, emb):
107
+ outputs_list = [] # 创建一个空列表来存储第二个输出
108
+ for layer in self:
109
+ if isinstance(layer, CondTimestepBlock):
110
+ # 调用layer并检查输出是否为一个元组
111
+ result = layer(x, cond, emb)
112
+ if isinstance(result, tuple) and len(result) == 2:
113
+ x, additional_output = result
114
+ outputs_list.append(additional_output) # 将第二个输出添加到列表
115
+ else:
116
+ x = result
117
+ elif isinstance(layer, TimestepBlock):
118
+ x = layer(x, emb)
119
+ else:
120
+ x = layer(x)
121
+
122
+ if outputs_list == []:
123
+ return x
124
+ else:
125
+ return x, outputs_list # 返回最终的x和所有附加输出的列表
126
+
127
+
128
+
129
+ class Upsample(nn.Module):
130
+ """
131
+ An upsampling layer with an optional convolution.
132
+
133
+ :param channels: channels in the inputs and outputs.
134
+ :param use_conv: a bool determining if a convolution is applied.
135
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
136
+ upsampling occurs in the inner-two dimensions.
137
+ """
138
+
139
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
140
+ super().__init__()
141
+ self.channels = channels
142
+ self.out_channels = out_channels or channels
143
+ self.use_conv = use_conv
144
+ self.dims = dims
145
+ if use_conv:
146
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
147
+
148
+ def forward(self, x):
149
+ assert x.shape[1] == self.channels
150
+ if self.dims == 3:
151
+ x = F.interpolate(
152
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
153
+ )
154
+ else:
155
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
156
+ if self.use_conv:
157
+ x = self.conv(x)
158
+ return x
159
+
160
+
161
+ class Downsample(nn.Module):
162
+ """
163
+ A downsampling layer with an optional convolution.
164
+
165
+ :param channels: channels in the inputs and outputs.
166
+ :param use_conv: a bool determining if a convolution is applied.
167
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
168
+ downsampling occurs in the inner-two dimensions.
169
+ """
170
+
171
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
172
+ super().__init__()
173
+ self.channels = channels
174
+ self.out_channels = out_channels or channels
175
+ self.use_conv = use_conv
176
+ self.dims = dims
177
+ stride = 2 if dims != 3 else (1, 2, 2)
178
+ if use_conv:
179
+ self.op = conv_nd(
180
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=1
181
+ )
182
+ else:
183
+ assert self.channels == self.out_channels
184
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
185
+
186
+ def forward(self, x):
187
+ assert x.shape[1] == self.channels
188
+ return self.op(x)
189
+
190
+
191
+ class SPADEGroupNorm(nn.Module):
192
+ def __init__(self, norm_nc, label_nc, eps = 1e-5,debug = False):
193
+ super().__init__()
194
+ self.debug = debug
195
+ self.norm = nn.GroupNorm(32, norm_nc, affine=False) # 32/16
196
+
197
+ self.eps = eps
198
+ nhidden = 128
199
+ self.mlp_shared = nn.Sequential(
200
+ nn.Conv2d(label_nc, nhidden, kernel_size=3, padding=1),
201
+ nn.ReLU(),
202
+ )
203
+ self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
204
+ self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
205
+
206
+ def forward(self, x, segmap):
207
+ # Part 1. generate parameter-free normalized activations
208
+ x = self.norm(x)
209
+
210
+ # Part 2. produce scaling and bias conditioned on semantic map
211
+ segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
212
+ actv = self.mlp_shared(segmap)
213
+ gamma = self.mlp_gamma(actv)
214
+ beta = self.mlp_beta(actv)
215
+
216
+ # apply scale and bias
217
+ if self.debug:
218
+ return x * (1 + gamma) + beta, (beta.detach().cpu(), gamma.detach().cpu())
219
+ else:
220
+ return x * (1 + gamma) + beta
221
+
222
+
223
+ class AdaIN(nn.Module):
224
+ def __init__(self, num_features):
225
+ super().__init__()
226
+ self.instance_norm = th.nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)
227
+
228
+ def forward(self, x, alpha, gamma):
229
+ assert x.shape[:2] == alpha.shape[:2] == gamma.shape[:2]
230
+ norm = self.instance_norm(x)
231
+ return alpha * norm + gamma
232
+
233
+ class RESAILGroupNorm(nn.Module):
234
+ def __init__(self, norm_nc, label_nc, guidance_nc, eps = 1e-5):
235
+ super().__init__()
236
+
237
+ self.norm = nn.GroupNorm(32, norm_nc, affine=False) # 32/16
238
+
239
+ # SPADE
240
+ self.eps = eps
241
+ nhidden = 128
242
+ self.mask_mlp_shared = nn.Sequential(
243
+ nn.Conv2d(label_nc, nhidden, kernel_size=3, padding=1),
244
+ nn.ReLU(),
245
+ )
246
+
247
+ self.mask_mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
248
+ self.mask_mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
249
+
250
+
251
+ # Guidance
252
+
253
+ self.conv_s = th.nn.Conv2d(label_nc, nhidden * 2, 3, 2)
254
+ self.pool_s = th.nn.AdaptiveAvgPool2d(1)
255
+ self.conv_s2 = th.nn.Conv2d(nhidden * 2, nhidden * 2, 1, 1)
256
+
257
+ self.conv1 = th.nn.Conv2d(guidance_nc, nhidden, 3, 1, padding=1)
258
+ self.adaIn1 = AdaIN(norm_nc * 2)
259
+ self.relu1 = nn.ReLU()
260
+
261
+ self.conv2 = th.nn.Conv2d(nhidden, nhidden, 3, 1, padding=1)
262
+ self.adaIn2 = AdaIN(norm_nc * 2)
263
+ self.relu2 = nn.ReLU()
264
+ self.conv3 = th.nn.Conv2d(nhidden, nhidden, 3, 1, padding=1)
265
+
266
+ self.guidance_mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
267
+ self.guidance_mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
268
+
269
+ self.blending_gamma = nn.Parameter(th.zeros(1), requires_grad=True)
270
+ self.blending_beta = nn.Parameter(th.zeros(1), requires_grad=True)
271
+ self.norm_nc = norm_nc
272
+
273
+ def forward(self, x, segmap, guidance):
274
+ # Part 1. generate parameter-free normalized activations
275
+ x = self.norm(x)
276
+ # Part 2. produce scaling and bias conditioned on semantic map
277
+ segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
278
+ mask_actv = self.mask_mlp_shared(segmap)
279
+ mask_gamma = self.mask_mlp_gamma(mask_actv)
280
+ mask_beta = self.mask_mlp_beta(mask_actv)
281
+
282
+
283
+ # Part 3. produce scaling and bias conditioned on feature guidance
284
+ guidance = F.interpolate(guidance, size=x.size()[2:], mode='bilinear')
285
+
286
+ f_s_1 = self.conv_s(segmap)
287
+ c1 = self.pool_s(f_s_1)
288
+ c2 = self.conv_s2(c1)
289
+
290
+ f1 = self.conv1(guidance)
291
+
292
+ f1 = self.adaIn1(f1, c1[:, : 128, ...], c1[:, 128:, ...])
293
+ f2 = self.relu1(f1)
294
+
295
+ f2 = self.conv2(f2)
296
+ f2 = self.adaIn2(f2, c2[:, : 128, ...], c2[:, 128:, ...])
297
+ f2 = self.relu2(f2)
298
+ guidance_actv = self.conv3(f2)
299
+
300
+ guidance_gamma = self.guidance_mlp_gamma(guidance_actv)
301
+ guidance_beta = self.guidance_mlp_beta(guidance_actv)
302
+
303
+ gamma_alpha = F.sigmoid(self.blending_gamma)
304
+ beta_alpha = F.sigmoid(self.blending_beta)
305
+
306
+ gamma_final = gamma_alpha * guidance_gamma + (1 - gamma_alpha) * mask_gamma
307
+ beta_final = beta_alpha * guidance_beta + (1 - beta_alpha) * mask_beta
308
+ out = x * (1 + gamma_final) + beta_final
309
+
310
+ # apply scale and bias
311
+ return out
312
+
313
+ class SPMGroupNorm(nn.Module):
314
+ def __init__(self, norm_nc, label_nc, feature_nc, eps = 1e-5):
315
+ super().__init__()
316
+ print("use SPM")
317
+
318
+ self.norm = nn.GroupNorm(32, norm_nc, affine=False) # 32/16
319
+
320
+ # SPADE
321
+ self.eps = eps
322
+ nhidden = 128
323
+ self.mask_mlp_shared = nn.Sequential(
324
+ nn.Conv2d(label_nc, nhidden, kernel_size=3, padding=1),
325
+ nn.ReLU(),
326
+ )
327
+
328
+ self.mask_mlp_gamma1 = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
329
+ self.mask_mlp_beta1 = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
330
+
331
+ self.mask_mlp_gamma2 = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
332
+ self.mask_mlp_beta2 = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
333
+
334
+
335
+ # Feature
336
+ self.feature_mlp_shared = nn.Sequential(
337
+ nn.Conv2d(feature_nc, nhidden, kernel_size=3, padding=1),
338
+ nn.ReLU(),
339
+ )
340
+
341
+ self.feature_mlp_gamma1 = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
342
+ self.feature_mlp_beta1 = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
343
+
344
+
345
+ def forward(self, x, segmap, guidance):
346
+ # Part 1. generate parameter-free normalized activations
347
+ x = self.norm(x)
348
+ # Part 2. produce scaling and bias conditioned on semantic map
349
+ segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
350
+ mask_actv = self.mask_mlp_shared(segmap)
351
+ mask_gamma1 = self.mask_mlp_gamma1(mask_actv)
352
+ mask_beta1 = self.mask_mlp_beta1(mask_actv)
353
+
354
+ mask_gamma2 = self.mask_mlp_gamma2(mask_actv)
355
+ mask_beta2 = self.mask_mlp_beta2(mask_actv)
356
+
357
+
358
+ # Part 3. produce scaling and bias conditioned on feature guidance
359
+ guidance = F.interpolate(guidance, size=x.size()[2:], mode='bilinear')
360
+ feature_actv = self.feature_mlp_shared(guidance)
361
+ feature_gamma1 = self.feature_mlp_gamma1(feature_actv)
362
+ feature_beta1 = self.feature_mlp_beta1(feature_actv)
363
+
364
+ gamma_final = feature_gamma1 * (1 + mask_gamma1) + mask_beta1
365
+ beta_final = feature_beta1 * (1 + mask_gamma2) + mask_beta2
366
+
367
+ out = x * (1 + gamma_final) + beta_final
368
+
369
+ # apply scale and bias
370
+ return out
371
+
372
+
373
+ class ResBlock(TimestepBlock):
374
+ """
375
+ A residual block that can optionally change the number of channels.
376
+
377
+ :param channels: the number of input channels.
378
+ :param emb_channels: the number of timestep embedding channels.
379
+ :param dropout: the rate of dropout.
380
+ :param out_channels: if specified, the number of out channels.
381
+ :param use_conv: if True and out_channels is specified, use a spatial
382
+ convolution instead of a smaller 1x1 convolution to change the
383
+ channels in the skip connection.
384
+ :param dims: determines if the signal is 1D, 2D, or 3D.
385
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
386
+ :param up: if True, use this block for upsampling.
387
+ :param down: if True, use this block for downsampling.
388
+ """
389
+
390
+ def __init__(
391
+ self,
392
+ channels,
393
+ emb_channels,
394
+ dropout,
395
+ out_channels=None,
396
+ use_conv=False,
397
+ use_scale_shift_norm=False,
398
+ dims=2,
399
+ use_checkpoint=False,
400
+ up=False,
401
+ down=False,
402
+ ):
403
+ super().__init__()
404
+ self.channels = channels
405
+ self.emb_channels = emb_channels
406
+ self.dropout = dropout
407
+ self.out_channels = out_channels or channels
408
+ self.use_conv = use_conv
409
+ self.use_checkpoint = use_checkpoint
410
+ self.use_scale_shift_norm = use_scale_shift_norm
411
+
412
+ self.in_layers = nn.Sequential(
413
+ normalization(channels),
414
+ SiLU(),
415
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
416
+ )
417
+
418
+ self.updown = up or down
419
+
420
+ if up:
421
+ self.h_upd = Upsample(channels, False, dims)
422
+ self.x_upd = Upsample(channels, False, dims)
423
+ elif down:
424
+ self.h_upd = Downsample(channels, False, dims)
425
+ self.x_upd = Downsample(channels, False, dims)
426
+ else:
427
+ self.h_upd = self.x_upd = nn.Identity()
428
+
429
+ self.emb_layers = nn.Sequential(
430
+ SiLU(),
431
+ linear(
432
+ emb_channels,
433
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
434
+ ),
435
+ )
436
+ self.out_layers = nn.Sequential(
437
+ normalization(self.out_channels),
438
+ SiLU(),
439
+ nn.Dropout(p=dropout),
440
+ zero_module(
441
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
442
+ ),
443
+ )
444
+
445
+ if self.out_channels == channels:
446
+ self.skip_connection = nn.Identity()
447
+ elif use_conv:
448
+ self.skip_connection = conv_nd(
449
+ dims, channels, self.out_channels, 3, padding=1
450
+ )
451
+ else:
452
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
453
+
454
+ def forward(self, x, emb):
455
+ """
456
+ Apply the block to a Tensor, conditioned on a timestep embedding.
457
+
458
+ :param x: an [N x C x ...] Tensor of features.
459
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
460
+ :return: an [N x C x ...] Tensor of outputs.
461
+ """
462
+
463
+ return th.utils.checkpoint.checkpoint(self._forward, x ,emb)
464
+ # return checkpoint(
465
+ # self._forward, (x, emb), self.parameters(), self.use_checkpoint
466
+ # )
467
+
468
+ def _forward(self, x, emb):
469
+ if self.updown:
470
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
471
+ h = in_rest(x)
472
+ h = self.h_upd(h)
473
+ x = self.x_upd(x)
474
+ h = in_conv(h)
475
+ else:
476
+ h = self.in_layers(x)
477
+ emb_out = self.emb_layers(emb)#.type(h.dtype)
478
+ while len(emb_out.shape) < len(h.shape):
479
+ emb_out = emb_out[..., None]
480
+ if self.use_scale_shift_norm:
481
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
482
+ scale, shift = th.chunk(emb_out, 2, dim=1)
483
+ h = out_norm(h) * (1 + scale) + shift
484
+ h = out_rest(h)
485
+ else:
486
+ h = h + emb_out
487
+ h = self.out_layers(h)
488
+ return self.skip_connection(x) + h
489
+
490
+ class SDMResBlock(CondTimestepBlock):
491
+ """
492
+ A residual block that can optionally change the number of channels.
493
+
494
+ :param channels: the number of input channels.
495
+ :param emb_channels: the number of timestep embedding channels.
496
+ :param dropout: the rate of dropout.
497
+ :param out_channels: if specified, the number of out channels.
498
+ :param use_conv: if True and out_channels is specified, use a spatial
499
+ convolution instead of a smaller 1x1 convolution to change the
500
+ channels in the skip connection.
501
+ :param dims: determines if the signal is 1D, 2D, or 3D.
502
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
503
+ :param up: if True, use this block for upsampling.
504
+ :param down: if True, use this block for downsampling.
505
+ """
506
+
507
+ def __init__(
508
+ self,
509
+ channels,
510
+ emb_channels,
511
+ dropout,
512
+ c_channels=3,
513
+ out_channels=None,
514
+ use_conv=False,
515
+ use_scale_shift_norm=False,
516
+ dims=2,
517
+ use_checkpoint=False,
518
+ up=False,
519
+ down=False,
520
+ SPADE_type = "spade",
521
+ guidance_nc = None,
522
+ debug = False
523
+ ):
524
+ super().__init__()
525
+ self.channels = channels
526
+ self.guidance_nc = guidance_nc
527
+ self.emb_channels = emb_channels
528
+ self.dropout = dropout
529
+ self.out_channels = out_channels or channels
530
+ self.use_conv = use_conv
531
+ self.use_checkpoint = use_checkpoint
532
+ self.use_scale_shift_norm = use_scale_shift_norm
533
+ self.SPADE_type = SPADE_type
534
+ self.debug = debug
535
+ if self.SPADE_type == "spade":
536
+ self.in_norm = SPADEGroupNorm(channels, c_channels, debug=self.debug)
537
+ elif self.SPADE_type == "RESAIL":
538
+ self.in_norm = RESAILGroupNorm(channels, c_channels, guidance_nc)
539
+ elif self.SPADE_type == "SPM":
540
+ self.in_norm = SPMGroupNorm(channels, c_channels, guidance_nc)
541
+ self.in_layers = nn.Sequential(
542
+ SiLU(),
543
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
544
+ )
545
+
546
+ self.updown = up or down
547
+
548
+ if up:
549
+ self.h_upd = Upsample(channels, False, dims)
550
+ self.x_upd = Upsample(channels, False, dims)
551
+ elif down:
552
+ self.h_upd = Downsample(channels, False, dims)
553
+ self.x_upd = Downsample(channels, False, dims)
554
+ else:
555
+ self.h_upd = self.x_upd = nn.Identity()
556
+
557
+ self.emb_layers = nn.Sequential(
558
+ SiLU(),
559
+ linear(
560
+ emb_channels,
561
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
562
+ ),
563
+ )
564
+
565
+ if self.SPADE_type == "spade":
566
+ self.out_norm = SPADEGroupNorm(self.out_channels, c_channels,debug=self.debug)
567
+ elif self.SPADE_type == "RESAIL":
568
+ self.out_norm = RESAILGroupNorm(self.out_channels, c_channels, guidance_nc)
569
+ elif self.SPADE_type == "SPM":
570
+ self.out_norm = SPMGroupNorm(self.out_channels, c_channels, guidance_nc)
571
+
572
+ self.out_layers = nn.Sequential(
573
+ SiLU(),
574
+ nn.Dropout(p=dropout),
575
+ zero_module(
576
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
577
+ ),
578
+ )
579
+
580
+ if self.out_channels == channels:
581
+ self.skip_connection = nn.Identity()
582
+ elif use_conv:
583
+ self.skip_connection = conv_nd(
584
+ dims, channels, self.out_channels, 3, padding=1
585
+ )
586
+ else:
587
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
588
+
589
+ def forward(self, x, cond, emb):
590
+ """
591
+ Apply the block to a Tensor, conditioned on a timestep embedding.
592
+
593
+ :param x: an [N x C x ...] Tensor of features.
594
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
595
+ :return: an [N x C x ...] Tensor of outputs.
596
+ """
597
+ return th.utils.checkpoint.checkpoint(self._forward, x, cond, emb)
598
+ # return checkpoint(
599
+ # self._forward, (x, cond, emb), self.parameters(), self.use_checkpoint
600
+ # )
601
+
602
+ def _forward(self, x, cond, emb):
603
+ if self.SPADE_type == "RESAIL" or self.SPADE_type == "SPM":
604
+ assert self.guidance_nc is not None, "Please set guidance_nc when you use RESAIL"
605
+ guidance = x[: ,x.shape[1] - self.guidance_nc:, ...]
606
+ else:
607
+ guidance = None
608
+ if self.updown:
609
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
610
+ if self.SPADE_type == "spade":
611
+ if not self.debug:
612
+ h = self.in_norm(x, cond)
613
+ else:
614
+ h, (b1,g1) = self.in_norm(x, cond)
615
+ elif self.SPADE_type == "RESAIL" or self.SPADE_type == "SPM":
616
+ h = self.in_norm(x, cond, guidance)
617
+
618
+ h = in_rest(h)
619
+ h = self.h_upd(h)
620
+ x = self.x_upd(x)
621
+ h = in_conv(h)
622
+ else:
623
+ if self.SPADE_type == "spade":
624
+ if not self.debug:
625
+ h = self.in_norm(x, cond)
626
+ else:
627
+ h, (b1,g1) = self.in_norm(x, cond)
628
+ elif self.SPADE_type == "RESAIL" or self.SPADE_type == "SPM":
629
+ h = self.in_norm(x, cond, guidance)
630
+ h = self.in_layers(h)
631
+
632
+ emb_out = self.emb_layers(emb)#.type(h.dtype)
633
+ while len(emb_out.shape) < len(h.shape):
634
+ emb_out = emb_out[..., None]
635
+ if self.use_scale_shift_norm:
636
+ scale, shift = th.chunk(emb_out, 2, dim=1)
637
+ if self.SPADE_type == "spade":
638
+ if not self.debug:
639
+ h = self.out_norm(h, cond)
640
+ else:
641
+ h, (b2,g2) = self.out_norm(h, cond)
642
+ elif self.SPADE_type == "RESAIL" or self.SPADE_type == "SPM":
643
+ h = self.out_norm(h, cond, guidance)
644
+
645
+ h = h * (1 + scale) + shift
646
+ h = self.out_layers(h)
647
+ else:
648
+ h = h + emb_out
649
+ if self.SPADE_type == "spade":
650
+ h = self.out_norm(h, cond)
651
+ elif self.SPADE_type == "RESAIL" or self.SPADE_type == "SPM":
652
+ h = self.out_norm(x, cond, guidance)
653
+
654
+ h = self.out_layers(h)
655
+ if self.debug:
656
+ extra = {(b1,g1),(b2,g2)}
657
+ return self.skip_connection(x) + h, extra
658
+ else:
659
+ return self.skip_connection(x) + h
660
+
661
+ class AttentionBlock(nn.Module):
662
+ """
663
+ An attention block that allows spatial positions to attend to each other.
664
+
665
+ Originally ported from here, but adapted to the N-d case.
666
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
667
+ """
668
+
669
+ def __init__(
670
+ self,
671
+ channels,
672
+ num_heads=1,
673
+ num_head_channels=-1,
674
+ use_checkpoint=False,
675
+ use_new_attention_order=False,
676
+ ):
677
+ super().__init__()
678
+ self.channels = channels
679
+ if num_head_channels == -1:
680
+ self.num_heads = num_heads
681
+ else:
682
+ assert (
683
+ channels % num_head_channels == 0
684
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
685
+ self.num_heads = channels // num_head_channels
686
+ self.use_checkpoint = use_checkpoint
687
+ self.norm = normalization(channels)
688
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
689
+ if use_new_attention_order:
690
+ # split qkv before split heads
691
+ self.attention = QKVAttention(self.num_heads)
692
+ else:
693
+ # split heads before split qkv
694
+ self.attention = QKVAttentionLegacy(self.num_heads)
695
+
696
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
697
+
698
+ def forward(self, x):
699
+ return th.utils.checkpoint.checkpoint(self._forward, x)
700
+ #return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint)
701
+
702
+ def _forward(self, x):
703
+ b, c, *spatial = x.shape
704
+ x = x.reshape(b, c, -1)
705
+ qkv = self.qkv(self.norm(x))
706
+ h = self.attention(qkv)
707
+ h = self.proj_out(h)
708
+ return (x + h).reshape(b, c, *spatial)
709
+
710
+
711
+ def count_flops_attn(model, _x, y):
712
+ """
713
+ A counter for the `thop` package to count the operations in an
714
+ attention operation.
715
+ Meant to be used like:
716
+ macs, params = thop.profile(
717
+ model,
718
+ inputs=(inputs, timestamps),
719
+ custom_ops={QKVAttention: QKVAttention.count_flops},
720
+ )
721
+ """
722
+ b, c, *spatial = y[0].shape
723
+ num_spatial = int(np.prod(spatial))
724
+ # We perform two matmuls with the same number of ops.
725
+ # The first computes the weight matrix, the second computes
726
+ # the combination of the value vectors.
727
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
728
+ model.total_ops += th.DoubleTensor([matmul_ops])
729
+
730
+
731
+ class QKVAttentionLegacy(nn.Module):
732
+ """
733
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
734
+ """
735
+
736
+ def __init__(self, n_heads):
737
+ super().__init__()
738
+ self.n_heads = n_heads
739
+
740
+ def forward(self, qkv):
741
+ """
742
+ Apply QKV attention.
743
+
744
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
745
+ :return: an [N x (H * C) x T] tensor after attention.
746
+ """
747
+ bs, width, length = qkv.shape
748
+ assert width % (3 * self.n_heads) == 0
749
+ ch = width // (3 * self.n_heads)
750
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
751
+ scale = 1 / math.sqrt(math.sqrt(ch))
752
+ weight = th.einsum(
753
+ "bct,bcs->bts", q * scale, k * scale
754
+ ) # More stable with f16 than dividing afterwards
755
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
756
+ a = th.einsum("bts,bcs->bct", weight, v)
757
+ return a.reshape(bs, -1, length)
758
+
759
+ @staticmethod
760
+ def count_flops(model, _x, y):
761
+ return count_flops_attn(model, _x, y)
762
+
763
+
764
+ class QKVAttention(nn.Module):
765
+ """
766
+ A module which performs QKV attention and splits in a different order.
767
+ """
768
+
769
+ def __init__(self, n_heads):
770
+ super().__init__()
771
+ self.n_heads = n_heads
772
+
773
+ def forward(self, qkv):
774
+ """
775
+ Apply QKV attention.
776
+
777
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
778
+ :return: an [N x (H * C) x T] tensor after attention.
779
+ """
780
+ bs, width, length = qkv.shape
781
+ assert width % (3 * self.n_heads) == 0
782
+ ch = width // (3 * self.n_heads)
783
+ q, k, v = qkv.chunk(3, dim=1)
784
+ scale = 1 / math.sqrt(math.sqrt(ch))
785
+ weight = th.einsum(
786
+ "bct,bcs->bts",
787
+ (q * scale).view(bs * self.n_heads, ch, length),
788
+ (k * scale).view(bs * self.n_heads, ch, length),
789
+ ) # More stable with f16 than dividing afterwards
790
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
791
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
792
+ return a.reshape(bs, -1, length)
793
+
794
+ @staticmethod
795
+ def count_flops(model, _x, y):
796
+ return count_flops_attn(model, _x, y)
797
+
798
+
799
+ class UNetModel(ModelMixin, ConfigMixin):
800
+ """
801
+ The full UNet model with attention and timestep embedding.
802
+
803
+ :param in_channels: channels in the input Tensor.
804
+ :param model_channels: base channel count for the model.
805
+ :param out_channels: channels in the output Tensor.
806
+ :param num_res_blocks: number of residual blocks per downsample.
807
+ :param attention_resolutions: a collection of downsample rates at which
808
+ attention will take place. May be a set, list, or tuple.
809
+ For example, if this contains 4, then at 4x downsampling, attention
810
+ will be used.
811
+ :param dropout: the dropout probability.
812
+ :param channel_mult: channel multiplier for each level of the UNet.
813
+ :param conv_resample: if True, use learned convolutions for upsampling and
814
+ downsampling.
815
+ :param dims: determines if the signal is 1D, 2D, or 3D.
816
+ :param num_classes: if specified (as an int), then this model will be
817
+ class-conditional with `num_classes` classes.
818
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
819
+ :param num_heads: the number of attention heads in each attention layer.
820
+ :param num_heads_channels: if specified, ignore num_heads and instead use
821
+ a fixed channel width per attention head.
822
+ :param num_heads_upsample: works with num_heads to set a different number
823
+ of heads for upsampling. Deprecated.
824
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
825
+ :param resblock_updown: use residual blocks for up/downsampling.
826
+ :param use_new_attention_order: use a different attention pattern for potentially
827
+ increased efficiency.
828
+ """
829
+
830
+ _supports_gradient_checkpointing = True
831
+ @register_to_config
832
+ def __init__(
833
+ self,
834
+ image_size,
835
+ in_channels,
836
+ model_channels,
837
+ out_channels,
838
+ num_res_blocks,
839
+ attention_resolutions,
840
+ dropout=0,
841
+ channel_mult=(1, 2, 4, 8),
842
+ conv_resample=True,
843
+ dims=2,
844
+ num_classes=None,
845
+ use_checkpoint=False,
846
+ use_fp16=True,
847
+ num_heads=1,
848
+ num_head_channels=-1,
849
+ num_heads_upsample=-1,
850
+ use_scale_shift_norm=False,
851
+ resblock_updown=False,
852
+ use_new_attention_order=False,
853
+ mask_emb="resize",
854
+ SPADE_type="spade",
855
+ debug = False
856
+ ):
857
+ super().__init__()
858
+
859
+ if num_heads_upsample == -1:
860
+ num_heads_upsample = num_heads
861
+
862
+ self.sample_size = image_size
863
+ self.in_channels = in_channels
864
+ self.model_channels = model_channels
865
+ self.out_channels = out_channels
866
+ self.num_res_blocks = num_res_blocks
867
+ self.attention_resolutions = attention_resolutions
868
+ self.dropout = dropout
869
+ self.channel_mult = channel_mult
870
+ self.conv_resample = conv_resample
871
+ self.num_classes = num_classes
872
+ self.use_checkpoint = use_checkpoint
873
+ self.num_heads = num_heads
874
+ self.num_head_channels = num_head_channels
875
+ self.num_heads_upsample = num_heads_upsample
876
+
877
+ self.debug = debug
878
+
879
+ self.mask_emb = mask_emb
880
+
881
+ time_embed_dim = model_channels * 4
882
+ self.time_embed = nn.Sequential(
883
+ linear(model_channels, time_embed_dim),
884
+ SiLU(),
885
+ linear(time_embed_dim, time_embed_dim),
886
+ )
887
+
888
+ ch = input_ch = int(channel_mult[0] * model_channels)
889
+ self.input_blocks = nn.ModuleList(
890
+ [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))] #ch=256
891
+ )
892
+ self._feature_size = ch
893
+ input_block_chans = [ch]
894
+ ds = 1
895
+ for level, mult in enumerate(channel_mult):
896
+ for _ in range(num_res_blocks):
897
+ layers = [
898
+ ResBlock(
899
+ ch,
900
+ time_embed_dim,
901
+ dropout,
902
+ out_channels=int(mult * model_channels),
903
+ dims=dims,
904
+ use_checkpoint=use_checkpoint,
905
+ use_scale_shift_norm=use_scale_shift_norm,
906
+ )
907
+ ]
908
+ ch = int(mult * model_channels)
909
+ #print(ds)
910
+ if ds in attention_resolutions:
911
+ layers.append(
912
+ AttentionBlock(
913
+ ch,
914
+ use_checkpoint=use_checkpoint,
915
+ num_heads=num_heads,
916
+ num_head_channels=num_head_channels,
917
+ use_new_attention_order=use_new_attention_order,
918
+ )
919
+ )
920
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
921
+ self._feature_size += ch
922
+ input_block_chans.append(ch)
923
+ if level != len(channel_mult) - 1:
924
+ out_ch = ch
925
+ self.input_blocks.append(
926
+ TimestepEmbedSequential(
927
+ ResBlock(
928
+ ch,
929
+ time_embed_dim,
930
+ dropout,
931
+ out_channels=out_ch,
932
+ dims=dims,
933
+ use_checkpoint=use_checkpoint,
934
+ use_scale_shift_norm=use_scale_shift_norm,
935
+ down=True,
936
+ )
937
+ if resblock_updown
938
+ else Downsample(
939
+ ch, conv_resample, dims=dims, out_channels=out_ch
940
+ )
941
+ )
942
+ )
943
+ ch = out_ch
944
+ input_block_chans.append(ch)
945
+ ds *= 2
946
+ self._feature_size += ch
947
+ self.middle_block = TimestepEmbedSequential(
948
+ SDMResBlock(
949
+ ch,
950
+ time_embed_dim,
951
+ dropout,
952
+ c_channels=num_classes if mask_emb == "resize" else num_classes*4,
953
+ dims=dims,
954
+ use_checkpoint=use_checkpoint,
955
+ use_scale_shift_norm=use_scale_shift_norm,
956
+ ),
957
+ AttentionBlock(
958
+ ch,
959
+ use_checkpoint=use_checkpoint,
960
+ num_heads=num_heads,
961
+ num_head_channels=num_head_channels,
962
+ use_new_attention_order=use_new_attention_order,
963
+ ),
964
+ SDMResBlock(
965
+ ch,
966
+ time_embed_dim,
967
+ dropout,
968
+ c_channels=num_classes if mask_emb == "resize" else num_classes*4 ,
969
+ dims=dims,
970
+ use_checkpoint=use_checkpoint,
971
+ use_scale_shift_norm=use_scale_shift_norm,
972
+ ),
973
+ )
974
+ self._feature_size += ch
975
+
976
+ self.output_blocks = nn.ModuleList([])
977
+ for level, mult in list(enumerate(channel_mult))[::-1]:
978
+ for i in range(num_res_blocks + 1):
979
+ ich = input_block_chans.pop()
980
+ #print(ch, ich)
981
+ layers = [
982
+ SDMResBlock(
983
+ ch + ich,
984
+ time_embed_dim,
985
+ dropout,
986
+ c_channels=num_classes if mask_emb == "resize" else num_classes*4,
987
+ out_channels=int(model_channels * mult),
988
+ dims=dims,
989
+ use_checkpoint=use_checkpoint,
990
+ use_scale_shift_norm=use_scale_shift_norm,
991
+ SPADE_type=SPADE_type,
992
+ guidance_nc = ich,
993
+ debug=self.debug,
994
+ )
995
+ ]
996
+ ch = int(model_channels * mult)
997
+ #print(ds)
998
+ if ds in attention_resolutions:
999
+ layers.append(
1000
+ AttentionBlock(
1001
+ ch,
1002
+ use_checkpoint=use_checkpoint,
1003
+ num_heads=num_heads_upsample,
1004
+ num_head_channels=num_head_channels,
1005
+ use_new_attention_order=use_new_attention_order,
1006
+ )
1007
+ )
1008
+ if level and i == num_res_blocks:
1009
+ out_ch = ch
1010
+ layers.append(
1011
+ SDMResBlock(
1012
+ ch,
1013
+ time_embed_dim,
1014
+ dropout,
1015
+ c_channels=num_classes if mask_emb == "resize" else num_classes*4,
1016
+ out_channels=out_ch,
1017
+ dims=dims,
1018
+ use_checkpoint=use_checkpoint,
1019
+ use_scale_shift_norm=use_scale_shift_norm,
1020
+ up=True,
1021
+ debug=self.debug
1022
+ )
1023
+ if resblock_updown
1024
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
1025
+ )
1026
+ ds //= 2
1027
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
1028
+ self._feature_size += ch
1029
+
1030
+ self.out = nn.Sequential(
1031
+ normalization(ch),
1032
+ SiLU(),
1033
+ zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)),
1034
+ )
1035
+ def _set_gradient_checkpointing(self, module, value=False):
1036
+ #if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
1037
+ module.gradient_checkpointing = value
1038
+ def forward(self, x, y=None, timesteps=None ):
1039
+ """
1040
+ Apply the model to an input batch.
1041
+
1042
+ :param x: an [N x C x ...] Tensor of inputs.
1043
+ :param timesteps: a 1-D batch of timesteps.
1044
+ :param y: an [N] Tensor of labels, if class-conditional.
1045
+ :return: an [N x C x ...] Tensor of outputs.
1046
+ """
1047
+ assert (y is not None) == (
1048
+ self.num_classes is not None
1049
+ ), "must specify y if and only if the model is class-conditional"
1050
+
1051
+ hs = []
1052
+ if not th.is_tensor(timesteps):
1053
+ timesteps = th.tensor([timesteps], dtype=th.long, device=x.device)
1054
+ elif th.is_tensor(timesteps) and len(timesteps.shape) == 0:
1055
+ timesteps = timesteps[None].to(x.device)
1056
+
1057
+ timesteps = timestep_embedding(timesteps, self.model_channels).type(x.dtype).to(x.device)
1058
+ emb = self.time_embed(timesteps)
1059
+
1060
+ y = y.type(self.dtype)
1061
+ h = x.type(self.dtype)
1062
+ for module in self.input_blocks:
1063
+ # input_blocks have no any opts for y
1064
+ h = module(h, y, emb)
1065
+ #print(h.shape)
1066
+ hs.append(h)
1067
+
1068
+ h = self.middle_block(h, y, emb)
1069
+
1070
+ if self.debug:
1071
+ extra_list = []
1072
+
1073
+ for module in self.output_blocks:
1074
+ temp = hs.pop()
1075
+
1076
+ #print("before:", h.shape, temp.shape)
1077
+ # copy padding to match the downsample size
1078
+ if h.shape[2] != temp.shape[2]:
1079
+ p1d = (0, 0, 0, 1)
1080
+ h = F.pad(h, p1d, "replicate")
1081
+
1082
+ if h.shape[3] != temp.shape[3]:
1083
+ p2d = (0, 1, 0, 0)
1084
+ h = F.pad(h, p2d, "replicate")
1085
+ #print("after:", h.shape, temp.shape)
1086
+
1087
+ h = th.cat([h, temp], dim=1)
1088
+ if self.debug:
1089
+ h, extra = module(h, y, emb)
1090
+ extra_list.append(extra)
1091
+ else:
1092
+ h = module(h, y, emb)
1093
+
1094
+ h = h.type(x.dtype)
1095
+
1096
+ if not self.debug:
1097
+ return UNet2DOutput(sample=self.out(h))
1098
+ else:
1099
+ return UNet2DOutput(sample=self.out(h)), extra_list
1100
+
1101
+
1102
+ class SuperResModel(UNetModel):
1103
+ """
1104
+ A UNetModel that performs super-resolution.
1105
+
1106
+ Expects an extra kwarg `low_res` to condition on a low-resolution image.
1107
+ """
1108
+
1109
+ def __init__(self, image_size, in_channels, *args, **kwargs):
1110
+ super().__init__(image_size, in_channels * 2, *args, **kwargs)
1111
+
1112
+ def forward(self, x, cond, timesteps, low_res=None, **kwargs):
1113
+ _, _, new_height, new_width = x.shape
1114
+ upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
1115
+ x = th.cat([x, upsampled], dim=1)
1116
+ return super().forward(x, cond, timesteps, **kwargs)
1117
+
1118
+
1119
+ class EncoderUNetModel(nn.Module):
1120
+ """
1121
+ The half UNet model with attention and timestep embedding.
1122
+
1123
+ For usage, see UNet.
1124
+ """
1125
+
1126
+ def __init__(
1127
+ self,
1128
+ image_size,
1129
+ in_channels,
1130
+ model_channels,
1131
+ out_channels,
1132
+ num_res_blocks,
1133
+ attention_resolutions,
1134
+ dropout=0,
1135
+ channel_mult=(1, 2, 4, 8),
1136
+ conv_resample=True,
1137
+ dims=2,
1138
+ use_checkpoint=False,
1139
+ use_fp16=False,
1140
+ num_heads=1,
1141
+ num_head_channels=-1,
1142
+ num_heads_upsample=-1,
1143
+ use_scale_shift_norm=False,
1144
+ resblock_updown=False,
1145
+ use_new_attention_order=False,
1146
+ pool="adaptive",
1147
+ ):
1148
+ super().__init__()
1149
+
1150
+ if num_heads_upsample == -1:
1151
+ num_heads_upsample = num_heads
1152
+
1153
+ self.in_channels = in_channels
1154
+ self.model_channels = model_channels
1155
+ self.out_channels = out_channels
1156
+ self.num_res_blocks = num_res_blocks
1157
+ self.attention_resolutions = attention_resolutions
1158
+ self.dropout = dropout
1159
+ self.channel_mult = channel_mult
1160
+ self.conv_resample = conv_resample
1161
+ self.use_checkpoint = use_checkpoint
1162
+ self.num_heads = num_heads
1163
+ self.num_head_channels = num_head_channels
1164
+ self.num_heads_upsample = num_heads_upsample
1165
+
1166
+ time_embed_dim = model_channels * 4
1167
+ self.time_embed = nn.Sequential(
1168
+ linear(model_channels, time_embed_dim),
1169
+ SiLU(),
1170
+ linear(time_embed_dim, time_embed_dim),
1171
+ )
1172
+
1173
+ ch = int(channel_mult[0] * model_channels)
1174
+ self.input_blocks = nn.ModuleList(
1175
+ [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
1176
+ )
1177
+ self._feature_size = ch
1178
+ input_block_chans = [ch]
1179
+ ds = 1
1180
+ for level, mult in enumerate(channel_mult):
1181
+ for _ in range(num_res_blocks):
1182
+ layers = [
1183
+ ResBlock(
1184
+ ch,
1185
+ time_embed_dim,
1186
+ dropout,
1187
+ out_channels=int(mult * model_channels),
1188
+ dims=dims,
1189
+ use_checkpoint=use_checkpoint,
1190
+ use_scale_shift_norm=use_scale_shift_norm,
1191
+ )
1192
+ ]
1193
+ ch = int(mult * model_channels)
1194
+ if ds in attention_resolutions:
1195
+ layers.append(
1196
+ AttentionBlock(
1197
+ ch,
1198
+ use_checkpoint=use_checkpoint,
1199
+ num_heads=num_heads,
1200
+ num_head_channels=num_head_channels,
1201
+ use_new_attention_order=use_new_attention_order,
1202
+ )
1203
+ )
1204
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
1205
+ self._feature_size += ch
1206
+ input_block_chans.append(ch)
1207
+ if level != len(channel_mult) - 1:
1208
+ out_ch = ch
1209
+ self.input_blocks.append(
1210
+ TimestepEmbedSequential(
1211
+ ResBlock(
1212
+ ch,
1213
+ time_embed_dim,
1214
+ dropout,
1215
+ out_channels=out_ch,
1216
+ dims=dims,
1217
+ use_checkpoint=use_checkpoint,
1218
+ use_scale_shift_norm=use_scale_shift_norm,
1219
+ down=True,
1220
+ )
1221
+ if resblock_updown
1222
+ else Downsample(
1223
+ ch, conv_resample, dims=dims, out_channels=out_ch
1224
+ )
1225
+ )
1226
+ )
1227
+ ch = out_ch
1228
+ input_block_chans.append(ch)
1229
+ ds *= 2
1230
+ self._feature_size += ch
1231
+
1232
+ self.middle_block = TimestepEmbedSequential(
1233
+ ResBlock(
1234
+ ch,
1235
+ time_embed_dim,
1236
+ dropout,
1237
+ dims=dims,
1238
+ use_checkpoint=use_checkpoint,
1239
+ use_scale_shift_norm=use_scale_shift_norm,
1240
+ ),
1241
+ AttentionBlock(
1242
+ ch,
1243
+ use_checkpoint=use_checkpoint,
1244
+ num_heads=num_heads,
1245
+ num_head_channels=num_head_channels,
1246
+ use_new_attention_order=use_new_attention_order,
1247
+ ),
1248
+ ResBlock(
1249
+ ch,
1250
+ time_embed_dim,
1251
+ dropout,
1252
+ dims=dims,
1253
+ use_checkpoint=use_checkpoint,
1254
+ use_scale_shift_norm=use_scale_shift_norm,
1255
+ ),
1256
+ )
1257
+ self._feature_size += ch
1258
+ self.pool = pool
1259
+ if pool == "adaptive":
1260
+ self.out = nn.Sequential(
1261
+ normalization(ch),
1262
+ SiLU(),
1263
+ nn.AdaptiveAvgPool2d((1, 1)),
1264
+ zero_module(conv_nd(dims, ch, out_channels, 1)),
1265
+ nn.Flatten(),
1266
+ )
1267
+ elif pool == "attention":
1268
+ assert num_head_channels != -1
1269
+ self.out = nn.Sequential(
1270
+ normalization(ch),
1271
+ SiLU(),
1272
+ AttentionPool2d(
1273
+ (image_size // ds), ch, num_head_channels, out_channels
1274
+ ),
1275
+ )
1276
+ elif pool == "spatial":
1277
+ self.out = nn.Sequential(
1278
+ nn.Linear(self._feature_size, 2048),
1279
+ nn.ReLU(),
1280
+ nn.Linear(2048, self.out_channels),
1281
+ )
1282
+ elif pool == "spatial_v2":
1283
+ self.out = nn.Sequential(
1284
+ nn.Linear(self._feature_size, 2048),
1285
+ normalization(2048),
1286
+ SiLU(),
1287
+ nn.Linear(2048, self.out_channels),
1288
+ )
1289
+ else:
1290
+ raise NotImplementedError(f"Unexpected {pool} pooling")
1291
+ def forward(self, x, timesteps):
1292
+ """
1293
+ Apply the model to an input batch.
1294
+
1295
+ :param x: an [N x C x ...] Tensor of inputs.
1296
+ :param timesteps: a 1-D batch of timesteps.
1297
+ :return: an [N x K] Tensor of outputs.
1298
+ """
1299
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
1300
+
1301
+ results = []
1302
+ h = x.type(self.dtype)
1303
+ for module in self.input_blocks:
1304
+ h = module(h, emb)
1305
+ if self.pool.startswith("spatial"):
1306
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
1307
+ h = self.middle_block(h, emb)
1308
+ if self.pool.startswith("spatial"):
1309
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
1310
+ h = th.cat(results, axis=-1)
1311
+ return self.out(h)
1312
+ else:
1313
+ h = h.type(x.dtype)
1314
+ return self.out(h)
1315
+
diffusion_module/unet_2d_blocks.py ADDED
The diff for this file is too large to render. See raw diff
 
diffusion_module/unet_2d_sdm.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
21
+ from diffusers.utils import BaseOutput
22
+ from diffusers.models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
23
+ from diffusers.models.modeling_utils import ModelMixin
24
+ from .unet_2d_blocks import UNetSDMMidBlock2D, get_down_block, get_up_block, UNetSDMMidBlock2D
25
+ from diffusers.loaders import UNet2DConditionLoadersMixin
26
+
27
+ @dataclass
28
+ class UNet2DOutput(BaseOutput):
29
+ """
30
+ Args:
31
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
32
+ Hidden states output. Output of last layer of model.
33
+ """
34
+
35
+ sample: torch.FloatTensor
36
+
37
+ class SDMUNet2DModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
38
+ r"""
39
+ UNet2DModel is a 2D UNet model that takes in a noisy sample and a timestep and returns sample shaped output.
40
+
41
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
42
+ implements for all the model (such as downloading or saving, etc.)
43
+
44
+ Parameters:
45
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
46
+ Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) -
47
+ 1)`.
48
+ in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image.
49
+ out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
50
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
51
+ time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
52
+ freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding.
53
+ flip_sin_to_cos (`bool`, *optional*, defaults to :
54
+ obj:`True`): Whether to flip sin to cos for fourier time embedding.
55
+ down_block_types (`Tuple[str]`, *optional*, defaults to :
56
+ obj:`("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block
57
+ types.
58
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
59
+ The mid block type. Choose from `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
60
+ up_block_types (`Tuple[str]`, *optional*, defaults to :
61
+ obj:`("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): Tuple of upsample block types.
62
+ block_out_channels (`Tuple[int]`, *optional*, defaults to :
63
+ obj:`(224, 448, 672, 896)`): Tuple of block output channels.
64
+ layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block.
65
+ mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block.
66
+ downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution.
67
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
68
+ attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
69
+ norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for the normalization.
70
+ norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization.
71
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
72
+ for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
73
+ class_embed_type (`str`, *optional*, defaults to None):
74
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
75
+ `"timestep"`, or `"identity"`.
76
+ num_class_embeds (`int`, *optional*, defaults to None):
77
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
78
+ class conditioning with `class_embed_type` equal to `None`.
79
+ """
80
+
81
+ _supports_gradient_checkpointing = True
82
+ @register_to_config
83
+ def __init__(
84
+ self,
85
+ sample_size: Optional[Union[int, Tuple[int, int]]] = None,
86
+ in_channels: int = 3,
87
+ out_channels: int = 3,
88
+ center_input_sample: bool = False,
89
+ time_embedding_type: str = "positional",
90
+ freq_shift: int = 0,
91
+ flip_sin_to_cos: bool = True,
92
+ down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
93
+ up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
94
+ block_out_channels: Tuple[int] = (224, 448, 672, 896),
95
+ layers_per_block: int = 2,
96
+ mid_block_scale_factor: float = 1,
97
+ downsample_padding: int = 1,
98
+ act_fn: str = "silu",
99
+ attention_head_dim: Optional[int] = 8,
100
+ norm_num_groups: int = 32,
101
+ norm_eps: float = 1e-5,
102
+ resnet_time_scale_shift: str = "scale_shift",
103
+ add_attention: bool = True,
104
+ class_embed_type: Optional[str] = None,
105
+ num_class_embeds: Optional[int] = None,
106
+ segmap_channels: int = 34,
107
+ ):
108
+ super().__init__()
109
+
110
+ self.sample_size = sample_size
111
+ self.segmap_channels = segmap_channels
112
+ time_embed_dim = block_out_channels[0] * 4
113
+
114
+ # Check inputs
115
+ if len(down_block_types) != len(up_block_types):
116
+ raise ValueError(
117
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
118
+ )
119
+
120
+ if len(block_out_channels) != len(down_block_types):
121
+ raise ValueError(
122
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
123
+ )
124
+
125
+ # input
126
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
127
+
128
+ # time
129
+ if time_embedding_type == "fourier":
130
+ self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16)
131
+ timestep_input_dim = 2 * block_out_channels[0]
132
+ elif time_embedding_type == "positional":
133
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
134
+ timestep_input_dim = block_out_channels[0]
135
+
136
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
137
+
138
+ # class embedding
139
+ if class_embed_type is None and num_class_embeds is not None:
140
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
141
+ elif class_embed_type == "timestep":
142
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
143
+ elif class_embed_type == "identity":
144
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
145
+ else:
146
+ self.class_embedding = None
147
+
148
+ self.down_blocks = nn.ModuleList([])
149
+ self.mid_block = None
150
+ self.up_blocks = nn.ModuleList([])
151
+
152
+ # down
153
+ output_channel = block_out_channels[0]
154
+ for i, down_block_type in enumerate(down_block_types):
155
+ input_channel = output_channel
156
+ output_channel = block_out_channels[i]
157
+ is_final_block = i == len(block_out_channels) - 1
158
+
159
+ down_block = get_down_block(
160
+ down_block_type,
161
+ num_layers=layers_per_block,
162
+ in_channels=input_channel,
163
+ out_channels=output_channel,
164
+ temb_channels=time_embed_dim,
165
+ add_downsample=not is_final_block,
166
+ resnet_eps=norm_eps,
167
+ resnet_act_fn=act_fn,
168
+ resnet_groups=norm_num_groups,
169
+ attn_num_head_channels=attention_head_dim,
170
+ downsample_padding=downsample_padding,
171
+ resnet_time_scale_shift=resnet_time_scale_shift,
172
+ )
173
+ self.down_blocks.append(down_block)
174
+
175
+ # mid
176
+ self.mid_block = UNetSDMMidBlock2D(
177
+ in_channels=block_out_channels[-1],
178
+ temb_channels=time_embed_dim,
179
+ resnet_eps=norm_eps,
180
+ resnet_act_fn=act_fn,
181
+ output_scale_factor=mid_block_scale_factor,
182
+ resnet_time_scale_shift=resnet_time_scale_shift,
183
+ attn_num_head_channels=attention_head_dim,
184
+ resnet_groups=norm_num_groups,
185
+ add_attention=add_attention,
186
+ segmap_channels=segmap_channels,
187
+ )
188
+
189
+ # up
190
+ reversed_block_out_channels = list(reversed(block_out_channels))
191
+ output_channel = reversed_block_out_channels[0]
192
+ for i, up_block_type in enumerate(up_block_types):
193
+ prev_output_channel = output_channel
194
+ output_channel = reversed_block_out_channels[i]
195
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
196
+
197
+ is_final_block = i == len(block_out_channels) - 1
198
+
199
+ up_block = get_up_block(
200
+ up_block_type,
201
+ num_layers=layers_per_block + 1,
202
+ in_channels=input_channel,
203
+ out_channels=output_channel,
204
+ prev_output_channel=prev_output_channel,
205
+ temb_channels=time_embed_dim,
206
+ add_upsample=not is_final_block,
207
+ resnet_eps=norm_eps,
208
+ resnet_act_fn=act_fn,
209
+ resnet_groups=norm_num_groups,
210
+ attn_num_head_channels=attention_head_dim,
211
+ resnet_time_scale_shift=resnet_time_scale_shift,
212
+ segmap_channels=segmap_channels,
213
+ )
214
+ self.up_blocks.append(up_block)
215
+ prev_output_channel = output_channel
216
+
217
+ # out
218
+ num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
219
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps)
220
+ self.conv_act = nn.SiLU()
221
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
222
+ def _set_gradient_checkpointing(self, module, value=False):
223
+ #if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
224
+ module.gradient_checkpointing = value
225
+ def forward(
226
+ self,
227
+ sample: torch.FloatTensor,
228
+ segmap: torch.FloatTensor,
229
+ timestep: Union[torch.Tensor, float, int],
230
+ class_labels: Optional[torch.Tensor] = None,
231
+ return_dict: bool = True,
232
+ ) -> Union[UNet2DOutput, Tuple]:
233
+ r"""
234
+ Args:
235
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
236
+ timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps
237
+ class_labels (`torch.FloatTensor`, *optional*, defaults to `None`):
238
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
239
+ return_dict (`bool`, *optional*, defaults to `True`):
240
+ Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
241
+
242
+ Returns:
243
+ [`~models.unet_2d.UNet2DOutput`] or `tuple`: [`~models.unet_2d.UNet2DOutput`] if `return_dict` is True,
244
+ otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
245
+ """
246
+ # 0. center input if necessary
247
+ if self.config.center_input_sample:
248
+ sample = 2 * sample - 1.0
249
+
250
+ # 1. time
251
+ #print(timestep.shape)
252
+ timesteps = timestep
253
+ if not torch.is_tensor(timesteps):
254
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
255
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
256
+ timesteps = timesteps[None].to(sample.device)
257
+
258
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
259
+ timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
260
+
261
+ t_emb = self.time_proj(timesteps)
262
+
263
+ # timesteps does not contain any weights and will always return f32 tensors
264
+ # but time_embedding might actually be running in fp16. so we need to cast here.
265
+ # there might be better ways to encapsulate this.
266
+ t_emb = t_emb.to(dtype=self.dtype)
267
+ emb = self.time_embedding(t_emb)
268
+
269
+ if self.class_embedding is not None:
270
+ if class_labels is None:
271
+ raise ValueError("class_labels should be provided when doing class conditioning")
272
+
273
+ if self.config.class_embed_type == "timestep":
274
+ class_labels = self.time_proj(class_labels)
275
+
276
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
277
+ emb = emb + class_emb
278
+
279
+ # 2. pre-process
280
+ skip_sample = sample
281
+ sample = self.conv_in(sample)
282
+
283
+ # 3. down
284
+ down_block_res_samples = (sample,)
285
+ for downsample_block in self.down_blocks:
286
+ if hasattr(downsample_block, "skip_conv"):
287
+ sample, res_samples, skip_sample = downsample_block(
288
+ hidden_states=sample, temb=emb, skip_sample=skip_sample,segmap=segmap
289
+ )
290
+ else:
291
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
292
+
293
+ down_block_res_samples += res_samples
294
+
295
+ # 4. mid
296
+ sample = self.mid_block(sample, segmap, emb)
297
+
298
+ # 5. up
299
+ skip_sample = None
300
+ for upsample_block in self.up_blocks:
301
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
302
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
303
+
304
+ if hasattr(upsample_block, "skip_conv"):
305
+ sample, skip_sample = upsample_block(sample, segmap, res_samples, emb, skip_sample)
306
+ else:
307
+ sample = upsample_block(sample, segmap, res_samples, emb)
308
+
309
+ # 6. post-process
310
+ sample = self.conv_norm_out(sample)
311
+ sample = self.conv_act(sample)
312
+ sample = self.conv_out(sample)
313
+
314
+ if skip_sample is not None:
315
+ sample += skip_sample
316
+
317
+ if self.config.time_embedding_type == "fourier":
318
+ timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
319
+ sample = sample / timesteps
320
+
321
+ if not return_dict:
322
+ return (sample,)
323
+
324
+ return UNet2DOutput(sample=sample)
325
+
326
+
327
+ if __name__ == "__main__":
328
+ path = 'output.txt'
329
+ f = open(path, 'w')
330
+
331
+ unet = SDMUNet2DModel(
332
+ sample_size=270,
333
+ in_channels=3,
334
+ out_channels=3,
335
+ layers_per_block=2,
336
+ block_out_channels=(256, 256, 512, 1024, 1024),
337
+ down_block_types=(
338
+ "ResnetDownsampleBlock2D",
339
+ "ResnetDownsampleBlock2D",
340
+ "ResnetDownsampleBlock2D",
341
+ "AttnDownBlock2D",
342
+ "AttnDownBlock2D",
343
+ ),
344
+ up_block_types=(
345
+ "SDMAttnUpBlock2D",
346
+ "SDMAttnUpBlock2D",
347
+ "SDMResnetUpsampleBlock2D",
348
+ "SDMResnetUpsampleBlock2D",
349
+ "SDMResnetUpsampleBlock2D",
350
+ ),
351
+ segmap_channels=34+1
352
+ )
353
+
354
+ print(unet,file=f)
355
+ f.close()
356
+
357
+ #summary(unet, [(1, 3, 270, 360), (1, 3, 270, 360), (2,)], device="cpu")
diffusion_module/utils/LSDMPipeline_expandDataset.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import torch
5
+
6
+ from diffusers.models import UNet2DModel, VQModel
7
+ from diffusers.schedulers import DDIMScheduler
8
+ from diffusers.utils import randn_tensor
9
+ from diffusers.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
10
+ import copy
11
+
12
+ class SDMLDMPipeline(DiffusionPipeline):
13
+ r"""
14
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
15
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
16
+
17
+ Parameters:
18
+ vae ([`VQModel`]):
19
+ Vector-quantized (VQ) Model to encode and decode images to and from latent representations.
20
+ unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image latents.
21
+ scheduler ([`SchedulerMixin`]):
22
+ [`DDIMScheduler`] is to be used in combination with `unet` to denoise the encoded image latents.
23
+ """
24
+
25
+ def __init__(self, vae: VQModel, unet: UNet2DModel, scheduler: DDIMScheduler, torch_dtype=torch.float16, resolution=512, resolution_type="city"):
26
+ super().__init__()
27
+ self.register_modules(vae=vae, unet=unet, scheduler=scheduler)
28
+ self.torch_dtype = torch_dtype
29
+ self.resolution = resolution
30
+ self.resolution_type = resolution_type
31
+ @torch.no_grad()
32
+ def __call__(
33
+ self,
34
+ segmap = None,
35
+ batch_size: int = 8,
36
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
37
+ eta: float = 0.0,
38
+ num_inference_steps: int = 1000,
39
+ output_type: Optional[str] = "pil",
40
+ return_dict: bool = True,
41
+ every_step_save: int = None,
42
+ s: int = 1,
43
+ num_evolution_per_mask = 10,
44
+ debug = False,
45
+ **kwargs,
46
+ ) -> Union[Tuple, ImagePipelineOutput]:
47
+ r"""
48
+ Args:
49
+ batch_size (`int`, *optional*, defaults to 1):
50
+ Number of images to generate.
51
+ generator (`torch.Generator`, *optional*):
52
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
53
+ to make generation deterministic.
54
+ num_inference_steps (`int`, *optional*, defaults to 50):
55
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
56
+ expense of slower inference.
57
+ output_type (`str`, *optional*, defaults to `"pil"`):
58
+ The output format of the generate image. Choose between
59
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
60
+ return_dict (`bool`, *optional*, defaults to `True`):
61
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
62
+
63
+ Returns:
64
+ [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.model.ImagePipelineOutput`] if `return_dict` is
65
+ True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
66
+ """
67
+ # self.unet.config.sample_size = (64, 64) # (135,180)
68
+ # self.unet.config.sample_size = (135,180)
69
+ if self.resolution_type == "crack":
70
+ self.unet.config.sample_size = (64,64)
71
+ elif self.resolution_type == "crack_256":
72
+ self.unet.config.sample_size = (256,256)
73
+ else:
74
+ sc = 1080 // self.resolution
75
+ latent_size = (self.resolution // 4, 1440 // (sc*4))
76
+ self.unet.config.sample_size = latent_size
77
+ #
78
+ if not isinstance(self.unet.config.sample_size, tuple):
79
+ self.unet.config.sample_size = (self.unet.config.sample_size, self.unet.config.sample_size)
80
+
81
+ if segmap is None:
82
+ print("Didn't inpute any segmap, use the empty as the input")
83
+ segmap = torch.zeros(batch_size,self.unet.config.segmap_channels, self.unet.config.sample_size[0], self.unet.config.sample_size[1])
84
+ segmap = segmap.to(self.device).type(self.torch_dtype)
85
+ if batch_size == 1 and num_evolution_per_mask > batch_size:
86
+ latents = randn_tensor(
87
+ (num_evolution_per_mask, self.unet.config.in_channels, self.unet.config.sample_size[0], self.unet.config.sample_size[1]),
88
+ generator=generator,
89
+ )
90
+ else:
91
+ latents = randn_tensor(
92
+ (batch_size, self.unet.config.in_channels, self.unet.config.sample_size[0], self.unet.config.sample_size[1]),
93
+ generator=generator,
94
+ )
95
+ latents = latents.to(self.device).type(self.torch_dtype)
96
+
97
+ # scale the initial noise by the standard deviation required by the scheduler (need to check)
98
+ latents = latents * self.scheduler.init_noise_sigma
99
+
100
+ self.scheduler.set_timesteps(num_inference_steps=num_inference_steps)
101
+
102
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
103
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
104
+
105
+ extra_kwargs = {}
106
+ if accepts_eta:
107
+ extra_kwargs["eta"] = eta
108
+
109
+ step_latent = []
110
+ learn_sigma = True if hasattr(self.scheduler, "variance_type") else False
111
+ if debug:
112
+ extra_list_list = []
113
+ self.unet.debug=True
114
+ for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
115
+
116
+ latent_model_input = self.scheduler.scale_model_input(latents, t)
117
+ # predict the noise residual
118
+ if debug:
119
+ output, extra_list = self.unet(latent_model_input, segmap, t)
120
+ noise_prediction = output.sample
121
+ extra_list_list.append(extra_list)
122
+ else:
123
+ noise_prediction = self.unet(latent_model_input, segmap, t).sample
124
+ # compute the previous noisy sample x_t -> x_t-1
125
+
126
+
127
+ if learn_sigma and "learn" in self.scheduler.variance_type:
128
+ model_pred, var_pred = torch.split(noise_prediction, latents.shape[1], dim=1)
129
+ else:
130
+ model_pred = noise_prediction
131
+ if s > 1.0:
132
+ if debug:
133
+ model_output_zero = self.unet(latent_model_input, torch.zeros_like(segmap), t)[0].sample
134
+ else:
135
+ model_output_zero = self.unet(latent_model_input, torch.zeros_like(segmap), t).sample
136
+ if learn_sigma and "learn" in self.scheduler.variance_type:
137
+ model_output_zero,_ = torch.split(model_output_zero, latents.shape[1], dim=1)
138
+ model_pred = model_pred + s * (model_pred - model_output_zero)
139
+ if learn_sigma and "learn" in self.scheduler.variance_type:
140
+ recombined = torch.cat((model_pred, var_pred), dim=1)
141
+ # when apply different scheduler, mean only !!
142
+ if learn_sigma and "learn" in self.scheduler.variance_type:
143
+ latents = self.scheduler.step(recombined, t, latents, **extra_kwargs).prev_sample
144
+ else:
145
+ latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs).prev_sample
146
+
147
+ if every_step_save is not None:
148
+ if (i+1) % every_step_save == 0:
149
+ step_latent.append(copy.deepcopy(latents))
150
+
151
+ if debug:
152
+ return extra_list_list[-1]
153
+
154
+ # decode the image latents with the VAE
155
+ if every_step_save is not None:
156
+ image = []
157
+ for i, l in enumerate(step_latent):
158
+ l /= self.vae.config.scaling_factor # (0.18215)
159
+ #latents /= 7.706491063029163
160
+ l = self.vae.decode(l, segmap)
161
+ l = (l / 2 + 0.5).clamp(0, 1)
162
+ l = l.cpu().permute(0, 2, 3, 1).numpy()
163
+ if output_type == "pil":
164
+ l = self.numpy_to_pil(l)
165
+ image.append(l)
166
+ else:
167
+ latents /= self.vae.config.scaling_factor#(0.18215)
168
+ #latents /= 7.706491063029163
169
+ # image = self.vae.decode(latents, segmap).sample
170
+ image = self.vae.decode(latents, return_dict=False)[0]
171
+ image = (image / 2 + 0.5).clamp(0, 1)
172
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
173
+ if output_type == "pil":
174
+ image = self.numpy_to_pil(image)
175
+
176
+ if not return_dict:
177
+ return (image,)
178
+
179
+ return ImagePipelineOutput(images=image)
diffusion_module/utils/Pipline.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import torch
5
+
6
+ from diffusers.models import UNet2DModel, VQModel
7
+ from diffusers.schedulers import DDIMScheduler
8
+ from diffusers.utils import randn_tensor
9
+ from diffusers.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
10
+ import copy
11
+
12
+ class LDMPipeline(DiffusionPipeline):
13
+ r"""
14
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
15
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
16
+
17
+ Parameters:
18
+ vae ([`VQModel`]):
19
+ Vector-quantized (VQ) Model to encode and decode images to and from latent representations.
20
+ unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image latents.
21
+ scheduler ([`SchedulerMixin`]):
22
+ [`DDIMScheduler`] is to be used in combination with `unet` to denoise the encoded image latents.
23
+ """
24
+
25
+ def __init__(self, vae: VQModel, unet: UNet2DModel, scheduler: DDIMScheduler, torch_dtype=torch.float16):
26
+ super().__init__()
27
+ self.register_modules(vae=vae, unet=unet, scheduler=scheduler)
28
+ self.torch_dtype = torch_dtype
29
+
30
+ @torch.no_grad()
31
+ def __call__(
32
+ self,
33
+ batch_size: int = 8,
34
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
35
+ eta: float = 0.0,
36
+ num_inference_steps: int = 1000,
37
+ output_type: Optional[str] = "pil",
38
+ return_dict: bool = True,
39
+ **kwargs,
40
+ ) -> Union[Tuple, ImagePipelineOutput]:
41
+ r"""
42
+ Args:
43
+ batch_size (`int`, *optional*, defaults to 1):
44
+ Number of images to generate.
45
+ generator (`torch.Generator`, *optional*):
46
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
47
+ to make generation deterministic.
48
+ num_inference_steps (`int`, *optional*, defaults to 50):
49
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
50
+ expense of slower inference.
51
+ output_type (`str`, *optional*, defaults to `"pil"`):
52
+ The output format of the generate image. Choose between
53
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
54
+ return_dict (`bool`, *optional*, defaults to `True`):
55
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
56
+
57
+ Returns:
58
+ [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.model.ImagePipelineOutput`] if `return_dict` is
59
+ True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
60
+ """
61
+ if not isinstance(self.unet.config.sample_size,tuple):
62
+ self.unet.config.sample_size = (self.unet.config.sample_size,self.unet.config.sample_size)
63
+
64
+ latents = randn_tensor(
65
+ (batch_size, self.unet.config.in_channels, self.unet.config.sample_size[0], self.unet.config.sample_size[1]),
66
+ generator=generator,
67
+ )
68
+ latents = latents.to(self.device).type(self.torch_dtype)
69
+
70
+ # scale the initial noise by the standard deviation required by the scheduler (need to check)
71
+ latents = latents * self.scheduler.init_noise_sigma
72
+
73
+ self.scheduler.set_timesteps(num_inference_steps)
74
+
75
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
76
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
77
+
78
+ extra_kwargs = {}
79
+ if accepts_eta:
80
+ extra_kwargs["eta"] = eta
81
+
82
+ for t in self.progress_bar(self.scheduler.timesteps):
83
+ latent_model_input = self.scheduler.scale_model_input(latents, t)
84
+ # predict the noise residual
85
+ noise_prediction = self.unet(latent_model_input, t).sample
86
+ # compute the previous noisy sample x_t -> x_t-1
87
+ latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs).prev_sample
88
+
89
+ # decode the image latents with the VAE
90
+ latents /= self.vae.config.scaling_factor#(0.18215)
91
+ image = self.vae.decode(latents).sample
92
+
93
+ image = (image / 2 + 0.5).clamp(0, 1)
94
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
95
+ if output_type == "pil":
96
+ image = self.numpy_to_pil(image)
97
+
98
+ if not return_dict:
99
+ return (image,)
100
+
101
+ return ImagePipelineOutput(images=image)
102
+
103
+
104
+ class SDMLDMPipeline(DiffusionPipeline):
105
+ r"""
106
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
107
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
108
+
109
+ Parameters:
110
+ vae ([`VQModel`]):
111
+ Vector-quantized (VQ) Model to encode and decode images to and from latent representations.
112
+ unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image latents.
113
+ scheduler ([`SchedulerMixin`]):
114
+ [`DDIMScheduler`] is to be used in combination with `unet` to denoise the encoded image latents.
115
+ """
116
+
117
+ def __init__(self, vae: VQModel, unet: UNet2DModel, scheduler: DDIMScheduler, torch_dtype=torch.float16, resolution=512, resolution_type="city"):
118
+ super().__init__()
119
+ self.register_modules(vae=vae, unet=unet, scheduler=scheduler)
120
+ self.torch_dtype = torch_dtype
121
+ self.resolution = resolution
122
+ self.resolution_type = resolution_type
123
+ @torch.no_grad()
124
+ def __call__(
125
+ self,
126
+ segmap = None,
127
+ batch_size: int = 8,
128
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
129
+ eta: float = 0.0,
130
+ num_inference_steps: int = 1000,
131
+ output_type: Optional[str] = "pil",
132
+ return_dict: bool = True,
133
+ every_step_save: int = None,
134
+ s: int = 1,
135
+ **kwargs,
136
+ ) -> Union[Tuple, ImagePipelineOutput]:
137
+ r"""
138
+ Args:
139
+ batch_size (`int`, *optional*, defaults to 1):
140
+ Number of images to generate.
141
+ generator (`torch.Generator`, *optional*):
142
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
143
+ to make generation deterministic.
144
+ num_inference_steps (`int`, *optional*, defaults to 50):
145
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
146
+ expense of slower inference.
147
+ output_type (`str`, *optional*, defaults to `"pil"`):
148
+ The output format of the generate image. Choose between
149
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
150
+ return_dict (`bool`, *optional*, defaults to `True`):
151
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
152
+
153
+ Returns:
154
+ [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.model.ImagePipelineOutput`] if `return_dict` is
155
+ True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
156
+ """
157
+ # self.unet.config.sample_size = (64, 64) # (135,180)
158
+ # self.unet.config.sample_size = (135,180)
159
+ if self.resolution_type == "crack":
160
+ self.unet.config.sample_size = (64,64)
161
+ elif self.resolution_type == "crack_256":
162
+ self.unet.config.sample_size = (256,256)
163
+ else:
164
+ sc = 1080 // self.resolution
165
+ latent_size = (self.resolution // 4, 1440 // (sc*4))
166
+ self.unet.config.sample_size = latent_size
167
+ #
168
+ if not isinstance(self.unet.config.sample_size, tuple):
169
+ self.unet.config.sample_size = (self.unet.config.sample_size, self.unet.config.sample_size)
170
+
171
+ if segmap is None:
172
+ print("Didn't inpute any segmap, use the empty as the input")
173
+ segmap = torch.zeros(batch_size,self.unet.config.segmap_channels, self.unet.config.sample_size[0], self.unet.config.sample_size[1])
174
+ segmap = segmap.to(self.device).type(self.torch_dtype)
175
+ latents = randn_tensor(
176
+ (batch_size, self.unet.config.in_channels, self.unet.config.sample_size[0], self.unet.config.sample_size[1]),
177
+ generator=generator,
178
+ )
179
+ latents = latents.to(self.device).type(self.torch_dtype)
180
+
181
+ # scale the initial noise by the standard deviation required by the scheduler (need to check)
182
+ latents = latents * self.scheduler.init_noise_sigma
183
+
184
+ self.scheduler.set_timesteps(num_inference_steps=num_inference_steps)
185
+
186
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
187
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
188
+
189
+ extra_kwargs = {}
190
+ if accepts_eta:
191
+ extra_kwargs["eta"] = eta
192
+
193
+ step_latent = []
194
+ learn_sigma = True if hasattr(self.scheduler, "variance_type") else False
195
+ for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
196
+
197
+ latent_model_input = self.scheduler.scale_model_input(latents, t)
198
+ # predict the noise residual
199
+ noise_prediction = self.unet(latent_model_input, segmap, t).sample
200
+ # compute the previous noisy sample x_t -> x_t-1
201
+
202
+
203
+ if learn_sigma and "learn" in self.scheduler.variance_type:
204
+ model_pred, var_pred = torch.split(noise_prediction, latents.shape[1], dim=1)
205
+ else:
206
+ model_pred = noise_prediction
207
+ if s > 1.0:
208
+ model_output_zero = self.unet(latent_model_input, torch.zeros_like(segmap), t).sample
209
+ if learn_sigma and "learn" in self.scheduler.variance_type:
210
+ model_output_zero,_ = torch.split(model_output_zero, latents.shape[1], dim=1)
211
+ model_pred = model_pred + s * (model_pred - model_output_zero)
212
+ if learn_sigma and "learn" in self.scheduler.variance_type:
213
+ recombined = torch.cat((model_pred, var_pred), dim=1)
214
+ # when apply different scheduler, mean only !!
215
+ if learn_sigma and "learn" in self.scheduler.variance_type:
216
+ latents = self.scheduler.step(recombined, t, latents, **extra_kwargs).prev_sample
217
+ else:
218
+ latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs).prev_sample
219
+
220
+ if every_step_save is not None:
221
+ if (i+1) % every_step_save == 0:
222
+ step_latent.append(copy.deepcopy(latents))
223
+
224
+ # decode the image latents with the VAE
225
+ if every_step_save is not None:
226
+ image = []
227
+ for i, l in enumerate(step_latent):
228
+ l /= self.vae.config.scaling_factor # (0.18215)
229
+ #latents /= 7.706491063029163
230
+ l = self.vae.decode(l, segmap)
231
+ l = (l / 2 + 0.5).clamp(0, 1)
232
+ l = l.cpu().permute(0, 2, 3, 1).numpy()
233
+ if output_type == "pil":
234
+ l = self.numpy_to_pil(l)
235
+ image.append(l)
236
+ else:
237
+ latents /= self.vae.config.scaling_factor#(0.18215)
238
+ #latents /= 7.706491063029163
239
+ # image = self.vae.decode(latents, segmap).sample
240
+ image = self.vae.decode(latents, return_dict=False)[0]
241
+ image = (image / 2 + 0.5).clamp(0, 1)
242
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
243
+ if output_type == "pil":
244
+ image = self.numpy_to_pil(image)
245
+
246
+ if not return_dict:
247
+ return (image,)
248
+
249
+ return ImagePipelineOutput(images=image)
250
+
251
+
252
+ class SDMPipeline(DiffusionPipeline):
253
+ r"""
254
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
255
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
256
+
257
+ Parameters:
258
+ vae ([`VQModel`]):
259
+ Vector-quantized (VQ) Model to encode and decode images to and from latent representations.
260
+ unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image latents.
261
+ scheduler ([`SchedulerMixin`]):
262
+ [`DDIMScheduler`] is to be used in combination with `unet` to denoise the encoded image latents.
263
+ """
264
+
265
+ def __init__(self, unet: UNet2DModel, scheduler: DDIMScheduler, torch_dtype=torch.float16, vae=None):
266
+ super().__init__()
267
+ self.register_modules(unet=unet, scheduler=scheduler)
268
+ self.torch_dtype = torch_dtype
269
+
270
+ @torch.no_grad()
271
+ def __call__(
272
+ self,
273
+ segmap = None,
274
+ batch_size: int = 8,
275
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
276
+ eta: float = 0.0,
277
+ num_inference_steps: int = 1000,
278
+ output_type: Optional[str] = "pil",
279
+ return_dict: bool = True,
280
+ s: int = 1,
281
+ **kwargs,
282
+ ) -> Union[Tuple, ImagePipelineOutput]:
283
+ r"""
284
+ Args:
285
+ batch_size (`int`, *optional*, defaults to 1):
286
+ Number of images to generate.
287
+ generator (`torch.Generator`, *optional*):
288
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
289
+ to make generation deterministic.
290
+ num_inference_steps (`int`, *optional*, defaults to 50):
291
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
292
+ expense of slower inference.
293
+ output_type (`str`, *optional*, defaults to `"pil"`):
294
+ The output format of the generate image. Choose between
295
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
296
+ return_dict (`bool`, *optional*, defaults to `True`):
297
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
298
+
299
+ Returns:
300
+ [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.model.ImagePipelineOutput`] if `return_dict` is
301
+ True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
302
+ """
303
+ self.unet.config.sample_size = (270,360)
304
+ if not isinstance(self.unet.config.sample_size, tuple):
305
+ self.unet.config.sample_size = (self.unet.config.sample_size, self.unet.config.sample_size)
306
+
307
+ if segmap is None:
308
+ print("Didn't inpute any segmap, use the empty as the input")
309
+ segmap = torch.zeros(batch_size,self.unet.config.segmap_channels, self.unet.config.sample_size[0], self.unet.config.sample_size[1])
310
+ segmap = segmap.to(self.device).type(self.torch_dtype)
311
+ latents = randn_tensor(
312
+ (batch_size, self.unet.config.in_channels, self.unet.config.sample_size[0], self.unet.config.sample_size[1]),
313
+ generator=generator,
314
+ )
315
+
316
+ latents = latents.to(self.device).type(self.torch_dtype)
317
+
318
+ # scale the initial noise by the standard deviation required by the scheduler (need to check)
319
+ latents = latents * self.scheduler.init_noise_sigma
320
+
321
+ self.scheduler.set_timesteps(num_inference_steps)
322
+
323
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
324
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
325
+
326
+ extra_kwargs = {}
327
+ if accepts_eta:
328
+ extra_kwargs["eta"] = eta
329
+
330
+ for t in self.progress_bar(self.scheduler.timesteps):
331
+ latent_model_input = self.scheduler.scale_model_input(latents, t)
332
+ # predict the noise residual
333
+ noise_prediction = self.unet(latent_model_input, segmap, t).sample
334
+
335
+ #noise_prediction = noise_prediction[]
336
+
337
+ if s > 1.0:
338
+ model_output_zero = self.unet(latent_model_input, torch.zeros_like(segmap), t).sample
339
+ noise_prediction[:, :3] = model_output_zero[:, :3] + s * (noise_prediction[:, :3] - model_output_zero[:, :3])
340
+
341
+ #noise_prediction = noise_prediction[:, :3]
342
+
343
+ # compute the previous noisy sample x_t -> x_t-1
344
+ #breakpoint()
345
+ latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs).prev_sample
346
+
347
+ # decode the image latents with the VAE
348
+ # latents /= self.vae.config.scaling_factor#(0.18215)
349
+ # image = self.vae.decode(latents).sample
350
+ image = latents
351
+ #image = (image + 1) / 2.0
352
+ image = (image / 2 + 0.5).clamp(0, 1)
353
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
354
+ if output_type == "pil":
355
+ image = self.numpy_to_pil(image)
356
+
357
+ if not return_dict:
358
+ return (image,)
359
+
360
+ return ImagePipelineOutput(images=image)
361
+
diffusion_module/utils/__pycache__/LSDMPipeline_expandDataset.cpython-39.pyc ADDED
Binary file (5.82 kB). View file
 
diffusion_module/utils/__pycache__/Pipline.cpython-310.pyc ADDED
Binary file (8.22 kB). View file
 
diffusion_module/utils/__pycache__/Pipline.cpython-39.pyc ADDED
Binary file (8.52 kB). View file
 
diffusion_module/utils/__pycache__/loss.cpython-39.pyc ADDED
Binary file (4.06 kB). View file
 
diffusion_module/utils/loss.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helpers for various likelihood-based losses. These are ported from the original
3
+ Ho et al. diffusion models codebase:
4
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py
5
+ """
6
+
7
+ import numpy as np
8
+
9
+ import torch as th
10
+
11
+
12
+ def normal_kl(mean1, logvar1, mean2, logvar2):
13
+ """
14
+ Compute the KL divergence between two gaussians.
15
+
16
+ Shapes are automatically broadcasted, so batches can be compared to
17
+ scalars, among other use cases.
18
+ """
19
+ tensor = None
20
+ for obj in (mean1, logvar1, mean2, logvar2):
21
+ if isinstance(obj, th.Tensor):
22
+ tensor = obj
23
+ break
24
+ assert tensor is not None, "at least one argument must be a Tensor"
25
+
26
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
27
+ # Tensors, but it does not work for th.exp().
28
+ logvar1, logvar2 = [
29
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
30
+ for x in (logvar1, logvar2)
31
+ ]
32
+
33
+ return 0.5 * (
34
+ -1.0
35
+ + logvar2
36
+ - logvar1
37
+ + th.exp(logvar1 - logvar2)
38
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
39
+ )
40
+
41
+
42
+ def approx_standard_normal_cdf(x):
43
+ """
44
+ A fast approximation of the cumulative distribution function of the
45
+ standard normal.
46
+ """
47
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
48
+
49
+
50
+ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
51
+ """
52
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
53
+ given image.
54
+
55
+ :param x: the target images. It is assumed that this was uint8 values,
56
+ rescaled to the range [-1, 1].
57
+ :param means: the Gaussian mean Tensor.
58
+ :param log_scales: the Gaussian log stddev Tensor.
59
+ :return: a tensor like x of log probabilities (in nats).
60
+ """
61
+ assert x.shape == means.shape == log_scales.shape
62
+ centered_x = x - means
63
+ inv_stdv = th.exp(-log_scales)
64
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
65
+ cdf_plus = approx_standard_normal_cdf(plus_in)
66
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
67
+ cdf_min = approx_standard_normal_cdf(min_in)
68
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
69
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
70
+ cdf_delta = cdf_plus - cdf_min
71
+ log_probs = th.where(
72
+ x < -0.999,
73
+ log_cdf_plus,
74
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
75
+ )
76
+ assert log_probs.shape == x.shape
77
+ return log_probs
78
+
79
+ def variance_KL_loss(latents, noisy_latents, timesteps, model_pred_mean, model_pred_var, noise_scheduler,posterior_mean_coef1, posterior_mean_coef2, posterior_log_variance_clipped):
80
+ model_pred_mean = model_pred_mean.detach()
81
+ true_mean = (
82
+ posterior_mean_coef1.to(device=timesteps.device)[timesteps].float()[..., None, None, None] * latents
83
+ + posterior_mean_coef2.to(device=timesteps.device)[timesteps].float()[..., None, None, None] * noisy_latents
84
+ )
85
+
86
+ true_log_variance_clipped = posterior_log_variance_clipped.to(device=timesteps.device)[timesteps].float()[
87
+ ..., None, None, None]
88
+
89
+ if noise_scheduler.variance_type == "learned":
90
+ model_log_variance = model_pred_var
91
+ #model_pred_var = th.exp(model_log_variance)
92
+ else:
93
+ min_log = true_log_variance_clipped
94
+ max_log = th.log(noise_scheduler.betas.to(device=timesteps.device)[timesteps].float()[..., None, None, None])
95
+ frac = (model_pred_var + 1) / 2
96
+ model_log_variance = frac * max_log + (1 - frac) * min_log
97
+ #model_pred_var = th.exp(model_log_variance)
98
+
99
+ sqrt_recip_alphas_cumprod = th.sqrt(1.0 / noise_scheduler.alphas_cumprod)
100
+ sqrt_recipm1_alphas_cumprod = th.sqrt(1.0 / noise_scheduler.alphas_cumprod - 1)
101
+
102
+ pred_xstart = (sqrt_recip_alphas_cumprod.to(device=timesteps.device)[timesteps].float()[
103
+ ..., None, None, None] * noisy_latents
104
+ - sqrt_recipm1_alphas_cumprod.to(device=timesteps.device)[timesteps].float()[
105
+ ..., None, None, None] * model_pred_mean)
106
+
107
+ model_mean = (
108
+ posterior_mean_coef1.to(device=timesteps.device)[timesteps].float()[..., None, None, None] * pred_xstart
109
+ + posterior_mean_coef2.to(device=timesteps.device)[timesteps].float()[..., None, None, None] * noisy_latents
110
+ )
111
+
112
+ # model_mean = out["mean"] model_log_variance = out["log_variance"]
113
+ kl = normal_kl(
114
+ true_mean, true_log_variance_clipped, model_mean, model_log_variance
115
+ )
116
+ kl = kl.mean() / np.log(2.0)
117
+
118
+ decoder_nll = -discretized_gaussian_log_likelihood(
119
+ latents, means=model_mean, log_scales=0.5 * model_log_variance
120
+ )
121
+ assert decoder_nll.shape == latents.shape
122
+ decoder_nll = decoder_nll.mean() / np.log(2.0)
123
+
124
+ # At the first timestep return the decoder NLL,
125
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
126
+ kl_loss = th.where((timesteps == 0), decoder_nll, kl).mean()
127
+ return kl_loss
128
+
129
+ def get_variance(noise_scheduler):
130
+ alphas_cumprod_prev = th.cat([th.tensor([1.0]), noise_scheduler.alphas_cumprod[:-1]])
131
+
132
+ posterior_mean_coef1 = (
133
+ noise_scheduler.betas * th.sqrt(alphas_cumprod_prev) / (1.0 - noise_scheduler.alphas_cumprod)
134
+ )
135
+
136
+ posterior_mean_coef2 = (
137
+ (1.0 - alphas_cumprod_prev)
138
+ * th.sqrt(noise_scheduler.alphas)
139
+ / (1.0 - noise_scheduler.alphas_cumprod)
140
+ )
141
+
142
+ posterior_variance = (
143
+ noise_scheduler.betas * (1.0 - alphas_cumprod_prev) / (1.0 - noise_scheduler.alphas_cumprod)
144
+ )
145
+ posterior_log_variance_clipped = th.log(
146
+ th.cat([posterior_variance[1][..., None], posterior_variance[1:]])
147
+ )
148
+ #res = posterior_log_variance_clipped.to(device=timesteps.device)[timesteps].float()
149
+ return posterior_mean_coef1, posterior_mean_coef2, posterior_log_variance_clipped #res[..., None, None, None]
diffusion_module/utils/noise_sampler.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn.mixture import GaussianMixture
2
+
3
+ def get_noise_sampler(sample_type='gau'):
4
+ if sample_type == 'gau':
5
+ sampler = lambda latnt_sz: torch.randn_like(latnt_sz)
6
+ elif sample_type == 'gau_offset':
7
+ sampler = lambda latnt_sz: torch.randn_like(latnt_sz) + (torch.randn_like(latnt_sz))
8
+ ...
9
+ elif sample_type == 'gmm':
10
+ ...
11
+ else:
12
+ ...
13
+ return
14
+
15
+ if __name__ == "__main__":
16
+ ...
diffusion_module/utils/scheduler_factory.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from diffusers import DDPMScheduler, DPMSolverMultistepScheduler, UniPCMultistepScheduler, DPMSolverSinglestepScheduler
3
+ from diffusers.pipeline_utils import DiffusionPipeline
4
+ import torch
5
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
6
+ from typing import List, Optional, Tuple, Union
7
+ import numpy as np
8
+ from diffusers.schedulers.scheduling_utils import SchedulerOutput
9
+ from diffusers.schedulers.scheduling_ddpm import DDPMSchedulerOutput
10
+ from diffusers.utils import randn_tensor, BaseOutput
11
+
12
+
13
+ ### Testing the DDPM Scheduler for Variant
14
+ class ModifiedDDPMScheduler(DDPMScheduler):
15
+ def __init__(self, *args, **kwargs):
16
+ super().__init__(*args, **kwargs)
17
+
18
+ def step(
19
+ self,
20
+ model_output: torch.FloatTensor,
21
+ timestep: int,
22
+ sample: torch.FloatTensor,
23
+ generator=None,
24
+ return_dict: bool = True,
25
+ ) -> Union[DDPMSchedulerOutput, Tuple]:
26
+ """
27
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
28
+ process from the learned model outputs (most often the predicted noise).
29
+
30
+ Args:
31
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
32
+ timestep (`int`): current discrete timestep in the diffusion chain.
33
+ sample (`torch.FloatTensor`):
34
+ current instance of sample being created by diffusion process.
35
+ generator: random number generator.
36
+ return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class
37
+
38
+ Returns:
39
+ [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] or `tuple`:
40
+ [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
41
+ returning a tuple, the first element is the sample tensor.
42
+
43
+ """
44
+ t = timestep
45
+
46
+ prev_t = self.previous_timestep(t)
47
+
48
+ if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
49
+ print("Conidtion is trigger")
50
+
51
+ model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
52
+ # [2,3, 64, 128]
53
+ else:
54
+ predicted_variance = None
55
+
56
+ # 1. compute alphas, betas
57
+ alpha_prod_t = self.alphas_cumprod[t]
58
+ alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
59
+ beta_prod_t = 1 - alpha_prod_t
60
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
61
+ current_alpha_t = alpha_prod_t / alpha_prod_t_prev
62
+ current_beta_t = 1 - current_alpha_t
63
+
64
+ # 2. compute predicted original sample from predicted noise also called
65
+ # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
66
+ if self.config.prediction_type == "epsilon":
67
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
68
+
69
+ elif self.config.prediction_type == "sample":
70
+ pred_original_sample = model_output
71
+ elif self.config.prediction_type == "v_prediction":
72
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
73
+ else:
74
+ raise ValueError(
75
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
76
+ " `v_prediction` for the DDPMScheduler."
77
+ )
78
+
79
+ # 3. Clip or threshold "predicted x_0"
80
+ if self.config.thresholding:
81
+ pred_original_sample = self._threshold_sample(pred_original_sample)
82
+ elif self.config.clip_sample:
83
+ pred_original_sample = pred_original_sample.clamp(
84
+ -self.config.clip_sample_range, self.config.clip_sample_range
85
+ )
86
+
87
+ # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
88
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
89
+ pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
90
+ current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
91
+
92
+ # 5. Compute predicted previous sample µ_t
93
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
94
+ pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
95
+
96
+ # 6. Add noise
97
+ variance = 0
98
+ if t > 0:
99
+ device = model_output.device
100
+ variance_noise = randn_tensor(
101
+ model_output.shape, generator=generator, device=device, dtype=model_output.dtype
102
+ )
103
+ if self.variance_type == "fixed_small_log":
104
+ variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise
105
+
106
+ elif self.variance_type == "learned_range":
107
+ variance = self._get_variance(t, predicted_variance=predicted_variance)
108
+ variance = torch.exp(0.5 * variance) * variance_noise
109
+
110
+ else:
111
+ variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise
112
+
113
+ pred_prev_sample = pred_prev_sample + variance
114
+ print(pred_prev_sample.shape)
115
+ if not return_dict:
116
+ return (pred_prev_sample,)
117
+
118
+ return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
119
+
120
+
121
+ class ModifiedUniPCScheduler(UniPCMultistepScheduler):
122
+ '''
123
+ This is the modification of UniPCMultistepScheduler, which is the same as UniPCMultistepScheduler except for the _get_variance function.
124
+ '''
125
+ def __init__(self, variance_type: str = "fixed_small", *args, **kwargs):
126
+ super().__init__(*args, **kwargs)
127
+ self.custom_timesteps = False
128
+ self.variance_type=variance_type
129
+ self.config.timestep_spacing="leading"
130
+ def previous_timestep(self, timestep):
131
+ if self.custom_timesteps:
132
+ index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
133
+ if index == self.timesteps.shape[0] - 1:
134
+ prev_t = torch.tensor(-1)
135
+ else:
136
+ prev_t = self.timesteps[index + 1]
137
+ else:
138
+ num_inference_steps = (
139
+ self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
140
+ )
141
+ prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
142
+
143
+ return prev_t
144
+
145
+ def _get_variance(self, t, predicted_variance=None, variance_type="learned_range"):
146
+ prev_t = self.previous_timestep(t)
147
+
148
+ alpha_prod_t = self.alphas_cumprod[t]
149
+ alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
150
+ current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev
151
+
152
+ variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t
153
+
154
+ variance = torch.clamp(variance, min=1e-20)
155
+
156
+ if variance_type is None:
157
+ variance_type = self.config.variance_type
158
+
159
+ if variance_type == "fixed_small":
160
+ variance = variance
161
+ elif variance_type == "fixed_small_log":
162
+ variance = torch.log(variance)
163
+ variance = torch.exp(0.5 * variance)
164
+ elif variance_type == "fixed_large":
165
+ variance = current_beta_t
166
+ elif variance_type == "fixed_large_log":
167
+ variance = torch.log(current_beta_t)
168
+ elif variance_type == "learned":
169
+ return predicted_variance
170
+ elif variance_type == "learned_range":
171
+ min_log = torch.log(variance)
172
+ max_log = torch.log(current_beta_t)
173
+ frac = (predicted_variance + 1) / 2
174
+ variance = frac * max_log + (1 - frac) * min_log
175
+
176
+ return variance
177
+
178
+ def step(self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, return_dict: bool = True) -> Union[SchedulerOutput, Tuple]:
179
+
180
+ if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
181
+ print("condition using predicted_variance is trigger")
182
+ model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
183
+ else:
184
+ predicted_variance = None
185
+
186
+ super_output = super().step(model_output, timestep, sample, return_dict=False)
187
+ prev_sample = super_output[0]
188
+ # breakpoint()
189
+ variance = 0
190
+ if timestep > 0:
191
+ device = model_output.device
192
+ variance_noise = randn_tensor(
193
+ model_output.shape, generator=None, device=device, dtype=model_output.dtype
194
+ )
195
+ if self.variance_type == "fixed_small_log":
196
+ variance = self._get_variance(timestep, predicted_variance=predicted_variance) * variance_noise
197
+ elif self.variance_type == "learned_range":
198
+ # breakpoint()
199
+ variance = self._get_variance(timestep, predicted_variance=predicted_variance)
200
+ variance = torch.exp(0.5 * variance) * variance_noise
201
+ # breakpoint()
202
+ else:
203
+ variance = (self._get_variance(timestep, predicted_variance=predicted_variance) ** 0.5) * variance_noise
204
+
205
+
206
+ # breakpoint()
207
+ print("time step is ", timestep)
208
+ prev_sample = prev_sample + variance
209
+
210
+ if not return_dict:
211
+ return (prev_sample,)
212
+
213
+ return DDPMSchedulerOutput(prev_sample=prev_sample,pred_original_sample=prev_sample)
214
+
215
+ #return SchedulerOutput(prev_sample=prev_sample)
216
+
217
+
218
+ def build_proc(sch_cfg=None, _sch=None, **kwargs):
219
+ if kwargs:
220
+ return _sch(**kwargs)
221
+
222
+ type_str = str(type(sch_cfg))
223
+ if 'dict' in type_str:
224
+ return _sch.from_config(**sch_cfg)
225
+ return _sch.from_config(sch_cfg, subfolder="scheduler")
226
+
227
+ scheduler_factory = {
228
+ 'UniPC' : partial(build_proc, _sch=UniPCMultistepScheduler),
229
+ 'modifiedUniPC' : partial(build_proc, _sch=ModifiedUniPCScheduler),
230
+ # DPM family
231
+ 'DDPM' : partial(build_proc, _sch=DDPMScheduler),
232
+ 'DPMSolver' : partial(build_proc, _sch=DPMSolverMultistepScheduler, algorithm_type='dpmsolver'),
233
+ 'DPMSolver++' : partial(build_proc, _sch=DPMSolverMultistepScheduler),
234
+ 'DPMSolverSingleStep' : partial(build_proc, _sch=DPMSolverSinglestepScheduler)
235
+
236
+ }
237
+
238
+ def scheduler_setup(pipe : DiffusionPipeline = None, scheduler_type : str = 'UniPC', from_config=None, **kwargs):
239
+ if not isinstance(pipe, DiffusionPipeline):
240
+ raise TypeError(f'pipe should be DiffusionPipeline, but given {type(pipe)}\n')
241
+
242
+ sch_cfg = from_config if from_config else pipe.scheduler.config
243
+ #sch_cfg = diffusers.configuration_utils.FrozenDict({**sch_cfg, 'solver_order':3})
244
+ #pipe.scheduler = scheduler_factory[scheduler_type](**kwargs) if kwargs \
245
+ # else scheduler_factory[scheduler_type](sch_cfg)
246
+
247
+ # pipe.scheduler = DPMSolverSinglestepScheduler()
248
+ # #pipe.scheduler = DDPMScheduler(beta_schedule="linear", variance_type="learned_range")
249
+ # print(pipe.scheduler)
250
+ print("Scheduler type in Scheduler_factory.py is Hard-coded to modifyUniPC, Please change it back to AutoDetect functionality if you want to change scheudler")
251
+ pipe.scheduler = ModifiedUniPCScheduler(variance_type="learned_range", )
252
+ # pipe.scheduler = ModifiedDDPMScheduler(beta_schedule="linear", variance_type="learned_range")
253
+
254
+ #pipe.scheduler = DDPMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
255
+ #pipe.scheduler._get_variance = _get_variance
256
+ return pipe
257
+
258
+ # unittest of scheduler..
259
+ if __name__ == "__main__":
260
+ def ld_mod():
261
+ noise_scheduler = DDPMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
262
+ vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to("cuda").to(torch.float16)
263
+ unet = SDMUNet2DModel.from_pretrained("/data/harry/Data_generation/diffusers-main/examples/VAESDM/LDM-sdm-model/checkpoint-46000", subfolder="unet").to("cuda").to(torch.float16)
264
+ return noise_scheduler, vae, unet
265
+
266
+ from Pipline import SDMLDMPipeline
267
+ from diffusers import StableDiffusionPipeline
268
+ import torch
269
+
270
+ path = "CompVis/stable-diffusion-v1-4"
271
+ pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16)
272
+
273
+ # change scheduler
274
+ # customized args : once you customized, customize forever ~ no from_config
275
+ #pipe = scheduler_setup(pipe, 'DPMSolver++', thresholding=True)
276
+ # from_config
277
+ pipe = scheduler_setup(pipe, 'DPMSolverSingleStep')
278
+
279
+ pipe = pipe.to("cuda")
280
+ prompt = "a highly realistic photo of green turtle"
281
+ generator = torch.manual_seed(0)
282
+ # only 15 steps are needed for good results => 2-4 seconds on GPU
283
+ image = pipe(prompt, generator=generator, num_inference_steps=15).images[0]
284
+ # save image
285
+ image.save("turtle.png")
286
+
287
+ '''
288
+ # load & wrap submodules into pipe-API
289
+ noise_scheduler, vae, unet = ld_mod()
290
+ pipe = SDMLDMPipeline(
291
+ unet=unet,
292
+ vqvae=vae,
293
+ scheduler=noise_scheduler,
294
+ torch_dtype=torch.float16
295
+ )
296
+
297
+ # change scheduler
298
+ pipe = scheduler_setup(pipe, 'DPMSolverSingleStep')
299
+ pipe = pipe.to("cuda")
300
+ '''
evolution.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import math
3
+ import numpy as np
4
+ from PIL import Image
5
+ from skimage.draw import line
6
+ from skimage import morphology
7
+ import cv2
8
+
9
+ def line_crosses_cracks(start, end, img):
10
+ rr, cc = line(start[0], start[1], end[0], end[1])
11
+ # Exclude the starting point from the line coordinates
12
+ if len(rr) > 1 and len(cc) > 1:
13
+ return np.any(img[rr[1:], cc[1:]] == 255)
14
+ return False
15
+
16
+ def random_walk(img_array, k=8, m=0.1, min_steps=50, max_steps=200, length=2, degree_range=30, seed=None):
17
+
18
+ if seed is not None:
19
+ random.seed(seed)
20
+ np.random.seed(seed)
21
+
22
+
23
+ img_array = cv2.ximgproc.thinning(img_array)
24
+
25
+ rows, cols = img_array.shape
26
+ # Find all white pixels (existing cracks)
27
+ white_pixels = np.column_stack(np.where(img_array == 255))
28
+ original_crack_count = len(white_pixels) # Count of original crack pixels
29
+
30
+ # Select k random starting points from the white pixels
31
+ if white_pixels.size == 0:
32
+ raise ValueError("No initial crack pixels found in the image.")
33
+ if k > len(white_pixels):
34
+ raise ValueError("k is greater than the number of existing crack pixels.")
35
+ initial_points = white_pixels[random.sample(range(len(white_pixels)), k)]
36
+
37
+ # Initialize step count for each initial point with a random value between min_steps and max_steps
38
+ step_counts = {i: random.randint(min_steps, max_steps) for i in range(k)}
39
+ # Initialize main direction for each initial point (0 to 360 degrees)
40
+ main_angles = {i: random.uniform(0, 360) for i in range(k)}
41
+
42
+ grown_crack_count = 0 # Count of newly grown crack pixels
43
+
44
+ # Start the random walk for each initial point
45
+ for idx, point in enumerate(initial_points):
46
+ current_pos = tuple(point)
47
+ current_steps = 0
48
+ while current_steps < step_counts[idx]:
49
+ # Check the crack ratio
50
+ current_ratio = np.sum(img_array == 255) / (rows * cols)
51
+ if current_ratio >= m:
52
+ return img_array, {'original_crack_count': original_crack_count, 'grown_crack_count': grown_crack_count}
53
+
54
+ # Generate a random direction within the fan-shaped area around the main angle
55
+ main_angle = main_angles[idx]
56
+ angle = math.radians(main_angle + random.uniform(-degree_range, degree_range))
57
+
58
+ # Determine the next position with the specified length
59
+ delta_row = length * math.sin(angle)
60
+ delta_col = length * math.cos(angle)
61
+ next_pos = (int(current_pos[0] + delta_row), int(current_pos[1] + delta_col))
62
+
63
+ # Check if the line from the current to the next position crosses existing cracks
64
+ if 0 <= next_pos[0] < rows and 0 <= next_pos[1] < cols and not line_crosses_cracks(current_pos, next_pos, img_array):
65
+ # Draw a line from the current position to the next position
66
+ rr, cc = line(current_pos[0], current_pos[1], next_pos[0], next_pos[1])
67
+ img_array[rr, cc] = 255 # Set the pixels along the line to white
68
+ grown_crack_count += len(rr) # Update the count of grown crack pixels
69
+ current_pos = next_pos
70
+ current_steps += 1
71
+ else:
72
+ # If the line crosses existing cracks or the next position is outside the boundaries, stop the walk for this point
73
+ break
74
+
75
+ return img_array, {'original_crack_count': original_crack_count, 'grown_crack_count': grown_crack_count}
76
+
77
+ # The rest of the test code remains the same.
78
+ # You can use this function in your test code to generate the image and get the counts.
79
+
80
+
81
+ # test code
82
+ if __name__ == "__main__":
83
+ # Updated parameters
84
+ k = 8 # Number of initial white pixels to start the random walk
85
+ m = 0.1 # Maximum ratio of crack pixels
86
+ min_steps = 50
87
+ max_steps = 200
88
+ img_path = '/data/leiqin/diffusion/huggingface_diffusers/crack_label_creator/random_walk/thindata_256/2.png'
89
+ img = Image.open(img_path)
90
+ img_array = np.array(img)
91
+ length = 2
92
+
93
+ # Perform the modified random walk
94
+ result_img_array_mod, pixels_dict = random_walk(img_array.copy(), k, m, min_steps, max_steps, length)
95
+
96
+ # Convert the result to an image
97
+ result_img_mod = Image.fromarray(result_img_array_mod.astype('uint8'))
98
+
99
+ # Save the resulting image
100
+ result_img_path_mod = 'resutls.png'
101
+ result_img_mod.save(result_img_path_mod)
102
+ print(pixels_dict)
figs/4.png ADDED
figs/4_1.jpg ADDED
figs/4_1.png ADDED
figs/4_1_mask.png ADDED
generate.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers.schedulers import UniPCMultistepScheduler
2
+ from diffusers import AutoencoderKL
3
+ from diffusion_module.unet import UNetModel
4
+ import torch
5
+ from diffusion_module.utils.LSDMPipeline_expandDataset import SDMLDMPipeline
6
+ from accelerate import Accelerator
7
+ from evolution import random_walk
8
+ import cv2
9
+ import numpy as np
10
+
11
+ def mask2onehot(data, num_classes):
12
+ # move to GPU and change data types
13
+ data = data.to(dtype=torch.int64)
14
+
15
+ # create one-hot label map
16
+ label_map = data
17
+ bs, _, h, w = label_map.size()
18
+ input_label = torch.FloatTensor(bs, num_classes, h, w).zero_().to(data.device)
19
+ input_semantics = input_label.scatter_(1, label_map, 1.0)
20
+
21
+ return input_semantics
22
+
23
+ def generate(img, pretrain_weight,seed=None):
24
+
25
+ noise_scheduler = UniPCMultistepScheduler()
26
+ vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
27
+ latent_size = (64, 64)
28
+ unet = UNetModel(
29
+ image_size = latent_size,
30
+ in_channels=vae.config.latent_channels,
31
+ model_channels=256,
32
+ out_channels=vae.config.latent_channels,
33
+ num_res_blocks=2,
34
+ attention_resolutions=(2, 4, 8),
35
+ dropout=0,
36
+ channel_mult=(1, 2, 3, 4),
37
+ num_heads=8,
38
+ num_head_channels=-1,
39
+ num_heads_upsample=-1,
40
+ use_scale_shift_norm=True,
41
+ resblock_updown=True,
42
+ use_new_attention_order=False,
43
+ num_classes=151,
44
+ mask_emb="resize",
45
+ use_checkpoint=True,
46
+ SPADE_type="spade",
47
+ )
48
+
49
+
50
+ unet = unet.from_pretrained(pretrain_weight)
51
+ device = 'cpu'
52
+ if device != 'cpu':
53
+ mixed_precision = "fp16"
54
+ else:
55
+ mixed_precision = "no"
56
+
57
+
58
+ accelerator = Accelerator(
59
+ mixed_precision=mixed_precision,
60
+ cpu= True if device is 'cpu' else False
61
+ )
62
+
63
+ weight_dtype = torch.float32
64
+ if accelerator.mixed_precision == "fp16":
65
+ weight_dtype = torch.float16
66
+
67
+ unet,vae = accelerator.prepare(unet, vae)
68
+ vae.to(device=accelerator.device, dtype=weight_dtype)
69
+ pipeline = SDMLDMPipeline(
70
+ vae=accelerator.unwrap_model(vae),
71
+ unet=accelerator.unwrap_model(unet),
72
+ scheduler=noise_scheduler,
73
+ torch_dtype=weight_dtype,
74
+ resolution_type="crack"
75
+ )
76
+ """
77
+ if accelerator.device != 'cpu':
78
+ pipeline.enable_xformers_memory_efficient_attention()
79
+ """
80
+ pipeline = pipeline.to(accelerator.device)
81
+ pipeline.set_progress_bar_config(disable=False)
82
+
83
+ if seed is None:
84
+ generator = None
85
+ else:
86
+ generator = torch.Generator(device=accelerator.device).manual_seed(seed)
87
+
88
+ resized_s = cv2.resize(img, (64, 64), interpolation=cv2.INTER_AREA)
89
+ # 灰度图放大到255
90
+ _, binary_s = cv2.threshold(resized_s, 1, 255, cv2.THRESH_BINARY)
91
+ # 转换为0,1
92
+ tensor_s = torch.from_numpy(binary_s / 255)
93
+ # h,w -> 1,1,h,w
94
+ tensor_s = tensor_s.unsqueeze(0).unsqueeze(0)
95
+ onehot_skeletons=[]
96
+ onehot_s = mask2onehot(tensor_s, 151)
97
+ onehot_skeletons.append(onehot_s)
98
+
99
+ onehot_skeletons = torch.stack(onehot_skeletons, dim=1).squeeze(0)
100
+ onehot_skeletons = onehot_skeletons.to(dtype=weight_dtype,device=accelerator.device)
101
+
102
+ images = pipeline(onehot_skeletons, generator=generator,batch_size = 1,
103
+ num_inference_steps=20, s=1.5,
104
+ num_evolution_per_mask=1).images
105
+
106
+ return images
requirements.txt ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.25.0
2
+ aiofiles==23.2.1
3
+ aiohttp==3.9.1
4
+ aiosignal==1.3.1
5
+ albumentations==1.3.1
6
+ altair==5.2.0
7
+ annotated-types==0.6.0
8
+ antlr4-python3-runtime==4.9.3
9
+ anyio==4.2.0
10
+ appdirs==1.4.4
11
+ async-timeout==4.0.3
12
+ attrs==23.2.0
13
+ blobfile==2.1.1
14
+ certifi==2023.11.17
15
+ charset-normalizer==3.3.2
16
+ click==8.1.7
17
+ colorama==0.4.6
18
+ contourpy==1.2.0
19
+ cycler==0.12.1
20
+ datasets==2.16.1
21
+ diffusers==0.16.1
22
+ dill==0.3.7
23
+ docker-pycreds==0.4.0
24
+ einops==0.7.0
25
+ exceptiongroup==1.2.0
26
+ fastapi==0.109.2
27
+ ffmpy==0.3.1
28
+ filelock==3.13.1
29
+ fonttools==4.47.0
30
+ frozenlist==1.4.1
31
+ fsspec==2023.10.0
32
+ gitdb==4.0.11
33
+ GitPython==3.1.41
34
+ gradio==4.16.0
35
+ gradio_client==0.8.1
36
+ h11==0.14.0
37
+ httpcore==1.0.2
38
+ httpx==0.26.0
39
+ huggingface-hub==0.20.1
40
+ idna==3.6
41
+ imageio==2.33.1
42
+ importlib-metadata==7.0.1
43
+ importlib-resources==6.1.1
44
+ Jinja2==3.1.2
45
+ joblib==1.3.2
46
+ jsonschema==4.21.1
47
+ jsonschema-specifications==2023.12.1
48
+ kiwisolver==1.4.5
49
+ lazy_loader==0.3
50
+ lightning-utilities==0.10.0
51
+ lxml==4.9.4
52
+ markdown-it-py==3.0.0
53
+ MarkupSafe==2.1.3
54
+ matplotlib==3.8.2
55
+ mdurl==0.1.2
56
+ mpmath==1.3.0
57
+ multidict==6.0.4
58
+ multiprocess==0.70.15
59
+ networkx==3.2.1
60
+ numpy==1.26.2
61
+ nvidia-cublas-cu12==12.1.3.1
62
+ nvidia-cuda-cupti-cu12==12.1.105
63
+ nvidia-cuda-nvrtc-cu12==12.1.105
64
+ nvidia-cuda-runtime-cu12==12.1.105
65
+ nvidia-cudnn-cu12==8.9.2.26
66
+ nvidia-cufft-cu12==11.0.2.54
67
+ nvidia-curand-cu12==10.3.2.106
68
+ nvidia-cusolver-cu12==11.4.5.107
69
+ nvidia-cusparse-cu12==12.1.0.106
70
+ nvidia-nccl-cu12==2.18.1
71
+ nvidia-nvjitlink-cu12==12.3.101
72
+ nvidia-nvtx-cu12==12.1.105
73
+ omegaconf==2.3.0
74
+ opencv-contrib-python==4.9.0.80
75
+ opencv-python-headless==4.9.0.80
76
+ orjson==3.9.13
77
+ packaging==23.2
78
+ pandas==2.1.4
79
+ pillow==10.2.0
80
+ protobuf==4.25.2
81
+ psutil==5.9.7
82
+ pyarrow==14.0.2
83
+ pyarrow-hotfix==0.6
84
+ pycryptodomex==3.20.0
85
+ pydantic==2.6.0
86
+ pydantic_core==2.16.1
87
+ pydub==0.25.1
88
+ Pygments==2.17.2
89
+ pyparsing==3.1.1
90
+ python-dateutil==2.8.2
91
+ python-multipart==0.0.7
92
+ pytorch-lightning==2.1.3
93
+ pytz==2023.3.post1
94
+ PyYAML==6.0.1
95
+ qudida==0.0.4
96
+ referencing==0.33.0
97
+ regex==2023.12.25
98
+ requests==2.31.0
99
+ rich==13.7.0
100
+ rpds-py==0.17.1
101
+ ruff==0.2.0
102
+ safetensors==0.4.1
103
+ scikit-image==0.22.0
104
+ scikit-learn==1.3.2
105
+ scipy==1.11.4
106
+ semantic-version==2.10.0
107
+ sentry-sdk==1.39.2
108
+ setproctitle==1.3.3
109
+ shellingham==1.5.4
110
+ six==1.16.0
111
+ smmap==5.0.1
112
+ sniffio==1.3.0
113
+ starlette==0.36.3
114
+ sympy==1.12
115
+ threadpoolctl==3.2.0
116
+ tifffile==2023.12.9
117
+ tokenizers==0.15.0
118
+ tomlkit==0.12.0
119
+ toolz==0.12.1
120
+ torch==2.1.2
121
+ torchaudio==2.1.2
122
+ torchmetrics==1.2.1
123
+ torchvision==0.16.2
124
+ tqdm==4.66.1
125
+ transformers==4.36.2
126
+ triton==2.1.0
127
+ typer==0.9.0
128
+ typing_extensions==4.9.0
129
+ tzdata==2023.4
130
+ urllib3==2.1.0
131
+ uvicorn==0.27.0.post1
132
+ wandb==0.16.2
133
+ websockets==11.0.3
134
+ xformers==0.0.23.post1
135
+ xxhash==3.4.1
136
+ yarl==1.9.4
137
+ zipp==3.17.0