|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from espnet.lm.lm_utils import make_lexical_tree |
|
from espnet.nets.pytorch_backend.nets_utils import to_device |
|
|
|
|
|
|
|
class MultiLevelLM(nn.Module): |
|
logzero = -10000000000.0 |
|
zero = 1.0e-10 |
|
|
|
def __init__( |
|
self, |
|
wordlm, |
|
subwordlm, |
|
word_dict, |
|
subword_dict, |
|
subwordlm_weight=0.8, |
|
oov_penalty=1.0, |
|
open_vocab=True, |
|
): |
|
super(MultiLevelLM, self).__init__() |
|
self.wordlm = wordlm |
|
self.subwordlm = subwordlm |
|
self.word_eos = word_dict["<eos>"] |
|
self.word_unk = word_dict["<unk>"] |
|
self.var_word_eos = torch.LongTensor([self.word_eos]) |
|
self.var_word_unk = torch.LongTensor([self.word_unk]) |
|
self.space = subword_dict["<space>"] |
|
self.eos = subword_dict["<eos>"] |
|
self.lexroot = make_lexical_tree(word_dict, subword_dict, self.word_unk) |
|
self.log_oov_penalty = math.log(oov_penalty) |
|
self.open_vocab = open_vocab |
|
self.subword_dict_size = len(subword_dict) |
|
self.subwordlm_weight = subwordlm_weight |
|
self.normalized = True |
|
|
|
def forward(self, state, x): |
|
|
|
if state is None: |
|
self.var_word_eos = to_device(x, self.var_word_eos) |
|
self.var_word_unk = to_device(x, self.var_word_eos) |
|
wlm_state, z_wlm = self.wordlm(None, self.var_word_eos) |
|
wlm_logprobs = F.log_softmax(z_wlm, dim=1) |
|
clm_state, z_clm = self.subwordlm(None, x) |
|
log_y = F.log_softmax(z_clm, dim=1) * self.subwordlm_weight |
|
new_node = self.lexroot |
|
clm_logprob = 0.0 |
|
xi = self.space |
|
else: |
|
clm_state, wlm_state, wlm_logprobs, node, log_y, clm_logprob = state |
|
xi = int(x) |
|
if xi == self.space: |
|
if node is not None and node[1] >= 0: |
|
w = to_device(x, torch.LongTensor([node[1]])) |
|
else: |
|
w = self.var_word_unk |
|
|
|
wlm_state, z_wlm = self.wordlm(wlm_state, w) |
|
wlm_logprobs = F.log_softmax(z_wlm, dim=1) |
|
new_node = self.lexroot |
|
clm_logprob = 0.0 |
|
elif node is not None and xi in node[0]: |
|
new_node = node[0][xi] |
|
clm_logprob += log_y[0, xi] |
|
elif self.open_vocab: |
|
new_node = None |
|
clm_logprob += log_y[0, xi] |
|
else: |
|
log_y = to_device( |
|
x, torch.full((1, self.subword_dict_size), self.logzero) |
|
) |
|
return (clm_state, wlm_state, wlm_logprobs, None, log_y, 0.0), log_y |
|
|
|
clm_state, z_clm = self.subwordlm(clm_state, x) |
|
log_y = F.log_softmax(z_clm, dim=1) * self.subwordlm_weight |
|
|
|
|
|
if xi != self.space: |
|
if new_node is not None and new_node[1] >= 0: |
|
wlm_logprob = wlm_logprobs[:, new_node[1]] - clm_logprob |
|
else: |
|
wlm_logprob = wlm_logprobs[:, self.word_unk] + self.log_oov_penalty |
|
log_y[:, self.space] = wlm_logprob |
|
log_y[:, self.eos] = wlm_logprob |
|
else: |
|
log_y[:, self.space] = self.logzero |
|
log_y[:, self.eos] = self.logzero |
|
|
|
return ( |
|
(clm_state, wlm_state, wlm_logprobs, new_node, log_y, float(clm_logprob)), |
|
log_y, |
|
) |
|
|
|
def final(self, state): |
|
clm_state, wlm_state, wlm_logprobs, node, log_y, clm_logprob = state |
|
if node is not None and node[1] >= 0: |
|
w = to_device(wlm_logprobs, torch.LongTensor([node[1]])) |
|
else: |
|
w = self.var_word_unk |
|
wlm_state, z_wlm = self.wordlm(wlm_state, w) |
|
return float(F.log_softmax(z_wlm, dim=1)[:, self.word_eos]) |
|
|
|
|
|
|
|
class LookAheadWordLM(nn.Module): |
|
logzero = -10000000000.0 |
|
zero = 1.0e-10 |
|
|
|
def __init__( |
|
self, wordlm, word_dict, subword_dict, oov_penalty=0.0001, open_vocab=True |
|
): |
|
super(LookAheadWordLM, self).__init__() |
|
self.wordlm = wordlm |
|
self.word_eos = word_dict["<eos>"] |
|
self.word_unk = word_dict["<unk>"] |
|
self.var_word_eos = torch.LongTensor([self.word_eos]) |
|
self.var_word_unk = torch.LongTensor([self.word_unk]) |
|
self.space = subword_dict["<space>"] |
|
self.eos = subword_dict["<eos>"] |
|
self.lexroot = make_lexical_tree(word_dict, subword_dict, self.word_unk) |
|
self.oov_penalty = oov_penalty |
|
self.open_vocab = open_vocab |
|
self.subword_dict_size = len(subword_dict) |
|
self.zero_tensor = torch.FloatTensor([self.zero]) |
|
self.normalized = True |
|
|
|
def forward(self, state, x): |
|
|
|
if state is None: |
|
self.var_word_eos = to_device(x, self.var_word_eos) |
|
self.var_word_unk = to_device(x, self.var_word_eos) |
|
self.zero_tensor = to_device(x, self.zero_tensor) |
|
wlm_state, z_wlm = self.wordlm(None, self.var_word_eos) |
|
cumsum_probs = torch.cumsum(F.softmax(z_wlm, dim=1), dim=1) |
|
new_node = self.lexroot |
|
xi = self.space |
|
else: |
|
wlm_state, cumsum_probs, node = state |
|
xi = int(x) |
|
if xi == self.space: |
|
if node is not None and node[1] >= 0: |
|
w = to_device(x, torch.LongTensor([node[1]])) |
|
else: |
|
w = self.var_word_unk |
|
|
|
wlm_state, z_wlm = self.wordlm(wlm_state, w) |
|
cumsum_probs = torch.cumsum(F.softmax(z_wlm, dim=1), dim=1) |
|
new_node = self.lexroot |
|
elif node is not None and xi in node[0]: |
|
new_node = node[0][xi] |
|
elif self.open_vocab: |
|
new_node = None |
|
else: |
|
log_y = to_device( |
|
x, torch.full((1, self.subword_dict_size), self.logzero) |
|
) |
|
return (wlm_state, None, None), log_y |
|
|
|
if new_node is not None: |
|
succ, wid, wids = new_node |
|
|
|
sum_prob = ( |
|
(cumsum_probs[:, wids[1]] - cumsum_probs[:, wids[0]]) |
|
if wids is not None |
|
else 1.0 |
|
) |
|
if sum_prob < self.zero: |
|
log_y = to_device( |
|
x, torch.full((1, self.subword_dict_size), self.logzero) |
|
) |
|
return (wlm_state, cumsum_probs, new_node), log_y |
|
|
|
unk_prob = ( |
|
cumsum_probs[:, self.word_unk] - cumsum_probs[:, self.word_unk - 1] |
|
) |
|
y = to_device( |
|
x, |
|
torch.full( |
|
(1, self.subword_dict_size), float(unk_prob) * self.oov_penalty |
|
), |
|
) |
|
|
|
for cid, nd in succ.items(): |
|
y[:, cid] = ( |
|
cumsum_probs[:, nd[2][1]] - cumsum_probs[:, nd[2][0]] |
|
) / sum_prob |
|
|
|
if wid >= 0: |
|
wlm_prob = (cumsum_probs[:, wid] - cumsum_probs[:, wid - 1]) / sum_prob |
|
y[:, self.space] = wlm_prob |
|
y[:, self.eos] = wlm_prob |
|
elif xi == self.space: |
|
y[:, self.space] = self.zero |
|
y[:, self.eos] = self.zero |
|
log_y = torch.log(torch.max(y, self.zero_tensor)) |
|
else: |
|
log_y = to_device(x, torch.zeros(1, self.subword_dict_size)) |
|
return (wlm_state, cumsum_probs, new_node), log_y |
|
|
|
def final(self, state): |
|
wlm_state, cumsum_probs, node = state |
|
if node is not None and node[1] >= 0: |
|
w = to_device(cumsum_probs, torch.LongTensor([node[1]])) |
|
else: |
|
w = self.var_word_unk |
|
wlm_state, z_wlm = self.wordlm(wlm_state, w) |
|
return float(F.log_softmax(z_wlm, dim=1)[:, self.word_eos]) |
|
|