File size: 1,155 Bytes
02cacbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch.nn as nn
import torch
from models.util import MyResNet34

class audio2poseLSTM(nn.Module):
    def __init__(self):
        super(audio2poseLSTM,self).__init__()

        self.em_pose = MyResNet34(256, 1)
        self.em_audio = MyResNet34(256, 1)
        self.lstm = nn.LSTM(512,256,num_layers=2,bias=True,batch_first=True)

        self.output = nn.Linear(256,6)


    def forward(self,x):
        pose_em = self.em_pose(x["img"])
        bs = pose_em.shape[0]
        zero_state = torch.zeros((2, bs, 256), requires_grad=True).to(pose_em.device)
        cur_state = (zero_state, zero_state)
        img_em = pose_em
        bs,seqlen,num,dims = x["audio"].shape

        audio = x["audio"].reshape(-1, 1, num, dims)
        audio_em = self.em_audio(audio).reshape(bs, seqlen, 256)

        result = [self.output(img_em).unsqueeze(1)]

        for i in range(seqlen):

            img_em,cur_state = self.lstm(torch.cat((audio_em[:,i:i+1],img_em.unsqueeze(1)),dim=2),cur_state)
            img_em = img_em.reshape(-1, 256)

            result.append(self.output(img_em).unsqueeze(1))
        res = torch.cat(result,dim=1)
        return res