MarcusSu1216 commited on
Commit
14ddef7
1 Parent(s): b5bb8c2

Delete modules/crepe.py

Browse files
Files changed (1) hide show
  1. modules/crepe.py +0 -327
modules/crepe.py DELETED
@@ -1,327 +0,0 @@
1
- from typing import Optional,Union
2
- try:
3
- from typing import Literal
4
- except Exception as e:
5
- from typing_extensions import Literal
6
- import numpy as np
7
- import torch
8
- import torchcrepe
9
- from torch import nn
10
- from torch.nn import functional as F
11
- import scipy
12
-
13
- #from:https://github.com/fishaudio/fish-diffusion
14
-
15
- def repeat_expand(
16
- content: Union[torch.Tensor, np.ndarray], target_len: int, mode: str = "nearest"
17
- ):
18
- """Repeat content to target length.
19
- This is a wrapper of torch.nn.functional.interpolate.
20
-
21
- Args:
22
- content (torch.Tensor): tensor
23
- target_len (int): target length
24
- mode (str, optional): interpolation mode. Defaults to "nearest".
25
-
26
- Returns:
27
- torch.Tensor: tensor
28
- """
29
-
30
- ndim = content.ndim
31
-
32
- if content.ndim == 1:
33
- content = content[None, None]
34
- elif content.ndim == 2:
35
- content = content[None]
36
-
37
- assert content.ndim == 3
38
-
39
- is_np = isinstance(content, np.ndarray)
40
- if is_np:
41
- content = torch.from_numpy(content)
42
-
43
- results = torch.nn.functional.interpolate(content, size=target_len, mode=mode)
44
-
45
- if is_np:
46
- results = results.numpy()
47
-
48
- if ndim == 1:
49
- return results[0, 0]
50
- elif ndim == 2:
51
- return results[0]
52
-
53
-
54
- class BasePitchExtractor:
55
- def __init__(
56
- self,
57
- hop_length: int = 512,
58
- f0_min: float = 50.0,
59
- f0_max: float = 1100.0,
60
- keep_zeros: bool = True,
61
- ):
62
- """Base pitch extractor.
63
-
64
- Args:
65
- hop_length (int, optional): Hop length. Defaults to 512.
66
- f0_min (float, optional): Minimum f0. Defaults to 50.0.
67
- f0_max (float, optional): Maximum f0. Defaults to 1100.0.
68
- keep_zeros (bool, optional): Whether keep zeros in pitch. Defaults to True.
69
- """
70
-
71
- self.hop_length = hop_length
72
- self.f0_min = f0_min
73
- self.f0_max = f0_max
74
- self.keep_zeros = keep_zeros
75
-
76
- def __call__(self, x, sampling_rate=44100, pad_to=None):
77
- raise NotImplementedError("BasePitchExtractor is not callable.")
78
-
79
- def post_process(self, x, sampling_rate, f0, pad_to):
80
- if isinstance(f0, np.ndarray):
81
- f0 = torch.from_numpy(f0).float().to(x.device)
82
-
83
- if pad_to is None:
84
- return f0
85
-
86
- f0 = repeat_expand(f0, pad_to)
87
-
88
- if self.keep_zeros:
89
- return f0
90
-
91
- vuv_vector = torch.zeros_like(f0)
92
- vuv_vector[f0 > 0.0] = 1.0
93
- vuv_vector[f0 <= 0.0] = 0.0
94
-
95
- # 去掉0频率, 并线性插值
96
- nzindex = torch.nonzero(f0).squeeze()
97
- f0 = torch.index_select(f0, dim=0, index=nzindex).cpu().numpy()
98
- time_org = self.hop_length / sampling_rate * nzindex.cpu().numpy()
99
- time_frame = np.arange(pad_to) * self.hop_length / sampling_rate
100
-
101
- if f0.shape[0] <= 0:
102
- return torch.zeros(pad_to, dtype=torch.float, device=x.device),torch.zeros(pad_to, dtype=torch.float, device=x.device)
103
-
104
- if f0.shape[0] == 1:
105
- return torch.ones(pad_to, dtype=torch.float, device=x.device) * f0[0],torch.ones(pad_to, dtype=torch.float, device=x.device)
106
-
107
- # 大概可以用 torch 重写?
108
- f0 = np.interp(time_frame, time_org, f0, left=f0[0], right=f0[-1])
109
- vuv_vector = vuv_vector.cpu().numpy()
110
- vuv_vector = np.ceil(scipy.ndimage.zoom(vuv_vector,pad_to/len(vuv_vector),order = 0))
111
-
112
- return f0,vuv_vector
113
-
114
-
115
- class MaskedAvgPool1d(nn.Module):
116
- def __init__(
117
- self, kernel_size: int, stride: Optional[int] = None, padding: Optional[int] = 0
118
- ):
119
- """An implementation of mean pooling that supports masked values.
120
-
121
- Args:
122
- kernel_size (int): The size of the median pooling window.
123
- stride (int, optional): The stride of the median pooling window. Defaults to None.
124
- padding (int, optional): The padding of the median pooling window. Defaults to 0.
125
- """
126
-
127
- super(MaskedAvgPool1d, self).__init__()
128
- self.kernel_size = kernel_size
129
- self.stride = stride or kernel_size
130
- self.padding = padding
131
-
132
- def forward(self, x, mask=None):
133
- ndim = x.dim()
134
- if ndim == 2:
135
- x = x.unsqueeze(1)
136
-
137
- assert (
138
- x.dim() == 3
139
- ), "Input tensor must have 2 or 3 dimensions (batch_size, channels, width)"
140
-
141
- # Apply the mask by setting masked elements to zero, or make NaNs zero
142
- if mask is None:
143
- mask = ~torch.isnan(x)
144
-
145
- # Ensure mask has the same shape as the input tensor
146
- assert x.shape == mask.shape, "Input tensor and mask must have the same shape"
147
-
148
- masked_x = torch.where(mask, x, torch.zeros_like(x))
149
- # Create a ones kernel with the same number of channels as the input tensor
150
- ones_kernel = torch.ones(x.size(1), 1, self.kernel_size, device=x.device)
151
-
152
- # Perform sum pooling
153
- sum_pooled = nn.functional.conv1d(
154
- masked_x,
155
- ones_kernel,
156
- stride=self.stride,
157
- padding=self.padding,
158
- groups=x.size(1),
159
- )
160
-
161
- # Count the non-masked (valid) elements in each pooling window
162
- valid_count = nn.functional.conv1d(
163
- mask.float(),
164
- ones_kernel,
165
- stride=self.stride,
166
- padding=self.padding,
167
- groups=x.size(1),
168
- )
169
- valid_count = valid_count.clamp(min=1) # Avoid division by zero
170
-
171
- # Perform masked average pooling
172
- avg_pooled = sum_pooled / valid_count
173
-
174
- # Fill zero values with NaNs
175
- avg_pooled[avg_pooled == 0] = float("nan")
176
-
177
- if ndim == 2:
178
- return avg_pooled.squeeze(1)
179
-
180
- return avg_pooled
181
-
182
-
183
- class MaskedMedianPool1d(nn.Module):
184
- def __init__(
185
- self, kernel_size: int, stride: Optional[int] = None, padding: Optional[int] = 0
186
- ):
187
- """An implementation of median pooling that supports masked values.
188
-
189
- This implementation is inspired by the median pooling implementation in
190
- https://gist.github.com/rwightman/f2d3849281624be7c0f11c85c87c1598
191
-
192
- Args:
193
- kernel_size (int): The size of the median pooling window.
194
- stride (int, optional): The stride of the median pooling window. Defaults to None.
195
- padding (int, optional): The padding of the median pooling window. Defaults to 0.
196
- """
197
-
198
- super(MaskedMedianPool1d, self).__init__()
199
- self.kernel_size = kernel_size
200
- self.stride = stride or kernel_size
201
- self.padding = padding
202
-
203
- def forward(self, x, mask=None):
204
- ndim = x.dim()
205
- if ndim == 2:
206
- x = x.unsqueeze(1)
207
-
208
- assert (
209
- x.dim() == 3
210
- ), "Input tensor must have 2 or 3 dimensions (batch_size, channels, width)"
211
-
212
- if mask is None:
213
- mask = ~torch.isnan(x)
214
-
215
- assert x.shape == mask.shape, "Input tensor and mask must have the same shape"
216
-
217
- masked_x = torch.where(mask, x, torch.zeros_like(x))
218
-
219
- x = F.pad(masked_x, (self.padding, self.padding), mode="reflect")
220
- mask = F.pad(
221
- mask.float(), (self.padding, self.padding), mode="constant", value=0
222
- )
223
-
224
- x = x.unfold(2, self.kernel_size, self.stride)
225
- mask = mask.unfold(2, self.kernel_size, self.stride)
226
-
227
- x = x.contiguous().view(x.size()[:3] + (-1,))
228
- mask = mask.contiguous().view(mask.size()[:3] + (-1,)).to(x.device)
229
-
230
- # Combine the mask with the input tensor
231
- #x_masked = torch.where(mask.bool(), x, torch.fill_(torch.zeros_like(x),float("inf")))
232
- x_masked = torch.where(mask.bool(), x, torch.FloatTensor([float("inf")]).to(x.device))
233
-
234
- # Sort the masked tensor along the last dimension
235
- x_sorted, _ = torch.sort(x_masked, dim=-1)
236
-
237
- # Compute the count of non-masked (valid) values
238
- valid_count = mask.sum(dim=-1)
239
-
240
- # Calculate the index of the median value for each pooling window
241
- median_idx = (torch.div((valid_count - 1), 2, rounding_mode='trunc')).clamp(min=0)
242
-
243
- # Gather the median values using the calculated indices
244
- median_pooled = x_sorted.gather(-1, median_idx.unsqueeze(-1).long()).squeeze(-1)
245
-
246
- # Fill infinite values with NaNs
247
- median_pooled[torch.isinf(median_pooled)] = float("nan")
248
-
249
- if ndim == 2:
250
- return median_pooled.squeeze(1)
251
-
252
- return median_pooled
253
-
254
-
255
- class CrepePitchExtractor(BasePitchExtractor):
256
- def __init__(
257
- self,
258
- hop_length: int = 512,
259
- f0_min: float = 50.0,
260
- f0_max: float = 1100.0,
261
- threshold: float = 0.05,
262
- keep_zeros: bool = False,
263
- device = None,
264
- model: Literal["full", "tiny"] = "full",
265
- use_fast_filters: bool = True,
266
- ):
267
- super().__init__(hop_length, f0_min, f0_max, keep_zeros)
268
-
269
- self.threshold = threshold
270
- self.model = model
271
- self.use_fast_filters = use_fast_filters
272
- self.hop_length = hop_length
273
- if device is None:
274
- self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
275
- else:
276
- self.dev = torch.device(device)
277
- if self.use_fast_filters:
278
- self.median_filter = MaskedMedianPool1d(3, 1, 1).to(device)
279
- self.mean_filter = MaskedAvgPool1d(3, 1, 1).to(device)
280
-
281
- def __call__(self, x, sampling_rate=44100, pad_to=None):
282
- """Extract pitch using crepe.
283
-
284
-
285
- Args:
286
- x (torch.Tensor): Audio signal, shape (1, T).
287
- sampling_rate (int, optional): Sampling rate. Defaults to 44100.
288
- pad_to (int, optional): Pad to length. Defaults to None.
289
-
290
- Returns:
291
- torch.Tensor: Pitch, shape (T // hop_length,).
292
- """
293
-
294
- assert x.ndim == 2, f"Expected 2D tensor, got {x.ndim}D tensor."
295
- assert x.shape[0] == 1, f"Expected 1 channel, got {x.shape[0]} channels."
296
-
297
- x = x.to(self.dev)
298
- f0, pd = torchcrepe.predict(
299
- x,
300
- sampling_rate,
301
- self.hop_length,
302
- self.f0_min,
303
- self.f0_max,
304
- pad=True,
305
- model=self.model,
306
- batch_size=1024,
307
- device=x.device,
308
- return_periodicity=True,
309
- )
310
-
311
- # Filter, remove silence, set uv threshold, refer to the original warehouse readme
312
- if self.use_fast_filters:
313
- pd = self.median_filter(pd)
314
- else:
315
- pd = torchcrepe.filter.median(pd, 3)
316
-
317
- pd = torchcrepe.threshold.Silence(-60.0)(pd, x, sampling_rate, 512)
318
- f0 = torchcrepe.threshold.At(self.threshold)(f0, pd)
319
-
320
- if self.use_fast_filters:
321
- f0 = self.mean_filter(f0)
322
- else:
323
- f0 = torchcrepe.filter.mean(f0, 3)
324
-
325
- f0 = torch.where(torch.isnan(f0), torch.full_like(f0, 0), f0)[0]
326
-
327
- return self.post_process(x, sampling_rate, f0, pad_to)