SerdarHelli's picture
Update app.py
83b0577
raw
history blame
2.26 kB
# -*- coding: utf-8 -*-
"""
Created on Tue Apr 26 21:02:31 2022
@author: pc
"""
import pickle
import numpy as np
import torch
import gradio as gr
import sys
import subprocess
import os
from typing import Tuple
import PIL.Image
from huggingface_hub import hf_hub_download
os.system("git clone https://github.com/NVlabs/stylegan3")
sys.path.append("stylegan3")
DESCRIPTION = f'''This model generates healthy MR Brain Images.
[Example]("https://huggingface.co/spaces/SerdarHelli/Brain-MR-Image-Generation-GAN/blob/main/ex.png")
'''
hf_hub_download("SerdarHelli/Brain-MRI-GAN", filename="brainmrigan.pkl",cache_dir="./model",revision="main")
network_pkl='./model/61b352960b95f66c2b1dd346ea9ab1ccabfcfc23b51977cbfd301889bf7304ef.abd9063484d155f44754cefcb942e98bba5da8aa506173e68fb8f1439acdb419.pkl'
with open(network_pkl, 'rb') as f:
G = pickle.load(f)['G_ema']
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
G.eval()
G.to(device)
def predict(Seed,noise_mode,truncation_psi):
# Generate images.
z = torch.from_numpy(np.random.RandomState(Seed).randn(1, G.z_dim)).to(device)
label = torch.zeros([1, G.c_dim], device=device)
# Construct an inverse rotation/translation matrix and pass to the generator. The
# generator expects this matrix as an inverse to avoid potentially failing numerical
# operations in the network.
img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
return (PIL.Image.fromarray(img[0].cpu().numpy()[:,:,0])).resize((512,512))
noises=['const', 'random', 'none']
interface=gr.Interface(fn=predict, title="Brain MR Image Generation with StyleGAN-2",
description = DESCRIPTION,
article = "Author: S.Serdar Helli and Burhan Arat",
inputs=[gr.inputs.Slider( minimum=0, maximum=2**16,label='Seed'),gr.inputs.Radio( choices=noises, default='const',label='Noise Mods'),
gr.inputs.Slider(0, 2, step=0.05, default=1, label='Truncation psi')],
outputs=gr.outputs.Image( type="numpy", label="Output"))
interface.launch(debug=True)