Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torchvision | |
import torchvision.transforms as transforms | |
def normalize(): | |
MEAN = [0.485, 0.456, 0.406] | |
STD = [0.229, 0.224, 0.225] | |
return transforms.Normalize(mean = MEAN, std = STD) | |
def denormalize(): | |
# out = (x - mean) / std | |
MEAN = [0.485, 0.456, 0.406] | |
STD = [0.229, 0.224, 0.225] | |
MEAN = [-mean/std for mean, std in zip(MEAN, STD)] | |
STD = [1/std for std in STD] | |
return transforms.Normalize(mean=MEAN, std=STD) | |
def transformer(imsize = None, cropsize = None): | |
transformer = [] | |
if imsize: | |
transformer.append(transforms.Resize(imsize)) | |
if cropsize: | |
transformer.append(transforms.RandomCrop(cropsize)) | |
transformer.append(transforms.ToTensor()) | |
transformer.append(normalize()) | |
return transforms.Compose(transformer) | |
def tensor_to_img(tensor): | |
denormalizer = denormalize() | |
if tensor.device == "cuda": | |
tensor = tensor.cpu() | |
# | |
tensor = torchvision.utils.make_grid(denormalizer(tensor.squeeze())) | |
image = transforms.functional.to_pil_image(tensor.clamp_(0., 1.)) | |
return image | |
def style_transfer(content_img, style_strength, style_img_1 = None, iw_1 = 0., style_img_2 = None, iw_2 = 0., style_img_3 = None, iw_3 = 0., preserve_color = None): | |
transform = transformer(imsize = 512) | |
content = transform(content_img).unsqueeze(0) | |
iw = [iw_1, iw_2, iw_3] | |
interpolation_weights = [i/ sum(iw) for i in iw] | |
style_imgs = [style_img_1, style_img_2, style_img_3] | |
styles = [] | |
for style_img in style_imgs: | |
if style_img is not None: | |
styles.append(transform(style_img).unsqueeze(0)) | |
if preserve_color == "None": preserve_color = None | |
elif preserve_color == "Whitening": preserve_color = "batch_wct" | |
#elif preserve_color == "Histogram matching": preserve_color = "histogram_matching" | |
with torch.no_grad(): | |
stylized_img = model(content, styles, interpolation_weights, preserve_color, style_strength) | |
return tensor_to_img(stylized_img) | |
title = "Artistic Style Transfer" | |
content_img = gr.components.Image(label="Content image", type = "pil") | |
style_img_1 = gr.components.Image(label="Style images", type = "pil") | |
iw_1 = gr.components.Slider(0., 1., label = "Style 1 interpolation") | |
style_img_2 = gr.components.Image(label="Style images", type = "pil") | |
iw_2 = gr.components.Slider(0., 1., label = "Style 2 interpolation") | |
style_img_3 = gr.components.Image(label="Style images", type = "pil") | |
iw_3 = gr.components.Slider(0., 1., label = "Style 3 interpolation") | |
style_strength = gr.components.Slider(0., 1., label = "Adjust style strength") | |
preserve_color = gr.components.Dropdown(["None", "Whitening"], label = "Choose color preserving mode") | |
interface = gr.Interface(fn = style_transfer, | |
inputs = [content_img, | |
style_strength, | |
style_img_1, | |
iw_1, | |
style_img_2, | |
iw_2, | |
style_img_3, | |
iw_3, | |
preserve_color], | |
outputs = gr.components.Image(), | |
title = title, | |
description = None | |
) | |
interface.queue() | |
interface.launch(share = True) |