File size: 4,399 Bytes
ea2f7c7
 
e4fb230
 
 
4e64649
 
e4fb230
bf82406
5e8f5b8
bf82406
587b848
e4fb230
5e8f5b8
8e6efc1
7bebb02
e4fb230
 
 
 
 
 
 
7bebb02
e4fb230
 
 
 
 
7bebb02
4e64649
 
 
 
3b83a8e
 
 
8e6efc1
3b83a8e
 
 
 
 
 
 
acda9fe
 
 
8e6efc1
acda9fe
 
 
8e6efc1
acda9fe
 
 
8e6efc1
acda9fe
 
 
8e6efc1
acda9fe
 
 
 
ab7b996
baaaa83
21b3928
ab7b996
 
e4fb230
 
 
a907392
8e6efc1
 
 
e4fb230
7bebb02
12e61cd
8e6efc1
7bebb02
 
 
5e8f5b8
 
 
 
bf82406
c06d8b9
12e61cd
7bebb02
709b74f
bf82406
 
 
587b848
 
709b74f
bf82406
3b83a8e
7bebb02
 
 
 
 
 
 
 
 
3b83a8e
bf82406
a402a5e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
from torchvision import transforms
import StyleTransfer.transformer as transformer
import StyleTransfer.StyTR as StyTR
from collections import OrderedDict
import tensorflow_hub as tfhub
import tensorflow as tf
import paddlehub as phub
import os

############################################# TRANSFORMER ############################################

def style_transform(h:int,w:int) -> transforms.Compose:
    k = (h,w)
    transform_list = []    
    transform_list.append(transforms.CenterCrop((h,w)))
    transform_list.append(transforms.ToTensor())
    transform = transforms.Compose(transform_list)
    return transform

def content_transform() -> transforms.Compose:
    transform_list = []   
    transform_list.append(transforms.ToTensor())
    transform = transforms.Compose(transform_list)
    return transform

def StyleTransformer(content_img: Image.Image, style_img: Image.Image) -> Image.Image:
    vgg_path        = 'StyleTransfer/models/vgg_normalised.pth'
    decoder_path    = 'StyleTransfer/models/decoder_iter_160000.pth'
    Trans_path      = 'StyleTransfer/models/transformer_iter_160000.pth'
    embedding_path  = 'StyleTransfer/models/embedding_iter_160000.pth'
    # Advanced options
    content_size=640
    style_size=640

    vgg = StyTR.vgg
    vgg.load_state_dict(torch.load(vgg_path))
    vgg = nn.Sequential(*list(vgg.children())[:44])

    decoder = StyTR.decoder
    Trans = transformer.Transformer()
    embedding = StyTR.PatchEmbed()
    decoder.eval()
    Trans.eval()
    vgg.eval()

    new_state_dict = OrderedDict()
    state_dict = torch.load(decoder_path)
    decoder.load_state_dict(state_dict)

    new_state_dict = OrderedDict()
    state_dict = torch.load(Trans_path)
    Trans.load_state_dict(state_dict)

    new_state_dict = OrderedDict()
    state_dict = torch.load(embedding_path)
    embedding.load_state_dict(state_dict)

    network = StyTR.StyTrans(vgg,decoder,embedding,Trans)
    network.eval()
    content_tf = content_transform() 
    style_tf   = style_transform(style_size,style_size)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    network.to(device)
    content = content_tf(content_img.convert("RGB"))   
    style = style_tf(style_img.convert("RGB"))
    style = style.to(device).unsqueeze(0)
    content = content.to(device).unsqueeze(0)
    with torch.no_grad():
        output= network(content,style)  
    output = output[0].cpu().squeeze()
    output = output.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
    return Image.fromarray(output)
   
############################################## STYLE-FAST #############################################
style_transfer_model = tfhub.load("https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2")

def StyleFAST(content_image:Image.Image, style_image:Image.Image) -> Image.Image:
    content_image = tf.convert_to_tensor(np.array(content_image), np.float32)[tf.newaxis, ...] / 255.
    style_image = tf.convert_to_tensor(np.array(style_image), np.float32)[tf.newaxis, ...] / 255.
    output = style_transfer_model(content_image, style_image)
    stylized_image = output[0]
    return Image.fromarray(np.uint8(stylized_image[0] * 255))

########################################### STYLE PROJECTION ##########################################
os.system("hub install stylepro_artistic==1.0.1")
stylepro_artistic = phub.Module(name="stylepro_artistic")
def StyleProjection(content_image:Image.Image,style_image:Image.Image) -> Image.Image:
    print('line92')
    result = stylepro_artistic.style_transfer(
    images=[{
        'content': np.array(content_image.convert('RGB') )[:, :, ::-1],
        'styles': [np.array(style_image.convert('RGB') )[:, :, ::-1]]}],
    alpha=0.8)
    print('line97')
    return Image.fromarray(np.uint8(result[0]['data'])[:,:,::-1]).convert('RGB')

def create_styledSofa(content_image:Image.Image,style_image:Image.Image,choice:str) -> Image.Image:
    if choice =="Style Transformer":
        output = StyleTransformer(content_image,style_image)
    elif choice =="Style FAST":
        output = StyleFAST(content_image,style_image)
    elif choice =="Style Projection":
        output = StyleProjection(content_image,style_image)
    else:
        output = content_image
    return output