YANGYYYY commited on
Commit
922e494
1 Parent(s): 9791f04

Upload 8 files

Browse files
utils/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .common import *
2
+ from .image_processing import *
3
+
4
+ class DefaultArgs:
5
+ dataset ='Hayao'
6
+ data_dir ='/content'
7
+ epochs = 10
8
+ batch_size = 1
9
+ checkpoint_dir ='/content/checkpoints'
10
+ save_image_dir ='/content/images'
11
+ display_image =True
12
+ save_interval =2
13
+ debug_samples =0
14
+ lr_g = 0.001
15
+ lr_d = 0.002
16
+ wadvg = 300.0
17
+ wadvd = 300.0
18
+ wcon = 1.5
19
+ wgra = 3
20
+ wcol = 10
21
+ use_sn = False
utils/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (739 Bytes). View file
 
utils/__pycache__/common.cpython-39.pyc ADDED
Binary file (4.29 kB). View file
 
utils/__pycache__/image_processing.cpython-39.pyc ADDED
Binary file (2.83 kB). View file
 
utils/common.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gc
3
+ import os
4
+ import torch.nn as nn
5
+ import urllib.request
6
+ import cv2
7
+ from tqdm import tqdm
8
+
9
+ HTTP_PREFIXES = [
10
+ 'http',
11
+ 'data:image/jpeg',
12
+ ]
13
+
14
+
15
+ RELEASED_WEIGHTS = {
16
+ "hayao:v2": (
17
+ # Dataset trained on Google Landmark micro as training real photo
18
+ "v2",
19
+ "https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.1/GeneratorV2_gldv2_Hayao.pt"
20
+ ),
21
+ "hayao:v1": (
22
+ "v1",
23
+ "https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.0/generator_hayao.pth"
24
+ ),
25
+ "hayao": (
26
+ "v1",
27
+ "https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.0/generator_hayao.pth"
28
+ ),
29
+ "shinkai:v1": (
30
+ "v1",
31
+ "https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.0/generator_shinkai.pth"
32
+ ),
33
+ "shinkai": (
34
+ "v1",
35
+ "https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.0/generator_shinkai.pth"
36
+ ),
37
+ }
38
+
39
+ def is_image_file(path):
40
+ _, ext = os.path.splitext(path)
41
+ return ext.lower() in (".png", ".jpg", ".jpeg")
42
+
43
+
44
+ def read_image(path):
45
+ """
46
+ Read image from given path
47
+ """
48
+
49
+ if any(path.startswith(p) for p in HTTP_PREFIXES):
50
+ urllib.request.urlretrieve(path, "temp.jpg")
51
+ path = "temp.jpg"
52
+
53
+ return cv2.imread(path)[: ,: ,::-1]
54
+
55
+
56
+ def save_checkpoint(model, path, optimizer=None, epoch=None):
57
+ checkpoint = {
58
+ 'model_state_dict': model.state_dict(),
59
+ 'epoch': epoch,
60
+ }
61
+ if optimizer is not None:
62
+ checkpoint['optimizer_state_dict'] = optimizer.state_dict()
63
+
64
+ torch.save(checkpoint, path)
65
+
66
+ def maybe_remove_module(state_dict):
67
+ # Remove added module ins state_dict in ddp training
68
+ # https://discuss.pytorch.org/t/why-are-state-dict-keys-getting-prepended-with-the-string-module/104627/3
69
+ new_state_dict = {}
70
+ module_str = 'module.'
71
+ for k, v in state_dict.items():
72
+
73
+ if k.startswith(module_str):
74
+ k = k[len(module_str):]
75
+ new_state_dict[k] = v
76
+ return new_state_dict
77
+
78
+
79
+ def load_checkpoint(model, path, optimizer=None, strip_optimizer=False, map_location=None) -> int:
80
+ state_dict = load_state_dict(path, map_location)
81
+ model_state_dict = maybe_remove_module(state_dict['model_state_dict'])
82
+ model.load_state_dict(
83
+ model_state_dict,
84
+ strict=True
85
+ )
86
+ if 'optimizer_state_dict' in state_dict:
87
+ if optimizer is not None:
88
+ optimizer.load_state_dict(state_dict['optimizer_state_dict'])
89
+ if strip_optimizer:
90
+ del state_dict["optimizer_state_dict"]
91
+ torch.save(state_dict, path)
92
+ print(f"Optimizer stripped and saved to {path}")
93
+
94
+ epoch = state_dict.get('epoch', 0)
95
+ return epoch
96
+
97
+
98
+ def load_state_dict(weight, map_location) -> dict:
99
+ if weight.lower() in RELEASED_WEIGHTS:
100
+ weight = _download_weight(weight.lower())
101
+
102
+ if map_location is None:
103
+ # auto select
104
+ map_location = 'cuda' if torch.cuda.is_available() else 'cpu'
105
+ state_dict = torch.load(weight, map_location=map_location)
106
+
107
+ return state_dict
108
+
109
+
110
+ def initialize_weights(net):
111
+ for m in net.modules():
112
+ try:
113
+ if isinstance(m, nn.Conv2d):
114
+ # m.weight.data.normal_(0, 0.02)
115
+ torch.nn.init.xavier_uniform_(m.weight)
116
+ m.bias.data.zero_()
117
+ elif isinstance(m, nn.ConvTranspose2d):
118
+ # m.weight.data.normal_(0, 0.02)
119
+ torch.nn.init.xavier_uniform_(m.weight)
120
+ m.bias.data.zero_()
121
+ elif isinstance(m, nn.Linear):
122
+ # m.weight.data.normal_(0, 0.02)
123
+ torch.nn.init.xavier_uniform_(m.weight)
124
+ m.bias.data.zero_()
125
+ elif isinstance(m, nn.BatchNorm2d):
126
+ m.weight.data.fill_(1)
127
+ m.bias.data.zero_()
128
+ except Exception as e:
129
+ # print(f'SKip layer {m}, {e}')
130
+ pass
131
+
132
+
133
+ def set_lr(optimizer, lr):
134
+ for param_group in optimizer.param_groups:
135
+ param_group['lr'] = lr
136
+
137
+
138
+ class DownloadProgressBar(tqdm):
139
+ '''
140
+ https://stackoverflow.com/questions/15644964/python-progress-bar-and-downloads
141
+ '''
142
+ def update_to(self, b=1, bsize=1, tsize=None):
143
+ if tsize is not None:
144
+ self.total = tsize
145
+ self.update(b * bsize - self.n)
146
+
147
+
148
+ def _download_weight(weight):
149
+ '''
150
+ Download weight and save to local file
151
+ '''
152
+ os.makedirs('.cache', exist_ok=True)
153
+ url = RELEASED_WEIGHTS[weight][1]
154
+ filename = os.path.basename(url)
155
+ save_path = f'.cache/{filename}'
156
+
157
+ if os.path.isfile(save_path):
158
+ return save_path
159
+
160
+ desc = f'Downloading {url} to {save_path}'
161
+ with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc=desc) as t:
162
+ urllib.request.urlretrieve(url, save_path, reporthook=t.update_to)
163
+
164
+ return save_path
165
+
utils/fast_numpyio.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # code from https://github.com/divideconcept/fastnumpyio/blob/main/fastnumpyio.py
2
+
3
+ import sys
4
+ import numpy as np
5
+ import numpy.lib.format
6
+ import struct
7
+
8
+ def save(file, array):
9
+ magic_string=b"\x93NUMPY\x01\x00v\x00"
10
+ header=bytes(("{'descr': '"+array.dtype.descr[0][1]+"', 'fortran_order': False, 'shape': "+str(array.shape)+", }").ljust(127-len(magic_string))+"\n",'utf-8')
11
+ if type(file) == str:
12
+ file=open(file,"wb")
13
+ file.write(magic_string)
14
+ file.write(header)
15
+ file.write(array.data)
16
+
17
+ def pack(array):
18
+ size=len(array.shape)
19
+ return bytes(array.dtype.byteorder.replace('=','<' if sys.byteorder == 'little' else '>')+array.dtype.kind,'utf-8')+array.dtype.itemsize.to_bytes(1,byteorder='little')+struct.pack(f'<B{size}I',size,*array.shape)+array.data
20
+
21
+ def load(file):
22
+ if type(file) == str:
23
+ file=open(file,"rb")
24
+ header = file.read(128)
25
+ if not header:
26
+ return None
27
+ descr = str(header[19:25], 'utf-8').replace("'","").replace(" ","")
28
+ shape = tuple(int(num) for num in str(header[60:120], 'utf-8').replace(', }', '').replace('(', '').replace(')', '').split(','))
29
+ datasize = numpy.lib.format.descr_to_dtype(descr).itemsize
30
+ for dimension in shape:
31
+ datasize *= dimension
32
+ return np.ndarray(shape, dtype=descr, buffer=file.read(datasize))
33
+
34
+ def unpack(data):
35
+ dtype = str(data[:2],'utf-8')
36
+ dtype += str(data[2])
37
+ size = data[3]
38
+ shape = struct.unpack_from(f'<{size}I', data, 4)
39
+ datasize=data[2]
40
+ for dimension in shape:
41
+ datasize *= dimension
42
+ return np.ndarray(shape, dtype=dtype, buffer=data[4+size*4:4+size*4+datasize])
43
+
utils/image_processing.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import cv2
3
+ import os
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+
7
+
8
+ def gram(input):
9
+ """
10
+ Calculate Gram Matrix
11
+
12
+ https://pytorch.org/tutorials/advanced/neural_style_tutorial.html#style-loss
13
+ """
14
+ b, c, w, h = input.size()
15
+
16
+ x = input.contiguous().view(b * c, w * h)
17
+
18
+ # x = x / 2
19
+
20
+ # Work around, torch.mm would generate some inf values.
21
+ # https://discuss.pytorch.org/t/gram-matrix-in-mixed-precision/166800/2
22
+ # x = torch.clamp(x, max=1.0e2, min=-1.0e2)
23
+ # x[x > 1.0e2] = 1.0e2
24
+ # x[x < -1.0e2] = -1.0e2
25
+
26
+ G = torch.mm(x, x.T)
27
+ G = torch.clamp(G, -64990.0, 64990.0)
28
+ # normalize by total elements
29
+ result = G.div(b * c * w * h)
30
+ return result
31
+
32
+
33
+
34
+ def divisible(dim):
35
+ '''
36
+ Make width and height divisible by 32
37
+ '''
38
+ width, height = dim
39
+ return width - (width % 32), height - (height % 32)
40
+
41
+
42
+ def resize_image(image, width=None, height=None, inter=cv2.INTER_AREA):
43
+ dim = None
44
+ h, w = image.shape[:2]
45
+
46
+ if width and height:
47
+ return cv2.resize(image, divisible((width, height)), interpolation=inter)
48
+
49
+ if width is None and height is None:
50
+ return cv2.resize(image, divisible((w, h)), interpolation=inter)
51
+
52
+ if width is None:
53
+ r = height / float(h)
54
+ dim = (int(w * r), height)
55
+
56
+ else:
57
+ r = width / float(w)
58
+ dim = (width, int(h * r))
59
+
60
+ return cv2.resize(image, divisible(dim), interpolation=inter)
61
+
62
+
63
+ def normalize_input(images):
64
+ '''
65
+ [0, 255] -> [-1, 1]
66
+ '''
67
+ return images / 127.5 - 1.0
68
+
69
+
70
+ def denormalize_input(images, dtype=None):
71
+ '''
72
+ [-1, 1] -> [0, 255]
73
+ '''
74
+ images = images * 127.5 + 127.5
75
+
76
+ if dtype is not None:
77
+ if isinstance(images, torch.Tensor):
78
+ images = images.type(dtype)
79
+ else:
80
+ # numpy.ndarray
81
+ images = images.astype(dtype)
82
+
83
+ return images
84
+
85
+
86
+ def preprocess_images(images):
87
+ '''
88
+ Preprocess image for inference
89
+
90
+ @Arguments:
91
+ - images: np.ndarray
92
+
93
+ @Returns
94
+ - images: torch.tensor
95
+ '''
96
+ images = images.astype(np.float32)
97
+
98
+ # Normalize to [-1, 1]
99
+ images = normalize_input(images)
100
+ images = torch.from_numpy(images)
101
+
102
+ # Add batch dim
103
+ if len(images.shape) == 3:
104
+ images = images.unsqueeze(0)
105
+
106
+ # channel first
107
+ images = images.permute(0, 3, 1, 2)
108
+
109
+ return images
110
+
111
+ def compute_data_mean(data_folder):
112
+ if not os.path.exists(data_folder):
113
+ raise FileNotFoundError(f'Folder {data_folder} does not exits')
114
+
115
+ image_files = os.listdir(data_folder)
116
+ total = np.zeros(3)
117
+
118
+ print(f"Compute mean (R, G, B) from {len(image_files)} images")
119
+
120
+ for img_file in tqdm(image_files):
121
+ path = os.path.join(data_folder, img_file)
122
+ image = cv2.imread(path)
123
+ total += image.mean(axis=(0, 1))
124
+
125
+ channel_mean = total / len(image_files)
126
+ mean = np.mean(channel_mean)
127
+
128
+ return mean - channel_mean[...,::-1] # Convert to BGR for training
129
+
130
+
131
+ if __name__ == '__main__':
132
+ t = torch.rand(2, 14, 32, 32)
133
+
134
+ with torch.autocast("cpu"):
135
+ print(gram(t))
utils/logger.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+
4
+ def get_logger(path, *args, **kwargs):
5
+ # logger = logging.getLogger('train')
6
+ # logger.setLevel(logging.NOTSET)
7
+ # formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
8
+ # # add filehandler
9
+ # fh = logging.FileHandler(path)
10
+ # fh.setLevel(logging.NOTSET)
11
+ # fh.setFormatter(formatter)
12
+ # ch = logging.StreamHandler()
13
+ # ch.setLevel(logging.ERROR)
14
+ # logger.addHandler(fh)
15
+ # logger.addHandler(ch)
16
+ # return logger
17
+ logging.basicConfig(format = '%(asctime)s %(message)s',
18
+ datefmt = '%m/%d/%Y %I:%M:%S %p',
19
+ handlers=[
20
+ logging.FileHandler(path),
21
+ logging.StreamHandler()
22
+ ],
23
+ level=logging.DEBUG)
24
+ return logging