wanchichen commited on
Commit
9632a6b
1 Parent(s): 00344c3

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +37 -0
README.md CHANGED
@@ -158,6 +158,43 @@ language:
158
 
159
  XEUS tops the [ML-SUPERB]() multilingual speech recognition leaderboard, outperforming [MMS](), [w2v-BERT 2.0](), and [XLS-R](). XEUS also sets a new state-of-the-art on 4 tasks in the monolingual [SUPERB]() benchmark.
160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  ## Results
162
 
163
  ![image/png](https://cdn-uploads.huggingface.co/production/uploads/630438615c70c21d0eae6613/RCAWBxSuDLXJ5zdj-OBdn.png)
 
158
 
159
  XEUS tops the [ML-SUPERB]() multilingual speech recognition leaderboard, outperforming [MMS](), [w2v-BERT 2.0](), and [XLS-R](). XEUS also sets a new state-of-the-art on 4 tasks in the monolingual [SUPERB]() benchmark.
160
 
161
+ ## Requirements
162
+
163
+ The code for XEUS is still in progress of being merged into the main ESPnet repo. It can instead be used from the following fork:
164
+
165
+ ```
166
+ pip install -e git+git://github.com/wanchichen/espnet.git@ssl
167
+ ```
168
+
169
+ XEUS supports [Flash Attention], which can be installed as follows:
170
+
171
+ ```
172
+ pip install flash-attn --no-build-isolation
173
+ ```
174
+
175
+ ## Usage
176
+
177
+ ```
178
+ from torch.nn.utils.rnn import pad_sequence
179
+ from espnet2.tasks.ssl import SSLTask
180
+ import soundfile as sf
181
+
182
+ device = "cuda" if torch.cuda.is_available() else "cpu"
183
+
184
+ xeus_model, xeus_train_args = SSLTask.build_model_from_file(
185
+ config = None,
186
+ ckpt = '/path/to/checkpoint/here/checkpoint.pth',
187
+ device,
188
+ )
189
+
190
+ wavs, sampling_rate = sf.read('/path/to/audio.wav') # sampling rate should be 16000
191
+ wav_lengths = torch.LongTensor([len(wav) for wav in [wavs]]).to(device)
192
+ wavs = pad_sequence([wavs], batch_first=True).to(device)
193
+
194
+ # we recommend use_mask=True during fine-tuning
195
+ feats = xeus_model.encode(wavs, wav_lengths, use_mask=False, use_final_output=False)[0][-1] # take the output of the last layer
196
+ ```
197
+
198
  ## Results
199
 
200
  ![image/png](https://cdn-uploads.huggingface.co/production/uploads/630438615c70c21d0eae6613/RCAWBxSuDLXJ5zdj-OBdn.png)