""" Onmt NMT Model base class definition """ import torch.nn as nn class BaseModel(nn.Module): """ Core trainable object in OpenNMT. Implements a trainable interface for a simple, generic encoder / decoder or decoder only model. """ def __init__(self, encoder, decoder): super(BaseModel, self).__init__() def forward(self, src, tgt, lengths, bptt=False, with_align=False): """Forward propagate a `src` and `tgt` pair for training. Possible initialized with a beginning decoder state. Args: src (Tensor): A source sequence passed to encoder. typically for inputs this will be a padded `LongTensor` of size ``(len, batch, features)``. However, may be an image or other generic input depending on encoder. tgt (LongTensor): A target sequence passed to decoder. Size ``(tgt_len, batch, features)``. lengths(LongTensor): The src lengths, pre-padding ``(batch,)``. bptt (Boolean): A flag indicating if truncated bptt is set. If reset then init_state with_align (Boolean): A flag indicating whether output alignment, Only valid for transformer decoder. Returns: (FloatTensor, dict[str, FloatTensor]): * decoder output ``(tgt_len, batch, hidden)`` * dictionary attention dists of ``(tgt_len, batch, src_len)`` """ raise NotImplementedError def update_dropout(self, dropout): raise NotImplementedError def count_parameters(self, log=print): raise NotImplementedError class NMTModel(BaseModel): """ Core trainable object in OpenNMT. Implements a trainable interface for a simple, generic encoder + decoder model. Args: encoder (onmt.encoders.EncoderBase): an encoder object decoder (onmt.decoders.DecoderBase): a decoder object """ def __init__(self, encoder, decoder): super(NMTModel, self).__init__(encoder, decoder) self.encoder = encoder self.decoder = decoder def forward(self, src, tgt, lengths, bptt=False, with_align=False): dec_in = tgt[:-1] # exclude last target from inputs enc_state, memory_bank, lengths = self.encoder(src, lengths) if not bptt: self.decoder.init_state(src, memory_bank, enc_state) dec_out, attns = self.decoder(dec_in, memory_bank, memory_lengths=lengths, with_align=with_align) return dec_out, attns def update_dropout(self, dropout): self.encoder.update_dropout(dropout) self.decoder.update_dropout(dropout) def count_parameters(self, log=print): """Count number of parameters in model (& print with `log` callback). Returns: (int, int): * encoder side parameter count * decoder side parameter count """ enc, dec = 0, 0 for name, param in self.named_parameters(): if 'encoder' in name: enc += param.nelement() else: dec += param.nelement() if callable(log): log('encoder: {}'.format(enc)) log('decoder: {}'.format(dec)) log('* number of parameters: {}'.format(enc + dec)) return enc, dec class LanguageModel(BaseModel): """ Core trainable object in OpenNMT. Implements a trainable interface for a simple, generic decoder only model. Currently TransformerLMDecoder is the only LM decoder implemented Args: decoder (onmt.decoders.TransformerLMDecoder): a transformer decoder """ def __init__(self, encoder=None, decoder=None): super(LanguageModel, self).__init__(encoder, decoder) if encoder is not None: raise ValueError("LanguageModel should not be used" "with an encoder") self.decoder = decoder def forward(self, src, tgt, lengths, bptt=False, with_align=False): """Forward propagate a `src` and `tgt` pair for training. Possible initialized with a beginning decoder state. Args: src (Tensor): A source sequence passed to decoder. typically for inputs this will be a padded `LongTensor` of size ``(len, batch, features)``. However, may be an image or other generic input depending on decoder. tgt (LongTensor): A target sequence passed to decoder. Size ``(tgt_len, batch, features)``. lengths(LongTensor): The src lengths, pre-padding ``(batch,)``. bptt (Boolean): A flag indicating if truncated bptt is set. If reset then init_state with_align (Boolean): A flag indicating whether output alignment, Only valid for transformer decoder. Returns: (FloatTensor, dict[str, FloatTensor]): * decoder output ``(tgt_len, batch, hidden)`` * dictionary attention dists of ``(tgt_len, batch, src_len)`` """ if not bptt: self.decoder.init_state() dec_out, attns = self.decoder( src, memory_bank=None, memory_lengths=lengths, with_align=with_align ) return dec_out, attns def update_dropout(self, dropout): self.decoder.update_dropout(dropout) def count_parameters(self, log=print): """Count number of parameters in model (& print with `log` callback). Returns: (int, int): * encoder side parameter count * decoder side parameter count """ enc, dec = 0, 0 for name, param in self.named_parameters(): if "decoder" in name: dec += param.nelement() if callable(log): # No encoder in LM, seq2seq count formatting kept log("encoder: {}".format(enc)) log("decoder: {}".format(dec)) log("* number of parameters: {}".format(enc + dec)) return enc, dec