emotion_recognition / model_LAV.py
nouamanetazi's picture
nouamanetazi HF staff
initial commit
ff43e05
raw
history blame
10 kB
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from layers.fc import MLP
from layers.layer_norm import LayerNorm
# ------------------------------------
# ---------- Masking sequence --------
# ------------------------------------
def make_mask(feature):
return (torch.sum(
torch.abs(feature),
dim=-1
) == 0).unsqueeze(1).unsqueeze(2)
# ------------------------------
# ---------- Flattening --------
# ------------------------------
class AttFlat(nn.Module):
def __init__(self, args, flat_glimpse, merge=False):
super(AttFlat, self).__init__()
self.args = args
self.merge = merge
self.flat_glimpse = flat_glimpse
self.mlp = MLP(
in_size=args.hidden_size,
mid_size=args.ff_size,
out_size=flat_glimpse,
dropout_r=args.dropout_r,
use_relu=True
)
if self.merge:
self.linear_merge = nn.Linear(
args.hidden_size * flat_glimpse,
args.hidden_size * 2
)
def forward(self, x, x_mask):
att = self.mlp(x)
if x_mask is not None:
att = att.masked_fill(
x_mask.squeeze(1).squeeze(1).unsqueeze(2),
-1e9
)
att = F.softmax(att, dim=1)
att_list = []
for i in range(self.flat_glimpse):
att_list.append(
torch.sum(att[:, :, i: i + 1] * x, dim=1)
)
if self.merge:
x_atted = torch.cat(att_list, dim=1)
x_atted = self.linear_merge(x_atted)
return x_atted
return torch.stack(att_list).transpose_(0, 1)
# ------------------------
# ---- Self Attention ----
# ------------------------
class SA(nn.Module):
def __init__(self, args):
super(SA, self).__init__()
self.mhatt = MHAtt(args)
self.ffn = FFN(args)
self.dropout1 = nn.Dropout(args.dropout_r)
self.norm1 = LayerNorm(args.hidden_size)
self.dropout2 = nn.Dropout(args.dropout_r)
self.norm2 = LayerNorm(args.hidden_size)
def forward(self, y, y_mask):
y = self.norm1(y + self.dropout1(
self.mhatt(y, y, y, y_mask)
))
y = self.norm2(y + self.dropout2(
self.ffn(y)
))
return y
# -------------------------------
# ---- Self Guided Attention ----
# -------------------------------
class SGA(nn.Module):
def __init__(self, args):
super(SGA, self).__init__()
self.mhatt1 = MHAtt(args)
self.mhatt2 = MHAtt(args)
self.ffn = FFN(args)
self.dropout1 = nn.Dropout(args.dropout_r)
self.norm1 = LayerNorm(args.hidden_size)
self.dropout2 = nn.Dropout(args.dropout_r)
self.norm2 = LayerNorm(args.hidden_size)
self.dropout3 = nn.Dropout(args.dropout_r)
self.norm3 = LayerNorm(args.hidden_size)
def forward(self, x, y, x_mask, y_mask):
x = self.norm1(x + self.dropout1(
self.mhatt1(v=x, k=x, q=x, mask=x_mask)
))
x = self.norm2(x + self.dropout2(
self.mhatt2(v=y, k=y, q=x, mask=y_mask)
))
x = self.norm3(x + self.dropout3(
self.ffn(x)
))
return x
# ------------------------------
# ---- Multi-Head Attention ----
# ------------------------------
class MHAtt(nn.Module):
def __init__(self, args):
super(MHAtt, self).__init__()
self.args = args
self.linear_v = nn.Linear(args.hidden_size, args.hidden_size)
self.linear_k = nn.Linear(args.hidden_size, args.hidden_size)
self.linear_q = nn.Linear(args.hidden_size, args.hidden_size)
self.linear_merge = nn.Linear(args.hidden_size, args.hidden_size)
self.dropout = nn.Dropout(args.dropout_r)
def forward(self, v, k, q, mask):
n_batches = q.size(0)
v = self.linear_v(v).view(
n_batches,
-1,
self.args.multi_head,
int(self.args.hidden_size / self.args.multi_head)
).transpose(1, 2)
k = self.linear_k(k).view(
n_batches,
-1,
self.args.multi_head,
int(self.args.hidden_size / self.args.multi_head)
).transpose(1, 2)
q = self.linear_q(q).view(
n_batches,
-1,
self.args.multi_head,
int(self.args.hidden_size / self.args.multi_head)
).transpose(1, 2)
atted = self.att(v, k, q, mask)
atted = atted.transpose(1, 2).contiguous().view(
n_batches,
-1,
self.args.hidden_size
)
atted = self.linear_merge(atted)
return atted
def att(self, value, key, query, mask):
d_k = query.size(-1)
scores = torch.matmul(
query, key.transpose(-2, -1)
) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask, -1e9)
att_map = F.softmax(scores, dim=-1)
att_map = self.dropout(att_map)
return torch.matmul(att_map, value)
# ---------------------------
# ---- Feed Forward Nets ----
# ---------------------------
class FFN(nn.Module):
def __init__(self, args):
super(FFN, self).__init__()
self.mlp = MLP(
in_size=args.hidden_size,
mid_size=args.ff_size,
out_size=args.hidden_size,
dropout_r=args.dropout_r,
use_relu=True
)
def forward(self, x):
return self.mlp(x)
# ---------------------------
# ---- FF + norm -----------
# ---------------------------
class FFAndNorm(nn.Module):
def __init__(self, args):
super(FFAndNorm, self).__init__()
self.ffn = FFN(args)
self.norm1 = LayerNorm(args.hidden_size)
self.dropout2 = nn.Dropout(args.dropout_r)
self.norm2 = LayerNorm(args.hidden_size)
def forward(self, x):
x = self.norm1(x)
x = self.norm2(x + self.dropout2(self.ffn(x)))
return x
class Block(nn.Module):
def __init__(self, args, i):
super(Block, self).__init__()
self.args = args
self.sa1 = SA(args)
self.sa2 = SGA(args)
self.sa3 = SGA(args)
self.last = (i == args.layer-1)
if not self.last:
self.att_lang = AttFlat(args, args.lang_seq_len, merge=False)
self.att_audio = AttFlat(args, args.audio_seq_len, merge=False)
self.att_vid = AttFlat(args, args.video_seq_len, merge=False)
self.norm_l = LayerNorm(args.hidden_size)
self.norm_a = LayerNorm(args.hidden_size)
self.norm_v = LayerNorm(args.hidden_size)
self.dropout = nn.Dropout(args.dropout_r)
def forward(self, x, x_mask, y, y_mask, z, z_mask):
ax = self.sa1(x, x_mask)
ay = self.sa2(y, x, y_mask, x_mask)
az = self.sa3(z, x, z_mask, x_mask)
x = ax + x
y = ay + y
z = az + z
if self.last:
return x, y, z
ax = self.att_lang(x, x_mask)
ay = self.att_audio(y, y_mask)
az = self.att_vid(z, y_mask)
return self.norm_l(x + self.dropout(ax)), \
self.norm_a(y + self.dropout(ay)), \
self.norm_v(z + self.dropout(az))
class Model_LAV(nn.Module):
def __init__(self, args, vocab_size, pretrained_emb):
super(Model_LAV, self).__init__()
self.args = args
# LSTM
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=args.word_embed_size
)
# Loading the GloVe embedding weights
self.embedding.weight.data.copy_(torch.from_numpy(pretrained_emb))
self.lstm_x = nn.LSTM(
input_size=args.word_embed_size,
hidden_size=args.hidden_size,
num_layers=1,
batch_first=True
)
# self.lstm_y = nn.LSTM(
# input_size=args.audio_feat_size,
# hidden_size=args.hidden_size,
# num_layers=1,
# batch_first=True
# )
# Feature size to hid size
self.adapter_y = nn.Linear(args.audio_feat_size, args.hidden_size)
self.adapter_z = nn.Linear(args.video_feat_size, args.hidden_size)
# Encoder blocks
self.enc_list = nn.ModuleList([Block(args, i) for i in range(args.layer)])
# Flattenting features before proj
self.attflat_ac = AttFlat(args, 1, merge=True)
self.attflat_vid = AttFlat(args, 1, merge=True)
self.attflat_lang = AttFlat(args, 1, merge=True)
# Classification layers
self.proj_norm = LayerNorm(2 * args.hidden_size)
if self.args.task == "sentiment":
if self.args.task_binary:
self.proj = nn.Linear(2 * args.hidden_size, 2)
else:
self.proj = nn.Linear(2 * args.hidden_size, 7)
if self.args.task == "emotion":
self.proj = self.proj = nn.Linear(2 * args.hidden_size, 6)
def forward(self, x, y, z):
x_mask = make_mask(x.unsqueeze(2))
y_mask = make_mask(y)
z_mask = make_mask(z)
embedding = self.embedding(x)
x, _ = self.lstm_x(embedding)
# y, _ = self.lstm_y(y)
y, z = self.adapter_y(y), self.adapter_z(z)
for i, dec in enumerate(self.enc_list):
x_m, y_m, z_m = None, None, None
if i == 0:
x_m, y_m, z_m = x_mask, y_mask, z_mask
x, y, z = dec(x, x_m, y, y_m, z, z_m)
x = self.attflat_lang(
x,
None
)
y = self.attflat_ac(
y,
None
)
z = self.attflat_vid(
z,
None
)
# Classification layers
proj_feat = x + y + z
proj_feat = self.proj_norm(proj_feat)
ans = self.proj(proj_feat)
return ans