firatozdemir commited on
Commit
691e8f7
1 Parent(s): 53681fc

init commit

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. README.md +3 -3
  3. app.py +52 -0
  4. requirements.txt +3 -0
.gitattributes CHANGED
@@ -29,3 +29,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
29
  *.zip filter=lfs diff=lfs merge=lfs -text
30
  *.zst filter=lfs diff=lfs merge=lfs -text
31
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
29
  *.zip filter=lfs diff=lfs merge=lfs -text
30
  *.zst filter=lfs diff=lfs merge=lfs -text
31
  *tfevents* filter=lfs diff=lfs merge=lfs -text
32
+ model_weights/** filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: OAGen
3
- emoji: 😻
4
- colorFrom: pink
5
- colorTo: gray
6
  sdk: streamlit
7
  sdk_version: 1.10.0
8
  app_file: app.py
 
1
  ---
2
  title: OAGen
3
+ emoji: 💻
4
+ colorFrom: gray
5
+ colorTo: red
6
  sdk: streamlit
7
  sdk_version: 1.10.0
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, glob, sys
2
+ import pickle
3
+ import streamlit as st
4
+ import torch
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ sys.path.append('stylegan3')
8
+
9
+ class SampleFromGAN:
10
+ def __init__(self, G, z_shp, in_gpu=False) -> None:
11
+ self.G = G
12
+ self.in_gpu = in_gpu
13
+ self.z_shp = z_shp #[#images, z_dim]
14
+ def __call__(self,):
15
+ z = torch.randn(self.z_shp)
16
+ if self.in_gpu:
17
+ z = z.cuda()
18
+ ims = G(z, c=None)
19
+ ims = ims[:,0,...]
20
+ return ims
21
+
22
+ class Plot:
23
+ def __init__(self, im_gen) -> None:
24
+ self.im_gen = im_gen
25
+ assert callable(im_gen)
26
+ def __call__(self):
27
+ ims = self.im_gen()
28
+ # plot first image
29
+ im = ims[0,...]
30
+ fig, ax = plt.subplots(1, figsize=(12,12))
31
+ fig.subplots_adjust(left=0,right=1,bottom=0,top=1)
32
+ ax.imshow(im, cmap='gray')
33
+ ax.axis('tight')
34
+ ax.axis('off')
35
+ st.pyplot(fig)
36
+
37
+
38
+ # path_ckpt = "/home/firat/saved_models/DLBIRHOUI/stylegan/00001-stylegan2-SWFD-linear_BP-linear_BP-gpus2-batch32-gamma8.2-augnoaug"
39
+ path_ckpt = "./model_weights"
40
+ fname_pkl = os.path.join(path_ckpt, 'network-snapshot-005000.pkl')
41
+ in_gpu = False
42
+ num_images = 1
43
+ with open(fname_pkl, 'rb') as f:
44
+ G = pickle.load(f)['G_ema'] # torch.nn.ModuleDict
45
+ if in_gpu:
46
+ G = G.cuda()
47
+
48
+ sampler = SampleFromGAN(G=G, z_shp=[num_images, G.z_dim], in_gpu=in_gpu)
49
+
50
+ button_on_click = Plot(im_gen=sampler)
51
+ button_gen_clicked = st.button(label='Generate an image', key='n', on_click=button_on_click)
52
+
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ torchvision
3
+ torchaudio