from PIL import Image import numpy as np import torch import torch.nn as nn from PIL import Image from torchvision import transforms import transformer as transformer import StyTR as StyTR import numpy as np from collections import OrderedDict import tensorflow_hub as hub import tensorflow as tf style_transfer_model = hub.load("https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2") ############################################# TRANSFORMER ############################################ def test_transform(size, crop): transform_list = [] if size != 0: transform_list.append(transforms.Resize(size)) if crop: transform_list.append(transforms.CenterCrop(size)) transform_list.append(transforms.ToTensor()) transform = transforms.Compose(transform_list) return transform def style_transform(h,w): k = (h,w) size = int(np.max(k)) transform_list = [] transform_list.append(transforms.CenterCrop((h,w))) transform_list.append(transforms.ToTensor()) transform = transforms.Compose(transform_list) return transform def content_transform(): transform_list = [] transform_list.append(transforms.ToTensor()) transform = transforms.Compose(transform_list) return transform def StyleTransformer(content_img: Image, style_img: Image, vgg_path:str = 'vgg_normalised.pth', decoder_path:str = 'decoder_iter_160000.pth', Trans_path:str = 'transformer_iter_160000.pth', embedding_path:str = 'embedding_iter_160000.pth'): # Advanced options content_size=640 style_size=640 crop='store_true' device = torch.device("cuda" if torch.cuda.is_available() else "cpu") vgg = StyTR.vgg vgg.load_state_dict(torch.load(vgg_path)) vgg = nn.Sequential(*list(vgg.children())[:44]) decoder = StyTR.decoder Trans = transformer.Transformer() embedding = StyTR.PatchEmbed() decoder.eval() Trans.eval() vgg.eval() new_state_dict = OrderedDict() state_dict = torch.load(decoder_path) for k, v in state_dict.items(): #namekey = k[7:] # remove `module.` namekey = k new_state_dict[namekey] = v decoder.load_state_dict(new_state_dict) new_state_dict = OrderedDict() state_dict = torch.load(Trans_path) for k, v in state_dict.items(): #namekey = k[7:] # remove `module.` namekey = k new_state_dict[namekey] = v Trans.load_state_dict(new_state_dict) new_state_dict = OrderedDict() state_dict = torch.load(embedding_path) for k, v in state_dict.items(): #namekey = k[7:] # remove `module.` namekey = k new_state_dict[namekey] = v embedding.load_state_dict(new_state_dict) network = StyTR.StyTrans(vgg,decoder,embedding,Trans) network.eval() network.to(device) content_tf = test_transform(content_size, crop) style_tf = test_transform(style_size, crop) content_tf1 = content_transform() content = content_tf(content_img.convert("RGB")) h,w,c=np.shape(content) style_tf1 = style_transform(h,w) style = style_tf(style_img.convert("RGB")) style = style.to(device).unsqueeze(0) content = content.to(device).unsqueeze(0) with torch.no_grad(): output= network(content,style) print(type(output)) output = output[0].cpu() print(type(output)) print(output.squeeze().shape) torch2PIL = transforms.ToPILImage() output = torch2PIL(output.squeeze()) return output ############################################## STYLE-GAN ############################################# def perform_style_transfer(content_image, style_image): content_image = tf.convert_to_tensor(content_image, np.float32)[tf.newaxis, ...] / 255. style_image = tf.convert_to_tensor(style_image, np.float32)[tf.newaxis, ...] / 255. output = style_transfer_model(content_image, style_image) stylized_image = output[0] return Image.fromarray(np.uint8(stylized_image[0] * 255)) def create_styledSofa(sofa:Image, style:Image): styled_sofa = StyleTransformer(sofa,style) return styled_sofa # image = Image.open('sofa_office.jpg') # image.show() # image = np.array(image) # image,box = resize_sofa(image) # image = image.crop(box) # print(box) # image.show()