GL-LCM / codes /model.py
diaoquesang's picture
Upload 29 files
6434535 verified
from modules.unet import UNetModel
from generative.networks.nets import VQVAE
from config import config
myUnet = UNetModel(
image_size=config.image_size / config.r,
model_channels=128,
in_channels=8,
out_channels=8,
num_res_blocks=8,
num_heads=8,
attention_resolutions=(64, 32, 16, 8),
num_heads_upsample=-1,
num_head_channels=-1,
resblock_updown=True,
channel_mult=(1, 1, 2, 2, 4, 4),
use_scale_shift_norm=True,
use_new_attention_order=True
)
myVQGANModel = VQVAE(
spatial_dims=2,
in_channels=1,
out_channels=1,
num_channels=(128, 256, 512),
num_res_channels=512,
num_res_layers=2,
downsample_parameters=((2, 4, 1, 1), (2, 4, 1, 1), (2, 4, 1, 1),),
upsample_parameters=((2, 4, 1, 1, 0), (2, 4, 1, 1, 0), (2, 4, 1, 1, 0)),
num_embeddings=1024,
embedding_dim=4,
)
if __name__ == "__main__":
print("Number of model parameters:", sum([p.numel() for p in myUnet.parameters()]))
print("Number of model parameters:", sum([p.numel() for p in myVQGANModel.parameters()]))