from PIL import Image import numpy as np import torch torch.cuda.is_available() import torch.nn as nn from torchvision import transforms import StyleTransfer.transformer as transformer import StyleTransfer.StyTR as StyTR from collections import OrderedDict import tensorflow_hub as tfhub import tensorflow as tf import paddlehub as phub import os ############################################# TRANSFORMER ############################################ def style_transform(h:int,w:int) -> transforms.Compose: k = (h,w) transform_list = [] transform_list.append(transforms.CenterCrop((h,w))) transform_list.append(transforms.ToTensor()) transform = transforms.Compose(transform_list) return transform def content_transform() -> transforms.Compose: transform_list = [] transform_list.append(transforms.ToTensor()) transform = transforms.Compose(transform_list) return transform def StyleTransformer(content_img: Image.Image, style_img: Image.Image) -> Image.Image: vgg_path = 'StyleTransfer/models/vgg_normalised.pth' decoder_path = 'StyleTransfer/models/decoder_iter_160000.pth' Trans_path = 'StyleTransfer/models/transformer_iter_160000.pth' embedding_path = 'StyleTransfer/models/embedding_iter_160000.pth' # Advanced options content_size=640 style_size=640 vgg = StyTR.vgg vgg.load_state_dict(torch.load(vgg_path)) vgg = nn.Sequential(*list(vgg.children())[:44]) decoder = StyTR.decoder Trans = transformer.Transformer() embedding = StyTR.PatchEmbed() decoder.eval() Trans.eval() vgg.eval() new_state_dict = OrderedDict() state_dict = torch.load(decoder_path) decoder.load_state_dict(state_dict) new_state_dict = OrderedDict() state_dict = torch.load(Trans_path) Trans.load_state_dict(state_dict) new_state_dict = OrderedDict() state_dict = torch.load(embedding_path) embedding.load_state_dict(state_dict) network = StyTR.StyTrans(vgg,decoder,embedding,Trans) network.eval() content_tf = content_transform() style_tf = style_transform(style_size,style_size) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") network.to(device) content = content_tf(content_img.convert("RGB")) style = style_tf(style_img.convert("RGB")) style = style.to(device).unsqueeze(0) content = content.to(device).unsqueeze(0) with torch.no_grad(): output= network(content,style) output = output[0].cpu().squeeze() output = output.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() return Image.fromarray(output) ############################################## STYLE-FAST ############################################# style_transfer_model = tfhub.load("https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2") def StyleFAST(content_image:Image.Image, style_image:Image.Image) -> Image.Image: content_image = tf.convert_to_tensor(np.array(content_image), np.float32)[tf.newaxis, ...] / 255. style_image = tf.convert_to_tensor(np.array(style_image), np.float32)[tf.newaxis, ...] / 255. output = style_transfer_model(content_image, style_image) stylized_image = output[0] return Image.fromarray(np.uint8(stylized_image[0] * 255)) ########################################### STYLE PROJECTION ########################################## os.system("hub install stylepro_artistic==1.0.1") stylepro_artistic = phub.Module(name="stylepro_artistic") def StyleProjection(content_image:Image.Image,style_image:Image.Image) -> Image.Image: print('line92') result = stylepro_artistic.style_transfer( images=[{ 'content': np.array(content_image.convert('RGB') )[:, :, ::-1], 'styles': [np.array(style_image.convert('RGB') )[:, :, ::-1]]}], alpha=0.8) print('line97') return Image.fromarray(np.uint8(result[0]['data'])[:,:,::-1]).convert('RGB') def create_styledSofa(content_image:Image.Image,style_image:Image.Image,choice:str) -> Image.Image: if choice =="Style Transformer": output = StyleTransformer(content_image,style_image) elif choice =="Style FAST": output = StyleFAST(content_image,style_image) elif choice =="Style Projection": output = StyleProjection(content_image,style_image) else: output = content_image return output