|
import torch |
|
from torch import nn |
|
import gradio as gr |
|
|
|
|
|
class Generator(nn.Module): |
|
|
|
|
|
def __init__(self, nc=4, nz=100, ngf=64): |
|
super(Generator, self).__init__() |
|
self.network = nn.Sequential( |
|
nn.ConvTranspose2d(nz, ngf * 4, 3, 1, 0, bias=False), |
|
nn.BatchNorm2d(ngf * 4), |
|
nn.ReLU(True), |
|
nn.ConvTranspose2d(ngf * 4, ngf * 2, 3, 2, 1, bias=False), |
|
nn.BatchNorm2d(ngf * 2), |
|
nn.ReLU(True), |
|
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 0, bias=False), |
|
nn.BatchNorm2d(ngf), |
|
nn.ReLU(True), |
|
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False), |
|
nn.Tanh(), |
|
) |
|
|
|
def forward(self, input): |
|
output = self.network(input) |
|
return output |
|
|
|
|
|
def predict(body, hair, top, bottom): |
|
name = str(body) + str(hair) + str(top) + str(bottom) |
|
return name |
|
|
|
|
|
gr.Interface( |
|
predict, |
|
inputs=[ |
|
gr.Radio(["human", "alien"]), |
|
gr.Slider(0, 5, label='Hair', step=1, default=0), |
|
gr.Slider(0, 3, label='Top', step=1, default=0), |
|
gr.Slider(0, 4, label='Bottom', step=1, default=0) |
|
], |
|
outputs="image", |
|
live=True, |
|
).launch() |
|
|