SofaStyler / styleTransfer.py
Sophie98
delete some debug code
9be4acc
raw
history blame
No virus
4.08 kB
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
import transformer as transformer
import StyTR as StyTR
import numpy as np
from collections import OrderedDict
import tensorflow_hub as hub
import tensorflow as tf
from torchvision.utils import save_image
############################################# TRANSFORMER ############################################
vgg_path = 'vgg_normalised.pth'
decoder_path = 'decoder_iter_160000.pth'
Trans_path = 'transformer_iter_160000.pth'
embedding_path = 'embedding_iter_160000.pth'
def test_transform(size, crop):
transform_list = []
if size != 0:
transform_list.append(transforms.Resize(size))
if crop:
transform_list.append(transforms.CenterCrop(size))
transform_list.append(transforms.ToTensor())
transform = transforms.Compose(transform_list)
return transform
def style_transform(h,w):
k = (h,w)
size = int(np.max(k))
transform_list = []
transform_list.append(transforms.CenterCrop((h,w)))
transform_list.append(transforms.ToTensor())
transform = transforms.Compose(transform_list)
return transform
def content_transform():
transform_list = []
transform_list.append(transforms.ToTensor())
transform = transforms.Compose(transform_list)
return transform
# Advanced options
content_size=640
style_size=640
crop='store_true'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vgg = StyTR.vgg
vgg.load_state_dict(torch.load(vgg_path))
vgg = nn.Sequential(*list(vgg.children())[:44])
decoder = StyTR.decoder
Trans = transformer.Transformer()
embedding = StyTR.PatchEmbed()
decoder.eval()
Trans.eval()
vgg.eval()
new_state_dict = OrderedDict()
state_dict = torch.load(decoder_path)
for k, v in state_dict.items():
#namekey = k[7:] # remove `module.`
namekey = k
new_state_dict[namekey] = v
decoder.load_state_dict(new_state_dict)
new_state_dict = OrderedDict()
state_dict = torch.load(Trans_path)
for k, v in state_dict.items():
#namekey = k[7:] # remove `module.`
namekey = k
new_state_dict[namekey] = v
Trans.load_state_dict(new_state_dict)
new_state_dict = OrderedDict()
state_dict = torch.load(embedding_path)
for k, v in state_dict.items():
#namekey = k[7:] # remove `module.`
namekey = k
new_state_dict[namekey] = v
embedding.load_state_dict(new_state_dict)
network = StyTR.StyTrans(vgg,decoder,embedding,Trans)
network.eval()
network.to(device)
content_tf = test_transform(content_size, crop)
style_tf = test_transform(style_size, crop)
def StyleTransformer(content_img: Image, style_img: Image):
content_tf1 = content_transform()
content = content_tf(content_img.convert("RGB"))
h,w,c=np.shape(content)
style_tf1 = style_transform(h,w)
style = style_tf(style_img.convert("RGB"))
style = style.to(device).unsqueeze(0)
content = content.to(device).unsqueeze(0)
with torch.no_grad():
output= network(content,style)
output = output[0].cpu().squeeze()
output = output.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
return Image.fromarray(output)
############################################## STYLE-GAN #############################################
style_transfer_model = hub.load("https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2")
def perform_style_transfer(content_image, style_image):
content_image = tf.convert_to_tensor(content_image, np.float32)[tf.newaxis, ...] / 255.
style_image = tf.convert_to_tensor(style_image, np.float32)[tf.newaxis, ...] / 255.
output = style_transfer_model(content_image, style_image)
stylized_image = output[0]
return Image.fromarray(np.uint8(stylized_image[0] * 255))
################################################# MAIN ################################################
def create_styledSofa(sofa:Image, style:Image):
styled_sofa = StyleTransformer(sofa,style)
return styled_sofa