File size: 8,761 Bytes
f329f75
 
0d9ff36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f329f75
0d9ff36
 
 
 
 
 
f329f75
0d9ff36
f329f75
0d9ff36
f329f75
0d9ff36
f329f75
0d9ff36
 
f329f75
0d9ff36
f329f75
c5d4931
0d9ff36
 
 
 
 
 
 
 
 
 
63d0469
 
0d9ff36
 
f329f75
0d9ff36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8fa1508
0d9ff36
 
 
 
 
 
 
8fa1508
f329f75
63d0469
 
f329f75
 
63d0469
f329f75
63d0469
 
0d9ff36
f329f75
 
 
 
0d9ff36
 
63d0469
f329f75
 
 
 
 
0d9ff36
 
63d0469
f329f75
63d0469
 
0d9ff36
 
 
63d0469
 
 
 
 
 
 
 
 
0d9ff36
 
f329f75
 
 
c5d4931
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d9ff36
 
c5d4931
0d9ff36
c5d4931
 
 
 
 
 
 
f329f75
 
0d9ff36
f329f75
0d9ff36
 
 
 
 
 
 
 
 
 
f329f75
 
0d9ff36
63d0469
0d9ff36
 
 
 
 
 
 
f329f75
 
 
0d9ff36
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
import gradio as gr

import src.generate as generate
import src.process as process
import src.tts as tts


# ------------------- UI printing functions -------------------
def clear_all():
    # target, user_transcript, score_html, diff_html, result_html,
    # tts_text, clone_status, tts_audio
    return "", "", "", "", "", "", "", None


def make_result_html(pass_threshold, passed, ratio):
    """Returns summary and score label."""
    summary = (
        f"βœ… Correct (β‰₯ {int(pass_threshold * 100)}%)"
        if passed else
        f"❌ Not a match (need β‰₯ {int(pass_threshold * 100)}%)"
    )
    score = f"Similarity: {ratio * 100:.1f}%"
    return summary, score


def make_alignment_html(ref_tokens, hyp_tokens, alignments):
    """Returns HTML showing alignment between target and recognized user audio."""
    out = []
    no_match_html = ' <span style="background:#ffe0e0;text-decoration:line-through;">'
    match_html = ' <span style="background:#e0ffe0;">'
    for span in alignments:
        op, i1, i2, j1, j2 = span
        ref_string = " ".join(ref_tokens[i1:i2])
        hyp_string = " ".join(hyp_tokens[j1:j2])
        if op == "equal":
            out.append(" " + ref_string)
        elif op == "delete":
            out.append(no_match_html + ref_string + "</span>")
        elif op == "insert":
            out.append(match_html + hyp_string + "</span>")
        elif op == "replace":
            out.append(no_match_html + ref_string + "</span>")
            out.append(match_html + hyp_string + "</span>")
    html = '<div style="line-height:1.6;font-size:1rem;">' + "".join(out).strip() + "</div>"
    return html


def make_html(sentence_match):
    """Build diff + results HTML."""
    diff_html = make_alignment_html(sentence_match.target_tokens,
                                    sentence_match.user_tokens,
                                    sentence_match.alignments)
    result_html, score_html = make_result_html(sentence_match.pass_threshold,
                                               sentence_match.passed,
                                               sentence_match.ratio)
    return score_html, result_html, diff_html


# ------------------- Core Check (English-only) -------------------
def get_user_transcript(audio_path: gr.Audio, target_sentence: str, model_id: str, device_pref: str) -> (str, str):
    """ASR for the input audio and basic validation."""
    if not target_sentence:
        return "Please generate a sentence first.", ""
    if audio_path is None:
        return "Please start, record, then stop the audio recording before trying to transcribe.", ""

    user_transcript = process.run_asr(audio_path, model_id, device_pref)
    if isinstance(user_transcript, Exception):
        return f"Transcription failed: {user_transcript}", ""
    return "", user_transcript


def transcribe_check(audio_path, target_sentence, model_id, device_pref, pass_threshold):
    """Transcribe user audio, compute match, and render results."""
    error_msg, user_transcript = get_user_transcript(audio_path, target_sentence, model_id, device_pref)
    if error_msg:
        score_html = ""
        diff_html = ""
        result_html = error_msg
    else:
        sentence_match = process.SentenceMatcher(target_sentence, user_transcript, pass_threshold)
        score_html, result_html, diff_html = make_html(sentence_match)
    return user_transcript, score_html, result_html, diff_html


