Abdullah-Nazhat commited on
Commit
c1b3fbf
1 Parent(s): a24d7a2

Upload 2 files

Browse files
Files changed (2) hide show
  1. approximator.py +271 -0
  2. train.py +195 -0
approximator.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from torch import nn, einsum
4
+ from torch.nn import functional as F
5
+ from einops.layers.torch import Rearrange
6
+ from einops import rearrange, reduce
7
+ from math import ceil
8
+
9
+
10
+ class FeedForward(nn.Module):
11
+ def __init__(self, dim, hidden_dim, dropout):
12
+ super().__init__()
13
+ self.net = nn.Sequential(
14
+ nn.Linear(dim, hidden_dim),
15
+ nn.GELU(),
16
+ nn.Dropout(dropout),
17
+ nn.Linear(hidden_dim, dim),
18
+ nn.Dropout(dropout)
19
+ )
20
+ def forward(self, x):
21
+ return self.net(x)
22
+
23
+
24
+ # helper functions
25
+
26
+ def exists(val):
27
+ return val is not None
28
+
29
+ def moore_penrose_iter_pinv(x, iters = 6):
30
+ device = x.device
31
+
32
+ abs_x = torch.abs(x)
33
+ col = abs_x.sum(dim = -1)
34
+ row = abs_x.sum(dim = -2)
35
+ z = rearrange(x, '... i j -> ... j i') / (torch.max(col) * torch.max(row))
36
+
37
+ I = torch.eye(x.shape[-1], device = device)
38
+ I = rearrange(I, 'i j -> () i j')
39
+
40
+ for _ in range(iters):
41
+ xz = x @ z
42
+ z = 0.25 * z @ (13 * I - (xz @ (15 * I - (xz @ (7 * I - xz)))))
43
+
44
+ return z
45
+
46
+ # main attention class
47
+
48
+ class NystromAttention(nn.Module):
49
+ def __init__(
50
+ self,
51
+ dim,
52
+ dim_head = 64,
53
+ heads = 8,
54
+ num_landmarks = 256,
55
+ pinv_iterations = 6,
56
+ residual = True,
57
+ residual_conv_kernel = 33,
58
+ eps = 1e-8,
59
+ dropout = 0.
60
+ ):
61
+ super().__init__()
62
+ self.eps = eps
63
+ inner_dim = heads * dim_head
64
+
65
+ self.num_landmarks = num_landmarks
66
+ self.pinv_iterations = pinv_iterations
67
+
68
+ self.heads = heads
69
+ self.scale = dim_head ** -0.5
70
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
71
+
72
+ self.to_out = nn.Sequential(
73
+ nn.Linear(inner_dim, dim),
74
+ nn.Dropout(dropout)
75
+ )
76
+
77
+ self.residual = residual
78
+ if residual:
79
+ kernel_size = residual_conv_kernel
80
+ padding = residual_conv_kernel // 2
81
+ self.res_conv = nn.Conv2d(heads, heads, (kernel_size, 1), padding = (padding, 0), groups = heads, bias = False)
82
+
83
+ def forward(self, x, mask = None, return_attn = False):
84
+ b, n, _, h, m, iters, eps = *x.shape, self.heads, self.num_landmarks, self.pinv_iterations, self.eps
85
+
86
+ # pad so that sequence can be evenly divided into m landmarks
87
+
88
+ remainder = n % m
89
+ if remainder > 0:
90
+ padding = m - (n % m)
91
+ x = F.pad(x, (0, 0, padding, 0), value = 0)
92
+
93
+ if exists(mask):
94
+ mask = F.pad(mask, (padding, 0), value = False)
95
+
96
+ # derive query, keys, values
97
+
98
+ q, k, v = self.to_qkv(x).chunk(3, dim = -1)
99
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
100
+
101
+ # set masked positions to 0 in queries, keys, values
102
+
103
+ if exists(mask):
104
+ mask = rearrange(mask, 'b n -> b () n')
105
+ q, k, v = map(lambda t: t * mask[..., None], (q, k, v))
106
+
107
+ q = q * self.scale
108
+
109
+ # generate landmarks by sum reduction, and then calculate mean using the mask
110
+
111
+ l = ceil(n / m)
112
+ landmark_einops_eq = '... (n l) d -> ... n d'
113
+ q_landmarks = reduce(q, landmark_einops_eq, 'sum', l = l)
114
+ k_landmarks = reduce(k, landmark_einops_eq, 'sum', l = l)
115
+
116
+ # calculate landmark mask, and also get sum of non-masked elements in preparation for masked mean
117
+
118
+ divisor = l
119
+ if exists(mask):
120
+ mask_landmarks_sum = reduce(mask, '... (n l) -> ... n', 'sum', l = l)
121
+ divisor = mask_landmarks_sum[..., None] + eps
122
+ mask_landmarks = mask_landmarks_sum > 0
123
+
124
+ # masked mean (if mask exists)
125
+
126
+ q_landmarks /= divisor
127
+ k_landmarks /= divisor
128
+
129
+ # similarities
130
+
131
+ einops_eq = '... i d, ... j d -> ... i j'
132
+ sim1 = einsum(einops_eq, q, k_landmarks)
133
+ sim2 = einsum(einops_eq, q_landmarks, k_landmarks)
134
+ sim3 = einsum(einops_eq, q_landmarks, k)
135
+
136
+ # masking
137
+
138
+ if exists(mask):
139
+ mask_value = -torch.finfo(q.dtype).max
140
+ sim1.masked_fill_(~(mask[..., None] * mask_landmarks[..., None, :]), mask_value)
141
+ sim2.masked_fill_(~(mask_landmarks[..., None] * mask_landmarks[..., None, :]), mask_value)
142
+ sim3.masked_fill_(~(mask_landmarks[..., None] * mask[..., None, :]), mask_value)
143
+
144
+ # eq (15) in the paper and aggregate values
145
+
146
+ attn1, attn2, attn3 = map(lambda t: t.softmax(dim = -1), (sim1, sim2, sim3))
147
+ attn2_inv = moore_penrose_iter_pinv(attn2, iters)
148
+
149
+ out = (attn1 @ attn2_inv) @ (attn3 @ v)
150
+
151
+ # add depth-wise conv residual of values
152
+
153
+ if self.residual:
154
+ out += self.res_conv(v)
155
+
156
+ # merge and combine heads
157
+
158
+ out = rearrange(out, 'b h n d -> b n (h d)', h = h)
159
+ out = self.to_out(out)
160
+ out = out[:, -n:]
161
+
162
+ if return_attn:
163
+ attn = attn1 @ attn2_inv @ attn3
164
+ return out, attn
165
+
166
+ return out
167
+
168
+
169
+
170
+
171
+
172
+
173
+
174
+ class NystromBlock(nn.Module):
175
+ def __init__(self,dim,dim_ffn, dropout):
176
+ super().__init__()
177
+ self.Nystrom = NystromAttention(
178
+ dim,
179
+ dim_head = 64,
180
+ heads = 4,
181
+ num_landmarks = 32,
182
+ pinv_iterations = 3,
183
+ residual = True,
184
+ residual_conv_kernel = 33,
185
+ eps = 1e-8,
186
+ dropout = dropout)
187
+ self.norm = nn.LayerNorm(dim)
188
+
189
+ self.ffn = FeedForward(dim,dim_ffn,dropout)
190
+
191
+ def forward(self, x):
192
+ res = x
193
+ x = self.norm(x)
194
+ x = self.Nystrom(x)
195
+ x = res + x
196
+ res = x
197
+ x = self.norm(x)
198
+ x = self.ffn(x)
199
+ out = x + res
200
+ return out
201
+
202
+
203
+
204
+
205
+ class ApproximatorGatingUnit(nn.Module):
206
+ def __init__(self,d_model,d_ffn,dropout):
207
+ super().__init__()
208
+ #self.proj = nn.Linear(d_model, d_model)
209
+ self.Approx_1 = NystromBlock(d_model,d_ffn,dropout)
210
+ self.Approx_2 = NystromBlock(d_model,d_ffn,dropout)
211
+
212
+
213
+
214
+
215
+ def forward(self, x):
216
+ u, v = x, x
217
+ u = self.Approx_1(u)
218
+ v = self.Approx_2(v)
219
+ out = u * v
220
+ return out
221
+
222
+
223
+ class ApproximatorBlock(nn.Module):
224
+ def __init__(self, d_model, d_ffn,dropout):
225
+ super().__init__()
226
+
227
+ self.norm = nn.LayerNorm(d_model)
228
+ self.agu = ApproximatorGatingUnit(d_model,d_ffn,dropout)
229
+ self.ffn = FeedForward(d_model,d_ffn,dropout)
230
+ def forward(self, x):
231
+ residual = x
232
+ x = self.norm(x)
233
+ x = self.agu(x)
234
+ x = x + residual
235
+ residual = x
236
+ x = self.norm(x)
237
+ x = self.ffn(x)
238
+ out = x + residual
239
+ return out
240
+
241
+
242
+
243
+
244
+
245
+
246
+
247
+
248
+
249
+ class Approximator(nn.Module):
250
+ def __init__(self, d_model, d_ffn, num_layers,dropout):
251
+ super().__init__()
252
+
253
+ self.model = nn.Sequential(
254
+
255
+ *[ApproximatorBlock(d_model,d_ffn,dropout) for _ in range(num_layers)],
256
+
257
+
258
+ )
259
+
260
+ def forward(self, x):
261
+
262
+ x = self.model(x)
263
+
264
+ return x
265
+
266
+
267
+
268
+
269
+
270
+
271
+
train.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #imports
2
+
3
+ import os
4
+ import csv
5
+ import torch
6
+ from torch import nn
7
+ from torch.utils.data import DataLoader
8
+ from torchvision import datasets
9
+ from torchvision.transforms import ToTensor, Normalize, RandomCrop, RandomHorizontalFlip, Compose
10
+ from approximator import Approximator
11
+
12
+ # data transforms
13
+
14
+ transform = Compose([
15
+ RandomCrop(32, padding=4),
16
+ RandomHorizontalFlip(),
17
+ ToTensor(),
18
+ Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
19
+
20
+ ])
21
+
22
+ training_data = datasets.CIFAR10(
23
+ root='data',
24
+ train=True,
25
+ download=True,
26
+ transform=transform
27
+ )
28
+
29
+ test_data = datasets.CIFAR10(
30
+ root='data',
31
+ train=False,
32
+ download=True,
33
+ transform=transform
34
+ )
35
+ # create dataloaders
36
+
37
+ batch_size = 128
38
+
39
+ train_dataloader = DataLoader(training_data, batch_size=batch_size,shuffle=True)
40
+ test_dataloader = DataLoader(test_data, batch_size=batch_size)
41
+
42
+
43
+ for X, y in test_dataloader:
44
+ print(f"Shape of X [N,C,H,W]:{X.shape}")
45
+ print(f"Shape of y:{y.shape}{y.dtype}")
46
+ break
47
+
48
+
49
+
50
+ # size checking for loading images
51
+ def check_sizes(image_size, patch_size):
52
+ sqrt_num_patches, remainder = divmod(image_size, patch_size)
53
+ assert remainder == 0, "`image_size` must be divisibe by `patch_size`"
54
+ num_patches = sqrt_num_patches ** 2
55
+ return num_patches
56
+
57
+
58
+
59
+ # create model
60
+ # Get cpu or gpu device for training.
61
+
62
+ device = "cuda" if torch.cuda.is_available() else "cpu"
63
+ print(f"using {device} device")
64
+
65
+ # model definition
66
+
67
+ class ApproximatorImageClassification(Approximator):
68
+ def __init__(
69
+ self,
70
+ image_size=32,
71
+ patch_size=4,
72
+ in_channels=3,
73
+ num_classes=10,
74
+ d_model=256,
75
+ d_ffn=512,
76
+ num_layers=4,
77
+ dropout=0.5
78
+ ):
79
+ num_patches = check_sizes(image_size, patch_size)
80
+ super().__init__(d_model, d_ffn, num_layers,dropout)
81
+ self.patcher = nn.Conv2d(
82
+ in_channels, d_model, kernel_size=patch_size, stride=patch_size
83
+ )
84
+ self.classifier = nn.Linear(d_model, num_classes)
85
+
86
+ def forward(self, x):
87
+
88
+ patches = self.patcher(x)
89
+ batch_size, num_channels, _, _ = patches.shape
90
+ patches = patches.permute(0, 2, 3, 1)
91
+ patches = patches.view(batch_size, -1, num_channels)
92
+ embedding = self.model(patches)
93
+ embedding = embedding.mean(dim=1) # global average pooling
94
+ out = self.classifier(embedding)
95
+ return out
96
+
97
+ model = ApproximatorImageClassification().to(device)
98
+ print(model)
99
+
100
+ # Optimizer
101
+
102
+ loss_fn = nn.CrossEntropyLoss()
103
+ optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)
104
+
105
+
106
+ # Training Loop
107
+
108
+ def train(dataloader, model, loss_fn, optimizer):
109
+ size = len(dataloader.dataset)
110
+ num_batches = len(dataloader)
111
+ model.train()
112
+ train_loss = 0
113
+ correct = 0
114
+ for batch, (X,y) in enumerate(dataloader):
115
+ X, y = X.to(device), y.to(device)
116
+
117
+ #compute prediction error
118
+ pred = model(X)
119
+ loss = loss_fn(pred,y)
120
+
121
+ # backpropagation
122
+ optimizer.zero_grad()
123
+ loss.backward()
124
+ optimizer.step()
125
+ train_loss += loss.item()
126
+ _, labels = torch.max(pred.data, 1)
127
+ correct += labels.eq(y.data).type(torch.float).sum()
128
+
129
+
130
+
131
+
132
+ if batch % 100 == 0:
133
+ loss, current = loss.item(), batch * len(X)
134
+ print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
135
+
136
+ train_loss /= num_batches
137
+ train_accuracy = 100. * correct.item() / size
138
+ print(train_accuracy)
139
+ return train_loss,train_accuracy
140
+
141
+
142
+
143
+ # Test loop
144
+
145
+ def test(dataloader, model, loss_fn):
146
+ size = len(dataloader.dataset)
147
+ num_batches = len(dataloader)
148
+ model.eval()
149
+ test_loss = 0
150
+ correct = 0
151
+ with torch.no_grad():
152
+ for X,y in dataloader:
153
+ X,y = X.to(device), y.to(device)
154
+ pred = model(X)
155
+ test_loss += loss_fn(pred, y).item()
156
+ correct += (pred.argmax(1) == y).type(torch.float).sum().item()
157
+ test_loss /= num_batches
158
+ correct /= size
159
+ print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
160
+ test_accuracy = 100*correct
161
+ return test_loss, test_accuracy
162
+
163
+
164
+
165
+ # apply train and test
166
+
167
+ logname = "/home/abdullah/Desktop/Proposals_experiments/Approximator/Experiments_cifar10/logs_approximator/logs_cifar10.csv"
168
+ if not os.path.exists(logname):
169
+ with open(logname, 'w') as logfile:
170
+ logwriter = csv.writer(logfile, delimiter=',')
171
+ logwriter.writerow(['epoch', 'train loss', 'train acc',
172
+ 'test loss', 'test acc'])
173
+
174
+
175
+ epochs = 100
176
+ for epoch in range(epochs):
177
+ print(f"Epoch {epoch+1}\n-----------------------------------")
178
+ train_loss, train_acc = train(train_dataloader, model, loss_fn, optimizer)
179
+ # learning rate scheduler
180
+ #if scheduler is not None:
181
+ # scheduler.step()
182
+ test_loss, test_acc = test(test_dataloader, model, loss_fn)
183
+ with open(logname, 'a') as logfile:
184
+ logwriter = csv.writer(logfile, delimiter=',')
185
+ logwriter.writerow([epoch+1, train_loss, train_acc,
186
+ test_loss, test_acc])
187
+ print("Done!")
188
+
189
+ # saving trained model
190
+
191
+ path = "/home/abdullah/Desktop/Proposals_experiments/Approximator/Experiments_cifar10/weights_approximator"
192
+ model_name = "ApproximatorImageClassification_cifar10"
193
+ torch.save(model.state_dict(), f"{path}/{model_name}.pth")
194
+ print(f"Saved Model State to {path}/{model_name}.pth ")
195
+