AudioToken / modules /fga /fga_model.py
guyyariv
AudioTokenDemo
1b92e8f
raw history blame
No virus
9.75 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from atten import Atten
class FGA(nn.Module):
def __init__(self, vocab_size, word_embed_dim, hidden_ques_dim, hidden_ans_dim,
hidden_hist_dim, hidden_cap_dim, hidden_img_dim):
'''
Factor Graph Attention
:param vocab_size: vocabulary size
:param word_embed_dim
:param hidden_ques_dim:
:param hidden_ans_dim:
:param hidden_hist_dim:
:param img_features_dim:
'''
super(FGA, self).__init__()
print("Init FGA with vocab size %s, word embed %s, hidden ques %s, hidden ans %s,"
" hidden hist %s, hidden cap %s, hidden img %s" % (vocab_size, word_embed_dim,
hidden_ques_dim,
hidden_ans_dim,
hidden_hist_dim,
hidden_cap_dim,
hidden_img_dim))
self.hidden_ques_dim = hidden_ques_dim
self.hidden_ans_dim = hidden_ans_dim
self.hidden_cap_dim = hidden_cap_dim
self.hidden_img_dim = hidden_img_dim
self.hidden_hist_dim = hidden_hist_dim
# Vocab of History LSTMs is one more as we are keeping a stop id (the last id)
self.word_embedddings = nn.Embedding(vocab_size+1+1, word_embed_dim, padding_idx=0)
self.lstm_ques = nn.LSTM(word_embed_dim, self.hidden_ques_dim, batch_first=True)
self.lstm_ans = nn.LSTM(word_embed_dim, self.hidden_ans_dim, batch_first=True)
self.lstm_hist_ques = nn.LSTM(word_embed_dim, self.hidden_hist_dim, batch_first=True)
self.lstm_hist_ans = nn.LSTM(word_embed_dim, self.hidden_hist_dim, batch_first=True)
self.lstm_hist_cap = nn.LSTM(word_embed_dim, self.hidden_cap_dim, batch_first=True)
self.qahistnet = nn.Sequential(
nn.Linear(self.hidden_hist_dim*2, self.hidden_hist_dim),
nn.ReLU(inplace=True)
)
self.concat_dim = self.hidden_ques_dim + self.hidden_ans_dim + \
self.hidden_ans_dim + self.hidden_img_dim + \
self.hidden_cap_dim + self.hidden_hist_dim*9
self.simnet = nn.Sequential(
nn.Linear(self.concat_dim, (self.concat_dim)//2, bias=False),
nn.BatchNorm1d((self.concat_dim) // 2),
nn.ReLU(inplace=True),
nn.Linear((self.concat_dim)//2, (self.concat_dim)//4, bias=False),
nn.BatchNorm1d((self.concat_dim) // 4),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear((self.concat_dim)//4, 1)
)
# To share weights, provide list of tuples: (idx, list of connected utils)
# Note, for efficiency, the shared utils (i.e., history, are connected to ans and question only.
# connecting shared factors is not supported (!)
sharing_factor_weights = {4: (9, [0, 1]),
5: (9, [0, 1])}
self.mul_atten = Atten(util_e=[self.hidden_ans_dim, # Answer modal
self.hidden_ques_dim, # Question modal
self.hidden_cap_dim, # Caption modal
self.hidden_img_dim, # Image modal
self.hidden_hist_dim, # Question-history modal
self.hidden_hist_dim # Answer-history modal
],
sharing_factor_weights=sharing_factor_weights,
sizes=[100, # 100 Answers
21, # Question length
41, # Caption length
37, # 36 Image regions
21, # History-Question length
21 # History-Answer length
] # The spatial dim used for pairwise normalization (use force for adaptive)
, prior_flag=True,
pairwise_flag=True)
def forward(self, input_ques, input_ans, input_hist_ques, input_hist_ans, input_hist_cap,
input_ques_length, input_ans_length, input_cap_length, i_e):
"""
:param input_ques:
:param input_ans:
:param input_hist_ques:
:param input_hist_ans:
:param input_hist_cap:
:param input_ques_length:
:param input_ans_length:
:param input_cap_length:
:param i_e:
:return:
"""
n_options = input_ans.size()[1]
batch_size = input_ques.size()[0]
nqa_per_dial, nwords_per_qa = input_hist_ques.size()[1], input_hist_ques.size()[2]
nwords_per_cap = input_hist_cap.size()[1]
max_length_input_ans = input_ans.size()[-1]
assert batch_size == input_hist_ques.size()[0] == input_hist_ans.size()[0] == input_ques.size()[0] == \
input_ans.size()[0] == input_hist_cap.size()[0]
assert nqa_per_dial == input_hist_ques.size()[1] == input_hist_ans.size()[1]
assert nwords_per_qa == input_hist_ques.size()[2] == input_hist_ans.size()[2]
q_we = self.word_embedddings(input_ques)
a_we = self.word_embedddings(input_ans.view(-1, max_length_input_ans))
hq_we = self.word_embedddings(input_hist_ques.view(-1, nwords_per_qa))
ha_we = self.word_embedddings(input_hist_ans.view(-1, nwords_per_qa))
c_we = self.word_embedddings(input_hist_cap.view(-1, nwords_per_cap))
'''
q_we = batch x 20 x embed_ques_dim
a_we = 100*batch x 20 x embed_ans_dim
hq_we = batch*nqa_per_dial, nwords_per_qa, embed_hist_dim
ha_we = batch*nqa_per_dial, nwords_per_qa, embed_hist_dim
c_we = batch*ncap_per_dial, nwords_per_cap, embed_hist_dim
'''
self.lstm_ques.flatten_parameters()
self.lstm_ans.flatten_parameters()
self.lstm_hist_ques.flatten_parameters()
self.lstm_hist_ans.flatten_parameters()
self.lstm_hist_cap.flatten_parameters()
i_feat = i_e
q_seq, self.hidden_ques = self.lstm_ques(q_we)
a_seq, self.hidden_ans = self.lstm_ans(a_we)
hq_seq, self.hidden_hist_ques = self.lstm_hist_ques(hq_we)
ha_seq, self.hidden_hist_ans = self.lstm_hist_ans(ha_we)
cap_seq, self.hidden_cap = self.lstm_hist_cap(c_we)
'''
length is used for attention prior
'''
q_len = input_ques_length.data - 1
c_len = input_cap_length.data.view(-1) - 1
ans_index = torch.arange(0, n_options * batch_size).long().cuda()
ans_len = input_ans_length.data.view(-1) - 1
ans_seq = a_seq[ans_index, ans_len, :]
ans_seq = ans_seq.view(batch_size, n_options, self.hidden_ans_dim)
batch_index = torch.arange(0, batch_size).long().cuda()
q_prior = torch.zeros(batch_size, q_seq.size(1)).cuda()
q_prior[batch_index, q_len] = 100
c_prior = torch.zeros(batch_size, cap_seq.size(1)).cuda()
c_prior[batch_index, c_len] = 100
ans_prior = torch.ones(batch_size, ans_seq.size(1)).cuda()
img_prior = torch.ones(batch_size, i_feat.size(1)).cuda()
(ans_atten, ques_atten, cap_atten, img_atten, hq_atten, ha_atten) = \
self.mul_atten([ans_seq, q_seq, cap_seq, i_feat, hq_seq, ha_seq],
priors=[ans_prior, q_prior, c_prior, img_prior, None, None])
'''
expand to answers based
'''
ques_atten = torch.unsqueeze(ques_atten, 1).expand(batch_size,
n_options,
self.hidden_ques_dim)
cap_atten = torch.unsqueeze(cap_atten, 1).expand(batch_size,
n_options,
self.hidden_cap_dim)
img_atten = torch.unsqueeze(img_atten, 1).expand(batch_size, n_options,
self.hidden_img_dim)
ans_atten = torch.unsqueeze(ans_atten, 1).expand(batch_size, n_options,
self.hidden_ans_dim)
'''
combine history
'''
input_qahistnet = torch.cat((hq_atten, ha_atten), 1)
# input_qahistnet: (nqa_per_dial*batch x 2*hidden_hist_dim)
output_qahistnet = self.qahistnet(input_qahistnet)
# output_qahistnet: (nqa_per_dial*batch x hidden_hist_dim)
output_qahistnet = output_qahistnet.view(batch_size,
nqa_per_dial * self.hidden_hist_dim)
# output_qahistnet: (batch x nqa_per_dial*hidden_hist_dim)
output_qahistnet = torch.unsqueeze(output_qahistnet, 1)\
.expand(batch_size,
n_options,
nqa_per_dial * self.hidden_hist_dim)
input_qa = torch.cat((ans_seq, ques_atten, ans_atten, img_atten,
output_qahistnet, cap_atten), 2) # Concatenate last dimension
input_qa = input_qa.view(batch_size * n_options, self.concat_dim)
out_scores = self.simnet(input_qa)
out_scores = out_scores.squeeze(dim=1)
out_scores = out_scores.view(batch_size, n_options)
return out_scores