# ------------------- Voice cloning gate -------------------
def clone_if_pass(
    audio_path,              # ref voice (the same recorded clip)
    target_sentence,         # sentence user was supposed to say
    user_transcript,         # what ASR heard
    tts_text,                # what we want to synthesize (in cloned voice)
    pass_threshold,          # must meet or exceed this
    tts_model_id,            # e.g., "coqui/XTTS-v2"
    tts_language,            # e.g., "en"
):
    """
    If user correctly read the target (>= threshold), clone their voice from the
    recorded audio and speak 'tts_text'. Otherwise, refuse.
    """
    # Basic validations
    if audio_path is None:
        return None, "Record audio first (reference voice is required)."
    if not target_sentence:
        return None, "Generate a target sentence first."
    if not user_transcript:
        return None, "Transcribe first to verify the sentence."
    if not tts_text:
        return None, "Enter the sentence to synthesize."

    # Recompute pass/fail to avoid relying on UI state
    sm = process.SentenceMatcher(target_sentence, user_transcript, pass_threshold)
    if not sm.passed:
        return None, (
            f"❌ Cloning blocked: your reading did not reach the threshold "
            f"({sm.ratio*100:.1f}% < {int(pass_threshold*100)}%)."
        )

    # Run zero-shot cloning
    out = tts.run_tts_clone(audio_path, tts_text, model_id=tts_model_id, language=tts_language)
    if isinstance(out, Exception):
        return None, f"Voice cloning failed: {out}"
    sr, wav = out
    # Gradio Audio can take a tuple (sr, np.array)
    return (sr, wav), f"βœ… Cloned and synthesized with {tts_model_id} ({tts_language})."


# ------------------- UI -------------------
with gr.Blocks(title="Say the Sentence (English)") as demo:
    gr.Markdown(
        """
        # 🎀 Say the Sentence (English)
        1) Generate a sentence.  
        2) Record yourself reading it.  
        3) Transcribe & check your accuracy.  
        4) If matched, clone your voice to speak any sentence you enter.
        """
    )

    with gr.Row():
        target = gr.Textbox(label="Target sentence", interactive=False,
                            placeholder="Click 'Generate sentence'")

    with gr.Row():
        btn_gen = gr.Button("🎲 Generate sentence", variant="primary")
        btn_clear = gr.Button("🧹 Clear")

    with gr.Row():
        audio = gr.Audio(sources=["microphone"], type="filepath",
                         label="Record your voice")

    with gr.Accordion("Advanced settings", open=False):
        model_id = gr.Dropdown(
            choices=[
                "openai/whisper-tiny.en",
                "openai/whisper-base.en",
                "distil-whisper/distil-small.en",
            ],
            value="openai/whisper-tiny.en",
            label="ASR model (English only)",
        )
        device_pref = gr.Radio(
            choices=["auto", "cpu", "cuda"],
            value="auto",
            label="Device preference"
        )
        pass_threshold = gr.Slider(0.50, 1.00, value=0.85, step=0.01,
                                   label="Match threshold")

    with gr.Row():
        btn_check = gr.Button("βœ… Transcribe & Check", variant="primary")
    with gr.Row():
        user_transcript = gr.Textbox(label="Transcription", interactive=False)
    with gr.Row():
        score_html = gr.Label(label="Score")
        result_html = gr.Label(label="Result")
    diff_html = gr.HTML(
        label="Word-level diff (red = expected but missing / green = extra or replacement)")

    gr.Markdown("## πŸ” Voice cloning (gated)")
    with gr.Row():
        tts_text = gr.Textbox(
            label="Text to synthesize (voice clone)",
            placeholder="Type the sentence you want the cloned voice to say",
        )
    with gr.Row():
        tts_model_id = gr.Dropdown(
            choices=[
                "coqui/XTTS-v2",
                # add others if you like, e.g., "myshell-ai/MeloTTS"
            ],
            value="coqui/XTTS-v2",
            label="TTS (voice cloning) model",
        )
        tts_language = gr.Dropdown(
            choices=["en", "de", "fr", "es", "it", "pt", "pl", "tr", "ru", "nl", "cs", "ar", "zh"],
            value="en",
            label="Language",
        )

    with gr.Row():
        btn_clone = gr.Button("πŸ” Clone voice (if passed)", variant="secondary")
    with gr.Row():
        tts_audio = gr.Audio(label="Cloned speech output", interactive=False)
        clone_status = gr.Label(label="Cloning status")

    # -------- Events --------
    btn_gen.click(fn=generate.gen_sentence_set, outputs=target)

    btn_clear.click(
        fn=clear_all,
        outputs=[target, user_transcript, score_html, result_html, diff_html, tts_text, clone_status, tts_audio]
    )

    btn_check.click(
        fn=transcribe_check,
        inputs=[audio, target, model_id, device_pref, pass_threshold],
        outputs=[user_transcript, score_html, result_html, diff_html]
    )

    btn_clone.click(
        fn=clone_if_pass,
        inputs=[audio, target, user_transcript, tts_text, pass_threshold, tts_model_id, tts_language],
        outputs=[tts_audio, clone_status],
    )

if __name__ == "__main__":
    demo.launch()