LanguageBind commited on
Commit
7443476
1 Parent(s): 1f7a501

Update languagebind/audio/processing_audio.py

Browse files
languagebind/audio/processing_audio.py CHANGED
@@ -29,28 +29,31 @@ def float32_to_int16_torch(x):
29
  DEFAULT_AUDIO_FRAME_SHIFT_MS = 10
30
 
31
  class AudioTransform:
32
- def __init__(self, config):
33
- self.sample_rate = config.audio_sample_rate
34
- self.num_mel_bins = config.num_mel_bins
35
- self.target_length = config.target_length
36
- self.audio_mean = config.audio_mean
37
- self.audio_std = config.audio_std
 
 
38
  # mean=-4.2677393
39
  # std=4.5689974
40
- self.norm = transforms.Normalize(mean=self.audio_mean, std=self.audio_std)
 
41
 
42
  def __call__(self, audio_data_and_origin_sr):
43
  audio_data, origin_sr = audio_data_and_origin_sr
44
  if self.sample_rate != origin_sr:
45
  # print(audio_data.shape, origin_sr)
46
  audio_data = torchaudio.functional.resample(audio_data, orig_freq=origin_sr, new_freq=self.sample_rate)
47
- waveform_melspec = self.waveform2melspec(audio_data[0])
48
- return self.norm(waveform_melspec)
 
49
 
50
  def waveform2melspec(self, audio_data):
51
- max_len = self.target_length * self.sample_rate // 100
52
- if audio_data.shape[-1] > max_len:
53
- mel = self.get_mel(audio_data)
54
  # split to three parts
55
  chunk_frames = self.target_length
56
  total_frames = mel.shape[0]
@@ -64,60 +67,38 @@ class AudioTransform:
64
  if len(ranges[2]) == 0: # if the audio is too short, we just use the first chunk
65
  ranges[2] = [0]
66
  # randomly choose index for each part
67
- # idx_front = np.random.choice(ranges[0])
68
- # idx_middle = np.random.choice(ranges[1])
69
- # idx_back = np.random.choice(ranges[2])
70
- idx_front = ranges[0][0] # fixed
71
- idx_middle = ranges[1][0]
72
- idx_back = ranges[2][0]
73
  # select mel
74
  mel_chunk_front = mel[idx_front:idx_front + chunk_frames, :]
75
  mel_chunk_middle = mel[idx_middle:idx_middle + chunk_frames, :]
76
  mel_chunk_back = mel[idx_back:idx_back + chunk_frames, :]
 
77
  # stack
78
  mel_fusion = torch.stack([mel_chunk_front, mel_chunk_middle, mel_chunk_back], dim=0)
79
- elif audio_data.shape[-1] < max_len: # padding if too short
80
- n_repeat = int(max_len / len(audio_data))
81
- audio_data = audio_data.repeat(n_repeat)
82
- audio_data = F.pad(
83
- audio_data,
84
- (0, max_len - len(audio_data)),
85
- mode="constant",
86
- value=0,
87
- )
88
- mel = self.get_mel(audio_data)
89
  mel_fusion = torch.stack([mel, mel, mel], dim=0)
90
  else: # if equal
91
- mel = self.get_mel(audio_data)
92
  mel_fusion = torch.stack([mel, mel, mel], dim=0)
93
-
94
- # twice check
95
- p = self.target_length - mel_fusion.shape[1]
96
-
97
- # if abs(p) / self.target_length > 0.2:
98
- # logging.warning(
99
- # "Large gap between audio n_frames(%d) and "
100
- # "target_length (%d). Is the audio_target_length "
101
- # "setting correct?",
102
- # mel_fusion.shape[1],
103
- # self.target_length,
104
- # )
105
-
106
- # cut and pad
107
- if p > 0:
108
- m = torch.nn.ZeroPad2d((0, 0, 0, p))
109
- mel_fusion = m(mel_fusion)
110
- elif p < 0:
111
- mel_fusion = mel_fusion[:, 0: self.target_length, :]
112
-
113
  mel_fusion = mel_fusion.transpose(1, 2) # [3, target_length, mel_bins] -> [3, mel_bins, target_length]
 
 
 
 
114
  return mel_fusion
115
 
116
  def get_mel(self, audio_data):
117
  # mel shape: (n_mels, T)
118
  audio_data -= audio_data.mean()
