Jyothirmai commited on
Commit
cc8f2c4
1 Parent(s): 74ff854

Delete models_debugger.py

Browse files
Files changed (1) hide show
  1. models_debugger.py +0 -816
models_debugger.py DELETED
@@ -1,816 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torchvision
4
- import numpy as np
5
- from torch.autograd import Variable
6
- from torchvision.models.vgg import model_urls as vgg_model_urls
7
- import torchvision.models as models
8
-
9
- from utils.tcn import *
10
-
11
-
12
- class DenseNet121(nn.Module):
13
- def __init__(self, classes=14, pretrained=True):
14
- super(DenseNet121, self).__init__()
15
- self.model = torchvision.models.densenet121(pretrained=pretrained)
16
- num_in_features = self.model.classifier.in_features
17
- self.model.classifier = nn.Sequential(
18
- nn.Linear(in_features=num_in_features, out_features=classes, bias=True),
19
- # nn.Sigmoid()
20
- )
21
-
22
- def forward(self, x) -> object:
23
- """
24
-
25
- :rtype: object
26
- """
27
- x = self.densenet121(x)
28
- return x
29
-
30
-
31
- class DenseNet161(nn.Module):
32
- def __init__(self, classes=156, pretrained=True):
33
- super(DenseNet161, self).__init__()
34
- self.model = torchvision.models.densenet161(pretrained=pretrained)
35
- num_in_features = self.model.classifier.in_features
36
- self.model.classifier = nn.Sequential(
37
- self.__init_linear(in_features=num_in_features, out_features=classes),
38
- # nn.Sigmoid()
39
- )
40
-
41
- def __init_linear(self, in_features, out_features):
42
- func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
43
- func.weight.data.normal_(0, 0.1)
44
- return func
45
-
46
- def forward(self, x) -> object:
47
- """
48
-
49
- :rtype: object
50
- """
51
- x = self.model(x)
52
- return x
53
-
54
-
55
- class DenseNet169(nn.Module):
56
- def __init__(self, classes=156, pretrained=True):
57
- super(DenseNet169, self).__init__()
58
- self.model = torchvision.models.densenet169(pretrained=pretrained)
59
- num_in_features = self.model.classifier.in_features
60
- self.model.classifier = nn.Sequential(
61
- self.__init_linear(in_features=num_in_features, out_features=classes),
62
- # nn.Sigmoid()
63
- )
64
-
65
- def __init_linear(self, in_features, out_features):
66
- func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
67
- func.weight.data.normal_(0, 0.1)
68
- return func
69
-
70
- def forward(self, x) -> object:
71
- """
72
-
73
- :rtype: object
74
- """
75
- x = self.model(x)
76
- return x
77
-
78
-
79
- class DenseNet201(nn.Module):
80
- def __init__(self, classes=156, pretrained=True):
81
- super(DenseNet201, self).__init__()
82
- self.model = torchvision.models.densenet201(pretrained=pretrained)
83
- num_in_features = self.model.classifier.in_features
84
- self.model.classifier = nn.Sequential(
85
- self.__init_linear(in_features=num_in_features, out_features=classes),
86
- nn.Sigmoid()
87
- )
88
-
89
- def __init_linear(self, in_features, out_features):
90
- func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
91
- func.weight.data.normal_(0, 0.1)
92
- return func
93
-
94
- def forward(self, x) -> object:
95
- """
96
-
97
- :rtype: object
98
- """
99
- x = self.model(x)
100
- return x
101
-
102
-
103
- class ResNet18(nn.Module):
104
- def __init__(self, classes=156, pretrained=True):
105
- super(ResNet18, self).__init__()
106
- self.model = torchvision.models.resnet18(pretrained=pretrained)
107
- num_in_features = self.model.fc.in_features
108
- self.model.fc = nn.Sequential(
109
- self.__init_linear(in_features=num_in_features, out_features=classes),
110
- # nn.Sigmoid()
111
- )
112
-
113
- def __init_linear(self, in_features, out_features):
114
- func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
115
- func.weight.data.normal_(0, 0.1)
116
- return func
117
-
118
- def forward(self, x) -> object:
119
- """
120
-
121
- :rtype: object
122
- """
123
- x = self.model(x)
124
- return x
125
-
126
-
127
- class ResNet34(nn.Module):
128
- def __init__(self, classes=156, pretrained=True):
129
- super(ResNet34, self).__init__()
130
- self.model = torchvision.models.resnet34(pretrained=pretrained)
131
- num_in_features = self.model.fc.in_features
132
- self.model.fc = nn.Sequential(
133
- self.__init_linear(in_features=num_in_features, out_features=classes),
134
- # nn.Sigmoid()
135
- )
136
-
137
- def __init_linear(self, in_features, out_features):
138
- func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
139
- func.weight.data.normal_(0, 0.1)
140
- return func
141
-
142
- def forward(self, x) -> object:
143
- """
144
-
145
- :rtype: object
146
- """
147
- x = self.model(x)
148
- return x
149
-
150
-
151
- class ResNet50(nn.Module):
152
- def __init__(self, classes=156, pretrained=True):
153
- super(ResNet50, self).__init__()
154
- self.model = torchvision.models.resnet50(pretrained=pretrained)
155
- num_in_features = self.model.fc.in_features
156
- self.model.fc = nn.Sequential(
157
- self.__init_linear(in_features=num_in_features, out_features=classes),
158
- # nn.Sigmoid()
159
- )
160
-
161
- def __init_linear(self, in_features, out_features):
162
- func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
163
- func.weight.data.normal_(0, 0.1)
164
- return func
165
-
166
- def forward(self, x) -> object:
167
- """
168
-
169
- :rtype: object
170
- """
171
- x = self.model(x)
172
- return x
173
-
174
-
175
- class ResNet101(nn.Module):
176
- def __init__(self, classes=156, pretrained=True):
177
- super(ResNet101, self).__init__()
178
- self.model = torchvision.models.resnet101(pretrained=pretrained)
179
- num_in_features = self.model.fc.in_features
180
- self.model.fc = nn.Sequential(
181
- self.__init_linear(in_features=num_in_features, out_features=classes),
182
- # nn.Sigmoid()
183
- )
184
-
185
- def __init_linear(self, in_features, out_features):
186
- func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
187
- func.weight.data.normal_(0, 0.1)
188
- return func
189
-
190
- def forward(self, x) -> object:
191
- """
192
-
193
- :rtype: object
194
- """
195
- x = self.model(x)
196
- return x
197
-
198
-
199
- class ResNet152(nn.Module):
200
- def __init__(self, classes=156, pretrained=True):
201
- super(ResNet152, self).__init__()
202
- self.model = torchvision.models.resnet152(pretrained=pretrained)
203
- num_in_features = self.model.fc.in_features
204
- self.model.fc = nn.Sequential(
205
- self.__init_linear(in_features=num_in_features, out_features=classes),
206
- # nn.Sigmoid()
207
- )
208
-
209
- def __init_linear(self, in_features, out_features):
210
- func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
211
- func.weight.data.normal_(0, 0.1)
212
- return func
213
-
214
- def forward(self, x) -> object:
215
- """
216
-
217
- :rtype: object
218
- """
219
- x = self.model(x)
220
- return x
221
-
222
-
223
- class VGG19(nn.Module):
224
- def __init__(self, classes=14, pretrained=True):
225
- super(VGG19, self).__init__()
226
- self.model = torchvision.models.vgg19_bn(pretrained=pretrained)
227
- self.model.classifier = nn.Sequential(
228
- self.__init_linear(in_features=25088, out_features=4096),
229
- nn.ReLU(),
230
- nn.Dropout(0.5),
231
- self.__init_linear(in_features=4096, out_features=4096),
232
- nn.ReLU(),
233
- nn.Dropout(0.5),
234
- self.__init_linear(in_features=4096, out_features=classes),
235
- # nn.Sigmoid()
236
- )
237
-
238
- def __init_linear(self, in_features, out_features):
239
- func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
240
- func.weight.data.normal_(0, 0.1)
241
- return func
242
-
243
- def forward(self, x) -> object:
244
- """
245
-
246
- :rtype: object
247
- """
248
- x = self.model(x)
249
- return x
250
-
251
-
252
- class VGG(nn.Module):
253
- def __init__(self, tags_num):
254
- super(VGG, self).__init__()
255
- vgg_model_urls['vgg19'] = vgg_model_urls['vgg19'].replace('https://', 'http://')
256
- self.vgg19 = models.vgg19(pretrained=True)
257
- vgg19_classifier = list(self.vgg19.classifier.children())[:-1]
258
- self.classifier = nn.Sequential(*vgg19_classifier)
259
- self.fc = nn.Linear(4096, tags_num)
260
- self.fc.apply(self.init_weights)
261
- self.bn = nn.BatchNorm1d(tags_num, momentum=0.1)
262
- # self.init_weights()
263
-
264
- def init_weights(self, m):
265
- if type(m) == nn.Linear:
266
- self.fc.weight.data.normal_(0, 0.1)
267
- self.fc.bias.data.fill_(0)
268
-
269
- def forward(self, images) -> object:
270
- """
271
-
272
- :rtype: object
273
- """
274
- visual_feats = self.vgg19.features(images)
275
- tags_classifier = visual_feats.view(visual_feats.size(0), -1)
276
- tags_classifier = self.bn(self.fc(self.classifier(tags_classifier)))
277
- return tags_classifier
278
-
279
-
280
- class InceptionV3(nn.Module):
281
- def __init__(self, classes=156, pretrained=True):
282
- super(InceptionV3, self).__init__()
283
- self.model = torchvision.models.inception_v3(pretrained=pretrained)
284
- num_in_features = self.model.classifier.in_features
285
- self.model.classifier = nn.Sequential(
286
- self.__init_linear(in_features=num_in_features, out_features=classes),
287
- # nn.Sigmoid()
288
- )
289
-
290
- def __init_linear(self, in_features, out_features):
291
- func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
292
- func.weight.data.normal_(0, 0.1)
293
- return func
294
-
295
- def forward(self, x) -> object:
296
- """
297
-
298
- :rtype: object
299
- """
300
- x = self.model(x)
301
- return x
302
-
303
-
304
- class CheXNetDenseNet121(nn.Module):
305
- def __init__(self, classes=14, pretrained=True):
306
- super(CheXNetDenseNet121, self).__init__()
307
- self.densenet121 = torchvision.models.densenet121(pretrained=pretrained)
308
- num_in_features = self.densenet121.classifier.in_features
309
- self.densenet121.classifier = nn.Sequential(
310
- nn.Linear(in_features=num_in_features, out_features=classes, bias=True),
311
- nn.Sigmoid()
312
- )
313
-
314
- def forward(self, x) -> object:
315
- """
316
-
317
- :rtype: object
318
- """
319
- x = self.densenet121(x)
320
- return x
321
-
322
-
323
- class CheXNet(nn.Module):
324
- def __init__(self, classes=156):
325
- super(CheXNet, self).__init__()
326
- self.densenet121 = CheXNetDenseNet121(classes=14)
327
- self.densenet121 = torch.nn.DataParallel(self.densenet121).cuda()
328
- self.densenet121.load_state_dict(torch.load('./models/CheXNet.pth.tar')['state_dict'])
329
- self.densenet121.module.densenet121.classifier = nn.Sequential(
330
- self.__init_linear(1024, classes),
331
- nn.Sigmoid()
332
- )
333
-
334
- def __init_linear(self, in_features, out_features):
335
- func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
336
- func.weight.data.normal_(0, 0.1)
337
- return func
338
-
339
- def forward(self, x) -> object:
340
- """
341
-
342
- :rtype: object
343
- """
344
- x = self.densenet121(x)
345
- return x
346
-
347
-
348
- class ModelFactory(object):
349
- def __init__(self, model_name, pretrained, classes):
350
- self.model_name = model_name
351
- self.pretrained = pretrained
352
- self.classes = classes
353
-
354
- def create_model(self):
355
- if self.model_name == 'VGG19':
356
- _model = VGG19(pretrained=self.pretrained, classes=self.classes)
357
- elif self.model_name == 'DenseNet121':
358
- _model = DenseNet121(pretrained=self.pretrained, classes=self.classes)
359
- elif self.model_name == 'DenseNet161':
360
- _model = DenseNet161(pretrained=self.pretrained, classes=self.classes)
361
- elif self.model_name == 'DenseNet169':
362
- _model = DenseNet169(pretrained=self.pretrained, classes=self.classes)
363
- elif self.model_name == 'DenseNet201':
364
- _model = DenseNet201(pretrained=self.pretrained, classes=self.classes)
365
- elif self.model_name == 'CheXNet':
366
- _model = CheXNet(classes=self.classes)
367
- elif self.model_name == 'ResNet18':
368
- _model = ResNet18(pretrained=self.pretrained, classes=self.classes)
369
- elif self.model_name == 'ResNet34':
370
- _model = ResNet34(pretrained=self.pretrained, classes=self.classes)
371
- elif self.model_name == 'ResNet50':
372
- _model = ResNet50(pretrained=self.pretrained, classes=self.classes)
373
- elif self.model_name == 'ResNet101':
374
- _model = ResNet101(pretrained=self.pretrained, classes=self.classes)
375
- elif self.model_name == 'ResNet152':
376
- _model = ResNet152(pretrained=self.pretrained, classes=self.classes)
377
- elif self.model_name == 'VGG':
378
- _model = VGG(tags_num=self.classes)
379
- else:
380
- _model = CheXNet(classes=self.classes)
381
-
382
- return _model
383
-
384
-
385
- class EncoderCNN(nn.Module):
386
- def __init__(self, embed_size, pretrained=True):
387
- super(EncoderCNN, self).__init__()
388
- # TODO Extract Image features from CNN based on other models
389
- resnet = models.resnet152(pretrained=pretrained)
390
- modules = list(resnet.children())[:-1]
391
- self.resnet = nn.Sequential(*modules)
392
- self.linear = nn.Linear(resnet.fc.in_features, embed_size)
393
- self.bn = nn.BatchNorm1d(embed_size, momentum=0.1)
394
- self.__init_weights()
395
-
396
- def __init_weights(self):
397
- self.linear.weight.data.normal_(0.0, 0.1)
398
- self.linear.bias.data.fill_(0)
399
-
400
- def forward(self, images) -> object:
401
- """
402
-
403
- :rtype: object
404
- """
405
- features = self.resnet(images)
406
- features = Variable(features.data)
407
- features = features.view(features.size(0), -1)
408
- features = self.bn(self.linear(features))
409
- return features
410
-
411
-
412
- class DecoderRNN(nn.Module):
413
- def __init__(self, embed_size, hidden_size, vocab_size, num_layers, n_max=50):
414
- super(DecoderRNN, self).__init__()
415
- self.embed = nn.Embedding(vocab_size, embed_size)
416
- self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
417
- self.linear = nn.Linear(hidden_size, vocab_size)
418
- self.__init_weights()
419
- self.n_max = n_max
420
-
421
- def __init_weights(self):
422
- self.embed.weight.data.uniform_(-0.1, 0.1)
423
- self.linear.weight.data.uniform_(-0.1, 0.1)
424
- self.linear.bias.data.fill_(0)
425
-
426
- def forward(self, features, captions) -> object:
427
- """
428
-
429
- :rtype: object
430
- """
431
- embeddings = self.embed(captions)
432
- embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
433
- hidden, _ = self.lstm(embeddings)
434
- outputs = self.linear(hidden[:, -1, :])
435
- return outputs
436
-
437
- def sample(self, features, start_tokens):
438
- sampled_ids = np.zeros((np.shape(features)[0], self.n_max))
439
- predicted = start_tokens
440
- embeddings = features
441
- embeddings = embeddings.unsqueeze(1)
442
-
443
- for i in range(self.n_max):
444
- predicted = self.embed(predicted)
445
- embeddings = torch.cat([embeddings, predicted], dim=1)
446
- hidden_states, _ = self.lstm(embeddings)
447
- hidden_states = hidden_states[:, -1, :]
448
- outputs = self.linear(hidden_states)
449
- predicted = torch.max(outputs, 1)[1]
450
- sampled_ids[:, i] = predicted
451
- predicted = predicted.unsqueeze(1)
452
- return sampled_ids
453
-
454
-
455
- class VisualFeatureExtractor(nn.Module):
456
- def __init__(self, pretrained=False):
457
- super(VisualFeatureExtractor, self).__init__()
458
- resnet = models.resnet152(pretrained=pretrained)
459
- modules = list(resnet.children())[:-1]
460
- self.resnet = nn.Sequential(*modules)
461
- self.out_features = resnet.fc.in_features
462
-
463
- def forward(self, images) -> object:
464
- """
465
-
466
- :rtype: object
467
- """
468
- features = self.resnet(images)
469
- features = features.view(features.size(0), -1)
470
- return features
471
-
472
-
473
- class MLC(nn.Module):
474
- def __init__(self, classes=156, sementic_features_dim=512, fc_in_features=2048, k=10):
475
- super(MLC, self).__init__()
476
- self.classifier = nn.Linear(in_features=fc_in_features, out_features=classes)
477
- self.embed = nn.Embedding(classes, sementic_features_dim)
478
- self.k = k
479
- self.softmax = nn.Softmax()
480
-
481
- def forward(self, visual_features) -> object:
482
- """
483
-
484
- :rtype: object
485
- """
486
- tags = self.softmax(self.classifier(visual_features))
487
- semantic_features = self.embed(torch.topk(tags, self.k)[1])
488
- return tags, semantic_features
489
-
490
-
491
- class CoAttention(nn.Module):
492
- def __init__(self, embed_size=512, hidden_size=512, visual_size=2048):
493
- super(CoAttention, self).__init__()
494
- self.W_v = nn.Linear(in_features=visual_size, out_features=visual_size)
495
- self.bn_v = nn.BatchNorm1d(num_features=visual_size, momentum=0.1)
496
-
497
- self.W_v_h = nn.Linear(in_features=hidden_size, out_features=visual_size)
498
- self.bn_v_h = nn.BatchNorm1d(num_features=visual_size, momentum=0.1)
499
-
500
- self.W_v_att = nn.Linear(in_features=visual_size, out_features=visual_size)
501
- self.bn_v_att = nn.BatchNorm1d(num_features=visual_size, momentum=0.1)
502
-
503
- self.W_a = nn.Linear(in_features=hidden_size, out_features=hidden_size)
504
- self.bn_a = nn.BatchNorm1d(num_features=10, momentum=0.1)
505
-
506
- self.W_a_h = nn.Linear(in_features=hidden_size, out_features=hidden_size)
507
- self.bn_a_h = nn.BatchNorm1d(num_features=1, momentum=0.1)
508
-
509
- self.W_a_att = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True)
510
- self.bn_a_att = nn.BatchNorm1d(num_features=10, momentum=0.1)
511
-
512
- self.W_fc = nn.Linear(in_features=visual_size + hidden_size, out_features=embed_size)
513
- self.bn_fc = nn.BatchNorm1d(num_features=embed_size, momentum=0.1)
514
-
515
- self.tanh = nn.Tanh()
516
- self.softmax = nn.Softmax()
517
-
518
- def forward(self, visual_features, semantic_features, h_sent) -> object:
519
- """
520
- only training
521
- :rtype: object
522
- """
523
- W_v = self.bn_v(self.W_v(visual_features))
524
- W_v_h = self.bn_v_h(self.W_v_h(h_sent.squeeze(1)))
525
-
526
- alpha_v = self.softmax(self.bn_v_att(self.W_v_att(self.tanh(W_v + W_v_h))))
527
- v_att = torch.mul(alpha_v, visual_features)
528
- # v_att = torch.mul(alpha_v, visual_features).sum(1).unsqueeze(1)
529
-
530
- W_a_h = self.bn_a_h(self.W_a_h(h_sent))
531
- W_a = self.bn_a(self.W_a(semantic_features))
532
- alpha_a = self.softmax(self.bn_a_att(self.W_a_att(self.tanh(torch.add(W_a_h, W_a)))))
533
- a_att = torch.mul(alpha_a, semantic_features).sum(1)
534
- # a_att = (alpha_a * semantic_features).sum(1)
535
- ctx = self.bn_fc(self.W_fc(torch.cat([v_att, a_att], dim=1)))
536
- # return self.W_fc(self.bn_fc(torch.cat([v_att, a_att], dim=1)))
537
- return ctx, v_att
538
-
539
-
540
- class SentenceLSTM(nn.Module):
541
- def __init__(self, embed_size=512, hidden_size=512, num_layers=1):
542
- super(SentenceLSTM, self).__init__()
543
- self.lstm = nn.LSTM(input_size=embed_size, hidden_size=hidden_size, num_layers=num_layers)
544
- self.W_t_h = nn.Linear(in_features=hidden_size, out_features=embed_size, bias=True)
545
- self.bn_t_h = nn.BatchNorm1d(num_features=1, momentum=0.1)
546
-
547
- self.W_t_ctx = nn.Linear(in_features=embed_size, out_features=embed_size, bias=True)
548
- self.bn_t_ctx = nn.BatchNorm1d(num_features=1, momentum=0.1)
549
-
550
- self.W_stop_s_1 = nn.Linear(in_features=hidden_size, out_features=embed_size, bias=True)
551
- self.bn_stop_s_1 = nn.BatchNorm1d(num_features=1, momentum=0.1)
552
-
553
- self.W_stop_s = nn.Linear(in_features=hidden_size, out_features=embed_size, bias=True)
554
- self.bn_stop_s = nn.BatchNorm1d(num_features=1, momentum=0.1)
555
-
556
- self.W_stop = nn.Linear(in_features=embed_size, out_features=2, bias=True)
557
- self.bn_stop = nn.BatchNorm1d(num_features=1, momentum=0.1)
558
-
559
- self.W_topic = nn.Linear(in_features=embed_size, out_features=embed_size, bias=True)
560
- self.bn_topic = nn.BatchNorm1d(num_features=1, momentum=0.1)
561
-
562
- self.W_topic_2 = nn.Linear(in_features=embed_size, out_features=embed_size, bias=True)
563
- self.bn_topic_2 = nn.BatchNorm1d(num_features=1, momentum=0.1)
564
-
565
- self.sigmoid = nn.Sigmoid()
566
- self.tanh = nn.Tanh()
567
-
568
- # def forward(self, ctx, prev_hidden_state, states=None) -> object:
569
- # """
570
- # Only training
571
- # :rtype: object
572
- # """
573
- # ctx = ctx.unsqueeze(1)
574
- # hidden_state, states = self.lstm(ctx, states)
575
- # topic = self.bn_topic(self.W_topic(self.sigmoid(self.bn_t_h(self.W_t_h(hidden_state))
576
- # + self.bn_t_ctx(self.W_t_ctx(ctx)))))
577
- # p_stop = self.bn_stop(self.W_stop(self.sigmoid(self.bn_stop_s_1(self.W_stop_s_1(prev_hidden_state))
578
- # + self.bn_stop_s(self.W_stop_s(hidden_state)))))
579
- # return topic, p_stop, hidden_state, states
580
-
581
- def forward(self, ctx, prev_hidden_state, states=None) -> object:
582
- """
583
- v2
584
- :rtype: object
585
- """
586
- ctx = ctx.unsqueeze(1)
587
- hidden_state, states = self.lstm(ctx, states)
588
- topic = self.bn_topic(self.W_topic(self.tanh(self.bn_t_h(self.W_t_h(hidden_state)
589
- + self.W_t_ctx(ctx)))))
590
- p_stop = self.bn_stop(self.W_stop(self.tanh(self.bn_stop_s(self.W_stop_s_1(prev_hidden_state)
591
- + self.W_stop_s(hidden_state)))))
592
- return topic, p_stop, hidden_state, states
593
-
594
-
595
- class SentenceTCN(nn.Module):
596
- def __init__(self,
597
- input_channel=10,
598
- embed_size=512,
599
- output_size=512,
600
- nhid=512,
601
- levels=8,
602
- kernel_size=2,
603
- dropout=0):
604
- super(SentenceTCN, self).__init__()
605
- channel_sizes = [nhid] * levels
606
- self.tcn = TCN(input_size=input_channel,
607
- output_size=output_size,
608
- num_channels=channel_sizes,
609
- kernel_size=kernel_size,
610
- dropout=dropout)
611
- self.W_t_h = nn.Linear(in_features=output_size, out_features=embed_size, bias=True)
612
- self.W_t_ctx = nn.Linear(in_features=output_size, out_features=embed_size, bias=True)
613
- self.W_stop_s_1 = nn.Linear(in_features=output_size, out_features=embed_size, bias=True)
614
- self.W_stop_s = nn.Linear(in_features=output_size, out_features=embed_size, bias=True)
615
- self.W_stop = nn.Linear(in_features=embed_size, out_features=2, bias=True)
616
- self.t_w = nn.Linear(in_features=5120, out_features=2, bias=True)
617
- self.tanh = nn.Tanh()
618
-
619
- def forward(self, ctx, prev_output) -> object:
620
- """
621
-
622
- :rtype: object
623
- """
624
- output = self.tcn.forward(ctx)
625
- topic = self.tanh(self.W_t_h(output) + self.W_t_ctx(ctx[:, -1, :]).squeeze(1))
626
- p_stop = self.W_stop(self.tanh(self.W_stop_s_1(prev_output) + self.W_stop_s(output)))
627
- return topic, p_stop, output
628
-
629
-
630
- class WordLSTM(nn.Module):
631
- def __init__(self, embed_size, hidden_size, vocab_size, num_layers, n_max=50):
632
- super(WordLSTM, self).__init__()
633
- self.embed = nn.Embedding(vocab_size, embed_size)
634
- self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
635
- self.linear = nn.Linear(hidden_size, vocab_size)
636
- self.__init_weights()
637
- self.n_max = n_max
638
- self.vocab_size = vocab_size
639
-
640
- def __init_weights(self):
641
- self.embed.weight.data.uniform_(-0.1, 0.1)
642
- self.linear.weight.data.uniform_(-0.1, 0.1)
643
- self.linear.bias.data.fill_(0)
644
-
645
- def forward(self, topic_vec, captions) -> object:
646
- """
647
-
648
- :rtype: object
649
- """
650
- embeddings = self.embed(captions)
651
- embeddings = torch.cat((topic_vec, embeddings), 1)
652
- hidden, _ = self.lstm(embeddings)
653
- outputs = self.linear(hidden[:, -1, :])
654
- return outputs
655
-
656
- def val(self, features, start_tokens):
657
- samples = torch.zeros((np.shape(features)[0], self.n_max, self.vocab_size))
658
- samples[:, 0, start_tokens[0]] = 1
659
- predicted = start_tokens
660
- embeddings = features
661
- embeddings = embeddings
662
-
663
- for i in range(1, self.n_max):
664
- predicted = self.embed(predicted)
665
- embeddings = torch.cat([embeddings, predicted], dim=1)
666
- hidden_states, _ = self.lstm(embeddings)
667
- hidden_states = hidden_states[:, -1, :]
668
- outputs = self.linear(hidden_states)
669
- samples[:, i, :] = outputs
670
- predicted = torch.max(outputs, 1)[1]
671
- predicted = predicted.unsqueeze(1)
672
- return samples
673
-
674
- def sample(self, features, start_tokens):
675
- sampled_ids = np.zeros((np.shape(features)[0], self.n_max))
676
- sampled_ids[:, 0] = start_tokens.view(-1,)
677
- predicted = start_tokens
678
- embeddings = features
679
- embeddings = embeddings
680
-
681
- for i in range(1, self.n_max):
682
- predicted = self.embed(predicted)
683
- embeddings = torch.cat([embeddings, predicted], dim=1)
684
- hidden_states, _ = self.lstm(embeddings)
685
- hidden_states = hidden_states[:, -1, :]
686
- outputs = self.linear(hidden_states)
687
- predicted = torch.max(outputs, 1)[1]
688
- sampled_ids[:, i] = predicted
689
- predicted = predicted.unsqueeze(1)
690
- return sampled_ids
691
-
692
-
693
- class WordTCN(nn.Module):
694
- def __init__(self,
695
- input_channel=11,
696
- vocab_size=1000,
697
- embed_size=512,
698
- output_size=512,
699
- nhid=512,
700
- levels=8,
701
- kernel_size=2,
702
- dropout=0,
703
- n_max=50):
704
- super(WordTCN, self).__init__()
705
- self.vocab_size = vocab_size
706
- self.embed_size = embed_size
707
- self.output_size = output_size
708
- channel_sizes = [nhid] * levels
709
- self.kernel_size = kernel_size
710
- self.dropout = dropout
711
- self.n_max = n_max
712
- self.embed = nn.Embedding(vocab_size, embed_size)
713
- self.W_out = nn.Linear(in_features=output_size, out_features=vocab_size, bias=True)
714
- self.tcn = TCN(input_size=input_channel,
715
- output_size=output_size,
716
- num_channels=channel_sizes,
717
- kernel_size=kernel_size,
718
- dropout=dropout)
719
-
720
- def forward(self, topic_vec, captions) -> object:
721
- """
722
-
723
- :rtype: object
724
- """
725
- captions = self.embed(captions)
726
- embeddings = torch.cat([topic_vec, captions], dim=1)
727
- output = self.tcn.forward(embeddings)
728
- words = self.W_out(output)
729
- return words
730
-
731
-
732
- if __name__ == '__main__':
733
- import warnings
734
- warnings.filterwarnings("ignore")
735
- images = torch.randn((4, 3, 224, 224))
736
- captions = torch.ones((4, 10)).long()
737
- hidden_state = torch.randn((4, 1, 512))
738
-
739
- print("images:{}".format(images.shape))
740
- print("captions:{}".format(captions.shape))
741
- print("hidden_states:{}".format(hidden_state.shape))
742
-
743
- extractor = VisualFeatureExtractor()
744
- visual_features = extractor.forward(images)
745
- print("visual_features:{}".format(visual_features.shape))
746
-
747
- mlc = MLC()
748
- tags, semantic_features = mlc.forward(visual_features)
749
- print("tags:{}".format(tags.shape))
750
- print("semantic_features:{}".format(semantic_features.shape))
751
-
752
- co_att = CoAttention()
753
- ctx, v_att = co_att.forward(visual_features, semantic_features, hidden_state)
754
- print("ctx:{}".format(ctx.shape))
755
- print("v_att:{}".format(v_att.shape))
756
-
757
- sent_lstm = SentenceLSTM()
758
- topic, p_stop, hidden_state, states = sent_lstm.forward(ctx, hidden_state)
759
- print("Topic:{}".format(topic.shape))
760
- print("P_STOP:{}".format(p_stop.shape))
761
-
762
- word_lstm = WordLSTM(embed_size=512, hidden_size=512, vocab_size=100, num_layers=1)
763
- words = word_lstm.forward(topic, captions)
764
- print("words:{}".format(words.shape))
765
-
766
- # Expected Output
767
- # images: torch.Size([4, 3, 224, 224])
768
- # captions: torch.Size([4, 1, 10])
769
- # hidden_states: torch.Size([4, 1, 512])
770
- # visual_features: torch.Size([4, 2048, 7, 7])
771
- # tags: torch.Size([4, 156])
772
- # semantic_features: torch.Size([4, 10, 512])
773
- # ctx: torch.Size([4, 512])
774
- # Topic: torch.Size([4, 1, 512])
775
- # P_STOP: torch.Size([4, 1, 2])
776
- # words: torch.Size([4, 1000])
777
-
778
- # images = torch.randn((4, 3, 224, 224))
779
- # captions = torch.ones((4, 3, 10)).long()
780
- # prev_outputs = torch.randn((4, 512))
781
- # now_words = torch.ones((4, 1))
782
- #
783
- # ctx_records = torch.zeros((4, 10, 512))
784
- # captions = torch.zeros((4, 10)).long()
785
- #
786
- # print("images:{}".format(images.shape))
787
- # print("captions:{}".format(captions.shape))
788
- # print("hidden_states:{}".format(prev_outputs.shape))
789
- #
790
- # extractor = VisualFeatureExtractor()
791
- # visual_features = extractor.forward(images)
792
- # print("visual_features:{}".format(visual_features.shape))
793
- #
794
- # mlc = MLC()
795
- # tags, semantic_features = mlc.forward(visual_features)
796
- # print("tags:{}".format(tags.shape))
797
- # print("semantic_features:{}".format(semantic_features.shape))
798
- #
799
- # co_att = CoAttention()
800
- # ctx = co_att.forward(visual_features, semantic_features, prev_outputs)
801
- # print("ctx:{}".format(ctx.shape))
802
- #
803
- # ctx_records[:, 0, :] = ctx
804
- #
805
- # sent_tcn = SentenceTCN()
806
- # topic, p_stop, prev_outputs = sent_tcn.forward(ctx_records, prev_outputs)
807
- # print("Topic:{}".format(topic.shape))
808
- # print("P_STOP:{}".format(p_stop.shape))
809
- # print("Prev_Outputs:{}".format(prev_outputs.shape))
810
- #
811
- # captions[:, 0] = now_words.view(-1,)
812
- #
813
- # word_tcn = WordTCN()
814
- # words = word_tcn.forward(topic, captions)
815
- # print("words:{}".format(words.shape))
816
-