SofaStyler / StyleTransfer /styleTransfer.py
Sophie98
fix error with style projection
587b848
raw
history blame
4.4 kB
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
from torchvision import transforms
import StyleTransfer.transformer as transformer
import StyleTransfer.StyTR as StyTR
from collections import OrderedDict
import tensorflow_hub as tfhub
import tensorflow as tf
import paddlehub as phub
import os
############################################# TRANSFORMER ############################################
def style_transform(h:int,w:int) -> transforms.Compose:
k = (h,w)
transform_list = []
transform_list.append(transforms.CenterCrop((h,w)))
transform_list.append(transforms.ToTensor())
transform = transforms.Compose(transform_list)
return transform
def content_transform() -> transforms.Compose:
transform_list = []
transform_list.append(transforms.ToTensor())
transform = transforms.Compose(transform_list)
return transform
def StyleTransformer(content_img: Image.Image, style_img: Image.Image) -> Image.Image:
vgg_path = 'StyleTransfer/models/vgg_normalised.pth'
decoder_path = 'StyleTransfer/models/decoder_iter_160000.pth'
Trans_path = 'StyleTransfer/models/transformer_iter_160000.pth'
embedding_path = 'StyleTransfer/models/embedding_iter_160000.pth'
# Advanced options
content_size=640
style_size=640
vgg = StyTR.vgg
vgg.load_state_dict(torch.load(vgg_path))
vgg = nn.Sequential(*list(vgg.children())[:44])
decoder = StyTR.decoder
Trans = transformer.Transformer()
embedding = StyTR.PatchEmbed()
decoder.eval()
Trans.eval()
vgg.eval()
new_state_dict = OrderedDict()
state_dict = torch.load(decoder_path)
decoder.load_state_dict(state_dict)
new_state_dict = OrderedDict()
state_dict = torch.load(Trans_path)
Trans.load_state_dict(state_dict)
new_state_dict = OrderedDict()
state_dict = torch.load(embedding_path)
embedding.load_state_dict(state_dict)
network = StyTR.StyTrans(vgg,decoder,embedding,Trans)
network.eval()
content_tf = content_transform()
style_tf = style_transform(style_size,style_size)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
network.to(device)
content = content_tf(content_img.convert("RGB"))
style = style_tf(style_img.convert("RGB"))
style = style.to(device).unsqueeze(0)
content = content.to(device).unsqueeze(0)
with torch.no_grad():
output= network(content,style)
output = output[0].cpu().squeeze()
output = output.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
return Image.fromarray(output)
############################################## STYLE-FAST #############################################
style_transfer_model = tfhub.load("https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2")
def StyleFAST(content_image:Image.Image, style_image:Image.Image) -> Image.Image:
content_image = tf.convert_to_tensor(np.array(content_image), np.float32)[tf.newaxis, ...] / 255.
style_image = tf.convert_to_tensor(np.array(style_image), np.float32)[tf.newaxis, ...] / 255.
output = style_transfer_model(content_image, style_image)
stylized_image = output[0]
return Image.fromarray(np.uint8(stylized_image[0] * 255))
########################################### STYLE PROJECTION ##########################################
os.system("phub install stylepro_artistic==1.0.1")
stylepro_artistic = phub.Module(name="stylepro_artistic")
def StyleProjection(content_image:Image.Image,style_image:Image.Image) -> Image.Image:
print('line92')
result = stylepro_artistic.style_transfer(
images=[{
'content': np.array(content_image.convert('RGB') )[:, :, ::-1],
'styles': [np.array(style_image.convert('RGB') )[:, :, ::-1]]}],
alpha=0.8)
print('line97')
return Image.fromarray(np.uint8(result[0]['data'])[:,:,::-1]).convert('RGB')
def create_styledSofa(content_image:Image.Image,style_image:Image.Image,choice:str) -> Image.Image:
if choice =="Style Transformer":
output = StyleTransformer(content_image,style_image)
elif choice =="Style FAST":
output = StyleFAST(content_image,style_image)
elif choice =="Style Projection":
output = StyleProjection(content_image,style_image)
else:
output = content_image
return output