import torch from transformers import PreTrainedModel from .LightweightGANConfig import LightweightGANConfig from .deploy import Generator class LightweightGANModel(PreTrainedModel): config_class = LightweightGANConfig def __init__(self, config): super().__init__(config) self.model = Generator( image_size=config.image_size, latent_dim=config.latent_dim, fmap_max=config.fmap_max, fmap_inverse_coef=config.fmap_inverse_coef, transparent=config.transparent, greyscale=config.greyscale, attn_res_layers=config.attn_res_layers, freq_chan_attn=config.freq_chan_attn, syncbatchnorm=config.syncbatchnorm, antialias=config.antialias, ) def forward(self, tensor): return self.model(tensor) def load_params(self, pt_file): self.model.load_state_dict(torch.load(pt_file))