Transformers
File size: 8,730 Bytes
2cbf616
 
 
 
 
 
 
 
 
a0c5f82
2cbf616
9336925
 
2cbf616
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0c5f82
2cbf616
 
a0c5f82
2cbf616
 
 
a0c5f82
 
 
 
2cbf616
 
a0c5f82
 
 
 
 
 
 
 
 
9bbab52
2cbf616
a0c5f82
 
 
2cbf616
 
 
9336925
 
 
 
b8204f4
2cbf616
 
 
 
 
 
 
 
4ac1f20
 
 
2cbf616
a0c5f82
2ea2657
2cbf616
 
 
 
 
 
 
 
 
4b48f8f
2cbf616
6081bb6
aecd90b
6081bb6
2ea2657
2cbf616
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8204f4
2cbf616
 
b8204f4
2cbf616
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0c5f82
2cbf616
 
 
 
 
 
a0c5f82
2cbf616
 
a0c5f82
 
2cbf616
 
 
 
 
 
 
 
 
 
 
 
6081bb6
 
2cbf616
 
a0c5f82
 
 
 
 
 
 
 
 
 
 
9336925
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c43a294
9336925
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257

import amfm_decompy.basic_tools as basic
import amfm_decompy.pYAAPT as pYAAPT
from dataclasses import dataclass
from typing import Dict, List, Optional
import numpy as np
import torch
import dataclasses
import parselmouth
from transformers import PreTrainedModel,PretrainedConfig, FeatureExtractionMixin
from datasets import Dataset
from scipy.signal import medfilt
import scipy.interpolate as scipy_interp

@dataclass
class SpeakerStats:
    f0_mean: float
    f0_std: float
    intensity_mean: float
    intensity_std: float
    
    @classmethod
    def from_features(cls, f0_values: List[np.ndarray], intensity_values: List[np.ndarray]):
        """Calculate stats from a list of features"""
        # Convert lists to numpy arrays
        f0_arrays = [np.array(f0) for f0 in f0_values]
        intensity_arrays = [np.array(i) for i in intensity_values]
        
        # Now we can use numpy operations
        f0_concat = np.concatenate([f0[f0 != 0] for f0 in f0_arrays])
        intensity_concat = np.concatenate(intensity_arrays)
        
        
        return cls(
            f0_mean=float(np.mean(f0_concat)),
            f0_std=float(np.std(f0_concat)),
            intensity_mean=float(np.mean(intensity_concat)),
            intensity_std=float(np.std(intensity_concat))
        )

class ProsodyConfig(PretrainedConfig):
    """Configuration class for prosody preprocessing"""
    model_type = "prosody_preprocessor"
    
    def __init__(
        self,
        sampling_rate: int = 16000,
        frame_length: float = 20.0,  # in ms
        frame_space: float = 5.0,   # in ms
        torch_dtype: str = "float32",  # Add default torch_dtype
        **kwargs
    ):
        super().__init__(torch_dtype=torch_dtype, **kwargs)  # Pass torch_dtype to parent
        self.sampling_rate = sampling_rate
        self.frame_length = frame_length
        self.frame_space = frame_space



