unijoh commited on
Commit
d501976
1 Parent(s): 9d60183

Update asr.py

Browse files
Files changed (1) hide show
  1. asr.py +7 -11
asr.py CHANGED
@@ -5,20 +5,15 @@ import torch
5
  ASR_SAMPLING_RATE = 16_000
6
 
7
  MODEL_ID = "facebook/mms-1b-all"
 
8
  processor = AutoProcessor.from_pretrained(MODEL_ID)
9
  model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
10
 
11
- def transcribe(audio_source=None, microphone=None, file_upload=None):
12
- audio_fp = file_upload if file_upload else microphone
13
- if audio_fp is None:
14
  return "ERROR: You have to either use the microphone or upload an audio file"
15
 
16
- audio_samples = librosa.load(audio_fp, sr=ASR_SAMPLING_RATE, mono=True)[0]
17
-
18
- # Set Faroese language
19
- processor.tokenizer.set_target_lang("fao")
20
- model.load_adapter("fao")
21
-
22
  inputs = processor(audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt")
23
 
24
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -27,7 +22,8 @@ def transcribe(audio_source=None, microphone=None, file_upload=None):
27
 
28
  with torch.no_grad():
29
  outputs = model(**inputs).logits
30
- ids = torch.argmax(outputs, dim=-1)[0]
31
- transcription = processor.decode(ids)
32
 
 
 
 
33
  return transcription
 
5
  ASR_SAMPLING_RATE = 16_000
6
 
7
  MODEL_ID = "facebook/mms-1b-all"
8
+
9
  processor = AutoProcessor.from_pretrained(MODEL_ID)
10
  model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
11
 
12
+ def transcribe(audio):
13
+ if audio is None:
 
14
  return "ERROR: You have to either use the microphone or upload an audio file"
15
 
16
+ audio_samples = librosa.load(audio, sr=ASR_SAMPLING_RATE, mono=True)[0]
 
 
 
 
 
17
  inputs = processor(audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt")
18
 
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
22
 
23
  with torch.no_grad():
24
  outputs = model(**inputs).logits
 
 
25
 
26
+ ids = torch.argmax(outputs, dim=-1)[0]
27
+ transcription = processor.decode(ids)
28
+
29
  return transcription