File size: 8,193 Bytes
eca77db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
import os 
import sys
sys.path.append(os.path.abspath(os.path.pardir))
from argparse import ArgumentParser

import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim

from torchvision import transforms
from torchvision.models import vgg19

from PIL import Image
import matplotlib.pyplot as plt

from nst.models.vgg19 import VGG19
from nst.losses import ContentLoss, StyleLoss

from tqdm import tqdm
from typing import List, Union

def main() -> None:
    # command line args 
    parser = ArgumentParser()
    parser.add_argument("--use_gpu", default=True, type=bool)
    parser.add_argument("--content_dir", default="../images/content/dancing.jpg", type=str)
    parser.add_argument("--style_dir", default="../images/style/picasso.jpg", type=str)
    parser.add_argument("--input_image", default="content", type=str)
    parser.add_argument("--output_dir", default="../result/result.jpg", type=str)
    parser.add_argument("--iterations", default=100, type=int)
    parser.add_argument("--alpha", default=1, type=int)
    parser.add_argument("--beta", default=1000000, type=int)
    parser.add_argument("--style_layer_weight", default=1.0, type=float)
    args = parser.parse_args()

    device = torch.device("cuda") if (torch.cuda.is_available() and args.use_gpu) else torch.device("cpu")
    print(f"training on device {device}")

    # content and style images
    content = image_loader(args.content_dir, device)
    style = image_loader(args.style_dir, device)

    # input image
    if args.input_image == "content":
        x = content.clone()
    elif args.input_image == "style":
        x = style.clone()
    else:
        x = torch.randn(content.data.size(), device=device)

    # mean and std for vgg19
    mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
    std = torch.tensor([0.229, 0.224, 0.225]).to(device)

    # vgg19 model
    model = VGG19(mean=mean, std=std).to(device=device)
    model = load_vgg19_weights(model, device)
    # LBFGS optimizer like in paper
    optimizer = optim.LBFGS([x.requires_grad_()])

    # computing content and style representations
    content_outputs = model(content)
    style_outputs = model(style)

    # defining content and style losses
    content_loss = ContentLoss(content_outputs["conv4"][1], device)
    style_losses = []
    for i in range(1, 6):
        style_losses.append(StyleLoss(style_outputs[f"conv{i}"][0], device))

    # run style transfer
    output = train(model, optimizer, content_loss, style_losses, x,
                   iterations=args.iterations, alpha=args.alpha, beta=args.beta, 
                   style_weight=args.style_layer_weight)
    output = output.detach().to("cpu")

    # save result
    plt.imsave(args.output_dir, output[0].permute(1, 2, 0).numpy())

def image_loader(path: str, device: torch.device=torch.device("cuda")) -> torch.Tensor:
    """
    Loads and resizes the image.

    Args:
        path (str): Path to the image.
        device (torch.device): device to load the image in.

    Returns:
        img (torch.Tensor): Loaded image as torch.Tensor.
    """
    transform = transforms.Compose([
                    transforms.Resize((512, 512)),
                    transforms.ToTensor(),
                ])
    img = Image.open(path)
    img = transform(img)
    img = img.unsqueeze(0).to(device=device)
    return img

def load_vgg19_weights(model: nn.Module, device: torch.device) -> nn.Module:
    """
    Loads VGG19 pretrained weights from ImageNet for style transfer.
    
    Args:
        model (nn.Module): VGG19 feature module with randomized weights.
        device (torch.device): The device to load the model in. 

    Returns:
        model (nn.Module): VGG19 module with pretrained ImageNet weights loaded.
    """
    pretrained_model = vgg19(pretrained=True).features.to(device).eval()

    matching_keys = {
        "conv1.conv1.weight": "0.weight",
        "conv1.conv1.bias": "0.bias",
        "conv1.conv2.weight": "2.weight",
        "conv1.conv2.bias": "2.bias",

        "conv2.conv1.weight": "5.weight",
        "conv2.conv1.bias": "5.bias",
        "conv2.conv2.weight": "7.weight",
        "conv2.conv2.bias": "7.bias",

        "conv3.conv1.weight": "10.weight",
        "conv3.conv1.bias": "10.bias",
        "conv3.conv2.weight": "12.weight",
        "conv3.conv2.bias": "12.bias",
        "conv3.conv3.weight": "14.weight",
        "conv3.conv3.bias": "14.bias",
        "conv3.conv4.weight": "16.weight",
        "conv3.conv4.bias": "16.bias",

        "conv4.conv1.weight": "19.weight",
        "conv4.conv1.bias": "19.bias",
        "conv4.conv2.weight": "21.weight",
        "conv4.conv2.bias": "21.bias",
        "conv4.conv3.weight": "23.weight",
        "conv4.conv3.bias": "23.bias",
        "conv4.conv4.weight": "25.weight",
        "conv4.conv4.bias": "25.bias",

        "conv5.conv1.weight": "28.weight",
        "conv5.conv1.bias": "28.bias",
        "conv5.conv2.weight": "30.weight",
        "conv5.conv2.bias": "30.bias",
        "conv5.conv3.weight": "32.weight",
        "conv5.conv3.bias": "32.bias",
        "conv5.conv4.weight": "34.weight",
        "conv5.conv4.bias": "34.bias",
    }

    pretrained_dict = pretrained_model.state_dict()
    model_dict = model.state_dict()

    for key, value in matching_keys.items():
        model_dict[key] = pretrained_dict[value]
    
    model.load_state_dict(model_dict)

    return model


def train(model: nn.Module, optimizer: torch.optim, content_loss: ContentLoss, style_losses: List[StyleLoss], 
          x: torch.Tensor, iterations: int=100, alpha: int=1, beta: int=1000000, style_weight: Union[int, float]=1.0) -> torch.Tensor:
    """
    Train the neural style transfer algorithm.

    Args:
        model (nn.Module): The VGG19 feature extractor for training the style transfer algorithm.
        optimizer (torch.optim): The optimization module to use.
        content_loss (ContentLoss): The content loss to preserve the content representation during style transfer.
        style_losses (List[StyleLoss]): A list of style loss objects to preserve the style representation across 
                                        different layers during style transfer.
        x (torch.Tensor): The input image for style transfer.
        iterations (int): Number of iterations to run.
        alpha (int): The weight given to content loss while computing the total loss.
        beta (int): The weight given to style loss while computing the total loss.
        style_weight Union[int, float]: The weight given to style loss of each layer while computing total style loss.

    Returns:
        x (torch.Tensor): The input image with the content and style transfered.
    """

    with tqdm(range(iterations)) as iterations:
        for iteration in iterations:
            iterations.set_description(f"Iteration: {iteration}")

            def closure():
                optimizer.zero_grad()

                # correcting to 0-1 range
                x.data.clamp_(0, 1)
                outputs = model(x)

                # input content and style representations
                content_feature_maps = outputs["conv4"][1]
                style_feature_maps = []
                for i in range(1, 6):
                    style_feature_maps.append(outputs[f"conv{i}"][0])

                # input content and style losses
                total_content_loss = content_loss(content_feature_maps)

                total_style_loss = 0
                for feature_map, style_loss in zip(style_feature_maps, style_losses):
                    total_style_loss += (style_weight * style_loss(feature_map))

                # total loss
                loss = (alpha * total_content_loss) + (beta * total_style_loss)
                loss.backward()

                iterations.set_postfix({
                    "content loss": total_content_loss.item(),
                    "style loss": total_style_loss.item(),
                    "total loss": loss.item()
                })

                return loss

            optimizer.step(closure)

    # final correction
    x.data.clamp_(0, 1)
    return x


if __name__ == "__main__":
    main()