from PIL import Image import numpy as np import torch import torch.nn as nn from torchvision import transforms import StyleTransfer.transformer as transformer import StyleTransfer.StyTR as StyTR from collections import OrderedDict import tensorflow_hub as tfhub import tensorflow as tf import paddlehub as phub import os ############################################# TRANSFORMER ############################################ def style_transform(h:int,w:int) -> transforms.Compose: k = (h,w) transform_list = [] transform_list.append(transforms.CenterCrop((h,w))) transform_list.append(transforms.ToTensor()) transform = transforms.Compose(transform_list) return transform def content_transform() -> transforms.Compose: transform_list = [] transform_list.append(transforms.ToTensor()) transform = transforms.Compose(transform_list) return transform def StyleTransformer(content_img: Image.Image, style_img: Image.Image) -> Image.Image: vgg_path = 'StyleTransfer/models/vgg_normalised.pth' decoder_path = 'StyleTransfer/models/decoder_iter_160000.pth' Trans_path = 'StyleTransfer/models/transformer_iter_160000.pth' embedding_path = 'StyleTransfer/models/embedding_iter_160000.pth' # Advanced options content_size=640 style_size=640 vgg = StyTR.vgg vgg.load_state_dict(torch.load(vgg_path)) vgg = nn.Sequential(*list(vgg.children())[:44]) decoder = StyTR.decoder Trans = transformer.Transformer() embedding = StyTR.PatchEmbed() decoder.eval() Trans.eval() vgg.eval() new_state_dict = OrderedDict() state_dict = torch.load(decoder_path) decoder.load_state_dict(state_dict) new_state_dict = OrderedDict() state_dict = torch.load(Trans_path) Trans.load_state_dict(state_dict) new_state_dict = OrderedDict() state_dict = torch.load(embedding_path) embedding.load_state_dict(state_dict) network = StyTR.StyTrans(vgg,decoder,embedding,Trans) network.eval() content_tf = content_transform() style_tf = style_transform(style_size,style_size) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") network.to(device) content = content_tf(content_img.convert("RGB")) style = style_tf(style_img.convert("RGB")) style = style.to(device).unsqueeze(0) content = content.to(device).unsqueeze(0) with torch.no_grad(): output= network(content,style) output = output[0].cpu().squeeze() output = output.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() return Image.fromarray(output) ############################################## STYLE-FAST ############################################# style_transfer_model = tfhub.load("https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2") def StyleFAST(content_image:Image.Image, style_image:Image.Image) -> Image.Image: content_image = tf.convert_to_tensor(np.array(content_image), np.float32)[tf.newaxis, ...] / 255. style_image = tf.convert_to_tensor(np.array(style_image), np.float32)[tf.newaxis, ...] / 255. output = style_transfer_model(content_image, style_image) stylized_image = output[0] return Image.fromarray(np.uint8(stylized_image[0] * 255)) ########################################### STYLE PROJECTION ########################################## os.system("hub install stylepro_artistic==1.0.1") stylepro_artistic = phub.Module(name="stylepro_artistic") def StyleProjection(content_image:Image.Image,style_image:Image.Image) -> Image.Image: print('line92') result = stylepro_artistic.style_transfer( images=[{ 'content': np.array(content_image.convert('RGB') )[:, :, ::-1], 'styles': [np.array(style_image.convert('RGB') )[:, :, ::-1]]}], alpha=0.8) print('line97') return Image.fromarray(np.uint8(result[0]['data'])[:,:,::-1]).convert('RGB') def create_styledSofa(content_image:Image.Image,style_image:Image.Image,choice:str) -> Image.Image: if choice =="Style Transformer": output = StyleTransformer(content_image,style_image) elif choice =="Style FAST": output = StyleFAST(content_image,style_image) elif choice =="Style Projection": output = StyleProjection(content_image,style_image) else: output = content_image return output