akera commited on
Commit
a39ac0d
1 Parent(s): 6d352f5

added formatting

Browse files
Files changed (1) hide show
  1. app.py +32 -28
app.py CHANGED
@@ -1,56 +1,60 @@
1
  import gradio as gr
2
- from transformers import Wav2Vec2ForCTC, AutoProcessor, Wav2Vec2Processor
3
  import torch
4
  import librosa
5
  import json
6
  import os
7
- import huggingface_hub
8
- from transformers import pipeline
9
-
10
-
11
- auth_token = os.environ.get("HF_TOKEN")
12
-
13
-
14
- target_lang_options = {"English": "eng", "Luganda": "lug", "Acholi": "ach", "Runyankole": "nyn", "Lugbara": "lgg"}
15
 
16
- languages = list(target_lang_options.keys())
17
-
18
-
19
- # Transcribe audio using custom model
20
- def transcribe_audio(input_file, language,chunk_length_s=10,
21
- stride_length_s=(4, 2), return_timestamps="word"):
 
 
 
 
 
 
 
 
22
 
 
 
 
 
23
 
 
 
24
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
25
  target_lang_code = target_lang_options[language]
26
 
27
  # Determine the model_id based on the language
28
  if target_lang_code == "eng":
29
  model_id = "facebook/mms-1b-all"
30
  else:
31
- model_id = "Sunbird/sunbird-mms"
32
 
 
33
  pipe = pipeline(model=model_id, device=device, token=auth_token)
34
- pipe.tokenizer.set_target_lang(target_lang_code)
35
-
36
- pipe.model.load_adapter(target_lang_code)
37
 
38
- # Read audio file
39
- # audio_data = input_file
40
  output = pipe(input_file, chunk_length_s=chunk_length_s, stride_length_s=stride_length_s, return_timestamps=return_timestamps)
41
- return output
42
-
43
 
 
44
  description = '''ASR with salt-mms'''
45
-
46
  iface = gr.Interface(fn=transcribe_audio,
47
  inputs=[
48
  gr.Audio(source="upload", type="filepath", label="upload file to transcribe"),
49
- gr.Dropdown(choices=languages, label="Language", value="English")
50
- ],
51
  outputs=gr.Textbox(label="Transcription"),
52
  description=description
53
  )
54
 
55
-
56
- iface.launch()
 
1
  import gradio as gr
2
+ from transformers import pipeline
3
  import torch
4
  import librosa
5
  import json
6
  import os
 
 
 
 
 
 
 
 
7
 
8
+ # Assuming other necessary imports and setup are already done
9
+
10
+ # Helper function to format and group word timestamps
11
+ def format_and_group_timestamps(chunks, interval=5.0):
12
+ grouped = {}
13
+ transcript = ""
14
+ for chunk in chunks:
15
+ start, end = chunk['timestamp']
16
+ word = chunk['text']
17
+ transcript += f"{word} "
18
+ interval_start = int(start // interval) * interval
19
+ if interval_start not in grouped:
20
+ grouped[interval_start] = []
21
+ grouped[interval_start].append((start, end, word))
22
 
23
+ formatted_output = f"Transcript: {transcript.strip()}'\n\n-------\n\nword-stamped transcripts (every 5 seconds):\n\n"
24
+ for interval_start, words in grouped.items():
25
+ formatted_output += f"({interval_start}, {interval_start + interval}) -- {' '.join([w[2] for w in words])}\n"
26
+ return formatted_output
27
 
28
+ # Modified transcribe_audio function
29
+ def transcribe_audio(input_file, language, chunk_length_s=10, stride_length_s=(4, 2), return_timestamps="word"):
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
31
+ target_lang_options = {"English": "eng", "Luganda": "lug", "Acholi": "ach", "Runyankole": "nyn", "Lugbara": "lgg"}
32
  target_lang_code = target_lang_options[language]
33
 
34
  # Determine the model_id based on the language
35
  if target_lang_code == "eng":
36
  model_id = "facebook/mms-1b-all"
37
  else:
38
+ model_id = "custom_model_id_for_other_languages" # Placeholder for actual model IDs
39
 
40
+ auth_token = os.environ.get("HF_TOKEN")
41
  pipe = pipeline(model=model_id, device=device, token=auth_token)
42
+ # Assuming necessary setup for tokenizer and loading adapter
 
 
43
 
 
 
44
  output = pipe(input_file, chunk_length_s=chunk_length_s, stride_length_s=stride_length_s, return_timestamps=return_timestamps)
45
+ formatted_output = format_and_group_timestamps(output['chunks'])
46
+ return formatted_output
47
 
48
+ # Interface setup remains the same
49
  description = '''ASR with salt-mms'''
 
50
  iface = gr.Interface(fn=transcribe_audio,
51
  inputs=[
52
  gr.Audio(source="upload", type="filepath", label="upload file to transcribe"),
53
+ gr.Dropdown(choices=list(target_lang_options.keys()), label="Language", value="English")
54
+ ],
55
  outputs=gr.Textbox(label="Transcription"),
56
  description=description
57
  )
58
 
59
+ # Launch the interface
60
+ iface.launch()