119
  mel = torchaudio.compliance.kaldi.fbank(
120
- audio_data.unsqueeze(0),
121
  htk_compat=True,
122
  sample_frequency=self.sample_rate,
123
  use_energy=False,
@@ -129,6 +110,7 @@ class AudioTransform:
129
  )
130
  return mel # (T, n_mels)
131
 
 
132
  def get_audio_transform(config):
133
  config = config.vision_config
134
  return AudioTransform(config)
 
29
  DEFAULT_AUDIO_FRAME_SHIFT_MS = 10
30
 
31
  class AudioTransform:
32
+ def __init__(self, args):
33
+ self.sample_rate = args.audio_sample_rate
34
+ self.num_mel_bins = args.num_mel_bins
35
+ self.target_length = args.target_length
36
+ self.audio_mean = args.audio_mean
37
+ self.audio_std = args.audio_std
38
+ self.mean = []
39
+ self.std = []
40
  # mean=-4.2677393
41
  # std=4.5689974
42
+ # self.norm = transforms.Normalize(mean=self.audio_mean, std=self.audio_std)
43
+
44
 
45
  def __call__(self, audio_data_and_origin_sr):
46
  audio_data, origin_sr = audio_data_and_origin_sr
47
  if self.sample_rate != origin_sr:
48
  # print(audio_data.shape, origin_sr)
49
  audio_data = torchaudio.functional.resample(audio_data, orig_freq=origin_sr, new_freq=self.sample_rate)
50
+ waveform_melspec = self.waveform2melspec(audio_data)
51
+ return waveform_melspec
52
+
53
 
54
  def waveform2melspec(self, audio_data):
55
+ mel = self.get_mel(audio_data)
56
+ if mel.shape[0] > self.target_length:
 
57
  # split to three parts
58
  chunk_frames = self.target_length
59
  total_frames = mel.shape[0]
 
67
  if len(ranges[2]) == 0: # if the audio is too short, we just use the first chunk
68
  ranges[2] = [0]
69
  # randomly choose index for each part
70
+ idx_front = np.random.choice(ranges[0])
71
+ idx_middle = np.random.choice(ranges[1])
72
+ idx_back = np.random.choice(ranges[2])
73
+ # idx_front = ranges[0][0] # fixed
74
+ # idx_middle = ranges[1][0]
75
+ # idx_back = ranges[2][0]
76
  # select mel
77
  mel_chunk_front = mel[idx_front:idx_front + chunk_frames, :]
78
  mel_chunk_middle = mel[idx_middle:idx_middle + chunk_frames, :]
79
  mel_chunk_back = mel[idx_back:idx_back + chunk_frames, :]
80
+ # print(total_frames, idx_front, idx_front + chunk_frames, idx_middle, idx_middle + chunk_frames, idx_back, idx_back + chunk_frames)
81
  # stack
82
  mel_fusion = torch.stack([mel_chunk_front, mel_chunk_middle, mel_chunk_back], dim=0)
83
+ elif mel.shape[0] < self.target_length: # padding if too short
84
+ n_repeat = int(self.target_length / mel.shape[0]) + 1
85
+ # print(self.target_length, mel.shape[0], n_repeat)
86
+ mel = mel.repeat(n_repeat, 1)[:self.target_length, :]
 
 
 
 
 
 
87
  mel_fusion = torch.stack([mel, mel, mel], dim=0)
88
  else: # if equal
 
89
  mel_fusion = torch.stack([mel, mel, mel], dim=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  mel_fusion = mel_fusion.transpose(1, 2) # [3, target_length, mel_bins] -> [3, mel_bins, target_length]
91
+
92
+ # self.mean.append(mel_fusion.mean())
93
+ # self.std.append(mel_fusion.std())
94
+ mel_fusion = (mel_fusion - self.audio_mean) / (self.audio_std * 2)
95
  return mel_fusion
96
 
97
  def get_mel(self, audio_data):
98
  # mel shape: (n_mels, T)
99
  audio_data -= audio_data.mean()
100
  mel = torchaudio.compliance.kaldi.fbank(
101
+ audio_data,
102
  htk_compat=True,
103
  sample_frequency=self.sample_rate,
104
  use_energy=False,
 
110
  )
111
  return mel # (T, n_mels)
112
 
113
+
114
  def get_audio_transform(config):
115
  config = config.vision_config
116
  return AudioTransform(config)