Irpan commited on
Commit
d29fa84
1 Parent(s): 619a599
Files changed (4) hide show
  1. app.py +3 -13
  2. asr.py +44 -0
  3. tts.py +19 -8
  4. util.py +70 -3
app.py CHANGED
@@ -1,17 +1,7 @@
1
  import gradio as gr
2
  import util
3
  import tts
4
-
5
- # Functions
6
- def check_pronunciation(input_text, script, user_audio):
7
- # Placeholder logic for pronunciation checking
8
- transcript_ugArab_box = "Automatic transcription of your audio (Arabic)..."
9
- transcript_ugLatn_box = "Automatic transcription of your audio (Latin)..."
10
- correct_pronunciation = "Correct pronunciation in IPA"
11
- user_pronunciation = "User pronunciation in IPA"
12
- pronunciation_match = "Matching segments in green, mismatched in red"
13
- pronunciation_score = 85.7 # Replace with actual score calculation
14
- return transcript_ugArab_box, transcript_ugLatn_box, correct_pronunciation, user_pronunciation, pronunciation_match, pronunciation_score
15
 
16
  # Front-End
17
  with gr.Blocks() as app:
@@ -101,13 +91,13 @@ with gr.Blocks() as app:
101
  )
102
 
103
  tts_btn.click(
104
- tts.generate_example_pronunciation,
105
  inputs=[input_text, script_choice],
106
  outputs=[example_audio]
107
  )
108
 
109
  check_btn.click(
110
- check_pronunciation,
111
  inputs=[input_text, script_choice, user_audio],
112
  outputs=[transcript_ugArab_box, transcript_ugLatn_box, correct_pronunciation_box, user_pronunciation_box, match_box, score_box]
113
  )
 
1
  import gradio as gr
2
  import util
3
  import tts
4
+ import asr
 
 
 
 
 
 
 
 
 
 
5
 
6
  # Front-End
7
  with gr.Blocks() as app:
 
91
  )
92
 
93
  tts_btn.click(
94
+ tts.generate_audio,
95
  inputs=[input_text, script_choice],
96
  outputs=[example_audio]
97
  )
98
 
99
  check_btn.click(
100
+ asr.check_pronunciation,
101
  inputs=[input_text, script_choice, user_audio],
102
  outputs=[transcript_ugArab_box, transcript_ugLatn_box, correct_pronunciation_box, user_pronunciation_box, match_box, score_box]
103
  )
asr.py CHANGED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
2
+ import torch
3
+ from umsc import UgMultiScriptConverter
4
+ import util
5
+
6
+ # Model ID and setup
7
+ model_id = 'ixxan/wav2vec2-large-mms-1b-uyghur-latin'
8
+ asr_model = Wav2Vec2ForCTC.from_pretrained(model_id, target_lang="uig-script_latin")
9
+ asr_processor = Wav2Vec2Processor.from_pretrained(model_id)
10
+ asr_processor.tokenizer.set_target_lang("uig-script_latin")
11
+
12
+ # Automatically allocate the device
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ asr_model = asr_model.to(device)
15
+
16
+ def asr(user_audio):
17
+ # Load and resample user audio
18
+ audio_input, sampling_rate = util.load_and_resample_audio(user_audio, target_rate=16000)
19
+
20
+ # Process audio through ASR model
21
+ inputs = asr_processor(audio_input.squeeze(), sampling_rate=sampling_rate, return_tensors="pt", padding=True)
22
+ inputs = {key: val.to(device) for key, val in inputs.items()}
23
+ with torch.no_grad():
24
+ logits = asr_model(**inputs).logits
25
+ predicted_ids = torch.argmax(logits, dim=-1)
26
+ transcript = asr_processor.batch_decode(predicted_ids)[0]
27
+ return transcript
28
+
29
+
30
+ def check_pronunciation(input_text, script, user_audio):
31
+ # Transcripts from user input audio
32
+ transcript_ugLatn_box = asr(user_audio)
33
+ ug_latn_to_arab = UgMultiScriptConverter('ULS', 'UAS')
34
+ transcript_ugArab_box = ug_latn_to_arab(transcript_ugLatn_box)
35
+
36
+ # Get IPA and Pronunciation Feedback
37
+ if script == 'Uyghur Latin':
38
+ input_text = ug_latn_to_arab(input_text) # make sure input text is arabic script to IPA conversion
39
+ correct_pronunciation, user_pronunciation, pronunciation_match, pronunciation_score = util.calculate_pronunciation_accuracy(
40
+ reference_text = input_text,
41
+ output_text = transcript_ugArab_box,
42
+ language_code='uig-Arab')
43
+
44
+ return transcript_ugArab_box, transcript_ugLatn_box, correct_pronunciation, user_pronunciation, pronunciation_match, pronunciation_score
tts.py CHANGED
@@ -2,20 +2,31 @@ from transformers import VitsModel, AutoTokenizer
2
  import torch
