# from https://huggingface.co/spaces/hysts/StyleGAN3/blob/main/model.py import pathlib import pickle import sys import numpy as np import torch import torch.nn as nn from huggingface_hub import hf_hub_download import torch import torchvision.utils as vutils import matplotlib.pyplot as plt from io import BytesIO from PIL import Image current_dir = pathlib.Path(__file__).parent submodule_dir = current_dir / "stylegan3" sys.path.insert(0, submodule_dir.as_posix()) user = "ellemac" dcgan_z_dim = 100 dcgan_gen_feats = 64 ngf = 64 dcgan_img_size = 64 nc = 3 # class Generator(nn.Module): # def __init__(self, ngpu, nz): # super(Generator, self).__init__() # self.ngpu = ngpu # self.main = nn.Sequential( # # input is Z, going into a convolution # nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False), # nn.BatchNorm2d(ngf * 8), # nn.LeakyReLU(0.2, inplace=True), # # state size. (ngf*8) x 4 x 4 # nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), # nn.BatchNorm2d(ngf * 4), # nn.LeakyReLU(0.2, inplace=True), # # state size. (ngf*4) x 8 x 8 # nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), # nn.BatchNorm2d(ngf * 2), # nn.LeakyReLU(0.2, inplace=True), # # state size. (ngf*2) x 16 x 16 # nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), # nn.BatchNorm2d(ngf), # nn.LeakyReLU(0.2, inplace=True), # # state size. (ngf) x 32 x 32 # nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False), # nn.Tanh() # # state size. (nc) x 64 x 64 # ) # def forward(self, input): # return self.main(input) class Generator(nn.Module): def __init__(self, n_gen_feats, n_gpu, z_dim, n_channels): super(Generator, self).__init__() self.n_gpu = n_gpu self.main = nn.Sequential( # input is Z, going into a convolution nn.ConvTranspose2d(z_dim, n_gen_feats * 8, 4, 1, 0, bias=False), nn.BatchNorm2d(n_gen_feats * 8), nn.LeakyReLU(0.2, inplace=True), # state size. (n_gen_feats*8) x 4 x 4 nn.ConvTranspose2d(n_gen_feats * 8, n_gen_feats * 4, 4, 2, 1, bias=False), nn.BatchNorm2d(n_gen_feats * 4), nn.LeakyReLU(0.2, inplace=True), # state size. (n_gen_feats*4) x 8 x 8 nn.ConvTranspose2d(n_gen_feats * 4, n_gen_feats * 2, 4, 2, 1, bias=False), nn.BatchNorm2d(n_gen_feats * 2), nn.LeakyReLU(0.2, inplace=True), # state size. (n_gen_feats*2) x 16 x 16 nn.ConvTranspose2d(n_gen_feats * 2, n_gen_feats, 4, 2, 1, bias=False), nn.BatchNorm2d(n_gen_feats), nn.LeakyReLU(0.2, inplace=True), # state size. (n_gen_feats) x 32 x 32 nn.ConvTranspose2d(n_gen_feats, n_channels, 4, 2, 1, bias=False), nn.Tanh() # state size. (n_channels) x 64 x 64 ) def forward(self, input): return self.main(input) class Model: MODEL_DICT = { "stylegan3-abstract": {"name": "abstract-560eps.pkl", "repo": "avantStyleGAN3"}, "stylegan3-high-fidelity": {"name": "high-fidelity-1120eps.pkl", "repo": "avantStyleGAN3"}, "ada-dcgan": {"name": "gen_6kepoch.pt", "repo": "avantGAN"}, } def __init__(self): self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self._download_all_models() self.model_name = "ada-dcgan" #stylegan3-abstract" self.model = self._load_model(self.model_name) def _load_model(self, model_name: str) -> nn.Module: file_name = self.MODEL_DICT[model_name]["name"] repo = self.MODEL_DICT[model_name]["repo"] path = hf_hub_download(f"{user}/{repo}", file_name) # model repo-type if "stylegan" in model_name: with open(path, "rb") as f: model = pickle.load(f)["G_ema"] else: # todo (elle): don't hardcode the config model = Generator(dcgan_gen_feats, 1, dcgan_z_dim, 3) # model = Generator(0, 100) model.load_state_dict(torch.load(path, map_location=self.device)) model.eval() model.to(self.device) return model def set_model(self, model_name: str) -> None: if model_name == self.model_name: return self.model_name = model_name self.model = self._load_model(model_name) def _download_all_models(self): for name in self.MODEL_DICT.keys(): self._load_model(name) @staticmethod def make_transform(translate: tuple[float, float] = (0,0), angle: float = 0) -> np.ndarray: mat = np.eye(3) sin = np.sin(angle / 360 * np.pi * 2) cos = np.cos(angle / 360 * np.pi * 2) mat[0][0] = cos mat[0][1] = sin mat[0][2] = translate[0] mat[1][0] = -sin mat[1][1] = cos mat[1][2] = translate[1] return mat def generate_z(self, seed: int) -> torch.Tensor: seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max)) z = np.random.RandomState(seed).randn(1, self.model.z_dim) return torch.from_numpy(z).float().to(self.device) def postprocess(self, tensor: torch.Tensor) -> np.ndarray: tensor = (tensor.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) return tensor.cpu().numpy() def set_transform(self, tx: float = 0, ty: float = 0, angle: float = 0) -> None: mat = self.make_transform((tx, ty), angle) mat = np.linalg.inv(mat) self.model.synthesis.input.transform.copy_(torch.from_numpy(mat)) @torch.inference_mode() def generate(self, z: torch.Tensor, label: torch.Tensor, truncation_psi: float) -> torch.Tensor: return self.model(z, label, truncation_psi=truncation_psi) def generate_image(self, seed: int, truncation_psi: float = 0, tx: float = 0, ty: float = 0, angle: float = 0) -> np.ndarray: self.set_transform(tx, ty, angle) z = self.generate_z(seed) label = torch.zeros([1, self.model.c_dim], device=self.device) out = self.generate(z, label, truncation_psi) out = self.postprocess(out) return out[0] def dcgan_generate_image(self, seed: int) -> np.ndarray: torch.manual_seed(seed) if self.device == 'cuda': torch.cuda.manual_seed(seed) with torch.no_grad(): n_images = 1 z = torch.randn(n_images, dcgan_z_dim, 1, 1, device=self.device) fake_images = self.model(z.to(self.device)).cpu() fake_images = fake_images.view(fake_images.size(0), 3, dcgan_img_size, dcgan_img_size) # Create a grid of images grid = vutils.make_grid(fake_images, normalize=True) # Plot the grid and save it to a buffer fig, ax = plt.subplots() ax.imshow(grid.permute(1, 2, 0)) # Convert from CHW to HWC for imshow plt.axis('off') # Save the plot to a buffer buf = BytesIO() plt.savefig(buf, format='png') buf.seek(0) # Load the buffer into a PIL Image img = Image.open(buf) return img def set_model_and_generate_image( self, model_name: str, seed: int, truncation_psi: float = 0, tx: float = 0, ty: float = 0, angle: float = 0 ) -> np.ndarray: self.set_model(model_name) if "stylegan3" in model_name: return self.generate_image(seed, truncation_psi, tx, ty, angle) else: return self.dcgan_generate_image(seed)