tfwang commited on
Commit
bd366ed
1 Parent(s): 0ac0aaf

add app file

Browse files
app.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train a diffusion model on images.
3
+ """
4
+ import gradio as gr
5
+ import argparse
6
+ from einops import rearrange
7
+ from glide_text2im import dist_util, logger
8
+ from torchvision.utils import make_grid
9
+ from glide_text2im.script_util import (
10
+ model_and_diffusion_defaults,
11
+ create_model_and_diffusion,
12
+ args_to_dict,
13
+ add_dict_to_argparser,
14
+ )
15
+ from glide_text2im.image_datasets_sketch import get_tensor
16
+ from glide_text2im.train_util import TrainLoop
17
+ from glide_text2im.glide_util import sample
18
+ import torch
19
+ import os
20
+ import torch as th
21
+ import torchvision.utils as tvu
22
+ import torch.distributed as dist
23
+ from PIL import Image
24
+ import cv2
25
+ import numpy as np
26
+
27
+ def run(image, mode, sample_c=1.3, num_samples=3, sample_step=100):
28
+ parser, parser_up = create_argparser()
29
+
30
+ args = parser.parse_args()
31
+ args_up = parser_up.parse_args()
32
+ dist_util.setup_dist()
33
+
34
+ if mode == 'sketch':
35
+ args.mode = 'coco-edge'
36
+ args_up.mode = 'coco-edge'
37
+ args.model_path = './ckpt/base_edge.pt'
38
+ args.sr_model_path = './ckpt/upsample_edge.pt'
39
+
40
+ elif mode == 'mask':
41
+ args.mode = 'coco'
42
+ args_up.mode = 'coco'
43
+ args.model_path = './ckpt/base_mask.pt'
44
+ args.sr_model_path = './ckpt/upsample_mask.pt'
45
+
46
+
47
+ args.val_data_dir = image
48
+ args.sample_c = sample_c
49
+ args.num_samples = num_samples
50
+
51
+
52
+ options=args_to_dict(args, model_and_diffusion_defaults(0.).keys())
53
+ model, diffusion = create_model_and_diffusion(**options)
54
+
55
+ options_up=args_to_dict(args_up, model_and_diffusion_defaults(True).keys())
56
+ model_up, diffusion_up = create_model_and_diffusion(**options_up)
57
+
58
+
59
+ if args.model_path:
60
+ print('loading model')
61
+ model_ckpt = dist_util.load_state_dict(args.model_path, map_location="cpu")
62
+
63
+ model.load_state_dict(
64
+ model_ckpt , strict=True )
65
+
66
+ if args.sr_model_path:
67
+ print('loading sr model')
68
+ model_ckpt2 = dist_util.load_state_dict(args.sr_model_path, map_location="cpu")
69
+
70
+ model_up.load_state_dict(
71
+ model_ckpt2 , strict=True )
72
+
73
+
74
+ model.to(dist_util.dev())
75
+ model_up.to(dist_util.dev())
76
+ model.eval()
77
+ model_up.eval()
78
+
79
+ ########### dataset
80
+ # logger.log("creating data loader...")
81
+
82
+ if args.mode == 'coco':
83
+ pil_image = image
84
+ label_pil = pil_image.convert("RGB").resize((256, 256), Image.NEAREST)
85
+ label_tensor = get_tensor()(label_pil)
86
+
87
+ data_dict = {"ref":label_tensor.unsqueeze(0).repeat(args.num_samples, 1, 1, 1)}
88
+
89
+ elif args.mode == 'coco-edge':
90
+ # pil_image = Image.open(image)
91
+ pil_image = image
92
+ label_pil = pil_image.convert("L").resize((256, 256), Image.NEAREST)
93
+
94
+ im_dist = cv2.distanceTransform(255-np.array(label_pil), cv2.DIST_L1, 3)
95
+ im_dist = np.clip((im_dist) , 0, 255).astype(np.uint8)
96
+ im_dist = Image.fromarray(im_dist).convert("RGB")
97
+
98
+ label_tensor = get_tensor()(im_dist)[:1]
99
+
100
+ data_dict = {"ref":label_tensor.unsqueeze(0).repeat(args.num_samples, 1, 1, 1)}
101
+
102
+
103
+
104
+ print("sampling...")
105
+
106
+
107
+ sampled_imgs = []
108
+ grid_imgs = []
109
+ img_id = 0
110
+ while (True):
111
+ if img_id >= args.num_samples:
112
+ break
113
+
114
+ model_kwargs = data_dict
115
+ with th.no_grad():
116
+ samples_lr =sample(
117
+ glide_model= model,
118
+ glide_options= options,
119
+ side_x= 64,
120
+ side_y= 64,
121
+ prompt=model_kwargs,
122
+ batch_size= args.num_samples,
123
+ guidance_scale=args.sample_c,
124
+ device=dist_util.dev(),
125
+ prediction_respacing= str(sample_step),
126
+ upsample_enabled= False,
127
+ upsample_temp=0.997,
128
+ mode = args.mode,
129
+ )
130
+
131
+ samples_lr = samples_lr.clamp(-1, 1)
132
+
133
+ tmp = (127.5*(samples_lr + 1.0)).int()
134
+ model_kwargs['low_res'] = tmp/127.5 - 1.
135
+
136
+ samples_hr =sample(
137
+ glide_model= model_up,
138
+ glide_options= options_up,
139
+ side_x=256,
140
+ side_y=256,
141
+ prompt=model_kwargs,
142
+ batch_size=args.num_samples,
143
+ guidance_scale=1,
144
+ device=dist_util.dev(),
145
+ prediction_respacing= "fast27",
146
+ upsample_enabled=True,
147
+ upsample_temp=0.997,
148
+ mode = args.mode,
149
+ )
150
+
151
+
152
+ samples_hr = samples_hr
153
+
154
+
155
+ for hr in samples_hr:
156
+
157
+ hr = 255. * rearrange((hr.cpu().numpy()+1.0)*0.5, 'c h w -> h w c')
158
+ sample_img = Image.fromarray(hr.astype(np.uint8))
159
+ sampled_imgs.append(sample_img)
160
+ img_id += 1
161
+
162
+ grid_imgs.append(samples_hr)
163
+
164
+ grid = torch.stack(grid_imgs, 0)
165
+ grid = rearrange(grid, 'n b c h w -> (n b) c h w')
166
+ grid = make_grid(grid, nrow=2)
167
+ # to image
168
+ grid = 255. * rearrange((grid+1.0)*0.5, 'c h w -> h w c').cpu().numpy()
169
+
170
+ return Image.fromarray(grid.astype(np.uint8))
171
+
172
+
173
+ def create_argparser():
174
+ defaults = dict(
175
+ data_dir="",
176
+ val_data_dir="",
177
+ model_path="./base_edge.pt",
178
+ sr_model_path="./upsample_edge.pt",
179
+ encoder_path="",
180
+ schedule_sampler="uniform",
181
+ lr=1e-4,
182
+ weight_decay=0.0,
183
+ lr_anneal_steps=0,
184
+ batch_size=2,
185
+ microbatch=-1, # -1 disables microbatches
186
+ ema_rate="0.9999", # comma-separated list of EMA values
187
+ log_interval=100,
188
+ save_interval=20000,
189
+ resume_checkpoint="",
190
+ use_fp16=False,
191
+ fp16_scale_growth=1e-3,
192
+ sample_c=1.,
193
+ sample_respacing="100",
194
+ uncond_p=0.2,
195
+ num_samples=3,
196
+ finetune_decoder = False,
197
+ mode = '',
198
+ )
199
+
200
+ defaults_up = defaults
201
+ defaults.update(model_and_diffusion_defaults())
202
+ parser = argparse.ArgumentParser()
203
+ add_dict_to_argparser(parser, defaults)
204
+
205
+ defaults_up.update(model_and_diffusion_defaults(True))
206
+ parser_up = argparse.ArgumentParser()
207
+ add_dict_to_argparser(parser_up, defaults_up)
208
+
209
+ return parser, parser_up
210
+
211
+ image = gr.outputs.Image(type="pil", label="Sampled results")
212
+ css = ".output-image{height: 528px !important} .output-carousel .output-image{height:272px !important} a{text-decoration: underline}"
213
+ demo = gr.Interface(fn=run, inputs=[
214
+ gr.inputs.Image(type="pil", label="Input Sketch" ) ,
215
+ # gr.Image(image_mode="L", source="canvas", type="pil", shape=(256,256), invert_colors=False, tool="editor"),
216
+ gr.inputs.Radio(label="Input Mode - The type of your input", choices=["mask", "sketch"],default="sketch"),
217
+ gr.inputs.Slider(label="sample_c - The strength of classifier-free guidance",default=1.4, minimum=1.0, maximum=2.0),
218
+ gr.inputs.Slider(label="Number of samples - How many samples you wish to generate", default=4, step=1, minimum=1, maximum=16),
219
+ gr.inputs.Slider(label="Number of Steps - How many steps you want to use", default=100, step=10, minimum=50, maximum=1000),
220
+ ],
221
+ outputs=[image],
222
+ css=css,
223
+ title="Generate images from sketches with PITI",
224
+ description="<div>By uploading a sketch map or a semantic map and pressing submit, you can generate images based on your input.</div>")
225
+
226
+ demo.launch(enable_queue=True)
227
+
glide_text2im/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
glide_text2im/adv.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.optim as optim
5
+ from torch.nn.parallel.distributed import DistributedDataParallel as DDP
6
+ from .nn import mean_flat
7
+ from . import dist_util
8
+ import functools
9
+
10
+ class AdversarialLoss(nn.Module):
11
+ def __init__(self, gan_type='WGAN_GP', gan_k=1,
12
+ lr_dis=1e-5 ):
13
+
14
+ super(AdversarialLoss, self).__init__()
15
+
16
+ self.gan_type = gan_type
17
+ self.gan_k = gan_k
18
+
19
+ model = NLayerDiscriminator().to(dist_util.dev())
20
+ self.discriminator = DDP(
21
+ model,
22
+ device_ids=[dist_util.dev()],
23
+ output_device=dist_util.dev(),
24
+ broadcast_buffers=False,
25
+ bucket_cap_mb=128,
26
+ find_unused_parameters=False,
27
+ )
28
+
29
+ if (gan_type in ['WGAN_GP', 'GAN']):
30
+ self.optimizer = optim.Adam(
31
+ self.discriminator.parameters(),
32
+ lr=lr_dis
33
+ )
34
+
35
+ def forward(self, fake, real):
36
+ fake_detach = fake.detach()
37
+ for _ in range(self.gan_k):
38
+ self.optimizer.zero_grad()
39
+ d_fake = self.discriminator(fake_detach)
40
+ d_real = self.discriminator(real)
41
+ if (self.gan_type.find('WGAN') >= 0):
42
+ loss_d = (d_fake - d_real).mean()
43
+ if self.gan_type.find('GP') >= 0:
44
+ epsilon = torch.rand(real.size(0), 1, 1, 1).to(dist_util.dev())
45
+ epsilon = epsilon.expand(real.size())
46
+ hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon)
47
+ hat.requires_grad = True
48
+ d_hat = self.discriminator(hat)
49
+ gradients = torch.autograd.grad(
50
+ outputs=d_hat.sum(), inputs=hat,
51
+ retain_graph=True, create_graph=True, only_inputs=True
52
+ )[0]
53
+ gradients = gradients.view(gradients.size(0), -1)
54
+ gradient_norm = gradients.norm(2, dim=1)
55
+ gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean()
56
+ loss_d += gradient_penalty
57
+
58
+ # print('d loss:', loss_d)
59
+ # Discriminator update
60
+ loss_d.backward()
61
+ self.optimizer.step()
62
+
63
+ d_fake_for_g = self.discriminator(fake)
64
+ if (self.gan_type.find('WGAN') >= 0):
65
+ loss_g = -d_fake_for_g
66
+
67
+ # Generator loss
68
+ return mean_flat(loss_g)
69
+
70
+
71
+
72
+ def conv3x3(in_channels, out_channels, stride=1):
73
+ return nn.Conv2d(in_channels, out_channels, kernel_size=3,
74
+ stride=stride, padding=1, bias=True)
75
+
76
+
77
+ def conv7x7(in_channels, out_channels, stride=1):
78
+ return nn.Conv2d(in_channels, out_channels, kernel_size=7,
79
+ stride=stride, padding=3, bias=True)
80
+
81
+
82
+ class Discriminator(nn.Module):
83
+ def __init__(self, ):
84
+ super(Discriminator, self).__init__()
85
+ self.conv1 = conv7x7(3, 32)
86
+ self.norm1 = nn.InstanceNorm2d(32, affine=True)
87
+ self.LReLU1 = nn.LeakyReLU(0.2)
88
+
89
+ self.conv2 = conv3x3(32, 32, 2)
90
+ self.norm2 = nn.InstanceNorm2d(32, affine=True)
91
+ self.LReLU2 = nn.LeakyReLU(0.2)
92
+
93
+ self.conv3 = conv3x3(32, 64)
94
+ self.norm3 = nn.InstanceNorm2d(64, affine=True)
95
+ self.LReLU3 = nn.LeakyReLU(0.2)
96
+
97
+ self.conv4 = conv3x3(64, 64, 2)
98
+ self.norm4 = nn.InstanceNorm2d(64, affine=True)
99
+ self.LReLU4 = nn.LeakyReLU(0.2)
100
+
101
+ self.conv5 = conv3x3(64, 128)
102
+ self.norm5 = nn.InstanceNorm2d(128, affine=True)
103
+ self.LReLU5 = nn.LeakyReLU(0.2)
104
+
105
+ self.conv6 = conv3x3(128, 128, 2)
106
+ self.norm6 = nn.InstanceNorm2d(128, affine=True)
107
+ self.LReLU6 = nn.LeakyReLU(0.2)
108
+
109
+ self.conv7 = conv3x3(128, 256)
110
+ self.norm7 = nn.InstanceNorm2d(256, affine=True)
111
+ self.LReLU7 = nn.LeakyReLU(0.2)
112
+
113
+ self.conv8 = conv3x3(256, 256, 2)
114
+ self.norm8 = nn.InstanceNorm2d(256, affine=True)
115
+ self.LReLU8 = nn.LeakyReLU(0.2)
116
+
117
+ self.conv9 = conv3x3(256, 512)
118
+ self.norm9 = nn.InstanceNorm2d(512, affine=True)
119
+ self.LReLU9 = nn.LeakyReLU(0.2)
120
+
121
+ self.conv10 = conv3x3(512, 512, 2)
122
+ self.norm10 = nn.InstanceNorm2d(512, affine=True)
123
+ self.LReLU10 = nn.LeakyReLU(0.2)
124
+
125
+ self.conv11 = conv3x3(512, 32)
126
+ self.norm11 = nn.InstanceNorm2d(32, affine=True)
127
+ self.LReLU11 = nn.LeakyReLU(0.2)
128
+ self.conv12 = conv3x3(32, 1)
129
+
130
+
131
+ def forward(self, x):
132
+ x = self.LReLU1(self.norm1(self.conv1(x)))
133
+ x = self.LReLU2(self.norm2(self.conv2(x)))
134
+ x = self.LReLU3(self.norm3(self.conv3(x)))
135
+ x = self.LReLU4(self.norm4(self.conv4(x)))
136
+ x = self.LReLU5(self.norm5(self.conv5(x)))
137
+ x = self.LReLU6(self.norm6(self.conv6(x)))
138
+ x = self.LReLU7(self.norm7(self.conv7(x)))
139
+ x = self.LReLU8(self.norm8(self.conv8(x)))
140
+ x = self.LReLU9(self.norm9(self.conv9(x)))
141
+ x = self.LReLU10(self.norm10(self.conv10(x)))
142
+
143
+ x = self.LReLU11(self.norm11(self.conv11(x)))
144
+ x = self.conv12(x)
145
+
146
+ return x
147
+
148
+
149
+
150
+ def get_norm_layer(norm_type='instance'):
151
+ """Return a normalization layer
152
+ Parameters:
153
+ norm_type (str) -- the name of the normalization layer: batch | instance | none
154
+ For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
155
+ For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
156
+ """
157
+ if norm_type == 'batch':
158
+ norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
159
+ elif norm_type == 'instance':
160
+ norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
161
+ elif norm_type == 'none':
162
+ def norm_layer(x): return Identity()
163
+ else:
164
+ raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
165
+ return norm_layer
166
+
167
+
168
+ class NLayerDiscriminator(nn.Module):
169
+ """Defines a PatchGAN discriminator"""
170
+
171
+ def __init__(self, input_nc=3, ndf=64, n_layers=3 ):
172
+ """Construct a PatchGAN discriminator
173
+ Parameters:
174
+ input_nc (int) -- the number of channels in input images
175
+ ndf (int) -- the number of filters in the last conv layer
176
+ n_layers (int) -- the number of conv layers in the discriminator
177
+ norm_layer -- normalization layer
178
+ """
179
+ super(NLayerDiscriminator, self).__init__()
180
+ norm_layer = get_norm_layer(norm_type='instance')
181
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
182
+ use_bias = norm_layer.func == nn.InstanceNorm2d
183
+ else:
184
+ use_bias = norm_layer == nn.InstanceNorm2d
185
+
186
+ kw = 4
187
+ padw = 1
188
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
189
+ nf_mult = 1
190
+ nf_mult_prev = 1
191
+ for n in range(1, n_layers): # gradually increase the number of filters
192
+ nf_mult_prev = nf_mult
193
+ nf_mult = min(2 ** n, 8)
194
+ sequence += [
195
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
196
+ norm_layer(ndf * nf_mult),
197
+ nn.LeakyReLU(0.2, True)
198
+ ]
199
+
200
+ nf_mult_prev = nf_mult
201
+ nf_mult = min(2 ** n_layers, 8)
202
+ sequence += [
203
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
204
+ norm_layer(ndf * nf_mult),
205
+ nn.LeakyReLU(0.2, True)
206
+ ]
207
+
208
+ sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
209
+ self.model = nn.Sequential(*sequence)
210
+
211
+ def forward(self, input):
212
+ """Standard forward."""
213
+ return self.model(input)
glide_text2im/dist_util.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helpers for distributed training.
3
+ """
4
+
5
+ import io
6
+ import os
7
+ import socket
8
+
9
+ import blobfile as bf
10
+ from mpi4py import MPI
11
+ import torch as th
12
+ import torch.distributed as dist
13
+ # Change this to reflect your cluster layout.
14
+ # The GPU for a given rank is (rank % GPUS_PER_NODE).
15
+ GPUS_PER_NODE = th.cuda.device_count()
16
+
17
+ SETUP_RETRY_COUNT = 3
18
+
19
+
20
+ def setup_dist():
21
+ """
22
+ Setup a distributed process group.
23
+ """
24
+ if dist.is_initialized():
25
+ return
26
+
27
+ comm = MPI.COMM_WORLD
28
+ backend = "gloo" if not th.cuda.is_available() else "nccl"
29
+
30
+ if backend == "gloo":
31
+ hostname = "localhost"
32
+ else:
33
+ hostname = socket.gethostbyname(socket.getfqdn())
34
+ if not os.environ.get("MASTER_ADDR"):
35
+ os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0)
36
+ os.environ["RANK"] = str(comm.rank)
37
+ os.environ["WORLD_SIZE"] = str(comm.size)
38
+ port = comm.bcast(_find_free_port(), root=0)
39
+ if not os.environ.get("MASTER_PORT"):
40
+ os.environ["MASTER_PORT"] = str(port)
41
+ th.cuda.set_device(dev())
42
+ dist.init_process_group(backend=backend, init_method="env://")
43
+
44
+
45
+ def dev():
46
+ """
47
+ Get the device to use for torch.distributed.
48
+ """
49
+ if th.cuda.is_available():
50
+ return th.device(f"cuda:{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}")
51
+ return th.device("cpu")
52
+
53
+
54
+ def load_state_dict(path, **kwargs):
55
+ """
56
+ Load a PyTorch file without redundant fetches across MPI ranks.
57
+ """
58
+ if MPI.COMM_WORLD.Get_rank() == 0:
59
+ with bf.BlobFile(path, "rb") as f:
60
+ data = f.read()
61
+ else:
62
+ data = None
63
+ data = MPI.COMM_WORLD.bcast(data)
64
+ return th.load(io.BytesIO(data), **kwargs)
65
+
66
+
67
+ def sync_params(params):
68
+ """
69
+ Synchronize a sequence of Tensors across ranks from rank 0.
70
+ """
71
+ for p in params:
72
+ with th.no_grad():
73
+ dist.broadcast(p, 0)
74
+
75
+
76
+ def _find_free_port():
77
+ try:
78
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
79
+ s.bind(("", 0))
80
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
81
+ return s.getsockname()[1]
82
+ finally:
83
+ s.close()
glide_text2im/fp16_util.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helpers to train with 16-bit precision.
3
+ """
4
+
5
+ import torch.nn as nn
6
+ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
7
+
8
+
9
+ def convert_module_to_f16(l):
10
+ """
11
+ Convert primitive modules to float16.
12
+ """
13
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
14
+ l.weight.data = l.weight.data.half()
15
+ l.bias.data = l.bias.data.half()
16
+
17
+
18
+ def convert_module_to_f32(l):
19
+ """
20
+ Convert primitive modules to float32, undoing convert_module_to_f16().
21
+ """
22
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
23
+ l.weight.data = l.weight.data.float()
24
+ l.bias.data = l.bias.data.float()
25
+
26
+
27
+ def make_master_params(model_params):
28
+ """
29
+ Copy model parameters into a (differently-shaped) list of full-precision
30
+ parameters.
31
+ """
32
+ master_params = _flatten_dense_tensors(
33
+ [param.detach().float() for param in model_params]
34
+ )
35
+ master_params = nn.Parameter(master_params)
36
+ master_params.requires_grad = True
37
+ return [master_params]
38
+
39
+
40
+ def model_grads_to_master_grads(model_params, master_params):
41
+ """
42
+ Copy the gradients from the model parameters into the master parameters
43
+ from make_master_params().
44
+ """
45
+ master_params[0].grad = _flatten_dense_tensors(
46
+ [param.grad.data.detach().float() for param in model_params]
47
+ )
48
+
49
+
50
+ def master_params_to_model_params(model_params, master_params):
51
+ """
52
+ Copy the master parameter data back into the model parameters.
53
+ """
54
+ # Without copying to a list, if a generator is passed, this will
55
+ # silently not copy any parameters.
56
+ model_params = list(model_params)
57
+
58
+ for param, master_param in zip(
59
+ model_params, unflatten_master_params(model_params, master_params)
60
+ ):
61
+ param.detach().copy_(master_param)
62
+
63
+
64
+ def unflatten_master_params(model_params, master_params):
65
+ """
66
+ Unflatten the master parameters to look like model_params.
67
+ """
68
+ return _unflatten_dense_tensors(master_params[0].detach(), model_params)
69
+
70
+
71
+ def zero_grad(model_params):
72
+ for param in model_params:
73
+ # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
74
+ if param.grad is not None:
75
+ param.grad.detach_()
76
+ param.grad.zero_()
glide_text2im/gaussian_diffusion.py ADDED
@@ -0,0 +1,946 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This code started out as a PyTorch port of Ho et al's diffusion models:
3
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py
4
+
5
+ Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules.
6
+ """
7
+
8
+ import enum
9
+ import math
10
+ from subprocess import call
11
+
12
+ import numpy as np
13
+ import torch as th
14
+
15
+ from .nn import mean_flat
16
+ from .losses import normal_kl, discretized_gaussian_log_likelihood
17
+
18
+ def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
19
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
20
+ warmup_time = int(num_diffusion_timesteps * warmup_frac)
21
+ betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
22
+ return betas
23
+
24
+
25
+ def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
26
+ """
27
+ This is the deprecated API for creating beta schedules.
28
+
29
+ See get_named_beta_schedule() for the new library of schedules.
30
+ """
31
+ if beta_schedule == "quad":
32
+ betas = (
33
+ np.linspace(
34
+ beta_start ** 0.5,
35
+ beta_end ** 0.5,
36
+ num_diffusion_timesteps,
37
+ dtype=np.float64,
38
+ )
39
+ ** 2
40
+ )
41
+ elif beta_schedule == "linear":
42
+ betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
43
+ elif beta_schedule == "warmup10":
44
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
45
+ elif beta_schedule == "warmup50":
46
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
47
+ elif beta_schedule == "const":
48
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
49
+ elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
50
+ betas = 1.0 / np.linspace(
51
+ num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
52
+ )
53
+ else:
54
+ raise NotImplementedError(beta_schedule)
55
+ assert betas.shape == (num_diffusion_timesteps,)
56
+ return betas
57
+
58
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
59
+ """
60
+ Get a pre-defined beta schedule for the given name.
61
+
62
+ The beta schedule library consists of beta schedules which remain similar
63
+ in the limit of num_diffusion_timesteps.
64
+ Beta schedules may be added, but should not be removed or changed once
65
+ they are committed to maintain backwards compatibility.
66
+ """
67
+ if schedule_name == "linear":
68
+ scale = 1000 / num_diffusion_timesteps
69
+ return get_beta_schedule(
70
+ "linear",
71
+ beta_start=scale * 0.0001,
72
+ beta_end=scale * 0.02,
73
+ num_diffusion_timesteps=num_diffusion_timesteps,
74
+ )
75
+ elif schedule_name == "squaredcos_cap_v2":
76
+ return betas_for_alpha_bar(
77
+ num_diffusion_timesteps,
78
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
79
+ )
80
+ else:
81
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
82
+
83
+
84
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
85
+ """
86
+ Create a beta schedule that discretizes the given alpha_t_bar function,
87
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
88
+
89
+ :param num_diffusion_timesteps: the number of betas to produce.
90
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
91
+ produces the cumulative product of (1-beta) up to that
92
+ part of the diffusion process.
93
+ :param max_beta: the maximum beta to use; use values lower than 1 to
94
+ prevent singularities.
95
+ """
96
+ betas = []
97
+ for i in range(num_diffusion_timesteps):
98
+ t1 = i / num_diffusion_timesteps
99
+ t2 = (i + 1) / num_diffusion_timesteps
100
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
101
+ return np.array(betas)
102
+
103
+
104
+ class ModelMeanType(enum.Enum):
105
+ """
106
+ Which type of output the model predicts.
107
+ """
108
+
109
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
110
+ START_X = enum.auto() # the model predicts x_0
111
+ EPSILON = enum.auto() # the model predicts epsilon
112
+
113
+
114
+ class ModelVarType(enum.Enum):
115
+ """
116
+ What is used as the model's output variance.
117
+
118
+ The LEARNED_RANGE option has been added to allow the model to predict
119
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
120
+ """
121
+
122
+ LEARNED = enum.auto()
123
+ FIXED_SMALL = enum.auto()
124
+ FIXED_LARGE = enum.auto()
125
+ LEARNED_RANGE = enum.auto()
126
+
127
+
128
+ class LossType(enum.Enum):
129
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
130
+ RESCALED_MSE = (
131
+ enum.auto()
132
+ ) # use raw MSE loss (with RESCALED_KL when learning variances)
133
+ KL = enum.auto() # use the variational lower-bound
134
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
135
+
136
+ def is_vb(self):
137
+ return self == LossType.KL or self == LossType.RESCALED_KL
138
+
139
+
140
+ class GaussianDiffusion:
141
+ """
142
+ Utilities for training and sampling diffusion models.
143
+
144
+ Ported directly from here, and then adapted over time to further experimentation.
145
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
146
+
147
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
148
+ starting at T and going to 1.
149
+ :param model_mean_type: a ModelMeanType determining what the model outputs.
150
+ :param model_var_type: a ModelVarType determining how variance is output.
151
+ :param loss_type: a LossType determining the loss function to use.
152
+ :param rescale_timesteps: if True, pass floating point timesteps into the
153
+ model so that they are always scaled like in the
154
+ original paper (0 to 1000).
155
+ """
156
+
157
+ def __init__(
158
+ self,
159
+ *,
160
+ betas,
161
+ model_mean_type,
162
+ model_var_type,
163
+ loss_type,
164
+ rescale_timesteps=False,
165
+ ):
166
+ self.model_mean_type = model_mean_type
167
+ self.model_var_type = model_var_type
168
+ self.loss_type = loss_type
169
+ self.rescale_timesteps = rescale_timesteps
170
+
171
+ # Use float64 for accuracy.
172
+ betas = np.array(betas, dtype=np.float64)
173
+ self.betas = betas
174
+ assert len(betas.shape) == 1, "betas must be 1-D"
175
+ assert (betas > 0).all() and (betas <= 1).all()
176
+
177
+ self.num_timesteps = int(betas.shape[0])
178
+
179
+ alphas = 1.0 - betas
180
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
181
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
182
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
183
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
184
+
185
+ # calculations for diffusion q(x_t | x_{t-1}) and others
186
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
187
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
188
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
189
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
190
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
191
+
192
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
193
+ self.posterior_variance = (
194
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
195
+ )
196
+ # log calculation clipped because the posterior variance is 0 at the
197
+ # beginning of the diffusion chain.
198
+ self.posterior_log_variance_clipped = np.log(
199
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
200
+ )
201
+ self.posterior_mean_coef1 = (
202
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
203
+ )
204
+ self.posterior_mean_coef2 = (
205
+ (1.0 - self.alphas_cumprod_prev)
206
+ * np.sqrt(alphas)
207
+ / (1.0 - self.alphas_cumprod)
208
+ )
209
+
210
+ def q_mean_variance(self, x_start, t):
211
+ """
212
+ Get the distribution q(x_t | x_0).
213
+
214
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
215
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
216
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
217
+ """
218
+ mean = (
219
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
220
+ )
221
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
222
+ log_variance = _extract_into_tensor(
223
+ self.log_one_minus_alphas_cumprod, t, x_start.shape
224
+ )
225
+ return mean, variance, log_variance
226
+
227
+ def q_sample(self, x_start, t, noise=None):
228
+ """
229
+ Diffuse the data for a given number of diffusion steps.
230
+
231
+ In other words, sample from q(x_t | x_0).
232
+
233
+ :param x_start: the initial data batch.
234
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
235
+ :param noise: if specified, the split-out normal noise.
236
+ :return: A noisy version of x_start.
237
+ """
238
+ if noise is None:
239
+ noise = th.randn_like(x_start)
240
+ assert noise.shape == x_start.shape
241
+ return (
242
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
243
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
244
+ * noise
245
+ )
246
+
247
+ def q_posterior_mean_variance(self, x_start, x_t, t):
248
+ """
249
+ Compute the mean and variance of the diffusion posterior:
250
+
251
+ q(x_{t-1} | x_t, x_0)
252
+
253
+ """
254
+ assert x_start.shape == x_t.shape
255
+ posterior_mean = (
256
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
257
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
258
+ )
259
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
260
+ posterior_log_variance_clipped = _extract_into_tensor(
261
+ self.posterior_log_variance_clipped, t, x_t.shape
262
+ )
263
+ assert (
264
+ posterior_mean.shape[0]
265
+ == posterior_variance.shape[0]
266
+ == posterior_log_variance_clipped.shape[0]
267
+ == x_start.shape[0]
268
+ )
269
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
270
+
271
+ def p_mean_variance(
272
+ self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None
273
+ ):
274
+ """
275
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
276
+ the initial x, x_0.
277
+
278
+ :param model: the model, which takes a signal and a batch of timesteps
279
+ as input.
280
+ :param x: the [N x C x ...] tensor at time t.
281
+ :param t: a 1-D Tensor of timesteps.
282
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
283
+ :param denoised_fn: if not None, a function which applies to the
284
+ x_start prediction before it is used to sample. Applies before
285
+ clip_denoised.
286
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
287
+ pass to the model. This can be used for conditioning.
288
+ :return: a dict with the following keys:
289
+ - 'mean': the model mean output.
290
+ - 'variance': the model variance output.
291
+ - 'log_variance': the log of 'variance'.
292
+ - 'pred_xstart': the prediction for x_0.
293
+ """
294
+ if model_kwargs is None:
295
+ model_kwargs = {}
296
+
297
+ B, C = x.shape[:2]
298
+ assert t.shape == (B,)
299
+ model_output = model(x, self._scale_timesteps(t), **model_kwargs)
300
+
301
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
302
+ assert model_output.shape == (B, C * 2, *x.shape[2:])
303
+ model_output, model_var_values = th.split(model_output, C, dim=1)
304
+ if self.model_var_type == ModelVarType.LEARNED:
305
+ model_log_variance = model_var_values
306
+ model_variance = th.exp(model_log_variance)
307
+ else:
308
+ min_log = _extract_into_tensor(
309
+ self.posterior_log_variance_clipped, t, x.shape
310
+ )
311
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
312
+ # The model_var_values is [-1, 1] for [min_var, max_var].
313
+ frac = (model_var_values + 1) / 2
314
+ model_log_variance = frac * max_log + (1 - frac) * min_log
315
+ model_variance = th.exp(model_log_variance)
316
+ else:
317
+ model_variance, model_log_variance = {
318
+ # for fixedlarge, we set the initial (log-)variance like so
319
+ # to get a better decoder log likelihood.
320
+ ModelVarType.FIXED_LARGE: (
321
+ np.append(self.posterior_variance[1], self.betas[1:]),
322
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
323
+ ),
324
+ ModelVarType.FIXED_SMALL: (
325
+ self.posterior_variance,
326
+ self.posterior_log_variance_clipped,
327
+ ),
328
+ }[self.model_var_type]
329
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
330
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
331
+
332
+ def process_xstart(x):
333
+ if denoised_fn is not None:
334
+ x = denoised_fn(x)
335
+ if clip_denoised:
336
+ return x.clamp(-1, 1)
337
+ return x
338
+
339
+ if self.model_mean_type == ModelMeanType.PREVIOUS_X:
340
+ pred_xstart = process_xstart(
341
+ self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
342
+ )
343
+ model_mean = model_output
344
+ elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
345
+ if self.model_mean_type == ModelMeanType.START_X:
346
+ pred_xstart = process_xstart(model_output)
347
+ else:
348
+ pred_xstart = process_xstart(
349
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
350
+ )
351
+ model_mean, _, _ = self.q_posterior_mean_variance(
352
+ x_start=pred_xstart, x_t=x, t=t
353
+ )
354
+ else:
355
+ raise NotImplementedError(self.model_mean_type)
356
+
357
+ assert (
358
+ model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
359
+ )
360
+ return {
361
+ "mean": model_mean,
362
+ "variance": model_variance,
363
+ "log_variance": model_log_variance,
364
+ "pred_xstart": pred_xstart,
365
+ }
366
+
367
+ def _predict_xstart_from_eps(self, x_t, t, eps):
368
+ assert x_t.shape == eps.shape
369
+ return (
370
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
371
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
372
+ )
373
+
374
+ def _predict_xstart_from_xprev(self, x_t, t, xprev):
375
+ assert x_t.shape == xprev.shape
376
+ return ( # (xprev - coef2*x_t) / coef1
377
+ _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
378
+ - _extract_into_tensor(
379
+ self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
380
+ )
381
+ * x_t
382
+ )
383
+
384
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
385
+ return (
386
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
387
+ - pred_xstart
388
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
389
+
390
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
391
+ """
392
+ Compute the mean for the previous step, given a function cond_fn that
393
+ computes the gradient of a conditional log probability with respect to
394
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
395
+ condition on y.
396
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
397
+ """
398
+ gradient = cond_fn(x, t, **model_kwargs)
399
+ new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
400
+ return new_mean
401
+
402
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
403
+ """
404
+ Compute what the p_mean_variance output would have been, should the
405
+ model's score function be conditioned by cond_fn.
406
+ See condition_mean() for details on cond_fn.
407
+ Unlike condition_mean(), this instead uses the conditioning strategy
408
+ from Song et al (2020).
409
+ """
410
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
411
+
412
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
413
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
414
+
415
+ out = p_mean_var.copy()
416
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
417
+ out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
418
+
419
+ def _scale_timesteps(self, t):
420
+ if self.rescale_timesteps:
421
+ return t.float() * (1000.0 / self.num_timesteps)
422
+ return t
423
+
424
+ def p_sample(
425
+ self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None
426
+ ):
427
+ """
428
+ Sample x_{t-1} from the model at the given timestep.
429
+
430
+ :param model: the model to sample from.
431
+ :param x: the current tensor at x_{t-1}.
432
+ :param t: the value of t, starting at 0 for the first diffusion step.
433
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
434
+ :param denoised_fn: if not None, a function which applies to the
435
+ x_start prediction before it is used to sample.
436
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
437
+ pass to the model. This can be used for conditioning.
438
+ :return: a dict containing the following keys:
439
+ - 'sample': a random sample from the model.
440
+ - 'pred_xstart': a prediction of x_0.
441
+ """
442
+ out = self.p_mean_variance(
443
+ model,
444
+ x,
445
+ t,
446
+ clip_denoised=clip_denoised,
447
+ denoised_fn=denoised_fn,
448
+ model_kwargs=model_kwargs,
449
+ )
450
+ noise = th.randn_like(x)
451
+ nonzero_mask = (
452
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
453
+ ) # no noise when t == 0
454
+ if cond_fn is not None:
455
+ out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
456
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
457
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
458
+
459
+ def p_sample_loop(
460
+ self,
461
+ model,
462
+ shape,
463
+ noise=None,
464
+ clip_denoised=True,
465
+ denoised_fn=None,
466
+ cond_fn=None,
467
+ model_kwargs=None,
468
+ device=None,
469
+ progress=False,
470
+ call_back=None,
471
+ start_step=None
472
+ ):
473
+ """
474
+ Generate samples from the model.
475
+
476
+ :param model: the model module.
477
+ :param shape: the shape of the samples, (N, C, H, W).
478
+ :param noise: if specified, the noise from the encoder to sample.
479
+ Should be of the same shape as `shape`.
480
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
481
+ :param denoised_fn: if not None, a function which applies to the
482
+ x_start prediction before it is used to sample.
483
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
484
+ pass to the model. This can be used for conditioning.
485
+ :param device: if specified, the device to create the samples on.
486
+ If not specified, use a model parameter's device.
487
+ :param progress: if True, show a tqdm progress bar.
488
+ :return: a non-differentiable batch of samples.
489
+ """
490
+ final = None
491
+ for sample in self.p_sample_loop_progressive(
492
+ model,
493
+ shape,
494
+ noise=noise,
495
+ clip_denoised=clip_denoised,
496
+ denoised_fn=denoised_fn,
497
+ cond_fn=cond_fn,
498
+ model_kwargs=model_kwargs,
499
+ device=device,
500
+ progress=progress,
501
+ call_back=call_back,
502
+ start_step=start_step
503
+ ):
504
+ final = sample
505
+ return final["sample"]
506
+
507
+ def p_sample_loop_progressive(
508
+ self,
509
+ model,
510
+ shape,
511
+ noise=None,
512
+ clip_denoised=True,
513
+ denoised_fn=None,
514
+ cond_fn=None,
515
+ model_kwargs=None,
516
+ device=None,
517
+ progress=False,
518
+ call_back=None,
519
+ start_step=None
520
+ ):
521
+ """
522
+ Generate samples from the model and yield intermediate samples from
523
+ each timestep of diffusion.
524
+
525
+ Arguments are the same as p_sample_loop().
526
+ Returns a generator over dicts, where each dict is the return value of
527
+ p_sample().
528
+ """
529
+ if device is None:
530
+ device = next(model.parameters()).device
531
+ assert isinstance(shape, (tuple, list))
532
+ if noise is not None:
533
+ img = noise
534
+ else:
535
+ img = th.randn(*shape, device=device)
536
+ indices = list(range(self.num_timesteps if not start_step else start_step))[::-1]
537
+
538
+ if progress:
539
+ # Lazy import so that we don't depend on tqdm.
540
+ from tqdm.auto import tqdm
541
+
542
+ indices = tqdm(indices)
543
+
544
+ for i in indices:
545
+ t = th.tensor([i] * shape[0], device=device)
546
+ with th.no_grad():
547
+ out = self.p_sample(
548
+ model,
549
+ img,
550
+ t,
551
+ clip_denoised=clip_denoised,
552
+ denoised_fn=denoised_fn,
553
+ cond_fn=cond_fn,
554
+ model_kwargs=model_kwargs,
555
+ )
556
+
557
+ if call_back:
558
+ call_back(out,t[0].item())
559
+ yield out
560
+ img = out["sample"]
561
+
562
+ def ddim_sample(
563
+ self,
564
+ model,
565
+ x,
566
+ t,
567
+ clip_denoised=True,
568
+ denoised_fn=None,
569
+ cond_fn=None,
570
+ model_kwargs=None,
571
+ eta=0.0,
572
+ ):
573
+ """
574
+ Sample x_{t-1} from the model using DDIM.
575
+
576
+ Same usage as p_sample().
577
+ """
578
+ out = self.p_mean_variance(
579
+ model,
580
+ x,
581
+ t,
582
+ clip_denoised=clip_denoised,
583
+ denoised_fn=denoised_fn,
584
+ model_kwargs=model_kwargs,
585
+ )
586
+
587
+ if cond_fn is not None:
588
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
589
+ # Usually our model outputs epsilon, but we re-derive it
590
+ # in case we used x_start or x_prev prediction.
591
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
592
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
593
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
594
+ sigma = (
595
+ eta
596
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
597
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
598
+ )
599
+ # Equation 12.
600
+ noise = th.randn_like(x)
601
+ mean_pred = (
602
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
603
+ + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
604
+ )
605
+ nonzero_mask = (
606
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
607
+ ) # no noise when t == 0
608
+ sample = mean_pred + nonzero_mask * sigma * noise
609
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
610
+
611
+ def ddim_reverse_sample(
612
+ self,
613
+ model,
614
+ x,
615
+ t,
616
+ clip_denoised=True,
617
+ denoised_fn=None,
618
+ cond_fn=None,
619
+ model_kwargs=None,
620
+ eta=0.0,
621
+ ):
622
+ """
623
+ Sample x_{t+1} from the model using DDIM reverse ODE.
624
+ """
625
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
626
+ out = self.p_mean_variance(
627
+ model,
628
+ x,
629
+ t,
630
+ clip_denoised=clip_denoised,
631
+ denoised_fn=denoised_fn,
632
+ model_kwargs=model_kwargs,
633
+ )
634
+ if cond_fn is not None:
635
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
636
+ # Usually our model outputs epsilon, but we re-derive it
637
+ # in case we used x_start or x_prev prediction.
638
+ eps = (
639
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
640
+ - out["pred_xstart"]
641
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
642
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
643
+
644
+ # Equation 12. reversed
645
+ mean_pred = (
646
+ out["pred_xstart"] * th.sqrt(alpha_bar_next)
647
+ + th.sqrt(1 - alpha_bar_next) * eps
648
+ )
649
+
650
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
651
+
652
+ def ddim_sample_loop(
653
+ self,
654
+ model,
655
+ shape,
656
+ noise=None,
657
+ clip_denoised=True,
658
+ denoised_fn=None,
659
+ cond_fn=None,
660
+ model_kwargs=None,
661
+ device=None,
662
+ progress=False,
663
+ eta=0.0,
664
+ ):
665
+ """
666
+ Generate samples from the model using DDIM.
667
+
668
+ Same usage as p_sample_loop().
669
+ """
670
+ final = None
671
+ for sample in self.ddim_sample_loop_progressive(
672
+ model,
673
+ shape,
674
+ noise=noise,
675
+ clip_denoised=clip_denoised,
676
+ denoised_fn=denoised_fn,
677
+ cond_fn=cond_fn,
678
+ model_kwargs=model_kwargs,
679
+ device=device,
680
+ progress=progress,
681
+ eta=eta,
682
+ ):
683
+ final = sample
684
+ return final["sample"]
685
+
686
+ def ddim_sample_loop_progressive(
687
+ self,
688
+ model,
689
+ shape,
690
+ noise=None,
691
+ clip_denoised=True,
692
+ denoised_fn=None,
693
+ cond_fn=None,
694
+ model_kwargs=None,
695
+ device=None,
696
+ progress=False,
697
+ eta=0.0,
698
+ ):
699
+ """
700
+ Use DDIM to sample from the model and yield intermediate samples from
701
+ each timestep of DDIM.
702
+
703
+ Same usage as p_sample_loop_progressive().
704
+ """
705
+ if device is None:
706
+ device = next(model.parameters()).device
707
+ assert isinstance(shape, (tuple, list))
708
+ if noise is not None:
709
+ img = noise
710
+ else:
711
+ img = th.randn(*shape, device=device)
712
+ indices = list(range(self.num_timesteps))[::-1]
713
+
714
+ if progress:
715
+ # Lazy import so that we don't depend on tqdm.
716
+ from tqdm.auto import tqdm
717
+
718
+ indices = tqdm(indices)
719
+
720
+ for i in indices:
721
+ t = th.tensor([i] * shape[0], device=device)
722
+ with th.no_grad():
723
+ out = self.ddim_sample(
724
+ model,
725
+ img,
726
+ t,
727
+ clip_denoised=clip_denoised,
728
+ denoised_fn=denoised_fn,
729
+ cond_fn=cond_fn,
730
+ model_kwargs=model_kwargs,
731
+ eta=eta,
732
+ )
733
+ yield out
734
+ img = out["sample"]
735
+
736
+ def _vb_terms_bpd(
737
+ self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
738
+ ):
739
+ """
740
+ Get a term for the variational lower-bound.
741
+
742
+ The resulting units are bits (rather than nats, as one might expect).
743
+ This allows for comparison to other papers.
744
+
745
+ :return: a dict with the following keys:
746
+ - 'output': a shape [N] tensor of NLLs or KLs.
747
+ - 'pred_xstart': the x_0 predictions.
748
+ """
749
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
750
+ x_start=x_start, x_t=x_t, t=t
751
+ )
752
+ out = self.p_mean_variance(
753
+ model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
754
+ )
755
+ kl = normal_kl(
756
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
757
+ )
758
+ kl = mean_flat(kl) / np.log(2.0)
759
+
760
+ decoder_nll = -discretized_gaussian_log_likelihood(
761
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
762
+ )
763
+ assert decoder_nll.shape == x_start.shape
764
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
765
+
766
+ # At the first timestep return the decoder NLL,
767
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
768
+ output = th.where((t == 0), decoder_nll, kl)
769
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
770
+
771
+ def training_losses(self, model, x_start, t, vgg, adv, model_kwargs=None, noise=None):
772
+ """
773
+ Compute training losses for a single timestep.
774
+
775
+ :param model: the model to evaluate loss on.
776
+ :param x_start: the [N x C x ...] tensor of inputs.
777
+ :param t: a batch of timestep indices.
778
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
779
+ pass to the model. This can be used for conditioning.
780
+ :param noise: if specified, the specific Gaussian noise to try to remove.
781
+ :return: a dict with the key "loss" containing a tensor of shape [N].
782
+ Some mean or variance settings may also have other keys.
783
+ """
784
+ if model_kwargs is None:
785
+ model_kwargs = {}
786
+ if noise is None:
787
+ noise = th.randn_like(x_start)
788
+ x_t = self.q_sample(x_start, t, noise=noise)
789
+
790
+ terms = {}
791
+
792
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
793
+ terms["loss"] = self._vb_terms_bpd(
794
+ model=model,
795
+ x_start=x_start,
796
+ x_t=x_t,
797
+ t=t,
798
+ clip_denoised=False,
799
+ model_kwargs=model_kwargs,
800
+ )["output"]
801
+ if self.loss_type == LossType.RESCALED_KL:
802
+ terms["loss"] *= self.num_timesteps
803
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
804
+ model_output = model(x_t, self._scale_timesteps(t), **model_kwargs)
805
+
806
+ if self.model_var_type in [
807
+ ModelVarType.LEARNED,
808
+ ModelVarType.LEARNED_RANGE,
809
+ ]:
810
+ B, C = x_t.shape[:2]
811
+ assert model_output.shape == (B, C * 2, *x_t.shape[2:])
812
+ model_output, model_var_values = th.split(model_output, C, dim=1)
813
+ # Learn the variance using the variational bound, but don't let
814
+ # it affect our mean prediction.
815
+ frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
816
+ terms["vb"] = self._vb_terms_bpd(
817
+ model=lambda *args, r=frozen_out: r,
818
+ x_start=x_start,
819
+ x_t=x_t,
820
+ t=t,
821
+ clip_denoised=False,
822
+ )["output"]
823
+ if self.loss_type == LossType.RESCALED_MSE:
824
+ # Divide by 1000 for equivalence with initial implementation.
825
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
826
+ terms["vb"] *= self.num_timesteps / 1000.0
827
+
828
+ target = {
829
+ ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
830
+ x_start=x_start, x_t=x_t, t=t
831
+ )[0],
832
+ ModelMeanType.START_X: x_start,
833
+ ModelMeanType.EPSILON: noise,
834
+ }[self.model_mean_type]
835
+ assert model_output.shape == target.shape == x_start.shape
836
+ terms["mse"] = mean_flat((target - model_output) ** 2)
837
+
838
+ if vgg is not None:
839
+ terms["perceptual"] = vgg(self._predict_xstart_from_eps(x_t=x_t, t=t, eps=model_output).clamp(-1, 1) , x_start) * 0.004
840
+ else:
841
+ terms["perceptual"] = 0.
842
+
843
+ if adv is not None:
844
+ terms["adv"] = adv(self._predict_xstart_from_eps(x_t=x_t, t=t, eps=model_output).clamp(-1, 1) , x_start) * 0.01
845
+ else:
846
+ terms["adv"] = 0.
847
+
848
+ if "vb" in terms:
849
+ terms["loss"] = terms["mse"] + terms["vb"] + terms["perceptual"] + terms["adv"]
850
+ else:
851
+ terms["loss"] = terms["mse"]
852
+ else:
853
+ raise NotImplementedError(self.loss_type)
854
+
855
+ return terms
856
+
857
+ def _prior_bpd(self, x_start):
858
+ """
859
+ Get the prior KL term for the variational lower-bound, measured in
860
+ bits-per-dim.
861
+
862
+ This term can't be optimized, as it only depends on the encoder.
863
+
864
+ :param x_start: the [N x C x ...] tensor of inputs.
865
+ :return: a batch of [N] KL values (in bits), one per batch element.
866
+ """
867
+ batch_size = x_start.shape[0]
868
+ t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
869
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
870
+ kl_prior = normal_kl(
871
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
872
+ )
873
+ return mean_flat(kl_prior) / np.log(2.0)
874
+
875
+ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
876
+ """
877
+ Compute the entire variational lower-bound, measured in bits-per-dim,
878
+ as well as other related quantities.
879
+
880
+ :param model: the model to evaluate loss on.
881
+ :param x_start: the [N x C x ...] tensor of inputs.
882
+ :param clip_denoised: if True, clip denoised samples.
883
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
884
+ pass to the model. This can be used for conditioning.
885
+
886
+ :return: a dict containing the following keys:
887
+ - total_bpd: the total variational lower-bound, per batch element.
888
+ - prior_bpd: the prior term in the lower-bound.
889
+ - vb: an [N x T] tensor of terms in the lower-bound.
890
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
891
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
892
+ """
893
+ device = x_start.device
894
+ batch_size = x_start.shape[0]
895
+
896
+ vb = []
897
+ xstart_mse = []
898
+ mse = []
899
+ for t in list(range(self.num_timesteps))[::-1]:
900
+ t_batch = th.tensor([t] * batch_size, device=device)
901
+ noise = th.randn_like(x_start)
902
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
903
+ # Calculate VLB term at the current timestep
904
+ with th.no_grad():
905
+ out = self._vb_terms_bpd(
906
+ model,
907
+ x_start=x_start,
908
+ x_t=x_t,
909
+ t=t_batch,
910
+ clip_denoised=clip_denoised,
911
+ model_kwargs=model_kwargs,
912
+ )
913
+ vb.append(out["output"])
914
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
915
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
916
+ mse.append(mean_flat((eps - noise) ** 2))
917
+
918
+ vb = th.stack(vb, dim=1)
919
+ xstart_mse = th.stack(xstart_mse, dim=1)
920
+ mse = th.stack(mse, dim=1)
921
+
922
+ prior_bpd = self._prior_bpd(x_start)
923
+ total_bpd = vb.sum(dim=1) + prior_bpd
924
+ return {
925
+ "total_bpd": total_bpd,
926
+ "prior_bpd": prior_bpd,
927
+ "vb": vb,
928
+ "xstart_mse": xstart_mse,
929
+ "mse": mse,
930
+ }
931
+
932
+
933
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
934
+ """
935
+ Extract values from a 1-D numpy array for a batch of indices.
936
+
937
+ :param arr: the 1-D numpy array.
938
+ :param timesteps: a tensor of indices into the array to extract.
939
+ :param broadcast_shape: a larger shape of K dimensions with the batch
940
+ dimension equal to the length of timesteps.
941
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
942
+ """
943
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
944
+ while len(res.shape) < len(broadcast_shape):
945
+ res = res[..., None]
946
+ return res.expand(broadcast_shape)
glide_text2im/glide_util.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Tuple
3
+ from . import dist_util
4
+ import PIL
5
+ import numpy as np
6
+ import torch as th
7
+ from .script_util import (
8
+ create_gaussian_diffusion,
9
+ create_model_and_diffusion,
10
+ model_and_diffusion_defaults,
11
+ )
12
+
13
+ # Sample from the base model.
14
+
15
+ #@th.inference_mode()
16
+ def sample(
17
+ glide_model,
18
+ glide_options,
19
+ side_x,
20
+ side_y,
21
+ prompt,
22
+ batch_size=1,
23
+ guidance_scale=4,
24
+ device="cpu",
25
+ prediction_respacing="100",
26
+ upsample_enabled=False,
27
+ upsample_temp=0.997,
28
+ mode = '',
29
+ ):
30
+
31
+ eval_diffusion = create_gaussian_diffusion(
32
+ steps=glide_options["diffusion_steps"],
33
+ learn_sigma=glide_options["learn_sigma"],
34
+ noise_schedule=glide_options["noise_schedule"],
35
+ predict_xstart=glide_options["predict_xstart"],
36
+ rescale_timesteps=glide_options["rescale_timesteps"],
37
+ rescale_learned_sigmas=glide_options["rescale_learned_sigmas"],
38
+ timestep_respacing=prediction_respacing
39
+ )
40
+
41
+ # Create the classifier-free guidance tokens (empty)
42
+ full_batch_size = batch_size * 2
43
+ cond_ref = prompt['ref']
44
+ uncond_ref = th.ones_like(cond_ref)
45
+
46
+ model_kwargs = {}
47
+ model_kwargs['ref'] = th.cat([cond_ref, uncond_ref], 0).to(dist_util.dev())
48
+
49
+ def cfg_model_fn(x_t, ts, **kwargs):
50
+ half = x_t[: len(x_t) // 2]
51
+ combined = th.cat([half, half], dim=0)
52
+ model_out = glide_model(combined, ts, **kwargs)
53
+ eps, rest = model_out[:, :3], model_out[:, 3:]
54
+ cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)
55
+
56
+ half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
57
+
58
+ eps = th.cat([half_eps, half_eps], dim=0)
59
+ return th.cat([eps, rest], dim=1)
60
+
61
+
62
+ if upsample_enabled:
63
+ model_kwargs['low_res'] = prompt['low_res'].to(dist_util.dev())
64
+ noise = th.randn((batch_size, 3, side_y, side_x), device=device) * upsample_temp
65
+ model_fn = glide_model # just use the base model, no need for CFG.
66
+ model_kwargs['ref'] = model_kwargs['ref'][:batch_size]
67
+
68
+ samples = eval_diffusion.p_sample_loop(
69
+ model_fn,
70
+ (batch_size, 3, side_y, side_x), # only thing that's changed
71
+ noise=noise,
72
+ device=device,
73
+ clip_denoised=True,
74
+ progress=False,
75
+ model_kwargs=model_kwargs,
76
+ cond_fn=None,
77
+ )[:batch_size]
78
+
79
+ else:
80
+ model_fn = cfg_model_fn # so we use CFG for the base model.
81
+ noise = th.randn((batch_size, 3, side_y, side_x), device=device)
82
+ noise = th.cat([noise, noise], 0)
83
+ samples = eval_diffusion.p_sample_loop(
84
+ model_fn,
85
+ (full_batch_size, 3, side_y, side_x), # only thing that's changed
86
+ noise=noise,
87
+ device=device,
88
+ clip_denoised=True,
89
+ progress=False,
90
+ model_kwargs=model_kwargs,
91
+ cond_fn=None,
92
+ )[:batch_size]
93
+
94
+ return samples
95
+
96
+
glide_text2im/logger.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Logger copied from OpenAI baselines to avoid extra RL-based dependencies:
3
+ https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import shutil
9
+ import os.path as osp
10
+ import json
11
+ import time
12
+ import datetime
13
+ import tempfile
14
+ import warnings
15
+ from collections import defaultdict
16
+ from contextlib import contextmanager
17
+
18
+ DEBUG = 10
19
+ INFO = 20
20
+ WARN = 30
21
+ ERROR = 40
22
+
23
+ DISABLED = 50
24
+
25
+
26
+ class KVWriter(object):
27
+ def writekvs(self, kvs):
28
+ raise NotImplementedError
29
+
30
+
31
+ class SeqWriter(object):
32
+ def writeseq(self, seq):
33
+ raise NotImplementedError
34
+
35
+
36
+ class HumanOutputFormat(KVWriter, SeqWriter):
37
+ def __init__(self, filename_or_file):
38
+ if isinstance(filename_or_file, str):
39
+ self.file = open(filename_or_file, "wt")
40
+ self.own_file = True
41
+ else:
42
+ assert hasattr(filename_or_file, "read"), (
43
+ "expected file or str, got %s" % filename_or_file
44
+ )
45
+ self.file = filename_or_file
46
+ self.own_file = False
47
+
48
+ def writekvs(self, kvs):
49
+ # Create strings for printing
50
+ key2str = {}
51
+ for (key, val) in sorted(kvs.items()):
52
+ if hasattr(val, "__float__"):
53
+ valstr = "%-8.3g" % val
54
+ else:
55
+ valstr = str(val)
56
+ key2str[self._truncate(key)] = self._truncate(valstr)
57
+
58
+ # Find max widths
59
+ if len(key2str) == 0:
60
+ print("WARNING: tried to write empty key-value dict")
61
+ return
62
+ else:
63
+ keywidth = max(map(len, key2str.keys()))
64
+ valwidth = max(map(len, key2str.values()))
65
+
66
+ # Write out the data
67
+ dashes = "-" * (keywidth + valwidth + 7)
68
+ lines = [dashes]
69
+ for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()):
70
+ lines.append(
71
+ "| %s%s | %s%s |"
72
+ % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val)))
73
+ )
74
+ lines.append(dashes)
75
+ self.file.write("\n".join(lines) + "\n")
76
+
77
+ # Flush the output to the file
78
+ self.file.flush()
79
+
80
+ def _truncate(self, s):
81
+ maxlen = 30
82
+ return s[: maxlen - 3] + "..." if len(s) > maxlen else s
83
+
84
+ def writeseq(self, seq):
85
+ seq = list(seq)
86
+ for (i, elem) in enumerate(seq):
87
+ self.file.write(elem)
88
+ if i < len(seq) - 1: # add space unless this is the last one
89
+ self.file.write(" ")
90
+ self.file.write("\n")
91
+ self.file.flush()
92
+
93
+ def close(self):
94
+ if self.own_file:
95
+ self.file.close()
96
+
97
+
98
+ class JSONOutputFormat(KVWriter):
99
+ def __init__(self, filename):
100
+ self.file = open(filename, "wt")
101
+
102
+ def writekvs(self, kvs):
103
+ for k, v in sorted(kvs.items()):
104
+ if hasattr(v, "dtype"):
105
+ kvs[k] = float(v)
106
+ self.file.write(json.dumps(kvs) + "\n")
107
+ self.file.flush()
108
+
109
+ def close(self):
110
+ self.file.close()
111
+
112
+
113
+ class CSVOutputFormat(KVWriter):
114
+ def __init__(self, filename):
115
+ self.file = open(filename, "w+t")
116
+ self.keys = []
117
+ self.sep = ","
118
+
119
+ def writekvs(self, kvs):
120
+ # Add our current row to the history
121
+ extra_keys = list(kvs.keys() - self.keys)
122
+ extra_keys.sort()
123
+ if extra_keys:
124
+ self.keys.extend(extra_keys)
125
+ self.file.seek(0)
126
+ lines = self.file.readlines()
127
+ self.file.seek(0)
128
+ for (i, k) in enumerate(self.keys):
129
+ if i > 0:
130
+ self.file.write(",")
131
+ self.file.write(k)
132
+ self.file.write("\n")
133
+ for line in lines[1:]:
134
+ self.file.write(line[:-1])
135
+ self.file.write(self.sep * len(extra_keys))
136
+ self.file.write("\n")
137
+ for (i, k) in enumerate(self.keys):
138
+ if i > 0:
139
+ self.file.write(",")
140
+ v = kvs.get(k)
141
+ if v is not None:
142
+ self.file.write(str(v))
143
+ self.file.write("\n")
144
+ self.file.flush()
145
+
146
+ def close(self):
147
+ self.file.close()
148
+
149
+
150
+ class TensorBoardOutputFormat(KVWriter):
151
+ """
152
+ Dumps key/value pairs into TensorBoard's numeric format.
153
+ """
154
+
155
+ def __init__(self, dir):
156
+ os.makedirs(dir, exist_ok=True)
157
+ self.dir = dir
158
+ self.step = 1
159
+ prefix = "events"
160
+ path = osp.join(osp.abspath(dir), prefix)
161
+ import tensorflow as tf
162
+ from tensorflow.python import pywrap_tensorflow
163
+ from tensorflow.core.util import event_pb2
164
+ from tensorflow.python.util import compat
165
+
166
+ self.tf = tf
167
+ self.event_pb2 = event_pb2
168
+ self.pywrap_tensorflow = pywrap_tensorflow
169
+ self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
170
+
171
+ def writekvs(self, kvs):
172
+ def summary_val(k, v):
173
+ kwargs = {"tag": k, "simple_value": float(v)}
174
+ return self.tf.Summary.Value(**kwargs)
175
+
176
+ summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()])
177
+ event = self.event_pb2.Event(wall_time=time.time(), summary=summary)
178
+ event.step = (
179
+ self.step
180
+ ) # is there any reason why you'd want to specify the step?
181
+ self.writer.WriteEvent(event)
182
+ self.writer.Flush()
183
+ self.step += 1
184
+
185
+ def close(self):
186
+ if self.writer:
187
+ self.writer.Close()
188
+ self.writer = None
189
+
190
+
191
+ def make_output_format(format, ev_dir, log_suffix=""):
192
+ os.makedirs(ev_dir, exist_ok=True)
193
+ if format == "stdout":
194
+ return HumanOutputFormat(sys.stdout)
195
+ elif format == "log":
196
+ return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix))
197
+ elif format == "json":
198
+ return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix))
199
+ elif format == "csv":
200
+ return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix))
201
+ elif format == "tensorboard":
202
+ return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix))
203
+ else:
204
+ raise ValueError("Unknown format specified: %s" % (format,))
205
+
206
+
207
+ # ================================================================
208
+ # API
209
+ # ================================================================
210
+
211
+
212
+ def logkv(key, val):
213
+ """
214
+ Log a value of some diagnostic
215
+ Call this once for each diagnostic quantity, each iteration
216
+ If called many times, last value will be used.
217
+ """
218
+ get_current().logkv(key, val)
219
+
220
+
221
+ def logkv_mean(key, val):
222
+ """
223
+ The same as logkv(), but if called many times, values averaged.
224
+ """
225
+ get_current().logkv_mean(key, val)
226
+
227
+
228
+ def logkvs(d):
229
+ """
230
+ Log a dictionary of key-value pairs
231
+ """
232
+ for (k, v) in d.items():
233
+ logkv(k, v)
234
+
235
+
236
+ def dumpkvs():
237
+ """
238
+ Write all of the diagnostics from the current iteration
239
+ """
240
+ return get_current().dumpkvs()
241
+
242
+
243
+ def getkvs():
244
+ return get_current().name2val
245
+
246
+
247
+ def log(*args, level=INFO):
248
+ """
249
+ Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
250
+ """
251
+ get_current().log(*args, level=level)
252
+
253
+
254
+ def debug(*args):
255
+ log(*args, level=DEBUG)
256
+
257
+
258
+ def info(*args):
259
+ log(*args, level=INFO)
260
+
261
+
262
+ def warn(*args):
263
+ log(*args, level=WARN)
264
+
265
+
266
+ def error(*args):
267
+ log(*args, level=ERROR)
268
+
269
+
270
+ def set_level(level):
271
+ """
272
+ Set logging threshold on current logger.
273
+ """
274
+ get_current().set_level(level)
275
+
276
+
277
+ def set_comm(comm):
278
+ get_current().set_comm(comm)
279
+
280
+ def save_args(args):
281
+ get_current().save_args(args)
282
+
283
+ def get_dir():
284
+ """
285
+ Get directory that log files are being written to.
286
+ will be None if there is no output directory (i.e., if you didn't call start)
287
+ """
288
+ return get_current().get_dir()
289
+
290
+
291
+ record_tabular = logkv
292
+ dump_tabular = dumpkvs
293
+
294
+
295
+ @contextmanager
296
+ def profile_kv(scopename):
297
+ logkey = "wait_" + scopename
298
+ tstart = time.time()
299
+ try:
300
+ yield
301
+ finally:
302
+ get_current().name2val[logkey] += time.time() - tstart
303
+
304
+
305
+ def profile(n):
306
+ """
307
+ Usage:
308
+ @profile("my_func")
309
+ def my_func(): code
310
+ """
311
+
312
+ def decorator_with_name(func):
313
+ def func_wrapper(*args, **kwargs):
314
+ with profile_kv(n):
315
+ return func(*args, **kwargs)
316
+
317
+ return func_wrapper
318
+
319
+ return decorator_with_name
320
+
321
+
322
+ # ================================================================
323
+ # Backend
324
+ # ================================================================
325
+
326
+
327
+ def get_current():
328
+ if Logger.CURRENT is None:
329
+ _configure_default_logger()
330
+
331
+ return Logger.CURRENT
332
+
333
+
334
+ class Logger(object):
335
+ DEFAULT = None # A logger with no output files. (See right below class definition)
336
+ # So that you can still log to the terminal without setting up any output files
337
+ CURRENT = None # Current logger being used by the free functions above
338
+
339
+ def __init__(self, dir, output_formats, comm=None):
340
+ self.name2val = defaultdict(float) # values this iteration
341
+ self.name2cnt = defaultdict(int)
342
+ self.level = INFO
343
+ self.dir = dir
344
+ self.output_formats = output_formats
345
+ self.comm = comm
346
+
347
+ def save_args(self,args):
348
+ with open(osp.join(self.dir,"args.json"),'w') as f:
349
+ json.dump(args,f)
350
+
351
+ # Logging API, forwarded
352
+ # ----------------------------------------
353
+ def logkv(self, key, val):
354
+ self.name2val[key] = val
355
+
356
+ def logkv_mean(self, key, val):
357
+ oldval, cnt = self.name2val[key], self.name2cnt[key]
358
+ self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1)
359
+ self.name2cnt[key] = cnt + 1
360
+
361
+ def dumpkvs(self):
362
+ if self.comm is None:
363
+ d = self.name2val
364
+ else:
365
+ d = mpi_weighted_mean(
366
+ self.comm,
367
+ {
368
+ name: (val, self.name2cnt.get(name, 1))
369
+ for (name, val) in self.name2val.items()
370
+ },
371
+ )
372
+ if self.comm.rank != 0:
373
+ d["dummy"] = 1 # so we don't get a warning about empty dict
374
+ out = d.copy() # Return the dict for unit testing purposes
375
+ for fmt in self.output_formats:
376
+ if isinstance(fmt, KVWriter):
377
+ fmt.writekvs(d)
378
+ self.name2val.clear()
379
+ self.name2cnt.clear()
380
+ return out
381
+
382
+ def log(self, *args, level=INFO):
383
+ if self.level <= level:
384
+ self._do_log(args)
385
+
386
+ # Configuration
387
+ # ----------------------------------------
388
+ def set_level(self, level):
389
+ self.level = level
390
+
391
+ def set_comm(self, comm):
392
+ self.comm = comm
393
+
394
+ def get_dir(self):
395
+ return self.dir
396
+
397
+ def close(self):
398
+ for fmt in self.output_formats:
399
+ fmt.close()
400
+
401
+ # Misc
402
+ # ----------------------------------------
403
+ def _do_log(self, args):
404
+ for fmt in self.output_formats:
405
+ if isinstance(fmt, SeqWriter):
406
+ fmt.writeseq(map(str, args))
407
+
408
+
409
+ def get_rank_without_mpi_import():
410
+ # check environment variables here instead of importing mpi4py
411
+ # to avoid calling MPI_Init() when this module is imported
412
+ for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]:
413
+ if varname in os.environ:
414
+ return int(os.environ[varname])
415
+ return 0
416
+
417
+
418
+ def mpi_weighted_mean(comm, local_name2valcount):
419
+ """
420
+ Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110
421
+ Perform a weighted average over dicts that are each on a different node
422
+ Input: local_name2valcount: dict mapping key -> (value, count)
423
+ Returns: key -> mean
424
+ """
425
+ all_name2valcount = comm.gather(local_name2valcount)
426
+ if comm.rank == 0:
427
+ name2sum = defaultdict(float)
428
+ name2count = defaultdict(float)
429
+ for n2vc in all_name2valcount:
430
+ for (name, (val, count)) in n2vc.items():
431
+ try:
432
+ val = float(val)
433
+ except ValueError:
434
+ if comm.rank == 0:
435
+ warnings.warn(
436
+ "WARNING: tried to compute mean on non-float {}={}".format(
437
+ name, val
438
+ )
439
+ )
440
+ else:
441
+ name2sum[name] += val * count
442
+ name2count[name] += count
443
+ return {name: name2sum[name] / name2count[name] for name in name2sum}
444
+ else:
445
+ return {}
446
+
447
+
448
+ def configure(dir=None, format_strs=None, comm=None, log_suffix=""):
449
+ """
450
+ If comm is provided, average all numerical stats across that comm
451
+ """
452
+ if dir is None:
453
+ dir = os.getenv("LOGDIR")
454
+
455
+ assert isinstance(dir, str)
456
+ dir = os.path.expanduser(dir)
457
+ os.makedirs(os.path.expanduser(dir), exist_ok=True)
458
+
459
+ rank = get_rank_without_mpi_import()
460
+ if rank > 0:
461
+ log_suffix = log_suffix + "-rank%03i" % rank
462
+
463
+ if format_strs is None:
464
+ if rank == 0:
465
+ format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",")
466
+ else:
467
+ format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",")
468
+ format_strs = filter(None, format_strs)
469
+ output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]
470
+
471
+ Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm)
472
+ if output_formats:
473
+ log("Logging to %s" % dir)
474
+
475
+
476
+ def _configure_default_logger():
477
+ configure()
478
+ Logger.DEFAULT = Logger.CURRENT
479
+
480
+
481
+ def reset():
482
+ if Logger.CURRENT is not Logger.DEFAULT:
483
+ Logger.CURRENT.close()
484
+ Logger.CURRENT = Logger.DEFAULT
485
+ log("Reset logger")
486
+
487
+
488
+ @contextmanager
489
+ def scoped_configure(dir=None, format_strs=None, comm=None):
490
+ prevlogger = Logger.CURRENT
491
+ configure(dir=dir, format_strs=format_strs, comm=comm)
492
+ try:
493
+ yield
494
+ finally:
495
+ Logger.CURRENT.close()
496
+ Logger.CURRENT = prevlogger
497
+
glide_text2im/losses.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helpers for various likelihood-based losses. These are ported from the original
3
+ Ho et al. diffusion models codebase:
4
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py
5
+ """
6
+
7
+ import numpy as np
8
+
9
+ import torch as th
10
+
11
+
12
+ def normal_kl(mean1, logvar1, mean2, logvar2):
13
+ """
14
+ Compute the KL divergence between two gaussians.
15
+
16
+ Shapes are automatically broadcasted, so batches can be compared to
17
+ scalars, among other use cases.
18
+ """
19
+ tensor = None
20
+ for obj in (mean1, logvar1, mean2, logvar2):
21
+ if isinstance(obj, th.Tensor):
22
+ tensor = obj
23
+ break
24
+ assert tensor is not None, "at least one argument must be a Tensor"
25
+
26
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
27
+ # Tensors, but it does not work for th.exp().
28
+ logvar1, logvar2 = [
29
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
30
+ for x in (logvar1, logvar2)
31
+ ]
32
+
33
+ return 0.5 * (
34
+ -1.0
35
+ + logvar2
36
+ - logvar1
37
+ + th.exp(logvar1 - logvar2)
38
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
39
+ )
40
+
41
+
42
+ def approx_standard_normal_cdf(x):
43
+ """
44
+ A fast approximation of the cumulative distribution function of the
45
+ standard normal.
46
+ """
47
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
48
+
49
+
50
+ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
51
+ """
52
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
53
+ given image.
54
+
55
+ :param x: the target images. It is assumed that this was uint8 values,
56
+ rescaled to the range [-1, 1].
57
+ :param means: the Gaussian mean Tensor.
58
+ :param log_scales: the Gaussian log stddev Tensor.
59
+ :return: a tensor like x of log probabilities (in nats).
60
+ """
61
+ assert x.shape == means.shape == log_scales.shape
62
+ centered_x = x - means
63
+ inv_stdv = th.exp(-log_scales)
64
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
65
+ cdf_plus = approx_standard_normal_cdf(plus_in)
66
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
67
+ cdf_min = approx_standard_normal_cdf(min_in)
68
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
69
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
70
+ cdf_delta = cdf_plus - cdf_min
71
+ log_probs = th.where(
72
+ x < -0.999,
73
+ log_cdf_plus,
74
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
75
+ )
76
+ assert log_probs.shape == x.shape
77
+ return log_probs
glide_text2im/nn.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Various utilities for neural networks.
3
+ """
4
+
5
+ import math
6
+
7
+ import torch as th
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
12
+ class SiLU(nn.Module):
13
+ def forward(self, x):
14
+ return x * th.sigmoid(x)
15
+
16
+ class GroupNorm32(nn.GroupNorm):
17
+ def __init__(self, num_groups, num_channels, swish, eps=1e-5):
18
+ super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps)
19
+ self.swish = swish
20
+
21
+ def forward(self, x):
22
+ y = super().forward(x.float()).to(x.dtype)
23
+ if self.swish == 1.0:
24
+ y = F.silu(y)
25
+ elif self.swish:
26
+ y = y * F.sigmoid(y * float(self.swish))
27
+ return y
28
+
29
+ def conv_nd(dims, *args, **kwargs):
30
+ """
31
+ Create a 1D, 2D, or 3D convolution module.
32
+ """
33
+ if dims == 1:
34
+ return nn.Conv1d(*args, **kwargs)
35
+ elif dims == 2:
36
+ return nn.Conv2d(*args, **kwargs)
37
+ elif dims == 3:
38
+ return nn.Conv3d(*args, **kwargs)
39
+ raise ValueError(f"unsupported dimensions: {dims}")
40
+
41
+
42
+ def linear(*args, **kwargs):
43
+ """
44
+ Create a linear module.
45
+ """
46
+ return nn.Linear(*args, **kwargs)
47
+
48
+
49
+ def avg_pool_nd(dims, *args, **kwargs):
50
+ """
51
+ Create a 1D, 2D, or 3D average pooling module.
52
+ """
53
+ if dims == 1:
54
+ return nn.AvgPool1d(*args, **kwargs)
55
+ elif dims == 2:
56
+ return nn.AvgPool2d(*args, **kwargs)
57
+ elif dims == 3:
58
+ return nn.AvgPool3d(*args, **kwargs)
59
+ raise ValueError(f"unsupported dimensions: {dims}")
60
+
61
+
62
+ def update_ema(target_params, source_params, rate=0.99):
63
+ """
64
+ Update target parameters to be closer to those of source parameters using
65
+ an exponential moving average.
66
+
67
+ :param target_params: the target parameter sequence.
68
+ :param source_params: the source parameter sequence.
69
+ :param rate: the EMA rate (closer to 1 means slower).
70
+ """
71
+ for targ, src in zip(target_params, source_params):
72
+ targ.detach().mul_(rate).add_(src, alpha=1 - rate)
73
+
74
+
75
+ def zero_module(module):
76
+ """
77
+ Zero out the parameters of a module and return it.
78
+ """
79
+ for p in module.parameters():
80
+ p.detach().zero_()
81
+ return module
82
+
83
+
84
+ def scale_module(module, scale):
85
+ """
86
+ Scale the parameters of a module and return it.
87
+ """
88
+ for p in module.parameters():
89
+ p.detach().mul_(scale)
90
+ return module
91
+
92
+
93
+ def mean_flat(tensor):
94
+ """
95
+ Take the mean over all non-batch dimensions.
96
+ """
97
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
98
+
99
+
100
+ def normalization(channels, swish=0.0):
101
+ """
102
+ Make a standard normalization layer.
103
+
104
+ :param channels: number of input channels.
105
+ :return: an nn.Module for normalization.
106
+ """
107
+ return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
108
+
109
+
110
+ def timestep_embedding(timesteps, dim, max_period=10000):
111
+ """
112
+ Create sinusoidal timestep embeddings.
113
+
114
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
115
+ These may be fractional.
116
+ :param dim: the dimension of the output.
117
+ :param max_period: controls the minimum frequency of the embeddings.
118
+ :return: an [N x dim] Tensor of positional embeddings.
119
+ """
120
+ half = dim // 2
121
+ freqs = th.exp(
122
+ -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
123
+ ).to(device=timesteps.device)
124
+ args = timesteps[:, None].float() * freqs[None]
125
+ embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
126
+ if dim % 2:
127
+ embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
128
+ return embedding
129
+
130
+
131
+ def checkpoint(func, inputs, params, flag):
132
+ """
133
+ Evaluate a function without caching intermediate activations, allowing for
134
+ reduced memory at the expense of extra compute in the backward pass.
135
+
136
+ :param func: the function to evaluate.
137
+ :param inputs: the argument sequence to pass to `func`.
138
+ :param params: a sequence of parameters `func` depends on but does not
139
+ explicitly take as arguments.
140
+ :param flag: if False, disable gradient checkpointing.
141
+ """
142
+ if flag:
143
+ args = tuple(inputs) + tuple(params)
144
+ return CheckpointFunction.apply(func, len(inputs), *args)
145
+ else:
146
+ return func(*inputs)
147
+
148
+
149
+ class CheckpointFunction(th.autograd.Function):
150
+ @staticmethod
151
+ def forward(ctx, run_function, length, *args):
152
+ ctx.run_function = run_function
153
+ ctx.input_tensors = list(args[:length])
154
+ ctx.input_params = list(args[length:])
155
+ with th.no_grad():
156
+ output_tensors = ctx.run_function(*ctx.input_tensors)
157
+ return output_tensors
158
+
159
+ @staticmethod
160
+ def backward(ctx, *output_grads):
161
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
162
+ with th.enable_grad():
163
+ # Fixes a bug where the first op in run_function modifies the
164
+ # Tensor storage in place, which is not allowed for detach()'d
165
+ # Tensors.
166
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
167
+ output_tensors = ctx.run_function(*shallow_copies)
168
+ input_grads = th.autograd.grad(
169
+ output_tensors,
170
+ ctx.input_tensors + ctx.input_params,
171
+ output_grads,
172
+ allow_unused=True,
173
+ )
174
+ del ctx.input_tensors
175
+ del ctx.input_params
176
+ del output_tensors
177
+ return (None, None) + input_grads
glide_text2im/resample.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+ import numpy as np
4
+ import torch as th
5
+ import torch.distributed as dist
6
+
7
+
8
+ def create_named_schedule_sampler(name, diffusion):
9
+ """
10
+ Create a ScheduleSampler from a library of pre-defined samplers.
11
+
12
+ :param name: the name of the sampler.
13
+ :param diffusion: the diffusion object to sample for.
14
+ """
15
+ if name == "uniform":
16
+ return UniformSampler(diffusion)
17
+ elif name == "loss-second-moment":
18
+ return LossSecondMomentResampler(diffusion)
19
+ else:
20
+ raise NotImplementedError(f"unknown schedule sampler: {name}")
21
+
22
+
23
+ class ScheduleSampler(ABC):
24
+ """
25
+ A distribution over timesteps in the diffusion process, intended to reduce
26
+ variance of the objective.
27
+
28
+ By default, samplers perform unbiased importance sampling, in which the
29
+ objective's mean is unchanged.
30
+ However, subclasses may override sample() to change how the resampled
31
+ terms are reweighted, allowing for actual changes in the objective.
32
+ """
33
+
34
+ @abstractmethod
35
+ def weights(self):
36
+ """
37
+ Get a numpy array of weights, one per diffusion step.
38
+
39
+ The weights needn't be normalized, but must be positive.
40
+ """
41
+
42
+ def sample(self, batch_size, device):
43
+ """
44
+ Importance-sample timesteps for a batch.
45
+
46
+ :param batch_size: the number of timesteps.
47
+ :param device: the torch device to save to.
48
+ :return: a tuple (timesteps, weights):
49
+ - timesteps: a tensor of timestep indices.
50
+ - weights: a tensor of weights to scale the resulting losses.
51
+ """
52
+ w = self.weights()
53
+ p = w / np.sum(w)
54
+ indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
55
+ indices = th.from_numpy(indices_np).long().to(device)
56
+ weights_np = 1 / (len(p) * p[indices_np])
57
+ weights = th.from_numpy(weights_np).float().to(device)
58
+ return indices, weights
59
+
60
+
61
+ class UniformSampler(ScheduleSampler):
62
+ def __init__(self, diffusion):
63
+ self.diffusion = diffusion
64
+ self._weights = np.ones([diffusion.num_timesteps])
65
+
66
+ def weights(self):
67
+ return self._weights
68
+
69
+
70
+ class LossAwareSampler(ScheduleSampler):
71
+ def update_with_local_losses(self, local_ts, local_losses):
72
+ """
73
+ Update the reweighting using losses from a model.
74
+
75
+ Call this method from each rank with a batch of timesteps and the
76
+ corresponding losses for each of those timesteps.
77
+ This method will perform synchronization to make sure all of the ranks
78
+ maintain the exact same reweighting.
79
+
80
+ :param local_ts: an integer Tensor of timesteps.
81
+ :param local_losses: a 1D Tensor of losses.
82
+ """
83
+ batch_sizes = [
84
+ th.tensor([0], dtype=th.int32, device=local_ts.device)
85
+ for _ in range(dist.get_world_size())
86
+ ]
87
+ dist.all_gather(
88
+ batch_sizes,
89
+ th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
90
+ )
91
+
92
+ # Pad all_gather batches to be the maximum batch size.
93
+ batch_sizes = [x.item() for x in batch_sizes]
94
+ max_bs = max(batch_sizes)
95
+
96
+ timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
97
+ loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
98
+ dist.all_gather(timestep_batches, local_ts)
99
+ dist.all_gather(loss_batches, local_losses)
100
+ timesteps = [
101
+ x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
102
+ ]
103
+ losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
104
+ self.update_with_all_losses(timesteps, losses)
105
+
106
+ @abstractmethod
107
+ def update_with_all_losses(self, ts, losses):
108
+ """
109
+ Update the reweighting using losses from a model.
110
+
111
+ Sub-classes should override this method to update the reweighting
112
+ using losses from the model.
113
+
114
+ This method directly updates the reweighting without synchronizing
115
+ between workers. It is called by update_with_local_losses from all
116
+ ranks with identical arguments. Thus, it should have deterministic
117
+ behavior to maintain state across workers.
118
+
119
+ :param ts: a list of int timesteps.
120
+ :param losses: a list of float losses, one per timestep.
121
+ """
122
+
123
+
124
+ class LossSecondMomentResampler(LossAwareSampler):
125
+ def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
126
+ self.diffusion = diffusion
127
+ self.history_per_term = history_per_term
128
+ self.uniform_prob = uniform_prob
129
+ self._loss_history = np.zeros(
130
+ [diffusion.num_timesteps, history_per_term], dtype=np.float64
131
+ )
132
+ self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
133
+
134
+ def weights(self):
135
+ if not self._warmed_up():
136
+ return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
137
+ weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
138
+ weights /= np.sum(weights)
139
+ weights *= 1 - self.uniform_prob
140
+ weights += self.uniform_prob / len(weights)
141
+ return weights
142
+
143
+ def update_with_all_losses(self, ts, losses):
144
+ for t, loss in zip(ts, losses):
145
+ if self._loss_counts[t] == self.history_per_term:
146
+ # Shift out the oldest loss term.
147
+ self._loss_history[t, :-1] = self._loss_history[t, 1:]
148
+ self._loss_history[t, -1] = loss
149
+ else:
150
+ self._loss_history[t, self._loss_counts[t]] = loss
151
+ self._loss_counts[t] += 1
152
+
153
+ def _warmed_up(self):
154
+ return (self._loss_counts == self.history_per_term).all()
glide_text2im/respace.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch as th
3
+
4
+ from .gaussian_diffusion import GaussianDiffusion
5
+
6
+
7
+ def space_timesteps(num_timesteps, section_counts):
8
+ """
9
+ Create a list of timesteps to use from an original diffusion process,
10
+ given the number of timesteps we want to take from equally-sized portions
11
+ of the original process.
12
+
13
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
14
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
15
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
16
+
17
+ If the stride is a string starting with "ddim", then the fixed striding
18
+ from the DDIM paper is used, and only one section is allowed.
19
+
20
+ :param num_timesteps: the number of diffusion steps in the original
21
+ process to divide up.
22
+ :param section_counts: either a list of numbers, or a string containing
23
+ comma-separated numbers, indicating the step count
24
+ per section. As a special case, use "ddimN" where N
25
+ is a number of steps to use the striding from the
26
+ DDIM paper.
27
+ :return: a set of diffusion steps from the original process to use.
28
+ """
29
+ if isinstance(section_counts, str):
30
+ if section_counts.startswith("ddim"):
31
+ desired_count = int(section_counts[len("ddim") :])
32
+ for i in range(1, num_timesteps):
33
+ if len(range(0, num_timesteps, i)) == desired_count:
34
+ return set(range(0, num_timesteps, i))
35
+ raise ValueError(
36
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
37
+ )
38
+ elif section_counts == "fast27":
39
+ steps = space_timesteps(num_timesteps, "10,10,3,2,2")
40
+ # Help reduce DDIM artifacts from noisiest timesteps.
41
+ steps.remove(num_timesteps - 1)
42
+ steps.add(num_timesteps - 3)
43
+ return steps
44
+ section_counts = [int(x) for x in section_counts.split(",")]
45
+ size_per = num_timesteps // len(section_counts)
46
+ extra = num_timesteps % len(section_counts)
47
+ start_idx = 0
48
+ all_steps = []
49
+ for i, section_count in enumerate(section_counts):
50
+ size = size_per + (1 if i < extra else 0)
51
+ if size < section_count:
52
+ raise ValueError(
53
+ f"cannot divide section of {size} steps into {section_count}"
54
+ )
55
+ if section_count <= 1:
56
+ frac_stride = 1
57
+ else:
58
+ frac_stride = (size - 1) / (section_count - 1)
59
+ cur_idx = 0.0
60
+ taken_steps = []
61
+ for _ in range(section_count):
62
+ taken_steps.append(start_idx + round(cur_idx))
63
+ cur_idx += frac_stride
64
+ all_steps += taken_steps
65
+ start_idx += size
66
+ return set(all_steps)
67
+
68
+
69
+ class SpacedDiffusion(GaussianDiffusion):
70
+ """
71
+ A diffusion process which can skip steps in a base diffusion process.
72
+
73
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
74
+ original diffusion process to retain.
75
+ :param kwargs: the kwargs to create the base diffusion process.
76
+ """
77
+
78
+ def __init__(self, use_timesteps, **kwargs):
79
+ self.use_timesteps = set(use_timesteps)
80
+ self.timestep_map = []
81
+ self.original_num_steps = len(kwargs["betas"])
82
+
83
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
84
+ last_alpha_cumprod = 1.0
85
+ new_betas = []
86
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
87
+ if i in self.use_timesteps:
88
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
89
+ last_alpha_cumprod = alpha_cumprod
90
+ self.timestep_map.append(i)
91
+ kwargs["betas"] = np.array(new_betas)
92
+ super().__init__(**kwargs)
93
+
94
+ def p_mean_variance(
95
+ self, model, *args, **kwargs
96
+ ): # pylint: disable=signature-differs
97
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
98
+
99
+ def condition_mean(self, cond_fn, *args, **kwargs):
100
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
101
+
102
+ def condition_score(self, cond_fn, *args, **kwargs):
103
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
104
+
105
+ def training_losses(
106
+ self, model, *args, **kwargs
107
+ ): # pylint: disable=signature-differs
108
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
109
+
110
+ def _wrap_model(self, model):
111
+ if isinstance(model, _WrappedModel):
112
+ return model
113
+ return _WrappedModel(
114
+ model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
115
+ )
116
+
117
+ def _scale_timesteps(self, t):
118
+ # Scaling is done by the wrapped model.
119
+ return t
120
+
121
+
122
+ class _WrappedModel:
123
+ def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
124
+ self.model = model
125
+ self.timestep_map = timestep_map
126
+ self.rescale_timesteps = rescale_timesteps
127
+ self.original_num_steps = original_num_steps
128
+
129
+ def __call__(self, x, ts, **kwargs):
130
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
131
+ new_ts = map_tensor[ts]
132
+ if self.rescale_timesteps:
133
+ new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
134
+ return self.model(x, new_ts, **kwargs)
glide_text2im/script_util.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import inspect
3
+
4
+ from . import gaussian_diffusion as gd
5
+ from .respace import SpacedDiffusion, space_timesteps
6
+ from .text2im_model import (
7
+ SuperResText2ImModel,
8
+ Text2ImModel,
9
+ )
10
+
11
+
12
+
13
+ def model_and_diffusion_defaults(super_res=0):
14
+ """
15
+ Defaults for image training.
16
+ """
17
+ result= dict(
18
+ image_size=64,
19
+ num_channels=192,
20
+ num_res_blocks=3,
21
+ channel_mult="",
22
+ num_heads=1,
23
+ num_head_channels=64,
24
+ num_heads_upsample=-1,
25
+ attention_resolutions="32,16,8",
26
+ dropout=0.1,
27
+ text_ctx=128,
28
+ xf_width=512,
29
+ xf_layers=16,
30
+ xf_heads=8,
31
+ xf_final_ln=True,
32
+ xf_padding=True,
33
+ learn_sigma=True, ##
34
+ sigma_small=False, ##
35
+ diffusion_steps=1000,
36
+ noise_schedule="squaredcos_cap_v2",
37
+ timestep_respacing="",
38
+ use_kl=False, ##
39
+ predict_xstart=False,
40
+ rescale_timesteps=True,
41
+ rescale_learned_sigmas=True,
42
+ use_fp16=False, ##
43
+ use_scale_shift_norm=True,
44
+ resblock_updown=True,
45
+ cache_text_emb=False,
46
+ inpaint=False,
47
+ super_res=0,
48
+ mode = '',
49
+ )
50
+ if super_res:
51
+ result.update(
52
+ dict(
53
+ image_size=256,
54
+ num_res_blocks=2,
55
+ noise_schedule="linear",
56
+ super_res=super_res,
57
+ ))
58
+ return result
59
+
60
+
61
+ def create_model_and_diffusion(
62
+ image_size=64,
63
+ num_channels=192,
64
+ num_res_blocks=3,
65
+ channel_mult="",
66
+ num_heads=1,
67
+ num_head_channels=64,
68
+ num_heads_upsample=-1,
69
+ attention_resolutions="32,16,8",
70
+ dropout=0.1,
71
+ text_ctx=128,
72
+ xf_width=512,
73
+ xf_layers=16,
74
+ xf_heads=8,
75
+ xf_final_ln=True,
76
+ xf_padding=True,
77
+ learn_sigma=False, ##
78
+ sigma_small=False, ##
79
+ diffusion_steps=1000,
80
+ noise_schedule="squaredcos_cap_v2",
81
+ timestep_respacing="",
82
+ use_kl=False, ##
83
+ predict_xstart=False,
84
+ rescale_timesteps=True,
85
+ rescale_learned_sigmas=True,
86
+ use_fp16=False, ##
87
+ use_scale_shift_norm=True,
88
+ resblock_updown=True,
89
+ cache_text_emb=False,
90
+ inpaint=False,
91
+ super_res=False,
92
+ mode = '',
93
+ ):
94
+ model = create_model(
95
+ image_size,
96
+ num_channels,
97
+ num_res_blocks,
98
+ learn_sigma=learn_sigma,
99
+ channel_mult=channel_mult,
100
+ use_fp16=use_fp16,
101
+ attention_resolutions=attention_resolutions,
102
+ num_heads=num_heads,
103
+ num_head_channels=num_head_channels,
104
+ num_heads_upsample=num_heads_upsample,
105
+ use_scale_shift_norm=use_scale_shift_norm,
106
+ dropout=dropout,
107
+ text_ctx=text_ctx,
108
+ xf_width=xf_width,
109
+ xf_layers=xf_layers,
110
+ xf_heads=xf_heads,
111
+ xf_final_ln=xf_final_ln,
112
+ xf_padding=xf_padding,
113
+ resblock_updown=resblock_updown,
114
+ cache_text_emb=cache_text_emb,
115
+ inpaint=inpaint,
116
+ super_res=super_res,
117
+ mode = mode
118
+ )
119
+ diffusion = create_gaussian_diffusion(
120
+ steps=diffusion_steps,
121
+ learn_sigma=learn_sigma,
122
+ sigma_small=sigma_small,
123
+ noise_schedule=noise_schedule,
124
+ use_kl=use_kl,
125
+ predict_xstart=predict_xstart,
126
+ rescale_timesteps=rescale_timesteps,
127
+ rescale_learned_sigmas=rescale_learned_sigmas,
128
+ timestep_respacing=timestep_respacing,
129
+ )
130
+ return model, diffusion
131
+
132
+
133
+ def create_model(
134
+ image_size,
135
+ num_channels,
136
+ num_res_blocks,
137
+ learn_sigma,
138
+ channel_mult,
139
+ use_fp16,
140
+ attention_resolutions,
141
+ num_heads,
142
+ num_head_channels,
143
+ num_heads_upsample,
144
+ use_scale_shift_norm,
145
+ dropout,
146
+ text_ctx,
147
+ xf_width,
148
+ xf_layers,
149
+ xf_heads,
150
+ xf_final_ln,
151
+ xf_padding,
152
+ resblock_updown,
153
+ cache_text_emb,
154
+ inpaint,
155
+ super_res,
156
+ mode,
157
+ ):
158
+ if channel_mult == "":
159
+ if image_size == 256:
160
+ channel_mult = (1, 1, 2, 2, 4, 4)
161
+ elif image_size == 128:
162
+ channel_mult = (1, 1, 2, 3, 4)
163
+ elif image_size == 64:
164
+ channel_mult = (1, 2, 3, 4)
165
+ else:
166
+ raise ValueError(f"unsupported image size: {image_size}")
167
+ else:
168
+ channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(","))
169
+ assert 2 ** (len(channel_mult) + 2) == image_size
170
+
171
+ attention_ds = []
172
+ for res in attention_resolutions.split(","):
173
+ attention_ds.append(image_size // int(res))
174
+
175
+ if super_res:
176
+ model_cls = SuperResText2ImModel
177
+ else:
178
+ model_cls = Text2ImModel
179
+
180
+ n_class = 3
181
+ if mode == 'ade20k' or mode == 'coco':
182
+ n_class = 3
183
+ elif mode == 'depth-normal' :
184
+ n_class = 6
185
+ elif mode == 'coco-edge' or mode == 'flickr-edge':
186
+ n_class = 1
187
+
188
+ return model_cls(
189
+ text_ctx=text_ctx,
190
+ xf_width=xf_width,
191
+ xf_layers=xf_layers,
192
+ xf_heads=xf_heads,
193
+ xf_final_ln=xf_final_ln,
194
+ model_channels=num_channels,
195
+ out_channels=(3 if not learn_sigma else 6),
196
+ num_res_blocks=num_res_blocks,
197
+ attention_resolutions=tuple(attention_ds),
198
+ dropout=dropout,
199
+ channel_mult=channel_mult,
200
+ use_fp16=use_fp16,
201
+ num_heads=num_heads,
202
+ num_heads_upsample=num_heads_upsample,
203
+ num_head_channels=num_head_channels,
204
+ use_scale_shift_norm=use_scale_shift_norm,
205
+ resblock_updown=resblock_updown,
206
+ in_channels=3,
207
+ n_class = n_class,
208
+ image_size = image_size,
209
+ )
210
+
211
+
212
+
213
+ def create_gaussian_diffusion(
214
+ *,
215
+ steps=1000,
216
+ learn_sigma=False,
217
+ sigma_small=False,
218
+ noise_schedule="linear",
219
+ use_kl=False,
220
+ predict_xstart=False,
221
+ rescale_timesteps=False,
222
+ rescale_learned_sigmas=False,
223
+ timestep_respacing="",
224
+ ):
225
+ betas = gd.get_named_beta_schedule(noise_schedule, steps)
226
+ if use_kl:
227
+ loss_type = gd.LossType.RESCALED_KL
228
+ elif rescale_learned_sigmas:
229
+ loss_type = gd.LossType.RESCALED_MSE
230
+ else:
231
+ loss_type = gd.LossType.MSE
232
+ if not timestep_respacing:
233
+ timestep_respacing = [steps]
234
+ return SpacedDiffusion(
235
+ use_timesteps=space_timesteps(steps, timestep_respacing),
236
+ betas=betas,
237
+ model_mean_type=(
238
+ gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
239
+ ),
240
+ model_var_type=(
241
+ (
242
+ gd.ModelVarType.FIXED_LARGE
243
+ if not sigma_small
244
+ else gd.ModelVarType.FIXED_SMALL
245
+ )
246
+ if not learn_sigma
247
+ else gd.ModelVarType.LEARNED_RANGE
248
+ ),
249
+ loss_type=loss_type,
250
+ rescale_timesteps=rescale_timesteps,
251
+ )
252
+
253
+
254
+ def add_dict_to_argparser(parser, default_dict):
255
+ for k, v in default_dict.items():
256
+ v_type = type(v)
257
+ if v is None:
258
+ v_type = str
259
+ elif isinstance(v, bool):
260
+ v_type = str2bool
261
+ parser.add_argument(f"--{k}", default=v, type=v_type)
262
+
263
+
264
+ def args_to_dict(args, keys=None):
265
+ if keys is None:
266
+ keys=vars(args)
267
+ return {k: getattr(args, k) for k in keys}
268
+
269
+
270
+ def str2bool(v):
271
+ """
272
+ https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
273
+ """
274
+ if isinstance(v, bool):
275
+ return v
276
+ if v.lower() in ("yes", "true", "t", "y", "1"):
277
+ return True
278
+ elif v.lower() in ("no", "false", "f", "n", "0"):
279
+ return False
280
+ else:
281
+ raise argparse.ArgumentTypeError("boolean value expected")
glide_text2im/text2im_model.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch as th
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import random
5
+ from .nn import timestep_embedding
6
+ from .unet import UNetModel
7
+ from .xf import LayerNorm, Transformer, convert_module_to_f16
8
+ from timm.models.vision_transformer import PatchEmbed
9
+
10
+ class Text2ImModel(nn.Module):
11
+ def __init__(
12
+ self,
13
+ text_ctx,
14
+ xf_width,
15
+ xf_layers,
16
+ xf_heads,
17
+ xf_final_ln,
18
+ model_channels,
19
+ out_channels,
20
+ num_res_blocks,
21
+ attention_resolutions,
22
+ dropout,
23
+ channel_mult,
24
+ use_fp16,
25
+ num_heads,
26
+ num_heads_upsample,
27
+ num_head_channels,
28
+ use_scale_shift_norm,
29
+ resblock_updown,
30
+ in_channels = 3,
31
+ n_class = 3,
32
+ image_size = 64,
33
+ ):
34
+ super().__init__()
35
+ self.encoder = Encoder(img_size=image_size, patch_size=image_size//16, in_chans=n_class,
36
+ xf_width=xf_width, xf_layers=8, xf_heads=xf_heads, model_channels=model_channels)
37
+
38
+ self.in_channels = in_channels
39
+ self.decoder = Text2ImUNet(
40
+ in_channels,
41
+ model_channels,
42
+ out_channels,
43
+ num_res_blocks,
44
+ attention_resolutions,
45
+ dropout=dropout,
46
+ channel_mult=channel_mult,
47
+ use_fp16=use_fp16,
48
+ num_heads=num_heads,
49
+ num_heads_upsample=num_heads_upsample,
50
+ num_head_channels=num_head_channels,
51
+ use_scale_shift_norm=use_scale_shift_norm,
52
+ resblock_updown=resblock_updown,
53
+ encoder_channels=xf_width
54
+ )
55
+
56
+
57
+ def forward(self, xt, timesteps, ref=None, uncond_p=0.0):
58
+ latent_outputs =self.encoder(ref, uncond_p)
59
+ pred = self.decoder(xt, timesteps, latent_outputs)
60
+ return pred
61
+
62
+
63
+ class Text2ImUNet(UNetModel):
64
+ def __init__(
65
+ self,
66
+ *args,
67
+ **kwargs,
68
+ ):
69
+ super().__init__(*args, **kwargs)
70
+ self.transformer_proj = nn.Linear(512, self.model_channels * 4) ###
71
+
72
+ def forward(self, x, timesteps, latent_outputs):
73
+ hs = []
74
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
75
+ xf_proj, xf_out = latent_outputs["xf_proj"], latent_outputs["xf_out"]
76
+
77
+ xf_proj = self.transformer_proj(xf_proj) ###
78
+ emb = emb + xf_proj.to(emb)
79
+
80
+ h = x.type(self.dtype)
81
+ for module in self.input_blocks:
82
+ h = module(h, emb, xf_out)
83
+ hs.append(h)
84
+ h = self.middle_block(h, emb, xf_out)
85
+ for module in self.output_blocks:
86
+ h = th.cat([h, hs.pop()], dim=1)
87
+ h = module(h, emb, xf_out)
88
+ h = h.type(x.dtype)
89
+ h = self.out(h)
90
+ return h
91
+
92
+
93
+ class Encoder(nn.Module):
94
+ def __init__(
95
+ self,
96
+ img_size,
97
+ patch_size,
98
+ in_chans,
99
+ xf_width,
100
+ xf_layers,
101
+ xf_heads,
102
+ model_channels,
103
+ ):
104
+ super().__init__( )
105
+ self.transformer = Transformer(
106
+ xf_width,
107
+ xf_layers,
108
+ xf_heads,
109
+ )
110
+
111
+ self.cnn = CNN(in_chans)
112
+ self.final_ln = LayerNorm(xf_width)
113
+
114
+ self.cls_token = nn.Parameter(th.empty(1, 1, xf_width, dtype=th.float32))
115
+ self.positional_embedding = nn.Parameter(th.empty(1, 256 + 1, xf_width, dtype=th.float32))
116
+
117
+ def forward(self, ref, uncond_p=0.0):
118
+ x = self.cnn(ref)
119
+ x = x.flatten(2).transpose(1, 2)
120
+
121
+ x = x + self.positional_embedding[:, 1:, :]
122
+
123
+ cls_token = self.cls_token + self.positional_embedding[:, :1, :]
124
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
125
+ x = th.cat((x, cls_tokens), dim=1)
126
+
127
+ xf_out = self.transformer(x)
128
+ if self.final_ln is not None:
129
+ xf_out = self.final_ln(xf_out)
130
+
131
+ xf_proj = xf_out[:, -1]
132
+ xf_out = xf_out[:, :-1].permute(0, 2, 1) # NLC -> NCL
133
+
134
+ outputs = dict(xf_proj=xf_proj, xf_out=xf_out)
135
+ return outputs
136
+
137
+
138
+ class SuperResText2ImModel(Text2ImModel):
139
+ """
140
+ A text2im model that performs super-resolution.
141
+ Expects an extra kwarg `low_res` to condition on a low-resolution image.
142
+ """
143
+
144
+ def __init__(self, *args, **kwargs):
145
+ if "in_channels" in kwargs:
146
+ kwargs = dict(kwargs)
147
+ kwargs["in_channels"] = kwargs["in_channels"] * 2
148
+ else:
149
+ # Curse you, Python. Or really, just curse positional arguments :|.
150
+ args = list(args)
151
+ args[1] = args[1] * 2
152
+ super().__init__(*args, **kwargs)
153
+
154
+
155
+ def forward(self, x, timesteps, low_res=None, **kwargs):
156
+ _, _, new_height, new_width = x.shape
157
+ upsampled = F.interpolate(
158
+ low_res, (new_height, new_width), mode="bilinear", align_corners=False
159
+ )
160
+
161
+ # ##########
162
+ # upsampled = upsampled + th.randn_like(upsampled)*0.0005*th.log(1 + 0.1* timesteps.reshape(timesteps.shape[0], 1,1,1))
163
+ # ##########
164
+
165
+ x = th.cat([x, upsampled], dim=1)
166
+ return super().forward(x, timesteps, **kwargs)
167
+
168
+
169
+
170
+ def conv3x3(in_channels, out_channels, stride=1):
171
+ return nn.Conv2d(in_channels, out_channels, kernel_size=3,
172
+ stride=stride, padding=1, bias=True)
173
+
174
+
175
+ def conv7x7(in_channels, out_channels, stride=1):
176
+ return nn.Conv2d(in_channels, out_channels, kernel_size=7,
177
+ stride=stride, padding=3, bias=True)
178
+
179
+ class CNN(nn.Module):
180
+ def __init__(self, in_channels=3):
181
+ super(CNN, self).__init__()
182
+ self.conv1 = conv7x7(in_channels, 32) #256
183
+ self.norm1 = nn.InstanceNorm2d(32, affine=True)
184
+ self.LReLU1 = nn.LeakyReLU(0.2)
185
+
186
+ self.conv2 = conv3x3(32, 64, 2) #128
187
+ self.norm2 = nn.InstanceNorm2d(64, affine=True)
188
+ self.LReLU2 = nn.LeakyReLU(0.2)
189
+
190
+ self.conv3 = conv3x3(64, 128, 2) #64
191
+ self.norm3 = nn.InstanceNorm2d(128, affine=True)
192
+ self.LReLU3 = nn.LeakyReLU(0.2)
193
+
194
+ self.conv4 = conv3x3(128, 256, 2) #32
195
+ self.norm4 = nn.InstanceNorm2d(256, affine=True)
196
+ self.LReLU4 = nn.LeakyReLU(0.2)
197
+
198
+ self.conv5 = conv3x3(256, 512, 2) #16
199
+ self.norm5 = nn.InstanceNorm2d(512, affine=True)
200
+ self.LReLU5 = nn.LeakyReLU(0.2)
201
+
202
+ self.conv6 = conv3x3(512, 512, 1)
203
+
204
+
205
+ def forward(self, x):
206
+ x = self.LReLU1(self.norm1(self.conv1(x)))
207
+ x = self.LReLU2(self.norm2(self.conv2(x)))
208
+ x = self.LReLU3(self.norm3(self.conv3(x)))
209
+ x = self.LReLU4(self.norm4(self.conv4(x)))
210
+ x = self.LReLU5(self.norm5(self.conv5(x)))
211
+ x = self.conv6(x)
212
+ return x
glide_text2im/train_util.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import functools
3
+ import os
4
+
5
+ import blobfile as bf
6
+ import numpy as np
7
+ import torch as th
8
+ import torch.distributed as dist
9
+ from torch.nn.parallel.distributed import DistributedDataParallel as DDP
10
+ from torch.optim import AdamW
11
+ from .glide_util import sample
12
+ from . import dist_util, logger
13
+ from .fp16_util import (
14
+ make_master_params,
15
+ master_params_to_model_params,
16
+ model_grads_to_master_grads,
17
+ unflatten_master_params,
18
+ zero_grad,
19
+ )
20
+ from .nn import update_ema
21
+ from .vgg import VGG
22
+ from .adv import AdversarialLoss
23
+ from .resample import LossAwareSampler, UniformSampler
24
+ import glob
25
+ import torchvision.utils as tvu
26
+ import PIL.Image as Image
27
+ # For ImageNet experiments, this was a good default value.
28
+ # We found that the lg_loss_scale quickly climbed to
29
+ # 20-21 within the first ~1K steps of training.
30
+ INITIAL_LOG_LOSS_SCALE = 20.0
31
+
32
+
33
+
34
+ class TrainLoop:
35
+ def __init__(
36
+ self,
37
+ model,
38
+ glide_options,
39
+ diffusion,
40
+ data,
41
+ val_data,
42
+ batch_size,
43
+ microbatch,
44
+ lr,
45
+ ema_rate,
46
+ log_interval,
47
+ save_interval,
48
+ resume_checkpoint,
49
+ use_fp16=False,
50
+ fp16_scale_growth=1e-3,
51
+ schedule_sampler=None,
52
+ weight_decay=0.0,
53
+ lr_anneal_steps=0,
54
+ finetune_decoder = False,
55
+ mode = '',
56
+ use_vgg = False,
57
+ use_gan = False,
58
+ uncond_p = 0,
59
+ super_res = 0,
60
+ ):
61
+ self.model = model
62
+ self.glide_options=glide_options
63
+ self.diffusion = diffusion
64
+ self.data = data
65
+ self.val_data=val_data
66
+ self.batch_size = batch_size
67
+ self.microbatch = microbatch if microbatch > 0 else batch_size
68
+ self.lr = lr
69
+ self.ema_rate = (
70
+ [ema_rate]
71
+ if isinstance(ema_rate, float)
72
+ else [float(x) for x in ema_rate.split(",")]
73
+ )
74
+ self.log_interval = log_interval
75
+ self.save_interval = save_interval
76
+ self.resume_checkpoint = find_resume_checkpoint(resume_checkpoint)
77
+ self.use_fp16 = use_fp16
78
+ self.fp16_scale_growth = fp16_scale_growth
79
+ self.schedule_sampler = schedule_sampler or UniformSampler(diffusion)
80
+ self.weight_decay = weight_decay
81
+ self.lr_anneal_steps = lr_anneal_steps
82
+ self.step = 0
83
+ self.resume_step = 0
84
+ self.global_batch = self.batch_size * dist.get_world_size()
85
+
86
+ if use_vgg:
87
+ self.vgg = VGG(conv_index='22').to(dist_util.dev())
88
+ print('use perc')
89
+ else:
90
+ self.vgg = None
91
+
92
+ if use_gan:
93
+ self.adv = AdversarialLoss()
94
+ print('use adv')
95
+ else:
96
+ self.adv = None
97
+
98
+ self.super_res = super_res
99
+
100
+ self.uncond_p =uncond_p
101
+ self.mode = mode
102
+
103
+ self.finetune_decoder = finetune_decoder
104
+ if finetune_decoder:
105
+ self.optimize_model = self.model
106
+ else:
107
+ self.optimize_model = self.model.encoder
108
+
109
+ self.model_params = list(self.optimize_model.parameters())
110
+ self.master_params = self.model_params
111
+ self.lg_loss_scale = INITIAL_LOG_LOSS_SCALE
112
+ self.sync_cuda = th.cuda.is_available()
113
+ self._load_and_sync_parameters()
114
+ if self.use_fp16:
115
+ self._setup_fp16()
116
+
117
+ self.opt = AdamW(self.master_params, lr=self.lr, weight_decay=self.weight_decay)
118
+ if self.resume_step:
119
+ self._load_optimizer_state()
120
+ # Model was resumed, either due to a restart or a checkpoint
121
+ # being specified at the command line.
122
+ self.ema_params = [
123
+ self._load_ema_parameters(rate) for rate in self.ema_rate
124
+ ]
125
+ else:
126
+ self.ema_params = [
127
+ copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate))
128
+ ]
129
+
130
+ if th.cuda.is_available():
131
+ self.use_ddp = True
132
+ self.ddp_model = DDP(
133
+ self.model,
134
+ device_ids=[dist_util.dev()],
135
+ output_device=dist_util.dev(),
136
+ broadcast_buffers=False,
137
+ bucket_cap_mb=128,
138
+ find_unused_parameters=False,
139
+ )
140
+ else:
141
+ if dist.get_world_size() > 1:
142
+ logger.warn(
143
+ "Distributed training requires CUDA. "
144
+ "Gradients will not be synchronized properly!"
145
+ )
146
+ self.use_ddp = False
147
+ self.ddp_model = self.model
148
+
149
+ def _load_and_sync_parameters(self):
150
+ resume_checkpoint = self.resume_checkpoint
151
+
152
+ if resume_checkpoint:
153
+ self.resume_step = parse_resume_step_from_filename(resume_checkpoint)
154
+ if dist.get_rank() == 0:
155
+ logger.log(f"loading model from checkpoint: {resume_checkpoint}...")
156
+ self.model.load_state_dict(th.load(resume_checkpoint, map_location="cpu"),strict=False)
157
+
158
+ dist_util.sync_params(self.model.parameters())
159
+
160
+ def _load_ema_parameters(self, rate):
161
+ ema_params = copy.deepcopy(self.master_params)
162
+
163
+ main_checkpoint = self.resume_checkpoint
164
+ ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate)
165
+ if ema_checkpoint:
166
+ if dist.get_rank() == 0:
167
+ logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...")
168
+ state_dict = th.load(ema_checkpoint, map_location=dist_util.dev())
169
+ ema_params = self._state_dict_to_master_params(state_dict)
170
+
171
+ #dist_util.sync_params(ema_params)
172
+ return ema_params
173
+
174
+ def _load_optimizer_state(self):
175
+ main_checkpoint = self.resume_checkpoint
176
+ opt_checkpoint = bf.join(
177
+ bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt"
178
+ )
179
+ if bf.exists(opt_checkpoint):
180
+ logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
181
+ state_dict = th.load(opt_checkpoint, map_location="cpu")
182
+ try:
183
+ self.opt.load_state_dict(state_dict)
184
+ except:
185
+ pass
186
+
187
+ def _setup_fp16(self):
188
+ self.master_params = make_master_params(self.model_params)
189
+ self.model.convert_to_fp16()
190
+
191
+ def run_loop(self):
192
+ while (
193
+ not self.lr_anneal_steps
194
+ or self.step <= self.lr_anneal_steps
195
+ ):
196
+
197
+ batch, model_kwargs = next(self.data)
198
+
199
+ # uncond_p = 0
200
+ # if self.super_res:
201
+ # uncond_p = 0
202
+ # elif self.finetune_decoder:
203
+ # uncond_p = self.uncond_p
204
+ # elif self.step > self.lr_anneal_steps - 40000:
205
+ # uncond_p = self.uncond_p
206
+
207
+ self.run_step(batch, model_kwargs)
208
+ if self.step % self.log_interval == 0:
209
+ logger.dumpkvs()
210
+ if self.step % self.save_interval == 0:
211
+ self.save()
212
+ self.val(self.step)
213
+ self.step += 1
214
+
215
+ if (self.step - 1) % self.save_interval != 0:
216
+ self.save()
217
+
218
+
219
+ def run_step(self, batch, model_kwargs):
220
+ self.forward_backward(batch, model_kwargs)
221
+ if self.use_fp16:
222
+ self.optimize_fp16()
223
+ else:
224
+ self.optimize_normal()
225
+ self.log_step()
226
+
227
+ def forward_backward(self, batch, model_kwargs):
228
+ zero_grad(self.model_params)
229
+ for i in range(0, batch.shape[0], self.microbatch):
230
+ micro = batch[i : i + self.microbatch].to(dist_util.dev())
231
+ micro_cond={n:model_kwargs[n][i:i+self.microbatch].to(dist_util.dev()) for n in model_kwargs if n in ['ref', 'low_res']}
232
+ last_batch = (i + self.microbatch) >= batch.shape[0]
233
+ t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())
234
+
235
+ if self.step <100:
236
+ vgg_loss = None
237
+ adv_loss = None
238
+ else:
239
+ vgg_loss = self.vgg
240
+ adv_loss = self.adv
241
+ compute_losses = functools.partial(
242
+ self.diffusion.training_losses,
243
+ self.ddp_model,
244
+ micro,
245
+ t,
246
+ vgg_loss,
247
+ adv_loss,
248
+ model_kwargs=micro_cond,
249
+ )
250
+
251
+ if last_batch or not self.use_ddp:
252
+ losses = compute_losses()
253
+ else:
254
+ with self.ddp_model.no_sync():
255
+ losses = compute_losses()
256
+
257
+ if isinstance(self.schedule_sampler, LossAwareSampler):
258
+ self.schedule_sampler.update_with_local_losses(
259
+ t, losses["loss"].detach()
260
+ )
261
+
262
+ loss = (losses["loss"] * weights).mean()
263
+ log_loss_dict(
264
+ self.diffusion, t, {k: v * weights for k, v in losses.items()}
265
+ )
266
+ if self.use_fp16:
267
+ loss_scale = 2 ** self.lg_loss_scale
268
+ (loss * loss_scale).backward()
269
+ else:
270
+ loss.backward()
271
+
272
+ def val(self, step):
273
+ inner_model=self.ddp_model.module
274
+ inner_model.eval()
275
+ if dist.get_rank() == 0:
276
+ print("sampling...")
277
+
278
+ s_path = os.path.join(logger.get_dir(), 'results')
279
+ os.makedirs(s_path,exist_ok=True)
280
+ img_id = 0
281
+ guidance_scale=self.glide_options['sample_c']
282
+
283
+
284
+ while (True):
285
+ if img_id >= self.glide_options['num_samples']:
286
+ break
287
+
288
+ batch, model_kwargs = next(self.val_data)
289
+ with th.no_grad():
290
+ samples=sample(
291
+ glide_model=inner_model,
292
+ glide_options=self.glide_options,
293
+ side_x=self.glide_options['image_size'],
294
+ side_y=self.glide_options['image_size'],
295
+ prompt=model_kwargs,
296
+ batch_size=self.glide_options['batch_size']//2,
297
+ guidance_scale=guidance_scale,
298
+ device=dist_util.dev(),
299
+ prediction_respacing=self.glide_options['sample_respacing'],
300
+ upsample_enabled=self.glide_options['super_res'],
301
+ upsample_temp=0.997,
302
+ mode = self.mode,
303
+ )
304
+
305
+ samples = samples.cpu()
306
+
307
+ ref = model_kwargs['ref_ori']
308
+ # LR = model_kwargs['low_res'].cpu()
309
+
310
+ for i in range(samples.size(0)):
311
+ out_path = os.path.join(s_path, f"{dist.get_rank()}_{img_id}_step{step}_{guidance_scale}_output.png")
312
+ tvu.save_image(
313
+ (samples[i]+1)*0.5, out_path)
314
+
315
+ out_path = os.path.join(s_path, f"{dist.get_rank()}_{img_id}_step{step}_{guidance_scale}_gt.png")
316
+ tvu.save_image(
317
+ (batch[i]+1)*0.5, out_path)
318
+
319
+ out_path = os.path.join(s_path, f"{dist.get_rank()}_{img_id}_step{step}_{guidance_scale}_ref.png")
320
+ tvu.save_image(
321
+ (ref[i]+1)*0.5, out_path)
322
+
323
+ # out_path = os.path.join(s_path, f"{dist.get_rank()}_{img_id}_step{step}_{guidance_scale}_lr.png")
324
+ # tvu.save_image(
325
+ # (LR[i]+1)*0.5, out_path)
326
+
327
+ img_id += 1
328
+ inner_model.train()
329
+
330
+
331
+ def optimize_fp16(self):
332
+ if any(not th.isfinite(p.grad).all() for p in self.model_params):
333
+ self.lg_loss_scale -= 1
334
+ logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
335
+ return
336
+
337
+ model_grads_to_master_grads(self.model_params, self.master_params)
338
+ self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale))
339
+ self._log_grad_norm()
340
+ self._anneal_lr()
341
+ self.opt.step()
342
+ for rate, params in zip(self.ema_rate, self.ema_params):
343
+ update_ema(params, self.master_params, rate=rate)
344
+ master_params_to_model_params(self.model_params, self.master_params)
345
+ self.lg_loss_scale += self.fp16_scale_growth
346
+
347
+ def optimize_normal(self):
348
+ self._log_grad_norm()
349
+ self._anneal_lr()
350
+ self.opt.step()
351
+ for rate, params in zip(self.ema_rate, self.ema_params):
352
+ update_ema(params, self.master_params, rate=rate)
353
+
354
+ def _log_grad_norm(self):
355
+ sqsum = 0.0
356
+ for p in self.master_params:
357
+ sqsum += (p.grad ** 2).sum().item()
358
+ logger.logkv_mean("grad_norm", np.sqrt(sqsum))
359
+
360
+ def _anneal_lr(self):
361
+ return
362
+
363
+
364
+ def log_step(self):
365
+ logger.logkv("step", self.step + self.resume_step)
366
+ logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch)
367
+ if self.use_fp16:
368
+ logger.logkv("lg_loss_scale", self.lg_loss_scale)
369
+
370
+ def save(self):
371
+ def save_checkpoint(rate, params):
372
+ state_dict = self._master_params_to_state_dict(params)
373
+ if dist.get_rank() == 0:
374
+ logger.log(f"saving model {rate}...")
375
+ if not rate:
376
+ filename = f"model{(self.step+self.resume_step):06d}.pt"
377
+ else:
378
+ filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt"
379
+ with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f:
380
+ th.save(state_dict, f)
381
+
382
+ save_checkpoint(0, self.master_params)
383
+ for rate, params in zip(self.ema_rate, self.ema_params):
384
+ save_checkpoint(rate, params)
385
+
386
+ if dist.get_rank() == 0:
387
+ with bf.BlobFile(
388
+ bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"),
389
+ "wb",
390
+ ) as f:
391
+ th.save(self.opt.state_dict(), f)
392
+
393
+ dist.barrier()
394
+
395
+ def _master_params_to_state_dict(self, master_params):
396
+ if self.use_fp16:
397
+ master_params = unflatten_master_params(
398
+ list(self.optimize_model.parameters()), master_params
399
+ )
400
+ state_dict = self.optimize_model.state_dict()
401
+ for i, (name, _value) in enumerate(self.optimize_model.named_parameters()):
402
+ assert name in state_dict
403
+ state_dict[name] = master_params[i]
404
+ return state_dict
405
+
406
+ def _state_dict_to_master_params(self, state_dict):
407
+ params = [state_dict[name] for name, _ in self.optimize_model.named_parameters()]
408
+ if self.use_fp16:
409
+ return make_master_params(params)
410
+ else:
411
+ return params
412
+
413
+
414
+ def parse_resume_step_from_filename(filename):
415
+ """
416
+ Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the
417
+ checkpoint's number of steps.
418
+ """
419
+ filename=filename.split('/')[-1]
420
+ assert(filename.endswith(".pt"))
421
+ filename=filename[:-3]
422
+ if filename.startswith("model"):
423
+ split = filename[5:]
424
+ elif filename.startswith("ema"):
425
+ split = filename.split("_")[-1]
426
+ else:
427
+ return 0
428
+ try:
429
+ return int(split)
430
+ except ValueError:
431
+ return 0
432
+
433
+
434
+ def get_blob_logdir():
435
+ p=os.path.join(logger.get_dir(),"checkpoints")
436
+ os.makedirs(p,exist_ok=True)
437
+ return p
438
+
439
+ def find_resume_checkpoint(resume_checkpoint):
440
+ # On your infrastructure, you may want to override this to automatically
441
+ # discover the latest checkpoint on your blob storage, etc.
442
+ if not resume_checkpoint:
443
+ return None
444
+ if "ROOT" in resume_checkpoint:
445
+ maybe_root=os.environ.get("AMLT_MAP_INPUT_DIR")
446
+ maybe_root="OUTPUT/log" if not maybe_root else maybe_root
447
+ root=os.path.join(maybe_root,"checkpoints")
448
+ resume_checkpoint=resume_checkpoint.replace("ROOT",root)
449
+ if "LATEST" in resume_checkpoint:
450
+ files=glob.glob(resume_checkpoint.replace("LATEST","*.pt"))
451
+ if not files:
452
+ return None
453
+ return max(files,key=parse_resume_step_from_filename)
454
+ return resume_checkpoint
455
+
456
+
457
+
458
+ def find_ema_checkpoint(main_checkpoint, step, rate):
459
+ if main_checkpoint is None:
460
+ return None
461
+ filename = f"ema_{rate}_{(step):06d}.pt"
462
+ path = bf.join(bf.dirname(main_checkpoint), filename)
463
+ if bf.exists(path):
464
+ return path
465
+ return None
466
+
467
+
468
+ def log_loss_dict(diffusion, ts, losses):
469
+ for key, values in losses.items():
470
+ logger.logkv_mean(key, values.mean().item())
471
+ # Log the quantiles (four quartiles, in particular).
472
+ for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()):
473
+ quartile = int(4 * sub_t / diffusion.num_timesteps)
474
+ logger.logkv_mean(f"{key}_q{quartile}", sub_loss)
475
+
glide_text2im/unet.py ADDED
@@ -0,0 +1,635 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from abc import abstractmethod
3
+
4
+ import torch as th
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from .fp16_util import convert_module_to_f16, convert_module_to_f32
9
+ from .nn import avg_pool_nd, conv_nd, linear, normalization, timestep_embedding, zero_module
10
+
11
+
12
+ class TimestepBlock(nn.Module):
13
+ """
14
+ Any module where forward() takes timestep embeddings as a second argument.
15
+ """
16
+
17
+ @abstractmethod
18
+ def forward(self, x, emb):
19
+ """
20
+ Apply the module to `x` given `emb` timestep embeddings.
21
+ """
22
+
23
+
24
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
25
+ """
26
+ A sequential module that passes timestep embeddings to the children that
27
+ support it as an extra input.
28
+ """
29
+
30
+ def forward(self, x, emb, encoder_out=None):
31
+ for layer in self:
32
+ if isinstance(layer, TimestepBlock):
33
+ x = layer(x, emb)
34
+ elif isinstance(layer, AttentionBlock):
35
+ x = layer(x, encoder_out)
36
+ else:
37
+ x = layer(x)
38
+ return x
39
+
40
+
41
+ class Upsample(nn.Module):
42
+ """
43
+ An upsampling layer with an optional convolution.
44
+
45
+ :param channels: channels in the inputs and outputs.
46
+ :param use_conv: a bool determining if a convolution is applied.
47
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
48
+ upsampling occurs in the inner-two dimensions.
49
+ """
50
+
51
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
52
+ super().__init__()
53
+ self.channels = channels
54
+ self.out_channels = out_channels or channels
55
+ self.use_conv = use_conv
56
+ self.dims = dims
57
+ if use_conv:
58
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
59
+
60
+ def forward(self, x):
61
+ assert x.shape[1] == self.channels
62
+ if self.dims == 3:
63
+ x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
64
+ else:
65
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
66
+ if self.use_conv:
67
+ x = self.conv(x)
68
+ return x
69
+
70
+
71
+ class Downsample(nn.Module):
72
+ """
73
+ A downsampling layer with an optional convolution.
74
+
75
+ :param channels: channels in the inputs and outputs.
76
+ :param use_conv: a bool determining if a convolution is applied.
77
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
78
+ downsampling occurs in the inner-two dimensions.
79
+ """
80
+
81
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
82
+ super().__init__()
83
+ self.channels = channels
84
+ self.out_channels = out_channels or channels
85
+ self.use_conv = use_conv
86
+ self.dims = dims
87
+ stride = 2 if dims != 3 else (1, 2, 2)
88
+ if use_conv:
89
+ self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=1)
90
+ else:
91
+ assert self.channels == self.out_channels
92
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
93
+
94
+ def forward(self, x):
95
+ assert x.shape[1] == self.channels
96
+ return self.op(x)
97
+
98
+
99
+ class ResBlock(TimestepBlock):
100
+ """
101
+ A residual block that can optionally change the number of channels.
102
+
103
+ :param channels: the number of input channels.
104
+ :param emb_channels: the number of timestep embedding channels.
105
+ :param dropout: the rate of dropout.
106
+ :param out_channels: if specified, the number of out channels.
107
+ :param use_conv: if True and out_channels is specified, use a spatial
108
+ convolution instead of a smaller 1x1 convolution to change the
109
+ channels in the skip connection.
110
+ :param dims: determines if the signal is 1D, 2D, or 3D.
111
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
112
+ :param up: if True, use this block for upsampling.
113
+ :param down: if True, use this block for downsampling.
114
+ """
115
+
116
+ def __init__(
117
+ self,
118
+ channels,
119
+ emb_channels,
120
+ dropout,
121
+ out_channels=None,
122
+ use_conv=False,
123
+ use_scale_shift_norm=False,
124
+ dims=2,
125
+ use_checkpoint=False,
126
+ up=False,
127
+ down=False,
128
+ ):
129
+ super().__init__()
130
+ self.channels = channels
131
+ self.emb_channels = emb_channels
132
+ self.dropout = dropout
133
+ self.out_channels = out_channels or channels
134
+ self.use_conv = use_conv
135
+ self.use_checkpoint = use_checkpoint
136
+ self.use_scale_shift_norm = use_scale_shift_norm
137
+
138
+ self.in_layers = nn.Sequential(
139
+ normalization(channels, swish=1.0),
140
+ nn.Identity(),
141
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
142
+ )
143
+
144
+ self.updown = up or down
145
+
146
+ if up:
147
+ self.h_upd = Upsample(channels, False, dims)
148
+ self.x_upd = Upsample(channels, False, dims)
149
+ elif down:
150
+ self.h_upd = Downsample(channels, False, dims)
151
+ self.x_upd = Downsample(channels, False, dims)
152
+ else:
153
+ self.h_upd = self.x_upd = nn.Identity()
154
+
155
+ self.emb_layers = nn.Sequential(
156
+ nn.SiLU(),
157
+ linear(
158
+ emb_channels,
159
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
160
+ ),
161
+ )
162
+ self.out_layers = nn.Sequential(
163
+ normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
164
+ nn.SiLU() if use_scale_shift_norm else nn.Identity(),
165
+ nn.Dropout(p=dropout),
166
+ zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
167
+ )
168
+
169
+ if self.out_channels == channels:
170
+ self.skip_connection = nn.Identity()
171
+ elif use_conv:
172
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
173
+ else:
174
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
175
+
176
+ def forward(self, x, emb):
177
+ """
178
+ Apply the block to a Tensor, conditioned on a timestep embedding.
179
+
180
+ :param x: an [N x C x ...] Tensor of features.
181
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
182
+ :return: an [N x C x ...] Tensor of outputs.
183
+ """
184
+ if self.updown:
185
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
186
+ h = in_rest(x)
187
+ h = self.h_upd(h)
188
+ x = self.x_upd(x)
189
+ h = in_conv(h)
190
+ else:
191
+ h = self.in_layers(x)
192
+ emb_out = self.emb_layers(emb).type(h.dtype)
193
+ while len(emb_out.shape) < len(h.shape):
194
+ emb_out = emb_out[..., None]
195
+ if self.use_scale_shift_norm:
196
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
197
+ scale, shift = th.chunk(emb_out, 2, dim=1)
198
+ h = out_norm(h) * (1 + scale) + shift
199
+ h = out_rest(h)
200
+ else:
201
+ h = h + emb_out
202
+ h = self.out_layers(h)
203
+ return self.skip_connection(x) + h
204
+
205
+
206
+ class AttentionBlock(nn.Module):
207
+ """
208
+ An attention block that allows spatial positions to attend to each other.
209
+
210
+ Originally ported from here, but adapted to the N-d case.
211
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
212
+ """
213
+
214
+ def __init__(
215
+ self,
216
+ channels,
217
+ num_heads=1,
218
+ num_head_channels=-1,
219
+ use_checkpoint=False,
220
+ encoder_channels=None,
221
+ ):
222
+ super().__init__()
223
+ self.channels = channels
224
+ if num_head_channels == -1:
225
+ self.num_heads = num_heads
226
+ else:
227
+ assert (
228
+ channels % num_head_channels == 0
229
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
230
+ self.num_heads = channels // num_head_channels
231
+ self.use_checkpoint = use_checkpoint
232
+ self.norm = normalization(channels, swish=0.0)
233
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
234
+ self.attention = QKVAttention(self.num_heads)
235
+
236
+ if encoder_channels is not None:
237
+ self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1)
238
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
239
+
240
+ def forward(self, x, encoder_out=None):
241
+ b, c, *spatial = x.shape
242
+ qkv = self.qkv(self.norm(x).view(b, c, -1))
243
+ if encoder_out is not None:
244
+ encoder_out = self.encoder_kv(encoder_out)
245
+ h = self.attention(qkv, encoder_out)
246
+ else:
247
+ h = self.attention(qkv)
248
+ h = self.proj_out(h)
249
+ return x + h.reshape(b, c, *spatial)
250
+
251
+
252
+ class QKVAttention(nn.Module):
253
+ """
254
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
255
+ """
256
+
257
+ def __init__(self, n_heads):
258
+ super().__init__()
259
+ self.n_heads = n_heads
260
+
261
+ def forward(self, qkv, encoder_kv=None):
262
+ """
263
+ Apply QKV attention.
264
+
265
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
266
+ :return: an [N x (H * C) x T] tensor after attention.
267
+ """
268
+ bs, width, length = qkv.shape
269
+ assert width % (3 * self.n_heads) == 0
270
+ ch = width // (3 * self.n_heads)
271
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
272
+ if encoder_kv is not None:
273
+ assert encoder_kv.shape[1] == self.n_heads * ch * 2
274
+ ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
275
+ k = th.cat([ek, k], dim=-1)
276
+ v = th.cat([ev, v], dim=-1)
277
+ scale = 1 / math.sqrt(math.sqrt(ch))
278
+ weight = th.einsum(
279
+ "bct,bcs->bts", q * scale, k * scale
280
+ ) # More stable with f16 than dividing afterwards
281
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
282
+ a = th.einsum("bts,bcs->bct", weight, v)
283
+ return a.reshape(bs, -1, length)
284
+
285
+
286
+ class UNetModel(nn.Module):
287
+ """
288
+ The full UNet model with attention and timestep embedding.
289
+
290
+ :param in_channels: channels in the input Tensor.
291
+ :param model_channels: base channel count for the model.
292
+ :param out_channels: channels in the output Tensor.
293
+ :param num_res_blocks: number of residual blocks per downsample.
294
+ :param attention_resolutions: a collection of downsample rates at which
295
+ attention will take place. May be a set, list, or tuple.
296
+ For example, if this contains 4, then at 4x downsampling, attention
297
+ will be used.
298
+ :param dropout: the dropout probability.
299
+ :param channel_mult: channel multiplier for each level of the UNet.
300
+ :param conv_resample: if True, use learned convolutions for upsampling and
301
+ downsampling.
302
+ :param dims: determines if the signal is 1D, 2D, or 3D.
303
+ :param num_classes: if specified (as an int), then this model will be
304
+ class-conditional with `num_classes` classes.
305
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
306
+ :param num_heads: the number of attention heads in each attention layer.
307
+ :param num_heads_channels: if specified, ignore num_heads and instead use
308
+ a fixed channel width per attention head.
309
+ :param num_heads_upsample: works with num_heads to set a different number
310
+ of heads for upsampling. Deprecated.
311
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
312
+ :param resblock_updown: use residual blocks for up/downsampling.
313
+ """
314
+
315
+ def __init__(
316
+ self,
317
+ in_channels,
318
+ model_channels,
319
+ out_channels,
320
+ num_res_blocks,
321
+ attention_resolutions,
322
+ dropout=0,
323
+ channel_mult=(1, 2, 4, 8),
324
+ conv_resample=True,
325
+ dims=2,
326
+ num_classes=None,
327
+ use_checkpoint=False,
328
+ use_fp16=False,
329
+ num_heads=1,
330
+ num_head_channels=-1,
331
+ num_heads_upsample=-1,
332
+ use_scale_shift_norm=False,
333
+ resblock_updown=False,
334
+ encoder_channels=None,
335
+ ):
336
+ super().__init__()
337
+
338
+ if num_heads_upsample == -1:
339
+ num_heads_upsample = num_heads
340
+
341
+ self.in_channels = in_channels
342
+ self.model_channels = model_channels
343
+ self.out_channels = out_channels
344
+ self.num_res_blocks = num_res_blocks
345
+ self.attention_resolutions = attention_resolutions
346
+ self.dropout = dropout
347
+ self.channel_mult = channel_mult
348
+ self.conv_resample = conv_resample
349
+ self.num_classes = num_classes
350
+ self.use_checkpoint = use_checkpoint
351
+ self.dtype = th.float16 if use_fp16 else th.float32
352
+ self.num_heads = num_heads
353
+ self.num_head_channels = num_head_channels
354
+ self.num_heads_upsample = num_heads_upsample
355
+
356
+ time_embed_dim = model_channels * 4
357
+ self.time_embed = nn.Sequential(
358
+ linear(model_channels, time_embed_dim),
359
+ nn.SiLU(),
360
+ linear(time_embed_dim, time_embed_dim),
361
+ )
362
+
363
+ if self.num_classes is not None:
364
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
365
+
366
+ ch = input_ch = int(channel_mult[0] * model_channels)
367
+ self.input_blocks = nn.ModuleList(
368
+ [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
369
+ )
370
+ self._feature_size = ch
371
+ input_block_chans = [ch]
372
+ ds = 1
373
+ for level, mult in enumerate(channel_mult):
374
+ for _ in range(num_res_blocks):
375
+ layers = [
376
+ ResBlock(
377
+ ch,
378
+ time_embed_dim,
379
+ dropout,
380
+ out_channels=int(mult * model_channels),
381
+ dims=dims,
382
+ use_checkpoint=use_checkpoint,
383
+ use_scale_shift_norm=use_scale_shift_norm,
384
+ )
385
+ ]
386
+ ch = int(mult * model_channels)
387
+ if ds in attention_resolutions:
388
+ layers.append(
389
+ AttentionBlock(
390
+ ch,
391
+ use_checkpoint=use_checkpoint,
392
+ num_heads=num_heads,
393
+ num_head_channels=num_head_channels,
394
+ encoder_channels=encoder_channels,
395
+ )
396
+ )
397
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
398
+ self._feature_size += ch
399
+ input_block_chans.append(ch)
400
+ if level != len(channel_mult) - 1:
401
+ out_ch = ch
402
+ self.input_blocks.append(
403
+ TimestepEmbedSequential(
404
+ ResBlock(
405
+ ch,
406
+ time_embed_dim,
407
+ dropout,
408
+ out_channels=out_ch,
409
+ dims=dims,
410
+ use_checkpoint=use_checkpoint,
411
+ use_scale_shift_norm=use_scale_shift_norm,
412
+ down=True,
413
+ )
414
+ if resblock_updown
415
+ else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
416
+ )
417
+ )
418
+ ch = out_ch
419
+ input_block_chans.append(ch)
420
+ ds *= 2
421
+ self._feature_size += ch
422
+
423
+ self.middle_block = TimestepEmbedSequential(
424
+ ResBlock(
425
+ ch,
426
+ time_embed_dim,
427
+ dropout,
428
+ dims=dims,
429
+ use_checkpoint=use_checkpoint,
430
+ use_scale_shift_norm=use_scale_shift_norm,
431
+ ),
432
+ AttentionBlock(
433
+ ch,
434
+ use_checkpoint=use_checkpoint,
435
+ num_heads=num_heads,
436
+ num_head_channels=num_head_channels,
437
+ encoder_channels=encoder_channels,
438
+ ),
439
+ ResBlock(
440
+ ch,
441
+ time_embed_dim,
442
+ dropout,
443
+ dims=dims,
444
+ use_checkpoint=use_checkpoint,
445
+ use_scale_shift_norm=use_scale_shift_norm,
446
+ ),
447
+ )
448
+ self._feature_size += ch
449
+
450
+ self.output_blocks = nn.ModuleList([])
451
+ for level, mult in list(enumerate(channel_mult))[::-1]:
452
+ for i in range(num_res_blocks + 1):
453
+ ich = input_block_chans.pop()
454
+ layers = [
455
+ ResBlock(
456
+ ch + ich,
457
+ time_embed_dim,
458
+ dropout,
459
+ out_channels=int(model_channels * mult),
460
+ dims=dims,
461
+ use_checkpoint=use_checkpoint,
462
+ use_scale_shift_norm=use_scale_shift_norm,
463
+ )
464
+ ]
465
+ ch = int(model_channels * mult)
466
+ if ds in attention_resolutions:
467
+ layers.append(
468
+ AttentionBlock(
469
+ ch,
470
+ use_checkpoint=use_checkpoint,
471
+ num_heads=num_heads_upsample,
472
+ num_head_channels=num_head_channels,
473
+ encoder_channels=encoder_channels,
474
+ )
475
+ )
476
+ if level and i == num_res_blocks:
477
+ out_ch = ch
478
+ layers.append(
479
+ ResBlock(
480
+ ch,
481
+ time_embed_dim,
482
+ dropout,
483
+ out_channels=out_ch,
484
+ dims=dims,
485
+ use_checkpoint=use_checkpoint,
486
+ use_scale_shift_norm=use_scale_shift_norm,
487
+ up=True,
488
+ )
489
+ if resblock_updown
490
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
491
+ )
492
+ ds //= 2
493
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
494
+ self._feature_size += ch
495
+
496
+ self.out = nn.Sequential(
497
+ normalization(ch, swish=1.0),
498
+ nn.Identity(),
499
+ zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)),
500
+ )
501
+ self.use_fp16 = use_fp16
502
+
503
+ def convert_to_fp16(self):
504
+ """
505
+ Convert the torso of the model to float16.
506
+ """
507
+ self.input_blocks.apply(convert_module_to_f16)
508
+ self.middle_block.apply(convert_module_to_f16)
509
+ self.output_blocks.apply(convert_module_to_f16)
510
+
511
+ def convert_to_fp32(self):
512
+ """
513
+ Convert the torso of the model to float32.
514
+ """
515
+ self.input_blocks.apply(convert_module_to_f32)
516
+ self.middle_block.apply(convert_module_to_f32)
517
+ self.output_blocks.apply(convert_module_to_f32)
518
+
519
+ def forward(self, x, timesteps, y=None):
520
+ """
521
+ Apply the model to an input batch.
522
+
523
+ :param x: an [N x C x ...] Tensor of inputs.
524
+ :param timesteps: a 1-D batch of timesteps.
525
+ :param y: an [N] Tensor of labels, if class-conditional.
526
+ :return: an [N x C x ...] Tensor of outputs.
527
+ """
528
+ assert (y is not None) == (
529
+ self.num_classes is not None
530
+ ), "must specify y if and only if the model is class-conditional"
531
+
532
+ hs = []
533
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
534
+
535
+ if self.num_classes is not None:
536
+ assert y.shape == (x.shape[0],)
537
+ emb = emb + self.label_emb(y)
538
+
539
+ h = x.type(self.dtype)
540
+ for module in self.input_blocks:
541
+ h = module(h, emb)
542
+ hs.append(h)
543
+ h = self.middle_block(h, emb)
544
+ for module in self.output_blocks:
545
+ h = th.cat([h, hs.pop()], dim=1)
546
+ h = module(h, emb)
547
+ h = h.type(x.dtype)
548
+ return self.out(h)
549
+
550
+ class SuperResUNetModel(UNetModel):
551
+ """
552
+ A UNetModel that performs super-resolution.
553
+
554
+ Expects an extra kwarg `low_res` to condition on a low-resolution image.
555
+ """
556
+
557
+ def __init__(self, *args, **kwargs):
558
+ if "in_channels" in kwargs:
559
+ kwargs = dict(kwargs)
560
+ kwargs["in_channels"] = kwargs["in_channels"] * 2
561
+ else:
562
+ # Curse you, Python. Or really, just curse positional arguments :|.
563
+ args = list(args)
564
+ args[1] = args[1] * 2
565
+ super().__init__(*args, **kwargs)
566
+
567
+ def forward(self, x, timesteps, low_res=None, **kwargs):
568
+ _, _, new_height, new_width = x.shape
569
+ upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
570
+ x = th.cat([x, upsampled], dim=1)
571
+ return super().forward(x, timesteps, **kwargs)
572
+
573
+
574
+ class InpaintUNetModel(UNetModel):
575
+ """
576
+ A UNetModel which can perform inpainting.
577
+ """
578
+
579
+ def __init__(self, *args, **kwargs):
580
+ if "in_channels" in kwargs:
581
+ kwargs = dict(kwargs)
582
+ kwargs["in_channels"] = kwargs["in_channels"] * 2 + 1
583
+ else:
584
+ # Curse you, Python. Or really, just curse positional arguments :|.
585
+ args = list(args)
586
+ args[1] = args[1] * 2 + 1
587
+ super().__init__(*args, **kwargs)
588
+
589
+ def forward(self, x, timesteps, inpaint_image=None, inpaint_mask=None, **kwargs):
590
+ if inpaint_image is None:
591
+ inpaint_image = th.zeros_like(x)
592
+ if inpaint_mask is None:
593
+ inpaint_mask = th.zeros_like(x[:, :1])
594
+ return super().forward(
595
+ th.cat([x, inpaint_image * inpaint_mask, inpaint_mask], dim=1),
596
+ timesteps,
597
+ **kwargs,
598
+ )
599
+
600
+
601
+ class SuperResInpaintUNetModel(UNetModel):
602
+ """
603
+ A UNetModel which can perform both upsampling and inpainting.
604
+ """
605
+
606
+ def __init__(self, *args, **kwargs):
607
+ if "in_channels" in kwargs:
608
+ kwargs = dict(kwargs)
609
+ kwargs["in_channels"] = kwargs["in_channels"] * 3 + 1
610
+ else:
611
+ # Curse you, Python. Or really, just curse positional arguments :|.
612
+ args = list(args)
613
+ args[1] = args[1] * 3 + 1
614
+ super().__init__(*args, **kwargs)
615
+
616
+ def forward(
617
+ self,
618
+ x,
619
+ timesteps,
620
+ inpaint_image=None,
621
+ inpaint_mask=None,
622
+ low_res=None,
623
+ **kwargs,
624
+ ):
625
+ if inpaint_image is None:
626
+ inpaint_image = th.zeros_like(x)
627
+ if inpaint_mask is None:
628
+ inpaint_mask = th.zeros_like(x[:, :1])
629
+ _, _, new_height, new_width = x.shape
630
+ upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
631
+ return super().forward(
632
+ th.cat([x, inpaint_image * inpaint_mask, inpaint_mask, upsampled], dim=1),
633
+ timesteps,
634
+ **kwargs,
635
+ )
glide_text2im/vgg.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision.models as models
5
+ from .nn import mean_flat
6
+
7
+ # input image range [-1,1]
8
+ class VGG(nn.Module):
9
+ def __init__(self, conv_index='22', rgb_range=1):
10
+ super(VGG, self).__init__()
11
+ vgg_features = models.vgg19(pretrained=True).features
12
+ modules = [m for m in vgg_features]
13
+ if conv_index.find('22') >= 0:
14
+ self.vgg = nn.Sequential(*modules[:8])
15
+ elif conv_index.find('54') >= 0:
16
+ self.vgg = nn.Sequential(*modules[:35])
17
+
18
+ vgg_mean = (0.485, 0.456, 0.406)
19
+ vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
20
+ self.sub_mean = MeanShift(rgb_range, vgg_mean, vgg_std)
21
+ for p in self.parameters():
22
+ p.requires_grad = False
23
+
24
+ def forward(self, sr, hr):
25
+ def _forward(x):
26
+ x = self.sub_mean(x)
27
+ x = self.vgg(x)
28
+ return x
29
+
30
+ sr = (sr + 1.)/2.
31
+ hr = (hr + 1.)/2.
32
+
33
+ vgg_sr = _forward(sr)
34
+ with torch.no_grad():
35
+ vgg_hr = _forward(hr.detach())
36
+
37
+ loss = mean_flat((vgg_sr - vgg_hr) ** 2)
38
+
39
+ return loss
40
+
41
+ class MeanShift(nn.Conv2d):
42
+ def __init__(
43
+ self, rgb_range,
44
+ rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):
45
+
46
+ super(MeanShift, self).__init__(3, 3, kernel_size=1)
47
+ std = torch.Tensor(rgb_std)
48
+ self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
49
+ self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
50
+ for p in self.parameters():
51
+ p.requires_grad = False
glide_text2im/xf.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Transformer implementation adapted from CLIP ViT:
3
+ https://github.com/openai/CLIP/blob/4c0275784d6d9da97ca1f47eaaee31de1867da91/clip/model.py
4
+ """
5
+
6
+ import math
7
+
8
+ import torch as th
9
+ import torch.nn as nn
10
+
11
+
12
+ def convert_module_to_f16(l):
13
+ """
14
+ Convert primitive modules to float16.
15
+ """
16
+ if isinstance(l, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
17
+ l.weight.data = l.weight.data.half()
18
+ if l.bias is not None:
19
+ l.bias.data = l.bias.data.half()
20
+
21
+
22
+ class LayerNorm(nn.LayerNorm):
23
+ """
24
+ Implementation that supports fp16 inputs but fp32 gains/biases.
25
+ """
26
+
27
+ def forward(self, x: th.Tensor):
28
+ return super().forward(x.float()).to(x.dtype)
29
+
30
+
31
+ class MultiheadAttention(nn.Module):
32
+ def __init__(self, width, heads):
33
+ super().__init__()
34
+ self.width = width
35
+ self.heads = heads
36
+ self.c_qkv = nn.Linear(width, width * 3)
37
+ self.c_proj = nn.Linear(width, width)
38
+ self.attention = QKVMultiheadAttention(heads)
39
+
40
+ def forward(self, x):
41
+ x = self.c_qkv(x)
42
+ x = self.attention(x)
43
+ x = self.c_proj(x)
44
+ return x
45
+
46
+
47
+ class MLP(nn.Module):
48
+ def __init__(self, width):
49
+ super().__init__()
50
+ self.width = width
51
+ self.c_fc = nn.Linear(width, width * 4)
52
+ self.c_proj = nn.Linear(width * 4, width)
53
+ self.gelu = nn.GELU()
54
+
55
+ def forward(self, x):
56
+ return self.c_proj(self.gelu(self.c_fc(x)))
57
+
58
+
59
+ class QKVMultiheadAttention(nn.Module):
60
+ def __init__(self, n_heads: int):
61
+ super().__init__()
62
+ self.n_heads = n_heads
63
+
64
+ def forward(self, qkv):
65
+ bs, n_ctx, width = qkv.shape
66
+ attn_ch = width // self.n_heads // 3
67
+ scale = 1 / math.sqrt(math.sqrt(attn_ch))
68
+ qkv = qkv.view(bs, n_ctx, self.n_heads, -1)
69
+ q, k, v = th.split(qkv, attn_ch, dim=-1)
70
+ weight = th.einsum(
71
+ "bthc,bshc->bhts", q * scale, k * scale
72
+ ) # More stable with f16 than dividing afterwards
73
+ wdtype = weight.dtype
74
+ weight = th.softmax(weight.float(), dim=-1).type(wdtype)
75
+ return th.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
76
+
77
+
78
+ class ResidualAttentionBlock(nn.Module):
79
+ def __init__(
80
+ self,
81
+ width: int,
82
+ heads: int,
83
+ ):
84
+ super().__init__()
85
+
86
+ self.attn = MultiheadAttention(
87
+ width,
88
+ heads,
89
+ )
90
+ self.ln_1 = LayerNorm(width)
91
+ self.mlp = MLP(width)
92
+ self.ln_2 = LayerNorm(width)
93
+
94
+ def forward(self, x: th.Tensor):
95
+ x = x + self.attn(self.ln_1(x))
96
+ x = x + self.mlp(self.ln_2(x))
97
+ return x
98
+
99
+
100
+ class Transformer(nn.Module):
101
+ def __init__(
102
+ self,
103
+ width: int,
104
+ layers: int,
105
+ heads: int,
106
+ ):
107
+ super().__init__()
108
+ self.width = width
109
+ self.layers = layers
110
+ self.resblocks = nn.ModuleList(
111
+ [
112
+ ResidualAttentionBlock(
113
+ width,
114
+ heads,
115
+ )
116
+ for _ in range(layers)
117
+ ]
118
+ )
119
+
120
+ def forward(self, x: th.Tensor):
121
+ for block in self.resblocks:
122
+ x = block(x)
123
+ return x