File size: 4,343 Bytes
ea2f7c7
 
e4fb230
 
 
 
 
 
 
 
5e8f5b8
 
e4fb230
5e8f5b8
 
 
 
e4fb230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40845c9
e4fb230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea2f7c7
e4fb230
a907392
c10dead
e4fb230
a907392
a53d88d
5e8f5b8
a53d88d
e4fb230
 
5e8f5b8
 
 
 
 
 
 
 
 
e4fb230
40845c9
ea2f7c7
 
2578b1e
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
import transformer as transformer
import StyTR as StyTR
import numpy as np
from collections import OrderedDict
import tensorflow_hub as hub
import tensorflow as tf

style_transfer_model = hub.load("https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2")


############################################# TRANSFORMER ############################################
def test_transform(size, crop):
    transform_list = []
   
    if size != 0: 
        transform_list.append(transforms.Resize(size))
    if crop:
        transform_list.append(transforms.CenterCrop(size))
    transform_list.append(transforms.ToTensor())
    transform = transforms.Compose(transform_list)
    return transform
def style_transform(h,w):
    k = (h,w)
    size = int(np.max(k))
    transform_list = []    
    transform_list.append(transforms.CenterCrop((h,w)))
    transform_list.append(transforms.ToTensor())
    transform = transforms.Compose(transform_list)
    return transform

def content_transform():
    transform_list = []   
    transform_list.append(transforms.ToTensor())
    transform = transforms.Compose(transform_list)
    return transform

def StyleTransformer(content_img: Image, style_img: Image,
                vgg_path:str = 'vgg_normalised.pth', decoder_path:str = 'decoder_iter_160000.pth',
                Trans_path:str = 'transformer_iter_160000.pth', embedding_path:str = 'embedding_iter_160000.pth'):
    # Advanced options
    content_size=640
    style_size=640
    crop='store_true'

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    vgg = StyTR.vgg
    vgg.load_state_dict(torch.load(vgg_path))
    vgg = nn.Sequential(*list(vgg.children())[:44])

    decoder = StyTR.decoder
    Trans = transformer.Transformer()
    embedding = StyTR.PatchEmbed()

    decoder.eval()
    Trans.eval()
    vgg.eval()

    new_state_dict = OrderedDict()
    state_dict = torch.load(decoder_path)
    for k, v in state_dict.items():
        #namekey = k[7:] # remove `module.`
        namekey = k
        new_state_dict[namekey] = v
    decoder.load_state_dict(new_state_dict)

    new_state_dict = OrderedDict()
    state_dict = torch.load(Trans_path)
    for k, v in state_dict.items():
        #namekey = k[7:] # remove `module.`
        namekey = k
        new_state_dict[namekey] = v
    Trans.load_state_dict(new_state_dict)

    new_state_dict = OrderedDict()
    state_dict = torch.load(embedding_path)
    for k, v in state_dict.items():
        #namekey = k[7:] # remove `module.`
        namekey = k
        new_state_dict[namekey] = v
    embedding.load_state_dict(new_state_dict)

    network = StyTR.StyTrans(vgg,decoder,embedding,Trans)
    network.eval()
    network.to(device)

    content_tf = test_transform(content_size, crop)
    style_tf = test_transform(style_size, crop)
    
    content_tf1 = content_transform()       
    content = content_tf(content_img.convert("RGB"))

    h,w,c=np.shape(content)    
    style_tf1 = style_transform(h,w)
    style = style_tf(style_img.convert("RGB"))


    style = style.to(device).unsqueeze(0)
    content = content.to(device).unsqueeze(0)

    with torch.no_grad():
        output= network(content,style)  
    print(type(output))  
    output = output[0].cpu()
    print(type(output))
    print(output.squeeze().shape)
    torch2PIL = transforms.ToPILImage()
    output = torch2PIL(output.squeeze())
    return output
   
############################################## STYLE-GAN #############################################

def perform_style_transfer(content_image, style_image):
    content_image = tf.convert_to_tensor(content_image, np.float32)[tf.newaxis, ...] / 255.
    style_image = tf.convert_to_tensor(style_image, np.float32)[tf.newaxis, ...] / 255.
    output = style_transfer_model(content_image, style_image)
    stylized_image = output[0]
    return Image.fromarray(np.uint8(stylized_image[0] * 255))

def create_styledSofa(sofa:Image, style:Image):
    styled_sofa = StyleTransformer(sofa,style)
    return styled_sofa

# image = Image.open('sofa_office.jpg') 
# image.show() 
# image = np.array(image)
# image,box = resize_sofa(image)  
# image = image.crop(box)
# print(box)
# image.show()