Superxixixi commited on
Commit
b98cec2
1 Parent(s): a25806a

Upload 5 files

Browse files
Files changed (5) hide show
  1. attentionLayer.py +39 -0
  2. audioEncoder.py +108 -0
  3. convLayer.py +42 -0
  4. loconet_encoder.py +90 -0
  5. visualEncoder.py +199 -0
attentionLayer.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+ from torch.nn import MultiheadAttention
5
+
6
+
7
+ class attentionLayer(nn.Module):
8
+
9
+ def __init__(self, d_model, nhead, dropout=0.1):
10
+ super(attentionLayer, self).__init__()
11
+ self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
12
+
13
+ self.linear1 = nn.Linear(d_model, d_model * 4)
14
+ self.dropout = nn.Dropout(dropout)
15
+ self.linear2 = nn.Linear(d_model * 4, d_model)
16
+
17
+ self.norm1 = nn.LayerNorm(d_model)
18
+ self.norm2 = nn.LayerNorm(d_model)
19
+ self.dropout1 = nn.Dropout(dropout)
20
+ self.dropout2 = nn.Dropout(dropout)
21
+
22
+ self.activation = F.relu
23
+
24
+ def forward(self, src, tar, adjust=False, attn_mask=None):
25
+ # type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tensor
26
+ src = src.transpose(0, 1) # B, T, C -> T, B, C
27
+ tar = tar.transpose(0, 1) # B, T, C -> T, B, C
28
+ if adjust:
29
+ src2 = self.self_attn(src, tar, tar, attn_mask=None, key_padding_mask=None)[0]
30
+ else:
31
+ src2 = self.self_attn(tar, src, src, attn_mask=None, key_padding_mask=None)[0]
32
+ src = src + self.dropout1(src2)
33
+ src = self.norm1(src)
34
+
35
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
36
+ src = src + self.dropout2(src2)
37
+ src = self.norm2(src)
38
+ src = src.transpose(0, 1) # T, B, C -> B, T, C
39
+ return src
audioEncoder.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class SEBasicBlock(nn.Module):
6
+ expansion = 1
7
+
8
+ def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8):
9
+ super(SEBasicBlock, self).__init__()
10
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
11
+ self.bn1 = nn.BatchNorm2d(planes)
12
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)
13
+ self.bn2 = nn.BatchNorm2d(planes)
14
+ self.relu = nn.ReLU(inplace=True)
15
+ self.se = SELayer(planes, reduction)
16
+ self.downsample = downsample
17
+ self.stride = stride
18
+
19
+ def forward(self, x):
20
+ residual = x
21
+
22
+ out = self.conv1(x)
23
+ out = self.relu(out)
24
+ out = self.bn1(out)
25
+
26
+ out = self.conv2(out)
27
+ out = self.bn2(out)
28
+ out = self.se(out)
29
+
30
+ if self.downsample is not None:
31
+ residual = self.downsample(x)
32
+
33
+ out += residual
34
+ out = self.relu(out)
35
+ return out
36
+
37
+ class SELayer(nn.Module):
38
+ def __init__(self, channel, reduction=8):
39
+ super(SELayer, self).__init__()
40
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
41
+ self.fc = nn.Sequential(
42
+ nn.Linear(channel, channel // reduction),
43
+ nn.ReLU(inplace=True),
44
+ nn.Linear(channel // reduction, channel),
45
+ nn.Sigmoid()
46
+ )
47
+
48
+ def forward(self, x):
49
+ b, c, _, _ = x.size()
50
+ y = self.avg_pool(x).view(b, c)
51
+ y = self.fc(y).view(b, c, 1, 1)
52
+ return x * y
53
+
54
+ class audioEncoder(nn.Module):
55
+ def __init__(self, layers, num_filters, **kwargs):
56
+ super(audioEncoder, self).__init__()
57
+ block = SEBasicBlock
58
+ self.inplanes = num_filters[0]
59
+
60
+ self.conv1 = nn.Conv2d(1, num_filters[0] , kernel_size=7, stride=(2, 1), padding=3,
61
+ bias=False)
62
+ self.bn1 = nn.BatchNorm2d(num_filters[0])
63
+ self.relu = nn.ReLU(inplace=True)
64
+
65
+ self.layer1 = self._make_layer(block, num_filters[0], layers[0])
66
+ self.layer2 = self._make_layer(block, num_filters[1], layers[1], stride=(2, 2))
67
+ self.layer3 = self._make_layer(block, num_filters[2], layers[2], stride=(2, 2))
68
+ self.layer4 = self._make_layer(block, num_filters[3], layers[3], stride=(1, 1))
69
+ out_dim = num_filters[3] * block.expansion
70
+
71
+ for m in self.modules():
72
+ if isinstance(m, nn.Conv2d):
73
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
74
+ elif isinstance(m, nn.BatchNorm2d):
75
+ nn.init.constant_(m.weight, 1)
76
+ nn.init.constant_(m.bias, 0)
77
+
78
+ def _make_layer(self, block, planes, blocks, stride=1):
79
+ downsample = None
80
+ if stride != 1 or self.inplanes != planes * block.expansion:
81
+ downsample = nn.Sequential(
82
+ nn.Conv2d(self.inplanes, planes * block.expansion,
83
+ kernel_size=1, stride=stride, bias=False),
84
+ nn.BatchNorm2d(planes * block.expansion),
85
+ )
86
+
87
+ layers = []
88
+ layers.append(block(self.inplanes, planes, stride, downsample))
89
+ self.inplanes = planes * block.expansion
90
+ for i in range(1, blocks):
91
+ layers.append(block(self.inplanes, planes))
92
+
93
+ return nn.Sequential(*layers)
94
+
95
+ def forward(self, x):
96
+ x = self.conv1(x)
97
+ x = self.bn1(x)
98
+ x = self.relu(x)
99
+
100
+ x = self.layer1(x)
101
+ x = self.layer2(x)
102
+ x = self.layer3(x)
103
+ x = self.layer4(x)
104
+ x = torch.mean(x, dim=2, keepdim=True)
105
+ x = x.view((x.size()[0], x.size()[1], -1))
106
+ x = x.transpose(1, 2)
107
+
108
+ return x
convLayer.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+
5
+
6
+ class ConvLayer(nn.Module):
7
+
8
+ def __init__(self, cfg):
9
+ super(ConvLayer, self).__init__()
10
+ self.cfg = cfg
11
+ self.s = cfg.num_speakers
12
+ self.conv2d = torch.nn.Conv2d(256, 256 * self.s, (self.s, 7), padding=(0, 3))
13
+ # below line is speaker parallel 93.88 code
14
+ # self.conv2d = torch.nn.Conv2d(256, 256 * self.s, (3, 7), padding=(0, 3))
15
+ self.ln = torch.nn.LayerNorm(256)
16
+ self.conv2d_1x1 = torch.nn.Conv2d(256, 512, (1, 1), padding=(0, 0))
17
+ self.conv2d_1x1_2 = torch.nn.Conv2d(512, 256, (1, 1), padding=(0, 0))
18
+ self.gelu = nn.GELU()
19
+
20
+ def forward(self, x, b, s):
21
+
22
+ identity = x # b*s, t, c
23
+ t = x.shape[1]
24
+ c = x.shape[2]
25
+ out = x.view(b, s, t, c)
26
+ out = out.permute(0, 3, 1, 2) # b, c, s, t
27
+
28
+ out = self.conv2d(out) # b, s*c, 1, t
29
+ out = out.view(b, c, s, t)
30
+ out = out.permute(0, 2, 3, 1) # b, s, t, c
31
+ out = self.ln(out)
32
+ out = out.permute(0, 3, 1, 2)
33
+ out = self.conv2d_1x1(out)
34
+ out = self.gelu(out)
35
+ out = self.conv2d_1x1_2(out) # b, c, s, t
36
+
37
+ out = out.permute(0, 2, 3, 1) # b, s, t, c
38
+ out = out.view(b * s, t, c)
39
+
40
+ out += identity
41
+
42
+ return out, b, s
loconet_encoder.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from attentionLayer import attentionLayer
5
+ from convLayer import ConvLayer
6
+ from torchvggish import vggish
7
+ from visualEncoder import visualFrontend, visualConv1D, visualTCN
8
+
9
+
10
+ class locoencoder(nn.Module):
11
+
12
+ def __init__(self, cfg):
13
+ super(locoencoder, self).__init__()
14
+ self.cfg = cfg
15
+ # Visual Temporal Encoder
16
+ self.visualFrontend = visualFrontend(cfg) # Visual Frontend
17
+ self.visualTCN = visualTCN() # Visual Temporal Network TCN
18
+ self.visualConv1D = visualConv1D() # Visual Temporal Network Conv1d
19
+
20
+ urls = {
21
+ 'vggish':
22
+ "https://github.com/harritaylor/torchvggish/releases/download/v0.1/vggish-10086976.pth"
23
+ }
24
+ self.audioEncoder = vggish.VGGish(urls, preprocess=False, postprocess=False)
25
+ self.audio_pool = nn.AdaptiveAvgPool1d(1)
26
+
27
+ # Audio-visual Cross Attention
28
+ self.crossA2V = attentionLayer(d_model=128, nhead=8)
29
+ self.crossV2A = attentionLayer(d_model=128, nhead=8)
30
+
31
+ # Audio-visual Self Attention
32
+
33
+ num_layers = self.cfg.av_layers
34
+ layers = nn.ModuleList()
35
+ for i in range(num_layers):
36
+ layers.append(ConvLayer(cfg))
37
+ layers.append(attentionLayer(d_model=256, nhead=8))
38
+ self.convAV = layers
39
+
40
+ def forward_visual_frontend(self, x):
41
+
42
+ B, T, W, H = x.shape
43
+ x = x.view(B * T, 1, 1, W, H)
44
+ x = (x / 255 - 0.4161) / 0.1688
45
+ x = self.visualFrontend(x)
46
+ x = x.view(B, T, 512)
47
+ x = x.transpose(1, 2)
48
+ x = self.visualTCN(x)
49
+ x = self.visualConv1D(x)
50
+ x = x.transpose(1, 2)
51
+ return x
52
+
53
+ def forward_audio_frontend(self, x):
54
+ t = x.shape[-2]
55
+ numFrames = t // 4
56
+ pad = 8 - (t % 8)
57
+ x = torch.nn.functional.pad(x, (0, 0, 0, pad), "constant")
58
+ # x = x.unsqueeze(1).transpose(2, 3)
59
+ x = self.audioEncoder(x)
60
+
61
+ b, c, t2, freq = x.shape
62
+ x = x.view(b * c, t2, freq)
63
+ x = self.audio_pool(x)
64
+ x = x.view(b, c, t2)[:, :, :numFrames]
65
+ x = x.permute(0, 2, 1)
66
+ return x
67
+
68
+ def forward_cross_attention(self, x1, x2):
69
+ x1_c = self.crossA2V(src=x1, tar=x2, adjust=self.cfg.adjust_attention)
70
+ x2_c = self.crossV2A(src=x2, tar=x1, adjust=self.cfg.adjust_attention)
71
+ return x1_c, x2_c
72
+
73
+ def forward_audio_visual_backend(self, x1, x2, b=1, s=1):
74
+ x = torch.cat((x1, x2), 2) # B*S, T, 2C
75
+ for i, layer in enumerate(self.convAV):
76
+ if i % 2 == 0:
77
+ x, b, s = layer(x, b, s)
78
+ else:
79
+ x = layer(src=x, tar=x)
80
+
81
+ x = torch.reshape(x, (-1, 256))
82
+ return x
83
+
84
+ def forward_audio_backend(self, x):
85
+ x = torch.reshape(x, (-1, 128))
86
+ return x
87
+
88
+ def forward_visual_backend(self, x):
89
+ x = torch.reshape(x, (-1, 128))
90
+ return x
visualEncoder.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ##
2
+ # ResNet18 Pretrained network to extract lip embedding
3
+ # This code is modified based on https://github.com/lordmartian/deep_avsr
4
+ ##
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from attentionLayer import attentionLayer
10
+
11
+
12
+ class ResNetLayer(nn.Module):
13
+ """
14
+ A ResNet layer used to build the ResNet network.
15
+ Architecture:
16
+ --> conv-bn-relu -> conv -> + -> bn-relu -> conv-bn-relu -> conv -> + -> bn-relu -->
17
+ | | | |
18
+ -----> downsample ------> ------------------------------------->
19
+ """
20
+
21
+ def __init__(self, inplanes, outplanes, stride):
22
+ super(ResNetLayer, self).__init__()
23
+ self.conv1a = nn.Conv2d(inplanes,
24
+ outplanes,
25
+ kernel_size=3,
26
+ stride=stride,
27
+ padding=1,
28
+ bias=False)
29
+ self.bn1a = nn.BatchNorm2d(outplanes, momentum=0.01, eps=0.001)
30
+ self.conv2a = nn.Conv2d(outplanes,
31
+ outplanes,
32
+ kernel_size=3,
33
+ stride=1,
34
+ padding=1,
35
+ bias=False)
36
+ self.stride = stride
37
+ if self.stride != 1:
38
+ self.downsample = nn.Conv2d(inplanes,
39
+ outplanes,
40
+ kernel_size=(1, 1),
41
+ stride=stride,
42
+ bias=False)
43
+ self.outbna = nn.BatchNorm2d(outplanes, momentum=0.01, eps=0.001)
44
+
45
+ self.conv1b = nn.Conv2d(outplanes,
46
+ outplanes,
47
+ kernel_size=3,
48
+ stride=1,
49
+ padding=1,
50
+ bias=False)
51
+ self.bn1b = nn.BatchNorm2d(outplanes, momentum=0.01, eps=0.001)
52
+ self.conv2b = nn.Conv2d(outplanes,
53
+ outplanes,
54
+ kernel_size=3,
55
+ stride=1,
56
+ padding=1,
57
+ bias=False)
58
+ self.outbnb = nn.BatchNorm2d(outplanes, momentum=0.01, eps=0.001)
59
+ return
60
+
61
+ def forward(self, inputBatch):
62
+ batch = F.relu(self.bn1a(self.conv1a(inputBatch)))
63
+ batch = self.conv2a(batch)
64
+ if self.stride == 1:
65
+ residualBatch = inputBatch
66
+ else:
67
+ residualBatch = self.downsample(inputBatch)
68
+ batch = batch + residualBatch
69
+ intermediateBatch = batch
70
+ batch = F.relu(self.outbna(batch))
71
+
72
+ batch = F.relu(self.bn1b(self.conv1b(batch)))
73
+ batch = self.conv2b(batch)
74
+ residualBatch = intermediateBatch
75
+ batch = batch + residualBatch
76
+ outputBatch = F.relu(self.outbnb(batch))
77
+ return outputBatch
78
+
79
+
80
+ class ResNet(nn.Module):
81
+ """
82
+ An 18-layer ResNet architecture.
83
+ """
84
+
85
+ def __init__(self):
86
+ super(ResNet, self).__init__()
87
+ self.layer1 = ResNetLayer(64, 64, stride=1)
88
+ self.layer2 = ResNetLayer(64, 128, stride=2)
89
+ self.layer3 = ResNetLayer(128, 256, stride=2)
90
+ self.layer4 = ResNetLayer(256, 512, stride=2)
91
+ self.avgpool = nn.AvgPool2d(kernel_size=(4, 4), stride=(1, 1))
92
+
93
+ return
94
+
95
+ def forward(self, inputBatch):
96
+ batch = self.layer1(inputBatch)
97
+ batch = self.layer2(batch)
98
+ batch = self.layer3(batch)
99
+ batch = self.layer4(batch)
100
+ outputBatch = self.avgpool(batch)
101
+ return outputBatch
102
+
103
+
104
+ class GlobalLayerNorm(nn.Module):
105
+
106
+ def __init__(self, channel_size):
107
+ super(GlobalLayerNorm, self).__init__()
108
+ self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
109
+ self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
110
+ self.reset_parameters()
111
+
112
+ def reset_parameters(self):
113
+ self.gamma.data.fill_(1)
114
+ self.beta.data.zero_()
115
+
116
+ def forward(self, y):
117
+ mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) #[M, 1, 1]
118
+ var = (torch.pow(y - mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True)
119
+ gLN_y = self.gamma * (y - mean) / torch.pow(var + 1e-8, 0.5) + self.beta
120
+ return gLN_y
121
+
122
+
123
+ class visualFrontend(nn.Module):
124
+ """
125
+ A visual feature extraction module. Generates a 512-dim feature vector per video frame.
126
+ Architecture: A 3D convolution block followed by an 18-layer ResNet.
127
+ """
128
+
129
+ def __init__(self, cfg):
130
+ self.cfg = cfg
131
+ super(visualFrontend, self).__init__()
132
+ self.frontend3D = nn.Sequential(
133
+ nn.Conv3d(1, 64, kernel_size=(5, 7, 7), stride=(1, 2, 2), padding=(2, 3, 3),
134
+ bias=False), nn.BatchNorm3d(64, momentum=0.01, eps=0.001), nn.ReLU(),
135
+ nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)))
136
+ self.resnet = ResNet()
137
+ return
138
+
139
+ def forward(self, inputBatch):
140
+ inputBatch = inputBatch.transpose(0, 1).transpose(1, 2)
141
+ batchsize = inputBatch.shape[0]
142
+ batch = self.frontend3D(inputBatch)
143
+
144
+ batch = batch.transpose(1, 2)
145
+ batch = batch.reshape(batch.shape[0] * batch.shape[1], batch.shape[2], batch.shape[3],
146
+ batch.shape[4])
147
+ outputBatch = self.resnet(batch)
148
+ outputBatch = outputBatch.reshape(batchsize, -1, 512)
149
+ outputBatch = outputBatch.transpose(1, 2)
150
+ outputBatch = outputBatch.transpose(1, 2).transpose(0, 1)
151
+ return outputBatch
152
+
153
+
154
+ class DSConv1d(nn.Module):
155
+
156
+ def __init__(self):
157
+ super(DSConv1d, self).__init__()
158
+ self.net = nn.Sequential(
159
+ nn.ReLU(),
160
+ nn.BatchNorm1d(512),
161
+ nn.Conv1d(512, 512, 3, stride=1, padding=1, dilation=1, groups=512, bias=False),
162
+ nn.PReLU(),
163
+ GlobalLayerNorm(512),
164
+ nn.Conv1d(512, 512, 1, bias=False),
165
+ )
166
+
167
+ def forward(self, x):
168
+ out = self.net(x)
169
+ return out + x
170
+
171
+
172
+ class visualTCN(nn.Module):
173
+
174
+ def __init__(self):
175
+ super(visualTCN, self).__init__()
176
+ stacks = []
177
+ for x in range(5):
178
+ stacks += [DSConv1d()]
179
+ self.net = nn.Sequential(*stacks) # Visual Temporal Network V-TCN
180
+
181
+ def forward(self, x):
182
+ out = self.net(x)
183
+ return out
184
+
185
+
186
+ class visualConv1D(nn.Module):
187
+
188
+ def __init__(self):
189
+ super(visualConv1D, self).__init__()
190
+ self.net = nn.Sequential(
191
+ nn.Conv1d(512, 256, 5, stride=1, padding=2),
192
+ nn.BatchNorm1d(256),
193
+ nn.ReLU(),
194
+ nn.Conv1d(256, 128, 1),
195
+ )
196
+
197
+ def forward(self, x):
198
+ out = self.net(x)
199
+ return out