|
""" Onmt NMT Model base class definition """ |
|
import torch.nn as nn |
|
|
|
|
|
class BaseModel(nn.Module): |
|
""" |
|
Core trainable object in OpenNMT. Implements a trainable interface |
|
for a simple, generic encoder / decoder or decoder only model. |
|
""" |
|
|
|
def __init__(self, encoder, decoder): |
|
super(BaseModel, self).__init__() |
|
|
|
def forward(self, src, tgt, lengths, bptt=False, with_align=False): |
|
"""Forward propagate a `src` and `tgt` pair for training. |
|
Possible initialized with a beginning decoder state. |
|
|
|
Args: |
|
src (Tensor): A source sequence passed to encoder. |
|
typically for inputs this will be a padded `LongTensor` |
|
of size ``(len, batch, features)``. However, may be an |
|
image or other generic input depending on encoder. |
|
tgt (LongTensor): A target sequence passed to decoder. |
|
Size ``(tgt_len, batch, features)``. |
|
lengths(LongTensor): The src lengths, pre-padding ``(batch,)``. |
|
bptt (Boolean): A flag indicating if truncated bptt is set. |
|
If reset then init_state |
|
with_align (Boolean): A flag indicating whether output alignment, |
|
Only valid for transformer decoder. |
|
|
|
Returns: |
|
(FloatTensor, dict[str, FloatTensor]): |
|
|
|
* decoder output ``(tgt_len, batch, hidden)`` |
|
* dictionary attention dists of ``(tgt_len, batch, src_len)`` |
|
""" |
|
raise NotImplementedError |
|
|
|
def update_dropout(self, dropout): |
|
raise NotImplementedError |
|
|
|
def count_parameters(self, log=print): |
|
raise NotImplementedError |
|
|
|
|
|
class NMTModel(BaseModel): |
|
""" |
|
Core trainable object in OpenNMT. Implements a trainable interface |
|
for a simple, generic encoder + decoder model. |
|
Args: |
|
encoder (onmt.encoders.EncoderBase): an encoder object |
|
decoder (onmt.decoders.DecoderBase): a decoder object |
|
""" |
|
|
|
def __init__(self, encoder, decoder): |
|
super(NMTModel, self).__init__(encoder, decoder) |
|
self.encoder = encoder |
|
self.decoder = decoder |
|
|
|
def forward(self, src, tgt, lengths, bptt=False, with_align=False): |
|
dec_in = tgt[:-1] |
|
|
|
enc_state, memory_bank, lengths = self.encoder(src, lengths) |
|
|
|
if not bptt: |
|
self.decoder.init_state(src, memory_bank, enc_state) |
|
dec_out, attns = self.decoder(dec_in, memory_bank, |
|
memory_lengths=lengths, |
|
with_align=with_align) |
|
return dec_out, attns |
|
|
|
def update_dropout(self, dropout): |
|
self.encoder.update_dropout(dropout) |
|
self.decoder.update_dropout(dropout) |
|
|
|
def count_parameters(self, log=print): |
|
"""Count number of parameters in model (& print with `log` callback). |
|
|
|
Returns: |
|
(int, int): |
|
* encoder side parameter count |
|
* decoder side parameter count |
|
""" |
|
|
|
enc, dec = 0, 0 |
|
for name, param in self.named_parameters(): |
|
if 'encoder' in name: |
|
enc += param.nelement() |
|
else: |
|
dec += param.nelement() |
|
if callable(log): |
|
log('encoder: {}'.format(enc)) |
|
log('decoder: {}'.format(dec)) |
|
log('* number of parameters: {}'.format(enc + dec)) |
|
return enc, dec |
|
|
|
|
|
class LanguageModel(BaseModel): |
|
""" |
|
Core trainable object in OpenNMT. Implements a trainable interface |
|
for a simple, generic decoder only model. |
|
Currently TransformerLMDecoder is the only LM decoder implemented |
|
Args: |
|
decoder (onmt.decoders.TransformerLMDecoder): a transformer decoder |
|
""" |
|
|
|
def __init__(self, encoder=None, decoder=None): |
|
super(LanguageModel, self).__init__(encoder, decoder) |
|
if encoder is not None: |
|
raise ValueError("LanguageModel should not be used" |
|
"with an encoder") |
|
self.decoder = decoder |
|
|
|
def forward(self, src, tgt, lengths, bptt=False, with_align=False): |
|
"""Forward propagate a `src` and `tgt` pair for training. |
|
Possible initialized with a beginning decoder state. |
|
Args: |
|
src (Tensor): A source sequence passed to decoder. |
|
typically for inputs this will be a padded `LongTensor` |
|
of size ``(len, batch, features)``. However, may be an |
|
image or other generic input depending on decoder. |
|
tgt (LongTensor): A target sequence passed to decoder. |
|
Size ``(tgt_len, batch, features)``. |
|
lengths(LongTensor): The src lengths, pre-padding ``(batch,)``. |
|
bptt (Boolean): A flag indicating if truncated bptt is set. |
|
If reset then init_state |
|
with_align (Boolean): A flag indicating whether output alignment, |
|
Only valid for transformer decoder. |
|
Returns: |
|
(FloatTensor, dict[str, FloatTensor]): |
|
* decoder output ``(tgt_len, batch, hidden)`` |
|
* dictionary attention dists of ``(tgt_len, batch, src_len)`` |
|
""" |
|
|
|
if not bptt: |
|
self.decoder.init_state() |
|
dec_out, attns = self.decoder( |
|
src, memory_bank=None, memory_lengths=lengths, |
|
with_align=with_align |
|
) |
|
return dec_out, attns |
|
|
|
def update_dropout(self, dropout): |
|
self.decoder.update_dropout(dropout) |
|
|
|
def count_parameters(self, log=print): |
|
"""Count number of parameters in model (& print with `log` callback). |
|
Returns: |
|
(int, int): |
|
* encoder side parameter count |
|
* decoder side parameter count |
|
""" |
|
|
|
enc, dec = 0, 0 |
|
for name, param in self.named_parameters(): |
|
if "decoder" in name: |
|
dec += param.nelement() |
|
|
|
if callable(log): |
|
|
|
log("encoder: {}".format(enc)) |
|
log("decoder: {}".format(dec)) |
|
log("* number of parameters: {}".format(enc + dec)) |
|
return enc, dec |
|
|