OAGen_Linear / utils.py
firatozdemir's picture
updated scripts
1e6244a
raw
history blame
1.46 kB
import os, glob, sys
import pickle
# import streamlit as st
import torch
# import matplotlib.pyplot as plt
import numpy as np
sys.path.append('stylegan3')
class SampleFromGAN:
def __init__(self, G, z_shp, in_gpu=False) -> None:
self.G = G
self.in_gpu = in_gpu
self.z_shp = z_shp #[#images, z_dim]
def __call__(self,):
z = torch.randn(self.z_shp)
if self.in_gpu:
z = z.cuda()
ims = self.G(z, c=None)
ims = ims[:,0,...]
return ims
# class Plot:
# def __init__(self, im_gen) -> None:
# self.im_gen = im_gen
# assert callable(im_gen)
# def __call__(self):
# ims = self.im_gen()
# # plot first image
# im = ims[0,...]
# fig, ax = plt.subplots(1, figsize=(12,12))
# fig.subplots_adjust(left=0,right=1,bottom=0,top=1)
# ax.imshow(im, cmap='gray')
# ax.axis('tight')
# ax.axis('off')
# st.pyplot(fig)
def load_default_gen(in_gpu=False, fname_pkl=None):
if fname_pkl is None:
path_ckpt = "./model_weights"
fname_pkl = os.path.join(path_ckpt, 'network-snapshot-005000.pkl')
if not os.path.isfile(fname_pkl):
raise AssertionError(f'Could not find the default network snapshot at {fname_pkl}. Quitting.')
with open(fname_pkl, 'rb') as f:
G = pickle.load(f)['G_ema'] # torch.nn.ModuleDict
if in_gpu:
G = G.cuda()
return G