from PIL import Image import numpy as np import torch import torch.nn as nn from torchvision import transforms import transformer as transformer import StyTR as StyTR from collections import OrderedDict import tensorflow_hub as tfhub import tensorflow as tf import os import cv2 import paddlehub as phub ############################################# TRANSFORMER ############################################ def style_transform(h,w): k = (h,w) transform_list = [] transform_list.append(transforms.CenterCrop((h,w))) transform_list.append(transforms.ToTensor()) transform = transforms.Compose(transform_list) return transform def content_transform(): transform_list = [] transform_list.append(transforms.ToTensor()) transform = transforms.Compose(transform_list) return transform def StyleTransformer(content_img: Image, style_img: Image): vgg_path = 'vgg_normalised.pth' decoder_path = 'decoder_iter_160000.pth' Trans_path = 'transformer_iter_160000.pth' embedding_path = 'embedding_iter_160000.pth' # Advanced options content_size=640 style_size=640 vgg = StyTR.vgg vgg.load_state_dict(torch.load(vgg_path)) vgg = nn.Sequential(*list(vgg.children())[:44]) decoder = StyTR.decoder Trans = transformer.Transformer() embedding = StyTR.PatchEmbed() decoder.eval() Trans.eval() vgg.eval() new_state_dict = OrderedDict() state_dict = torch.load(decoder_path) decoder.load_state_dict(state_dict) new_state_dict = OrderedDict() state_dict = torch.load(Trans_path) Trans.load_state_dict(state_dict) new_state_dict = OrderedDict() state_dict = torch.load(embedding_path) embedding.load_state_dict(state_dict) network = StyTR.StyTrans(vgg,decoder,embedding,Trans) network.eval() content_tf = content_transform() style_tf = style_transform(style_size,style_size) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") network.to(device) content = content_tf(content_img.convert("RGB")) style = style_tf(style_img.convert("RGB")) style = style.to(device).unsqueeze(0) content = content.to(device).unsqueeze(0) with torch.no_grad(): output= network(content,style) output = output[0].cpu().squeeze() output = output.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() return Image.fromarray(output) ############################################## STYLE-GAN ############################################# def StyleGAN(content_image, style_image): style_transfer_model = tfhub.load("https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2") content_image = tf.convert_to_tensor(content_image, np.float32)[tf.newaxis, ...] / 255. style_image = tf.convert_to_tensor(style_image, np.float32)[tf.newaxis, ...] / 255. output = style_transfer_model(content_image, style_image) stylized_image = output[0] return Image.fromarray(np.uint8(stylized_image[0] * 255)) ########################################### STYLE PROJECTION ########################################## def styleProjection(content_image,style_image): stylepro_artistic = phub.Module(name="stylepro_artistic") result = stylepro_artistic.style_transfer( images=[{ 'content': np.array(content_image.convert('RGB') )[:, :, ::-1], 'styles': [np.array(style_image.convert('RGB') )[:, :, ::-1]] }]) return Image.fromarray(np.uint8(result[0]['data'])[:,:,::-1]).convert('RGB') def create_styledSofa(content_image,style_image): output = StyleTransformer(content_image,style_image) return output