Edit model card

Model Card for Model ID

Model Details

Model Description

The model consists of a music encoder MERT-v1-300M, a natural language decoder vicuna-7b-delta-v0, and a linear projection laer between the two.

This checkpoint of MusiLingo is developed on the MusicInstruct (MI)-short and can answer short instructions with music raw audio, such as querying about the tempo, emotion, genre, tags information. You can use the MI dataset for the following demo

Model Sources [optional]

Getting Start

from tqdm.auto import tqdm

import torch
from torch.utils.data import DataLoader
from transformers import Wav2Vec2FeatureExtractor
from transformers import StoppingCriteria, StoppingCriteriaList



class StoppingCriteriaSub(StoppingCriteria):
    def __init__(self, stops=[], encounters=1):
        super().__init__()
        self.stops = stops
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        for stop in self.stops:
            if torch.all((stop == input_ids[0][-len(stop):])).item():
                return True
        return False

def answer(self, samples, stopping, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.5,
        repetition_penalty=1.0, length_penalty=1, temperature=0.1, max_length=2000):
    audio = samples["audio"].cuda()
    audio_embeds, atts_audio = self.encode_audio(audio)
    if 'instruction_input' in samples:  # instruction dataset
        #print('Instruction Batch')
        instruction_prompt = []
        for instruction in samples['instruction_input']:
            prompt = '<Audio><AudioHere></Audio> ' + instruction
            instruction_prompt.append(self.prompt_template.format(prompt))
        audio_embeds, atts_audio = self.instruction_prompt_wrap(audio_embeds, atts_audio, instruction_prompt)
    self.llama_tokenizer.padding_side = "right"
    batch_size = audio_embeds.shape[0]
    bos = torch.ones([batch_size, 1],
                    dtype=torch.long,
                    device=torch.device('cuda')) * self.llama_tokenizer.bos_token_id
    bos_embeds = self.llama_model.model.embed_tokens(bos)
    atts_bos = atts_audio[:, :1]
    inputs_embeds = torch.cat([bos_embeds, audio_embeds], dim=1)
    attention_mask = torch.cat([atts_bos, atts_audio], dim=1)
    outputs = self.llama_model.generate(
        inputs_embeds=inputs_embeds,
        max_new_tokens=max_new_tokens,
        stopping_criteria=stopping,
        num_beams=num_beams,
        do_sample=True,
        min_length=min_length,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        length_penalty=length_penalty,
        temperature=temperature,
    )
    output_token = outputs[0]
    if output_token[0] == 0:  # the model might output a unknow token <unk> at the beginning. remove it
        output_token = output_token[1:]
    if output_token[0] == 1:  # if there is a start token <s> at the beginning. remove it
        output_token = output_token[1:]
    output_text = self.llama_tokenizer.decode(output_token, add_special_tokens=False)
    output_text = output_text.split('###')[0]  # remove the stop sign '###'
    output_text = output_text.split('Assistant:')[-1].strip()
    return output_text

processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v1-330M",trust_remote_code=True)
ds = CMIDataset(processor, 'path/to/MI_dataset', 'test', question_type='short')
dl = DataLoader(
                ds,
                batch_size=1,
                num_workers=0,
                pin_memory=True,
                shuffle=False,
                drop_last=True,
                collate_fn=ds.collater
                )

stopping = StoppingCriteriaList([StoppingCriteriaSub([torch.tensor([835]).cuda(),
                                torch.tensor([2277, 29937]).cuda()])])

from transformers import AutoModel
model_short = AutoModel.from_pretrained("m-a-p/MusiLingo-short-v1")

for idx, sample in tqdm(enumerate(dl)):
    ans = answer(Musilingo_short.model, sample, stopping, length_penalty=100, temperature=0.1)
    txt = sample['text_input'][0]
    print(txt)
    print(and)

Citing This Work

If you find the work useful for your research, please consider citing it using the following BibTeX entry:

@inproceedings{deng2024musilingo,
  title={MusiLingo: Bridging Music and Text with Pre-trained Language Models for Music Captioning and Query Response},
  author={Deng, Zihao and Ma, Yinghao and Liu, Yudong and Guo, Rongchen and Zhang, Ge and Chen, Wenhu and Huang, Wenhao and Benetos, Emmanouil},
  booktitle={Proceedings of the 2024 Annual Conference of the North American Chapter of the Association for Computational Linguistics (NAACL 2024)},
  year={2024},
  organization={Association for Computational Linguistics}
}
Downloads last month
20
Safetensors
Model size
7.06B params
Tensor type
F32
·
FP16
·
Inference API
Inference API (serverless) does not yet support model repos that contain custom code.

Collection including m-a-p/MusiLingo-short-v1