OAGen_Linear / utils.py
firatozdemir's picture
minor update
912a3d8
raw
history blame
1.43 kB
import os, glob, sys
import pickle
import streamlit as st
import torch
import matplotlib.pyplot as plt
import numpy as np
sys.path.append('stylegan3')
class SampleFromGAN:
def __init__(self, G, z_shp, in_gpu=False) -> None:
self.G = G
self.in_gpu = in_gpu
self.z_shp = z_shp #[#images, z_dim]
def __call__(self,):
z = torch.randn(self.z_shp)
if self.in_gpu:
z = z.cuda()
ims = self.G(z, c=None)
ims = ims[:,0,...]
return ims
class Plot:
def __init__(self, im_gen) -> None:
self.im_gen = im_gen
assert callable(im_gen)
def __call__(self):
ims = self.im_gen()
# plot first image
im = ims[0,...]
fig, ax = plt.subplots(1, figsize=(12,12))
fig.subplots_adjust(left=0,right=1,bottom=0,top=1)
ax.imshow(im, cmap='gray')
ax.axis('tight')
ax.axis('off')
st.pyplot(fig)
def load_default_gen(in_gpu=False, fname_pkl=None):
if fname_pkl is None:
path_ckpt = "./model_weights"
fname_pkl = os.path.join(path_ckpt, 'network-snapshot-005000.pkl')
if not os.path.isfile(fname_pkl):
raise AssertionError(f'Could not find the default network snapshot at {fname_pkl}. Quitting.')
with open(fname_pkl, 'rb') as f:
G = pickle.load(f)['G_ema'] # torch.nn.ModuleDict
if in_gpu:
G = G.cuda()
return G