File size: 1,251 Bytes
02c689c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# Copyright (c) Meta Platforms, Inc. and affiliates
# All rights reserved.
#
# This source code is licensed under the license found in the
# MIT_LICENSE file in the root directory of this source tree.


from dataclasses import dataclass
from typing import Any, Dict, List, Optional

import torch


@dataclass
class MultimodalSample:
    id: int
    lang: str
    text: str
    audio_local_path: Optional[str] = None
    waveform: Optional[torch.Tensor] = None
    sampling_rate: Optional[int] = None
    units: Optional[List[int]] = None

    @classmethod
    def from_json(cls, js: Dict[str, Any]) -> "MultimodalSample":
        return cls(
            id=js["id"],
            lang=js["lang"],
            text=js["text"],
            audio_local_path=js.get("audio_local_path"),
            waveform=None,  # don't serialize
            sampling_rate=js.get("sampling_rate"),
            units=js.get("units"),
        )


@dataclass
class LangPairSample:
    source: MultimodalSample
    target: MultimodalSample

    @classmethod
    def from_json(cls, js: Dict[str, Any]) -> "LangPairSample":
        return cls(
            source=MultimodalSample.from_json(js["source"]),
            target=MultimodalSample.from_json(js["target"]),
        )