TranSVAE / app.py
ldkong's picture
Update app.py
d63c6cd
raw
history blame
No virus
1.58 kB
from torch import nn
class Generator(nn.Module):
# Refer to the link below for explanations about nc, nz, and ngf
# https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html#inputs
def __init__(self, nc=4, nz=100, ngf=64):
super(Generator, self).__init__()
self.network = nn.Sequential(
nn.ConvTranspose2d(nz, ngf * 4, 3, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
nn.ConvTranspose2d(ngf * 4, ngf * 2, 3, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 0, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
nn.Tanh(),
)
def forward(self, input):
output = self.network(input)
return output
from huggingface_hub import hf_hub_download
import torch
model = Generator()
weights_path = hf_hub_download('nateraw/cryptopunks-gan', 'generator.pth')
model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu'))) # Use 'cuda' if you have a GPU available
from torchvision.utils import save_image
def predict(seed):
num_punks = 4
torch.manual_seed(seed)
z = torch.randn(num_punks, 100, 1, 1)
punks = model(z)
save_image(punks, "punks.png", normalize=True)
return 'punks.png'
import gradio as gr
gr.Interface(
predict,
inputs=[
gr.Slider(0, 1000, label='Seed', default=42),
],
outputs="image",
).launch()