File size: 973 Bytes
8f71eda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
from transformers import PreTrainedModel
from .LightweightGANConfig import LightweightGANConfig
from .deploy import Generator


class LightweightGANModel(PreTrainedModel):
    config_class = LightweightGANConfig

    def __init__(self, config):
        super().__init__(config)
        self.model = Generator(
            image_size=config.image_size,
            latent_dim=config.latent_dim,
            fmap_max=config.fmap_max,
            fmap_inverse_coef=config.fmap_inverse_coef,
            transparent=config.transparent,
            greyscale=config.greyscale,
            attn_res_layers=config.attn_res_layers,
            freq_chan_attn=config.freq_chan_attn,
            syncbatchnorm=config.syncbatchnorm,
            antialias=config.antialias,
        )

    def forward(self, tensor):
        return self.model(tensor)

    def load_params(self, pt_file):
        self.model.load_state_dict(torch.load(pt_file))