Spaces:
Sleeping
Sleeping
lib
Browse files- diffusion_utilities.py +250 -0
- requirements.txt +1 -0
diffusion_utilities.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
from torchvision.utils import save_image, make_grid
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
from matplotlib.animation import FuncAnimation, PillowWriter
|
7 |
+
import os
|
8 |
+
import torchvision.transforms as transforms
|
9 |
+
from torch.utils.data import Dataset
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
|
13 |
+
class ResidualConvBlock(nn.Module):
|
14 |
+
def __init__(
|
15 |
+
self, in_channels: int, out_channels: int, is_res: bool = False
|
16 |
+
) -> None:
|
17 |
+
super().__init__()
|
18 |
+
|
19 |
+
# Check if input and output channels are the same for the residual connection
|
20 |
+
self.same_channels = in_channels == out_channels
|
21 |
+
|
22 |
+
# Flag for whether or not to use residual connection
|
23 |
+
self.is_res = is_res
|
24 |
+
|
25 |
+
# First convolutional layer
|
26 |
+
self.conv1 = nn.Sequential(
|
27 |
+
nn.Conv2d(in_channels, out_channels, 3, 1, 1), # 3x3 kernel with stride 1 and padding 1
|
28 |
+
nn.BatchNorm2d(out_channels), # Batch normalization
|
29 |
+
nn.GELU(), # GELU activation function
|
30 |
+
)
|
31 |
+
|
32 |
+
# Second convolutional layer
|
33 |
+
self.conv2 = nn.Sequential(
|
34 |
+
nn.Conv2d(out_channels, out_channels, 3, 1, 1), # 3x3 kernel with stride 1 and padding 1
|
35 |
+
nn.BatchNorm2d(out_channels), # Batch normalization
|
36 |
+
nn.GELU(), # GELU activation function
|
37 |
+
)
|
38 |
+
|
39 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
40 |
+
|
41 |
+
# If using residual connection
|
42 |
+
if self.is_res:
|
43 |
+
# Apply first convolutional layer
|
44 |
+
x1 = self.conv1(x)
|
45 |
+
|
46 |
+
# Apply second convolutional layer
|
47 |
+
x2 = self.conv2(x1)
|
48 |
+
|
49 |
+
# If input and output channels are the same, add residual connection directly
|
50 |
+
if self.same_channels:
|
51 |
+
out = x + x2
|
52 |
+
else:
|
53 |
+
# If not, apply a 1x1 convolutional layer to match dimensions before adding residual connection
|
54 |
+
shortcut = nn.Conv2d(x.shape[1], x2.shape[1], kernel_size=1, stride=1, padding=0).to(x.device)
|
55 |
+
out = shortcut(x) + x2
|
56 |
+
#print(f"resconv forward: x {x.shape}, x1 {x1.shape}, x2 {x2.shape}, out {out.shape}")
|
57 |
+
|
58 |
+
# Normalize output tensor
|
59 |
+
return out / 1.414
|
60 |
+
|
61 |
+
# If not using residual connection, return output of second convolutional layer
|
62 |
+
else:
|
63 |
+
x1 = self.conv1(x)
|
64 |
+
x2 = self.conv2(x1)
|
65 |
+
return x2
|
66 |
+
|
67 |
+
# Method to get the number of output channels for this block
|
68 |
+
def get_out_channels(self):
|
69 |
+
return self.conv2[0].out_channels
|
70 |
+
|
71 |
+
# Method to set the number of output channels for this block
|
72 |
+
def set_out_channels(self, out_channels):
|
73 |
+
self.conv1[0].out_channels = out_channels
|
74 |
+
self.conv2[0].in_channels = out_channels
|
75 |
+
self.conv2[0].out_channels = out_channels
|
76 |
+
|
77 |
+
|
78 |
+
|
79 |
+
class UnetUp(nn.Module):
|
80 |
+
def __init__(self, in_channels, out_channels):
|
81 |
+
super(UnetUp, self).__init__()
|
82 |
+
|
83 |
+
# Create a list of layers for the upsampling block
|
84 |
+
# The block consists of a ConvTranspose2d layer for upsampling, followed by two ResidualConvBlock layers
|
85 |
+
layers = [
|
86 |
+
nn.ConvTranspose2d(in_channels, out_channels, 2, 2),
|
87 |
+
ResidualConvBlock(out_channels, out_channels),
|
88 |
+
ResidualConvBlock(out_channels, out_channels),
|
89 |
+
]
|
90 |
+
|
91 |
+
# Use the layers to create a sequential model
|
92 |
+
self.model = nn.Sequential(*layers)
|
93 |
+
|
94 |
+
def forward(self, x, skip):
|
95 |
+
# Concatenate the input tensor x with the skip connection tensor along the channel dimension
|
96 |
+
x = torch.cat((x, skip), 1)
|
97 |
+
|
98 |
+
# Pass the concatenated tensor through the sequential model and return the output
|
99 |
+
x = self.model(x)
|
100 |
+
return x
|
101 |
+
|
102 |
+
|
103 |
+
class UnetDown(nn.Module):
|
104 |
+
def __init__(self, in_channels, out_channels):
|
105 |
+
super(UnetDown, self).__init__()
|
106 |
+
|
107 |
+
# Create a list of layers for the downsampling block
|
108 |
+
# Each block consists of two ResidualConvBlock layers, followed by a MaxPool2d layer for downsampling
|
109 |
+
layers = [ResidualConvBlock(in_channels, out_channels), ResidualConvBlock(out_channels, out_channels), nn.MaxPool2d(2)]
|
110 |
+
|
111 |
+
# Use the layers to create a sequential model
|
112 |
+
self.model = nn.Sequential(*layers)
|
113 |
+
|
114 |
+
def forward(self, x):
|
115 |
+
# Pass the input through the sequential model and return the output
|
116 |
+
return self.model(x)
|
117 |
+
|
118 |
+
class EmbedFC(nn.Module):
|
119 |
+
def __init__(self, input_dim, emb_dim):
|
120 |
+
super(EmbedFC, self).__init__()
|
121 |
+
'''
|
122 |
+
This class defines a generic one layer feed-forward neural network for embedding input data of
|
123 |
+
dimensionality input_dim to an embedding space of dimensionality emb_dim.
|
124 |
+
'''
|
125 |
+
self.input_dim = input_dim
|
126 |
+
|
127 |
+
# define the layers for the network
|
128 |
+
layers = [
|
129 |
+
nn.Linear(input_dim, emb_dim),
|
130 |
+
nn.GELU(),
|
131 |
+
nn.Linear(emb_dim, emb_dim),
|
132 |
+
]
|
133 |
+
|
134 |
+
# create a PyTorch sequential model consisting of the defined layers
|
135 |
+
self.model = nn.Sequential(*layers)
|
136 |
+
|
137 |
+
def forward(self, x):
|
138 |
+
# flatten the input tensor
|
139 |
+
x = x.view(-1, self.input_dim)
|
140 |
+
# apply the model layers to the flattened tensor
|
141 |
+
return self.model(x)
|
142 |
+
|
143 |
+
def unorm(x):
|
144 |
+
# unity norm. results in range of [0,1]
|
145 |
+
# assume x (h,w,3)
|
146 |
+
xmax = x.max((0,1))
|
147 |
+
xmin = x.min((0,1))
|
148 |
+
return(x - xmin)/(xmax - xmin)
|
149 |
+
|
150 |
+
def norm_all(store, n_t, n_s):
|
151 |
+
# runs unity norm on all timesteps of all samples
|
152 |
+
nstore = np.zeros_like(store)
|
153 |
+
for t in range(n_t):
|
154 |
+
for s in range(n_s):
|
155 |
+
nstore[t,s] = unorm(store[t,s])
|
156 |
+
return nstore
|
157 |
+
|
158 |
+
def norm_torch(x_all):
|
159 |
+
# runs unity norm on all timesteps of all samples
|
160 |
+
# input is (n_samples, 3,h,w), the torch image format
|
161 |
+
x = x_all.cpu().numpy()
|
162 |
+
xmax = x.max((2,3))
|
163 |
+
xmin = x.min((2,3))
|
164 |
+
xmax = np.expand_dims(xmax,(2,3))
|
165 |
+
xmin = np.expand_dims(xmin,(2,3))
|
166 |
+
nstore = (x - xmin)/(xmax - xmin)
|
167 |
+
return torch.from_numpy(nstore)
|
168 |
+
|
169 |
+
def gen_tst_context(n_cfeat):
|
170 |
+
"""
|
171 |
+
Generate test context vectors
|
172 |
+
"""
|
173 |
+
vec = torch.tensor([
|
174 |
+
[1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0], # human, non-human, food, spell, side-facing
|
175 |
+
[1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0], # human, non-human, food, spell, side-facing
|
176 |
+
[1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0], # human, non-human, food, spell, side-facing
|
177 |
+
[1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0], # human, non-human, food, spell, side-facing
|
178 |
+
[1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0], # human, non-human, food, spell, side-facing
|
179 |
+
[1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0]] # human, non-human, food, spell, side-facing
|
180 |
+
)
|
181 |
+
return len(vec), vec
|
182 |
+
|
183 |
+
def plot_grid(x,n_sample,n_rows,save_dir,w):
|
184 |
+
# x:(n_sample, 3, h, w)
|
185 |
+
ncols = n_sample//n_rows
|
186 |
+
grid = make_grid(norm_torch(x), nrow=ncols) # curiously, nrow is number of columns.. or number of items in the row.
|
187 |
+
save_image(grid, save_dir + f"run_image_w{w}.png")
|
188 |
+
print('saved image at ' + save_dir + f"run_image_w{w}.png")
|
189 |
+
return grid
|
190 |
+
|
191 |
+
def plot_sample(x_gen_store,n_sample,nrows,save_dir, fn, w, save=False):
|
192 |
+
ncols = n_sample//nrows
|
193 |
+
sx_gen_store = np.moveaxis(x_gen_store,2,4) # change to Numpy image format (h,w,channels) vs (channels,h,w)
|
194 |
+
nsx_gen_store = norm_all(sx_gen_store, sx_gen_store.shape[0], n_sample) # unity norm to put in range [0,1] for np.imshow
|
195 |
+
|
196 |
+
# create gif of images evolving over time, based on x_gen_store
|
197 |
+
fig, axs = plt.subplots(nrows=nrows, ncols=ncols, sharex=True, sharey=True,figsize=(ncols,nrows))
|
198 |
+
def animate_diff(i, store):
|
199 |
+
print(f'gif animating frame {i} of {store.shape[0]}', end='\r')
|
200 |
+
plots = []
|
201 |
+
for row in range(nrows):
|
202 |
+
for col in range(ncols):
|
203 |
+
axs[row, col].clear()
|
204 |
+
axs[row, col].set_xticks([])
|
205 |
+
axs[row, col].set_yticks([])
|
206 |
+
plots.append(axs[row, col].imshow(store[i,(row*ncols)+col]))
|
207 |
+
return plots
|
208 |
+
ani = FuncAnimation(fig, animate_diff, fargs=[nsx_gen_store], interval=200, blit=False, repeat=True, frames=nsx_gen_store.shape[0])
|
209 |
+
plt.close()
|
210 |
+
if save:
|
211 |
+
ani.save(save_dir + f"{fn}_w{w}.gif", dpi=100, writer=PillowWriter(fps=5))
|
212 |
+
print('saved gif at ' + save_dir + f"{fn}_w{w}.gif")
|
213 |
+
return ani
|
214 |
+
|
215 |
+
|
216 |
+
class CustomDataset(Dataset):
|
217 |
+
def __init__(self, sfilename, lfilename, transform, null_context=False):
|
218 |
+
self.sprites = np.load(sfilename)
|
219 |
+
self.slabels = np.load(lfilename)
|
220 |
+
print(f"sprite shape: {self.sprites.shape}")
|
221 |
+
print(f"labels shape: {self.slabels.shape}")
|
222 |
+
self.transform = transform
|
223 |
+
self.null_context = null_context
|
224 |
+
self.sprites_shape = self.sprites.shape
|
225 |
+
self.slabel_shape = self.slabels.shape
|
226 |
+
|
227 |
+
# Return the number of images in the dataset
|
228 |
+
def __len__(self):
|
229 |
+
return len(self.sprites)
|
230 |
+
|
231 |
+
# Get the image and label at a given index
|
232 |
+
def __getitem__(self, idx):
|
233 |
+
# Return the image and label as a tuple
|
234 |
+
if self.transform:
|
235 |
+
image = self.transform(self.sprites[idx])
|
236 |
+
if self.null_context:
|
237 |
+
label = torch.tensor(0).to(torch.int64)
|
238 |
+
else:
|
239 |
+
label = torch.tensor(self.slabels[idx]).to(torch.int64)
|
240 |
+
return (image, label)
|
241 |
+
|
242 |
+
def getshapes(self):
|
243 |
+
# return shapes of data and labels
|
244 |
+
return self.sprites_shape, self.slabel_shape
|
245 |
+
|
246 |
+
transform = transforms.Compose([
|
247 |
+
transforms.ToTensor(), # from [0,255] to range [0.0,1.0]
|
248 |
+
transforms.Normalize((0.5,), (0.5,)) # range [-1,1]
|
249 |
+
|
250 |
+
])
|
requirements.txt
CHANGED
@@ -1,2 +1,3 @@
|
|
|
|
1 |
torch
|
2 |
torchvision
|
|
|
1 |
+
ipython
|
2 |
torch
|
3 |
torchvision
|