Edit model card

Model Card for Model ID

Model Details

Model Description

The model consists of a music encoder MERT-v1-300M, a natural language decoder vicuna-7b-delta-v0, and a linear projection laer between the two.

This checkpoint of MusiLingo is developed on the MusicInstruct (MI)-long and can answer long instructions with music raw audio, such as querying about the subjective feelings etc. You can use the MI dataset for the following demo

Model Sources [optional]

Getting Start

from tqdm.auto import tqdm

import torch
from torch.utils.data import DataLoader
from transformers import Wav2Vec2FeatureExtractor
from transformers import StoppingCriteria, StoppingCriteriaList



class StoppingCriteriaSub(StoppingCriteria):
    def __init__(self, stops=[], encounters=1):
        super().__init__()
        self.stops = stops
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        for stop in self.stops:
            if torch.all((stop == input_ids[0][-len(stop):])).item():
                return True
        return False

def answer(self, samples, stopping, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.5,
        repetition_penalty=1.0, length_penalty=1, temperature=0.1, max_length=2000):
    audio = samples["audio"].cuda()
    audio_embeds, atts_audio = self.encode_audio(audio)
    if 'instruction_input' in samples:  # instruction dataset
        #print('Instruction Batch')
        instruction_prompt = []
        for instruction in samples['instruction_input']:
            prompt = '<Audio><AudioHere></Audio> ' + instruction
            instruction_prompt.append(self.prompt_template.format(prompt))
        audio_embeds, atts_audio = self.instruction_prompt_wrap(audio_embeds, atts_audio, instruction_prompt)
    self.llama_tokenizer.padding_side = "right"
    batch_size = audio_embeds.shape[0]
    bos = torch.ones([batch_size, 1],
                    dtype=torch.long,
                    device=torch.device('cuda')) * self.llama_tokenizer.bos_token_id
    bos_embeds = self.llama_model.model.embed_tokens(bos)
    atts_bos = atts_audio[:, :1]
    inputs_embeds = torch.cat([bos_embeds, audio_embeds], dim=1)
    attention_mask = torch.cat([atts_bos, atts_audio], dim=1)
    outputs = self.llama_model.generate(
        inputs_embeds=inputs_embeds,
        max_new_tokens=max_new_tokens,
        stopping_criteria=stopping,
        num_beams=num_beams,
        do_sample=True,
        min_length=min_length,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        length_penalty=length_penalty,
        temperature=temperature,
    )
    output_token = outputs[0]
    if output_token[0] == 0:  # the model might output a unknow token <unk> at the beginning. remove it
        output_token = output_token[1:]
    if output_token[0] == 1:  # if there is a start token <s> at the beginning. remove it
        output_token = output_token[1:]
    output_text = self.llama_tokenizer.decode(output_token, add_special_tokens=False)
    output_text = output_text.split('###')[0]  # remove the stop sign '###'
    output_text = output_text.split('Assistant:')[-1].strip()
    return output_text

processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v1-330M",trust_remote_code=True)
ds = CMIDataset(processor, 'path/to/MI_dataset', 'test', question_type='long')
dl = DataLoader(
                ds,
                batch_size=1,
                num_workers=0,
                pin_memory=True,
                shuffle=False,
                drop_last=True,
                collate_fn=ds.collater
                )

stopping = StoppingCriteriaList([StoppingCriteriaSub([torch.tensor([835]).cuda(),
                                torch.tensor([2277, 29937]).cuda()])])

from transformers import AutoModel
model_long = AutoModel.from_pretrained("m-a-p/MusiLingo-long-v1")

for idx, sample in tqdm(enumerate(dl)):
    ans = answer(Musilingo_long.model, sample, stopping, length_penalty=100, temperature=0.1)
    txt = sample['text_input'][0]
    print(txt)
    print(and)

Citing This Work

If you find the work useful for your research, please consider citing it using the following BibTeX entry:

@inproceedings{deng2024musilingo,
  title={MusiLingo: Bridging Music and Text with Pre-trained Language Models for Music Captioning and Query Response},
  author={Deng, Zihao and Ma, Yinghao and Liu, Yudong and Guo, Rongchen and Zhang, Ge and Chen, Wenhu and Huang, Wenhao and Benetos, Emmanouil},
  booktitle={Proceedings of the 2024 Annual Conference of the North American Chapter of the Association for Computational Linguistics (NAACL 2024)},
  year={2024},
  organization={Association for Computational Linguistics}
}
Downloads last month
38
Safetensors
Model size
7.06B params
Tensor type
F32
·
FP16
·
Inference API
Inference API (serverless) does not yet support model repos that contain custom code.

Collection including m-a-p/MusiLingo-long-v1