File size: 3,378 Bytes
6065472
 
 
 
dd3d338
6065472
 
dd3d338
6065472
dd3d338
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6065472
 
 
487e498
6065472
 
 
 
 
 
 
 
487e498
6065472
dd3d338
6065472
dd3d338
 
 
 
 
6065472
 
 
 
 
 
 
 
487e498
6065472
487e498
 
dd3d338
 
6065472
487e498
dd3d338
 
6065472
 
487e498
 
 
6065472
487e498
 
f729a94
 
 
 
 
 
 
 
 
487e498
 
 
 
 
 
6065472
487e498
 
 
 
 
 
 
 
 
 
 
 
 
 
6065472
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from functools import partial
import gradio as gr
import torch
from torchaudio.functional import resample
from transformers import AutoModel, PreTrainedTokenizerFast


def load_model(model_name,
               device):
    if model_name == "AudioCaps":
        model = AutoModel.from_pretrained(
            "wsntxxn/effb2-trm-audiocaps-captioning",
            trust_remote_code=True
        ).to(device)
        tokenizer = PreTrainedTokenizerFast.from_pretrained(
            "wsntxxn/audiocaps-simple-tokenizer"
        )
    elif model_name == "Clotho":
        model = AutoModel.from_pretrained(
            "wsntxxn/effb2-trm-clotho-captioning",
            trust_remote_code=True
        ).to(device)
        tokenizer = PreTrainedTokenizerFast.from_pretrained(
            "wsntxxn/clotho-simple-tokenizer"
        )
    return model, tokenizer


def infer(file, runner):
    sr, wav = file
    wav = torch.as_tensor(wav)
    if wav.dtype == torch.short:
        wav = wav / 2 ** 15
    elif wav.dtype == torch.int:
        wav = wav / 2 ** 31
    if wav.ndim > 1:
        wav = wav.mean(1)
    wav = resample(wav, sr, runner.target_sr)
    wav_len = len(wav)
    wav = wav.float().unsqueeze(0)
    with torch.no_grad():
        word_idx = runner.model(
            audio=wav,
            audio_length=[wav_len]
        )[0]
        cap = runner.tokenizer.decode(word_idx, skip_special_tokens=True)
    return cap

# def input_toggle(input_type):
#     if input_type == "file":
#         return gr.update(visible=True), gr.update(visible=False)
#     elif input_type == "mic":
#         return gr.update(visible=False), gr.update(visible=True)

class InferRunner:

    def __init__(self, model_name):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model, self.tokenizer = load_model(model_name, self.device)
        self.target_sr = self.model.config.sample_rate
    
    def change_model(self, model_name):
        self.model, self.tokenizer = load_model(model_name, self.device)
        self.target_sr = self.model.config.sample_rate


def change_model(radio):
    global infer_runner
    infer_runner.change_model(radio)


with gr.Blocks() as demo:
    with gr.Row():
        gr.Markdown("# Lightweight EfficientNetB2-Transformer Audio Captioning")

    with gr.Row():
        gr.Markdown("""
            [![arXiv](https://img.shields.io/badge/arXiv-2407.14329-brightgreen.svg?style=flat-square)](https://arxiv.org/abs/2407.14329)
            
            [![github](https://img.shields.io/badge/GitHub-Code-blue?logo=Github&style=flat-square)](https://github.com/wsntxxn/AudioCaption?tab=readme-ov-file#lightweight-effb2-transformer-model)
        """)
    with gr.Row():
        with gr.Column():
            radio = gr.Radio(
                ["AudioCaps", "Clotho"],
                value="AudioCaps",
                label="Select model"
            )
            infer_runner = InferRunner(radio.value)
            file = gr.Audio(label="Input", visible=True)
            radio.change(fn=change_model, inputs=[radio,],)
            btn = gr.Button("Run")
        with gr.Column():
            output = gr.Textbox(label="Output")
        btn.click(
            fn=partial(infer,
                       runner=infer_runner),
            inputs=[file,],
            outputs=output
        )
    
    demo.launch()