SofaStyler / StyleTransfer /styleTransfer.py
Sophie98
projection only works with queue
7ea2648
from PIL import Image
import numpy as np
import torch
torch.cuda.is_available()
import torch.nn as nn
from torchvision import transforms
import StyleTransfer.transformer as transformer
import StyleTransfer.StyTR as StyTR
from collections import OrderedDict
import tensorflow_hub as tfhub
import tensorflow as tf
import paddlehub as phub
import os
############################################# TRANSFORMER ############################################
def style_transform(h:int,w:int) -> transforms.Compose:
k = (h,w)
transform_list = []
transform_list.append(transforms.CenterCrop((h,w)))
transform_list.append(transforms.ToTensor())
transform = transforms.Compose(transform_list)
return transform
def content_transform() -> transforms.Compose:
transform_list = []
transform_list.append(transforms.ToTensor())
transform = transforms.Compose(transform_list)
return transform
def StyleTransformer(content_img: Image.Image, style_img: Image.Image) -> Image.Image:
vgg_path = 'StyleTransfer/models/vgg_normalised.pth'
decoder_path = 'StyleTransfer/models/decoder_iter_160000.pth'
Trans_path = 'StyleTransfer/models/transformer_iter_160000.pth'
embedding_path = 'StyleTransfer/models/embedding_iter_160000.pth'
# Advanced options
content_size=640
style_size=640
vgg = StyTR.vgg
vgg.load_state_dict(torch.load(vgg_path))
vgg = nn.Sequential(*list(vgg.children())[:44])
decoder = StyTR.decoder
Trans = transformer.Transformer()
embedding = StyTR.PatchEmbed()
decoder.eval()
Trans.eval()
vgg.eval()
new_state_dict = OrderedDict()
state_dict = torch.load(decoder_path)
decoder.load_state_dict(state_dict)
new_state_dict = OrderedDict()
state_dict = torch.load(Trans_path)
Trans.load_state_dict(state_dict)
new_state_dict = OrderedDict()
state_dict = torch.load(embedding_path)
embedding.load_state_dict(state_dict)
network = StyTR.StyTrans(vgg,decoder,embedding,Trans)
network.eval()
content_tf = content_transform()
style_tf = style_transform(style_size,style_size)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
network.to(device)
content = content_tf(content_img.convert("RGB"))
style = style_tf(style_img.convert("RGB"))
style = style.to(device).unsqueeze(0)
content = content.to(device).unsqueeze(0)
with torch.no_grad():
output= network(content,style)
output = output[0].cpu().squeeze()
output = output.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
return Image.fromarray(output)
############################################## STYLE-FAST #############################################
style_transfer_model = tfhub.load("https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2")
def StyleFAST(content_image:Image.Image, style_image:Image.Image) -> Image.Image:
content_image = tf.convert_to_tensor(np.array(content_image), np.float32)[tf.newaxis, ...] / 255.
style_image = tf.convert_to_tensor(np.array(style_image), np.float32)[tf.newaxis, ...] / 255.
output = style_transfer_model(content_image, style_image)
stylized_image = output[0]
return Image.fromarray(np.uint8(stylized_image[0] * 255))
########################################### STYLE PROJECTION ##########################################
os.system("hub install stylepro_artistic==1.0.1")
stylepro_artistic = phub.Module(name="stylepro_artistic")
def StyleProjection(content_image:Image.Image,style_image:Image.Image) -> Image.Image:
print('line92')
result = stylepro_artistic.style_transfer(
images=[{
'content': np.array(content_image.convert('RGB') )[:, :, ::-1],
'styles': [np.array(style_image.convert('RGB') )[:, :, ::-1]]}],
alpha=0.8)
print('line97')
return Image.fromarray(np.uint8(result[0]['data'])[:,:,::-1]).convert('RGB')
def create_styledSofa(content_image:Image.Image,style_image:Image.Image,choice:str) -> Image.Image:
if choice =="Style Transformer":
output = StyleTransformer(content_image,style_image)
elif choice =="Style FAST":
output = StyleFAST(content_image,style_image)
elif choice =="Style Projection":
output = StyleProjection(content_image,style_image)
else:
output = content_image
return output