SerdarHelli commited on
Commit
e2cf2b0
·
1 Parent(s): 3d781fd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -0
app.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import imageio
3
+ import numpy as np
4
+ import scipy.interpolate
5
+ import torch
6
+ from tqdm import tqdm
7
+ import gradio as gr
8
+ from huggingface_hub import hf_hub_download
9
+
10
+ def layout_grid(img, grid_w=None, grid_h=1, float_to_uint8=True, chw_to_hwc=True, to_numpy=True):
11
+ batch_size, channels, img_h, img_w = img.shape
12
+ if grid_w is None:
13
+ grid_w = batch_size // grid_h
14
+ assert batch_size == grid_w * grid_h
15
+ if float_to_uint8:
16
+ img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
17
+ img = img.reshape(grid_h, grid_w, channels, img_h, img_w)
18
+ img = img.permute(2, 0, 3, 1, 4)
19
+ img = img.reshape(channels, grid_h * img_h, grid_w * img_w)
20
+ if chw_to_hwc:
21
+ img = img.permute(1, 2, 0)
22
+ if to_numpy:
23
+ img = img.cpu().numpy()
24
+ return img
25
+
26
+
27
+
28
+
29
+ network_pkl=hf_hub_download('SerdarHelli/BrainMRIGAN/braingan-400.pkl')
30
+ with open(network_pkl, 'rb') as f:
31
+ G = pickle.load(f)['G_ema']
32
+
33
+ def predict(Seed,choices):
34
+ device = torch.device('cuda')
35
+ G.eval()
36
+ G.to(device)
37
+ shuffle_seed=None
38
+ w_frames=60*4
39
+ kind='cubic'
40
+ num_keyframes=None
41
+ wraps=2
42
+ psi=1
43
+ device=torch.device('cuda')
44
+
45
+
46
+ if choices=='4x2':
47
+ grid_w = 4
48
+ grid_h = 2
49
+ s1=Seed
50
+ seeds=(np.arange(s1-16,s1)).tolist()
51
+ if choices=='2x1':
52
+ grid_w = 2
53
+ grid_h = 1
54
+ s1=Seed
55
+ seeds=(np.arange(s1-4,s1)).tolist()
56
+
57
+
58
+ mp4='ex.mp4'
59
+ truncation_psi=1
60
+ num_keyframes=None
61
+
62
+
63
+ if num_keyframes is None:
64
+ if len(seeds) % (grid_w*grid_h) != 0:
65
+ raise ValueError('Number of input seeds must be divisible by grid W*H')
66
+ num_keyframes = len(seeds) // (grid_w*grid_h)
67
+
68
+ all_seeds = np.zeros(num_keyframes*grid_h*grid_w, dtype=np.int64)
69
+ for idx in range(num_keyframes*grid_h*grid_w):
70
+ all_seeds[idx] = seeds[idx % len(seeds)]
71
+
72
+ if shuffle_seed is not None:
73
+ rng = np.random.RandomState(seed=shuffle_seed)
74
+ rng.shuffle(all_seeds)
75
+
76
+ zs = torch.from_numpy(np.stack([np.random.RandomState(seed).randn(G.z_dim) for seed in all_seeds])).to(device)
77
+ ws = G.mapping(z=zs, c=None, truncation_psi=psi)
78
+ _ = G.synthesis(ws[:1]) # warm up
79
+ ws = ws.reshape(grid_h, grid_w, num_keyframes, *ws.shape[1:])
80
+
81
+ # Interpolation.
82
+ grid = []
83
+ for yi in range(grid_h):
84
+ row = []
85
+ for xi in range(grid_w):
86
+ x = np.arange(-num_keyframes * wraps, num_keyframes * (wraps + 1))
87
+ y = np.tile(ws[yi][xi].cpu().numpy(), [wraps * 2 + 1, 1, 1])
88
+ interp = scipy.interpolate.interp1d(x, y, kind=kind, axis=0)
89
+ row.append(interp)
90
+ grid.append(row)
91
+
92
+ # Render video.
93
+ video_out = imageio.get_writer(mp4, mode='I', fps=60, codec='libx264')
94
+ for frame_idx in tqdm(range(num_keyframes * w_frames)):
95
+ imgs = []
96
+ for yi in range(grid_h):
97
+ for xi in range(grid_w):
98
+ interp = grid[yi][xi]
99
+ w = torch.from_numpy(interp(frame_idx / w_frames)).to(device)
100
+ img = G.synthesis(ws=w.unsqueeze(0), noise_mode='const')[0]
101
+ imgs.append(img)
102
+ video_out.append_data(layout_grid(torch.stack(imgs), grid_w=grid_w, grid_h=grid_h))
103
+ video_out.close()
104
+ return 'ex.mp4'
105
+
106
+
107
+
108
+ choices=['4x2','2x1']
109
+ interface=gr.Interface(fn=predict, title="Brain MR Image Generation with StyleGAN-2",
110
+ description = "",
111
+ article = "Author: S.Serdar Helli",
112
+ inputs=[gr.inputs.Slider( minimum=16, maximum=2**10,label='Seed'),gr.inputs.Radio( choices=choices, default='4x2',label='Image Grid')],
113
+ outputs=gr.outputs.Video(label='Video'))
114
+
115
+
116
+ interface.launch(debug=True)