# Copyright (c) Meta Platforms, Inc. and affiliates | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# MIT_LICENSE file in the root directory of this source tree. | |
from dataclasses import dataclass | |
from typing import Any, Dict, List, Optional | |
import torch | |
class MultimodalSample: | |
id: int | |
lang: str | |
text: str | |
audio_local_path: Optional[str] = None | |
waveform: Optional[torch.Tensor] = None | |
sampling_rate: Optional[int] = None | |
units: Optional[List[int]] = None | |
def from_json(cls, js: Dict[str, Any]) -> "MultimodalSample": | |
return cls( | |
id=js["id"], | |
lang=js["lang"], | |
text=js["text"], | |
audio_local_path=js.get("audio_local_path"), | |
waveform=None, # don't serialize | |
sampling_rate=js.get("sampling_rate"), | |
units=js.get("units"), | |
) | |
class LangPairSample: | |
source: MultimodalSample | |
target: MultimodalSample | |
def from_json(cls, js: Dict[str, Any]) -> "LangPairSample": | |
return cls( | |
source=MultimodalSample.from_json(js["source"]), | |
target=MultimodalSample.from_json(js["target"]), | |
) | |