MarianMT¶
Bugs: If you see something strange, file a Github Issue and assign @sshleifer. Translations should be similar, but not identical to, output in the test set linked to in each model card.
Implementation Notes¶
Each model is about 298 MB on disk, there are 1,000+ models.
The list of supported language pairs can be found here.
models were originally trained by Jörg Tiedemann using the Marian C++ library, which supports fast training and translation.
All models are transformer encoder-decoders with 6 layers in each component. Each model’s performance is documented in a model card.
The 80 opus models that require BPE preprocessing are not supported.
- The modeling code is the same as
BartForConditionalGeneration
with a few minor modifications: static (sinusoid) positional embeddings (
MarianConfig.static_position_embeddings=True
)a new final_logits_bias (
MarianConfig.add_bias_logits=True
)no layernorm_embedding (
MarianConfig.normalize_embedding=False
)the model starts generating with pad_token_id (which has 0 token_embedding) as the prefix. (Bart uses <s/>)
- The modeling code is the same as
Code to bulk convert models can be found in
convert_marian_to_pytorch.py
Naming¶
All model names use the following format:
Helsinki-NLP/opus-mt-{src}-{tgt}
The language codes used to name models are inconsistent. Two digit codes can usually be found here, three digit codes require googling “language code {code}”.
Codes formatted like
es_AR
are usuallycode_{region}
. That one is spanish documents from Argentina.
Multilingual Models¶
- All model names use the following format:
Helsinki-NLP/opus-mt-{src}-{tgt}
: if
src
is in all caps, the model supports multiple input languages, you can figure out which ones by looking at the model card, or the Group Members mapping .if
tgt
is in all caps, the model can output multiple languages, and you should specify a language code by prepending the desired output language to the src_textYou can see a tokenizer’s supported language codes in
tokenizer.supported_language_codes
Example of translating english to many romance languages, using language codes:
from transformers import MarianMTModel, MarianTokenizer
src_text = [
'>>fr<< this is a sentence in english that we want to translate to french',
'>>pt<< This should go to portuguese',
'>>es<< And this to Spanish'
]
model_name = 'Helsinki-NLP/opus-mt-en-ROMANCE'
tokenizer = MarianTokenizer.from_pretrained(model_name)
print(tokenizer.supported_language_codes)
model = MarianMTModel.from_pretrained(model_name)
translated = model.generate(**tokenizer.prepare_seq2seq_batch(src_text))
tgt_text = [tokenizer.decode(t, skip_special_tokens=True) for t in translated]
# ["c'est une phrase en anglais que nous voulons traduire en français",
# 'Isto deve ir para o portuguĂŞs.',
# 'Y esto al español']
Sometimes, models were trained on collections of languages that do not resolve to a group. In this case, _ is used as a separator for src or tgt, as in 'Helsinki-NLP/opus-mt-en_el_es_fi-en_el_es_fi'
. These still require language codes.
There are many supported regional language codes, like >>es_ES<<
(Spain) and >>es_AR<<
(Argentina), that do not seem to change translations. I have not found these to provide different results than just using >>es<<
.
- For Example:
Helsinki-NLP/opus-mt-NORTH_EU-NORTH_EU
: translates from all NORTH_EU languages (see mapping) to all NORTH_EU languages. Use a special language code like>>de<<
to specify output language.Helsinki-NLP/opus-mt-ROMANCE-en
: translates from many romance languages to english, no codes needed since there is only 1 tgt language.
GROUP_MEMBERS = {
'ZH': ['cmn', 'cn', 'yue', 'ze_zh', 'zh_cn', 'zh_CN', 'zh_HK', 'zh_tw', 'zh_TW', 'zh_yue', 'zhs', 'zht', 'zh'],
'ROMANCE': ['fr', 'fr_BE', 'fr_CA', 'fr_FR', 'wa', 'frp', 'oc', 'ca', 'rm', 'lld', 'fur', 'lij', 'lmo', 'es', 'es_AR', 'es_CL', 'es_CO', 'es_CR', 'es_DO', 'es_EC', 'es_ES', 'es_GT', 'es_HN', 'es_MX', 'es_NI', 'es_PA', 'es_PE', 'es_PR', 'es_SV', 'es_UY', 'es_VE', 'pt', 'pt_br', 'pt_BR', 'pt_PT', 'gl', 'lad', 'an', 'mwl', 'it', 'it_IT', 'co', 'nap', 'scn', 'vec', 'sc', 'ro', 'la'],
'NORTH_EU': ['de', 'nl', 'fy', 'af', 'da', 'fo', 'is', 'no', 'nb', 'nn', 'sv'],
'SCANDINAVIA': ['da', 'fo', 'is', 'no', 'nb', 'nn', 'sv'],
'SAMI': ['se', 'sma', 'smj', 'smn', 'sms'],
'NORWAY': ['nb_NO', 'nb', 'nn_NO', 'nn', 'nog', 'no_nb', 'no'],
'CELTIC': ['ga', 'cy', 'br', 'gd', 'kw', 'gv']
}
Code to see available pretrained models:
from transformers.hf_api import HfApi
model_list = HfApi().model_list()
org = "Helsinki-NLP"
model_ids = [x.modelId for x in model_list if x.modelId.startswith(org)]
suffix = [x.split('/')[1] for x in model_ids]
multi_models = [f'{org}/{s}' for s in suffix if s != s.lower()]
MarianMTModel¶
Pytorch version of marian-nmt’s transformer.h (c++). Designed for the OPUS-NMT translation checkpoints.
Model API is identical to BartForConditionalGeneration.
Available models are listed at Model List
This class inherits nearly all functionality from BartForConditionalGeneration
, see that page for method signatures.
MarianConfig¶
-
class
transformers.
MarianConfig
(activation_dropout=0.0, extra_pos_embeddings=2, activation_function='gelu', vocab_size=50265, d_model=1024, encoder_ffn_dim=4096, encoder_layers=12, encoder_attention_heads=16, decoder_ffn_dim=4096, decoder_layers=12, decoder_attention_heads=16, encoder_layerdrop=0.0, decoder_layerdrop=0.0, attention_dropout=0.0, dropout=0.1, max_position_embeddings=1024, init_std=0.02, classifier_dropout=0.0, num_labels=3, is_encoder_decoder=True, pad_token_id=1, bos_token_id=0, eos_token_id=2, normalize_before=False, add_final_layer_norm=False, scale_embedding=False, normalize_embedding=True, static_position_embeddings=False, add_bias_logits=False, force_bos_token_to_be_generated=False, **common_kwargs)[source]¶
MarianTokenizer¶
-
class
transformers.
MarianTokenizer
(vocab, source_spm, target_spm, source_lang=None, target_lang=None, unk_token='<unk>', eos_token='</s>', pad_token='<pad>', model_max_length=512, **kwargs)[source]¶ Sentencepiece tokenizer for marian. Source and target languages have different SPM models. The logic is use the relevant source_spm or target_spm to encode txt as pieces, then look up each piece in a vocab dictionary.
Examples:
>>> from transformers import MarianTokenizer >>> tok = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-de') >>> src_texts = [ "I am a small frog.", "Tom asked his teacher for advice."] >>> tgt_texts = ["Ich bin ein kleiner Frosch.", "Tom bat seinen Lehrer um Rat."] # optional >>> batch_enc: BatchEncoding = tok.prepare_seq2seq_batch(src_texts, tgt_texts=tgt_texts) >>> # keys [input_ids, attention_mask, labels]. >>> # model(**batch) should work
-
prepare_seq2seq_batch
(src_texts: List[str], tgt_texts: Optional[List[str]] = None, max_length: Optional[int] = None, max_target_length: Optional[int] = None, return_tensors: str = 'pt', truncation=True, padding='longest', **unused) → transformers.tokenization_utils_base.BatchEncoding[source]¶ - Arguments:
- src_texts: (
list
): list of documents to summarize or source language texts
- tgt_texts: (
list
, optional): list of tgt language texts or summaries.
- max_length (
int
, optional): Controls the maximum length for encoder inputs (documents to summarize or source language texts) If left unset or set to
None
, this will use the predefined model maximum length if a maximum length is required by one of the truncation/padding parameters. If the model has no specific maximum input length (like XLNet) truncation/padding to a maximum length will be deactivated.- max_target_length (
int
, optional): Controls the maximum length of decoder inputs (target language texts or summaries) If left unset or set to
None
, this will use the max_length value.- padding (
bool
,str
orPaddingStrategy
, optional, defaults toFalse
): Activates and controls padding. Accepts the following values:
True
or'longest'
: Pad to the longest sequence in the batch (or no padding if only a single sequence if provided).'max_length'
: Pad to a maximum length specified with the argumentmax_length
or to the maximum acceptable input length for the model if that argument is not provided.False
or'do_not_pad'
(default): No padding (i.e., can output a batch with sequences of different lengths).
- return_tensors (
str
orTensorType
, optional, defaults to “pt”): If set, will return tensors instead of list of python integers. Acceptable values are:
'tf'
: Return TensorFlowtf.constant
objects.'pt'
: Return PyTorchtorch.Tensor
objects.'np'
: Return Numpynp.ndarray
objects.
- truncation (
bool
,str
orTruncationStrategy
, optional, defaults toTrue
): Activates and controls truncation. Accepts the following values:
True
or'longest_first'
: Truncate to a maximum length specified with the argumentmax_length
or to the maximum acceptable input length for the model if that argument is not provided. This will truncate token by token, removing a token from the longest sequence in the pair if a pair of sequences (or a batch of pairs) is provided.'only_first'
: Truncate to a maximum length specified with the argumentmax_length
or to the maximum acceptable input length for the model if that argument is not provided. This will only truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.'only_second'
: Truncate to a maximum length specified with the argumentmax_length
or to the maximum acceptable input length for the model if that argument is not provided. This will only truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.False
or'do_not_truncate'
(default): No truncation (i.e., can output batch with sequence lengths greater than the model maximum admissible input size).
- src_texts: (
- Return:
BatchEncoding
: ABatchEncoding
with the following fields:input_ids – List of token ids to be fed to the encoder.
attention_mask – List of indices specifying which tokens should be attended to by the model.
decoder_input_ids – List of token ids to be fed to the decoder.
- decoder_attention_mask – List of indices specifying which tokens should be attended to by the decoder.
This does not include causal mask, which is built by the model.
The full set of keys
[input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]
, will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys.
Prepare model inputs for translation. For best performance, translate one sentence at a time.
-