import gradio as gr import torch from models import create_model from options.test_options import TestOptions from PIL import Image from torchvision import transforms import os # Set options opt = TestOptions().parse(use_cmd_line=False) opt.model = 'pix2pix' opt.netG = 'unet_256' opt.dataset_mode = 'single' opt.norm = 'batch' opt.no_dropout = True opt.init_type = 'normal' opt.init_gain = 0.02 opt.dataroot = './dummy_path' # This is just a placeholder since it's required opt.checkpoints_dir = './checkpoints' opt.name = 'artgan_pix2pix' # Load model model = create_model(opt) model.setup(opt) model.eval() # Specify the path to the model weights model_path = './checkpoints/artgan_pix2pix/latest_net_G.pth' # Ensure the file exists if os.path.isfile(model_path): # Load the weights into the model's generator model.netG.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) print("Weights loaded successfully from:", model_path) else: raise ValueError(f"No file found at {model_path}. Please check the path.") # Define preprocessing transformations transform = transforms.Compose([ transforms.Resize(290), transforms.CenterCrop(256), transforms.ToTensor() ]) def generate_art(input_image): # Convert the input_image into a format suitable for your model input_tensor = transform(input_image).unsqueeze(0) with torch.no_grad(): output = model.netG(input_tensor) print(output[0].min(), output[0].max(), output[0].mean()) output_image = transforms.ToPILImage()(output[0]) return output_image # Define the Gradio Interface gr.Interface( generate_art, inputs=gr.Image(label="Upload 5x5 vector map", type="pil"), outputs=gr.Image(type="pil"), title="ArtGAN Generator", ).launch()