DmitrMakeev's picture
Upload 7 files
02cacbe
raw
history blame
1.16 kB
import torch.nn as nn
import torch
from models.util import MyResNet34
class audio2poseLSTM(nn.Module):
def __init__(self):
super(audio2poseLSTM,self).__init__()
self.em_pose = MyResNet34(256, 1)
self.em_audio = MyResNet34(256, 1)
self.lstm = nn.LSTM(512,256,num_layers=2,bias=True,batch_first=True)
self.output = nn.Linear(256,6)
def forward(self,x):
pose_em = self.em_pose(x["img"])
bs = pose_em.shape[0]
zero_state = torch.zeros((2, bs, 256), requires_grad=True).to(pose_em.device)
cur_state = (zero_state, zero_state)
img_em = pose_em
bs,seqlen,num,dims = x["audio"].shape
audio = x["audio"].reshape(-1, 1, num, dims)
audio_em = self.em_audio(audio).reshape(bs, seqlen, 256)
result = [self.output(img_em).unsqueeze(1)]
for i in range(seqlen):
img_em,cur_state = self.lstm(torch.cat((audio_em[:,i:i+1],img_em.unsqueeze(1)),dim=2),cur_state)
img_em = img_em.reshape(-1, 256)
result.append(self.output(img_em).unsqueeze(1))
res = torch.cat(result,dim=1)
return res