#!/usr/bin/env python # coding=utf-8 # Copyright (c) 2022 PyTorch contributors and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions. import torch.nn as nn from huggan.pytorch.huggan_mixin import HugGANModelHubMixin class Generator(nn.Module, HugGANModelHubMixin): def __init__(self, num_channels=3, latent_dim=100, hidden_size=64): super(Generator, self).__init__() self.model = nn.Sequential( # input is Z, going into a convolution nn.ConvTranspose2d(latent_dim, hidden_size * 8, 4, 1, 0, bias=False), nn.BatchNorm2d(hidden_size * 8), nn.ReLU(True), # state size. (hidden_size*8) x 4 x 4 nn.ConvTranspose2d(hidden_size * 8, hidden_size * 4, 4, 2, 1, bias=False), nn.BatchNorm2d(hidden_size * 4), nn.ReLU(True), # state size. (hidden_size*4) x 8 x 8 nn.ConvTranspose2d(hidden_size * 4, hidden_size * 2, 4, 2, 1, bias=False), nn.BatchNorm2d(hidden_size * 2), nn.ReLU(True), # state size. (hidden_size*2) x 16 x 16 nn.ConvTranspose2d(hidden_size * 2, hidden_size, 4, 2, 1, bias=False), nn.BatchNorm2d(hidden_size), nn.ReLU(True), # state size. (hidden_size) x 32 x 32 nn.ConvTranspose2d(hidden_size, num_channels, 4, 2, 1, bias=False), nn.Tanh() # state size. (num_channels) x 64 x 64 ) def forward(self, noise): pixel_values = self.model(noise) return pixel_values class Discriminator(nn.Module): def __init__(self, num_channels=3, hidden_size=64): super(Discriminator, self).__init__() self.model = nn.Sequential( # input is (num_channels) x 64 x 64 nn.Conv2d(num_channels, hidden_size, 4, 2, 1, bias=False), nn.LeakyReLU(0.2, inplace=True), # state size. (hidden_size) x 32 x 32 nn.Conv2d(hidden_size, hidden_size * 2, 4, 2, 1, bias=False), nn.BatchNorm2d(hidden_size * 2), nn.LeakyReLU(0.2, inplace=True), # state size. (hidden_size*2) x 16 x 16 nn.Conv2d(hidden_size * 2, hidden_size * 4, 4, 2, 1, bias=False), nn.BatchNorm2d(hidden_size * 4), nn.LeakyReLU(0.2, inplace=True), # state size. (hidden_size*4) x 8 x 8 nn.Conv2d(hidden_size * 4, hidden_size * 8, 4, 2, 1, bias=False), nn.BatchNorm2d(hidden_size * 8), nn.LeakyReLU(0.2, inplace=True), # state size. (hidden_size*8) x 4 x 4 nn.Conv2d(hidden_size * 8, 1, 4, 1, 0, bias=False), nn.Sigmoid(), ) def forward(self, pixel_values): logits = self.model(pixel_values) return logits