Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from models import create_model | |
from options.test_options import TestOptions | |
from PIL import Image | |
from torchvision import transforms | |
import os | |
# Set options | |
opt = TestOptions().parse(use_cmd_line=False) | |
opt.model = 'pix2pix' | |
opt.netG = 'unet_256' | |
opt.dataset_mode = 'single' | |
opt.norm = 'batch' | |
opt.no_dropout = True | |
opt.init_type = 'normal' | |
opt.init_gain = 0.02 | |
opt.dataroot = './dummy_path' # This is just a placeholder since it's required | |
opt.checkpoints_dir = './checkpoints' | |
opt.name = 'artgan_pix2pix' | |
# Load model | |
model = create_model(opt) | |
model.setup(opt) | |
model.eval() | |
# Specify the path to the model weights | |
model_path = './checkpoints/artgan_pix2pix/latest_net_G.pth' | |
# Ensure the file exists | |
if os.path.isfile(model_path): | |
# Load the weights into the model's generator | |
model.netG.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) | |
print("Weights loaded successfully from:", model_path) | |
else: | |
raise ValueError(f"No file found at {model_path}. Please check the path.") | |
# Define preprocessing transformations | |
transform = transforms.Compose([ | |
transforms.Resize(290), | |
transforms.CenterCrop(256), | |
transforms.ToTensor() | |
]) | |
def generate_art(input_image): | |
# Convert the input_image into a format suitable for your model | |
input_tensor = transform(input_image).unsqueeze(0) | |
with torch.no_grad(): | |
output = model.netG(input_tensor) | |
print(output[0].min(), output[0].max(), output[0].mean()) | |
output_image = transforms.ToPILImage()(output[0]) | |
return output_image | |
# Define the Gradio Interface | |
gr.Interface( | |
generate_art, | |
inputs=gr.Image(label="Upload 5x5 vector map", type="pil"), | |
outputs=gr.Image(type="pil"), | |
title="ArtGAN Generator", | |
).launch() | |