from transformers import OPTForCausalLM, OPTConfig, AutoModel import torch from miditok import TokSequence class OPTForMusicGeneration(OPTForCausalLM): def generate_music(self, tokenizer, **kwargs): input = torch.tensor([[self.config.bos_token_id]], device=self.device) midi = self.generate(input, **kwargs) generated_ts = TokSequence(ids=midi.tolist()[0], ids_bpe_encoded=True) generated_score = tokenizer(generated_ts) return generated_score OPTForMusicGeneration.register_for_auto_class("AutoModel")