Spaces:
Running
Running
import os, glob, sys | |
import pickle | |
# import streamlit as st | |
import torch | |
# import matplotlib.pyplot as plt | |
import numpy as np | |
sys.path.append('stylegan3') | |
class SampleFromGAN: | |
def __init__(self, G, z_shp, in_gpu=False) -> None: | |
self.G = G | |
self.in_gpu = in_gpu | |
self.z_shp = z_shp #[#images, z_dim] | |
def __call__(self,): | |
z = torch.randn(self.z_shp) | |
if self.in_gpu: | |
z = z.cuda() | |
ims = self.G(z, c=None) | |
ims = ims[:,0,...] | |
return ims | |
# class Plot: | |
# def __init__(self, im_gen) -> None: | |
# self.im_gen = im_gen | |
# assert callable(im_gen) | |
# def __call__(self): | |
# ims = self.im_gen() | |
# # plot first image | |
# im = ims[0,...] | |
# fig, ax = plt.subplots(1, figsize=(12,12)) | |
# fig.subplots_adjust(left=0,right=1,bottom=0,top=1) | |
# ax.imshow(im, cmap='gray') | |
# ax.axis('tight') | |
# ax.axis('off') | |
# st.pyplot(fig) | |
def load_default_gen(in_gpu=False, fname_pkl=None): | |
if fname_pkl is None: | |
path_ckpt = "./model_weights" | |
fname_pkl = os.path.join(path_ckpt, 'network-snapshot-005000.pkl') | |
if not os.path.isfile(fname_pkl): | |
raise AssertionError(f'Could not find the default network snapshot at {fname_pkl}. Quitting.') | |
with open(fname_pkl, 'rb') as f: | |
G = pickle.load(f)['G_ema'] # torch.nn.ModuleDict | |
if in_gpu: | |
G = G.cuda() | |
return G | |