File size: 557 Bytes
e07199b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from transformers import OPTForCausalLM, OPTConfig, AutoModel
import torch

from miditok import TokSequence


class OPTForMusicGeneration(OPTForCausalLM):
    
    def generate_music(self, tokenizer, **kwargs):
        input = torch.tensor([[self.config.bos_token_id]], device=self.device)
        midi = self.generate(input, **kwargs)
        generated_ts = TokSequence(ids=midi.tolist()[0], ids_bpe_encoded=True)
        generated_score = tokenizer(generated_ts)
        return generated_score


OPTForMusicGeneration.register_for_auto_class("AutoModel")