davidpiscasio's picture
Updated app.py
6d6e88e
raw
history blame
No virus
3.68 kB
from options.test_options import TestOptions
from models import create_model
import torch
import numpy as np
import gradio as gr
from einops import rearrange
import torchvision
import torchvision.transforms as transforms
def tensor2im(input_image, imtype=np.uint8):
if not isinstance(input_image, np.ndarray):
if isinstance(input_image, torch.Tensor): # get the data from a variable
image_tensor = input_image.data
else:
return input_image
image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array
if image_numpy.shape[0] == 1: # grayscale to RGB
image_numpy = np.tile(image_numpy, (3, 1, 1))
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
else: # if it is a numpy array, do nothing
image_numpy = input_image
return image_numpy.astype(imtype)
def get_model(translation):
if translation == 'Orange to Apple':
return 'orange2apple'
elif translation == 'Horse to Zebra':
return 'horse2zebra'
elif translation == 'Image to Van Gogh':
return 'style_vangogh'
elif translation == 'Image to Monet':
return 'style_monet'
def unpaired_img2img(translation, image):
opt = TestOptions().parse()
m_name = get_model(translation)
opt.name = m_name + '_pretrained'
opt.model = 'test'
opt.no_dropout = True
opt.num_threads = 0
opt.batch_size = 1
opt.no_flip = True
model = create_model(opt)
model.setup(opt)
model.eval()
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
image = torch.from_numpy(image) # Convert image from numpy to PyTorch tensor
image = rearrange(image, "h w c -> c h w") # Since PyTorch is channel first
# Perform necessary image transforms
image = transforms.Resize(256)(image)
image = transforms.CenterCrop(256)(image).float()/255.
image = normalize(image)
image = rearrange(image, "c h w -> 1 c h w") # Insert batch size of 1 (as required by our model)
model.set_input(image)
model.test()
visuals = model.get_current_visuals() # get image results
for i in visuals.values():
im_data = i
im = tensor2im(im_data)
return im
gr.Interface(fn=unpaired_img2img,
inputs=[gr.inputs.Dropdown(['Horse to Zebra', 'Orange to Apple', 'Image to Van Gogh', 'Image to Monet']),
gr.inputs.Image(shape=(256,256))],
outputs=gr.outputs.Image(type="numpy"),
title="Unpaired Image to Image Translation",
examples=[['Horse to Zebra', "examples/horse1.jpg"],
['Horse to Zebra', "examples/horse3.jpg"],
['Orange to Apple', "examples/orange1.jpg"],
['Orange to Apple', "examples/orange2.jpg"],
['Image to Van Gogh', "examples/img1.jpg"],
['Image to Van Gogh', "examples/img2.jpg"],
['Image to Monet', "examples/img1.jpg"],
['Image to Monet', "examples/img2.jpg"]],
description="This is a PyTorch implementation of the unpaired image-to-image translation using a pretrained CycleGAN model. Kindly select first the type of translation you wish to see using the dropdown menu. Then, upload the image you wish to translate and click on the 'Submit' button.",
article="To know more about Unpaired Image to Image Translation and CycleGAN, you may access their <a href = https://paperswithcode.com/paper/unpaired-image-to-image-translation-using>Papers with Code</a> page.",
allow_flagging="never").launch(inbrowser=True)