pawlo2013 commited on
Commit
5086590
β€’
1 Parent(s): 1f82180

fixed code readability

Browse files
app.py CHANGED
@@ -1,13 +1,11 @@
1
  import gradio as gr
2
  from PIL import Image
3
- import numpy as np
4
  from torchvision import transforms
5
  from load_model import sample
6
  import torch
7
- import glob
8
  import random
9
  import os
10
- import pathlib
11
 
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
  device = "mps" if torch.backends.mps.is_available() else device
 
1
  import gradio as gr
2
  from PIL import Image
 
3
  from torchvision import transforms
4
  from load_model import sample
5
  import torch
 
6
  import random
7
  import os
8
+
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
  device = "mps" if torch.backends.mps.is_available() else device
load_model.py CHANGED
@@ -1,12 +1,10 @@
1
- from models.structure.Unet_3 import Unet
2
  from diffusers import DDPMScheduler
3
  import torch
4
  import os
5
  import glob
6
- from tqdm import tqdm
7
  from torchvision import transforms
8
  import pathlib
9
- from torchvision.utils import save_image
10
  from safetensors.torch import load_model, save_model
11
  import time as tm
12
 
 
1
+ from models.structure.Advanced_Conditional_Unet import Unet
2
  from diffusers import DDPMScheduler
3
  import torch
4
  import os
5
  import glob
 
6
  from torchvision import transforms
7
  import pathlib
 
8
  from safetensors.torch import load_model, save_model
9
  import time as tm
10
 
models/structure/{Unet_3.py β†’ Advanced_Conditional_Unet.py} RENAMED
@@ -1,14 +1,8 @@
1
- import math
2
- from inspect import isfunction
3
  from functools import partial
4
- import matplotlib.pyplot as plt
5
- from tqdm.auto import tqdm
6
- from einops import rearrange
7
  import torch
8
- from torch import nn, einsum
9
  import torch.nn.functional as F
10
- from .Advanced_Network_Helpers_3 import *
11
- from transformers import PreTrainedModel
12
 
13
 
14
  class Unet(nn.Module):
 
 
 
1
  from functools import partial
 
 
 
2
  import torch
3
+ from torch import nn
4
  import torch.nn.functional as F
5
+ from .Advanced_Network_Helpers import *
 
6
 
7
 
8
  class Unet(nn.Module):
models/structure/Advanced_Network_Helpers.py CHANGED
@@ -143,23 +143,13 @@ class Attention(nn.Module):
143
  self.to_v = nn.Conv2d(dim, hidden_dim, 1, bias=False)
144
  self.to_out = nn.Conv2d(hidden_dim, dim, 1)
145
 
146
- def forward(self, x, cross_attend=None):
147
  b, c, h, w = x.shape
148
 
149
- if cross_attend is not None:
150
- assert cross_attend.shape == x.shape
151
-
152
- q_att = self.to_q(x)
153
- k_att = self.to_k(cross_attend)
154
- v_att = self.to_v(cross_attend)
155
- q = rearrange(q_att, "b (h c) x y -> b h c (x y)", h=self.heads)
156
- k = rearrange(k_att, "b (h c) x y -> b h c (x y)", h=self.heads)
157
- v = rearrange(v_att, "b (h c) x y -> b h c (x y)", h=self.heads)
158
- else:
159
- qkv = self.to_qkv(x).chunk(3, dim=1)
160
- q, k, v = map(
161
- lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
162
- )
163
  q = q * self.scale
164
 
165
  sim = einsum("b h d i, b h d j -> b h i j", q, k)
@@ -173,7 +163,7 @@ class Attention(nn.Module):
173
 
174
 
175
  class LinearCrossAttention(nn.Module):
176
- def __init__(self, dim, heads=12, dim_head=128) -> None:
177
  super().__init__()
178
  self.scale = dim_head**-0.5
179
  self.heads = heads
@@ -210,25 +200,12 @@ class LinearAttention(nn.Module):
210
  self.to_v = nn.Conv2d(dim, hidden_dim, 1, bias=False)
211
  self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), nn.GroupNorm(1, dim))
212
 
213
- def forward(self, x, cross_attend=None):
214
  b, c, h, w = x.shape
215
- if cross_attend is not None:
216
- assert (
217
- cross_attend.shape == x.shape
218
- ), f"cross_attend must be same shape as x is {cross_attend.shape} and x is {x.shape}"
219
-
220
- q_att = self.to_q(x)
221
- k_att = self.to_k(cross_attend)
222
- v_att = self.to_v(cross_attend)
223
- q = rearrange(q_att, "b (h c) x y -> b h c (x y)", h=self.heads)
224
- k = rearrange(k_att, "b (h c) x y -> b h c (x y)", h=self.heads)
225
- v = rearrange(v_att, "b (h c) x y -> b h c (x y)", h=self.heads)
226
-
227
- else:
228
- qkv = self.to_qkv(x).chunk(3, dim=1)
229
- q, k, v = map(
230
- lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
231
- )
232
  # calculate the softmax with respect to columns softmax of equivalent to q^T with respect to last dim
233
  q = q.softmax(dim=-2)
234
  # calculate the softmax with respect to rows of k
 
143
  self.to_v = nn.Conv2d(dim, hidden_dim, 1, bias=False)
