Aboubacar OUATTARA - kaira commited on
Commit
54440ac
1 Parent(s): 7dc5f48

initial commit

Browse files
Files changed (4) hide show
  1. app.py +50 -0
  2. bambara_utils.py +46 -0
  3. packages.txt +1 -0
  4. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import torch
3
+ from transformers import pipeline
4
+ import gradio as gr
5
+ from bambara_utils import BambaraWhisperTokenizer
6
+
7
+ # Determine the appropriate device (GPU or CPU)
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+
10
+ # Define the model checkpoint and language
11
+ model_checkpoint = "oza75/whisper-bambara-asr-001"
12
+ language = "bambara"
13
+
14
+ # Load the custom tokenizer designed for Bambara and the ASR model
15
+ tokenizer = BambaraWhisperTokenizer.from_pretrained(model_checkpoint, language=language, device=device)
16
+ pipe = pipeline(model=model_checkpoint, tokenizer=tokenizer, device=device)
17
+
18
+
19
+ @spaces.GPU()
20
+ def transcribe(audio):
21
+ """
22
+ Transcribes the provided audio file into text using the configured ASR pipeline.
23
+
24
+ Args:
25
+ audio: The path to the audio file to transcribe.
26
+
27
+ Returns:
28
+ A string representing the transcribed text.
29
+ """
30
+ # Use the pipeline to perform transcription
31
+ text = pipe(audio)["text"]
32
+ return text
33
+
34
+
35
+ def main():
36
+ # Setup Gradio interface
37
+ iface = gr.Interface(
38
+ fn=transcribe,
39
+ inputs=gr.Audio(type="filepath"),
40
+ outputs="text",
41
+ title="Bambara Automatic Speech Recognition",
42
+ description="Realtime demo for Bambara speech recognition based on a fine-tuning of the Whisper model."
43
+ )
44
+
45
+ # Launch the interface
46
+ iface.launch(share=False)
47
+
48
+
49
+ if __name__ == "__main__":
50
+ main()
bambara_utils.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from tokenizers import AddedToken
4
+ from transformers import WhisperTokenizer, WhisperProcessor
5
+ from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE, TASK_IDS
6
+
7
+ CUSTOM_TO_LANGUAGE_CODE = {**TO_LANGUAGE_CODE, "bambara": "bm"}
8
+
9
+
10
+ class BambaraWhisperTokenizer(WhisperTokenizer):
11
+ def __init__(self, *args, **kwargs):
12
+ super().__init__(*args, **kwargs)
13
+ self.add_tokens(AddedToken(content="<|bm|>", lstrip=False, rstrip=False, normalized=False, special=True))
14
+
15
+ @property
16
+ def prefix_tokens(self) -> List[int]:
17
+ bos_token_id = self.convert_tokens_to_ids("<|startoftranscript|>")
18
+ translate_token_id = self.convert_tokens_to_ids("<|translate|>")
19
+ transcribe_token_id = self.convert_tokens_to_ids("<|transcribe|>")
20
+ notimestamps_token_id = self.convert_tokens_to_ids("<|notimestamps|>")
21
+
22
+ if self.language is not None:
23
+ self.language = self.language.lower()
24
+ if self.language in CUSTOM_TO_LANGUAGE_CODE:
25
+ language_id = CUSTOM_TO_LANGUAGE_CODE[self.language]
26
+ elif self.language in CUSTOM_TO_LANGUAGE_CODE.values():
27
+ language_id = self.language
28
+ else:
29
+ is_language_code = len(self.language) == 2
30
+ raise ValueError(
31
+ f"Unsupported language: {self.language}. Language should be one of:"
32
+ f" {list(CUSTOM_TO_LANGUAGE_CODE.values()) if is_language_code else list(CUSTOM_TO_LANGUAGE_CODE.keys())}."
33
+ )
34
+
35
+ if self.task is not None:
36
+ if self.task not in TASK_IDS:
37
+ raise ValueError(f"Unsupported task: {self.task}. Task should be in: {TASK_IDS}")
38
+
39
+ bos_sequence = [bos_token_id]
40
+ if self.language is not None:
41
+ bos_sequence.append(self.convert_tokens_to_ids(f"<|{language_id}|>"))
42
+ if self.task is not None:
43
+ bos_sequence.append(transcribe_token_id if self.task == "transcribe" else translate_token_id)
44
+ if not self.predict_timestamps:
45
+ bos_sequence.append(notimestamps_token_id)
46
+ return bos_sequence
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ffmpeg
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ datasets[audio]
2
+ transformers
3
+ accelerate
4
+ evaluate
5
+ jiwer
6
+ tensorboard
7
+ gradio
8
+ spaces