sakharamg's picture
Uploading all files
158b61b
""" Onmt NMT Model base class definition """
import torch.nn as nn
class BaseModel(nn.Module):
"""
Core trainable object in OpenNMT. Implements a trainable interface
for a simple, generic encoder / decoder or decoder only model.
"""
def __init__(self, encoder, decoder):
super(BaseModel, self).__init__()
def forward(self, src, tgt, lengths, bptt=False, with_align=False):
"""Forward propagate a `src` and `tgt` pair for training.
Possible initialized with a beginning decoder state.
Args:
src (Tensor): A source sequence passed to encoder.
typically for inputs this will be a padded `LongTensor`
of size ``(len, batch, features)``. However, may be an
image or other generic input depending on encoder.
tgt (LongTensor): A target sequence passed to decoder.
Size ``(tgt_len, batch, features)``.
lengths(LongTensor): The src lengths, pre-padding ``(batch,)``.
bptt (Boolean): A flag indicating if truncated bptt is set.
If reset then init_state
with_align (Boolean): A flag indicating whether output alignment,
Only valid for transformer decoder.
Returns:
(FloatTensor, dict[str, FloatTensor]):
* decoder output ``(tgt_len, batch, hidden)``
* dictionary attention dists of ``(tgt_len, batch, src_len)``
"""
raise NotImplementedError
def update_dropout(self, dropout):
raise NotImplementedError
def count_parameters(self, log=print):
raise NotImplementedError
class NMTModel(BaseModel):
"""
Core trainable object in OpenNMT. Implements a trainable interface
for a simple, generic encoder + decoder model.
Args:
encoder (onmt.encoders.EncoderBase): an encoder object
decoder (onmt.decoders.DecoderBase): a decoder object
"""
def __init__(self, encoder, decoder):
super(NMTModel, self).__init__(encoder, decoder)
self.encoder = encoder
self.decoder = decoder
def forward(self, src, tgt, lengths, bptt=False, with_align=False):
dec_in = tgt[:-1] # exclude last target from inputs
enc_state, memory_bank, lengths = self.encoder(src, lengths)
if not bptt:
self.decoder.init_state(src, memory_bank, enc_state)
dec_out, attns = self.decoder(dec_in, memory_bank,
memory_lengths=lengths,
with_align=with_align)
return dec_out, attns
def update_dropout(self, dropout):
self.encoder.update_dropout(dropout)
self.decoder.update_dropout(dropout)
def count_parameters(self, log=print):
"""Count number of parameters in model (& print with `log` callback).
Returns:
(int, int):
* encoder side parameter count
* decoder side parameter count
"""
enc, dec = 0, 0
for name, param in self.named_parameters():
if 'encoder' in name:
enc += param.nelement()
else:
dec += param.nelement()
if callable(log):
log('encoder: {}'.format(enc))
log('decoder: {}'.format(dec))
log('* number of parameters: {}'.format(enc + dec))
return enc, dec
class LanguageModel(BaseModel):
"""
Core trainable object in OpenNMT. Implements a trainable interface
for a simple, generic decoder only model.
Currently TransformerLMDecoder is the only LM decoder implemented
Args:
decoder (onmt.decoders.TransformerLMDecoder): a transformer decoder
"""
def __init__(self, encoder=None, decoder=None):
super(LanguageModel, self).__init__(encoder, decoder)
if encoder is not None:
raise ValueError("LanguageModel should not be used"
"with an encoder")
self.decoder = decoder
def forward(self, src, tgt, lengths, bptt=False, with_align=False):
"""Forward propagate a `src` and `tgt` pair for training.
Possible initialized with a beginning decoder state.
Args:
src (Tensor): A source sequence passed to decoder.
typically for inputs this will be a padded `LongTensor`
of size ``(len, batch, features)``. However, may be an
image or other generic input depending on decoder.
tgt (LongTensor): A target sequence passed to decoder.
Size ``(tgt_len, batch, features)``.
lengths(LongTensor): The src lengths, pre-padding ``(batch,)``.
bptt (Boolean): A flag indicating if truncated bptt is set.
If reset then init_state
with_align (Boolean): A flag indicating whether output alignment,
Only valid for transformer decoder.
Returns:
(FloatTensor, dict[str, FloatTensor]):
* decoder output ``(tgt_len, batch, hidden)``
* dictionary attention dists of ``(tgt_len, batch, src_len)``
"""
if not bptt:
self.decoder.init_state()
dec_out, attns = self.decoder(
src, memory_bank=None, memory_lengths=lengths,
with_align=with_align
)
return dec_out, attns
def update_dropout(self, dropout):
self.decoder.update_dropout(dropout)
def count_parameters(self, log=print):
"""Count number of parameters in model (& print with `log` callback).
Returns:
(int, int):
* encoder side parameter count
* decoder side parameter count
"""
enc, dec = 0, 0
for name, param in self.named_parameters():
if "decoder" in name:
dec += param.nelement()
if callable(log):
# No encoder in LM, seq2seq count formatting kept
log("encoder: {}".format(enc))
log("decoder: {}".format(dec))
log("* number of parameters: {}".format(enc + dec))
return enc, dec