SofaStyler / styleTransfer.py
Sophie98
fix error
40845c9
raw
history blame
3.47 kB
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
import transformer as transformer
import StyTR as StyTR
import numpy as np
from collections import OrderedDict
def test_transform(size, crop):
transform_list = []
if size != 0:
transform_list.append(transforms.Resize(size))
if crop:
transform_list.append(transforms.CenterCrop(size))
transform_list.append(transforms.ToTensor())
transform = transforms.Compose(transform_list)
return transform
def style_transform(h,w):
k = (h,w)
size = int(np.max(k))
transform_list = []
transform_list.append(transforms.CenterCrop((h,w)))
transform_list.append(transforms.ToTensor())
transform = transforms.Compose(transform_list)
return transform
def content_transform():
transform_list = []
transform_list.append(transforms.ToTensor())
transform = transforms.Compose(transform_list)
return transform
def StyleTransformer(content_img: Image, style_img: Image,
vgg_path:str = 'vgg_normalised.pth', decoder_path:str = 'decoder_iter_160000.pth',
Trans_path:str = 'transformer_iter_160000.pth', embedding_path:str = 'embedding_iter_160000.pth'):
# Advanced options
content_size=640
style_size=640
crop='store_true'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vgg = StyTR.vgg
vgg.load_state_dict(torch.load(vgg_path))
vgg = nn.Sequential(*list(vgg.children())[:44])
decoder = StyTR.decoder
Trans = transformer.Transformer()
embedding = StyTR.PatchEmbed()
decoder.eval()
Trans.eval()
vgg.eval()
new_state_dict = OrderedDict()
state_dict = torch.load(decoder_path)
for k, v in state_dict.items():
#namekey = k[7:] # remove `module.`
namekey = k
new_state_dict[namekey] = v
decoder.load_state_dict(new_state_dict)
new_state_dict = OrderedDict()
state_dict = torch.load(Trans_path)
for k, v in state_dict.items():
#namekey = k[7:] # remove `module.`
namekey = k
new_state_dict[namekey] = v
Trans.load_state_dict(new_state_dict)
new_state_dict = OrderedDict()
state_dict = torch.load(embedding_path)
for k, v in state_dict.items():
#namekey = k[7:] # remove `module.`
namekey = k
new_state_dict[namekey] = v
embedding.load_state_dict(new_state_dict)
network = StyTR.StyTrans(vgg,decoder,embedding,Trans)
network.eval()
network.to(device)
content_tf = test_transform(content_size, crop)
style_tf = test_transform(style_size, crop)
content_tf1 = content_transform()
content = content_tf(content_img.convert("RGB"))
h,w,c=np.shape(content)
style_tf1 = style_transform(h,w)
style = style_tf(style_img.convert("RGB"))
style = style.to(device).unsqueeze(0)
content = content.to(device).unsqueeze(0)
with torch.no_grad():
output= network(content,style)
output = output[0].cpu()
output = transforms.ToPILImage(output)
return output
def create_styledSofa(sofa:Image, style:Image):
styled_sofa = StyleTransformer(sofa,style)
return styled_sofa
# image = Image.open('sofa_office.jpg')
# image.show()
# image = np.array(image)
# image,box = resize_sofa(image)
# image = image.crop(box)
# print(box)
# image.show()