|
import torch
|
|
from transformers import PreTrainedModel
|
|
from .LightweightGANConfig import LightweightGANConfig
|
|
from .deploy import Generator
|
|
|
|
|
|
class LightweightGANModel(PreTrainedModel):
|
|
config_class = LightweightGANConfig
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.model = Generator(
|
|
image_size=config.image_size,
|
|
latent_dim=config.latent_dim,
|
|
fmap_max=config.fmap_max,
|
|
fmap_inverse_coef=config.fmap_inverse_coef,
|
|
transparent=config.transparent,
|
|
greyscale=config.greyscale,
|
|
attn_res_layers=config.attn_res_layers,
|
|
freq_chan_attn=config.freq_chan_attn,
|
|
syncbatchnorm=config.syncbatchnorm,
|
|
antialias=config.antialias,
|
|
)
|
|
|
|
def forward(self, tensor):
|
|
return self.model(tensor)
|
|
|
|
def load_params(self, pt_file):
|
|
self.model.load_state_dict(torch.load(pt_file))
|
|
|