import gradio as gr import torch from models import create_model from options.test_options import TestOptions from PIL import Image from torchvision import transforms # Set options opt = TestOptions().parse(use_cmd_line=False) opt.model = 'pix2pix' opt.netG = 'unet_256' opt.dataset_mode = 'single' opt.norm = 'batch' opt.no_dropout = True opt.init_type = 'normal' opt.init_gain = 0.02 opt.dataroot = './dummy_path' opt.checkpoints_dir = './checkpoints' opt.name = 'artgan_pix2pix' opt.preprocess = 'resize_and_crop' opt.load_size = 290 opt.crop_size = 256 opt.no_flip = False # Load model model = create_model(opt) model.setup(opt) model.eval() # Get Transform function from base_dataset def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True): transform_list = [] if grayscale: transform_list.append(transforms.Grayscale(1)) if 'resize' in opt.preprocess: osize = [opt.load_size, opt.load_size] transform_list.append(transforms.Resize(osize, method)) if 'crop' in opt.preprocess: if params is None: transform_list.append(transforms.RandomCrop(opt.crop_size)) else: transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size))) if not opt.no_flip: if params is None: transform_list.append(transforms.RandomHorizontalFlip()) elif params['flip']: transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) if convert: transform_list += [transforms.ToTensor()] if grayscale: transform_list += [transforms.Normalize((0.5,), (0.5,))] else: transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] return transforms.Compose(transform_list) def generate_art(input_image): transform = get_transform(opt) input_tensor = transform(input_image).unsqueeze(0) with torch.no_grad(): output = model.netG(input_tensor) output_image = transforms.ToPILImage()(output[0]) return output_image # Define the Gradio Interface gr.Interface( generate_art, inputs=gr.Image(label="Upload 5x5 vector map", type="pil"), outputs=gr.Image(type="pil"), title="ArtGAN Generator", ).launch()