|
from functools import partial |
|
import gradio as gr |
|
import torch |
|
from torchaudio.functional import resample |
|
from transformers import AutoModel, PreTrainedTokenizerFast |
|
|
|
|
|
def load_model(model_name, |
|
device): |
|
if model_name == "AudioCaps": |
|
model = AutoModel.from_pretrained( |
|
"wsntxxn/effb2-trm-audiocaps-captioning", |
|
trust_remote_code=True |
|
).to(device) |
|
tokenizer = PreTrainedTokenizerFast.from_pretrained( |
|
"wsntxxn/audiocaps-simple-tokenizer" |
|
) |
|
elif model_name == "Clotho": |
|
model = AutoModel.from_pretrained( |
|
"wsntxxn/effb2-trm-clotho-captioning", |
|
trust_remote_code=True |
|
).to(device) |
|
tokenizer = PreTrainedTokenizerFast.from_pretrained( |
|
"wsntxxn/clotho-simple-tokenizer" |
|
) |
|
return model, tokenizer |
|
|
|
|
|
def infer(file, runner): |
|
sr, wav = file |
|
wav = torch.as_tensor(wav) |
|
if wav.dtype == torch.short: |
|
wav = wav / 2 ** 15 |
|
elif wav.dtype == torch.int: |
|
wav = wav / 2 ** 31 |
|
if wav.ndim > 1: |
|
wav = wav.mean(1) |
|
wav = resample(wav, sr, runner.target_sr) |
|
wav_len = len(wav) |
|
wav = wav.float().unsqueeze(0) |
|
with torch.no_grad(): |
|
word_idx = runner.model( |
|
audio=wav, |
|
audio_length=[wav_len] |
|
)[0] |
|
cap = runner.tokenizer.decode(word_idx, skip_special_tokens=True) |
|
return cap |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class InferRunner: |
|
|
|
def __init__(self, model_name): |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
self.model, self.tokenizer = load_model(model_name, self.device) |
|
self.target_sr = self.model.config.sample_rate |
|
|
|
def change_model(self, model_name): |
|
self.model, self.tokenizer = load_model(model_name, self.device) |
|
self.target_sr = self.model.config.sample_rate |
|
|
|
|
|
def change_model(radio): |
|
global infer_runner |
|
infer_runner.change_model(radio) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
gr.Markdown("# Lightweight EfficientNetB2-Transformer Audio Captioning") |
|
|
|
with gr.Row(): |
|
gr.Markdown(""" |
|
[![arXiv](https://img.shields.io/badge/arXiv-2407.14329-brightgreen.svg?style=flat-square)](https://arxiv.org/abs/2407.14329) |
|
|
|
[![github](https://img.shields.io/badge/GitHub-Code-blue?logo=Github&style=flat-square)](https://github.com/wsntxxn/AudioCaption?tab=readme-ov-file#lightweight-effb2-transformer-model) |
|
""") |
|
with gr.Row(): |
|
with gr.Column(): |
|
radio = gr.Radio( |
|
["AudioCaps", "Clotho"], |
|
value="AudioCaps", |
|
label="Select model" |
|
) |
|
infer_runner = InferRunner(radio.value) |
|
file = gr.Audio(label="Input", visible=True) |
|
radio.change(fn=change_model, inputs=[radio,],) |
|
btn = gr.Button("Run") |
|
with gr.Column(): |
|
output = gr.Textbox(label="Output") |
|
btn.click( |
|
fn=partial(infer, |
|
runner=infer_runner), |
|
inputs=[file,], |
|
outputs=output |
|
) |
|
|
|
demo.launch() |
|
|
|
|