SpeechLine / app.py
w11wo's picture
Increased Number of Components
798ba23
import os
import shutil
from pathlib import Path
import gradio as gr
import pandas as pd
from datasets import Audio, Dataset
from speechline.segmenters import SilenceSegmenter, WordOverlapSegmenter
from speechline.transcribers import Wav2Vec2Transcriber
from speechline.utils.tokenizer import WordTokenizer
MAX_SEGMENTS = 100
OUTPUT_DIR = "tmp"
def segmentation_interface(choice: str):
if choice == "Silence Gap":
return gr.update(visible=True), gr.update(visible=False)
elif choice == "Word Overlap":
return gr.update(visible=False), gr.update(visible=True)
def run(audio_path, model, segmentation_type, silence_duration, ground_truth):
transcriber = Wav2Vec2Transcriber(model)
dataset = Dataset.from_dict({"audio": [audio_path]})
dataset = dataset.cast_column(
"audio", Audio(sampling_rate=transcriber.sampling_rate)
)
output_offsets = transcriber.predict(dataset, output_offsets=True)
if segmentation_type == "Silence Gap":
segmenter = SilenceSegmenter()
elif segmentation_type == "Word Overlap":
segmenter = WordOverlapSegmenter()
tokenizer = WordTokenizer()
if os.path.exists(OUTPUT_DIR):
shutil.rmtree(OUTPUT_DIR)
segmenter.chunk_audio_segments(
audio_path,
OUTPUT_DIR,
output_offsets[0],
minimum_chunk_duration=0,
silence_duration=silence_duration,
ground_truth=tokenizer(ground_truth),
)
outputs, idx = [], 0
for path in sorted(Path(OUTPUT_DIR).rglob("*")):
if path.suffix == ".tsv":
gt = pd.read_csv(
path, sep="\t", names=["start_offset", "end_offset", "text"]
)
outputs.append(gr.Dataframe.update(value=gt, visible=True))
elif path.suffix == ".wav":
outputs.append(gr.Audio.update(value=str(path), visible=True))
idx += 1
for _ in range(MAX_SEGMENTS - idx):
outputs += [gr.Dataframe.update(visible=False), gr.Audio.update(visible=False)]
return outputs
with gr.Blocks() as demo:
gr.Markdown(
f"""
<center>
# ๐ŸŽ™๏ธ SpeechLine Demo
[Repository](https://github.com/bookbot-kids/speechline) | [Documentation](https://bookbot-kids.github.io/speechline/)
</center>
"""
)
with gr.Row():
with gr.Column():
audio = gr.Audio(type="filepath")
model = gr.Dropdown(
choices=[
"facebook/wav2vec2-base-960h",
],
value="facebook/wav2vec2-base-960h",
label="Transcriber Model",
)
segmenter = gr.Radio(
choices=["Silence Gap", "Word Overlap"],
value="Silence Gap",
label="Segmentation Method",
)
sil = gr.Slider(
0, 1, value=0.1, step=0.1, label="Silence Duration", visible=True
)
gt = gr.Textbox(
label="Ground Truth",
placeholder="Enter Ground Truth Text",
interactive=True,
visible=False,
)
segmenter.change(
fn=segmentation_interface, inputs=segmenter, outputs=[sil, gt]
)
inputs = [audio, model, segmenter, sil, gt]
transcribe_btn = gr.Button("Transcribe")
with gr.Column():
outputs = [
gr.Dataframe(
visible=True, headers=["start_offset", "end_offset", "text"]
),
gr.Audio(visible=True),
]
for _ in range(MAX_SEGMENTS - 1):
outputs += [gr.Dataframe(visible=False), gr.Audio(visible=False)]
transcribe_btn.click(fn=run, inputs=inputs, outputs=outputs)
demo.launch()