Spaces:
Running
Running
File size: 1,430 Bytes
3341213 912a3d8 3341213 cc9a06d 3341213 60eed5d 3341213 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
import os, glob, sys
import pickle
import streamlit as st
import torch
import matplotlib.pyplot as plt
import numpy as np
sys.path.append('stylegan3')
class SampleFromGAN:
def __init__(self, G, z_shp, in_gpu=False) -> None:
self.G = G
self.in_gpu = in_gpu
self.z_shp = z_shp #[#images, z_dim]
def __call__(self,):
z = torch.randn(self.z_shp)
if self.in_gpu:
z = z.cuda()
ims = self.G(z, c=None)
ims = ims[:,0,...]
return ims
class Plot:
def __init__(self, im_gen) -> None:
self.im_gen = im_gen
assert callable(im_gen)
def __call__(self):
ims = self.im_gen()
# plot first image
im = ims[0,...]
fig, ax = plt.subplots(1, figsize=(12,12))
fig.subplots_adjust(left=0,right=1,bottom=0,top=1)
ax.imshow(im, cmap='gray')
ax.axis('tight')
ax.axis('off')
st.pyplot(fig)
def load_default_gen(in_gpu=False, fname_pkl=None):
if fname_pkl is None:
path_ckpt = "./model_weights"
fname_pkl = os.path.join(path_ckpt, 'network-snapshot-005000.pkl')
if not os.path.isfile(fname_pkl):
raise AssertionError(f'Could not find the default network snapshot at {fname_pkl}. Quitting.')
with open(fname_pkl, 'rb') as f:
G = pickle.load(f)['G_ema'] # torch.nn.ModuleDict
if in_gpu:
G = G.cuda()
return G
|