debisoft commited on
Commit
442c1b2
1 Parent(s): 79eb231
Files changed (2) hide show
  1. diffusion_utilities.py +250 -0
  2. requirements.txt +1 -0
diffusion_utilities.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from torchvision.utils import save_image, make_grid
5
+ import matplotlib.pyplot as plt
6
+ from matplotlib.animation import FuncAnimation, PillowWriter
7
+ import os
8
+ import torchvision.transforms as transforms
9
+ from torch.utils.data import Dataset
10
+ from PIL import Image
11
+
12
+
13
+ class ResidualConvBlock(nn.Module):
14
+ def __init__(
15
+ self, in_channels: int, out_channels: int, is_res: bool = False
16
+ ) -> None:
17
+ super().__init__()
18
+
19
+ # Check if input and output channels are the same for the residual connection
20
+ self.same_channels = in_channels == out_channels
21
+
22
+ # Flag for whether or not to use residual connection
23
+ self.is_res = is_res
24
+
25
+ # First convolutional layer
26
+ self.conv1 = nn.Sequential(
27
+ nn.Conv2d(in_channels, out_channels, 3, 1, 1), # 3x3 kernel with stride 1 and padding 1
28
+ nn.BatchNorm2d(out_channels), # Batch normalization
29
+ nn.GELU(), # GELU activation function
30
+ )
31
+
32
+ # Second convolutional layer
33
+ self.conv2 = nn.Sequential(
34
+ nn.Conv2d(out_channels, out_channels, 3, 1, 1), # 3x3 kernel with stride 1 and padding 1
35
+ nn.BatchNorm2d(out_channels), # Batch normalization
36
+ nn.GELU(), # GELU activation function
37
+ )
38
+
39
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
40
+
41
+ # If using residual connection
42
+ if self.is_res:
43
+ # Apply first convolutional layer
44
+ x1 = self.conv1(x)
45
+
46
+ # Apply second convolutional layer
47
+ x2 = self.conv2(x1)
48
+
49
+ # If input and output channels are the same, add residual connection directly
50
+ if self.same_channels:
51
+ out = x + x2
52
+ else:
53
+ # If not, apply a 1x1 convolutional layer to match dimensions before adding residual connection
54
+ shortcut = nn.Conv2d(x.shape[1], x2.shape[1], kernel_size=1, stride=1, padding=0).to(x.device)
55
+ out = shortcut(x) + x2
56
+ #print(f"resconv forward: x {x.shape}, x1 {x1.shape}, x2 {x2.shape}, out {out.shape}")
57
+
58
+ # Normalize output tensor
59
+ return out / 1.414
60
+
61
+ # If not using residual connection, return output of second convolutional layer
62
+ else:
63
+ x1 = self.conv1(x)
64
+ x2 = self.conv2(x1)
65
+ return x2
66
+
67
+ # Method to get the number of output channels for this block
68
+ def get_out_channels(self):
69
+ return self.conv2[0].out_channels
70
+
71
+ # Method to set the number of output channels for this block
72
+ def set_out_channels(self, out_channels):
73
+ self.conv1[0].out_channels = out_channels
74
+ self.conv2[0].in_channels = out_channels
75
+ self.conv2[0].out_channels = out_channels
76
+
77
+
78
+
79
+ class UnetUp(nn.Module):
80
+ def __init__(self, in_channels, out_channels):
81
+ super(UnetUp, self).__init__()
82
+
83
+ # Create a list of layers for the upsampling block
84
+ # The block consists of a ConvTranspose2d layer for upsampling, followed by two ResidualConvBlock layers
85
+ layers = [
86
+ nn.ConvTranspose2d(in_channels, out_channels, 2, 2),
87
+ ResidualConvBlock(out_channels, out_channels),
88
+ ResidualConvBlock(out_channels, out_channels),
89
+ ]
90
+
91
+ # Use the layers to create a sequential model
92
+ self.model = nn.Sequential(*layers)
93
+
94
+ def forward(self, x, skip):
95
+ # Concatenate the input tensor x with the skip connection tensor along the channel dimension
96
+ x = torch.cat((x, skip), 1)
97
+
98
+ # Pass the concatenated tensor through the sequential model and return the output
99
+ x = self.model(x)
100
+ return x
101
+
102
+
103
+ class UnetDown(nn.Module):
104
+ def __init__(self, in_channels, out_channels):
105
+ super(UnetDown, self).__init__()
106
+
107
+ # Create a list of layers for the downsampling block
108
+ # Each block consists of two ResidualConvBlock layers, followed by a MaxPool2d layer for downsampling
109
+ layers = [ResidualConvBlock(in_channels, out_channels), ResidualConvBlock(out_channels, out_channels), nn.MaxPool2d(2)]
110
+
111
+ # Use the layers to create a sequential model
112
+ self.model = nn.Sequential(*layers)
113
+
114
+ def forward(self, x):
115
+ # Pass the input through the sequential model and return the output
116
+ return self.model(x)
117
+
118
+ class EmbedFC(nn.Module):
119
+ def __init__(self, input_dim, emb_dim):
120
+ super(EmbedFC, self).__init__()
121
+ '''
122
+ This class defines a generic one layer feed-forward neural network for embedding input data of
123
+ dimensionality input_dim to an embedding space of dimensionality emb_dim.
124
+ '''
125
+ self.input_dim = input_dim
126
+
127
+ # define the layers for the network
128
+ layers = [
129
+ nn.Linear(input_dim, emb_dim),
130
+ nn.GELU(),
131
+ nn.Linear(emb_dim, emb_dim),
132
+ ]
133
+
134
+ # create a PyTorch sequential model consisting of the defined layers
135
+ self.model = nn.Sequential(*layers)
136
+
137
+ def forward(self, x):
138
+ # flatten the input tensor
139
+ x = x.view(-1, self.input_dim)
140
+ # apply the model layers to the flattened tensor
141
+ return self.model(x)
142
+
143
+ def unorm(x):
144
+ # unity norm. results in range of [0,1]
145
+ # assume x (h,w,3)
146
+ xmax = x.max((0,1))
147
+ xmin = x.min((0,1))
148
+ return(x - xmin)/(xmax - xmin)
149
+
150
+ def norm_all(store, n_t, n_s):
151
+ # runs unity norm on all timesteps of all samples
152
+ nstore = np.zeros_like(store)
153
+ for t in range(n_t):
154
+ for s in range(n_s):
155
+ nstore[t,s] = unorm(store[t,s])
156
+ return nstore
157
+
158
+ def norm_torch(x_all):
159
+ # runs unity norm on all timesteps of all samples
160
+ # input is (n_samples, 3,h,w), the torch image format
161
+ x = x_all.cpu().numpy()
162
+ xmax = x.max((2,3))
163
+ xmin = x.min((2,3))
164
+ xmax = np.expand_dims(xmax,(2,3))
165
+ xmin = np.expand_dims(xmin,(2,3))
166
+ nstore = (x - xmin)/(xmax - xmin)
167
+ return torch.from_numpy(nstore)
168
+
169
+ def gen_tst_context(n_cfeat):
170
+ """
171
+ Generate test context vectors
172
+ """
173
+ vec = torch.tensor([
174
+ [1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0], # human, non-human, food, spell, side-facing
175
+ [1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0], # human, non-human, food, spell, side-facing
176
+ [1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0], # human, non-human, food, spell, side-facing
177
+ [1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0], # human, non-human, food, spell, side-facing
178
+ [1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0], # human, non-human, food, spell, side-facing
179
+ [1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0]] # human, non-human, food, spell, side-facing
180
+ )
181
+ return len(vec), vec
182
+
183
+ def plot_grid(x,n_sample,n_rows,save_dir,w):
184
+ # x:(n_sample, 3, h, w)
185
+ ncols = n_sample//n_rows
186
+ grid = make_grid(norm_torch(x), nrow=ncols) # curiously, nrow is number of columns.. or number of items in the row.
187
+ save_image(grid, save_dir + f"run_image_w{w}.png")
188
+ print('saved image at ' + save_dir + f"run_image_w{w}.png")
189
+ return grid
190
+
191
+ def plot_sample(x_gen_store,n_sample,nrows,save_dir, fn, w, save=False):
192
+ ncols = n_sample//nrows
193
+ sx_gen_store = np.moveaxis(x_gen_store,2,4) # change to Numpy image format (h,w,channels) vs (channels,h,w)
194
+ nsx_gen_store = norm_all(sx_gen_store, sx_gen_store.shape[0], n_sample) # unity norm to put in range [0,1] for np.imshow
195
+
196
+ # create gif of images evolving over time, based on x_gen_store
197
+ fig, axs = plt.subplots(nrows=nrows, ncols=ncols, sharex=True, sharey=True,figsize=(ncols,nrows))
198
+ def animate_diff(i, store):
199
+ print(f'gif animating frame {i} of {store.shape[0]}', end='\r')
200
+ plots = []
201
+ for row in range(nrows):
202
+ for col in range(ncols):
203
+ axs[row, col].clear()
204
+ axs[row, col].set_xticks([])
205
+ axs[row, col].set_yticks([])
206
+ plots.append(axs[row, col].imshow(store[i,(row*ncols)+col]))
207
+ return plots
208
+ ani = FuncAnimation(fig, animate_diff, fargs=[nsx_gen_store], interval=200, blit=False, repeat=True, frames=nsx_gen_store.shape[0])
209
+ plt.close()
210
+ if save:
211
+ ani.save(save_dir + f"{fn}_w{w}.gif", dpi=100, writer=PillowWriter(fps=5))
212
+ print('saved gif at ' + save_dir + f"{fn}_w{w}.gif")
213
+ return ani
214
+
215
+
216
+ class CustomDataset(Dataset):
217
+ def __init__(self, sfilename, lfilename, transform, null_context=False):
218
+ self.sprites = np.load(sfilename)
219
+ self.slabels = np.load(lfilename)
220
+ print(f"sprite shape: {self.sprites.shape}")
221
+ print(f"labels shape: {self.slabels.shape}")
222
+ self.transform = transform
223
+ self.null_context = null_context
224
+ self.sprites_shape = self.sprites.shape
225
+ self.slabel_shape = self.slabels.shape
226
+
227
+ # Return the number of images in the dataset
228
+ def __len__(self):
229
+ return len(self.sprites)
230
+
231
+ # Get the image and label at a given index
232
+ def __getitem__(self, idx):
233
+ # Return the image and label as a tuple
234
+ if self.transform:
235
+ image = self.transform(self.sprites[idx])
236
+ if self.null_context:
237
+ label = torch.tensor(0).to(torch.int64)
238
+ else:
239
+ label = torch.tensor(self.slabels[idx]).to(torch.int64)
240
+ return (image, label)
241
+
242
+ def getshapes(self):
243
+ # return shapes of data and labels
244
+ return self.sprites_shape, self.slabel_shape
245
+
246
+ transform = transforms.Compose([
247
+ transforms.ToTensor(), # from [0,255] to range [0.0,1.0]
248
+ transforms.Normalize((0.5,), (0.5,)) # range [-1,1]
249
+
250
+ ])
requirements.txt CHANGED
@@ -1,2 +1,3 @@
 
1
  torch
2
  torchvision
 
1
+ ipython
2
  torch
3
  torchvision