File size: 1,802 Bytes
901a2ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
883a623
 
 
901a2ca
 
 
 
 
 
 
e015b78
 
901a2ca
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import gradio as gr
import torch
from models import create_model
from options.test_options import TestOptions
from PIL import Image
from torchvision import transforms
import os

# Set options
opt = TestOptions().parse(use_cmd_line=False)
opt.model = 'pix2pix'
opt.netG = 'unet_256'
opt.dataset_mode = 'single'
opt.norm = 'batch'
opt.no_dropout = True
opt.init_type = 'normal'
opt.init_gain = 0.02
opt.dataroot = './dummy_path'  # This is just a placeholder since it's required
opt.checkpoints_dir = './checkpoints'
opt.name = 'artgan_pix2pix'

# Load model
model = create_model(opt)
model.setup(opt)
model.eval()

# Specify the path to the model weights
model_path = './checkpoints/artgan_pix2pix/latest_net_G.pth'

# Ensure the file exists
if os.path.isfile(model_path):
    # Load the weights into the model's generator
    model.netG.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
    print("Weights loaded successfully from:", model_path)
else:
    raise ValueError(f"No file found at {model_path}. Please check the path.")

# Define preprocessing transformations
transform = transforms.Compose([
    transforms.Resize(290),
    transforms.CenterCrop(256),
    transforms.ToTensor()
])

def generate_art(input_image):
    # Convert the input_image into a format suitable for your model
    input_tensor = transform(input_image).unsqueeze(0)
    
    with torch.no_grad():
        output = model.netG(input_tensor)
        
    print(output[0].min(), output[0].max(), output[0].mean())

    output_image = transforms.ToPILImage()(output[0])
    
    return output_image

# Define the Gradio Interface
gr.Interface(
    generate_art,
    inputs=gr.Image(label="Upload 5x5 vector map", type="pil"),
    outputs=gr.Image(type="pil"),
    title="ArtGAN Generator",
).launch()