boomb0om commited on
Commit
a86fc61
1 Parent(s): eac7b6f

Update model card files

Browse files
README.md CHANGED
@@ -1,36 +1,25 @@
1
- # Real-ESRGAN
2
- PyTorch implementation of a Real-ESRGAN model trained on custom dataset. This model shows better results on faces compared to the original version. It is also easier to integrate this model into your projects.
3
-
4
- You can try it in [google colab](https://colab.research.google.com/drive/1yO6deHTscL7FBcB6_SRzbxRr1nVtuZYE?usp=sharing)
5
-
6
- - Paper: [Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data](https://arxiv.org/abs/2107.10833)
7
- - [Official github](https://github.com/xinntao/Real-ESRGAN)
8
-
9
- ### Installation
10
-
11
  ---
12
 
13
- 1. Clone repo
14
-
15
- ```bash
16
- git clone https://https://github.com/sberbank-ai/Real-ESRGAN
17
- cd Real-ESRGAN
18
- ```
19
-
20
- 2. Install requirements
21
-
22
- ```bash
23
- pip install -r requirements.txt
24
- ```
25
 
26
- 3. Download [pretrained weights](https://drive.google.com/drive/folders/16PlVKhTNkSyWFx52RPb2hXPIQveNGbxS) and put them into `weights/` folder
27
 
28
- ### Usage
29
 
30
- ---
 
 
31
 
32
- Basic example:
33
 
 
34
  ```python
35
  import torch
36
  from PIL import Image
@@ -48,5 +37,4 @@ image = Image.open(path_to_image).convert('RGB')
48
  sr_image = model.predict(image)
49
 
50
  sr_image.save('results/sr_image.png')
51
- ```
52
-
1
+ ---
2
+ language:
3
+ - ru
4
+ - en
5
+ tags:
6
+ - PyTorch
7
+ thumbnail: "https://github.com/sberbank-ai/Real-ESRGAN"
 
 
 
8
  ---
9
 
10
+ # Real-ESRGAN
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ PyTorch implementation of a Real-ESRGAN model trained on custom dataset. This model shows better results on faces compared to the original version. It is also easier to integrate this model into your projects.
13
 
14
+ Real-ESRGAN is an upgraded ESRGAN trained with pure synthetic data is capable of enhancing details while removing annoying artifacts for common real-world images.
15
 
16
+ - [Paper](https://arxiv.org/abs/2107.10833)
17
+ - [Official github](https://github.com/xinntao/Real-ESRGAN)
18
+ - [Ours github](https://github.com/sberbank-ai/Real-ESRGAN)
19
 
20
+ ## Usage
21
 
22
+ Code for using model you can obtain in our [repo](https://github.com/sberbank-ai/Real-ESRGAN).
23
  ```python
24
  import torch
25
  from PIL import Image
37
  sr_image = model.predict(image)
38
 
39
  sr_image.save('results/sr_image.png')
40
+ ```
 
weights/RealESRGAN_x2.pth → RealESRGAN_x2.pth RENAMED
File without changes
weights/RealESRGAN_x4.pth → RealESRGAN_x4.pth RENAMED
File without changes
weights/RealESRGAN_x8.pth → RealESRGAN_x8.pth RENAMED
File without changes
arch_util.py DELETED
@@ -1,197 +0,0 @@
1
- import math
2
- import torch
3
- from torch import nn as nn
4
- from torch.nn import functional as F
5
- from torch.nn import init as init
6
- from torch.nn.modules.batchnorm import _BatchNorm
7
-
8
- @torch.no_grad()
9
- def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
10
- """Initialize network weights.
11
-
12
- Args:
13
- module_list (list[nn.Module] | nn.Module): Modules to be initialized.
14
- scale (float): Scale initialized weights, especially for residual
15
- blocks. Default: 1.
16
- bias_fill (float): The value to fill bias. Default: 0
17
- kwargs (dict): Other arguments for initialization function.
18
- """
19
- if not isinstance(module_list, list):
20
- module_list = [module_list]
21
- for module in module_list:
22
- for m in module.modules():
23
- if isinstance(m, nn.Conv2d):
24
- init.kaiming_normal_(m.weight, **kwargs)
25
- m.weight.data *= scale
26
- if m.bias is not None:
27
- m.bias.data.fill_(bias_fill)
28
- elif isinstance(m, nn.Linear):
29
- init.kaiming_normal_(m.weight, **kwargs)
30
- m.weight.data *= scale
31
- if m.bias is not None:
32
- m.bias.data.fill_(bias_fill)
33
- elif isinstance(m, _BatchNorm):
34
- init.constant_(m.weight, 1)
35
- if m.bias is not None:
36
- m.bias.data.fill_(bias_fill)
37
-
38
-
39
- def make_layer(basic_block, num_basic_block, **kwarg):
40
- """Make layers by stacking the same blocks.
41
-
42
- Args:
43
- basic_block (nn.module): nn.module class for basic block.
44
- num_basic_block (int): number of blocks.
45
-
46
- Returns:
47
- nn.Sequential: Stacked blocks in nn.Sequential.
48
- """
49
- layers = []
50
- for _ in range(num_basic_block):
51
- layers.append(basic_block(**kwarg))
52
- return nn.Sequential(*layers)
53
-
54
-
55
- class ResidualBlockNoBN(nn.Module):
56
- """Residual block without BN.
57
-
58
- It has a style of:
59
- ---Conv-ReLU-Conv-+-
60
- |________________|
61
-
62
- Args:
63
- num_feat (int): Channel number of intermediate features.
64
- Default: 64.
65
- res_scale (float): Residual scale. Default: 1.
66
- pytorch_init (bool): If set to True, use pytorch default init,
67
- otherwise, use default_init_weights. Default: False.
68
- """
69
-
70
- def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
71
- super(ResidualBlockNoBN, self).__init__()
72
- self.res_scale = res_scale
73
- self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
74
- self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
75
- self.relu = nn.ReLU(inplace=True)
76
-
77
- if not pytorch_init:
78
- default_init_weights([self.conv1, self.conv2], 0.1)
79
-
80
- def forward(self, x):
81
- identity = x
82
- out = self.conv2(self.relu(self.conv1(x)))
83
- return identity + out * self.res_scale
84
-
85
-
86
- class Upsample(nn.Sequential):
87
- """Upsample module.
88
-
89
- Args:
90
- scale (int): Scale factor. Supported scales: 2^n and 3.
91
- num_feat (int): Channel number of intermediate features.
92
- """
93
-
94
- def __init__(self, scale, num_feat):
95
- m = []
96
- if (scale & (scale - 1)) == 0: # scale = 2^n
97
- for _ in range(int(math.log(scale, 2))):
98
- m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
99
- m.append(nn.PixelShuffle(2))
100
- elif scale == 3:
101
- m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
102
- m.append(nn.PixelShuffle(3))
103
- else:
104
- raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
105
- super(Upsample, self).__init__(*m)
106
-
107
-
108
- def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
109
- """Warp an image or feature map with optical flow.
110
-
111
- Args:
112
- x (Tensor): Tensor with size (n, c, h, w).
113
- flow (Tensor): Tensor with size (n, h, w, 2), normal value.
114
- interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
115
- padding_mode (str): 'zeros' or 'border' or 'reflection'.
116
- Default: 'zeros'.
117
- align_corners (bool): Before pytorch 1.3, the default value is
118
- align_corners=True. After pytorch 1.3, the default value is
119
- align_corners=False. Here, we use the True as default.
120
-
121
- Returns:
122
- Tensor: Warped image or feature map.
123
- """
124
- assert x.size()[-2:] == flow.size()[1:3]
125
- _, _, h, w = x.size()
126
- # create mesh grid
127
- grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
128
- grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
129
- grid.requires_grad = False
130
-
131
- vgrid = grid + flow
132
- # scale grid to [-1,1]
133
- vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
134
- vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
135
- vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
136
- output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
137
-
138
- # TODO, what if align_corners=False
139
- return output
140
-
141
-
142
- def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
143
- """Resize a flow according to ratio or shape.
144
-
145
- Args:
146
- flow (Tensor): Precomputed flow. shape [N, 2, H, W].
147
- size_type (str): 'ratio' or 'shape'.
148
- sizes (list[int | float]): the ratio for resizing or the final output
149
- shape.
150
- 1) The order of ratio should be [ratio_h, ratio_w]. For
151
- downsampling, the ratio should be smaller than 1.0 (i.e., ratio
152
- < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
153
- ratio > 1.0).
154
- 2) The order of output_size should be [out_h, out_w].
155
- interp_mode (str): The mode of interpolation for resizing.
156
- Default: 'bilinear'.
157
- align_corners (bool): Whether align corners. Default: False.
158
-
159
- Returns:
160
- Tensor: Resized flow.
161
- """
162
- _, _, flow_h, flow_w = flow.size()
163
- if size_type == 'ratio':
164
- output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
165
- elif size_type == 'shape':
166
- output_h, output_w = sizes[0], sizes[1]
167
- else:
168
- raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')
169
-
170
- input_flow = flow.clone()
171
- ratio_h = output_h / flow_h
172
- ratio_w = output_w / flow_w
173
- input_flow[:, 0, :, :] *= ratio_w
174
- input_flow[:, 1, :, :] *= ratio_h
175
- resized_flow = F.interpolate(
176
- input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
177
- return resized_flow
178
-
179
-
180
- # TODO: may write a cpp file
181
- def pixel_unshuffle(x, scale):
182
- """ Pixel unshuffle.
183
-
184
- Args:
185
- x (Tensor): Input feature with shape (b, c, hh, hw).
186
- scale (int): Downsample ratio.
187
-
188
- Returns:
189
- Tensor: the pixel unshuffled feature.
190
- """
191
- b, c, hh, hw = x.size()
192
- out_channel = c * (scale**2)
193
- assert hh % scale == 0 and hw % scale == 0
194
- h = hh // scale
195
- w = hw // scale
196
- x_view = x.view(b, c, h, scale, w, scale)
197
- return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inputs/lr_image.png DELETED
Binary file (263 kB)
realesrgan.py DELETED
@@ -1,55 +0,0 @@
1
- import torch
2
- from torch.nn import functional as F
3
- from PIL import Image
4
- import numpy as np
5
- import cv2
6
-
7
- from rrdbnet_arch import RRDBNet
8
- from utils_sr import *
9
-
10
-
11
- class RealESRGAN:
12
- def __init__(self, device, scale=4):
13
- self.device = device
14
- self.scale = scale
15
- self.model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=scale)
16
-
17
- def load_weights(self, model_path):
18
- loadnet = torch.load(model_path)
19
- if 'params' in loadnet:
20
- self.model.load_state_dict(loadnet['params'], strict=True)
21
- elif 'params_ema' in loadnet:
22
- self.model.load_state_dict(loadnet['params_ema'], strict=True)
23
- else:
24
- self.model.load_state_dict(loadnet, strict=True)
25
- self.model.eval()
26
- self.model.to(self.device)
27
-
28
- def predict(self, lr_image, batch_size=4, patches_size=192,
29
- padding=24, pad_size=15):
30
- scale = self.scale
31
- device = self.device
32
- lr_image = np.array(lr_image)
33
- lr_image = pad_reflect(lr_image, pad_size)
34
-
35
- patches, p_shape = split_image_into_overlapping_patches(lr_image, patch_size=patches_size,
36
- padding_size=padding)
37
- img = torch.FloatTensor(patches/255).permute((0,3,1,2)).to(device).detach()
38
-
39
- with torch.no_grad():
40
- res = self.model(img[0:batch_size])
41
- for i in range(batch_size, img.shape[0], batch_size):
42
- res = torch.cat((res, self.model(img[i:i+batch_size])), 0)
43
-
44
- sr_image = res.permute((0,2,3,1)).cpu().clamp_(0, 1)
45
- np_sr_image = sr_image.numpy()
46
-
47
- padded_size_scaled = tuple(np.multiply(p_shape[0:2], scale)) + (3,)
48
- scaled_image_shape = tuple(np.multiply(lr_image.shape[0:2], scale)) + (3,)
49
- np_sr_image = stich_together(np_sr_image, padded_image_shape=padded_size_scaled,
50
- target_shape=scaled_image_shape, padding_size=padding * scale)
51
- sr_img = (np_sr_image*255).astype(np.uint8)
52
- sr_img = unpad_image(sr_img, pad_size*scale)
53
- sr_img = Image.fromarray(sr_img)
54
-
55
- return sr_img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt DELETED
@@ -1,6 +0,0 @@
1
- numpy
2
- opencv-python
3
- Pillow
4
- torch>=1.7
5
- torchvision>=0.8.0
6
- tqdm
 
 
 
 
 
 
results/sr_image.png DELETED
Binary file (3.67 MB)
rrdbnet_arch.py DELETED
@@ -1,121 +0,0 @@
1
- import torch
2
- from torch import nn as nn
3
- from torch.nn import functional as F
4
-
5
- from arch_util import default_init_weights, make_layer, pixel_unshuffle
6
-
7
-
8
- class ResidualDenseBlock(nn.Module):
9
- """Residual Dense Block.
10
-
11
- Used in RRDB block in ESRGAN.
12
-
13
- Args:
14
- num_feat (int): Channel number of intermediate features.
15
- num_grow_ch (int): Channels for each growth.
16
- """
17
-
18
- def __init__(self, num_feat=64, num_grow_ch=32):
19
- super(ResidualDenseBlock, self).__init__()
20
- self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
21
- self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
22
- self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
23
- self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
24
- self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
25
-
26
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
27
-
28
- # initialization
29
- default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
30
-
31
- def forward(self, x):
32
- x1 = self.lrelu(self.conv1(x))
33
- x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
34
- x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
35
- x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
36
- x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
37
- # Emperically, we use 0.2 to scale the residual for better performance
38
- return x5 * 0.2 + x
39
-
40
-
41
- class RRDB(nn.Module):
42
- """Residual in Residual Dense Block.
43
-
44
- Used in RRDB-Net in ESRGAN.
45
-
46
- Args:
47
- num_feat (int): Channel number of intermediate features.
48
- num_grow_ch (int): Channels for each growth.
49
- """
50
-
51
- def __init__(self, num_feat, num_grow_ch=32):
52
- super(RRDB, self).__init__()
53
- self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
54
- self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
55
- self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
56
-
57
- def forward(self, x):
58
- out = self.rdb1(x)
59
- out = self.rdb2(out)
60
- out = self.rdb3(out)
61
- # Emperically, we use 0.2 to scale the residual for better performance
62
- return out * 0.2 + x
63
-
64
-
65
- class RRDBNet(nn.Module):
66
- """Networks consisting of Residual in Residual Dense Block, which is used
67
- in ESRGAN.
68
-
69
- ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
70
-
71
- We extend ESRGAN for scale x2 and scale x1.
72
- Note: This is one option for scale 1, scale 2 in RRDBNet.
73
- We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
74
- and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
75
-
76
- Args:
77
- num_in_ch (int): Channel number of inputs.
78
- num_out_ch (int): Channel number of outputs.
79
- num_feat (int): Channel number of intermediate features.
80
- Default: 64
81
- num_block (int): Block number in the trunk network. Defaults: 23
82
- num_grow_ch (int): Channels for each growth. Default: 32.
83
- """
84
-
85
- def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
86
- super(RRDBNet, self).__init__()
87
- self.scale = scale
88
- if scale == 2:
89
- num_in_ch = num_in_ch * 4
90
- elif scale == 1:
91
- num_in_ch = num_in_ch * 16
92
- self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
93
- self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
94
- self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
95
- # upsample
96
- self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
97
- self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
98
- if scale == 8:
99
- self.conv_up3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
100
- self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
101
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
102
-
103
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
104
-
105
- def forward(self, x):
106
- if self.scale == 2:
107
- feat = pixel_unshuffle(x, scale=2)
108
- elif self.scale == 1:
109
- feat = pixel_unshuffle(x, scale=4)
110
- else:
111
- feat = x
112
- feat = self.conv_first(feat)
113
- body_feat = self.conv_body(self.body(feat))
114
- feat = feat + body_feat
115
- # upsample
116
- feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
117
- feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
118
- if self.scale == 8:
119
- feat = self.lrelu(self.conv_up3(F.interpolate(feat, scale_factor=2, mode='nearest')))
120
- out = self.conv_last(self.lrelu(self.conv_hr(feat)))
121
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils_sr.py DELETED
@@ -1,141 +0,0 @@
1
- import numpy as np
2
- import torch
3
- from PIL import Image
4
- import os
5
- import io
6
- import imageio
7
-
8
- def pad_reflect(image, pad_size):
9
- imsize = image.shape
10
- height, width = imsize[:2]
11
- new_img = np.zeros([height+pad_size*2, width+pad_size*2, imsize[2]]).astype(np.uint8)
12
- new_img[pad_size:-pad_size, pad_size:-pad_size, :] = image
13
-
14
- new_img[0:pad_size, pad_size:-pad_size, :] = np.flip(image[0:pad_size, :, :], axis=0) #top
15
- new_img[-pad_size:, pad_size:-pad_size, :] = np.flip(image[-pad_size:, :, :], axis=0) #bottom
16
- new_img[:, 0:pad_size, :] = np.flip(new_img[:, pad_size:pad_size*2, :], axis=1) #left
17
- new_img[:, -pad_size:, :] = np.flip(new_img[:, -pad_size*2:-pad_size, :], axis=1) #right
18
-
19
- return new_img
20
-
21
- def unpad_image(image, pad_size):
22
- return image[pad_size:-pad_size, pad_size:-pad_size, :]
23
-
24
-
25
- def jpegBlur(im,q):
26
- buf = io.BytesIO()
27
- imageio.imwrite(buf,im,format='jpg',quality=q)
28
- s = buf.getbuffer()
29
- return imageio.imread(s,format='jpg')
30
-
31
-
32
- def process_array(image_array, expand=True):
33
- """ Process a 3-dimensional array into a scaled, 4 dimensional batch of size 1. """
34
-
35
- image_batch = image_array / 255.0
36
- if expand:
37
- image_batch = np.expand_dims(image_batch, axis=0)
38
- return image_batch
39
-
40
-
41
- def process_output(output_tensor):
42
- """ Transforms the 4-dimensional output tensor into a suitable image format. """
43
-
44
- sr_img = output_tensor.clip(0, 1) * 255
45
- sr_img = np.uint8(sr_img)
46
- return sr_img
47
-
48
-
49
- def pad_patch(image_patch, padding_size, channel_last=True):
50
- """ Pads image_patch with with padding_size edge values. """
51
-
52
- if channel_last:
53
- return np.pad(
54
- image_patch,
55
- ((padding_size, padding_size), (padding_size, padding_size), (0, 0)),
56
- 'edge',
57
- )
58
- else:
59
- return np.pad(
60
- image_patch,
61
- ((0, 0), (padding_size, padding_size), (padding_size, padding_size)),
62
- 'edge',
63
- )
64
-
65
-
66
- def unpad_patches(image_patches, padding_size):
67
- return image_patches[:, padding_size:-padding_size, padding_size:-padding_size, :]
68
-
69
-
70
- def split_image_into_overlapping_patches(image_array, patch_size, padding_size=2):
71
- """ Splits the image into partially overlapping patches.
72
- The patches overlap by padding_size pixels.
73
- Pads the image twice:
74
- - first to have a size multiple of the patch size,
75
- - then to have equal padding at the borders.
76
- Args:
77
- image_array: numpy array of the input image.
78
- patch_size: size of the patches from the original image (without padding).
79
- padding_size: size of the overlapping area.
80
- """
81
-
82
- xmax, ymax, _ = image_array.shape
83
- x_remainder = xmax % patch_size
84
- y_remainder = ymax % patch_size
85
-
86
- # modulo here is to avoid extending of patch_size instead of 0
87
- x_extend = (patch_size - x_remainder) % patch_size
88
- y_extend = (patch_size - y_remainder) % patch_size
89
-
90
- # make sure the image is divisible into regular patches
91
- extended_image = np.pad(image_array, ((0, x_extend), (0, y_extend), (0, 0)), 'edge')
92
-
93
- # add padding around the image to simplify computations
94
- padded_image = pad_patch(extended_image, padding_size, channel_last=True)
95
-
96
- xmax, ymax, _ = padded_image.shape
97
- patches = []
98
-
99
- x_lefts = range(padding_size, xmax - padding_size, patch_size)
100
- y_tops = range(padding_size, ymax - padding_size, patch_size)
101
-
102
- for x in x_lefts:
103
- for y in y_tops:
104
- x_left = x - padding_size
105
- y_top = y - padding_size
106
- x_right = x + patch_size + padding_size
107
- y_bottom = y + patch_size + padding_size
108
- patch = padded_image[x_left:x_right, y_top:y_bottom, :]
109
- patches.append(patch)
110
-
111
- return np.array(patches), padded_image.shape
112
-
113
-
114
- def stich_together(patches, padded_image_shape, target_shape, padding_size=4):
115
- """ Reconstruct the image from overlapping patches.
116
- After scaling, shapes and padding should be scaled too.
117
- Args:
118
- patches: patches obtained with split_image_into_overlapping_patches
119
- padded_image_shape: shape of the padded image contructed in split_image_into_overlapping_patches
120
- target_shape: shape of the final image
121
- padding_size: size of the overlapping area.
122
- """
123
-
124
- xmax, ymax, _ = padded_image_shape
125
- patches = unpad_patches(patches, padding_size)
126
- patch_size = patches.shape[1]
127
- n_patches_per_row = ymax // patch_size
128
-
129
- complete_image = np.zeros((xmax, ymax, 3))
130
-
131
- row = -1
132
- col = 0
133
- for i in range(len(patches)):
134
- if i % n_patches_per_row == 0:
135
- row += 1
136
- col = 0
137
- complete_image[
138
- row * patch_size: (row + 1) * patch_size, col * patch_size: (col + 1) * patch_size,:
139
- ] = patches[i]
140
- col += 1
141
- return complete_image[0: target_shape[0], 0: target_shape[1], :]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
weights/README.md DELETED
@@ -1,4 +0,0 @@
1
- # Pretrained weights
2
- Download pretrained weights there:
3
- https://drive.google.com/drive/folders/16PlVKhTNkSyWFx52RPb2hXPIQveNGbxS
4
-