guyyariv
AudioTokenDemo
1b92e8f
#!/usr/bin/env python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from itertools import product, permutations, combinations_with_replacement, chain
class Unary(nn.Module):
def __init__(self, embed_size):
"""
Captures local entity information
:param embed_size: the embedding dimension
"""
super(Unary, self).__init__()
self.embed = nn.Conv1d(embed_size, embed_size, 1)
self.feature_reduce = nn.Conv1d(embed_size, 1, 1)
def forward(self, X):
X = X.transpose(1, 2)
X_embed = self.embed(X)
X_nl_embed = F.dropout(F.relu(X_embed))
X_poten = self.feature_reduce(X_nl_embed)
return X_poten.squeeze(1)
class Pairwise(nn.Module):
def __init__(self, embed_x_size, x_spatial_dim=None, embed_y_size=None, y_spatial_dim=None):
"""
Captures interaction between utilities or entities of the same utility
:param embed_x_size: the embedding dimension of the first utility
:param x_spatial_dim: the spatial dimension of the first utility for batch norm and weighted marginalization
:param embed_y_size: the embedding dimension of the second utility (none for self-interactions)
:param y_spatial_dim: the spatial dimension of the second utility for batch norm and weighted marginalization
"""
super(Pairwise, self).__init__()
embed_y_size = embed_y_size if y_spatial_dim is not None else embed_x_size
self.y_spatial_dim = y_spatial_dim if y_spatial_dim is not None else x_spatial_dim
self.embed_size = max(embed_x_size, embed_y_size)
self.x_spatial_dim = x_spatial_dim
self.embed_X = nn.Conv1d(embed_x_size, self.embed_size, 1)
self.embed_Y = nn.Conv1d(embed_y_size, self.embed_size, 1)
if x_spatial_dim is not None:
self.normalize_S = nn.BatchNorm1d(self.x_spatial_dim * self.y_spatial_dim)
self.margin_X = nn.Conv1d(self.y_spatial_dim, 1, 1)
self.margin_Y = nn.Conv1d(self.x_spatial_dim, 1, 1)
def forward(self, X, Y=None):
X_t = X.transpose(1, 2)
Y_t = Y.transpose(1, 2) if Y is not None else X_t
X_embed = self.embed_X(X_t)
Y_embed = self.embed_Y(Y_t)
X_norm = F.normalize(X_embed)
Y_norm = F.normalize(Y_embed)
S = X_norm.transpose(1, 2).bmm(Y_norm)
if self.x_spatial_dim is not None:
S = self.normalize_S(S.view(-1, self.x_spatial_dim * self.y_spatial_dim)) \
.view(-1, self.x_spatial_dim, self.y_spatial_dim)
X_poten = self.margin_X(S.transpose(1, 2)).transpose(1, 2).squeeze(2)
Y_poten = self.margin_Y(S).transpose(1, 2).squeeze(2)
else:
X_poten = S.mean(dim=2, keepdim=False)
Y_poten = S.mean(dim=1, keepdim=False)
if Y is None:
return X_poten
else:
return X_poten, Y_poten
class Atten(nn.Module):
def __init__(self, util_e, sharing_factor_weights=[], prior_flag=False,
sizes=[], size_force=False, pairwise_flag=True,
unary_flag=True, self_flag=True):
"""
The class performs an attention on a given list of utilities representation.
:param util_e: the embedding dimensions
:param sharing_factor_weights: To share weights, provide a dict of tuples:
{idx: (num_utils, connected utils)
Note, for efficiency, the shared utils (i.e., history, are connected to ans
and question only.
TODO: connections between shared utils
:param prior_flag: is prior factor provided
:param sizes: the spatial simension (used for batch-norm and weighted marginalization)
:param size_force: force spatial size with adaptive avg pooling.
:param pairwise_flag: use pairwise interaction between utilities
:param unary_flag: use local information
:param self_flag: use self interactions between utilitie's entities
"""
super(Atten, self).__init__()
self.util_e = util_e
self.prior_flag = prior_flag
self.n_utils = len(util_e)
self.spatial_pool = nn.ModuleDict()
self.un_models = nn.ModuleList()
self.self_flag = self_flag
self.pairwise_flag = pairwise_flag
self.unary_flag = unary_flag
self.size_force = size_force
if len(sizes) == 0:
sizes = [None for _ in util_e]
self.sharing_factor_weights = sharing_factor_weights
#force the provided size
for idx, e_dim in enumerate(util_e):
self.un_models.append(Unary(e_dim))
if self.size_force:
self.spatial_pool[str(idx)] = nn.AdaptiveAvgPool1d(sizes[idx])
#Pairwise
self.pp_models = nn.ModuleDict()
for ((idx1, e_dim_1), (idx2, e_dim_2)) \
in combinations_with_replacement(enumerate(util_e), 2):
# self
if self.self_flag and idx1 == idx2:
self.pp_models[str(idx1)] = Pairwise(e_dim_1, sizes[idx1])
else:
if pairwise_flag:
if idx1 in self.sharing_factor_weights:
# not connected
if idx2 not in self.sharing_factor_weights[idx1][1]:
continue
if idx2 in self.sharing_factor_weights:
# not connected
if idx1 not in self.sharing_factor_weights[idx2][1]:
continue
self.pp_models[str((idx1, idx2))] = Pairwise(e_dim_1, sizes[idx1], e_dim_2, sizes[idx2])
# Handle reduce potentials (with scalars)
self.reduce_potentials = nn.ModuleList()
self.num_of_potentials = dict()
self.default_num_of_potentials = 0
if self.self_flag:
self.default_num_of_potentials += 1
if self.unary_flag:
self.default_num_of_potentials += 1
if self.prior_flag:
self.default_num_of_potentials += 1
for idx in range(self.n_utils):
self.num_of_potentials[idx] = self.default_num_of_potentials
'''
All other utilities
'''
if pairwise_flag:
for idx, (num_utils, connected_utils) in sharing_factor_weights:
for c_u in connected_utils:
self.num_of_potentials[c_u] += num_utils
self.num_of_potentials[idx] += 1
for k in self.num_of_potentials:
if k not in self.sharing_factor_weights:
self.num_of_potentials[k] += (self.n_utils - 1) \
- len(sharing_factor_weights)
for idx in range(self.n_utils):
self.reduce_potentials.append(nn.Conv1d(self.num_of_potentials[idx],
1, 1, bias=False))
def forward(self, utils, priors=None):
assert self.n_utils == len(utils)
assert (priors is None and not self.prior_flag) \
or (priors is not None
and self.prior_flag
and len(priors) == self.n_utils)
b_size = utils[0].size(0)
util_factors = dict()
attention = list()
#Force size, constant size is used for pairwise batch normalization
if self.size_force:
for i, (num_utils, _) in self.sharing_factor_weights.items():
if str(i) not in self.spatial_pool.keys():
continue
else:
high_util = utils[i]
high_util = high_util.view(num_utils * b_size, high_util.size(2), high_util.size(3))
high_util = high_util.transpose(1, 2)
utils[i] = self.spatial_pool[str(i)](high_util).transpose(1, 2)
for i in range(self.n_utils):
if i in self.sharing_factor_weights \
or str(i) not in self.spatial_pool.keys():
continue
utils[i] = utils[i].transpose(1, 2)
utils[i] = self.spatial_pool[str(i)](utils[i]).transpose(1, 2)
if self.prior_flag and priors[i] is not None:
priors[i] = self.spatial_pool[str(i)](priors[i].unsqueeze(1)).squeeze(1)
# handle Shared weights
for i, (num_utils, connected_list) in self.sharing_factor_weights:
if self.unary_flag:
util_factors.setdefault(i, []).append(self.un_models[i](utils[i]))
if self.self_flag:
util_factors.setdefault(i, []).append(self.pp_models[str(i)](utils[i]))
if self.pairwise_flag:
for j in connected_list:
other_util = utils[j]
expanded_util = other_util.unsqueeze(1).expand(b_size,
num_utils,
other_util.size(1),
other_util.size(2)).contiguous().view(
b_size * num_utils,
other_util.size(1),
other_util.size(2))
if i < j:
factor_ij, factor_ji = self.pp_models[str((i, j))](utils[i], expanded_util)
else:
factor_ji, factor_ij = self.pp_models[str((j, i))](expanded_util, utils[i])
util_factors[i].append(factor_ij)
util_factors.setdefault(j, []).append(factor_ji.view(b_size, num_utils, factor_ji.size(1)))
# handle local factors
for i in range(self.n_utils):
if i in self.sharing_factor_weights:
continue
if self.unary_flag:
util_factors.setdefault(i, []).append(self.un_models[i](utils[i]))
if self.self_flag:
util_factors.setdefault(i, []).append(self.pp_models[str(i)](utils[i]))
# joint
if self.pairwise_flag:
for (i, j) in combinations_with_replacement(range(self.n_utils), 2):
if i in self.sharing_factor_weights \
or j in self.sharing_factor_weights:
continue
if i == j:
continue
else:
factor_ij, factor_ji = self.pp_models[str((i, j))](utils[i], utils[j])
util_factors.setdefault(i, []).append(factor_ij)
util_factors.setdefault(j, []).append(factor_ji)
# perform attention
for i in range(self.n_utils):
if self.prior_flag:
prior = priors[i] \
if priors[i] is not None \
else torch.zeros_like(util_factors[i][0], requires_grad=False).cuda()
util_factors[i].append(prior)
util_factors[i] = torch.cat([p if len(p.size()) == 3 else p.unsqueeze(1)
for p in util_factors[i]], dim=1)
util_factors[i] = self.reduce_potentials[i](util_factors[i]).squeeze(1)
util_factors[i] = F.softmax(util_factors[i], dim=1).unsqueeze(2)
attention.append(torch.bmm(utils[i].transpose(1, 2), util_factors[i]).squeeze(2))
return attention
class NaiveAttention(nn.Module):
def __init__(self):
"""
Used for ablation analysis - removing attention.
"""
super(NaiveAttention, self).__init__()
def forward(self, utils, priors):
atten = []
spatial_atten = []
for u, p in zip(utils, priors):
if type(u) is tuple:
u = u[1]
num_elements = u.shape[0]
if p is not None:
u = u.view(-1, u.shape[-2], u.shape[-1])
p = p.view(-1, p.shape[-2], p.shape[-1])
spatial_atten.append(
torch.bmm(p.transpose(1, 2), u).squeeze(2).view(num_elements, -1, u.shape[-2], u.shape[-1]))
else:
spatial_atten.append(u.mean(2))
continue
if p is not None:
atten.append(torch.bmm(u.transpose(1, 2), p.unsqueeze(2)).squeeze(2))
else:
atten.append(u.mean(1))
return atten, spatial_atten