MonsterForge-small / MonsterForgeModel.py
michaelriedl's picture
Switched to relative import
10394e1
raw
history blame contribute delete
988 Bytes
import torch
from transformers import PreTrainedModel
from .MonsterForgeSmallConfig import MonsterForgeSmallConfig
from .LightweightGAN import Generator
class MonsterForgeModel(PreTrainedModel):
config_class = MonsterForgeSmallConfig
def __init__(self, config):
super().__init__(config)
self.model = Generator(
image_size=config.image_size,
latent_dim=config.latent_dim,
fmap_max=config.fmap_max,
fmap_inverse_coef=config.fmap_inverse_coef,
transparent=config.transparent,
greyscale=config.greyscale,
attn_res_layers=config.attn_res_layers,
freq_chan_attn=config.freq_chan_attn,
syncbatchnorm=config.syncbatchnorm,
antialias=config.antialias,
)
def forward(self, tensor):
return self.model(tensor)
def load_params(self, pt_file):
self.model.load_state_dict(torch.load(pt_file))