OAGen_Linear / app.py
firatozdemir's picture
init commit
691e8f7
raw
history blame
1.52 kB
import os, glob, sys
import pickle
import streamlit as st
import torch
import matplotlib.pyplot as plt
import numpy as np
sys.path.append('stylegan3')
class SampleFromGAN:
def __init__(self, G, z_shp, in_gpu=False) -> None:
self.G = G
self.in_gpu = in_gpu
self.z_shp = z_shp #[#images, z_dim]
def __call__(self,):
z = torch.randn(self.z_shp)
if self.in_gpu:
z = z.cuda()
ims = G(z, c=None)
ims = ims[:,0,...]
return ims
class Plot:
def __init__(self, im_gen) -> None:
self.im_gen = im_gen
assert callable(im_gen)
def __call__(self):
ims = self.im_gen()
# plot first image
im = ims[0,...]
fig, ax = plt.subplots(1, figsize=(12,12))
fig.subplots_adjust(left=0,right=1,bottom=0,top=1)
ax.imshow(im, cmap='gray')
ax.axis('tight')
ax.axis('off')
st.pyplot(fig)
# path_ckpt = "/home/firat/saved_models/DLBIRHOUI/stylegan/00001-stylegan2-SWFD-linear_BP-linear_BP-gpus2-batch32-gamma8.2-augnoaug"
path_ckpt = "./model_weights"
fname_pkl = os.path.join(path_ckpt, 'network-snapshot-005000.pkl')
in_gpu = False
num_images = 1
with open(fname_pkl, 'rb') as f:
G = pickle.load(f)['G_ema'] # torch.nn.ModuleDict
if in_gpu:
G = G.cuda()
sampler = SampleFromGAN(G=G, z_shp=[num_images, G.z_dim], in_gpu=in_gpu)
button_on_click = Plot(im_gen=sampler)
button_gen_clicked = st.button(label='Generate an image', key='n', on_click=button_on_click)