victan commited on
Commit
2cf8f5b
1 Parent(s): 02c689c

Upload seamless_communication/datasets/huggingface.py with huggingface_hub

Browse files
seamless_communication/datasets/huggingface.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # MIT_LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import logging
9
+ import os
10
+ from abc import abstractmethod
11
+ from typing import Dict, Iterable, Optional
12
+
13
+ import numpy as np
14
+ import torch
15
+ from datasets import load_dataset
16
+
17
+ from .datatypes import LangPairSample, MultimodalSample
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class SpeechTokenizer:
23
+ @abstractmethod
24
+ def encode(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
25
+ ...
26
+
27
+
28
+ class Speech2SpeechFleursDatasetBuilder:
29
+ """Assembles speech2speech dataset from google/fleurs on HuggingFace"""
30
+
31
+ HF_FLEURS_DATASET_NAME = "google/fleurs"
32
+
33
+ def __init__(
34
+ self,
35
+ source_lang: str,
36
+ target_lang: str,
37
+ split: str = "test",
38
+ skip_source_audio: bool = True,
39
+ skip_target_audio: bool = True,
40
+ audio_dtype: torch.dtype = torch.float32,
41
+ dataset_cache_dir: Optional[str] = None,
42
+ speech_tokenizer: Optional[SpeechTokenizer] = None,
43
+ ):
44
+ self.source_lang = source_lang
45
+ self.target_lang = target_lang
46
+ self.split = split
47
+ self.dataset_cache_dir = dataset_cache_dir
48
+ self.audio_dtype = audio_dtype
49
+ self.skip_source_audio = skip_source_audio
50
+ self.skip_target_audio = skip_target_audio
51
+ self.speech_tokenizer = speech_tokenizer
52
+
53
+ def _prepare_sample(
54
+ self,
55
+ sample_id: int,
56
+ lang: str,
57
+ text: str,
58
+ audio_local_path: Optional[str] = None,
59
+ waveform_npy: Optional[np.ndarray] = None,
60
+ sampling_rate: Optional[int] = None,
61
+ ) -> MultimodalSample:
62
+ should_skip_audio = (
63
+ lang == self.target_lang
64
+ and self.skip_target_audio
65
+ or lang == self.source_lang
66
+ and self.skip_source_audio
67
+ or waveform_npy is None
68
+ )
69
+ if not should_skip_audio:
70
+ waveform = torch.from_numpy(waveform_npy).to(self.audio_dtype)
71
+ else:
72
+ waveform = None
73
+ if self.speech_tokenizer is not None and not should_skip_audio:
74
+ assert waveform is not None
75
+ assert sampling_rate is not None
76
+ units_tensor = self.speech_tokenizer.encode(
77
+ waveform, sampling_rate
78
+ ).reshape(-1)
79
+ units = units_tensor.tolist()
80
+ else:
81
+ units = None
82
+ return MultimodalSample(
83
+ id=sample_id,
84
+ lang=lang,
85
+ text=text.strip(),
86
+ audio_local_path=audio_local_path,
87
+ waveform=waveform,
88
+ sampling_rate=sampling_rate,
89
+ units=units,
90
+ )
91
+
92
+ def iterate_lang_audio_samples(self, lang: str) -> Iterable[MultimodalSample]:
93
+ ds = load_dataset(
94
+ self.HF_FLEURS_DATASET_NAME,
95
+ lang,
96
+ split=self.split,
97
+ cache_dir=self.dataset_cache_dir,
98
+ streaming=False,
99
+ )
100
+ for item in ds:
101
+ audio_path = os.path.join(
102
+ os.path.dirname(item["path"]), item["audio"]["path"]
103
+ )
104
+ (sample_id, audio_local_path, waveform, sampling_rate, text) = (
105
+ item["id"],
106
+ audio_path,
107
+ item["audio"]["array"],
108
+ item["audio"]["sampling_rate"],
109
+ item["transcription"],
110
+ )
111
+ yield self._prepare_sample(
112
+ sample_id=sample_id,
113
+ audio_local_path=audio_local_path,
114
+ waveform_npy=waveform,
115
+ sampling_rate=sampling_rate,
116
+ text=text,
117
+ lang=lang,
118
+ )
119
+
120
+ def __iter__(self) -> Iterable[LangPairSample]:
121
+ logger.info(f"Loading {self.target_lang} samples")
122
+ target_samples: Dict[int, MultimodalSample] = {}
123
+ for idx, sample in enumerate(
124
+ self.iterate_lang_audio_samples(lang=self.target_lang)
125
+ ):
126
+ if idx and idx % 100 == 0:
127
+ logger.info(f"..loaded {idx} target samples")
128
+ target_samples[sample.id] = sample
129
+
130
+ logger.info(f"Loading {self.source_lang} samples")
131
+ for idx, sample in enumerate(
132
+ self.iterate_lang_audio_samples(lang=self.source_lang)
133
+ ):
134
+ if idx and idx % 100 == 0:
135
+ logger.info(f"..loaded {idx} source samples")
136
+ if sample.id in target_samples:
137
+ yield LangPairSample(source=sample, target=target_samples[sample.id])