3
  from umsc import UgMultiScriptConverter
4
  import scipy.io.wavfile
5
- import os
6
 
7
- tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-uig-script_arabic")
8
- tts_model = VitsModel.from_pretrained("facebook/mms-tts-uig-script_arabic")
 
 
9
 
10
- def generate_example_pronunciation(input_text, script):
11
- # Convert text to uyghur_arabic
 
 
 
 
 
 
 
12
  ug_latn_to_arab = UgMultiScriptConverter('ULS', 'UAS')
13
- if not script == "Uyghur Arabic":
14
  input_text = ug_latn_to_arab(input_text)
15
 
16
- tts_inputs = tts_tokenizer(input_text, return_tensors="pt")
 
 
 
17
  with torch.no_grad():
18
- tts_output = tts_model(**tts_inputs).waveform
19
 
20
  # Save to a temporary file
21
  output_path = "tts_output.wav"
 
2
  import torch
3
  from umsc import UgMultiScriptConverter
4
  import scipy.io.wavfile
 
5
 
6
+ # Model ID and setup
7
+ model_id = "facebook/mms-tts-uig-script_arabic"
8
+ tts_tokenizer = AutoTokenizer.from_pretrained(model_id)
9
+ tts_model = VitsModel.from_pretrained(model_id)
10
 
11
+ # Automatically allocate the device
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ tts_model = tts_model.to(device)
14
+
15
+ def generate_audio(input_text, script):
16
+ """
17
+ Generate audio for the given input text and script
18
+ """
19
+ # Convert text to Uyghur Arabic if needed
20
  ug_latn_to_arab = UgMultiScriptConverter('ULS', 'UAS')
21
+ if script != "Uyghur Arabic":
22
  input_text = ug_latn_to_arab(input_text)
23
 
24
+ # Tokenize and move inputs to the same device as the model
25
+ tts_inputs = tts_tokenizer(input_text, return_tensors="pt").to(device)
26
+
27
+ # Perform inference
28
  with torch.no_grad():
29
+ tts_output = tts_model(**tts_inputs).waveform.cpu() # Move output back to CPU for saving
30
 
31
  # Save to a temporary file
32
  output_path = "tts_output.wav"
util.py CHANGED
@@ -1,16 +1,21 @@
1
  import random
2
  from umsc import UgMultiScriptConverter
 
 
 
 
3
 
4
  # Lists of Uyghur short and long texts
5
  short_texts = [
6
  "سالام", "رەھمەت", "ياخشىمۇسىز"
7
  ]
8
  long_texts = [
9
- "مەكتەپكە بارغاندا تېخىمۇ بىلىملىك بولۇپ قېلىمەن.",
10
  "يېزا مەنزىرىسى ھەقىقەتەن گۈزەل.",
11
- "پېقىرلارغا ياردەم قىلىش مەنەم پەرزەندە."
12
  ]
13
 
 
14
  def generate_short_text(script_choice):
15
  """Generate a random Uyghur short text based on the type."""
16
  ug_arab_to_latn = UgMultiScriptConverter('UAS', 'ULS')
@@ -27,4 +32,66 @@ def generate_long_text(script_choice):
27
  text = random.choice(long_texts)
28
  if script_choice == "Uyghur Latin":
29
  return ug_arab_to_latn(text)
30
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import random
2
  from umsc import UgMultiScriptConverter
