File size: 2,312 Bytes
3d8e29a
 
 
 
 
9f6e901
3d8e29a
 
 
 
 
 
 
 
 
 
9f6e901
3d8e29a
 
9f6e901
 
 
 
3d8e29a
 
 
 
 
 
9f6e901
 
3d8e29a
9f6e901
 
3d8e29a
 
9f6e901
3d8e29a
9f6e901
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d8e29a
 
 
9f6e901
3d8e29a
9f6e901
3d8e29a
 
 
9f6e901
3d8e29a
 
 
 
 
 
 
 
 
9f6e901
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import gradio as gr
import torch
from models import create_model
from options.test_options import TestOptions
from PIL import Image
from torchvision import transforms

# Set options
opt = TestOptions().parse(use_cmd_line=False)
opt.model = 'pix2pix'
opt.netG = 'unet_256'
opt.dataset_mode = 'single'
opt.norm = 'batch'
opt.no_dropout = True
opt.init_type = 'normal'
opt.init_gain = 0.02
opt.dataroot = './dummy_path'
opt.checkpoints_dir = './checkpoints'
opt.name = 'artgan_pix2pix'
opt.preprocess = 'resize_and_crop'
opt.load_size = 290
opt.crop_size = 256
opt.no_flip = False

# Load model
model = create_model(opt)
model.setup(opt)
model.eval()

# Get Transform function from base_dataset
def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
    transform_list = []
    if grayscale:
        transform_list.append(transforms.Grayscale(1))
    if 'resize' in opt.preprocess:
        osize = [opt.load_size, opt.load_size]
        transform_list.append(transforms.Resize(osize, method))
    if 'crop' in opt.preprocess:
        if params is None:
            transform_list.append(transforms.RandomCrop(opt.crop_size))
        else:
            transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
    if not opt.no_flip:
        if params is None:
            transform_list.append(transforms.RandomHorizontalFlip())
        elif params['flip']:
            transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
    if convert:
        transform_list += [transforms.ToTensor()]
        if grayscale:
            transform_list += [transforms.Normalize((0.5,), (0.5,))]
        else:
            transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    return transforms.Compose(transform_list)

def generate_art(input_image):
    transform = get_transform(opt)
    input_tensor = transform(input_image).unsqueeze(0)
    
    with torch.no_grad():
        output = model.netG(input_tensor)
    output_image = transforms.ToPILImage()(output[0])
    
    return output_image

# Define the Gradio Interface
gr.Interface(
    generate_art,
    inputs=gr.Image(label="Upload 5x5 vector map", type="pil"),
    outputs=gr.Image(type="pil"),
    title="ArtGAN Generator",
).launch()