SofaStyler / styleTransfer.py
Sophie98
try to fix error
b3e13eb
raw history blame
No virus
3.72 kB
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
from torchvision import transforms
import transformer as transformer
import StyTR as StyTR
from collections import OrderedDict
import tensorflow_hub as tfhub
import tensorflow as tf
import os
import cv2
import paddlehub as phub
############################################# TRANSFORMER ############################################
def style_transform(h,w):
k = (h,w)
transform_list = []
transform_list.append(transforms.CenterCrop((h,w)))
transform_list.append(transforms.ToTensor())
transform = transforms.Compose(transform_list)
return transform
def content_transform():
transform_list = []
transform_list.append(transforms.ToTensor())
transform = transforms.Compose(transform_list)
return transform
def StyleTransformer(content_img: Image, style_img: Image):
vgg_path = 'vgg_normalised.pth'
decoder_path = 'decoder_iter_160000.pth'
Trans_path = 'transformer_iter_160000.pth'
embedding_path = 'embedding_iter_160000.pth'
# Advanced options
content_size=640
style_size=640
vgg = StyTR.vgg
vgg.load_state_dict(torch.load(vgg_path))
vgg = nn.Sequential(*list(vgg.children())[:44])
decoder = StyTR.decoder
Trans = transformer.Transformer()
embedding = StyTR.PatchEmbed()
decoder.eval()
Trans.eval()
vgg.eval()
new_state_dict = OrderedDict()
state_dict = torch.load(decoder_path)
decoder.load_state_dict(state_dict)
new_state_dict = OrderedDict()
state_dict = torch.load(Trans_path)
Trans.load_state_dict(state_dict)
new_state_dict = OrderedDict()
state_dict = torch.load(embedding_path)
embedding.load_state_dict(state_dict)
network = StyTR.StyTrans(vgg,decoder,embedding,Trans)
network.eval()
content_tf = content_transform()
style_tf = style_transform(style_size,style_size)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
network.to(device)
content = content_tf(content_img.convert("RGB"))
style = style_tf(style_img.convert("RGB"))
style = style.to(device).unsqueeze(0)
content = content.to(device).unsqueeze(0)
with torch.no_grad():
output= network(content,style)
output = output[0].cpu().squeeze()
output = output.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
return Image.fromarray(output)
############################################## STYLE-GAN #############################################
def StyleGAN(content_image, style_image):
style_transfer_model = tfhub.load("https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2")
content_image = tf.convert_to_tensor(content_image, np.float32)[tf.newaxis, ...] / 255.
style_image = tf.convert_to_tensor(style_image, np.float32)[tf.newaxis, ...] / 255.
output = style_transfer_model(content_image, style_image)
stylized_image = output[0]
return Image.fromarray(np.uint8(stylized_image[0] * 255))
########################################### STYLE PROJECTION ##########################################
def styleProjection(content_image,style_image):
stylepro_artistic = phub.Module(name="stylepro_artistic")
result = stylepro_artistic.style_transfer(
images=[{
'content': np.array(content_image.convert('RGB') )[:, :, ::-1],
'styles': [np.array(style_image.convert('RGB') )[:, :, ::-1]]
}])
return Image.fromarray(np.uint8(result[0]['data'])[:,:,::-1]).convert('RGB')
def create_styledSofa(content_image,style_image):
output = StyleTransformer(content_image,style_image)
return output