Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,654 Bytes
c962a1e 26065ca 391bf1a 35cff45 6cd713c 0f5d4d0 391bf1a c962a1e 1f8a9e2 da4f293 0c021ae 7f93b4d c962a1e 26065ca 61d0002 c962a1e 391bf1a da4f293 12b8205 da4f293 391bf1a 4dcbad1 0d0c66a 391bf1a c962a1e e4b0eea 6cd713c 9aa9482 5bd4675 c962a1e bda6501 c962a1e cc828b8 c962a1e 9aa9482 391bf1a bda6501 69bc1a2 685aafd 3b2410c 69bc1a2 9aa9482 c962a1e 391bf1a 26065ca 391bf1a fc18a2b 1578762 fc18a2b c962a1e 26065ca 391bf1a fc18a2b 1578762 fc18a2b c962a1e 26065ca 391bf1a 26065ca 364e345 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import os
from math import floor
from typing import Optional
import numpy as np
import spaces
import torch
import gradio as gr
from transformers import pipeline
from transformers.pipelines.audio_utils import ffmpeg_read
# configuration
MODEL_NAME = "kotoba-tech/kotoba-whisper-v2.0"
BATCH_SIZE = 16
CHUNK_LENGTH_S = 15
EXAMPLE = "./sample_diarization_japanese.mp3"
# device setting
if torch.cuda.is_available():
torch_dtype = torch.bfloat16
device = "cuda"
model_kwargs = {'attn_implementation': 'sdpa'}
else:
torch_dtype = torch.float32
device = "cpu"
model_kwargs = {}
# define the pipeline
pipe = pipeline(
model=MODEL_NAME,
chunk_length_s=CHUNK_LENGTH_S,
batch_size=BATCH_SIZE,
torch_dtype=torch_dtype,
device=device,
model_kwargs=model_kwargs,
trust_remote_code=True
)
def format_time(start: Optional[float], end: Optional[float]):
def _format_time(seconds: Optional[float]):
if seconds is None:
return "complete "
minutes = floor(seconds / 60)
hours = floor(seconds / 3600)
seconds = seconds - hours * 3600 - minutes * 60
m_seconds = floor(round(seconds - floor(seconds), 3) * 10 ** 3)
seconds = floor(seconds)
return f'{hours:02}:{minutes:02}:{seconds:02}.{m_seconds:03}'
return f"[{_format_time(start)}-> {_format_time(end)}]:"
@spaces.GPU
def get_prediction(inputs, prompt: Optional[str]):
generate_kwargs = {"language": "ja", "task": "transcribe"}
if prompt:
generate_kwargs['prompt_ids'] = pipe.tokenizer.get_prompt_ids(prompt, return_tensors='pt').to(device)
prediction = pipe(inputs, return_timestamps=True, generate_kwargs=generate_kwargs)
text = "".join([c['text'] for c in prediction['chunks']])
text_timestamped = "\n".join([f"{format_time(*c['timestamp'])} {c['text']}" for c in prediction['chunks']])
return text, text_timestamped
def transcribe(inputs: str, prompt):
if inputs is None:
raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
with open(inputs, "rb") as f:
inputs = f.read()
inputs = ffmpeg_read(inputs, pipe.feature_extractor.sampling_rate)
array_pad = np.zeros(int(pipe.feature_extractor.sampling_rate * 0.5))
inputs = np.concatenate([array_pad, inputs, array_pad])
inputs = {"array": inputs, "sampling_rate": pipe.feature_extractor.sampling_rate}
return get_prediction(inputs, prompt)
demo = gr.Blocks()
description = (f"Transcribe long-form microphone or audio inputs with the click of a button! Demo uses Kotoba-Whisper "
f"checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers to transcribe audio"
f" files of arbitrary length.")
title = f"Transcribe Audio with {os.path.basename(MODEL_NAME)}"
mf_transcribe = gr.Interface(
fn=transcribe,
inputs=[
gr.Audio(sources="microphone", type="filepath"),
gr.Textbox(lines=1, placeholder="Prompt"),
],
outputs=["text", "text"],
title=title,
description=description,
allow_flagging="never",
)
file_transcribe = gr.Interface(
fn=transcribe,
inputs=[
gr.Audio(sources="upload", type="filepath", label="Audio file"),
gr.Textbox(lines=1, placeholder="Prompt"),
],
outputs=["text", "text"],
title=title,
description=description,
allow_flagging="never",
)
with demo:
gr.TabbedInterface([mf_transcribe, file_transcribe], ["Microphone", "Audio file"])
demo.queue(api_open=False, default_concurrency_limit=40).launch(show_api=False, show_error=True)
|