File size: 1,435 Bytes
46c0a52 212623f 46c0a52 2ff1d07 46c0a52 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import gradio as gr
import torch
from torch import nn
from torchvision.transforms.functional import to_pil_image
import torchvision.utils as vutils
class Generator(nn.Module):
def __init__(self, ngpu):
super(Generator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
nn.ConvTranspose2d(128, 512, 4,1,0,bias=False),
nn.BatchNorm2d(512),
nn.ReLU(True),
nn.ConvTranspose2d(512,256,4,2,1,bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.ConvTranspose2d(256, 128,4,2,1,bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 64,4,2,1,bias=False),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.ConvTranspose2d(64, 3, 4,2,1,bias=False),
nn.Tanh()
)
def forward(self, x):
return self.main(x)
model = Generator(ngpu=0)
model.load_state_dict(torch.load('car_gen.pth',map_location='cpu'))
def generate(button):
model.eval()
noise = torch.randn(32,128,1,1)
with torch.inference_mode():
images = []
predictions = model(noise).detach().cpu()
generated_grid = vutils.make_grid(predictions, nrow=8, padding=2, normalize=True)
return to_pil_image(generated_grid)
Interface = gr.Interface(
title='CarGAN',
fn=generate,
inputs=gr.Button(value='Generate',size='lg'),
outputs=gr.Image(type='pil')
)
Interface.launch()
|