williamberman commited on
Commit
aa67e5e
1 Parent(s): e4ea387

vae working

Browse files
Files changed (2) hide show
  1. sdxl.py +1 -1
  2. sdxl_models.py +18 -28
sdxl.py CHANGED
@@ -9,7 +9,6 @@ import torch
9
  import torch.nn.functional as F
10
  import torchvision.transforms
11
  import torchvision.transforms.functional as TF
12
- import wandb
13
  import webdataset as wds
14
  from PIL import Image
15
  from torch.nn.parallel import DistributedDataParallel as DDP
@@ -17,6 +16,7 @@ from torch.utils.data import default_collate
17
  from transformers import (CLIPTextModel, CLIPTextModelWithProjection,
18
  CLIPTokenizerFast)
19
 
 
20
  from diffusion import (default_num_train_timesteps,
21
  euler_ode_solver_diffusion_loop, make_sigmas)
22
  from sdxl_models import (SDXLAdapter, SDXLControlNet, SDXLControlNetFull,
 
9
  import torch.nn.functional as F
10
  import torchvision.transforms
11
  import torchvision.transforms.functional as TF
 
12
  import webdataset as wds
13
  from PIL import Image
14
  from torch.nn.parallel import DistributedDataParallel as DDP
 
16
  from transformers import (CLIPTextModel, CLIPTextModelWithProjection,
17
  CLIPTokenizerFast)
18
 
19
+ import wandb
20
  from diffusion import (default_num_train_timesteps,
21
  euler_ode_solver_diffusion_loop, make_sigmas)
22
  from sdxl_models import (SDXLAdapter, SDXLControlNet, SDXLControlNetFull,
sdxl_models.py CHANGED
@@ -6,7 +6,7 @@ import safetensors.torch
6
  import torch
7
  import torch.nn.functional as F
8
  import torchvision.transforms.functional as TF
9
- import xformers
10
  from PIL import Image
11
  from torch import nn
12
 
@@ -62,17 +62,17 @@ class SDXLVae(nn.Module, ModelUtils):
62
  # 128 -> 128
63
  nn.ModuleDict(dict(
64
  resnets=nn.ModuleList([ResnetBlock2D(128, 128, eps=1e-6), ResnetBlock2D(128, 128, eps=1e-6)]),
65
- downsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1)))]),
66
  )),
67
  # 128 -> 256
68
  nn.ModuleDict(dict(
69
  resnets=nn.ModuleList([ResnetBlock2D(128, 256, eps=1e-6), ResnetBlock2D(256, 256, eps=1e-6)]),
70
- downsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1)))]),
71
  )),
72
  # 256 -> 512
73
  nn.ModuleDict(dict(
74
  resnets=nn.ModuleList([ResnetBlock2D(256, 512, eps=1e-6), ResnetBlock2D(512, 512, eps=1e-6)]),
75
- downsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1)))]),
76
  )),
77
  # 512 -> 512
78
  nn.ModuleDict(dict(resnets=nn.ModuleList([ResnetBlock2D(512, 512, eps=1e-6), ResnetBlock2D(512, 512, eps=1e-6)]))),
@@ -151,6 +151,7 @@ class SDXLVae(nn.Module, ModelUtils):
151
  h = resnet(h)
152
 
153
  if "downsamplers" in down_block:
 
154
  h = down_block["downsamplers"][0]["conv"](h)
155
 
156
  h = self.encoder["mid_block"]["resnets"][0](h)
@@ -1333,49 +1334,38 @@ class Attention(nn.Module):
1333
  self.to_out = nn.Sequential(nn.Linear(channels, channels), nn.Dropout(0.0))
1334
 
1335
  def forward(self, hidden_states, encoder_hidden_states=None):
1336
- input_ndim = hidden_states.ndim
1337
-
1338
- if input_ndim == 4:
1339
- batch_size, channels, height, width = hidden_states.shape
1340
- hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2)
1341
-
1342
- hidden_states = attention(self.to_q, self.to_k, self.to_v, self.to_out, hidden_states, encoder_hidden_states)
1343
-
1344
- if input_ndim == 4:
1345
- hidden_states = hidden_states.transpose(1, 2).view(batch_size, channels, height, width)
1346
-
1347
- return hidden_states
1348
 
1349
 
1350
  class VaeMidBlockAttention(nn.Module):
1351
  def __init__(self, channels):
1352
  super().__init__()
1353
  self.group_norm = nn.GroupNorm(32, channels, eps=1e-06)
1354
- self.to_q = nn.Linear(channels, channels, bias=True)
1355
- self.to_k = nn.Linear(channels, channels, bias=True)
1356
- self.to_v = nn.Linear(channels, channels, bias=True)
1357
  self.to_out = nn.Sequential(nn.Linear(channels, channels), nn.Dropout(0.0))
 
