|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
|
import chainer |
|
import chainer.functions as F |
|
from espnet.lm.lm_utils import make_lexical_tree |
|
|
|
|
|
|
|
class MultiLevelLM(chainer.Chain): |
|
logzero = -10000000000.0 |
|
zero = 1.0e-10 |
|
|
|
def __init__( |
|
self, |
|
wordlm, |
|
subwordlm, |
|
word_dict, |
|
subword_dict, |
|
subwordlm_weight=0.8, |
|
oov_penalty=1.0, |
|
open_vocab=True, |
|
): |
|
super(MultiLevelLM, self).__init__() |
|
self.wordlm = wordlm |
|
self.subwordlm = subwordlm |
|
self.word_eos = word_dict["<eos>"] |
|
self.word_unk = word_dict["<unk>"] |
|
self.xp_word_eos = self.xp.full(1, self.word_eos, "i") |
|
self.xp_word_unk = self.xp.full(1, self.word_unk, "i") |
|
self.space = subword_dict["<space>"] |
|
self.eos = subword_dict["<eos>"] |
|
self.lexroot = make_lexical_tree(word_dict, subword_dict, self.word_unk) |
|
self.log_oov_penalty = math.log(oov_penalty) |
|
self.open_vocab = open_vocab |
|
self.subword_dict_size = len(subword_dict) |
|
self.subwordlm_weight = subwordlm_weight |
|
self.normalized = True |
|
|
|
def __call__(self, state, x): |
|
|
|
if state is None: |
|
wlm_state, z_wlm = self.wordlm(None, self.xp_word_eos) |
|
wlm_logprobs = F.log_softmax(z_wlm).data |
|
clm_state, z_clm = self.subwordlm(None, x) |
|
log_y = F.log_softmax(z_clm).data * self.subwordlm_weight |
|
new_node = self.lexroot |
|
clm_logprob = 0.0 |
|
xi = self.space |
|
else: |
|
clm_state, wlm_state, wlm_logprobs, node, log_y, clm_logprob = state |
|
xi = int(x) |
|
if xi == self.space: |
|
if node is not None and node[1] >= 0: |
|
w = self.xp.full(1, node[1], "i") |
|
else: |
|
w = self.xp_word_unk |
|
|
|
wlm_state, z_wlm = self.wordlm(wlm_state, w) |
|
wlm_logprobs = F.log_softmax(z_wlm).data |
|
new_node = self.lexroot |
|
clm_logprob = 0.0 |
|
elif node is not None and xi in node[0]: |
|
new_node = node[0][xi] |
|
clm_logprob += log_y[0, xi] |
|
elif self.open_vocab: |
|
new_node = None |
|
clm_logprob += log_y[0, xi] |
|
else: |
|
log_y = self.xp.full((1, self.subword_dict_size), self.logzero, "f") |
|
return (clm_state, wlm_state, None, log_y, 0.0), log_y |
|
|
|
clm_state, z_clm = self.subwordlm(clm_state, x) |
|
log_y = F.log_softmax(z_clm).data * self.subwordlm_weight |
|
|
|
|
|
if xi != self.space: |
|
if new_node is not None and new_node[1] >= 0: |
|
wlm_logprob = wlm_logprobs[:, new_node[1]] - clm_logprob |
|
else: |
|
wlm_logprob = wlm_logprobs[:, self.word_unk] + self.log_oov_penalty |
|
log_y[:, self.space] = wlm_logprob |
|
log_y[:, self.eos] = wlm_logprob |
|
else: |
|
log_y[:, self.space] = self.logzero |
|
log_y[:, self.eos] = self.logzero |
|
|
|
return (clm_state, wlm_state, wlm_logprobs, new_node, log_y, clm_logprob), log_y |
|
|
|
def final(self, state): |
|
clm_state, wlm_state, wlm_logprobs, node, log_y, clm_logprob = state |
|
if node is not None and node[1] >= 0: |
|
w = self.xp.full(1, node[1], "i") |
|
else: |
|
w = self.xp_word_unk |
|
wlm_state, z_wlm = self.wordlm(wlm_state, w) |
|
return F.log_softmax(z_wlm).data[:, self.word_eos] |
|
|
|
|
|
|
|
class LookAheadWordLM(chainer.Chain): |
|
logzero = -10000000000.0 |
|
zero = 1.0e-10 |
|
|
|
def __init__( |
|
self, wordlm, word_dict, subword_dict, oov_penalty=0.0001, open_vocab=True |
|
): |
|
super(LookAheadWordLM, self).__init__() |
|
self.wordlm = wordlm |
|
self.word_eos = word_dict["<eos>"] |
|
self.word_unk = word_dict["<unk>"] |
|
self.xp_word_eos = self.xp.full(1, self.word_eos, "i") |
|
self.xp_word_unk = self.xp.full(1, self.word_unk, "i") |
|
self.space = subword_dict["<space>"] |
|
self.eos = subword_dict["<eos>"] |
|
self.lexroot = make_lexical_tree(word_dict, subword_dict, self.word_unk) |
|
self.oov_penalty = oov_penalty |
|
self.open_vocab = open_vocab |
|
self.subword_dict_size = len(subword_dict) |
|
self.normalized = True |
|
|
|
def __call__(self, state, x): |
|
|
|
if state is None: |
|
wlm_state, z_wlm = self.wordlm(None, self.xp_word_eos) |
|
cumsum_probs = self.xp.cumsum(F.softmax(z_wlm).data, axis=1) |
|
new_node = self.lexroot |
|
xi = self.space |
|
else: |
|
wlm_state, cumsum_probs, node = state |
|
xi = int(x) |
|
if xi == self.space: |
|
if node is not None and node[1] >= 0: |
|
w = self.xp.full(1, node[1], "i") |
|
else: |
|
w = self.xp_word_unk |
|
|
|
wlm_state, z_wlm = self.wordlm(wlm_state, w) |
|
cumsum_probs = self.xp.cumsum(F.softmax(z_wlm).data, axis=1) |
|
new_node = self.lexroot |
|
elif node is not None and xi in node[0]: |
|
new_node = node[0][xi] |
|
elif self.open_vocab: |
|
new_node = None |
|
else: |
|
log_y = self.xp.full((1, self.subword_dict_size), self.logzero, "f") |
|
return (wlm_state, None, None), log_y |
|
|
|
if new_node is not None: |
|
succ, wid, wids = new_node |
|
|
|
sum_prob = ( |
|
(cumsum_probs[:, wids[1]] - cumsum_probs[:, wids[0]]) |
|
if wids is not None |
|
else 1.0 |
|
) |
|
if sum_prob < self.zero: |
|
log_y = self.xp.full((1, self.subword_dict_size), self.logzero, "f") |
|
return (wlm_state, cumsum_probs, new_node), log_y |
|
|
|
unk_prob = ( |
|
cumsum_probs[:, self.word_unk] - cumsum_probs[:, self.word_unk - 1] |
|
) |
|
y = self.xp.full( |
|
(1, self.subword_dict_size), unk_prob * self.oov_penalty, "f" |
|
) |
|
|
|
for cid, nd in succ.items(): |
|
y[:, cid] = ( |
|
cumsum_probs[:, nd[2][1]] - cumsum_probs[:, nd[2][0]] |
|
) / sum_prob |
|
|
|
if wid >= 0: |
|
wlm_prob = (cumsum_probs[:, wid] - cumsum_probs[:, wid - 1]) / sum_prob |
|
y[:, self.space] = wlm_prob |
|
y[:, self.eos] = wlm_prob |
|
elif xi == self.space: |
|
y[:, self.space] = self.zero |
|
y[:, self.eos] = self.zero |
|
log_y = self.xp.log( |
|
self.xp.clip(y, self.zero, None) |
|
) |
|
else: |
|
log_y = self.xp.zeros((1, self.subword_dict_size), "f") |
|
return (wlm_state, cumsum_probs, new_node), log_y |
|
|
|
def final(self, state): |
|
wlm_state, cumsum_probs, node = state |
|
if node is not None and node[1] >= 0: |
|
w = self.xp.full(1, node[1], "i") |
|
else: |
|
w = self.xp_word_unk |
|
wlm_state, z_wlm = self.wordlm(wlm_state, w) |
|
return F.log_softmax(z_wlm).data[:, self.word_eos] |
|
|