from transformers import PreTrainedModel from .mgpt_config import MGPTConfig from .gpt_model import GPT class MusicModel(PreTrainedModel): config_class = MGPTConfig def __init__(self, config): super().__init__(config) self.model = GPT(config) def forward(self, *args, **kwargs): return self.model(*args, **kwargs) def generate(self, *args, **kwargs): return self.model.generate(*args, **kwargs)