File size: 646 Bytes
9fd08f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
import torch.nn as nn
from transformers import AutoConfig
from DeFTAN2 import DeFTAN2

class DeFTAN2Model(nn.Module):
    def __init__(self, config):
        super(DeFTAN2Model, self).__init__()
        self.model = DeFTAN2(config)

    def forward(self, x):
        return self.model(x)
    
    @classmethod
    def from_pretrained(cls, model_path):
        config = AutoConfig.from_pretrained(model_path)  # config.json에서 설정 로드
        model = cls(config)
        state_dict = torch.load(f"{model_path}/deftan2.bin", map_location="cpu")
        model.load_state_dict(state_dict)
        return model