import os, glob, sys import pickle import streamlit as st import torch import matplotlib.pyplot as plt import numpy as np sys.path.append('stylegan3') class SampleFromGAN: def __init__(self, G, z_shp, in_gpu=False) -> None: self.G = G self.in_gpu = in_gpu self.z_shp = z_shp #[#images, z_dim] def __call__(self,): z = torch.randn(self.z_shp) if self.in_gpu: z = z.cuda() ims = G(z, c=None) ims = ims[:,0,...] return ims class Plot: def __init__(self, im_gen) -> None: self.im_gen = im_gen assert callable(im_gen) def __call__(self): ims = self.im_gen() # plot first image im = ims[0,...] fig, ax = plt.subplots(1, figsize=(12,12)) fig.subplots_adjust(left=0,right=1,bottom=0,top=1) ax.imshow(im, cmap='gray') ax.axis('tight') ax.axis('off') st.pyplot(fig) # path_ckpt = "/home/firat/saved_models/DLBIRHOUI/stylegan/00001-stylegan2-SWFD-linear_BP-linear_BP-gpus2-batch32-gamma8.2-augnoaug" path_ckpt = "./model_weights" fname_pkl = os.path.join(path_ckpt, 'network-snapshot-005000.pkl') in_gpu = False num_images = 1 with open(fname_pkl, 'rb') as f: G = pickle.load(f)['G_ema'] # torch.nn.ModuleDict if in_gpu: G = G.cuda() sampler = SampleFromGAN(G=G, z_shp=[num_images, G.z_dim], in_gpu=in_gpu) button_on_click = Plot(im_gen=sampler) button_gen_clicked = st.button(label='Generate an image', key='n', on_click=button_on_click)