TranSVAE / app.py
ldkong's picture
Update app.py
136c53f
import gradio as gr
import torch
from torch import nn
import imageio
class Generator(nn.Module):
# Refer to the link below for explanations about nc, nz, and ngf
# https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html#inputs
def __init__(self, nc=4, nz=100, ngf=64):
super(Generator, self).__init__()
self.network = nn.Sequential(
nn.ConvTranspose2d(nz, ngf * 4, 3, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
nn.ConvTranspose2d(ngf * 4, ngf * 2, 3, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 0, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
nn.Tanh(),
)
def forward(self, input):
output = self.network(input)
return output
def display_gif(file_name):
images = []
for frame in range(8):
frame_name = '_%d' % (frame)
image_filename = file_name + frame_name + '.png'
images.append(imageio.imread(image_filename))
gif_filename = 'avatar.gif'
return imageio.mimsave(gif_filename, images)
def display_image(file_name):
image_filename = file_name + '0' + '.png'
print(image_filename)
image = imageio.imread(image_filename)
imageio.imwrite('image.png', image)
def run(action, body, hair, top, bottom):
# body
if body == "human": body = '0'
elif body == "alien": body = '1'
# hair
if hair == "green": hair = '0'
elif hair == "yellow": hair = '2'
elif hair == "rose": hair = '4'
elif hair == "red": hair = '7'
elif hair == "wine": hair = '8'
# top
if top == "brown": top = '0'
elif top == "blue": top = '1'
elif top == "white": top = '2'
# bottom
if bottom == "white": bottom = '0'
elif bottom == "golden": bottom = '1'
elif bottom == "red": bottom = '2'
elif bottom == "silver": bottom = '3'
name = './Sprite/frames/domain_1/' + action + '/'
name = name + 'front' + '_' + str(body) + str(bottom) + str(top) + str(hair) + '_'
gif = display_image(name)
return 'image.png', 'image.png'
gr.Interface(
run,
inputs=[
gr.Radio(choices=["shoot", "slash", "spellcard", "thrust", "walk"], value="shoot"),
gr.Radio(choices=["human", "alien"], value="human"),
gr.Radio(choices=["green", "yellow", "rose", "red", "wine"], value="green"),
gr.Radio(choices=["brown", "blue", "white"], value="brown"),
gr.Radio(choices=["white", "golden", "red", "silver"], value="white"),
],
outputs=[
gr.components.Image(type="file", label="Avatar (Source)"),
gr.components.Image(type="file", label="Avatar (Target)")
],
live=False,
title="TransferVAE",
).launch()