Spaces:
Runtime error
Runtime error
from .vae import VAE | |
import numpy as np | |
import torch, copy, pdb | |
import torch.nn.functional as F | |
from torch import nn | |
import pdb | |
def set_trainable(module, value): | |
for param in module.parameters(): | |
param.requires_grad = value | |
class SpaceFusion(VAE): | |
def __init__(self, encoder, decoder, tokenizer_encoder, tokenizer_decoder, args): | |
super(SpaceFusion, self).__init__(encoder, decoder, tokenizer_encoder, tokenizer_decoder, args) | |
children = [v for v in encoder.encoder.layer.children()] # list of 12 BertLayer | |
self.num_s2s_bert_layer = args.num_s2s_bert_layer | |
self.S2S_layers = nn.ModuleList([copy.deepcopy(c) for c in children[-args.num_s2s_bert_layer:] ]) # the last layer of encoder | |
self.S2S_pooler = copy.deepcopy(encoder.pooler) | |
self.ix_turn_sep = tokenizer_encoder.convert_tokens_to_ids('[SEP]') | |
if args.freeze_bert: | |
print('@'*20 + f' freezing BERT {args.num_frozen_bert_layer} layers') | |
for child in children[:args.num_frozen_bert_layer]: | |
set_trainable(child, False) | |
def ids2speaker(self, ids): | |
# 0 for speaker A, 1 for speaker B | |
N, T = ids.shape | |
speaker = np.zeros((N, T)) | |
sep = ids == self.ix_turn_sep | |
for i in range(N): | |
is_B = False # start with speaker A | |
for t in range(T): | |
speaker[i,t] = int(is_B) | |
if sep[i,t].item(): | |
is_B = not is_B | |
# make sure the final speaker is speaker B (so response is always speaker A) | |
if not is_B: | |
speaker = 1 - speaker | |
return torch.LongTensor(speaker).to(ids.device) | |
def forward(self, inputs_src, inputs_tgt, labels_tgt, return_vec=False): # [batch, time] | |
# toggle config to get desired encoder output | |
self.encoder.encoder.output_attentions = False | |
self.encoder.encoder.output_hidden_states = True | |
# AE encoder | |
mask = (inputs_tgt > 0).float().to(inputs_src.device) | |
outputs = self.encoder(inputs_tgt, attention_mask=mask) | |
z_AE, _ = self.connect(outputs[1]) | |
z_AE = z_AE.squeeze(1) | |
# S2S encoder | |
mask = (inputs_src > 0).float() | |
speaker = self.ids2speaker(inputs_src) | |
outputs = self.encoder(inputs_src, attention_mask=mask, token_type_ids=speaker) | |
_, _, all_layer_attn = outputs # last_layer_attn, pooled, all_layer_attn = outputs | |
seq_z_prev = all_layer_attn[-self.num_s2s_bert_layer-1] # seq of z at layer 11 () | |
for s2s in self.S2S_layers: | |
layer_outputs = s2s(seq_z_prev, attention_mask=mask.unsqueeze(1).unsqueeze(1)) | |
seq_z_prev = layer_outputs[0] | |
z_S2S = self.encoder.pooler(layer_outputs[0]) | |
z_S2S, _ = self.connect(z_S2S) | |
z_S2S = z_S2S.squeeze(1) | |
if return_vec: | |
return z_AE, z_S2S | |
# interpolation/smoothness | |
u = torch.FloatTensor(np.random.random((z_AE.shape[0], 1))).to(inputs_tgt.device) | |
z_interp = u * z_AE + (1 - u) * z_S2S | |
std = 0.1 | |
noise = torch.FloatTensor(np.random.normal(size=z_interp.shape) * std).to(z_interp.device) | |
z_interp = z_interp + noise | |
loss_rec = 0 | |
z_idx = 0 | |
for z in [z_AE, z_S2S, z_interp]: | |
#pdb.set_trace() | |
past = z # past = self.decoder.linear(z) | |
outputs = self.decoder(input_ids=labels_tgt, past=past, labels=labels_tgt, label_ignore=self.pad_token_id) | |
if z_idx == 1: | |
loss_rec = loss_rec + 1.0 * outputs[0] | |
else: | |
loss_rec = loss_rec + outputs[0] | |
z_idx += 1 | |
loss_rec = loss_rec/3 | |
# fusion/regularization | |
L_pull = self.dist_pair(z_AE, z_S2S) | |
L_push = torch.stack([self.dist_batch(z) for z in [z_AE, z_S2S]]).min() | |
loss_reg = (L_pull - L_push * 2) / np.sqrt(z.shape[-1]) | |
loss = loss_rec + self.args.beta * loss_reg | |
return loss_rec, loss_reg, loss | |
def sent2latent(self, inputs_src): | |
# toggle config to get desired encoder output | |
self.encoder.encoder.output_attentions = False | |
self.encoder.encoder.output_hidden_states = True | |
# S2S encoder | |
mask = (inputs_src > 0).float() | |
speaker = self.ids2speaker(inputs_src) | |
outputs = self.encoder(inputs_src, attention_mask=mask, token_type_ids=speaker) | |
_, _, all_layer_attn = outputs # last_layer_attn, pooled, all_layer_attn = outputs | |
# seq_z_prev = all_layer_attn[-2] # seq of z at layer 11 () | |
# layer_outputs = self.S2S_layer(seq_z_prev, attention_mask=mask.unsqueeze(1).unsqueeze(1)) | |
seq_z_prev = all_layer_attn[-self.num_s2s_bert_layer-1] # seq of z at layer 11 () | |
for s2s in self.S2S_layers: | |
layer_outputs = s2s(seq_z_prev, attention_mask=mask.unsqueeze(1).unsqueeze(1)) | |
seq_z_prev = layer_outputs[0] | |
z_S2S = self.encoder.pooler(layer_outputs[0]) | |
z_S2S, _ = self.connect(z_S2S) | |
z_S2S = z_S2S.squeeze(1) | |
return z_S2S | |
def dist_pair(self, a, b): | |
return F.pairwise_distance(a, b).mean() | |
def dist_batch(self, vec): | |
n = vec.shape[0] | |
dmin = [] | |
for i in range(n): | |
dd = F.pairwise_distance(vec[i:i+1,:].repeat(n,1), vec) | |
dmin.append(dd.min()) | |
return torch.stack(dmin).mean() |