SerdarHelli's picture
Update app.py
4245274
raw
history blame
3.62 kB
import os
import pickle
import sys
import subprocess
import tarfile
import subprocess
import imageio
import numpy as np
import scipy.interpolate
import torch
from tqdm import tqdm
import gradio as gr
from huggingface_hub import hf_hub_download
def layout_grid(img, grid_w=None, grid_h=1, float_to_uint8=True, chw_to_hwc=True, to_numpy=True):
batch_size, channels, img_h, img_w = img.shape
if grid_w is None:
grid_w = batch_size // grid_h
assert batch_size == grid_w * grid_h
if float_to_uint8:
img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
img = img.reshape(grid_h, grid_w, channels, img_h, img_w)
img = img.permute(2, 0, 3, 1, 4)
img = img.reshape(channels, grid_h * img_h, grid_w * img_w)
if chw_to_hwc:
img = img.permute(1, 2, 0)
if to_numpy:
img = img.cpu().numpy()
return img
network_pkl='braingan-400.pkl'
with open(network_pkl, 'rb') as f:
G = pickle.load(f)['G_ema']
def predict(Seed,choices):
device = torch.device('cuda')
G.eval()
G.to(device)
shuffle_seed=None
w_frames=60*4
kind='cubic'
num_keyframes=None
wraps=2
psi=1
device=torch.device('cuda')
if choices=='4x2':
grid_w = 4
grid_h = 2
s1=Seed
seeds=(np.arange(s1-16,s1)).tolist()
if choices=='2x1':
grid_w = 2
grid_h = 1
s1=Seed
seeds=(np.arange(s1-4,s1)).tolist()
mp4='ex.mp4'
truncation_psi=1
num_keyframes=None
if num_keyframes is None:
if len(seeds) % (grid_w*grid_h) != 0:
raise ValueError('Number of input seeds must be divisible by grid W*H')
num_keyframes = len(seeds) // (grid_w*grid_h)
all_seeds = np.zeros(num_keyframes*grid_h*grid_w, dtype=np.int64)
for idx in range(num_keyframes*grid_h*grid_w):
all_seeds[idx] = seeds[idx % len(seeds)]
if shuffle_seed is not None:
rng = np.random.RandomState(seed=shuffle_seed)
rng.shuffle(all_seeds)
zs = torch.from_numpy(np.stack([np.random.RandomState(seed).randn(G.z_dim) for seed in all_seeds])).to(device)
ws = G.mapping(z=zs, c=None, truncation_psi=psi)
_ = G.synthesis(ws[:1]) # warm up
ws = ws.reshape(grid_h, grid_w, num_keyframes, *ws.shape[1:])
# Interpolation.
grid = []
for yi in range(grid_h):
row = []
for xi in range(grid_w):
x = np.arange(-num_keyframes * wraps, num_keyframes * (wraps + 1))
y = np.tile(ws[yi][xi].cpu().numpy(), [wraps * 2 + 1, 1, 1])
interp = scipy.interpolate.interp1d(x, y, kind=kind, axis=0)
row.append(interp)
grid.append(row)
# Render video.
video_out = imageio.get_writer(mp4, mode='I', fps=60, codec='libx264')
for frame_idx in tqdm(range(num_keyframes * w_frames)):
imgs = []
for yi in range(grid_h):
for xi in range(grid_w):
interp = grid[yi][xi]
w = torch.from_numpy(interp(frame_idx / w_frames)).to(device)
img = G.synthesis(ws=w.unsqueeze(0), noise_mode='const')[0]
imgs.append(img)
video_out.append_data(layout_grid(torch.stack(imgs), grid_w=grid_w, grid_h=grid_h))
video_out.close()
return 'ex.mp4'
choices=['4x2','2x1']
interface=gr.Interface(fn=predict, title="Brain MR Image Generation with StyleGAN-2",
description = "",
article = "Author: S.Serdar Helli",
inputs=[gr.inputs.Slider( minimum=16, maximum=2**10,label='Seed'),gr.inputs.Radio( choices=choices, default='4x2',label='Image Grid')],
outputs=gr.outputs.Video(label='Video'))
interface.launch(debug=True)