1358
 
1359
  def forward(self, hidden_states):
1360
- input_ndim = hidden_states.ndim
1361
 
1362
- if input_ndim == 4:
1363
- batch_size, channels, height, width = hidden_states.shape
1364
- hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2)
1365
 
1366
  hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1367
 
1368
- hidden_states = attention(self.to_q, self.to_k, self.to_v, self.to_out, hidden_states)
1369
 
1370
- if input_ndim == 4:
1371
- hidden_states = hidden_states.transpose(1, 2).view(batch_size, channels, height, width)
 
1372
 
1373
  return hidden_states
1374
 
1375
 
1376
- def attention(to_q, to_k, to_v, to_out, hidden_states, encoder_hidden_states=None):
1377
  batch_size, q_seq_len, channels = hidden_states.shape
1378
- head_dim = 64
1379
 
1380
  if encoder_hidden_states is not None:
1381
  kv = encoder_hidden_states
 
6
  import torch
7
  import torch.nn.functional as F
8
  import torchvision.transforms.functional as TF
9
+ import xformers.ops
10
  from PIL import Image
11
  from torch import nn
12
 
 
62
  # 128 -> 128
63
  nn.ModuleDict(dict(
64
  resnets=nn.ModuleList([ResnetBlock2D(128, 128, eps=1e-6), ResnetBlock2D(128, 128, eps=1e-6)]),
65
+ downsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(128, 128, kernel_size=3, stride=2)))]),
66
  )),
67
  # 128 -> 256
68
  nn.ModuleDict(dict(
69
  resnets=nn.ModuleList([ResnetBlock2D(128, 256, eps=1e-6), ResnetBlock2D(256, 256, eps=1e-6)]),
70
+ downsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(256, 256, kernel_size=3, stride=2)))]),
71
  )),
72
  # 256 -> 512
73
  nn.ModuleDict(dict(
74
  resnets=nn.ModuleList([ResnetBlock2D(256, 512, eps=1e-6), ResnetBlock2D(512, 512, eps=1e-6)]),
75
+ downsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(512, 512, kernel_size=3, stride=2)))]),
76
  )),
77
  # 512 -> 512
78
  nn.ModuleDict(dict(resnets=nn.ModuleList([ResnetBlock2D(512, 512, eps=1e-6), ResnetBlock2D(512, 512, eps=1e-6)]))),
 
151
  h = resnet(h)
152
 
153
  if "downsamplers" in down_block:
154
+ h = F.pad(h, pad=(0, 1, 0, 1), mode="constant", value=0)
155
  h = down_block["downsamplers"][0]["conv"](h)
156
 
157
  h = self.encoder["mid_block"]["resnets"][0](h)
 
1334
  self.to_out = nn.Sequential(nn.Linear(channels, channels), nn.Dropout(0.0))
1335
 
1336
  def forward(self, hidden_states, encoder_hidden_states=None):
1337
+ return attention(self.to_q, self.to_k, self.to_v, self.to_out, 64, hidden_states, encoder_hidden_states)
 
 
 
 
 
 
 
 
 
 
 
1338
 
1339
 
1340
  class VaeMidBlockAttention(nn.Module):
1341
  def __init__(self, channels):
1342
  super().__init__()
1343
  self.group_norm = nn.GroupNorm(32, channels, eps=1e-06)
1344
+ self.to_q = nn.Linear(channels, channels)
1345
+ self.to_k = nn.Linear(channels, channels)
1346
+ self.to_v = nn.Linear(channels, channels)
1347
  self.to_out = nn.Sequential(nn.Linear(channels, channels), nn.Dropout(0.0))
1348
+ self.head_dim = channels
1349
 
1350
  def forward(self, hidden_states):
1351
+ residual = hidden_states
1352
 
1353
+ batch_size, channels, height, width = hidden_states.shape
1354
+ hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2)
 
1355
 
1356
  hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1357
 
1358
+ hidden_states = attention(self.to_q, self.to_k, self.to_v, self.to_out, self.head_dim, hidden_states)
1359
 
1360
+ hidden_states = hidden_states.transpose(1, 2).view(batch_size, channels, height, width)
1361
+
1362
+ hidden_states = hidden_states + residual
1363
 
1364
  return hidden_states
1365
 
1366
 
1367
+ def attention(to_q, to_k, to_v, to_out, head_dim, hidden_states, encoder_hidden_states=None):
1368
  batch_size, q_seq_len, channels = hidden_states.shape
 
1369
 
1370
  if encoder_hidden_states is not None:
1371
  kv = encoder_hidden_states