import os, glob, sys import pickle import streamlit as st import torch import matplotlib.pyplot as plt import numpy as np sys.path.append('stylegan3') class SampleFromGAN: def __init__(self, G, z_shp, in_gpu=False) -> None: self.G = G self.in_gpu = in_gpu self.z_shp = z_shp #[#images, z_dim] def __call__(self,): z = torch.randn(self.z_shp) if self.in_gpu: z = z.cuda() ims = self.G(z, c=None) ims = ims[:,0,...] return ims class Plot: def __init__(self, im_gen) -> None: self.im_gen = im_gen assert callable(im_gen) def __call__(self): ims = self.im_gen() # plot first image im = ims[0,...] fig, ax = plt.subplots(1, figsize=(12,12)) fig.subplots_adjust(left=0,right=1,bottom=0,top=1) ax.imshow(im, cmap='gray') ax.axis('tight') ax.axis('off') st.pyplot(fig) def load_default_gen(in_gpu=False, fname_pkl=None): if fname_pkl is None: path_ckpt = "./model_weights" fname_pkl = os.path.join(path_ckpt, 'network-snapshot-005000.pkl') if not os.path.isfile(fname_pkl): raise AssertionError(f'Could not find the default network snapshot at {fname_pkl}. Quitting.') with open(fname_pkl, 'rb') as f: G = pickle.load(f)['G_ema'] # torch.nn.ModuleDict if in_gpu: G = G.cuda() return G