SofaStyler / styleTransfer.py
Sophie98
ERROR
acda9fe
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
import transformer as transformer
import StyTR as StyTR
import numpy as np
from collections import OrderedDict
import tensorflow_hub as hub
import tensorflow as tf
############################################# TRANSFORMER ############################################
vgg_path = 'vgg_normalised.pth'
decoder_path = 'decoder_iter_160000.pth'
Trans_path = 'transformer_iter_160000.pth'
embedding_path = 'embedding_iter_160000.pth'
def style_transform(h,w):
k = (h,w)
transform_list = []
transform_list.append(transforms.CenterCrop((h,w)))
transform_list.append(transforms.ToTensor())
transform = transforms.Compose(transform_list)
return transform
def content_transform():
transform_list = []
transform_list.append(transforms.ToTensor())
transform = transforms.Compose(transform_list)
return transform
# Advanced options
content_size=640
style_size=640
vgg = StyTR.vgg
vgg.load_state_dict(torch.load(vgg_path))
vgg = nn.Sequential(*list(vgg.children())[:44])
decoder = StyTR.decoder
Trans = transformer.Transformer()
embedding = StyTR.PatchEmbed()
def StyleTransformer(content_img: Image, style_img: Image):
decoder.eval()
Trans.eval()
vgg.eval()
new_state_dict = OrderedDict()
state_dict = torch.load(decoder_path)
decoder.load_state_dict(state_dict)
new_state_dict = OrderedDict()
state_dict = torch.load(Trans_path)
Trans.load_state_dict(state_dict)
new_state_dict = OrderedDict()
state_dict = torch.load(embedding_path)
embedding.load_state_dict(state_dict)
network = StyTR.StyTrans(vgg,decoder,embedding,Trans)
network.eval()
content_tf = content_transform()
style_tf = style_transform(style_size,style_size)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
network.to(device)
content = content_tf(content_img.convert("RGB"))
style = style_tf(style_img.convert("RGB"))
style = style.to(device).unsqueeze(0)
content = content.to(device).unsqueeze(0)
with torch.no_grad():
output= network(content,style)
output = output[0].cpu().squeeze()
output = output.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
return Image.fromarray(output)
############################################## STYLE-GAN #############################################
style_transfer_model = hub.load("https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2")
def StyleGAN(content_image, style_image):
content_image = tf.convert_to_tensor(content_image, np.float32)[tf.newaxis, ...] / 255.
style_image = tf.convert_to_tensor(style_image, np.float32)[tf.newaxis, ...] / 255.
output = style_transfer_model(content_image, style_image)
stylized_image = output[0]
return Image.fromarray(np.uint8(stylized_image[0] * 255))
################################################# MAIN ################################################
def create_styledSofa(sofa:Image, style:Image):
#styled_sofa = StyleGAN(sofa,style)
styled_sofa = StyleTransformer(sofa,style)
return styled_sofa