Spaces:
Running
Running
import os, glob, sys | |
import pickle | |
import streamlit as st | |
import torch | |
import matplotlib.pyplot as plt | |
import numpy as np | |
sys.path.append('stylegan3') | |
class SampleFromGAN: | |
def __init__(self, G, z_shp, in_gpu=False) -> None: | |
self.G = G | |
self.in_gpu = in_gpu | |
self.z_shp = z_shp #[#images, z_dim] | |
def __call__(self,): | |
z = torch.randn(self.z_shp) | |
if self.in_gpu: | |
z = z.cuda() | |
ims = G(z, c=None) | |
ims = ims[:,0,...] | |
return ims | |
class Plot: | |
def __init__(self, im_gen) -> None: | |
self.im_gen = im_gen | |
assert callable(im_gen) | |
def __call__(self): | |
ims = self.im_gen() | |
# plot first image | |
im = ims[0,...] | |
fig, ax = plt.subplots(1, figsize=(12,12)) | |
fig.subplots_adjust(left=0,right=1,bottom=0,top=1) | |
ax.imshow(im, cmap='gray') | |
ax.axis('tight') | |
ax.axis('off') | |
st.pyplot(fig) | |
# path_ckpt = "/home/firat/saved_models/DLBIRHOUI/stylegan/00001-stylegan2-SWFD-linear_BP-linear_BP-gpus2-batch32-gamma8.2-augnoaug" | |
path_ckpt = "./model_weights" | |
fname_pkl = os.path.join(path_ckpt, 'network-snapshot-005000.pkl') | |
in_gpu = False | |
num_images = 1 | |
with open(fname_pkl, 'rb') as f: | |
G = pickle.load(f)['G_ema'] # torch.nn.ModuleDict | |
if in_gpu: | |
G = G.cuda() | |
sampler = SampleFromGAN(G=G, z_shp=[num_images, G.z_dim], in_gpu=in_gpu) | |
button_on_click = Plot(im_gen=sampler) | |
button_gen_clicked = st.button(label='Generate an image', key='n', on_click=button_on_click) | |