File size: 2,220 Bytes
be484a4
 
 
 
 
 
 
 
 
 
fcae87c
be484a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fcae87c
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch
import torch.nn as nn
from enum import Enum

import base64
import json
from io import BytesIO
from PIL import Image
import requests
import re
from copy import deepcopy

class ImageType(Enum):
    REAL_UP_L = 0
    REAL_UP_R = 1
    REAL_DOWN_R = 2
    REAL_DOWN_L = 3
    FAKE = 4


def crop_image_part(image: torch.Tensor,
                    part: ImageType) -> torch.Tensor:
    size = image.shape[2] // 2

    if part == ImageType.REAL_UP_L:
        return image[:, :, :size, :size]

    elif part == ImageType.REAL_UP_R:
        return image[:, :, :size, size:]

    elif part == ImageType.REAL_DOWN_L:
        return image[:, :, size:, :size]

    elif part == ImageType.REAL_DOWN_R:
        return image[:, :, size:, size:]

    else:
        raise ValueError('invalid part')


def init_weights(module: nn.Module):
    if isinstance(module, nn.Conv2d):
        torch.nn.init.normal_(module.weight, 0.0, 0.02)

    if isinstance(module, nn.BatchNorm2d):
        torch.nn.init.normal_(module.weight, 1.0, 0.02)
        module.bias.data.fill_(0)

def load_image_from_local(image_path, image_resize=None):
    image = Image.open(image_path)

    if isinstance(image_resize, tuple):
        image = image.resize(image_resize)
    return image

def load_image_from_url(image_url, rgba_mode=False, image_resize=None, default_image=None):
    try:
        image = Image.open(requests.get(image_url, stream=True).raw)

        if rgba_mode:
            image = image.convert("RGBA")

        if isinstance(image_resize, tuple):
            image = image.resize(image_resize)

    except Exception as e:
        image = None
        if default_image:
            image = load_image_from_local(default_image, image_resize=image_resize)

    return image

def image_to_base64(image_array):
    buffered = BytesIO()
    image_array.save(buffered, format="PNG")
    image_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
    return f"data:image/png;base64, {image_b64}"


def copy_G_params(model):
    flatten = deepcopy(list(p.data for p in model.parameters()))
    return flatten


def load_params(model, new_param):
    for p, new_p in zip(model.parameters(), new_param):
        p.data.copy_(new_p)