# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import logging from fairseq.tasks import register_task from fairseq.tasks.speech_to_text import SpeechToTextTask from fairseq.tasks.translation import ( TranslationTask, TranslationConfig ) try: import examples.simultaneous_translation # noqa import_successful = True except BaseException: import_successful = False logger = logging.getLogger(__name__) def check_import(flag): if not flag: raise ImportError( "'examples.simultaneous_translation' is not correctly imported. " "Please considering `pip install -e $FAIRSEQ_DIR`." ) @register_task("simul_speech_to_text") class SimulSpeechToTextTask(SpeechToTextTask): def __init__(self, args, tgt_dict): check_import(import_successful) super().__init__(args, tgt_dict) @register_task("simul_text_to_text", dataclass=TranslationConfig) class SimulTextToTextTask(TranslationTask): def __init__(self, cfg, src_dict, tgt_dict): check_import(import_successful) super().__init__(cfg, src_dict, tgt_dict)