akera commited on
Commit
d4afb45
·
verified ·
1 Parent(s): 361f06d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -28
app.py CHANGED
@@ -5,52 +5,73 @@ import librosa
5
  import json
6
  import os
7
  import huggingface_hub
 
8
 
9
  # with open('ISO_codes.json', 'r') as file:
10
  # iso_codes = json.load(file)
11
 
12
- languages = ["lug", "ach", "nyn", "teo"]
13
  auth_token = os.environ.get("HF_TOKEN")
14
 
15
- from huggingface_hub import login
16
- login(token=auth_token)
17
 
18
- model_id = "Sunbird/sunbird-mms"
19
- model = Wav2Vec2ForCTC.from_pretrained(model_id, use_auth_token=auth_token)
20
- processor = Wav2Vec2Processor.from_pretrained(model_id, use_auth_token=auth_token)
21
 
 
22
 
23
- def transcribe(audio_file_mic=None, audio_file_upload=None, language="Luganda (lug)"):
24
- if audio_file_mic:
25
- audio_file = audio_file_mic
26
- elif audio_file_upload:
27
- audio_file = audio_file_upload
28
- else:
29
- return "Please upload an audio file or record one"
30
 
31
- # Make sure audio is 16kHz
32
- speech, sample_rate = librosa.load(audio_file)
33
- if sample_rate != 16000:
34
- speech = librosa.resample(speech, orig_sr=sample_rate, target_sr=16000)
35
 
36
- # Keep the same model in memory and simply switch out the language adapters by calling load_adapter() for the model and set_target_lang() for the tokenizer
37
- language_code = language
38
- processor.tokenizer.set_target_lang(language_code)
39
- model.load_adapter(language_code)
40
 
41
- inputs = processor(speech, sampling_rate=16_000, return_tensors="pt")
 
 
 
 
 
 
 
42
 
43
- with torch.no_grad():
44
- outputs = model(**inputs).logits
 
 
45
 
46
- ids = torch.argmax(outputs, dim=-1)[0]
47
- transcription = processor.decode(ids)
48
- return transcription
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
 
51
  description = '''ASR with salt-mms'''
52
 
53
- iface = gr.Interface(fn=transcribe,
54
  inputs=[
55
  gr.Audio(source="microphone", type="filepath", label="Record Audio"),
56
  gr.Audio(source="upload", type="filepath", label="Upload Audio"),
 
5
  import json
6
  import os
7
  import huggingface_hub
8
+ from transformers import pipeline
9
 
10
  # with open('ISO_codes.json', 'r') as file:
11
  # iso_codes = json.load(file)
12
 
13
+ # languages = ["lug", "ach", "nyn", "teo"]
14
  auth_token = os.environ.get("HF_TOKEN")
15
 
 
 
16
 
17
+ target_lang_options = {"English": "eng", "Luganda": "lug", "Acholi": "ach", "Runyankole": "nyn", "Lugbara": "lgg"}
18
+ target_lang_code = target_lang_options[target_lang]
 
19
 
20
+ languages = list(target_lang_options.keys())
21
 
 
 
 
 
 
 
 
22
 
23
+ if target_lang_code=="eng":
24
+ model_id = "facebook/mms-1b-all"
25
+ else:
26
+ model_id = "Sunbird/sunbird-mms"
27
 
 
 
 
 
28
 
29
+ # Transcribe audio using custom model
30
+ def transcribe_audio(input_file, target_lang_code,
31
+ device, model_id=model_id,
32
+ chunk_length_s=10, stride_length_s=(4, 2), return_timestamps="word"):
33
+
34
+ pipe = pipeline(model=model_id, device=device, token=hf_auth_token)
35
+ pipe.tokenizer.set_target_lang(target_lang_code)
36
+ pipe.model.load_adapter(target_lang_code)
37
 
38
+ # Read audio file
39
+ audio_data = input_file.read()
40
+ output = pipe(audio_data, chunk_length_s=chunk_length_s, stride_length_s=stride_length_s, return_timestamps=return_timestamps)
41
+ return output
42
 
43
+
44
+ # def transcribe(audio_file_mic=None, audio_file_upload=None, language="Luganda (lug)"):
45
+ # if audio_file_mic:
46
+ # audio_file = audio_file_mic
47
+ # elif audio_file_upload:
48
+ # audio_file = audio_file_upload
49
+ # else:
50
+ # return "Please upload an audio file or record one"
51
+
52
+ # # Make sure audio is 16kHz
53
+ # speech, sample_rate = librosa.load(audio_file)
54
+ # if sample_rate != 16000:
55
+ # speech = librosa.resample(speech, orig_sr=sample_rate, target_sr=16000)
56
+
57
+ # # Keep the same model in memory and simply switch out the language adapters by calling load_adapter() for the model and set_target_lang() for the tokenizer
58
+ # language_code = language
59
+ # processor.tokenizer.set_target_lang(language_code)
60
+ # model.load_adapter(language_code)
61
+
62
+ # inputs = processor(speech, sampling_rate=16_000, return_tensors="pt")
63
+
64
+ # with torch.no_grad():
65
+ # outputs = model(**inputs).logits
66
+
67
+ # ids = torch.argmax(outputs, dim=-1)[0]
68
+ # transcription = processor.decode(ids)
69
+ # return transcription
70
 
71
 
72
  description = '''ASR with salt-mms'''
73
 
74
+ iface = gr.Interface(fn=transcribe_audio,
75
  inputs=[
76
  gr.Audio(source="microphone", type="filepath", label="Record Audio"),
77
  gr.Audio(source="upload", type="filepath", label="Upload Audio"),