Update app.py
Browse files
app.py
CHANGED
@@ -6,28 +6,74 @@ import imageio
|
|
6 |
import cv2
|
7 |
|
8 |
|
9 |
-
class
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
self.
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
|
|
27 |
|
28 |
def forward(self, input):
|
29 |
-
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
|
33 |
def display_gif(file_name, save_name):
|
|
|
6 |
import cv2
|
7 |
|
8 |
|
9 |
+
class RelationModuleMultiScale(torch.nn.Module):
|
10 |
+
|
11 |
+
def __init__(self, img_feature_dim, num_bottleneck, num_frames):
|
12 |
+
super(RelationModuleMultiScale, self).__init__()
|
13 |
+
self.subsample_num = 3
|
14 |
+
self.img_feature_dim = img_feature_dim
|
15 |
+
self.scales = [i for i in range(num_frames, 1, -1)]
|
16 |
+
self.relations_scales = []
|
17 |
+
self.subsample_scales = []
|
18 |
+
for scale in self.scales:
|
19 |
+
relations_scale = self.return_relationset(num_frames, scale)
|
20 |
+
self.relations_scales.append(relations_scale)
|
21 |
+
self.subsample_scales.append(min(self.subsample_num, len(relations_scale)))
|
22 |
+
self.num_frames = num_frames
|
23 |
+
self.fc_fusion_scales = nn.ModuleList() # high-tech modulelist
|
24 |
+
for i in range(len(self.scales)):
|
25 |
+
scale = self.scales[i]
|
26 |
+
fc_fusion = nn.Sequential(nn.ReLU(), nn.Linear(scale * self.img_feature_dim, num_bottleneck), nn.ReLU())
|
27 |
+
self.fc_fusion_scales += [fc_fusion]
|
28 |
|
29 |
def forward(self, input):
|
30 |
+
act_scale_1 = input[:, self.relations_scales[0][0] , :]
|
31 |
+
act_scale_1 = act_scale_1.view(act_scale_1.size(0), self.scales[0] * self.img_feature_dim)
|
32 |
+
act_scale_1 = self.fc_fusion_scales[0](act_scale_1)
|
33 |
+
act_scale_1 = act_scale_1.unsqueeze(1)
|
34 |
+
act_all = act_scale_1.clone()
|
35 |
+
for scaleID in range(1, len(self.scales)):
|
36 |
+
act_relation_all = torch.zeros_like(act_scale_1)
|
37 |
+
num_total_relations = len(self.relations_scales[scaleID])
|
38 |
+
num_select_relations = self.subsample_scales[scaleID]
|
39 |
+
idx_relations_evensample = [int(ceil(i * num_total_relations / num_select_relations)) for i in range(num_select_relations)]
|
40 |
+
for idx in idx_relations_evensample:
|
41 |
+
act_relation = input[:, self.relations_scales[scaleID][idx], :]
|
42 |
+
act_relation = act_relation.view(act_relation.size(0), self.scales[scaleID] * self.img_feature_dim)
|
43 |
+
act_relation = self.fc_fusion_scales[scaleID](act_relation)
|
44 |
+
act_relation = act_relation.unsqueeze(1)
|
45 |
+
act_relation_all += act_relation
|
46 |
+
act_all = torch.cat((act_all, act_relation_all), 1)
|
47 |
+
return act_all
|
48 |
+
|
49 |
+
def return_relationset(self, num_frames, num_frames_relation):
|
50 |
+
import itertools
|
51 |
+
return list(itertools.combinations([i for i in range(num_frames)], num_frames_relation))
|
52 |
+
|
53 |
+
|
54 |
+
parser = argparse.ArgumentParser()
|
55 |
+
parser.add_argument('--dataset', default='Sprite', help='datasets')
|
56 |
+
parser.add_argument('--data_root', default='dataset', help='root directory for data')
|
57 |
+
parser.add_argument('--num_class', type=int, default=15, help='the number of class for jester dataset')
|
58 |
+
parser.add_argument('--input_type', default='image', choices=['feature', 'image'], help='the type of input')
|
59 |
+
parser.add_argument('--src', default='domain_1', help='source domain')
|
60 |
+
parser.add_argument('--tar', default='domain_2', help='target domain')
|
61 |
+
parser.add_argument('--num_segments', type=int, default=8, help='the number of frame segment')
|
62 |
+
parser.add_argument('--backbone', type=str, default="dcgan", choices=['dcgan', 'resnet101', 'I3Dpretrain','I3Dfinetune'], help='backbone')
|
63 |
+
parser.add_argument('--channels', default=3, type=int, help='input channels for image inputs')
|
64 |
+
parser.add_argument('--add_fc', default=1, type=int, metavar='M', help='number of additional fc layers (excluding the last fc layer) (e.g. 0, 1, 2)')
|
65 |
+
parser.add_argument('--fc_dim', type=int, default=1024, help='dimension of added fc')
|
66 |
+
parser.add_argument('--frame_aggregation', type=str, default='trn', choices=[ 'rnn', 'trn'], help='aggregation of frame features (none if baseline_type is not video)')
|
67 |
+
parser.add_argument('--dropout_rate', default=0.5, type=float, help='dropout ratio for frame-level feature (default: 0.5)')
|
68 |
+
parser.add_argument('--f_dim', type=int, default=512, help='dim of f')
|
69 |
+
parser.add_argument('--z_dim', type=int, default=512, help='dimensionality of z_t')
|
70 |
+
parser.add_argument('--f_rnn_layers', type=int, default=1, help='number of layers (content lstm)')
|
71 |
+
parser.add_argument('--use_bn', type=str, default='none', choices=['none', 'AdaBN', 'AutoDIAL'], help='normalization-based methods')
|
72 |
+
parser.add_argument('--prior_sample', type=str, default='random', choices=['random', 'post'], help='how to sample prior')
|
73 |
+
parser.add_argument('--batch_size', default=128, type=int, help='-batch size')
|
74 |
+
parser.add_argument('--use_attn', type=str, default='TransAttn', choices=['none', 'TransAttn', 'general'], help='attention-mechanism')
|
75 |
+
parser.add_argument('--data_threads', type=int, default=5, help='number of data loading threads')
|
76 |
+
opt = parser.parse_args(args=[])
|
77 |
|
78 |
|
79 |
def display_gif(file_name, save_name):
|