File size: 4,081 Bytes
ea2f7c7
 
e4fb230
 
 
 
 
 
 
 
5e8f5b8
 
e4fb230
8e6efc1
5e8f5b8
 
8e6efc1
 
 
 
 
 
e4fb230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e6efc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4fb230
 
 
 
 
 
 
ea2f7c7
e4fb230
a907392
8e6efc1
 
 
e4fb230
5e8f5b8
 
8e6efc1
 
5e8f5b8
 
 
 
 
 
 
a402a5e
 
 
e4fb230
40845c9
ea2f7c7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
import transformer as transformer
import StyTR as StyTR
import numpy as np
from collections import OrderedDict
import tensorflow_hub as hub
import tensorflow as tf

from torchvision.utils import save_image

############################################# TRANSFORMER ############################################

vgg_path        = 'vgg_normalised.pth'
decoder_path    = 'decoder_iter_160000.pth'
Trans_path      = 'transformer_iter_160000.pth'
embedding_path  = 'embedding_iter_160000.pth'

def test_transform(size, crop):
    transform_list = []
   
    if size != 0: 
        transform_list.append(transforms.Resize(size))
    if crop:
        transform_list.append(transforms.CenterCrop(size))
    transform_list.append(transforms.ToTensor())
    transform = transforms.Compose(transform_list)
    return transform
def style_transform(h,w):
    k = (h,w)
    size = int(np.max(k))
    transform_list = []    
    transform_list.append(transforms.CenterCrop((h,w)))
    transform_list.append(transforms.ToTensor())
    transform = transforms.Compose(transform_list)
    return transform

def content_transform():
    transform_list = []   
    transform_list.append(transforms.ToTensor())
    transform = transforms.Compose(transform_list)
    return transform

# Advanced options
content_size=640
style_size=640
crop='store_true'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

vgg = StyTR.vgg
vgg.load_state_dict(torch.load(vgg_path))
vgg = nn.Sequential(*list(vgg.children())[:44])

decoder = StyTR.decoder
Trans = transformer.Transformer()
embedding = StyTR.PatchEmbed()

decoder.eval()
Trans.eval()
vgg.eval()

new_state_dict = OrderedDict()
state_dict = torch.load(decoder_path)
for k, v in state_dict.items():
    #namekey = k[7:] # remove `module.`
    namekey = k
    new_state_dict[namekey] = v
decoder.load_state_dict(new_state_dict)

new_state_dict = OrderedDict()
state_dict = torch.load(Trans_path)
for k, v in state_dict.items():
    #namekey = k[7:] # remove `module.`
    namekey = k
    new_state_dict[namekey] = v
Trans.load_state_dict(new_state_dict)

new_state_dict = OrderedDict()
state_dict = torch.load(embedding_path)
for k, v in state_dict.items():
    #namekey = k[7:] # remove `module.`
    namekey = k
    new_state_dict[namekey] = v
embedding.load_state_dict(new_state_dict)

network = StyTR.StyTrans(vgg,decoder,embedding,Trans)
network.eval()
network.to(device)

content_tf = test_transform(content_size, crop)
style_tf = test_transform(style_size, crop)

def StyleTransformer(content_img: Image, style_img: Image):
    content_tf1 = content_transform()       
    content = content_tf(content_img.convert("RGB"))
    h,w,c=np.shape(content)    
    style_tf1 = style_transform(h,w)
    style = style_tf(style_img.convert("RGB"))
    style = style.to(device).unsqueeze(0)
    content = content.to(device).unsqueeze(0)

    with torch.no_grad():
        output= network(content,style)  
    output = output[0].cpu().squeeze()
    output = output.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
    return Image.fromarray(output)
   
############################################## STYLE-GAN #############################################

style_transfer_model = hub.load("https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2")

def perform_style_transfer(content_image, style_image):
    content_image = tf.convert_to_tensor(content_image, np.float32)[tf.newaxis, ...] / 255.
    style_image = tf.convert_to_tensor(style_image, np.float32)[tf.newaxis, ...] / 255.
    output = style_transfer_model(content_image, style_image)
    stylized_image = output[0]
    return Image.fromarray(np.uint8(stylized_image[0] * 255))

################################################# MAIN ################################################


def create_styledSofa(sofa:Image, style:Image):
    styled_sofa = StyleTransformer(sofa,style)
    return styled_sofa