alefiury commited on
Commit
7a28165
1 Parent(s): b3117af

Update README.md

Browse files

Update of the example inference code

Files changed (1) hide show
  1. README.md +29 -25
README.md CHANGED
@@ -26,6 +26,8 @@ It achieves the following results on the evaluation set:
26
 
27
  ```python
28
  import os
 
 
29
  from typing import List, Optional, Union, Dict
30
 
31
  import tqdm
@@ -42,7 +44,6 @@ from transformers import (
42
  Wav2Vec2Processor
43
  )
44
 
45
-
46
  class CustomDataset(torch.utils.data.Dataset):
47
  def __init__(
48
  self,
@@ -68,19 +69,19 @@ class CustomDataset(torch.utils.data.Dataset):
68
  filepath = self.dataset[index]
69
  else:
70
  filepath = os.path.join(self.basedir, self.dataset[index])
71
-
72
  speech_array, sr = torchaudio.load(filepath)
73
-
74
  if speech_array.shape[0] > 1:
75
  speech_array = torch.mean(speech_array, dim=0, keepdim=True)
76
-
77
  if sr != self.sampling_rate:
78
  transform = torchaudio.transforms.Resample(sr, self.sampling_rate)
79
  speech_array = transform(speech_array)
80
  sr = self.sampling_rate
81
-
82
  len_audio = speech_array.shape[1]
83
-
84
  # Pad or truncate the audio to match the desired length
85
  if len_audio < self.max_audio_len * self.sampling_rate:
86
  # Pad the audio if it's shorter than the desired length
@@ -89,9 +90,9 @@ class CustomDataset(torch.utils.data.Dataset):
89
  else:
90
  # Truncate the audio if it's longer than the desired length
91
  speech_array = speech_array[:, :self.max_audio_len * self.sampling_rate]
92
-
93
  speech_array = speech_array.squeeze().numpy()
94
-
95
  return {"input_values": speech_array, "attention_mask": None}
96
 
97
 
@@ -99,34 +100,37 @@ class CollateFunc:
99
  def __init__(
100
  self,
101
  processor: Wav2Vec2Processor,
102
- max_length: Optional[int] = None,
103
  padding: Union[bool, str] = True,
104
  pad_to_multiple_of: Optional[int] = None,
 
105
  sampling_rate: int = 16000,
 
106
  ):
107
- self.padding = padding
108
- self.processor = processor
109
- self.max_length = max_length
110
  self.sampling_rate = sampling_rate
 
 
111
  self.pad_to_multiple_of = pad_to_multiple_of
 
 
112
 
113
- def __call__(self, batch: List):
114
- input_features = []
115
-
116
- for audio in batch:
117
- input_tensor = self.processor(audio, sampling_rate=self.sampling_rate).input_values
118
- input_tensor = np.squeeze(input_tensor)
119
- input_features.append({"input_values": input_tensor})
120
 
121
- batch = self.processor.pad(
122
- input_features,
 
 
123
  padding=self.padding,
124
  max_length=self.max_length,
125
  pad_to_multiple_of=self.pad_to_multiple_of,
126
- return_tensors="pt",
127
  )
128
 
129
- return batch
 
 
 
130
 
131
 
132
  def predict(test_dataloader, model, device: torch.device):
@@ -175,15 +179,15 @@ def get_gender(model_name_or_path: str, audio_paths: List[str], label2id: Dict,
175
  batch_size=16,
176
  collate_fn=data_collator,
177
  shuffle=False,
178
- num_workers=10
179
  )
180
 
181
  preds = predict(test_dataloader=test_dataloader, model=model, device=device)
182
 
183
  return preds
184
 
185
-
186
  model_name_or_path = "alefiury/wav2vec2-large-xlsr-53-gender-recognition-librispeech"
 
187
  audio_paths = [] # Must be a list with absolute paths of the audios that will be used in inference
188
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
189
 
 
26
 
27
  ```python
28
  import os
29
+ import random
30
+ from glob import glob
31
  from typing import List, Optional, Union, Dict
32
 
33
  import tqdm
 
44
  Wav2Vec2Processor
45
  )
46
 
 
47
  class CustomDataset(torch.utils.data.Dataset):
48
  def __init__(
49
  self,
 
69
  filepath = self.dataset[index]
70
  else:
71
  filepath = os.path.join(self.basedir, self.dataset[index])
72
+
73
  speech_array, sr = torchaudio.load(filepath)
74
+
75
  if speech_array.shape[0] > 1:
76
  speech_array = torch.mean(speech_array, dim=0, keepdim=True)
77
+
78
  if sr != self.sampling_rate:
79
  transform = torchaudio.transforms.Resample(sr, self.sampling_rate)
80
  speech_array = transform(speech_array)
81
  sr = self.sampling_rate
82
+
83
  len_audio = speech_array.shape[1]
84
+
85
  # Pad or truncate the audio to match the desired length
86
  if len_audio < self.max_audio_len * self.sampling_rate:
87
  # Pad the audio if it's shorter than the desired length
 
90
  else:
91
  # Truncate the audio if it's longer than the desired length
92
  speech_array = speech_array[:, :self.max_audio_len * self.sampling_rate]
93
+
94
  speech_array = speech_array.squeeze().numpy()
95
+
96
  return {"input_values": speech_array, "attention_mask": None}
97
 
98
 
 
100
  def __init__(
101
  self,
102
  processor: Wav2Vec2Processor,
 
103
  padding: Union[bool, str] = True,
104
  pad_to_multiple_of: Optional[int] = None,
105
+ return_attention_mask: bool = True,
106
  sampling_rate: int = 16000,
107
+ max_length: Optional[int] = None,
108
  ):
 
 
 
109
  self.sampling_rate = sampling_rate
110
+ self.processor = processor
111
+ self.padding = padding
112
  self.pad_to_multiple_of = pad_to_multiple_of
113
+ self.return_attention_mask = return_attention_mask
114
+ self.max_length = max_length
115
 
116
+ def __call__(self, batch: List[Dict[str, np.ndarray]]):
117
+ # Extract input_values from the batch
118
+ input_values = [item["input_values"] for item in batch]
 
 
 
 
119
 
120
+ batch = self.processor(
121
+ input_values,
122
+ sampling_rate=self.sampling_rate,
123
+ return_tensors="pt",
124
  padding=self.padding,
125
  max_length=self.max_length,
126
  pad_to_multiple_of=self.pad_to_multiple_of,
127
+ return_attention_mask=self.return_attention_mask
128
  )
129
 
130
+ return {
131
+ "input_values": batch.input_values,
132
+ "attention_mask": batch.attention_mask if self.return_attention_mask else None
133
+ }
134
 
135
 
136
  def predict(test_dataloader, model, device: torch.device):
 
179
  batch_size=16,
180
  collate_fn=data_collator,
181
  shuffle=False,
182
+ num_workers=2
183
  )
184
 
185
  preds = predict(test_dataloader=test_dataloader, model=model, device=device)
186
 
187
  return preds
188
 
 
189
  model_name_or_path = "alefiury/wav2vec2-large-xlsr-53-gender-recognition-librispeech"
190
+
191
  audio_paths = [] # Must be a list with absolute paths of the audios that will be used in inference
192
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
193