File size: 336 Bytes
34cfbe5
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15

from transformers import PreTrainedModel
from .mgpt_config import MGPTConfig
from .gpt_model import GPT

class MusicModel(PreTrainedModel):
    config_class = MGPTConfig

    def __init__(self, config):
        super().__init__(config)
        self.model = GPT(config)

    def forward(self, inputs):
        return self.model(inputs)