SofaStyler / styleTransfer.py
Sophie98
errrrrrrror
a907392
raw
history blame
4.35 kB
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
import transformer as transformer
import StyTR as StyTR
import numpy as np
from collections import OrderedDict
import tensorflow_hub as hub
import tensorflow as tf
style_transfer_model = hub.load("https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2")
############################################# TRANSFORMER ############################################
def test_transform(size, crop):
transform_list = []
if size != 0:
transform_list.append(transforms.Resize(size))
if crop:
transform_list.append(transforms.CenterCrop(size))
transform_list.append(transforms.ToTensor())
transform = transforms.Compose(transform_list)
return transform
def style_transform(h,w):
k = (h,w)
size = int(np.max(k))
transform_list = []
transform_list.append(transforms.CenterCrop((h,w)))
transform_list.append(transforms.ToTensor())
transform = transforms.Compose(transform_list)
return transform
def content_transform():
transform_list = []
transform_list.append(transforms.ToTensor())
transform = transforms.Compose(transform_list)
return transform
def StyleTransformer(content_img: Image, style_img: Image,
vgg_path:str = 'vgg_normalised.pth', decoder_path:str = 'decoder_iter_160000.pth',
Trans_path:str = 'transformer_iter_160000.pth', embedding_path:str = 'embedding_iter_160000.pth'):
# Advanced options
content_size=640
style_size=640
crop='store_true'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vgg = StyTR.vgg
vgg.load_state_dict(torch.load(vgg_path))
vgg = nn.Sequential(*list(vgg.children())[:44])
decoder = StyTR.decoder
Trans = transformer.Transformer()
embedding = StyTR.PatchEmbed()
decoder.eval()
Trans.eval()
vgg.eval()
new_state_dict = OrderedDict()
state_dict = torch.load(decoder_path)
for k, v in state_dict.items():
#namekey = k[7:] # remove `module.`
namekey = k
new_state_dict[namekey] = v
decoder.load_state_dict(new_state_dict)
new_state_dict = OrderedDict()
state_dict = torch.load(Trans_path)
for k, v in state_dict.items():
#namekey = k[7:] # remove `module.`
namekey = k
new_state_dict[namekey] = v
Trans.load_state_dict(new_state_dict)
new_state_dict = OrderedDict()
state_dict = torch.load(embedding_path)
for k, v in state_dict.items():
#namekey = k[7:] # remove `module.`
namekey = k
new_state_dict[namekey] = v
embedding.load_state_dict(new_state_dict)
network = StyTR.StyTrans(vgg,decoder,embedding,Trans)
network.eval()
network.to(device)
content_tf = test_transform(content_size, crop)
style_tf = test_transform(style_size, crop)
content_tf1 = content_transform()
content = content_tf(content_img.convert("RGB"))
h,w,c=np.shape(content)
style_tf1 = style_transform(h,w)
style = style_tf(style_img.convert("RGB"))
style = style.to(device).unsqueeze(0)
content = content.to(device).unsqueeze(0)
with torch.no_grad():
output= network(content,style)
print(type(output))
print(output.shape)
output = output[0].cpu()
print(type(output))
print(output.shape)
torch2PIL = transforms.ToPILImage()
output = torch2PIL(output)
return output
############################################## STYLE-GAN #############################################
def perform_style_transfer(content_image, style_image):
content_image = tf.convert_to_tensor(content_image, np.float32)[tf.newaxis, ...] / 255.
style_image = tf.convert_to_tensor(style_image, np.float32)[tf.newaxis, ...] / 255.
output = style_transfer_model(content_image, style_image)
stylized_image = output[0]
return Image.fromarray(np.uint8(stylized_image[0] * 255))
def create_styledSofa(sofa:Image, style:Image):
styled_sofa = StyleTransformer(sofa,style)
return styled_sofa
# image = Image.open('sofa_office.jpg')
# image.show()
# image = np.array(image)
# image,box = resize_sofa(image)
# image = image.crop(box)
# print(box)
# image.show()