Spaces:
Running
Running
firatozdemir
commited on
Commit
•
691e8f7
1
Parent(s):
53681fc
init commit
Browse files- .gitattributes +1 -0
- README.md +3 -3
- app.py +52 -0
- requirements.txt +3 -0
.gitattributes
CHANGED
@@ -29,3 +29,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
29 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
30 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
31 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
29 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
30 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
31 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
32 |
+
model_weights/** filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
---
|
2 |
title: OAGen
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.10.0
|
8 |
app_file: app.py
|
|
|
1 |
---
|
2 |
title: OAGen
|
3 |
+
emoji: 💻
|
4 |
+
colorFrom: gray
|
5 |
+
colorTo: red
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.10.0
|
8 |
app_file: app.py
|
app.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, glob, sys
|
2 |
+
import pickle
|
3 |
+
import streamlit as st
|
4 |
+
import torch
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import numpy as np
|
7 |
+
sys.path.append('stylegan3')
|
8 |
+
|
9 |
+
class SampleFromGAN:
|
10 |
+
def __init__(self, G, z_shp, in_gpu=False) -> None:
|
11 |
+
self.G = G
|
12 |
+
self.in_gpu = in_gpu
|
13 |
+
self.z_shp = z_shp #[#images, z_dim]
|
14 |
+
def __call__(self,):
|
15 |
+
z = torch.randn(self.z_shp)
|
16 |
+
if self.in_gpu:
|
17 |
+
z = z.cuda()
|
18 |
+
ims = G(z, c=None)
|
19 |
+
ims = ims[:,0,...]
|
20 |
+
return ims
|
21 |
+
|
22 |
+
class Plot:
|
23 |
+
def __init__(self, im_gen) -> None:
|
24 |
+
self.im_gen = im_gen
|
25 |
+
assert callable(im_gen)
|
26 |
+
def __call__(self):
|
27 |
+
ims = self.im_gen()
|
28 |
+
# plot first image
|
29 |
+
im = ims[0,...]
|
30 |
+
fig, ax = plt.subplots(1, figsize=(12,12))
|
31 |
+
fig.subplots_adjust(left=0,right=1,bottom=0,top=1)
|
32 |
+
ax.imshow(im, cmap='gray')
|
33 |
+
ax.axis('tight')
|
34 |
+
ax.axis('off')
|
35 |
+
st.pyplot(fig)
|
36 |
+
|
37 |
+
|
38 |
+
# path_ckpt = "/home/firat/saved_models/DLBIRHOUI/stylegan/00001-stylegan2-SWFD-linear_BP-linear_BP-gpus2-batch32-gamma8.2-augnoaug"
|
39 |
+
path_ckpt = "./model_weights"
|
40 |
+
fname_pkl = os.path.join(path_ckpt, 'network-snapshot-005000.pkl')
|
41 |
+
in_gpu = False
|
42 |
+
num_images = 1
|
43 |
+
with open(fname_pkl, 'rb') as f:
|
44 |
+
G = pickle.load(f)['G_ema'] # torch.nn.ModuleDict
|
45 |
+
if in_gpu:
|
46 |
+
G = G.cuda()
|
47 |
+
|
48 |
+
sampler = SampleFromGAN(G=G, z_shp=[num_images, G.z_dim], in_gpu=in_gpu)
|
49 |
+
|
50 |
+
button_on_click = Plot(im_gen=sampler)
|
51 |
+
button_gen_clicked = st.button(label='Generate an image', key='n', on_click=button_on_click)
|
52 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
torchaudio
|