from PIL import Image import numpy as np import torch import torch.nn as nn from PIL import Image from torchvision import transforms import transformer as transformer import StyTR as StyTR import numpy as np from collections import OrderedDict import tensorflow_hub as hub import tensorflow as tf ############################################# TRANSFORMER ############################################ vgg_path = 'vgg_normalised.pth' decoder_path = 'decoder_iter_160000.pth' Trans_path = 'transformer_iter_160000.pth' embedding_path = 'embedding_iter_160000.pth' def style_transform(h,w): k = (h,w) size = int(np.max(k)) transform_list = [] transform_list.append(transforms.CenterCrop((h,w))) transform_list.append(transforms.ToTensor()) transform = transforms.Compose(transform_list) return transform def content_transform(): transform_list = [] transform_list.append(transforms.ToTensor()) transform = transforms.Compose(transform_list) return transform # Advanced options content_size=640 style_size=640 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") vgg = StyTR.vgg vgg.load_state_dict(torch.load(vgg_path)) vgg = nn.Sequential(*list(vgg.children())[:44]) decoder = StyTR.decoder Trans = transformer.Transformer() embedding = StyTR.PatchEmbed() decoder.eval() Trans.eval() vgg.eval() new_state_dict = OrderedDict() state_dict = torch.load(decoder_path) decoder.load_state_dict(state_dict) new_state_dict = OrderedDict() state_dict = torch.load(Trans_path) Trans.load_state_dict(state_dict) new_state_dict = OrderedDict() state_dict = torch.load(embedding_path) embedding.load_state_dict(state_dict) network = StyTR.StyTrans(vgg,decoder,embedding,Trans) network.eval() content_tf = content_transform() style_tf = style_transform(style_size,style_size) def StyleTransformer(content_img: Image, style_img: Image): network.to(device) content = content_tf(content_img.convert("RGB")) style = style_tf(style_img.convert("RGB")) style = style.to(device).unsqueeze(0) content = content.to(device).unsqueeze(0) with torch.no_grad(): output= network(content,style) output = output[0].cpu().squeeze() output = output.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() return Image.fromarray(output) ############################################## STYLE-GAN ############################################# style_transfer_model = hub.load("https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2") def StyleGAN(content_image, style_image): content_image = tf.convert_to_tensor(content_image, np.float32)[tf.newaxis, ...] / 255. style_image = tf.convert_to_tensor(style_image, np.float32)[tf.newaxis, ...] / 255. output = style_transfer_model(content_image, style_image) stylized_image = output[0] return Image.fromarray(np.uint8(stylized_image[0] * 255)) ################################################# MAIN ################################################ def create_styledSofa(sofa:Image, style:Image): #styled_sofa = StyleGAN(sofa,style) styled_sofa = StyleTransformer(sofa,style) return styled_sofa