144
  self.to_out = nn.Conv2d(hidden_dim, dim, 1)
145
 
146
+ def forward(self, x):
147
  b, c, h, w = x.shape
148
 
149
+ qkv = self.to_qkv(x).chunk(3, dim=1)
150
+ q, k, v = map(
151
+ lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
152
+ )
 
 
 
 
 
 
 
 
 
 
153
  q = q * self.scale
154
 
155
  sim = einsum("b h d i, b h d j -> b h i j", q, k)
 
163
 
164
 
165
  class LinearCrossAttention(nn.Module):
166
+ def __init__(self, dim, heads=4, dim_head=32) -> None:
167
  super().__init__()
168
  self.scale = dim_head**-0.5
169
  self.heads = heads
 
200
  self.to_v = nn.Conv2d(dim, hidden_dim, 1, bias=False)
201
  self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), nn.GroupNorm(1, dim))
202
 
203
+ def forward(self, x):
204
  b, c, h, w = x.shape
205
+ qkv = self.to_qkv(x).chunk(3, dim=1)
206
+ q, k, v = map(
207
+ lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
208
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  # calculate the softmax with respect to columns softmax of equivalent to q^T with respect to last dim
210
  q = q.softmax(dim=-2)
211
  # calculate the softmax with respect to rows of k
models/structure/Advanced_Network_Helpers_2.py DELETED
@@ -1,232 +0,0 @@
1
- import math
2
- from inspect import isfunction
3
- from functools import partial
4
- import matplotlib.pyplot as plt
5
- from tqdm.auto import tqdm
6
- from einops import rearrange
7
- import torch
8
- from torch import nn, einsum
9
- import torch.nn.functional as F
10
-
11
-
12
- def exists(x):
13
- return x is not None
14
-
15
-
16
- def default(val, d):
17
- if exists(val):
18
- return val
19
- return d() if isfunction(d) else d
20
-
21
-
22
- class Residual(nn.Module):
23
- def __init__(self, fn):
24
- super().__init__()
25
- self.fn = fn
26
-
27
- def forward(self, x, *args, **kwargs):
28
- return self.fn(x, *args, **kwargs) + x
29
-
30
-
31
- def Upsample(dim):
32
- return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
33
-
34
-
35
- def Downsample(dim):
36
- return nn.Conv2d(dim, dim, 4, 2, 1)
37
-
38
-
39
- class SinusoidalPositionEmbeddings(nn.Module):
40
- def __init__(self, dim):
41
- super().__init__()
42
- self.dim = dim
43
-
44
- def forward(self, time):
45
- device = time.device
46
- half_dim = self.dim // 2
47
- embeddings = math.log(10000) / (half_dim - 1)
48
- embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
49
- embeddings = time[:, None] * embeddings[None, :]
50
- embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
51
- return embeddings
52
-
53
-
54
- class Block(nn.Module):
55
- def __init__(self, dim, dim_out, groups=8):
56
- super().__init__()
57
- self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)
58
- self.norm = nn.GroupNorm(groups, dim_out)
59
- self.act = nn.SiLU()
60
-
61
- def forward(self, x, scale_shift=None):
62
- x = self.proj(x)
63
- x = self.norm(x)
64
-
65
- if exists(scale_shift):
66
- scale, shift = scale_shift
67
- x = x * (scale + 1) + shift
68
-
69
- x = self.act(x)
70
- return x
71
-
72
-
73
- class ResnetBlock(nn.Module):
74
- """https://arxiv.org/abs/1512.03385"""
75
-
76
- def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
77
- super().__init__()
78
- self.mlp = (
79
- nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out))
80
- if exists(time_emb_dim)
81
- else None
82
- )
83
-
84
- self.block1 = Block(dim, dim_out, groups=groups)
85
- self.block2 = Block(dim_out, dim_out, groups=groups)
86
- self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
87
-
88
- def forward(self, x, time_emb=None):
89
- h = self.block1(x)
90
-
91
- if exists(self.mlp) and exists(time_emb):
92
- time_emb = self.mlp(time_emb)
93
- h = rearrange(time_emb, "b c -> b c 1 1") + h
94
-
95
- h = self.block2(h)
96
- return h + self.res_conv(x)
97
-
98
-
99
- class ConvNextBlock(nn.Module):
100
- """https://arxiv.org/abs/2201.03545"""
101
-
102
- def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True):
103
- super().__init__()
104
- self.mlp = (
105
- nn.Sequential(nn.GELU(), nn.Linear(time_emb_dim, dim))
106
- if exists(time_emb_dim)
107
- else None
108
- )
109
-
110
- self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)
111
-
112
- self.net = nn.Sequential(
113
- nn.GroupNorm(1, dim) if norm else nn.Identity(),
114
- nn.Conv2d(dim, dim_out * mult, 3, padding=1),
115
- nn.GELU(),
116
- nn.GroupNorm(1, dim_out * mult),
117
- nn.Conv2d(dim_out * mult, dim_out, 3, padding=1),
118
- )
119
-
120
- self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
121
-
122
- def forward(self, x, time_emb=None):
123
- h = self.ds_conv(x)
124
-
125
- if exists(self.mlp) and exists(time_emb):
126
- assert exists(time_emb), "time embedding must be passed in"
127
- condition = self.mlp(time_emb)
128
- h = h + rearrange(condition, "b c -> b c 1 1")
129
-
130
- h = self.net(h)
131
- return h + self.res_conv(x)
132
-
133
-
134
- class Attention(nn.Module):
135
- def __init__(self, dim, heads=4, dim_head=32):
136
- super().__init__()
137
- self.scale = dim_head**-0.5
138
- self.heads = heads
139
- hidden_dim = dim_head * heads
140
- self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
141
- self.to_q = nn.Conv2d(dim, hidden_dim, 1, bias=False)
142
- self.to_k = nn.Conv2d(dim, hidden_dim, 1, bias=False)
143
- self.to_v = nn.Conv2d(dim, hidden_dim, 1, bias=False)
144
- self.to_out = nn.Conv2d(hidden_dim, dim, 1)
145
-
146
- def forward(self, x):
147
- b, c, h, w = x.shape
148
-
149
- qkv = self.to_qkv(x).chunk(3, dim=1)
150
- q, k, v = map(
151
- lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
152
- )
153
- q = q * self.scale
154
-
155
- sim = einsum("b h d i, b h d j -> b h i j", q, k)
156
- sim = sim - sim.amax(dim=-1, keepdim=True).detach()
157
- attn = sim.softmax(dim=-1)
158
-
159
- out = einsum("b h i j, b h d j -> b h i d", attn, v)
160
- out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
161
-
162
- return self.to_out(out)
163
-
164
-
165
- class LinearCrossAttention(nn.Module):
166
- def __init__(self, dim, heads=4, dim_head=32) -> None:
167
- super().__init__()
168
- self.scale = dim_head**-0.5
169
- self.heads = heads
170
- hidden_dim = dim_head * heads
171
- self.to_kv = nn.Conv2d(dim, hidden_dim * 2, 1, bias=False)
172
- self.to_q = nn.Conv2d(dim, hidden_dim, 1, bias=False)
173
- self.out = nn.Conv2d(hidden_dim, dim, 1)
174
-
175
- def forward(self, x, cross_attend):
176
- b, c, h, w = x.shape
177
- q = self.to_q(x)
178
- k, v = self.to_kv(cross_attend).chunk(2, dim=1)
179
- q = rearrange(q, "b (h c) x y -> b h c (x y)", h=self.heads)
180
- k = rearrange(k, "b (h c) x y -> b h c (x y)", h=self.heads)
181
- v = rearrange(v, "b (h c) x y -> b h c (x y)", h=self.heads)
182
- q = q * self.scale
183
- sim = einsum("b h d i, b h d j -> b h i j", q, k)
184
- sim = sim - sim.amax(dim=-1, keepdim=True).detach()
185
- attn = sim.softmax(dim=-1)
186
- out = einsum("b h i j, b h d j -> b h i d", attn, v)
187
- out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
188
- return self.out(out)
189
-
190
-
191
- class LinearAttention(nn.Module):
192
- def __init__(self, dim, heads=4, dim_head=32):
193
- super().__init__()
194
- self.scale = dim_head**-0.5
195
- self.heads = heads
196
- hidden_dim = dim_head * heads
197
- self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
198
- self.to_q = nn.Conv2d(dim, hidden_dim, 1, bias=False)
199
- self.to_k = nn.Conv2d(dim, hidden_dim, 1, bias=False)
200
- self.to_v = nn.Conv2d(dim, hidden_dim, 1, bias=False)
201
- self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), nn.GroupNorm(1, dim))
202
-
203
- def forward(self, x):
204
- b, c, h, w = x.shape
205
- qkv = self.to_qkv(x).chunk(3, dim=1)
206
- q, k, v = map(
207
- lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
208
- )
209
- # calculate the softmax with respect to columns softmax of equivalent to q^T with respect to last dim
210
- q = q.softmax(dim=-2)
211
- # calculate the softmax with respect to rows of k
212
- k = k.softmax(dim=-1)
213
- # normalize the values in the attention matrix
214
- q = q * self.scale
215
- # dot product of q and v matrices
216
- context = torch.einsum("b h d n, b h e n -> b h d e", k, v)
217
- # dot product of context and q
218
- out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
219
- # rearrange the output to match the pytorch convention
220
- out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
221
- return self.to_out(out)
222
-
223
-
224
- class PreNorm(nn.Module):
225
- def __init__(self, dim, fn):
226
- super().__init__()
227
- self.fn = fn
228
- self.norm = nn.GroupNorm(1, dim)
229
-
230
- def forward(self, x, *args, **kwargs):
231
- x = self.norm(x)
232
- return self.fn(x, *args, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/structure/Advanced_Network_Helpers_3.py DELETED
@@ -1,232 +0,0 @@
1
- import math
2
- from inspect import isfunction
3
- from functools import partial
4
- import matplotlib.pyplot as plt
5
- from tqdm.auto import tqdm
6
- from einops import rearrange
7
- import torch
8
- from torch import nn, einsum
9
- import torch.nn.functional as F
10
-
11
-
12
- def exists(x):
13
- return x is not None
14
-
15
-
16
- def default(val, d):
17
- if exists(val):
18
- return val
19
- return d() if isfunction(d) else d
20
-
21
-
22
- class Residual(nn.Module):
23
- def __init__(self, fn):
24
- super().__init__()
25
- self.fn = fn
26
-
27
- def forward(self, x, *args, **kwargs):
28
- return self.fn(x, *args, **kwargs) + x
29
-
30
-
31
- def Upsample(dim):
32
- return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
33
-
34
-
35
- def Downsample(dim):
36
- return nn.Conv2d(dim, dim, 4, 2, 1)
37
-
38
-
39
- class SinusoidalPositionEmbeddings(nn.Module):
40
- def __init__(self, dim):
41
- super().__init__()
42
- self.dim = dim
43
-
44
- def forward(self, time):
45
- device = time.device
46
- half_dim = self.dim // 2
47
- embeddings = math.log(10000) / (half_dim - 1)
48
- embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
49
- embeddings = time[:, None] * embeddings[None, :]
50
- embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
51
- return embeddings
52
-
53
-
54
- class Block(nn.Module):
55
- def __init__(self, dim, dim_out, groups=8):
56
- super().__init__()
57
- self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)
58
- self.norm = nn.GroupNorm(groups, dim_out)
59
- self.act = nn.SiLU()
60
-
61
- def forward(self, x, scale_shift=None):
62
- x = self.proj(x)
63
- x = self.norm(x)
64
-
65
- if exists(scale_shift):
66
- scale, shift = scale_shift
67
- x = x * (scale + 1) + shift
68
-
69
- x = self.act(x)
70
- return x
71
-
72
-
73
- class ResnetBlock(nn.Module):
74
- """https://arxiv.org/abs/1512.03385"""
75
-
76
- def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
77
- super().__init__()
78
- self.mlp = (
79
- nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out))
80
- if exists(time_emb_dim)
81
- else None
82
- )
83
-
84
- self.block1 = Block(dim, dim_out, groups=groups)
85
- self.block2 = Block(dim_out, dim_out, groups=groups)
86
- self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
87
-
88
- def forward(self, x, time_emb=None):
89
- h = self.block1(x)
90
-
91
- if exists(self.mlp) and exists(time_emb):
92
- time_emb = self.mlp(time_emb)
93
- h = rearrange(time_emb, "b c -> b c 1 1") + h
94
-
95
- h = self.block2(h)
96
- return h + self.res_conv(x)
97
-
98
-
99
- class ConvNextBlock(nn.Module):
100
- """https://arxiv.org/abs/2201.03545"""
101
-
102
- def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True):
103
- super().__init__()
104
- self.mlp = (
105
- nn.Sequential(nn.GELU(), nn.Linear(time_emb_dim, dim))
106
- if exists(time_emb_dim)
107
- else None
108
- )
109
-
110
- self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)
111
-
112
- self.net = nn.Sequential(
113
- nn.GroupNorm(1, dim) if norm else nn.Identity(),
114
- nn.Conv2d(dim, dim_out * mult, 3, padding=1),
115
- nn.GELU(),
116
- nn.GroupNorm(1, dim_out * mult),
117
- nn.Conv2d(dim_out * mult, dim_out, 3, padding=1),
118
- )
119
-
120
- self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
121
-
122
- def forward(self, x, time_emb=None):
123
- h = self.ds_conv(x)
124
-
125
- if exists(self.mlp) and exists(time_emb):
126
- assert exists(time_emb), "time embedding must be passed in"
127
- condition = self.mlp(time_emb)
128
- h = h + rearrange(condition, "b c -> b c 1 1")
129
-
130
- h = self.net(h)
131
- return h + self.res_conv(x)
132
-
133
-
134
- class Attention(nn.Module):
135
- def __init__(self, dim, heads=4, dim_head=32):
136
- super().__init__()
137
- self.scale = dim_head**-0.5
138
- self.heads = heads
139
- hidden_dim = dim_head * heads
140
- self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
141
- self.to_q = nn.Conv2d(dim, hidden_dim, 1, bias=False)
142
- self.to_k = nn.Conv2d(dim, hidden_dim, 1, bias=False)
143
- self.to_v = nn.Conv2d(dim, hidden_dim, 1, bias=False)
144
- self.to_out = nn.Conv2d(hidden_dim, dim, 1)
145
-
146
- def forward(self, x):
147
- b, c, h, w = x.shape
148
-
149
- qkv = self.to_qkv(x).chunk(3, dim=1)
150
- q, k, v = map(
151
- lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
152
- )
153
- q = q * self.scale
154
-
155
- sim = einsum("b h d i, b h d j -> b h i j", q, k)
156
- sim = sim - sim.amax(dim=-1, keepdim=True).detach()
157
- attn = sim.softmax(dim=-1)
158
-
159
- out = einsum("b h i j, b h d j -> b h i d", attn, v)
160
- out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
161
-
162
- return self.to_out(out)
163
-
164
-
165
- class LinearCrossAttention(nn.Module):
166
- def __init__(self, dim, heads=4, dim_head=32) -> None:
167
- super().__init__()
168
- self.scale = dim_head**-0.5
169
- self.heads = heads
170
- hidden_dim = dim_head * heads
171
- self.to_kv = nn.Conv2d(dim, hidden_dim * 2, 1, bias=False)
172
- self.to_q = nn.Conv2d(dim, hidden_dim, 1, bias=False)
173
- self.out = nn.Conv2d(hidden_dim, dim, 1)
174
-
175
- def forward(self, x, cross_attend):
176
- b, c, h, w = x.shape
177
- q = self.to_q(x)
178
- k, v = self.to_kv(cross_attend).chunk(2, dim=1)
179
- q = rearrange(q, "b (h c) x y -> b h c (x y)", h=self.heads)
180
- k = rearrange(k, "b (h c) x y -> b h c (x y)", h=self.heads)
181
- v = rearrange(v, "b (h c) x y -> b h c (x y)", h=self.heads)
182
- q = q * self.scale
183
- sim = einsum("b h d i, b h d j -> b h i j", q, k)
184
- sim = sim - sim.amax(dim=-1, keepdim=True).detach()
185
- attn = sim.softmax(dim=-1)
186
- out = einsum("b h i j, b h d j -> b h i d", attn, v)
187
- out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
188
- return self.out(out)
189
-
190
-
191
- class LinearAttention(nn.Module):
192
- def __init__(self, dim, heads=4, dim_head=32):
193
- super().__init__()
194
- self.scale = dim_head**-0.5
195
- self.heads = heads
196
- hidden_dim = dim_head * heads
197
- self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
198
- self.to_q = nn.Conv2d(dim, hidden_dim, 1, bias=False)
199
- self.to_k = nn.Conv2d(dim, hidden_dim, 1, bias=False)
200
- self.to_v = nn.Conv2d(dim, hidden_dim, 1, bias=False)
201
- self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), nn.GroupNorm(1, dim))
202
-
203
- def forward(self, x):
204
- b, c, h, w = x.shape
205
- qkv = self.to_qkv(x).chunk(3, dim=1)
206
- q, k, v = map(
207
- lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
208
- )
209
- # calculate the softmax with respect to columns softmax of equivalent to q^T with respect to last dim
210
- q = q.softmax(dim=-2)
211
- # calculate the softmax with respect to rows of k
212
- k = k.softmax(dim=-1)
213
- # normalize the values in the attention matrix
214
- q = q * self.scale
215
- # dot product of q and v matrices
216
- context = torch.einsum("b h d n, b h e n -> b h d e", k, v)
217
- # dot product of context and q
218
- out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
219
- # rearrange the output to match the pytorch convention
220
- out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
221
- return self.to_out(out)
222
-
223
-
224
- class PreNorm(nn.Module):
225
- def __init__(self, dim, fn):
226
- super().__init__()
227
- self.fn = fn
228
- self.norm = nn.GroupNorm(1, dim)
229
-
230
- def forward(self, x, *args, **kwargs):
231
- x = self.norm(x)
232
- return self.fn(x, *args, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/structure/Unet.py DELETED
@@ -1,152 +0,0 @@
1
- import math
2
- from inspect import isfunction
3
- from functools import partial
4
- import matplotlib.pyplot as plt
5
- from tqdm.auto import tqdm
6
- from einops import rearrange
7
- import torch
8
- from torch import nn, einsum
9
- import torch.nn.functional as F
10
- from .Advanced_Network_Helpers import *
11
-
12
-
13
- class Unet(nn.Module):
14
- def __init__(
15
- self,
16
- dim,
17
- init_dim=None,
18
- out_dim=None,
19
- dim_mults=(1, 2, 4, 8),
20
- channels=3,
21
- with_time_emb=True,
22
- resnet_block_groups=8,
23
- use_convnext=True,
24
- convnext_mult=2,
25
- ):
26
- super().__init__()
27
-
28
- # determine dimensions
29
- self.channels = channels # since we are concatenating the images and the conditionings along the channel dimension
30
-
31
- init_dim = default(init_dim, dim // 3 * 2)
32
- self.init_conv = nn.Conv2d(self.channels * 2, init_dim, 7, padding=3)
33
- self.conditioning_init = nn.Conv2d(self.channels * 2, init_dim, 7, padding=3)
34
- dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
35
- in_out = list(zip(dims[:-1], dims[1:]))
36
- self.in_out = in_out
37
-
38
- if use_convnext:
39
- block_klass = partial(ConvNextBlock, mult=convnext_mult)
40
- else:
41
- block_klass = partial(ResnetBlock, groups=resnet_block_groups)
42
-
43
- # time embeddings
44
- if with_time_emb:
45
- time_dim = dim * 4
46
- self.time_mlp = nn.Sequential(
47
- SinusoidalPositionEmbeddings(dim),
48
- nn.Linear(dim, time_dim),
49
- nn.GELU(),
50
- nn.Linear(time_dim, time_dim),
51
- )
52
- else:
53
- time_dim = None
54
- self.time_mlp = None
55
-
56
- # layers
57
- self.downs = nn.ModuleList([])
58
- self.ups = nn.ModuleList([])
59
- self.conditioning_encoder = nn.ModuleList([])
60
- num_resolutions = len(in_out)
61
- self.num_resolutions = num_resolutions
62
-
63
- # conditioning encoder
64
- for ind, (dim_in, dim_out) in enumerate(in_out):
65
- is_last = ind >= (num_resolutions - 1)
66
-
67
- self.conditioning_encoder.append(
68
- nn.ModuleList(
69
- [
70
- block_klass(dim_in, dim_out),
71
- Residual(PreNorm(dim_out, LinearAttention(dim_out))),
72
- Downsample(dim_out) if not is_last else nn.Identity(),
73
- ]
74
- )
75
- )
76
-
77
- for ind, (dim_in, dim_out) in enumerate(in_out):
78
- is_last = ind >= (num_resolutions - 1)
79
-
80
- self.downs.append(
81
- nn.ModuleList(
82
- [
83
- block_klass(dim_in, dim_out, time_emb_dim=time_dim),
84
- block_klass(dim_out, dim_out, time_emb_dim=time_dim),
85
- Residual(PreNorm(dim_out, LinearAttention(dim_out))),
86
- Downsample(dim_out) if not is_last else nn.Identity(),
87
- ]
88
- )
89
- )
90
-
91
- mid_dim = dims[-1]
92
-
93
- self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
94
- self.cross_attention = Residual(PreNorm(mid_dim, LinearCrossAttention(mid_dim)))
95
- self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
96
-
97
- for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
98
- is_last = ind >= (num_resolutions - 1)
99
- self.ups.append(
100
- nn.ModuleList(
101
- [
102
- block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
103
- block_klass(dim_in, dim_in, time_emb_dim=time_dim),
104
- Residual(PreNorm(dim_in, LinearAttention(dim_in))),
105
- Upsample(dim_in) if not is_last else nn.Identity(),
106
- ]
107
- )
108
- )
109
-
110
- out_dim = default(out_dim, channels)
111
- self.final_conv = nn.Sequential(
112
- block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1)
113
- )
114
-
115
- def forward(self, x, time, implicit_conditioning, explicit_conditioning):
116
- x = torch.cat((x, explicit_conditioning), dim=1)
117
- conditioning = torch.cat((implicit_conditioning, explicit_conditioning), dim=1)
118
- x = self.init_conv(x)
119
-
120
- conditioning = self.conditioning_init(conditioning)
121
-
122
- t = self.time_mlp(time) if exists(self.time_mlp) else None
123
-
124
- h = []
125
-
126
- # conditioning encoder
127
-
128
- for block1, attn, downsample in self.conditioning_encoder:
129
- conditioning = block1(conditioning)
130
- conditioning = attn(conditioning)
131
- conditioning = downsample(conditioning)
132
-
133
- for block1, block2, attn, downsample in self.downs:
134
- x = block1(x, t)
135
- x = block2(x, t)
136
- x = attn(x)
137
- h.append(x)
138
- x = downsample(x)
139
-
140
- # bottleneck
141
- x = self.mid_block1(x, t)
142
- x = self.cross_attention(x, conditioning)
143
- x = self.mid_block2(x, t)
144
-
145
- for block1, block2, attn, upsample in self.ups:
146
- x = torch.cat((x, h.pop()), dim=1)
147
- x = block1(x, t)
148
- x = block2(x, t)
149
- x = attn(x)
150
- x = upsample(x)
151
-
152
- return self.final_conv(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/structure/Unet_2.py DELETED
@@ -1,152 +0,0 @@
1
- import math
2
- from inspect import isfunction
3
- from functools import partial
4
- import matplotlib.pyplot as plt
5
- from tqdm.auto import tqdm
6
- from einops import rearrange
7
- import torch
8
- from torch import nn, einsum
9
- import torch.nn.functional as F
10
- from .Advanced_Network_Helpers_2 import *
11
-
12
-
13
- class Unet(nn.Module):
14
- def __init__(
15
- self,
16
- dim,
17
- init_dim=None,
18
- out_dim=None,
19
- dim_mults=(1, 2, 4, 8),
20
- channels=3,
21
- with_time_emb=True,
22
- resnet_block_groups=8,
23
- use_convnext=True,
24
- convnext_mult=2,
25
- ):
26
- super().__init__()
27
-
28
- # determine dimensions
29
- self.channels = channels # since we are concatenating the images and the conditionings along the channel dimension
30
-
31
- init_dim = default(init_dim, dim // 3 * 2)
32
- self.init_conv = nn.Conv2d(self.channels * 2, init_dim, 7, padding=3)
33
- self.conditioning_init = nn.Conv2d(self.channels * 2, init_dim, 7, padding=3)
34
- dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
35
- in_out = list(zip(dims[:-1], dims[1:]))
36
- self.in_out = in_out
37
-
38
- if use_convnext:
39
- block_klass = partial(ConvNextBlock, mult=convnext_mult)
40
- else:
41
- block_klass = partial(ResnetBlock, groups=resnet_block_groups)
42
-
43
- # time embeddings
44
- if with_time_emb:
45
- time_dim = dim * 4
46
- self.time_mlp = nn.Sequential(
47
- SinusoidalPositionEmbeddings(dim),
48
- nn.Linear(dim, time_dim),
49
- nn.GELU(),
50
- nn.Linear(time_dim, time_dim),
51
- )
52
- else:
53
- time_dim = None
54
- self.time_mlp = None
55
-
56
- # layers
57
- self.downs = nn.ModuleList([])
58
- self.ups = nn.ModuleList([])
59
- self.conditioning_encoder = nn.ModuleList([])
60
- num_resolutions = len(in_out)
61
- self.num_resolutions = num_resolutions
62
-
63
- # conditioning encoder
64
- for ind, (dim_in, dim_out) in enumerate(in_out):
65
- is_last = ind >= (num_resolutions - 1)
66
-
67
- self.conditioning_encoder.append(
68
- nn.ModuleList(
69
- [
70
- block_klass(dim_in, dim_out),
71
- Residual(PreNorm(dim_out, LinearAttention(dim_out))),
72
- Downsample(dim_out) if not is_last else nn.Identity(),
73
- ]
74
- )
75
- )
76
-
77
- for ind, (dim_in, dim_out) in enumerate(in_out):
78
- is_last = ind >= (num_resolutions - 1)
79
-
80
- self.downs.append(
81
- nn.ModuleList(
82
- [
83
- block_klass(dim_in, dim_out, time_emb_dim=time_dim),
84
- block_klass(dim_out, dim_out, time_emb_dim=time_dim),
85
- Residual(PreNorm(dim_out, LinearAttention(dim_out))),
86
- Downsample(dim_out) if not is_last else nn.Identity(),
87
- ]
88
- )
89
- )
90
-
91
- mid_dim = dims[-1]
92
-
93
- self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
94
- self.cross_attention = Residual(PreNorm(mid_dim, LinearCrossAttention(mid_dim)))
95
- self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
96
-
97
- for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
98
- is_last = ind >= (num_resolutions - 1)
99
- self.ups.append(
100
- nn.ModuleList(
101
- [
102
- block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
103
- block_klass(dim_in, dim_in, time_emb_dim=time_dim),
104
- Residual(PreNorm(dim_in, LinearAttention(dim_in))),
105
- Upsample(dim_in) if not is_last else nn.Identity(),
106
- ]
107
- )
108
- )
109
-
110
- out_dim = default(out_dim, channels)
111
- self.final_conv = nn.Sequential(
112
- block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1)
113
- )
114
-
115
- def forward(self, x, time, implicit_conditioning, explicit_conditioning):
116
- x = torch.cat((x, explicit_conditioning), dim=1)
117
- conditioning = torch.cat((implicit_conditioning, explicit_conditioning), dim=1)
118
- x = self.init_conv(x)
119
-
120
- conditioning = self.conditioning_init(conditioning)
121
-
122
- t = self.time_mlp(time) if exists(self.time_mlp) else None
123
-
124
- h = []
125
-
126
- # conditioning encoder
127
-
128
- for block1, attn, downsample in self.conditioning_encoder:
129
- conditioning = block1(conditioning)
130
- conditioning = attn(conditioning)
131
- conditioning = downsample(conditioning)
132
-
133
- for block1, block2, attn, downsample in self.downs:
134
- x = block1(x, t)
135
- x = block2(x, t)
136
- x = attn(x)
137
- h.append(x)
138
- x = downsample(x)
139
-
140
- # bottleneck
141
- x = self.mid_block1(x, t)
142
- x = self.cross_attention(x, conditioning)
143
- x = self.mid_block2(x, t)
144
-
145
- for block1, block2, attn, upsample in self.ups:
146
- x = torch.cat((x, h.pop()), dim=1)
147
- x = block1(x, t)
148
- x = block2(x, t)
149
- x = attn(x)
150
- x = upsample(x)
151
-
152
- return self.final_conv(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/structure/hf_compatible_model.py DELETED
@@ -1,192 +0,0 @@
1
- from transformers import PretrainedConfig, PreTrainedModel
2
- import math
3
- from inspect import isfunction
4
- from functools import partial
5
- import matplotlib.pyplot as plt
6
- from tqdm.auto import tqdm
7
- from einops import rearrange
8
- import torch
9
- from torch import nn, einsum
10
- import torch.nn.functional as F
11
- from transformers import PreTrainedModel
12
- from .Advanced_Network_Helpers_3 import *
13
- import os
14
-
15
-
16
- class UnetConfig(PretrainedConfig):
17
- model_type = "unet"
18
-
19
- def __init__(
20
- self,
21
- dim=64,
22
- init_dim=None,
23
- out_dim=None,
24
- dim_mults=(1, 2, 4, 8),
25
- channels=3,
26
- with_time_emb=True,
27
- resnet_block_groups=8,
28
- use_convnext=True,
29
- convnext_mult=2,
30
- **kwargs
31
- ):
32
- super().__init__(**kwargs)
33
- self.dim = dim
34
- self.init_dim = init_dim
35
- self.out_dim = out_dim
36
- self.dim_mults = dim_mults
37
- self.channels = channels
38
- self.with_time_emb = with_time_emb
39
- self.resnet_block_groups = resnet_block_groups
40
- self.use_convnext = use_convnext
41
- self.convnext_mult = convnext_mult
42
-
43
-
44
- class Unet(PreTrainedModel):
45
- config_class = UnetConfig
46
-
47
- def __init__(
48
- self,
49
- config,
50
- ):
51
- super().__init__(config)
52
-
53
- # determine dimensions
54
- self.channels = (
55
- config.channels
56
- ) # since we are concatenating the images and the conditionings along the channel dimension
57
-
58
- init_dim = default(config.init_dim, config.dim // 3 * 2)
59
- self.init_conv = nn.Conv2d(self.channels * 2, init_dim, 7, padding=3)
60
- self.conditioning_init = nn.Conv2d(self.channels, init_dim, 7, padding=3)
61
- dims = [init_dim, *map(lambda m: config.dim * m, config.dim_mults)]
62
- in_out = list(zip(dims[:-1], dims[1:]))
63
- self.in_out = in_out
64
-
65
- if config.use_convnext:
66
- block_klass = partial(ConvNextBlock, mult=config.convnext_mult)
67
- else:
68
- block_klass = partial(ResnetBlock, groups=config.resnet_block_groups)
69
-
70
- # time embeddings
71
- if config.with_time_emb:
72
- time_dim = config.dim * 4
73
- self.time_mlp = nn.Sequential(
74
- SinusoidalPositionEmbeddings(config.dim),
75
- nn.Linear(config.dim, time_dim),
76
- nn.GELU(),
77
- nn.Linear(time_dim, time_dim),
78
- )
79
- else:
80
- time_dim = None
81
- self.time_mlp = None
82
-
83
- # layers
84
- self.downs = nn.ModuleList([])
85
- self.ups = nn.ModuleList([])
86
- self.conditioning_encoder = nn.ModuleList([])
87
- num_resolutions = len(in_out)
88
- self.num_resolutions = num_resolutions
89
-
90
- # conditioning encoder
91
- for ind, (dim_in, dim_out) in enumerate(in_out):
92
- is_last = ind >= (num_resolutions - 1)
93
-
94
- self.conditioning_encoder.append(
95
- nn.ModuleList(
96
- [
97
- block_klass(dim_in, dim_out),
98
- Residual(PreNorm(dim_out, LinearAttention(dim_out))),
99
- Downsample(dim_out) if not is_last else nn.Identity(),
100
- ]
101
- )
102
- )
103
-
104
- for ind, (dim_in, dim_out) in enumerate(in_out):
105
- is_last = ind >= (num_resolutions - 1)
106
-
107
- self.downs.append(
108
- nn.ModuleList(
109
- [
110
- block_klass(dim_in, dim_out, time_emb_dim=time_dim),
111
- block_klass(dim_out, dim_out, time_emb_dim=time_dim),
112
- Residual(PreNorm(dim_out, LinearAttention(dim_out))),
113
- Downsample(dim_out) if not is_last else nn.Identity(),
114
- ]
115
- )
116
- )
117
-
118
- mid_dim = dims[-1]
119
-
120
- self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
121
- self.cross_attention_1 = Residual(
122
- PreNorm(mid_dim, LinearCrossAttention(mid_dim))
123
- )
124
- self.cross_attention_2 = Residual(
125
- PreNorm(mid_dim, LinearCrossAttention(mid_dim))
126
- )
127
- self.cross_attention_3 = Residual(
128
- PreNorm(mid_dim, LinearCrossAttention(mid_dim))
129
- )
130
- self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
131
-
132
- for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
133
- is_last = ind >= (num_resolutions - 1)
134
- self.ups.append(
135
- nn.ModuleList(
136
- [
137
- block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
138
- block_klass(dim_in, dim_in, time_emb_dim=time_dim),
139
- Residual(PreNorm(dim_in, LinearAttention(dim_in))),
140
- Upsample(dim_in) if not is_last else nn.Identity(),
141
- ]
142
- )
143
- )
144
-
145
- out_dim = default(config.out_dim, config.channels)
146
- self.final_conv = nn.Sequential(
147
- block_klass(config.dim, config.dim), nn.Conv2d(config.dim, out_dim, 1)
148
- )
149
-
150
- def forward(self, x, time, implicit_conditioning, explicit_conditioning):
151
- x = torch.cat((x, explicit_conditioning), dim=1)
152
-
153
- x = self.init_conv(x)
154
-
155
- conditioning = self.conditioning_init(implicit_conditioning)
156
-
157
- t = self.time_mlp(time) if exists(self.time_mlp) else None
158
-
159
- h = []
160
-
161
- # conditioning encoder
162
-
163
- for block1, attn, downsample in self.conditioning_encoder:
164
- conditioning = block1(conditioning)
165
- conditioning = attn(conditioning)
166
- conditioning = downsample(conditioning)
167
-
168
- for block1, block2, attn, downsample in self.downs:
169
- x = block1(x, t)
170
- x = block2(x, t)
171
- x = attn(x)
172
- h.append(x)
173
- x = downsample(x)
174
-
175
- # reverse the c list
176
-
177
- # bottleneck
178
-
179
- x = self.cross_attention_1(x, conditioning)
180
- x = self.mid_block1(x, t)
181
- x = self.cross_attention_2(x, conditioning)
182
- x = self.mid_block2(x, t)
183
- x = self.cross_attention_3(x, conditioning)
184
-
185
- for block1, block2, attn, upsample in self.ups:
186
- x = torch.cat((x, h.pop()), dim=1)
187
- x = block1(x, t)
188
- x = block2(x, t)
189
- x = attn(x)
190
- x = upsample(x)
191
-
192
- return self.final_conv(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,14 +1,10 @@
1
  einops
2
  datasets
3
- matplotlib
4
  tqdm
5
  accelerate
6
- jax[cpu]
7
  torchinfo
8
- wandb
9
- ema_pytorch
10
- lpips
11
- pyyaml
12
  diffusers
13
  transformers
14
- torch-ema
 
 
 
1
  einops
2
  datasets
 
3
  tqdm
4
  accelerate
 
5
  torchinfo
 
 
 
 
6
  diffusers
7
  transformers
8
+ pathlib
9
+ safetensors
10
+
results/sample.png CHANGED