unijoh commited on
Commit
2859a48
1 Parent(s): 2244bbb

Create asr.py

Browse files
Files changed (1) hide show
  1. asr.py +33 -0
asr.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ from transformers import Wav2Vec2ForCTC, AutoProcessor
3
+ import torch
4
+
5
+ ASR_SAMPLING_RATE = 16_000
6
+ MODEL_ID = "facebook/mms-1b-all"
7
+
8
+ processor = AutoProcessor.from_pretrained(MODEL_ID)
9
+ model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
10
+
11
+ def transcribe(audio_source=None, microphone=None, file_upload=None):
12
+ audio_fp = file_upload if "upload" in str(audio_source or "").lower() else microphone
13
+ if audio_fp is None:
14
+ return "ERROR: You have to either use the microphone or upload an audio file"
15
+
16
+ audio_samples = librosa.load(audio_fp, sr=ASR_SAMPLING_RATE, mono=True)[0]
17
+ processor.tokenizer.set_target_lang("fao") # Set Faroese language
18
+ model.load_adapter("fao")
19
+
20
+ inputs = processor(audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt")
21
+
22
+ # Set device
23
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+ model.to(device)
25
+ inputs = inputs.to(device)
26
+
27
+ with torch.no_grad():
28
+ outputs = model(**inputs).logits
29
+
30
+ ids = torch.argmax(outputs, dim=-1)[0]
31
+ transcription = processor.decode(ids)
32
+
33
+ return transcription