wsntxxn
Update arxiv and code link
f729a94
raw
history blame
3.64 kB
from pathlib import Path
import argparse
from functools import partial
import gradio as gr
import torch
from torchaudio.functional import resample
import utils.train_util as train_util
def load_model(cfg,
ckpt_path,
device):
model = train_util.init_model_from_config(cfg["model"])
ckpt = torch.load(ckpt_path, "cpu")
train_util.load_pretrained_model(model, ckpt)
model.eval()
model = model.to(device)
tokenizer = train_util.init_obj_from_dict(cfg["tokenizer"])
if not tokenizer.loaded:
tokenizer.load_state_dict(ckpt["tokenizer"])
model.set_index(tokenizer.bos, tokenizer.eos, tokenizer.pad)
return model, tokenizer
def infer(file, runner):
sr, wav = file
wav = torch.as_tensor(wav)
if wav.dtype == torch.short:
wav = wav / 2 ** 15
elif wav.dtype == torch.int:
wav = wav / 2 ** 31
if wav.ndim > 1:
wav = wav.mean(1)
wav = resample(wav, sr, runner.target_sr)
wav_len = len(wav)
wav = wav.float().unsqueeze(0).to(runner.device)
input_dict = {
"mode": "inference",
"wav": wav,
"wav_len": [wav_len],
"specaug": False,
"sample_method": "beam",
"beam_size": 3,
}
with torch.no_grad():
output_dict = runner.model(input_dict)
seq = output_dict["seq"].cpu().numpy()
cap = runner.tokenizer.decode(seq)[0]
return cap
# def input_toggle(input_type):
# if input_type == "file":
# return gr.update(visible=True), gr.update(visible=False)
# elif input_type == "mic":
# return gr.update(visible=False), gr.update(visible=True)
class InferRunner:
def __init__(self, model_name):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
exp_dir = Path(f"./checkpoints/{model_name.lower()}")
cfg = train_util.load_config(exp_dir / "config.yaml")
self.model, self.tokenizer = load_model(cfg, exp_dir / "ckpt.pth", self.device)
self.target_sr = cfg["target_sr"]
def change_model(self, model_name):
exp_dir = Path(f"./checkpoints/{model_name.lower()}")
cfg = train_util.load_config(exp_dir / "config.yaml")
self.model, self.tokenizer = load_model(cfg, exp_dir / "ckpt.pth", self.device)
self.target_sr = cfg["target_sr"]
def change_model(radio):
global infer_runner
infer_runner.change_model(radio)
with gr.Blocks() as demo:
with gr.Row():
gr.Markdown("# Lightweight EfficientNetB2-Transformer Audio Captioning")
with gr.Row():
gr.Markdown("""
[![arXiv](https://img.shields.io/badge/arXiv-2407.14329-brightgreen.svg?style=flat-square)](https://arxiv.org/abs/2407.14329)
[![github](https://img.shields.io/badge/GitHub-Code-blue?logo=Github&style=flat-square)](https://github.com/wsntxxn/AudioCaption?tab=readme-ov-file#lightweight-effb2-transformer-model)
""")
with gr.Row():
with gr.Column():
radio = gr.Radio(
["AudioCaps", "Clotho"],
value="AudioCaps",
label="Select model"
)
infer_runner = InferRunner(radio.value)
file = gr.Audio(label="Input", visible=True)
radio.change(fn=change_model, inputs=[radio,],)
btn = gr.Button("Run")
with gr.Column():
output = gr.Textbox(label="Output")
btn.click(
fn=partial(infer,
runner=infer_runner),
inputs=[file,],
outputs=output
)
demo.launch()