Spaces:
Runtime error
Runtime error
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # pylint: disable=C0116 | |
| # pylint: disable=C0301 | |
| from collections import defaultdict | |
| import torch | |
| from lhotse import MonoCut | |
| from lhotse.cut import Cut, MixedCut | |
| from nemo.collections.common.data.lhotse.text_adapters import SourceTargetTextExample | |
| from nemo.collections.common.data.prompt_fn import registered_prompt_format_fn | |
| from nemo.collections.common.prompts.formatter import Modality, PromptFormatter | |
| class T5NMTPromptFormatter(PromptFormatter): | |
| """ | |
| The default prompt format for Megatron T5 based neural machine translation models. | |
| Based on: https://github.com/NVIDIA/NeMo/blob/ad5ef750e351edbb5eeb7eb6df2d0c804819600f/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py#L790 | |
| """ | |
| NAME = "t5nmt" | |
| OUTPUT_ROLE = "assistant" | |
| TEMPLATE = { | |
| "user": { | |
| "template": f"|target_lang| |message|", | |
| "slots": { | |
| "target_lang": Modality.Text, | |
| "message": Modality.Text, | |
| }, | |
| }, | |
| OUTPUT_ROLE: { | |
| "template": f"|message|", | |
| "slots": { | |
| "message": Modality.Text, | |
| }, | |
| }, | |
| } | |
| def encode_turn(self, prompt_template: str, expected_slots: dict, slot_values: dict) -> list[int]: | |
| # Automatically adds "<" and ">" to target lang token for T5 NMT. | |
| # Based on: https://github.com/NVIDIA/NeMo/blob/ad5ef750e351edbb5eeb7eb6df2d0c804819600f/nemo/collections/nlp/models/machine_translation/mt_enc_dec_model.py#L307 | |
| if (val := slot_values.get("target_lang")) is not None: | |
| if not val.startswith("<") or not val.endswith(">"): | |
| slot_values["target_lang"] = f"<{val}>" | |
| return super().encode_turn( | |
| prompt_template=prompt_template, expected_slots=expected_slots, slot_values=slot_values | |
| ) | |
| def t5nmt(cut: Cut, prompt: T5NMTPromptFormatter) -> dict[str, torch.Tensor]: | |
| ans = defaultdict(list) | |
| if isinstance(cut, MixedCut): | |
| cut = cut._first_non_padding_cut | |
| if not isinstance(cut, MonoCut): | |
| raise TypeError( | |
| f"Expected input audio to have a single channel (required MonoCut/MixedCut, but we received: {cut=})" | |
| ) | |
| if hasattr(cut, "context"): | |
| context = cut.context | |
| elif hasattr(cut, "default_context"): | |
| context = cut.default_context | |
| else: | |
| raise RuntimeError("Missing context/default_context custom field in cut: {cut}") | |
| turns = [ | |
| dict( | |
| role="user", | |
| # "message" slot is the audio portion of the cut; currently it is populated inside model's forward. | |
| slots={"target_lang": context, "message": ""}, | |
| ), | |
| ] | |
| if len(cut.supervisions) > 0 and cut.supervisions[0].text is not None: | |
| turns.append( | |
| dict( | |
| role=prompt.OUTPUT_ROLE, | |
| slots={"message": cut.supervisions[0].text}, | |
| ) | |
| ) | |
| return prompt.encode_dialog(turns) | |
| def t5nmt_src_tgt_text_example(example: SourceTargetTextExample, prompt: T5NMTPromptFormatter): | |
| ctx = f"<{example.target.language}>" | |
| if example.has_custom("extra_prompt"): | |
| ctx = f"{ctx} {example.extra_prompt}" | |
| return prompt.encode_dialog( | |
| [ | |
| {"role": "user", "slots": {"message": example.source.text, "target_lang": ctx}}, | |
| {"role": prompt.OUTPUT_ROLE, "slots": {"message": example.target.text}}, | |
| ] | |
| ) | |