File size: 973 Bytes
8f71eda |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
import torch
from transformers import PreTrainedModel
from .LightweightGANConfig import LightweightGANConfig
from .deploy import Generator
class LightweightGANModel(PreTrainedModel):
config_class = LightweightGANConfig
def __init__(self, config):
super().__init__(config)
self.model = Generator(
image_size=config.image_size,
latent_dim=config.latent_dim,
fmap_max=config.fmap_max,
fmap_inverse_coef=config.fmap_inverse_coef,
transparent=config.transparent,
greyscale=config.greyscale,
attn_res_layers=config.attn_res_layers,
freq_chan_attn=config.freq_chan_attn,
syncbatchnorm=config.syncbatchnorm,
antialias=config.antialias,
)
def forward(self, tensor):
return self.model(tensor)
def load_params(self, pt_file):
self.model.load_state_dict(torch.load(pt_file))
|