class ProsodyPreprocessor(FeatureExtractionMixin):
    config_class = ProsodyConfig
    
    def __init__(self,
            sampling_rate: int = 16000,
            frame_length: float = 20.0,  # in ms
            frame_space: float = 5.0,   # in ms
            torch_dtype: str = "float32",  # Add default torch_dtype
             config: Optional[ProsodyConfig] = None,
            **kwargs):
        # config = config or ProsodyConfig()
        super().__init__()
        self.config = config
        self.speaker_stats: Dict[str, SpeakerStats] = {}
        self.sampling_rate = sampling_rate
        self.frame_length = frame_length
        self.frame_space = frame_space
        
    def extract_features(self, audio):
        """Extract F0 and intensity features"""


        # Override the original method to fix a bug
        pYAAPT.PitchObj.interpolate = interpolate
        
        audio = torch.Tensor(audio)

        if audio.dim() == 1:
            audio = audio.unsqueeze(0)
        f0, f0_interp = self._get_f0(audio)
        f0 = f0[0, 0, :]
        f0_interpolated = f0_interp[0, 0, :]
        
        
        f0 = f0[6:]
        f0_interpolated = f0_interpolated[6:]

        sound = parselmouth.Sound(audio.numpy(), sampling_frequency=self.sampling_rate, start_time=0)
        
        
        # Extract intensity at 200Hz
        intensity = sound.to_intensity(time_step=1/200.0)
        intensity_values = intensity.values.T.flatten()
        
        
        # Ensure same length
        min_len = min(len(f0), len(intensity))
        f0 = f0[:min_len]
        f0_interpolated = f0_interpolated[:min_len]
        intensity_values = intensity_values[:min_len]

        intensity_values[intensity_values < 20] = 20
      
        
        return {
            "f0": f0,
            "f0_interp": f0_interpolated,
            "intensity": intensity_values,
        }
        
    def collect_stats(self, dataset: Dataset, num_proc: int = 4, batch_size: int = 32) -> Dict[str, SpeakerStats]:
        """First pass: collect speaker statistics using dataset.map"""
        
        def extract_features_batch(examples):
            features_list = []
            for audio in examples['audio']:
                features = self.extract_features(audio)
                features_list.append(features)
            
            return {
                'f0': [f['f0'] for f in features_list],
                'intensity': [f['intensity'] for f in features_list],
                'speaker_id': examples['speaker_id']
            }
        
        features_dataset = dataset.map(
            extract_features_batch,
            batched=True,
            batch_size=batch_size,
            num_proc=num_proc,
            # load_from_cache_file=False
            remove_columns=dataset.column_names
        )
        
   
        speaker_features = {}
        for item in features_dataset:
            
            speaker_id = item['speaker_id']
            if speaker_id not in speaker_features:
                speaker_features[speaker_id] = {'f0': [], 'intensity': []}
            
            speaker_features[speaker_id]['f0'].append(item['f0'])
            speaker_features[speaker_id]['intensity'].append(item['intensity'])
        
        self.speaker_stats = {
            spk: SpeakerStats.from_features(
                feats['f0'],
                feats['intensity']
            )
            for spk, feats in speaker_features.items()
        }
        
        return features_dataset, self.speaker_stats
    
    def save_stats(self, path: str):
        """Save speaker stats to file"""
        stats_dict = {
            spk: dataclasses.asdict(stats)
            for spk, stats in self.speaker_stats.items()
        }
        torch.save(stats_dict, path)
    
    @classmethod
    def load_stats(cls, path: str) -> Dict[str, SpeakerStats]:
        """Load speaker stats from file"""
        stats_dict = torch.load(path)
        return {
            spk: SpeakerStats(**stats)
            for spk, stats in stats_dict.items()
        }
    def _get_f0(self, audio: torch.Tensor):
        """Extract F0 using YAAPT."""
        to_pad = int(self.frame_length / 1000 * self.sampling_rate) // 2
        
        f0s = []
        f0s_interp = []
        
        for y in audio.numpy().astype(np.float64):
            y_pad = np.pad(y.squeeze(), (to_pad, to_pad), "constant", constant_values=0)
            signal = basic.SignalObj(y_pad, self.sampling_rate)
            pitch = pYAAPT.yaapt(
                signal,
                frame_length=self.frame_length,
                frame_space=self.frame_space,
                nccf_thresh1=0.25,
                tda_frame_length=25.0
            )
            f0s_interp.append(pitch.samp_interp[None, None, :])
            f0s.append(pitch.samp_values[None, None, :])
        
        f0 = np.vstack(f0s)
        f0_interp = np.vstack(f0s_interp)
        
        # Apply frequency threshold
        f0[f0 > 500] = 0
        f0_interp[f0_interp > 500] = 0
        f0[f0 < 0] = 0
        f0_interp[f0_interp < 0] = 0
        
        return f0, f0_interp

    # def save_pretrained(self, save_directory: str, **kwargs):
    #     """Save the preprocessor configuration."""
    #     self.config.save_pretrained(save_directory)
    #
    # def _load_pretrained_model(self, **kwargs):
    #     """Override _load_pretrained_model to load speaker stats"""
    #     # self.speaker_stats = {
    #     #     spk: SpeakerStats(**stats)
    #     #     for spk, stats in state_dict.items()
    #     # }


def interpolate(self):
    pitch = np.zeros((self.nframes))
    pitch[:] = self.samp_values
    pitch2 = medfilt(self.samp_values, self.SMOOTH_FACTOR)

    # This part in the original code is kind of confused and caused
    # some problems with the extrapolated points before the first
    # voiced frame and after the last voiced frame. So, I made some
    # small modifications in order to make it work better.
    edges = self.edges_finder(pitch)
    first_sample = pitch[0]
    last_sample = pitch[-1]

    if len(np.nonzero(pitch2)[0]) < 2:
        pitch[pitch == 0] = self.PTCH_TYP
    else:
        nz_pitch = pitch2[pitch2 > 0]
        pitch2 = scipy_interp.pchip(np.nonzero(pitch2)[0],
                                    nz_pitch)(range(self.nframes))
        pitch[pitch == 0] = pitch2[pitch == 0]
    if self.SMOOTH > 0:
        pitch = medfilt(pitch, self.SMOOTH_FACTOR)
    try:
        if first_sample == 0:
            # This if statement fixes the bug that caused the whole f0 to be flattened
            if edges[0] == 0: 
                edges[0] = 1
            pitch[:edges[0]-1] = pitch[edges[0]]
        if last_sample == 0:
            pitch[edges[-1]+1:] = pitch[edges[-1]]
    except:
        pass
    self.samp_interp = pitch