Spaces:
Running
Running
LanguageBind
commited on
Commit
•
7443476
1
Parent(s):
1f7a501
Update languagebind/audio/processing_audio.py
Browse files
languagebind/audio/processing_audio.py
CHANGED
@@ -29,28 +29,31 @@ def float32_to_int16_torch(x):
|
|
29 |
DEFAULT_AUDIO_FRAME_SHIFT_MS = 10
|
30 |
|
31 |
class AudioTransform:
|
32 |
-
def __init__(self,
|
33 |
-
self.sample_rate =
|
34 |
-
self.num_mel_bins =
|
35 |
-
self.target_length =
|
36 |
-
self.audio_mean =
|
37 |
-
self.audio_std =
|
|
|
|
|
38 |
# mean=-4.2677393
|
39 |
# std=4.5689974
|
40 |
-
self.norm = transforms.Normalize(mean=self.audio_mean, std=self.audio_std)
|
|
|
41 |
|
42 |
def __call__(self, audio_data_and_origin_sr):
|
43 |
audio_data, origin_sr = audio_data_and_origin_sr
|
44 |
if self.sample_rate != origin_sr:
|
45 |
# print(audio_data.shape, origin_sr)
|
46 |
audio_data = torchaudio.functional.resample(audio_data, orig_freq=origin_sr, new_freq=self.sample_rate)
|
47 |
-
waveform_melspec = self.waveform2melspec(audio_data
|
48 |
-
return
|
|
|
49 |
|
50 |
def waveform2melspec(self, audio_data):
|
51 |
-
|
52 |
-
if
|
53 |
-
mel = self.get_mel(audio_data)
|
54 |
# split to three parts
|
55 |
chunk_frames = self.target_length
|
56 |
total_frames = mel.shape[0]
|
@@ -64,60 +67,38 @@ class AudioTransform:
|
|
64 |
if len(ranges[2]) == 0: # if the audio is too short, we just use the first chunk
|
65 |
ranges[2] = [0]
|
66 |
# randomly choose index for each part
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
idx_front = ranges[0][0] # fixed
|
71 |
-
idx_middle = ranges[1][0]
|
72 |
-
idx_back = ranges[2][0]
|
73 |
# select mel
|
74 |
mel_chunk_front = mel[idx_front:idx_front + chunk_frames, :]
|
75 |
mel_chunk_middle = mel[idx_middle:idx_middle + chunk_frames, :]
|
76 |
mel_chunk_back = mel[idx_back:idx_back + chunk_frames, :]
|
|
|
77 |
# stack
|
78 |
mel_fusion = torch.stack([mel_chunk_front, mel_chunk_middle, mel_chunk_back], dim=0)
|
79 |
-
elif
|
80 |
-
n_repeat = int(
|
81 |
-
|
82 |
-
|
83 |
-
audio_data,
|
84 |
-
(0, max_len - len(audio_data)),
|
85 |
-
mode="constant",
|
86 |
-
value=0,
|
87 |
-
)
|
88 |
-
mel = self.get_mel(audio_data)
|
89 |
mel_fusion = torch.stack([mel, mel, mel], dim=0)
|
90 |
else: # if equal
|
91 |
-
mel = self.get_mel(audio_data)
|
92 |
mel_fusion = torch.stack([mel, mel, mel], dim=0)
|
93 |
-
|
94 |
-
# twice check
|
95 |
-
p = self.target_length - mel_fusion.shape[1]
|
96 |
-
|
97 |
-
# if abs(p) / self.target_length > 0.2:
|
98 |
-
# logging.warning(
|
99 |
-
# "Large gap between audio n_frames(%d) and "
|
100 |
-
# "target_length (%d). Is the audio_target_length "
|
101 |
-
# "setting correct?",
|
102 |
-
# mel_fusion.shape[1],
|
103 |
-
# self.target_length,
|
104 |
-
# )
|
105 |
-
|
106 |
-
# cut and pad
|
107 |
-
if p > 0:
|
108 |
-
m = torch.nn.ZeroPad2d((0, 0, 0, p))
|
109 |
-
mel_fusion = m(mel_fusion)
|
110 |
-
elif p < 0:
|
111 |
-
mel_fusion = mel_fusion[:, 0: self.target_length, :]
|
112 |
-
|
113 |
mel_fusion = mel_fusion.transpose(1, 2) # [3, target_length, mel_bins] -> [3, mel_bins, target_length]
|
|
|
|
|
|
|
|
|
114 |
return mel_fusion
|
115 |
|
116 |
def get_mel(self, audio_data):
|
117 |
# mel shape: (n_mels, T)
|
118 |
audio_data -= audio_data.mean()
|
119 |
mel = torchaudio.compliance.kaldi.fbank(
|
120 |
-
audio_data
|
121 |
htk_compat=True,
|
122 |
sample_frequency=self.sample_rate,
|
123 |
use_energy=False,
|
@@ -129,6 +110,7 @@ class AudioTransform:
|
|
129 |
)
|
130 |
return mel # (T, n_mels)
|
131 |
|
|
|
132 |
def get_audio_transform(config):
|
133 |
config = config.vision_config
|
134 |
return AudioTransform(config)
|
|
|
29 |
DEFAULT_AUDIO_FRAME_SHIFT_MS = 10
|
30 |
|
31 |
class AudioTransform:
|
32 |
+
def __init__(self, args):
|
33 |
+
self.sample_rate = args.audio_sample_rate
|
34 |
+
self.num_mel_bins = args.num_mel_bins
|
35 |
+
self.target_length = args.target_length
|
36 |
+
self.audio_mean = args.audio_mean
|
37 |
+
self.audio_std = args.audio_std
|
38 |
+
self.mean = []
|
39 |
+
self.std = []
|
40 |
# mean=-4.2677393
|
41 |
# std=4.5689974
|
42 |
+
# self.norm = transforms.Normalize(mean=self.audio_mean, std=self.audio_std)
|
43 |
+
|
44 |
|
45 |
def __call__(self, audio_data_and_origin_sr):
|
46 |
audio_data, origin_sr = audio_data_and_origin_sr
|
47 |
if self.sample_rate != origin_sr:
|
48 |
# print(audio_data.shape, origin_sr)
|
49 |
audio_data = torchaudio.functional.resample(audio_data, orig_freq=origin_sr, new_freq=self.sample_rate)
|
50 |
+
waveform_melspec = self.waveform2melspec(audio_data)
|
51 |
+
return waveform_melspec
|
52 |
+
|
53 |
|
54 |
def waveform2melspec(self, audio_data):
|
55 |
+
mel = self.get_mel(audio_data)
|
56 |
+
if mel.shape[0] > self.target_length:
|
|
|
57 |
# split to three parts
|
58 |
chunk_frames = self.target_length
|
59 |
total_frames = mel.shape[0]
|
|
|
67 |
if len(ranges[2]) == 0: # if the audio is too short, we just use the first chunk
|
68 |
ranges[2] = [0]
|
69 |
# randomly choose index for each part
|
70 |
+
idx_front = np.random.choice(ranges[0])
|
71 |
+
idx_middle = np.random.choice(ranges[1])
|
72 |
+
idx_back = np.random.choice(ranges[2])
|
73 |
+
# idx_front = ranges[0][0] # fixed
|
74 |
+
# idx_middle = ranges[1][0]
|
75 |
+
# idx_back = ranges[2][0]
|
76 |
# select mel
|
77 |
mel_chunk_front = mel[idx_front:idx_front + chunk_frames, :]
|
78 |
mel_chunk_middle = mel[idx_middle:idx_middle + chunk_frames, :]
|
79 |
mel_chunk_back = mel[idx_back:idx_back + chunk_frames, :]
|
80 |
+
# print(total_frames, idx_front, idx_front + chunk_frames, idx_middle, idx_middle + chunk_frames, idx_back, idx_back + chunk_frames)
|
81 |
# stack
|
82 |
mel_fusion = torch.stack([mel_chunk_front, mel_chunk_middle, mel_chunk_back], dim=0)
|
83 |
+
elif mel.shape[0] < self.target_length: # padding if too short
|
84 |
+
n_repeat = int(self.target_length / mel.shape[0]) + 1
|
85 |
+
# print(self.target_length, mel.shape[0], n_repeat)
|
86 |
+
mel = mel.repeat(n_repeat, 1)[:self.target_length, :]
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
mel_fusion = torch.stack([mel, mel, mel], dim=0)
|
88 |
else: # if equal
|
|
|
89 |
mel_fusion = torch.stack([mel, mel, mel], dim=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
mel_fusion = mel_fusion.transpose(1, 2) # [3, target_length, mel_bins] -> [3, mel_bins, target_length]
|
91 |
+
|
92 |
+
# self.mean.append(mel_fusion.mean())
|
93 |
+
# self.std.append(mel_fusion.std())
|
94 |
+
mel_fusion = (mel_fusion - self.audio_mean) / (self.audio_std * 2)
|
95 |
return mel_fusion
|
96 |
|
97 |
def get_mel(self, audio_data):
|
98 |
# mel shape: (n_mels, T)
|
99 |
audio_data -= audio_data.mean()
|
100 |
mel = torchaudio.compliance.kaldi.fbank(
|
101 |
+
audio_data,
|
102 |
htk_compat=True,
|
103 |
sample_frequency=self.sample_rate,
|
104 |
use_energy=False,
|
|
|
110 |
)
|
111 |
return mel # (T, n_mels)
|
112 |
|
113 |
+
|
114 |
def get_audio_transform(config):
|
115 |
config = config.vision_config
|
116 |
return AudioTransform(config)
|