OFA-Image_Caption / fairseq /fairseq /tasks /translation_from_pretrained_xlm.py
JustinLin610
update
8437114
raw history blame
No virus
1.29 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from fairseq.data.legacy.masked_lm_dictionary import MaskedLMDictionary
from fairseq.tasks.translation import TranslationConfig, TranslationTask
from . import register_task
@dataclass
class TranslationFromPretrainedXLMConfig(TranslationConfig):
pass
@register_task(
"translation_from_pretrained_xlm", dataclass=TranslationFromPretrainedXLMConfig
)
class TranslationFromPretrainedXLMTask(TranslationTask):
"""
Same as TranslationTask except use the MaskedLMDictionary class so that
we can load data that was binarized with the MaskedLMDictionary class.
This task should be used for the entire training pipeline when we want to
train an NMT model from a pretrained XLM checkpoint: binarizing NMT data,
training NMT with the pretrained XLM checkpoint, and subsequent evaluation
of that trained model.
"""
@classmethod
def load_dictionary(cls, filename):
"""Load the masked LM dictionary from the filename
Args:
filename (str): the filename
"""
return MaskedLMDictionary.load(filename)