Songyou's picture
add files
f3b11f9
raw
history blame
1.3 kB
import torch
from torch.autograd import Variable
from models.transformer.module.subsequent_mask import subsequent_mask
def decode(model, src, src_mask, max_len, type):
ys = torch.ones(1)
ys = ys.repeat(src.shape[0], 1).view(src.shape[0], 1).type_as(src.data)
# ys shape [batch_size, 1]
encoder_outputs = model.encode(src, src_mask)
break_condition = torch.zeros(src.shape[0], dtype=torch.bool)
for i in range(max_len-1):
with torch.no_grad():
out = model.decode(encoder_outputs, src_mask, Variable(ys),
Variable(subsequent_mask(ys.size(1)).type_as(src.data)))
log_prob = model.generator(out[:, -1])
prob = torch.exp(log_prob)
if type == 'greedy':
_, next_word = torch.max(prob, dim = 1)
ys = torch.cat([ys, next_word.unsqueeze(-1)], dim=1) # [batch_size, i]
elif type == 'multinomial':
next_word = torch.multinomial(prob, 1)
ys = torch.cat([ys, next_word], dim=1) #[batch_size, i]
next_word = torch.squeeze(next_word)
break_condition = (break_condition | (next_word.to('cpu')==2))
if all(break_condition): # end token
break
return ys