Ubuntu commited on
Commit
2400474
·
1 Parent(s): 98f0a8b

Update README

Browse files
Files changed (1) hide show
  1. README.md +48 -0
README.md CHANGED
@@ -1,3 +1,51 @@
1
  ---
 
2
  license: apache-2.0
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ language: ja
3
  license: apache-2.0
4
  ---
5
+
6
+ # Fine-tuned XLSR-53 large model for speech diarization in Japanese phone-call
7
+
8
+ Fine-tuned [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) on Japanese using phone-call data [CallHome](https://media.talkbank.org/ca/CallHome/jpn/).
9
+
10
+ ## Usage
11
+ The model can be used directly as follows.
12
+
13
+ ```python
14
+ import numpy as np
15
+ import torch
16
+ from pydub import AudioSegment
17
+
18
+ from transformers import Wav2Vec2ForAudioFrameClassification, Wav2Vec2FeatureExtractor
19
+
20
+
21
+ def _make_timegrid(sound_duration: float, total_len: int):
22
+ start_timegrid = np.linspace(0, sound_duration, total_len + 1)
23
+ dt = start_timegrid[1] - start_timegrid[0]
24
+ end_timegrid = start_timegrid + dt
25
+ return start_timegrid[:total_len], end_timegrid[:total_len]
26
+
27
+ feature_extractor = Wav2Vec2FeatureExtractor(
28
+ feature_size=1,
29
+ sampling_rate=16_000,
30
+ padding_value=0.0,
31
+ do_normalize=True,
32
+ return_attention_mask=True,
33
+ )
34
+ model = Wav2Vec2ForAudioFrameClassification.from_pretrained("Ivydata/wav2vec2-large-speech-diarization-jp")
35
+ filepath = "/path/to/file.wav"
36
+ sound = AudioSegment.from_file(filepath)
37
+ sound = sound.set_frame_rate(16_000)
38
+ sound_duration = sound.duration_seconds
39
+
40
+ feature = feature_extractor(np.array(sound.get_array_of_samples())).input_values[0]
41
+ input_values = torch.tensor(feature, dtype=torch.float32).unsqueeze(0)
42
+
43
+ with torch.no_grad():
44
+ logits = model(input_values).logits
45
+ pred = logits.argmax(dim=-1).squeeze(0)
46
+ start_timegrid, end_timegrid = _make_timegrid(sound_duration, len(pred))
47
+
48
+ print("sec speaker_label")
49
+ for p, start_time in zip(pred, start_timegrid):
50
+ print(f"{start_time:.4f} {p}")
51
+ ```