Add-Vishnu commited on
Commit
7788a27
·
1 Parent(s): bed76c6

Create asr.py

Browse files
Files changed (1) hide show
  1. asr.py +57 -0
asr.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Wav2Vec2ForCTC, AutoProcessor
2
+ import torch
3
+ from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor
4
+ import time
5
+ import gradio as gr
6
+ import librosa
7
+
8
+ model_id = "facebook/mms-1b-all"
9
+ processor = AutoProcessor.from_pretrained(model_id)
10
+ model = Wav2Vec2ForCTC.from_pretrained(model_id)
11
+
12
+ model_id_lid = "facebook/mms-lid-126"
13
+ processor_lid = AutoFeatureExtractor.from_pretrained(model_id_lid)
14
+ model_lid = Wav2Vec2ForSequenceClassification.from_pretrained(model_id_lid)
15
+
16
+ def transcribe(audio):
17
+ audio = librosa.load(audio, sr=16_000, mono=True)[0]
18
+ inputs = processor(audio, sampling_rate=16_000,return_tensors="pt")
19
+ with torch.no_grad():
20
+ tr_start_time = time.time()
21
+ outputs = model(**inputs).logits
22
+ tr_end_time = time.time()
23
+ ids = torch.argmax(outputs, dim=-1)[0]
24
+ transcription = processor.decode(ids)
25
+ return transcription,(tr_end_time-tr_start_time)
26
+
27
+
28
+ def detect_language(audio):
29
+ audio = librosa.load(audio, sr=16_000, mono=True)[0]
30
+ # print(audio)
31
+ inputs_lid = processor_lid(audio, sampling_rate=16_000, return_tensors="pt")
32
+ with torch.no_grad():
33
+ start_time_lid = time.time()
34
+ outputs_lid = model_lid(**inputs_lid).logits
35
+ end_time = time.time()
36
+ # print(end_time-start_time," sec")
37
+ lang_id = torch.argmax(outputs_lid, dim=-1)[0].item()
38
+ detected_lang = model_lid.config.id2label[lang_id]
39
+ print(detected_lang)
40
+ return detected_lang, (end_time_lid-start_time_lid)
41
+
42
+
43
+ def transcribe_lang(audio,lang):
44
+ audio = librosa.load(audio, sr=16_000, mono=True)[0]
45
+ processor.tokenizer.set_target_lang(lang)
46
+ model.load_adapter(lang)
47
+ print(lang)
48
+ inputs = processor(audio, sampling_rate=16_000,return_tensors="pt")
49
+ with torch.no_grad():
50
+ tr_start_time = time.time()
51
+ outputs = model(**inputs).logits
52
+ tr_end_time = time.time()
53
+ ids = torch.argmax(outputs, dim=-1)[0]
54
+ transcription = processor.decode(ids)
55
+ return transcription,(tr_end_time-tr_start_time)
56
+
57
+