Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from models import create_model | |
from options.test_options import TestOptions | |
from PIL import Image | |
from torchvision import transforms | |
# Set options | |
opt = TestOptions().parse(use_cmd_line=False) | |
opt.model = 'pix2pix' | |
opt.netG = 'unet_256' | |
opt.dataset_mode = 'single' | |
opt.norm = 'batch' | |
opt.no_dropout = True | |
opt.init_type = 'normal' | |
opt.init_gain = 0.02 | |
opt.dataroot = './dummy_path' | |
opt.checkpoints_dir = './checkpoints' | |
opt.name = 'artgan_pix2pix' | |
opt.preprocess = 'resize_and_crop' | |
opt.load_size = 290 | |
opt.crop_size = 256 | |
opt.no_flip = False | |
# Load model | |
model = create_model(opt) | |
model.setup(opt) | |
model.eval() | |
# Get Transform function from base_dataset | |
def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True): | |
transform_list = [] | |
if grayscale: | |
transform_list.append(transforms.Grayscale(1)) | |
if 'resize' in opt.preprocess: | |
osize = [opt.load_size, opt.load_size] | |
transform_list.append(transforms.Resize(osize, method)) | |
if 'crop' in opt.preprocess: | |
if params is None: | |
transform_list.append(transforms.RandomCrop(opt.crop_size)) | |
else: | |
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size))) | |
if not opt.no_flip: | |
if params is None: | |
transform_list.append(transforms.RandomHorizontalFlip()) | |
elif params['flip']: | |
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) | |
if convert: | |
transform_list += [transforms.ToTensor()] | |
if grayscale: | |
transform_list += [transforms.Normalize((0.5,), (0.5,))] | |
else: | |
transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] | |
return transforms.Compose(transform_list) | |
def generate_art(input_image): | |
transform = get_transform(opt) | |
input_tensor = transform(input_image).unsqueeze(0) | |
with torch.no_grad(): | |
output = model.netG(input_tensor) | |
output_image = transforms.ToPILImage()(output[0]) | |
return output_image | |
# Define the Gradio Interface | |
gr.Interface( | |
generate_art, | |
inputs=gr.Image(label="Upload 5x5 vector map", type="pil"), | |
outputs=gr.Image(type="pil"), | |
title="ArtGAN Generator", | |
).launch() | |