Babel / Optimus /code /modules /spacefusion.py
rynmurdock's picture
init
c5ca37a
from .vae import VAE
import numpy as np
import torch, copy, pdb
import torch.nn.functional as F
from torch import nn
import pdb
def set_trainable(module, value):
for param in module.parameters():
param.requires_grad = value
class SpaceFusion(VAE):
def __init__(self, encoder, decoder, tokenizer_encoder, tokenizer_decoder, args):
super(SpaceFusion, self).__init__(encoder, decoder, tokenizer_encoder, tokenizer_decoder, args)
children = [v for v in encoder.encoder.layer.children()] # list of 12 BertLayer
self.num_s2s_bert_layer = args.num_s2s_bert_layer
self.S2S_layers = nn.ModuleList([copy.deepcopy(c) for c in children[-args.num_s2s_bert_layer:] ]) # the last layer of encoder
self.S2S_pooler = copy.deepcopy(encoder.pooler)
self.ix_turn_sep = tokenizer_encoder.convert_tokens_to_ids('[SEP]')
if args.freeze_bert:
print('@'*20 + f' freezing BERT {args.num_frozen_bert_layer} layers')
for child in children[:args.num_frozen_bert_layer]:
set_trainable(child, False)
def ids2speaker(self, ids):
# 0 for speaker A, 1 for speaker B
N, T = ids.shape
speaker = np.zeros((N, T))
sep = ids == self.ix_turn_sep
for i in range(N):
is_B = False # start with speaker A
for t in range(T):
speaker[i,t] = int(is_B)
if sep[i,t].item():
is_B = not is_B
# make sure the final speaker is speaker B (so response is always speaker A)
if not is_B:
speaker = 1 - speaker
return torch.LongTensor(speaker).to(ids.device)
def forward(self, inputs_src, inputs_tgt, labels_tgt, return_vec=False): # [batch, time]
# toggle config to get desired encoder output
self.encoder.encoder.output_attentions = False
self.encoder.encoder.output_hidden_states = True
# AE encoder
mask = (inputs_tgt > 0).float().to(inputs_src.device)
outputs = self.encoder(inputs_tgt, attention_mask=mask)
z_AE, _ = self.connect(outputs[1])
z_AE = z_AE.squeeze(1)
# S2S encoder
mask = (inputs_src > 0).float()
speaker = self.ids2speaker(inputs_src)
outputs = self.encoder(inputs_src, attention_mask=mask, token_type_ids=speaker)
_, _, all_layer_attn = outputs # last_layer_attn, pooled, all_layer_attn = outputs
seq_z_prev = all_layer_attn[-self.num_s2s_bert_layer-1] # seq of z at layer 11 ()
for s2s in self.S2S_layers:
layer_outputs = s2s(seq_z_prev, attention_mask=mask.unsqueeze(1).unsqueeze(1))
seq_z_prev = layer_outputs[0]
z_S2S = self.encoder.pooler(layer_outputs[0])
z_S2S, _ = self.connect(z_S2S)
z_S2S = z_S2S.squeeze(1)
if return_vec:
return z_AE, z_S2S
# interpolation/smoothness
u = torch.FloatTensor(np.random.random((z_AE.shape[0], 1))).to(inputs_tgt.device)
z_interp = u * z_AE + (1 - u) * z_S2S
std = 0.1
noise = torch.FloatTensor(np.random.normal(size=z_interp.shape) * std).to(z_interp.device)
z_interp = z_interp + noise
loss_rec = 0
z_idx = 0
for z in [z_AE, z_S2S, z_interp]:
#pdb.set_trace()
past = z # past = self.decoder.linear(z)
outputs = self.decoder(input_ids=labels_tgt, past=past, labels=labels_tgt, label_ignore=self.pad_token_id)
if z_idx == 1:
loss_rec = loss_rec + 1.0 * outputs[0]
else:
loss_rec = loss_rec + outputs[0]
z_idx += 1
loss_rec = loss_rec/3
# fusion/regularization
L_pull = self.dist_pair(z_AE, z_S2S)
L_push = torch.stack([self.dist_batch(z) for z in [z_AE, z_S2S]]).min()
loss_reg = (L_pull - L_push * 2) / np.sqrt(z.shape[-1])
loss = loss_rec + self.args.beta * loss_reg
return loss_rec, loss_reg, loss
def sent2latent(self, inputs_src):
# toggle config to get desired encoder output
self.encoder.encoder.output_attentions = False
self.encoder.encoder.output_hidden_states = True
# S2S encoder
mask = (inputs_src > 0).float()
speaker = self.ids2speaker(inputs_src)
outputs = self.encoder(inputs_src, attention_mask=mask, token_type_ids=speaker)
_, _, all_layer_attn = outputs # last_layer_attn, pooled, all_layer_attn = outputs
# seq_z_prev = all_layer_attn[-2] # seq of z at layer 11 ()
# layer_outputs = self.S2S_layer(seq_z_prev, attention_mask=mask.unsqueeze(1).unsqueeze(1))
seq_z_prev = all_layer_attn[-self.num_s2s_bert_layer-1] # seq of z at layer 11 ()
for s2s in self.S2S_layers:
layer_outputs = s2s(seq_z_prev, attention_mask=mask.unsqueeze(1).unsqueeze(1))
seq_z_prev = layer_outputs[0]
z_S2S = self.encoder.pooler(layer_outputs[0])
z_S2S, _ = self.connect(z_S2S)
z_S2S = z_S2S.squeeze(1)
return z_S2S
def dist_pair(self, a, b):
return F.pairwise_distance(a, b).mean()
def dist_batch(self, vec):
n = vec.shape[0]
dmin = []
for i in range(n):
dd = F.pairwise_distance(vec[i:i+1,:].repeat(n,1), vec)
dmin.append(dd.min())
return torch.stack(dmin).mean()