3
+ import torchaudio
4
+ import string
5
+ import epitran
6
+ from difflib import SequenceMatcher
7
 
8
  # Lists of Uyghur short and long texts
9
  short_texts = [
10
  "سالام", "رەھمەت", "ياخشىمۇسىز"
11
  ]
12
  long_texts = [
13
+ "مەكتەپكە بارغاندا تېخىمۇ بىلىملىك بولۇمەن.",
14
  "يېزا مەنزىرىسى ھەقىقەتەن گۈزەل.",
15
+ "بىزنىڭ ئۆيدەپ تۆت تەكچە تۆتىلىسى تەكتەكچە"
16
  ]
17
 
18
+ # Front-End Utils
19
  def generate_short_text(script_choice):
20
  """Generate a random Uyghur short text based on the type."""
21
  ug_arab_to_latn = UgMultiScriptConverter('UAS', 'ULS')
 
32
  text = random.choice(long_texts)
33
  if script_choice == "Uyghur Latin":
34
  return ug_arab_to_latn(text)
35
+ return text
36
+
37
+ # ASR Utils
38
+ def load_and_resample_audio(file_path, target_rate):
39
+ """Load audio and resample based on target sample rate"""
40
+ audio_input, sampling_rate = torchaudio.load(file_path)
41
+ if sampling_rate != target_rate:
42
+ resampler = torchaudio.transforms.Resample(sampling_rate, target_rate)
43
+ audio_input = resampler(audio_input)
44
+ return audio_input, target_rate
45
+
46
+ def calculate_pronunciation_accuracy(reference_text, output_text, language_code='uig-Arab'):
47
+ """
48
+ Calculate pronunciation accuracy between reference and ASR output text using Epitran.
49
+
50
+ Args:
51
+ reference_text (str): The ground truth text in Uyghur (Arabic script).
52
+ output_text (str): The ASR output text in Uyghur (Arabic script).
53
+ language_code (str): Epitran language code (default is 'uig-Arab' for Uyghur).
54
+
55
+ Returns:
56
+ float: Pronunciation accuracy as a percentage.
57
+ str: IPA transliteration of the reference text.
58
+ str: IPA transliteration of the output text.
59
+ """
60
+ # Initialize Epitran for Uyghur (Arabic script)
61
+ ipa_converter = epitran.Epitran(language_code)
62
+
63
+ # Remove punctuation from both texts
64
+ reference_text_clean = remove_punctuation(reference_text)
65
+ output_text_clean = remove_punctuation(output_text)
66
+
67
+ # Transliterate both texts to IPA
68
+ reference_ipa = ipa_converter.transliterate(reference_text_clean)
69
+ output_ipa = ipa_converter.transliterate(output_text_clean)
70
+
71
+ # Calculate pronunciation accuracy using SequenceMatcher
72
+ matcher = SequenceMatcher(None, reference_ipa, output_ipa)
73
+ match_ratio = matcher.ratio() # This is the fraction of matching characters
74
+
75
+ # Convert to percentage
76
+ pronunciation_accuracy = match_ratio * 100
77
+
78
+ # Generate HTML for comparison
79
+ comparison_html = ""
80
+ for opcode, i1, i2, j1, j2 in matcher.get_opcodes():
81
+ ref_segment = reference_ipa[i1:i2]
82
+ out_segment = output_ipa[j1:j2]
83
+
84
+ if opcode == 'equal': # Matching characters
85
+ comparison_html += f'<span style="color: green">{ref_segment}</span>'
86
+ elif opcode == 'replace': # Mismatched characters
87
+ comparison_html += f'<span style="color: red">{ref_segment}</span>'
88
+ elif opcode == 'delete': # Characters in reference but not in output
89
+ comparison_html += f'<span style="color: red">{ref_segment}</span>'
90
+ elif opcode == 'insert': # Characters in output but not in reference
91
+ comparison_html += f'<span style="color: red">{out_segment}</span>'
92
+
93
+ return reference_ipa, output_ipa, comparison_html, pronunciation_accuracy
94
+
95
+ def remove_punctuation(text):
96
+ """Helper function to remove punctuation from text."""
97
+ return text.translate(str.maketrans('', '', string.punctuation))