arXiv

Usage

import torch
from transformers import AutoModel, PreTrainedTokenizerFast
import torchaudio


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = AutoModel.from_pretrained(
    "wsntxxn/cnn14rnn-tempgru-audiocaps-captioning",
    trust_remote_code=True
).to(device)
tokenizer = PreTrainedTokenizerFast.from_pretrained(
    "wsntxxn/audiocaps-simple-tokenizer"
)

wav, sr = torchaudio.load("/path/to/file.wav")
wav = torchaudio.functional.resample(wav, sr, model.config.sample_rate)
if wav.size(0) > 1:
    wav = wav.mean(0).unsqueeze(0)

with torch.no_grad():
    word_idxs = model(
        audio=wav,
        audio_length=[wav.size(1)],
    )

caption = tokenizer.decode(word_idxs[0], skip_special_tokens=True)
print(caption)

This will make the description as specific as possible.

You can also manually assign a temporal tag to control the specificity of temporal relationship description:

with torch.no_grad():
    word_idxs = model(
        audio=wav,
        audio_length=[wav.size(1)],
        temporal_tag=[2], # desribe "sequential" if there are sequential events, otherwise use the most complex relationship
    )

The temporal tag is defined as:

Temporal Tag Definition
0 Only 1 Event
1 Simultaneous Events
2 Sequential Events
3 More Complex Events

Citation

If you find the model useful, please cite this paper:

@inproceedings{xie2023enhance,
    author = {Zeyu Xie and Xuenan Xu and Mengyue Wu and Kai Yu},
    title = {Enhance Temporal Relations in Audio Captioning with Sound Event Detection},
    year = 2023,
    booktitle = {Proc. INTERSPEECH},
    pages = {4179--4183},
}
Downloads last month
211
Safetensors
Model size
101M params
Tensor type
F32
·
Inference Examples
Inference API (serverless) does not yet support model repos that contain custom code.