File size: 448 Bytes
34cfbe5
 
 
 
 
 
 
 
 
 
 
 
f5d45bf
 
8204068
f5d45bf
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18

from transformers import PreTrainedModel
from .mgpt_config import MGPTConfig
from .gpt_model import GPT

class MusicModel(PreTrainedModel):
    config_class = MGPTConfig

    def __init__(self, config):
        super().__init__(config)
        self.model = GPT(config)

    def forward(self, *args, **kwargs):
        return self.model(*args, **kwargs)

    def generate(self, *args, **kwargs):
        return self.model.generate(*args, **kwargs)