MonsterForge-medium / LightweightGANModel.py
michaelriedl's picture
Initial dump
8f71eda
raw
history blame contribute delete
973 Bytes
import torch
from transformers import PreTrainedModel
from .LightweightGANConfig import LightweightGANConfig
from .deploy import Generator
class LightweightGANModel(PreTrainedModel):
config_class = LightweightGANConfig
def __init__(self, config):
super().__init__(config)
self.model = Generator(
image_size=config.image_size,
latent_dim=config.latent_dim,
fmap_max=config.fmap_max,
fmap_inverse_coef=config.fmap_inverse_coef,
transparent=config.transparent,
greyscale=config.greyscale,
attn_res_layers=config.attn_res_layers,
freq_chan_attn=config.freq_chan_attn,
syncbatchnorm=config.syncbatchnorm,
antialias=config.antialias,
)
def forward(self, tensor):
return self.model(tensor)
def load_params(self, pt_file):
self.model.load_state_dict(torch.load(pt_file))