Xingqian Xu
New app first commit
2fbcf51
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import numpy.random as npr
import copy
from lib.model_zoo.common.get_model import get_model, register
from lib.model_zoo.common import utils
from .optimus_models.tokenization_gpt2 import GPT2Tokenizer
symbol = 'optimus'
@register('optimus_vae')
class optimus_vae(nn.Module):
"""VAE with normal prior"""
def __init__(self, encoder, decoder, tokenizer_encoder, tokenizer_decoder, args): #
super().__init__()
self.encoder = encoder if isinstance(encoder, nn.Module) else get_model()(encoder)
self.decoder = decoder if isinstance(decoder, nn.Module) else get_model()(decoder)
self.tokenizer_encoder = tokenizer_encoder \
if isinstance(tokenizer_encoder, nn.Module) \
else get_model()(tokenizer_encoder, verbose=False)
self.tokenizer_decoder = tokenizer_decoder \
if isinstance(tokenizer_decoder, nn.Module) \
else get_model()(tokenizer_decoder, verbose=False)
gpt2_special_tokens_dict = {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>'}
if isinstance(self.tokenizer_encoder, GPT2Tokenizer):
self.tokenizer_encoder.add_special_tokens(gpt2_special_tokens_dict)
if isinstance(self.tokenizer_decoder, GPT2Tokenizer):
self.tokenizer_decoder.add_special_tokens(gpt2_special_tokens_dict)
self.args = args
self.nz = args.latent_size
self.eos_token_id = self.tokenizer_decoder.convert_tokens_to_ids(
[self.tokenizer_decoder.eos_token])[0]
self.pad_token_id = self.tokenizer_decoder.convert_tokens_to_ids(
[self.tokenizer_decoder.pad_token])[0]
# connector: from Bert hidden units to the latent space
# self.linear = nn.Linear(args.nz, 2 * args.nz, bias=False)
# Standard Normal prior
loc = torch.zeros(self.nz)
scale = torch.ones(self.nz)
self.prior = torch.distributions.normal.Normal(loc, scale)
def connect(self, bert_fea, nsamples=1):
"""
Returns: Tensor1, Tensor2
Tensor1: the tensor latent z with shape [batch, nsamples, nz]
Tensor2: the tenor of KL for each x with shape [batch]
"""
# (batch_size, nz)
mean, logvar = self.encoder.linear(bert_fea).chunk(2, -1)
# pdb.set_trace()
# mean, logvar = mean.squeeze(0), logvar.squeeze(0)
# (batch, nsamples, nz)
z = self.reparameterize(mean, logvar, nsamples)
KL = 0.5 * (mean.pow(2) + logvar.exp() - logvar - 1).sum(dim=1)
return z, KL
def connect_deterministic(self, bert_fea, nsamples=1):
"""
Returns: Tensor1, Tensor2
Tensor1: the tensor latent z with shape [batch, nsamples, nz]
Tensor2: the tenor of KL for each x with shape [batch]
"""
# (batch_size, nz)
mean, logvar = self.encoder.linear(bert_fea).chunk(2, -1)
# pdb.set_trace()
# mean, logvar = mean.squeeze(0), logvar.squeeze(0)
logvar.fill_(.0)
# (batch, nsamples, nz)
z = self.reparameterize(mean, logvar, nsamples)
KL = 0.5 * (mean.pow(2) + logvar.exp() - logvar - 1).sum(dim=1)
return z, KL
def reparameterize(self, mu, logvar, nsamples=1):
"""sample from posterior Gaussian family
Args:
mu: Tensor
Mean of gaussian distribution with shape (batch, nz)
logvar: Tensor
logvar of gaussian distibution with shape (batch, nz)
Returns: Tensor
Sampled z with shape (batch, nsamples, nz)
"""
batch_size, nz = mu.size()
std = logvar.mul(0.5).exp()
mu_expd = mu.unsqueeze(1).expand(batch_size, nsamples, nz)
std_expd = std.unsqueeze(1).expand(batch_size, nsamples, nz)
eps = torch.zeros_like(std_expd).normal_()
return mu_expd + torch.mul(eps, std_expd)
def forward(self, inputs, labels):
# pdb.set_trace()
attention_mask=(inputs > 0).float()
# logger.info(inputs)
# logger.info(attention_mask)
# logger.info(labels)
reconstrution_mask=(labels != 50257).float() # 50257 is the padding token for GPT2
sent_length = torch.sum(reconstrution_mask, dim=1)
outputs = self.encoder(inputs, attention_mask)
pooled_hidden_fea = outputs[1] # model outputs are always tuple in pytorch-transformers (see doc)
if self.args.fb_mode==0:
# Connect hidden feature to the latent space
latent_z, loss_kl = self.connect(pooled_hidden_fea)
latent_z = latent_z.squeeze(1)
# Decoding
outputs = self.decoder(input_ids=labels, past=latent_z, labels=labels, label_ignore=self.pad_token_id)
loss_rec = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
elif self.args.fb_mode==1:
# Connect hidden feature to the latent space
mu, logvar = self.encoder.linear(pooled_hidden_fea).chunk(2, -1)
latent_z = self.reparameterize(mu, logvar, nsamples=1)
latent_z = latent_z.squeeze(1)
loss_kl = 0.5 * (mu.pow(2) + logvar.exp() - logvar - 1)
kl_mask = (loss_kl > self.args.dim_target_kl).float()
loss_kl = (kl_mask * loss_kl).sum(dim=1)
# pdb.set_trace()
# past = self.decoder.linear(latent_z)
# Decoding
outputs = self.decoder(input_ids=labels, past=latent_z, labels=labels, label_ignore=self.pad_token_id)
loss_rec = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
elif self.args.fb_mode==2:
# Connect hidden feature to the latent space
latent_z, loss_kl = self.connect_deterministic(pooled_hidden_fea)
latent_z = latent_z.squeeze(1)
# past = self.decoder.linear(latent_z)
# Decoding
outputs = self.decoder(input_ids=labels, past=latent_z, labels=labels, label_ignore=self.pad_token_id)
loss_rec = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
# pdb.set_trace()
if self.args.length_weighted_loss:
loss = loss_rec / sent_length + self.args.beta * loss_kl
else:
loss = loss_rec + self.args.beta * loss_kl
return loss_rec, loss_kl, loss
def encoder_sample(self, bert_fea, nsamples):
"""sampling from the encoder
Returns: Tensor1
Tensor1: the tensor latent z with shape [batch, nsamples, nz]
"""
# (batch_size, nz)
mu, logvar = self.encoder.linear(bert_fea).chunk(2, -1)
mu, logvar = mu.squeeze(0), logvar.squeeze(0)
# (batch, nsamples, nz)
z = self.reparameterize(mu, logvar, nsamples)
return z, (mu, logvar)
def encode_stats(self, x):
"""
Returns: Tensor1, Tensor2
Tensor1: the mean of latent z with shape [batch, nz]
Tensor2: the logvar of latent z with shape [batch, nz]
"""
return self.encoder.encode_stats(x)
def decode(self, z, strategy, K=10):
"""generate samples from z given strategy
Args:
z: [batch, nsamples, nz]
strategy: "beam" or "greedy" or "sample"
K: the beam width parameter
Returns: List1
List1: a list of decoded word sequence
"""
if strategy == "beam":
return self.decoder.beam_search_decode(z, K)
elif strategy == "greedy":
return self.decoder.greedy_decode(z)
elif strategy == "sample":
return self.decoder.sample_decode(z)
else:
raise ValueError("the decoding strategy is not supported")
def reconstruct(self, x, decoding_strategy="greedy", K=5):
"""reconstruct from input x
Args:
x: (batch, *)
decoding_strategy: "beam" or "greedy" or "sample"
K: the beam width parameter
Returns: List1
List1: a list of decoded word sequence
"""
z = self.sample_from_inference(x).squeeze(1)
return self.decode(z, decoding_strategy, K)
def log_probability(self, x, z):
"""Cross Entropy in the language case
Args:
x: (batch_size, seq_len)
z: (batch_size, n_sample, nz)
Returns:
log_p: (batch_size, n_sample).
log_p(x|z) across different x and z
"""
outputs = self.decoder(input_ids=x, past=z, labels=x, label_ignore=self.pad_token_id)
loss_rec = outputs[0]
return -loss_rec
def loss_iw(self, x0, x1, nsamples=50, ns=1):
"""
Args:
x: if the data is constant-length, x is the data tensor with
shape (batch, *). Otherwise x is a tuple that contains
the data tensor and length list
Returns: Tensor1, Tensor2, Tensor3
Tensor1: total loss [batch]
Tensor2: reconstruction loss shape [batch]
Tensor3: KL loss shape [batch]
"""
# encoding into bert features
bert_fea = self.encoder(x0)[1]
# (batch_size, nz)
mu, logvar = self.encoder.linear(bert_fea).chunk(2, -1)
##################
# compute KL
##################
# pdb.set_trace()
KL = 0.5 * (mu.pow(2) + logvar.exp() - logvar - 1).sum(dim=1)
# mu, logvar = mu.squeeze(0), logvar.squeeze(0)
ll_tmp, rc_tmp = [], []
for _ in range(int(nsamples / ns)):
# (batch, nsamples, nz)
z = self.reparameterize(mu, logvar, ns)
# past = self.decoder.linear(z)
past = z
# [batch, nsamples]
log_prior = self.eval_prior_dist(z)
log_gen = self.eval_cond_ll(x1, past)
log_infer = self.eval_inference_dist(z, (mu, logvar))
# pdb.set_trace()
log_gen = log_gen.unsqueeze(0).contiguous().view(z.shape[0],-1)
# pdb.set_trace()
rc_tmp.append(log_gen)
ll_tmp.append(log_gen + log_prior - log_infer)
log_prob_iw = log_sum_exp(torch.cat(ll_tmp, dim=-1), dim=-1) - math.log(nsamples)
log_gen_iw = torch.mean(torch.cat(rc_tmp, dim=-1), dim=-1)
return log_prob_iw, log_gen_iw , KL
def nll_iw(self, x0, x1, nsamples, ns=1):
"""compute the importance weighting estimate of the log-likelihood
Args:
x0, x1: two different tokenization results of x, where x is the data tensor with shape (batch, *).
nsamples: Int
the number of samples required to estimate marginal data likelihood
Returns: Tensor1
Tensor1: the estimate of log p(x), shape [batch]
"""
# compute iw every ns samples to address the memory issue
# nsamples = 500, ns = 100
# nsamples = 500, ns = 10
# TODO: note that x is forwarded twice in self.encoder.sample(x, ns) and self.eval_inference_dist(x, z, param)
#. this problem is to be solved in order to speed up
tmp = []
for _ in range(int(nsamples / ns)):
# [batch, ns, nz]
# Chunyuan:
# encoding into bert features
pooled_hidden_fea = self.encoder(x0)[1]
# param is the parameters required to evaluate q(z|x)
z, param = self.encoder_sample(pooled_hidden_fea, ns)
# [batch, ns]
log_comp_ll = self.eval_complete_ll(x1, z)
log_infer_ll = self.eval_inference_dist(z, param)
tmp.append(log_comp_ll - log_infer_ll)
ll_iw = log_sum_exp(torch.cat(tmp, dim=-1), dim=-1) - math.log(nsamples)
return ll_iw
def KL(self, x):
_, KL = self.encode(x, 1)
return KL
def eval_prior_dist(self, zrange):
"""perform grid search to calculate the true posterior
Args:
zrange: tensor
different z points that will be evaluated, with
shape (k^2, nz), where k=(zmax - zmin)/space
"""
# (k^2)
return self.prior.log_prob(zrange).sum(dim=-1)
def eval_complete_ll(self, x, z):
"""compute log p(z,x)
Args:
x: Tensor
input with shape [batch, seq_len]
z: Tensor
evaluation points with shape [batch, nsamples, nz]
Returns: Tensor1
Tensor1: log p(z,x) Tensor with shape [batch, nsamples]
"""
# [batch, nsamples]
log_prior = self.eval_prior_dist(z)
log_gen = self.eval_cond_ll(x, z)
return log_prior + log_gen
def eval_cond_ll(self, x, z):
"""compute log p(x|z)
"""
x_shape = list(x.size())
z_shape = list(z.size())
if len(z_shape) == 3:
x = x.unsqueeze(1).repeat(1, z_shape[1], 1).contiguous().view(x_shape[0]*z_shape[1], x_shape[-1])
z = z.contiguous().view(x_shape[0]*z_shape[1], z_shape[-1])
return self.log_probability(x, z)
def eval_log_model_posterior(self, x, grid_z):
"""perform grid search to calculate the true posterior
this function computes p(z|x)
Args:
grid_z: tensor
different z points that will be evaluated, with
shape (k^2, nz), where k=(zmax - zmin)/pace
Returns: Tensor
Tensor: the log posterior distribution log p(z|x) with
shape [batch_size, K^2]
"""
try:
batch_size = x.size(0)
except:
batch_size = x[0].size(0)
# (batch_size, k^2, nz)
grid_z = grid_z.unsqueeze(0).expand(batch_size, *grid_z.size()).contiguous()
# (batch_size, k^2)
log_comp = self.eval_complete_ll(x, grid_z)
# normalize to posterior
log_posterior = log_comp - log_sum_exp(log_comp, dim=1, keepdim=True)
return log_posterior
def sample_from_inference(self, x, nsamples=1):
"""perform sampling from inference net
Returns: Tensor
Tensor: samples from infernece nets with
shape (batch_size, nsamples, nz)
"""
z, _ = self.encoder.sample(x, nsamples)
return z
def sample_from_posterior(self, x, nsamples):
"""perform MH sampling from model posterior
Returns: Tensor
Tensor: samples from model posterior with
shape (batch_size, nsamples, nz)
"""
# use the samples from inference net as initial points
# for MCMC sampling. [batch_size, nsamples, nz]
cur = self.encoder.sample_from_inference(x, 1)
cur_ll = self.eval_complete_ll(x, cur)
total_iter = self.args.mh_burn_in + nsamples * self.args.mh_thin
samples = []
for iter_ in range(total_iter):
next = torch.normal(mean=cur,
std=cur.new_full(size=cur.size(), fill_value=self.args.mh_std))
# [batch_size, 1]
next_ll = self.eval_complete_ll(x, next)
ratio = next_ll - cur_ll
accept_prob = torch.min(ratio.exp(), ratio.new_ones(ratio.size()))
uniform_t = accept_prob.new_empty(accept_prob.size()).uniform_()
# [batch_size, 1]
mask = (uniform_t < accept_prob).float()
mask_ = mask.unsqueeze(2)
cur = mask_ * next + (1 - mask_) * cur
cur_ll = mask * next_ll + (1 - mask) * cur_ll
if iter_ >= self.args.mh_burn_in and (iter_ - self.args.mh_burn_in) % self.args.mh_thin == 0:
samples.append(cur.unsqueeze(1))
return torch.cat(samples, dim=1)
def calc_model_posterior_mean(self, x, grid_z):
"""compute the mean value of model posterior, i.e. E_{z ~ p(z|x)}[z]
Args:
grid_z: different z points that will be evaluated, with
shape (k^2, nz), where k=(zmax - zmin)/pace
x: [batch, *]
Returns: Tensor1
Tensor1: the mean value tensor with shape [batch, nz]
"""
# [batch, K^2]
log_posterior = self.eval_log_model_posterior(x, grid_z)
posterior = log_posterior.exp()
# [batch, nz]
return torch.mul(posterior.unsqueeze(2), grid_z.unsqueeze(0)).sum(1)
def calc_infer_mean(self, x):
"""
Returns: Tensor1
Tensor1: the mean of inference distribution, with shape [batch, nz]
"""
mean, logvar = self.encoder.forward(x)
return mean
def eval_inference_dist(self, z, param):
"""this function computes log q(z | x)
Args:
z: tensor
different z points that will be evaluated, with
shape [batch, nsamples, nz]
Returns: Tensor1
Tensor1: log q(z|x) with shape [batch, nsamples]
"""
nz = z.size(2)
mu, logvar = param
# (batch_size, 1, nz)
mu, logvar = mu.unsqueeze(1), logvar.unsqueeze(1)
var = logvar.exp()
# (batch_size, nsamples, nz)
dev = z - mu
# (batch_size, nsamples)
log_density = -0.5 * ((dev ** 2) / var).sum(dim=-1) - \
0.5 * (nz * math.log(2 * math.pi) + logvar.sum(-1))
return log_density
def calc_mi(self, test_data_batch, args):
# calc_mi_v3
import math
from modules.utils import log_sum_exp
mi = 0
num_examples = 0
mu_batch_list, logvar_batch_list = [], []
neg_entropy = 0.
for batch_data in test_data_batch:
x0, _, _ = batch_data
x0 = x0.to(args.device)
# encoding into bert features
bert_fea = self.encoder(x0)[1]
(batch_size, nz)
mu, logvar = self.encoder.linear(bert_fea).chunk(2, -1)
x_batch, nz = mu.size()
#print(x_batch, end=' ')
num_examples += x_batch
# E_{q(z|x)}log(q(z|x)) = -0.5*nz*log(2*\pi) - 0.5*(1+logvar).sum(-1)
neg_entropy += (-0.5 * nz * math.log(2 * math.pi)- 0.5 * (1 + logvar).sum(-1)).sum().item()
mu_batch_list += [mu.cpu()]
logvar_batch_list += [logvar.cpu()]
pdb.set_trace()
neg_entropy = neg_entropy / num_examples
##print()
num_examples = 0
log_qz = 0.
for i in range(len(mu_batch_list)):
###############
# get z_samples
###############
mu, logvar = mu_batch_list[i].cuda(), logvar_batch_list[i].cuda()
# [z_batch, 1, nz]
z_samples = self.reparameterize(mu, logvar, 1)
z_samples = z_samples.view(-1, 1, nz)
num_examples += z_samples.size(0)
###############
# compute density
###############
# [1, x_batch, nz]
#mu, logvar = mu_batch_list[i].cuda(), logvar_batch_list[i].cuda()
#indices = list(np.random.choice(np.arange(len(mu_batch_list)), 10)) + [i]
indices = np.arange(len(mu_batch_list))
mu = torch.cat([mu_batch_list[_] for _ in indices], dim=0).cuda()
logvar = torch.cat([logvar_batch_list[_] for _ in indices], dim=0).cuda()
x_batch, nz = mu.size()
mu, logvar = mu.unsqueeze(0), logvar.unsqueeze(0)
var = logvar.exp()
# (z_batch, x_batch, nz)
dev = z_samples - mu
# (z_batch, x_batch)
log_density = -0.5 * ((dev ** 2) / var).sum(dim=-1) - \
0.5 * (nz * math.log(2 * math.pi) + logvar.sum(-1))
# log q(z): aggregate posterior
# [z_batch]
log_qz += (log_sum_exp(log_density, dim=1) - math.log(x_batch)).sum(-1)
log_qz /= num_examples
mi = neg_entropy - log_qz
return mi
def calc_au(self, eval_dataloader, args, delta=0.01):
"""compute the number of active units
"""
cnt = 0
for batch_data in eval_dataloader:
x0, _, _ = batch_data
x0 = x0.to(args.device)
# encoding into bert features
bert_fea = self.encoder(x0)[1]
# (batch_size, nz)
mean, logvar = self.encoder.linear(bert_fea).chunk(2, -1)
if cnt == 0:
means_sum = mean.sum(dim=0, keepdim=True)
else:
means_sum = means_sum + mean.sum(dim=0, keepdim=True)
cnt += mean.size(0)
# (1, nz)
mean_mean = means_sum / cnt
cnt = 0
for batch_data in eval_dataloader:
x0, _, _ = batch_data
x0 = x0.to(args.device)
# encoding into bert features
bert_fea = self.encoder(x0)[1]
# (batch_size, nz)
mean, _ = self.encoder.linear(bert_fea).chunk(2, -1)
if cnt == 0:
var_sum = ((mean - mean_mean) ** 2).sum(dim=0)
else:
var_sum = var_sum + ((mean - mean_mean) ** 2).sum(dim=0)
cnt += mean.size(0)
# (nz)
au_var = var_sum / (cnt - 1)
return (au_var >= delta).sum().item(), au_var
from .optimus_models.optimus_bert import BertForLatentConnector_XX
@register('optimus_bert_connector')
class optimus_bert_connector(BertForLatentConnector_XX):
pass
from .optimus_models.tokenization_bert import BertTokenizer
@register('optimus_bert_tokenizer')
class optimus_bert_tokenizer(BertTokenizer):
pass
from .optimus_models.optimus_gpt2 import GPT2ForLatentConnector_XX
@register('optimus_gpt2_connector')
class optimus_gpt2_connector(GPT2ForLatentConnector_XX):
pass
from .optimus_models.tokenization_gpt2 import GPT2Tokenizer
@register('optimus_gpt2_tokenizer')
class optimus_gpt2_tokenizer(GPT2Tokenizer):
pass
##############################
# some helpers for inference #
##############################
def sample_single_sequence_conditional(
model,
context,
past=None,
temperature=1,
top_k=0,
top_p=0.0,
eos_token=50829,
max_length=30, ):
past = past.unsqueeze(0)
generated = context.unsqueeze(0)
with torch.no_grad():
while True:
# for _ in trange(length):
inputs = {'input_ids': generated, 'past': past}
outputs = model(**inputs)
next_token_logits = outputs[0][0, -1, :] / temperature
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
if next_token[0].item() == eos_token:
break
if generated.shape[1] >= max_length:
generated[0, -1] = eos_token
break
return generated.squeeze(0)
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (vocabulary size)
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
top_k = min(top_k, logits.size(-1)) # Safety check
if top_k > 0:
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = filter_value
return logits
########################
# compatible to vd 2.0 #
########################
@register('optimus_vae_next')
class optimus_vae_next(optimus_vae):
def get_device(self):
return self.encoder.linear.weight.device
def encode(self, text, max_length=77):
tokenizer = self.tokenizer_encoder
token = [tokenizer.tokenize(sentence.lower()) for sentence in text]
token = [ti[0:max_length] for ti in token]
token_id = []
for tokeni in token:
token_sentence = [tokenizer._convert_token_to_id(i) for i in tokeni]
token_sentence = tokenizer.add_special_tokens_single_sentence(token_sentence)
token_id.append(torch.LongTensor(token_sentence))
token_id = torch._C._nn.pad_sequence(token_id, batch_first=True, padding_value=0.0)
token_id = token_id.to(self.get_device())
z = self.encoder(token_id, attention_mask=(token_id > 0).float())[1]
z_mu, z_logvar = self.encoder.linear(z).chunk(2, -1)
# z_sampled = self.reparameterize(z_mu, z_logvar, 1)
return z_mu.squeeze(1)
@torch.no_grad()
def decode(self, z, temperature=1.0):
bos_token = self.tokenizer_decoder.encode('<BOS>')
eos_token = self.tokenizer_decoder.encode('<EOS>')
context_tokens = torch.LongTensor(bos_token).to(z.device)
sentenses = []
for zi in z:
out = sample_single_sequence_conditional(
model=self.decoder,
context=context_tokens,
past=zi, temperature=temperature,
top_k=0, top_p=1.0,
max_length=30,
eos_token = eos_token[0],)
text = self.tokenizer_decoder.decode(out.tolist(), clean_up_tokenization_spaces=True)
text = text.split()[1:-1]
text = ' '.join(text)
sentenses.append(text)
return sentenses