nguyenvulebinh commited on
Commit
d5f441b
1 Parent(s): e30eabe

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +74 -0
README.md ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: cc-by-nc-4.0
3
+ language:
4
+ - vi
5
+ ---
6
+
7
+ ### Vietnamese ASR sequence-to-sequence model. This model supports output normalizing text, labeling timestamps, and segmenting multiple speakers.
8
+
9
+
10
+ ```python
11
+ # !pip install transformers, sentencepiece
12
+
13
+ from transformers import SpeechEncoderDecoderModel
14
+ from transformers import AutoFeatureExtractor, AutoTokenizer, GenerationConfig
15
+ import torchaudio
16
+ import torch
17
+
18
+ model_path = 'nguyenvulebinh/wav2vec2-bartpho'
19
+ model = SpeechEncoderDecoderModel.from_pretrained(model_path).eval()
20
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_path)
21
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
22
+ if torch.cuda.is_available():
23
+ model = model.cuda()
24
+
25
+
26
+ def decode_tokens(token_ids, skip_special_tokens=True, time_precision=0.02):
27
+ timestamp_begin = tokenizer.vocab_size
28
+ outputs = [[]]
29
+ for token in token_ids:
30
+ if token >= timestamp_begin:
31
+ timestamp = f" |{(token - timestamp_begin) * time_precision:.2f}| "
32
+ outputs.append(timestamp)
33
+ outputs.append([])
34
+ else:
35
+ outputs[-1].append(token)
36
+ outputs = [
37
+ s if isinstance(s, str) else tokenizer.decode(s, skip_special_tokens=skip_special_tokens) for s in outputs
38
+ ]
39
+ return "".join(outputs).replace("< |", "<|").replace("| >", "|>")
40
+
41
+ def decode_wav(audio_wavs, asr_model, prefix=""):
42
+ device = next(asr_model.parameters()).device
43
+ input_values = feature_extractor.pad(
44
+ [{"input_values": feature} for feature in audio_wavs],
45
+ padding=True,
46
+ max_length=None,
47
+ pad_to_multiple_of=None,
48
+ return_tensors="pt",
49
+ )
50
+
51
+ output_beam_ids = asr_model.generate(
52
+ input_values['input_values'].to(device),
53
+ attention_mask=input_values['attention_mask'].to(device),
54
+ decoder_input_ids=tokenizer.batch_encode_plus([prefix] * len(audio_wavs), return_tensors="pt")['input_ids'][..., :-1].to(device),
55
+ generation_config=GenerationConfig(decoder_start_token_id=tokenizer.bos_token_id),
56
+ max_length=250,
57
+ num_beams=25,
58
+ no_repeat_ngram_size=4,
59
+ num_return_sequences=1,
60
+ early_stopping=True,
61
+ return_dict_in_generate=True,
62
+ output_scores=True,
63
+ )
64
+
65
+ output_text = [decode_tokens(sequence) for sequence in output_beam_ids.sequences]
66
+
67
+ return output_text
68
+
69
+ print(decode_wav([torchaudio.load('sample_news.wav')[0].squeeze()], model))
70
+
71
+ # <|0.00| Gia đình cho biết, nhiều lần đã từng gọi điện báo chính quyền và lực lượng an ninh địa phương nhưng đều không có tác dụng |7.00|>
72
+ # <|8.14| Không ai giúp đỡ được mình một chút nào cả, nên là lúc đó là lúc tuyệt vọng nhất, nó tra tấn mình cực kỳ khổ, gây cái tâm lý ức chế rất là nhiều, rất là lớn |19.02|>']
73
+
74
+ ```