from PIL import Image import numpy as np import torch import torch.nn as nn from PIL import Image from torchvision import transforms import transformer as transformer import StyTR as StyTR import numpy as np from collections import OrderedDict def test_transform(size, crop): transform_list = [] if size != 0: transform_list.append(transforms.Resize(size)) if crop: transform_list.append(transforms.CenterCrop(size)) transform_list.append(transforms.ToTensor()) transform = transforms.Compose(transform_list) return transform def style_transform(h,w): k = (h,w) size = int(np.max(k)) transform_list = [] transform_list.append(transforms.CenterCrop((h,w))) transform_list.append(transforms.ToTensor()) transform = transforms.Compose(transform_list) return transform def content_transform(): transform_list = [] transform_list.append(transforms.ToTensor()) transform = transforms.Compose(transform_list) return transform def StyleTransformer(content_img: Image, style_img: Image, vgg_path:str = 'vgg_normalised.pth', decoder_path:str = 'decoder_iter_160000.pth', Trans_path:str = 'transformer_iter_160000.pth', embedding_path:str = 'embedding_iter_160000.pth'): # Advanced options content_size=640 style_size=640 crop='store_true' device = torch.device("cuda" if torch.cuda.is_available() else "cpu") vgg = StyTR.vgg vgg.load_state_dict(torch.load(vgg_path)) vgg = nn.Sequential(*list(vgg.children())[:44]) decoder = StyTR.decoder Trans = transformer.Transformer() embedding = StyTR.PatchEmbed() decoder.eval() Trans.eval() vgg.eval() new_state_dict = OrderedDict() state_dict = torch.load(decoder_path) for k, v in state_dict.items(): #namekey = k[7:] # remove `module.` namekey = k new_state_dict[namekey] = v decoder.load_state_dict(new_state_dict) new_state_dict = OrderedDict() state_dict = torch.load(Trans_path) for k, v in state_dict.items(): #namekey = k[7:] # remove `module.` namekey = k new_state_dict[namekey] = v Trans.load_state_dict(new_state_dict) new_state_dict = OrderedDict() state_dict = torch.load(embedding_path) for k, v in state_dict.items(): #namekey = k[7:] # remove `module.` namekey = k new_state_dict[namekey] = v embedding.load_state_dict(new_state_dict) network = StyTR.StyTrans(vgg,decoder,embedding,Trans) network.eval() network.to(device) content_tf = test_transform(content_size, crop) style_tf = test_transform(style_size, crop) content_tf1 = content_transform() content = content_tf(content_img.convert("RGB")) h,w,c=np.shape(content) style_tf1 = style_transform(h,w) style = style_tf(style_img.convert("RGB")) style = style.to(device).unsqueeze(0) content = content.to(device).unsqueeze(0) with torch.no_grad(): output= network(content,style) output = output[0].cpu() output = transforms.ToPILImage(output) return output def create_styledSofa(sofa:Image, style:Image): styled_sofa = StyleTransformer(sofa,style) return styled_sofa # image = Image.open('sofa_office.jpg') # image.show() # image = np.array(image) # image,box = resize_sofa(image) # image = image.crop(box) # print(box) # image.show()