File size: 3,378 Bytes
6065472 dd3d338 6065472 dd3d338 6065472 dd3d338 6065472 487e498 6065472 487e498 6065472 dd3d338 6065472 dd3d338 6065472 487e498 6065472 487e498 dd3d338 6065472 487e498 dd3d338 6065472 487e498 6065472 487e498 f729a94 487e498 6065472 487e498 6065472 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
from functools import partial
import gradio as gr
import torch
from torchaudio.functional import resample
from transformers import AutoModel, PreTrainedTokenizerFast
def load_model(model_name,
device):
if model_name == "AudioCaps":
model = AutoModel.from_pretrained(
"wsntxxn/effb2-trm-audiocaps-captioning",
trust_remote_code=True
).to(device)
tokenizer = PreTrainedTokenizerFast.from_pretrained(
"wsntxxn/audiocaps-simple-tokenizer"
)
elif model_name == "Clotho":
model = AutoModel.from_pretrained(
"wsntxxn/effb2-trm-clotho-captioning",
trust_remote_code=True
).to(device)
tokenizer = PreTrainedTokenizerFast.from_pretrained(
"wsntxxn/clotho-simple-tokenizer"
)
return model, tokenizer
def infer(file, runner):
sr, wav = file
wav = torch.as_tensor(wav)
if wav.dtype == torch.short:
wav = wav / 2 ** 15
elif wav.dtype == torch.int:
wav = wav / 2 ** 31
if wav.ndim > 1:
wav = wav.mean(1)
wav = resample(wav, sr, runner.target_sr)
wav_len = len(wav)
wav = wav.float().unsqueeze(0)
with torch.no_grad():
word_idx = runner.model(
audio=wav,
audio_length=[wav_len]
)[0]
cap = runner.tokenizer.decode(word_idx, skip_special_tokens=True)
return cap
# def input_toggle(input_type):
# if input_type == "file":
# return gr.update(visible=True), gr.update(visible=False)
# elif input_type == "mic":
# return gr.update(visible=False), gr.update(visible=True)
class InferRunner:
def __init__(self, model_name):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model, self.tokenizer = load_model(model_name, self.device)
self.target_sr = self.model.config.sample_rate
def change_model(self, model_name):
self.model, self.tokenizer = load_model(model_name, self.device)
self.target_sr = self.model.config.sample_rate
def change_model(radio):
global infer_runner
infer_runner.change_model(radio)
with gr.Blocks() as demo:
with gr.Row():
gr.Markdown("# Lightweight EfficientNetB2-Transformer Audio Captioning")
with gr.Row():
gr.Markdown("""
[![arXiv](https://img.shields.io/badge/arXiv-2407.14329-brightgreen.svg?style=flat-square)](https://arxiv.org/abs/2407.14329)
[![github](https://img.shields.io/badge/GitHub-Code-blue?logo=Github&style=flat-square)](https://github.com/wsntxxn/AudioCaption?tab=readme-ov-file#lightweight-effb2-transformer-model)
""")
with gr.Row():
with gr.Column():
radio = gr.Radio(
["AudioCaps", "Clotho"],
value="AudioCaps",
label="Select model"
)
infer_runner = InferRunner(radio.value)
file = gr.Audio(label="Input", visible=True)
radio.change(fn=change_model, inputs=[radio,],)
btn = gr.Button("Run")
with gr.Column():
output = gr.Textbox(label="Output")
btn.click(
fn=partial(infer,
runner=infer_runner),
inputs=[file,],
outputs=output
)
demo.launch()
|