2ch commited on
Commit
cac9914
1 Parent(s): 0106545

update controlnet

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. annotator/annotator_path.py +0 -22
  2. annotator/binary/__init__.py +0 -14
  3. annotator/canny/__init__.py +0 -5
  4. annotator/clipvision/__init__.py +0 -123
  5. annotator/color/__init__.py +0 -20
  6. annotator/hed/__init__.py +0 -98
  7. annotator/keypose/__init__.py +0 -212
  8. annotator/keypose/faster_rcnn_r50_fpn_coco.py +0 -182
  9. annotator/keypose/hrnet_w48_coco_256x192.py +0 -169
  10. annotator/lama/__init__.py +0 -58
  11. annotator/lama/config.yaml +0 -157
  12. annotator/lama/saicinpainting/__init__.py +0 -0
  13. annotator/lama/saicinpainting/training/__init__.py +0 -0
  14. annotator/lama/saicinpainting/training/data/__init__.py +0 -0
  15. annotator/lama/saicinpainting/training/data/masks.py +0 -332
  16. annotator/lama/saicinpainting/training/losses/__init__.py +0 -0
  17. annotator/lama/saicinpainting/training/losses/adversarial.py +0 -177
  18. annotator/lama/saicinpainting/training/losses/constants.py +0 -152
  19. annotator/lama/saicinpainting/training/losses/distance_weighting.py +0 -126
  20. annotator/lama/saicinpainting/training/losses/feature_matching.py +0 -33
  21. annotator/lama/saicinpainting/training/losses/perceptual.py +0 -113
  22. annotator/lama/saicinpainting/training/losses/segmentation.py +0 -43
  23. annotator/lama/saicinpainting/training/losses/style_loss.py +0 -155
  24. annotator/lama/saicinpainting/training/modules/__init__.py +0 -31
  25. annotator/lama/saicinpainting/training/modules/base.py +0 -80
  26. annotator/lama/saicinpainting/training/modules/depthwise_sep_conv.py +0 -17
  27. annotator/lama/saicinpainting/training/modules/fake_fakes.py +0 -47
  28. annotator/lama/saicinpainting/training/modules/ffc.py +0 -485
  29. annotator/lama/saicinpainting/training/modules/multidilated_conv.py +0 -98
  30. annotator/lama/saicinpainting/training/modules/multiscale.py +0 -244
  31. annotator/lama/saicinpainting/training/modules/pix2pixhd.py +0 -669
  32. annotator/lama/saicinpainting/training/modules/spatial_transform.py +0 -49
  33. annotator/lama/saicinpainting/training/modules/squeeze_excitation.py +0 -20
  34. annotator/lama/saicinpainting/training/trainers/__init__.py +0 -29
  35. annotator/lama/saicinpainting/training/trainers/base.py +0 -293
  36. annotator/lama/saicinpainting/training/trainers/default.py +0 -175
  37. annotator/lama/saicinpainting/training/visualizers/__init__.py +0 -15
  38. annotator/lama/saicinpainting/training/visualizers/base.py +0 -73
  39. annotator/lama/saicinpainting/training/visualizers/colors.py +0 -76
  40. annotator/lama/saicinpainting/training/visualizers/directory.py +0 -36
  41. annotator/lama/saicinpainting/training/visualizers/noop.py +0 -9
  42. annotator/lama/saicinpainting/utils.py +0 -174
  43. annotator/leres/__init__.py +0 -113
  44. annotator/leres/leres/LICENSE +0 -23
  45. annotator/leres/leres/Resnet.py +0 -199
  46. annotator/leres/leres/Resnext_torch.py +0 -237
  47. annotator/leres/leres/depthmap.py +0 -546
  48. annotator/leres/leres/multi_depth_model_woauxi.py +0 -34
  49. annotator/leres/leres/net_tools.py +0 -54
  50. annotator/leres/leres/network_auxi.py +0 -417
annotator/annotator_path.py DELETED
@@ -1,22 +0,0 @@
1
- import os
2
- from modules import shared
3
-
4
- models_path = shared.opts.data.get('control_net_modules_path', None)
5
- if not models_path:
6
- models_path = getattr(shared.cmd_opts, 'controlnet_annotator_models_path', None)
7
- if not models_path:
8
- models_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'downloads')
9
-
10
- if not os.path.isabs(models_path):
11
- models_path = os.path.join(shared.data_path, models_path)
12
-
13
- clip_vision_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'clip_vision')
14
- # clip vision is always inside controlnet "extensions\sd-webui-controlnet"
15
- # and any problem can be solved by removing controlnet and reinstall
16
-
17
- models_path = os.path.realpath(models_path)
18
- os.makedirs(models_path, exist_ok=True)
19
- print(f'ControlNet preprocessor location: {models_path}')
20
- # Make sure that the default location is inside controlnet "extensions\sd-webui-controlnet"
21
- # so that any problem can be solved by removing controlnet and reinstall
22
- # if users do not change configs on their own (otherwise users will know what is wrong)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/binary/__init__.py DELETED
@@ -1,14 +0,0 @@
1
- import cv2
2
-
3
-
4
- def apply_binary(img, bin_threshold):
5
- img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
6
-
7
- if bin_threshold == 0 or bin_threshold == 255:
8
- # Otsu's threshold
9
- otsu_threshold, img_bin = cv2.threshold(img_gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
10
- print("Otsu threshold:", otsu_threshold)
11
- else:
12
- _, img_bin = cv2.threshold(img_gray, bin_threshold, 255, cv2.THRESH_BINARY_INV)
13
-
14
- return cv2.cvtColor(img_bin, cv2.COLOR_GRAY2RGB)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/canny/__init__.py DELETED
@@ -1,5 +0,0 @@
1
- import cv2
2
-
3
-
4
- def apply_canny(img, low_threshold, high_threshold):
5
- return cv2.Canny(img, low_threshold, high_threshold)
 
 
 
 
 
 
annotator/clipvision/__init__.py DELETED
@@ -1,123 +0,0 @@
1
- import os
2
- import torch
3
-
4
- from modules import devices
5
- from modules.modelloader import load_file_from_url
6
- from annotator.annotator_path import models_path
7
- from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPImageProcessor, modeling_utils
8
-
9
-
10
- config_clip_g = {
11
- "attention_dropout": 0.0,
12
- "dropout": 0.0,
13
- "hidden_act": "gelu",
14
- "hidden_size": 1664,
15
- "image_size": 224,
16
- "initializer_factor": 1.0,
17
- "initializer_range": 0.02,
18
- "intermediate_size": 8192,
19
- "layer_norm_eps": 1e-05,
20
- "model_type": "clip_vision_model",
21
- "num_attention_heads": 16,
22
- "num_channels": 3,
23
- "num_hidden_layers": 48,
24
- "patch_size": 14,
25
- "projection_dim": 1280,
26
- "torch_dtype": "float32"
27
- }
28
-
29
- config_clip_h = {
30
- "attention_dropout": 0.0,
31
- "dropout": 0.0,
32
- "hidden_act": "gelu",
33
- "hidden_size": 1280,
34
- "image_size": 224,
35
- "initializer_factor": 1.0,
36
- "initializer_range": 0.02,
37
- "intermediate_size": 5120,
38
- "layer_norm_eps": 1e-05,
39
- "model_type": "clip_vision_model",
40
- "num_attention_heads": 16,
41
- "num_channels": 3,
42
- "num_hidden_layers": 32,
43
- "patch_size": 14,
44
- "projection_dim": 1024,
45
- "torch_dtype": "float32"
46
- }
47
-
48
- config_clip_vitl = {
49
- "attention_dropout": 0.0,
50
- "dropout": 0.0,
51
- "hidden_act": "quick_gelu",
52
- "hidden_size": 1024,
53
- "image_size": 224,
54
- "initializer_factor": 1.0,
55
- "initializer_range": 0.02,
56
- "intermediate_size": 4096,
57
- "layer_norm_eps": 1e-05,
58
- "model_type": "clip_vision_model",
59
- "num_attention_heads": 16,
60
- "num_channels": 3,
61
- "num_hidden_layers": 24,
62
- "patch_size": 14,
63
- "projection_dim": 768,
64
- "torch_dtype": "float32"
65
- }
66
-
67
- configs = {
68
- 'clip_g': config_clip_g,
69
- 'clip_h': config_clip_h,
70
- 'clip_vitl': config_clip_vitl,
71
- }
72
-
73
- downloads = {
74
- 'clip_vitl': 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/pytorch_model.bin',
75
- 'clip_g': 'https://huggingface.co/lllyasviel/Annotators/resolve/main/clip_g.pth',
76
- 'clip_h': 'https://huggingface.co/h94/IP-Adapter/resolve/main/models/image_encoder/pytorch_model.bin'
77
- }
78
-
79
-
80
- class ClipVisionDetector:
81
- def __init__(self, config):
82
- assert config in downloads
83
- self.download_link = downloads[config]
84
- self.model_path = os.path.join(models_path, 'clip_vision')
85
- self.file_name = config + '.pth'
86
- self.config = configs[config]
87
- self.device = devices.get_device_for("controlnet")
88
- os.makedirs(self.model_path, exist_ok=True)
89
- file_path = os.path.join(self.model_path, self.file_name)
90
- if not os.path.exists(file_path):
91
- load_file_from_url(url=self.download_link, model_dir=self.model_path, file_name=self.file_name)
92
- config = CLIPVisionConfig(**self.config)
93
- self.model = CLIPVisionModelWithProjection(config)
94
- self.processor = CLIPImageProcessor(crop_size=224,
95
- do_center_crop=True,
96
- do_convert_rgb=True,
97
- do_normalize=True,
98
- do_resize=True,
99
- image_mean=[0.48145466, 0.4578275, 0.40821073],
100
- image_std=[0.26862954, 0.26130258, 0.27577711],
101
- resample=3,
102
- size=224)
103
-
104
- sd = torch.load(file_path, map_location=torch.device('cpu'))
105
- self.model.load_state_dict(sd, strict=False)
106
- del sd
107
-
108
- self.model.eval()
109
- self.model.cpu()
110
-
111
- def unload_model(self):
112
- if self.model is not None:
113
- self.model.to('meta')
114
-
115
- def __call__(self, input_image):
116
- with torch.no_grad():
117
- clip_vision_model = self.model.cpu()
118
- feat = self.processor(images=input_image, return_tensors="pt")
119
- feat['pixel_values'] = feat['pixel_values'].cpu()
120
- result = clip_vision_model(**feat, output_hidden_states=True)
121
- result['hidden_states'] = [v.to(devices.get_device_for("controlnet")) for v in result['hidden_states']]
122
- result = {k: v.to(devices.get_device_for("controlnet")) if isinstance(v, torch.Tensor) else v for k, v in result.items()}
123
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/color/__init__.py DELETED
@@ -1,20 +0,0 @@
1
- import cv2
2
-
3
- def cv2_resize_shortest_edge(image, size):
4
- h, w = image.shape[:2]
5
- if h < w:
6
- new_h = size
7
- new_w = int(round(w / h * size))
8
- else:
9
- new_w = size
10
- new_h = int(round(h / w * size))
11
- resized_image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA)
12
- return resized_image
13
-
14
- def apply_color(img, res=512):
15
- img = cv2_resize_shortest_edge(img, res)
16
- h, w = img.shape[:2]
17
-
18
- input_img_color = cv2.resize(img, (w//64, h//64), interpolation=cv2.INTER_CUBIC)
19
- input_img_color = cv2.resize(input_img_color, (w, h), interpolation=cv2.INTER_NEAREST)
20
- return input_img_color
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/hed/__init__.py DELETED
@@ -1,98 +0,0 @@
1
- # This is an improved version and model of HED edge detection with Apache License, Version 2.0.
2
- # Please use this implementation in your products
3
- # This implementation may produce slightly different results from Saining Xie's official implementations,
4
- # but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations.
5
- # Different from official models and other implementations, this is an RGB-input model (rather than BGR)
6
- # and in this way it works better for gradio's RGB protocol
7
-
8
- import os
9
- import cv2
10
- import torch
11
- import numpy as np
12
-
13
- from einops import rearrange
14
- import os
15
- from modules import devices
16
- from annotator.annotator_path import models_path
17
- from annotator.util import safe_step, nms
18
-
19
-
20
- class DoubleConvBlock(torch.nn.Module):
21
- def __init__(self, input_channel, output_channel, layer_number):
22
- super().__init__()
23
- self.convs = torch.nn.Sequential()
24
- self.convs.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
25
- for i in range(1, layer_number):
26
- self.convs.append(torch.nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
27
- self.projection = torch.nn.Conv2d(in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0)
28
-
29
- def __call__(self, x, down_sampling=False):
30
- h = x
31
- if down_sampling:
32
- h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2))
33
- for conv in self.convs:
34
- h = conv(h)
35
- h = torch.nn.functional.relu(h)
36
- return h, self.projection(h)
37
-
38
-
39
- class ControlNetHED_Apache2(torch.nn.Module):
40
- def __init__(self):
41
- super().__init__()
42
- self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1)))
43
- self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2)
44
- self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2)
45
- self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3)
46
- self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3)
47
- self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3)
48
-
49
- def __call__(self, x):
50
- h = x - self.norm
51
- h, projection1 = self.block1(h)
52
- h, projection2 = self.block2(h, down_sampling=True)
53
- h, projection3 = self.block3(h, down_sampling=True)
54
- h, projection4 = self.block4(h, down_sampling=True)
55
- h, projection5 = self.block5(h, down_sampling=True)
56
- return projection1, projection2, projection3, projection4, projection5
57
-
58
-
59
- netNetwork = None
60
- remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/ControlNetHED.pth"
61
- modeldir = os.path.join(models_path, "hed")
62
- old_modeldir = os.path.dirname(os.path.realpath(__file__))
63
-
64
-
65
- def apply_hed(input_image, is_safe=False):
66
- global netNetwork
67
- if netNetwork is None:
68
- modelpath = os.path.join(modeldir, "ControlNetHED.pth")
69
- old_modelpath = os.path.join(old_modeldir, "ControlNetHED.pth")
70
- if os.path.exists(old_modelpath):
71
- modelpath = old_modelpath
72
- elif not os.path.exists(modelpath):
73
- from basicsr.utils.download_util import load_file_from_url
74
- load_file_from_url(remote_model_path, model_dir=modeldir)
75
- netNetwork = ControlNetHED_Apache2().to(devices.get_device_for("controlnet"))
76
- netNetwork.load_state_dict(torch.load(modelpath, map_location='cpu'))
77
- netNetwork.to(devices.get_device_for("controlnet")).float().eval()
78
-
79
- assert input_image.ndim == 3
80
- H, W, C = input_image.shape
81
- with torch.no_grad():
82
- image_hed = torch.from_numpy(input_image.copy()).float().to(devices.get_device_for("controlnet"))
83
- image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
84
- edges = netNetwork(image_hed)
85
- edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges]
86
- edges = [cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR) for e in edges]
87
- edges = np.stack(edges, axis=2)
88
- edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64)))
89
- if is_safe:
90
- edge = safe_step(edge)
91
- edge = (edge * 255.0).clip(0, 255).astype(np.uint8)
92
- return edge
93
-
94
-
95
- def unload_hed_model():
96
- global netNetwork
97
- if netNetwork is not None:
98
- netNetwork.cpu()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/keypose/__init__.py DELETED
@@ -1,212 +0,0 @@
1
- import numpy as np
2
- import cv2
3
- import torch
4
-
5
- import os
6
- from modules import devices
7
- from annotator.annotator_path import models_path
8
-
9
- import mmcv
10
- from mmdet.apis import inference_detector, init_detector
11
- from mmpose.apis import inference_top_down_pose_model
12
- from mmpose.apis import init_pose_model, process_mmdet_results, vis_pose_result
13
-
14
-
15
- def preprocessing(image, device):
16
- # Resize
17
- scale = 640 / max(image.shape[:2])
18
- image = cv2.resize(image, dsize=None, fx=scale, fy=scale)
19
- raw_image = image.astype(np.uint8)
20
-
21
- # Subtract mean values
22
- image = image.astype(np.float32)
23
- image -= np.array(
24
- [
25
- float(104.008),
26
- float(116.669),
27
- float(122.675),
28
- ]
29
- )
30
-
31
- # Convert to torch.Tensor and add "batch" axis
32
- image = torch.from_numpy(image.transpose(2, 0, 1)).float().unsqueeze(0)
33
- image = image.to(device)
34
-
35
- return image, raw_image
36
-
37
-
38
- def imshow_keypoints(img,
39
- pose_result,
40
- skeleton=None,
41
- kpt_score_thr=0.1,
42
- pose_kpt_color=None,
43
- pose_link_color=None,
44
- radius=4,
45
- thickness=1):
46
- """Draw keypoints and links on an image.
47
- Args:
48
- img (ndarry): The image to draw poses on.
49
- pose_result (list[kpts]): The poses to draw. Each element kpts is
50
- a set of K keypoints as an Kx3 numpy.ndarray, where each
51
- keypoint is represented as x, y, score.
52
- kpt_score_thr (float, optional): Minimum score of keypoints
53
- to be shown. Default: 0.3.
54
- pose_kpt_color (np.array[Nx3]`): Color of N keypoints. If None,
55
- the keypoint will not be drawn.
56
- pose_link_color (np.array[Mx3]): Color of M links. If None, the
57
- links will not be drawn.
58
- thickness (int): Thickness of lines.
59
- """
60
-
61
- img_h, img_w, _ = img.shape
62
- img = np.zeros(img.shape)
63
-
64
- for idx, kpts in enumerate(pose_result):
65
- if idx > 1:
66
- continue
67
- kpts = kpts['keypoints']
68
- # print(kpts)
69
- kpts = np.array(kpts, copy=False)
70
-
71
- # draw each point on image
72
- if pose_kpt_color is not None:
73
- assert len(pose_kpt_color) == len(kpts)
74
-
75
- for kid, kpt in enumerate(kpts):
76
- x_coord, y_coord, kpt_score = int(kpt[0]), int(kpt[1]), kpt[2]
77
-
78
- if kpt_score < kpt_score_thr or pose_kpt_color[kid] is None:
79
- # skip the point that should not be drawn
80
- continue
81
-
82
- color = tuple(int(c) for c in pose_kpt_color[kid])
83
- cv2.circle(img, (int(x_coord), int(y_coord)),
84
- radius, color, -1)
85
-
86
- # draw links
87
- if skeleton is not None and pose_link_color is not None:
88
- assert len(pose_link_color) == len(skeleton)
89
-
90
- for sk_id, sk in enumerate(skeleton):
91
- pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1]))
92
- pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1]))
93
-
94
- if (pos1[0] <= 0 or pos1[0] >= img_w or pos1[1] <= 0 or pos1[1] >= img_h or pos2[0] <= 0
95
- or pos2[0] >= img_w or pos2[1] <= 0 or pos2[1] >= img_h or kpts[sk[0], 2] < kpt_score_thr
96
- or kpts[sk[1], 2] < kpt_score_thr or pose_link_color[sk_id] is None):
97
- # skip the link that should not be drawn
98
- continue
99
- color = tuple(int(c) for c in pose_link_color[sk_id])
100
- cv2.line(img, pos1, pos2, color, thickness=thickness)
101
-
102
- return img
103
-
104
-
105
- human_det, pose_model = None, None
106
- det_model_path = "https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth"
107
- pose_model_path = "https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth"
108
-
109
- modeldir = os.path.join(models_path, "keypose")
110
- old_modeldir = os.path.dirname(os.path.realpath(__file__))
111
-
112
- det_config = 'faster_rcnn_r50_fpn_coco.py'
113
- pose_config = 'hrnet_w48_coco_256x192.py'
114
-
115
- det_checkpoint = 'faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
116
- pose_checkpoint = 'hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth'
117
- det_cat_id = 1
118
- bbox_thr = 0.2
119
-
120
- skeleton = [
121
- [15, 13], [13, 11], [16, 14], [14, 12], [11, 12], [5, 11], [6, 12], [5, 6], [5, 7], [6, 8],
122
- [7, 9], [8, 10],
123
- [1, 2], [0, 1], [0, 2], [1, 3], [2, 4], [3, 5], [4, 6]
124
- ]
125
-
126
- pose_kpt_color = [
127
- [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255],
128
- [0, 255, 0],
129
- [255, 128, 0], [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0], [0, 255, 0],
130
- [255, 128, 0],
131
- [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0]
132
- ]
133
-
134
- pose_link_color = [
135
- [0, 255, 0], [0, 255, 0], [255, 128, 0], [255, 128, 0],
136
- [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [0, 255, 0],
137
- [255, 128, 0],
138
- [0, 255, 0], [255, 128, 0], [51, 153, 255], [51, 153, 255], [51, 153, 255],
139
- [51, 153, 255],
140
- [51, 153, 255], [51, 153, 255], [51, 153, 255]
141
- ]
142
-
143
- def find_download_model(checkpoint, remote_path):
144
- modelpath = os.path.join(modeldir, checkpoint)
145
- old_modelpath = os.path.join(old_modeldir, checkpoint)
146
-
147
- if os.path.exists(old_modelpath):
148
- modelpath = old_modelpath
149
- elif not os.path.exists(modelpath):
150
- from basicsr.utils.download_util import load_file_from_url
151
- load_file_from_url(remote_path, model_dir=modeldir)
152
-
153
- return modelpath
154
-
155
- def apply_keypose(input_image):
156
- global human_det, pose_model
157
- if netNetwork is None:
158
- det_model_local = find_download_model(det_checkpoint, det_model_path)
159
- hrnet_model_local = find_download_model(pose_checkpoint, pose_model_path)
160
- det_config_mmcv = mmcv.Config.fromfile(det_config)
161
- pose_config_mmcv = mmcv.Config.fromfile(pose_config)
162
- human_det = init_detector(det_config_mmcv, det_model_local, device=devices.get_device_for("controlnet"))
163
- pose_model = init_pose_model(pose_config_mmcv, hrnet_model_local, device=devices.get_device_for("controlnet"))
164
-
165
- assert input_image.ndim == 3
166
- input_image = input_image.copy()
167
- with torch.no_grad():
168
- image = torch.from_numpy(input_image).float().to(devices.get_device_for("controlnet"))
169
- image = image / 255.0
170
- mmdet_results = inference_detector(human_det, image)
171
-
172
- # keep the person class bounding boxes.
173
- person_results = process_mmdet_results(mmdet_results, det_cat_id)
174
-
175
- return_heatmap = False
176
- dataset = pose_model.cfg.data['test']['type']
177
-
178
- # e.g. use ('backbone', ) to return backbone feature
179
- output_layer_names = None
180
- pose_results, _ = inference_top_down_pose_model(
181
- pose_model,
182
- image,
183
- person_results,
184
- bbox_thr=bbox_thr,
185
- format='xyxy',
186
- dataset=dataset,
187
- dataset_info=None,
188
- return_heatmap=return_heatmap,
189
- outputs=output_layer_names
190
- )
191
-
192
- im_keypose_out = imshow_keypoints(
193
- image,
194
- pose_results,
195
- skeleton=skeleton,
196
- pose_kpt_color=pose_kpt_color,
197
- pose_link_color=pose_link_color,
198
- radius=2,
199
- thickness=2
200
- )
201
- im_keypose_out = im_keypose_out.astype(np.uint8)
202
-
203
- # image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
204
- # edge = netNetwork(image_hed)[0]
205
- # edge = (edge.cpu().numpy() * 255.0).clip(0, 255).astype(np.uint8)
206
- return im_keypose_out
207
-
208
-
209
- def unload_hed_model():
210
- global netNetwork
211
- if netNetwork is not None:
212
- netNetwork.cpu()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/keypose/faster_rcnn_r50_fpn_coco.py DELETED
@@ -1,182 +0,0 @@
1
- checkpoint_config = dict(interval=1)
2
- # yapf:disable
3
- log_config = dict(
4
- interval=50,
5
- hooks=[
6
- dict(type='TextLoggerHook'),
7
- # dict(type='TensorboardLoggerHook')
8
- ])
9
- # yapf:enable
10
- dist_params = dict(backend='nccl')
11
- log_level = 'INFO'
12
- load_from = None
13
- resume_from = None
14
- workflow = [('train', 1)]
15
- # optimizer
16
- optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
17
- optimizer_config = dict(grad_clip=None)
18
- # learning policy
19
- lr_config = dict(
20
- policy='step',
21
- warmup='linear',
22
- warmup_iters=500,
23
- warmup_ratio=0.001,
24
- step=[8, 11])
25
- total_epochs = 12
26
-
27
- model = dict(
28
- type='FasterRCNN',
29
- pretrained='torchvision://resnet50',
30
- backbone=dict(
31
- type='ResNet',
32
- depth=50,
33
- num_stages=4,
34
- out_indices=(0, 1, 2, 3),
35
- frozen_stages=1,
36
- norm_cfg=dict(type='BN', requires_grad=True),
37
- norm_eval=True,
38
- style='pytorch'),
39
- neck=dict(
40
- type='FPN',
41
- in_channels=[256, 512, 1024, 2048],
42
- out_channels=256,
43
- num_outs=5),
44
- rpn_head=dict(
45
- type='RPNHead',
46
- in_channels=256,
47
- feat_channels=256,
48
- anchor_generator=dict(
49
- type='AnchorGenerator',
50
- scales=[8],
51
- ratios=[0.5, 1.0, 2.0],
52
- strides=[4, 8, 16, 32, 64]),
53
- bbox_coder=dict(
54
- type='DeltaXYWHBBoxCoder',
55
- target_means=[.0, .0, .0, .0],
56
- target_stds=[1.0, 1.0, 1.0, 1.0]),
57
- loss_cls=dict(
58
- type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
59
- loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
60
- roi_head=dict(
61
- type='StandardRoIHead',
62
- bbox_roi_extractor=dict(
63
- type='SingleRoIExtractor',
64
- roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
65
- out_channels=256,
66
- featmap_strides=[4, 8, 16, 32]),
67
- bbox_head=dict(
68
- type='Shared2FCBBoxHead',
69
- in_channels=256,
70
- fc_out_channels=1024,
71
- roi_feat_size=7,
72
- num_classes=80,
73
- bbox_coder=dict(
74
- type='DeltaXYWHBBoxCoder',
75
- target_means=[0., 0., 0., 0.],
76
- target_stds=[0.1, 0.1, 0.2, 0.2]),
77
- reg_class_agnostic=False,
78
- loss_cls=dict(
79
- type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
80
- loss_bbox=dict(type='L1Loss', loss_weight=1.0))),
81
- # model training and testing settings
82
- train_cfg=dict(
83
- rpn=dict(
84
- assigner=dict(
85
- type='MaxIoUAssigner',
86
- pos_iou_thr=0.7,
87
- neg_iou_thr=0.3,
88
- min_pos_iou=0.3,
89
- match_low_quality=True,
90
- ignore_iof_thr=-1),
91
- sampler=dict(
92
- type='RandomSampler',
93
- num=256,
94
- pos_fraction=0.5,
95
- neg_pos_ub=-1,
96
- add_gt_as_proposals=False),
97
- allowed_border=-1,
98
- pos_weight=-1,
99
- debug=False),
100
- rpn_proposal=dict(
101
- nms_pre=2000,
102
- max_per_img=1000,
103
- nms=dict(type='nms', iou_threshold=0.7),
104
- min_bbox_size=0),
105
- rcnn=dict(
106
- assigner=dict(
107
- type='MaxIoUAssigner',
108
- pos_iou_thr=0.5,
109
- neg_iou_thr=0.5,
110
- min_pos_iou=0.5,
111
- match_low_quality=False,
112
- ignore_iof_thr=-1),
113
- sampler=dict(
114
- type='RandomSampler',
115
- num=512,
116
- pos_fraction=0.25,
117
- neg_pos_ub=-1,
118
- add_gt_as_proposals=True),
119
- pos_weight=-1,
120
- debug=False)),
121
- test_cfg=dict(
122
- rpn=dict(
123
- nms_pre=1000,
124
- max_per_img=1000,
125
- nms=dict(type='nms', iou_threshold=0.7),
126
- min_bbox_size=0),
127
- rcnn=dict(
128
- score_thr=0.05,
129
- nms=dict(type='nms', iou_threshold=0.5),
130
- max_per_img=100)
131
- # soft-nms is also supported for rcnn testing
132
- # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05)
133
- ))
134
-
135
- dataset_type = 'CocoDataset'
136
- data_root = 'data/coco'
137
- img_norm_cfg = dict(
138
- mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
139
- train_pipeline = [
140
- dict(type='LoadImageFromFile'),
141
- dict(type='LoadAnnotations', with_bbox=True),
142
- dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
143
- dict(type='RandomFlip', flip_ratio=0.5),
144
- dict(type='Normalize', **img_norm_cfg),
145
- dict(type='Pad', size_divisor=32),
146
- dict(type='DefaultFormatBundle'),
147
- dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
148
- ]
149
- test_pipeline = [
150
- dict(type='LoadImageFromFile'),
151
- dict(
152
- type='MultiScaleFlipAug',
153
- img_scale=(1333, 800),
154
- flip=False,
155
- transforms=[
156
- dict(type='Resize', keep_ratio=True),
157
- dict(type='RandomFlip'),
158
- dict(type='Normalize', **img_norm_cfg),
159
- dict(type='Pad', size_divisor=32),
160
- dict(type='DefaultFormatBundle'),
161
- dict(type='Collect', keys=['img']),
162
- ])
163
- ]
164
- data = dict(
165
- samples_per_gpu=2,
166
- workers_per_gpu=2,
167
- train=dict(
168
- type=dataset_type,
169
- ann_file=f'{data_root}/annotations/instances_train2017.json',
170
- img_prefix=f'{data_root}/train2017/',
171
- pipeline=train_pipeline),
172
- val=dict(
173
- type=dataset_type,
174
- ann_file=f'{data_root}/annotations/instances_val2017.json',
175
- img_prefix=f'{data_root}/val2017/',
176
- pipeline=test_pipeline),
177
- test=dict(
178
- type=dataset_type,
179
- ann_file=f'{data_root}/annotations/instances_val2017.json',
180
- img_prefix=f'{data_root}/val2017/',
181
- pipeline=test_pipeline))
182
- evaluation = dict(interval=1, metric='bbox')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/keypose/hrnet_w48_coco_256x192.py DELETED
@@ -1,169 +0,0 @@
1
- # _base_ = [
2
- # '../../../../_base_/default_runtime.py',
3
- # '../../../../_base_/datasets/coco.py'
4
- # ]
5
- evaluation = dict(interval=10, metric='mAP', save_best='AP')
6
-
7
- optimizer = dict(
8
- type='Adam',
9
- lr=5e-4,
10
- )
11
- optimizer_config = dict(grad_clip=None)
12
- # learning policy
13
- lr_config = dict(
14
- policy='step',
15
- warmup='linear',
16
- warmup_iters=500,
17
- warmup_ratio=0.001,
18
- step=[170, 200])
19
- total_epochs = 210
20
- channel_cfg = dict(
21
- num_output_channels=17,
22
- dataset_joints=17,
23
- dataset_channel=[
24
- [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
25
- ],
26
- inference_channel=[
27
- 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
28
- ])
29
-
30
- # model settings
31
- model = dict(
32
- type='TopDown',
33
- pretrained='https://download.openmmlab.com/mmpose/'
34
- 'pretrain_models/hrnet_w48-8ef0771d.pth',
35
- backbone=dict(
36
- type='HRNet',
37
- in_channels=3,
38
- extra=dict(
39
- stage1=dict(
40
- num_modules=1,
41
- num_branches=1,
42
- block='BOTTLENECK',
43
- num_blocks=(4, ),
44
- num_channels=(64, )),
45
- stage2=dict(
46
- num_modules=1,
47
- num_branches=2,
48
- block='BASIC',
49
- num_blocks=(4, 4),
50
- num_channels=(48, 96)),
51
- stage3=dict(
52
- num_modules=4,
53
- num_branches=3,
54
- block='BASIC',
55
- num_blocks=(4, 4, 4),
56
- num_channels=(48, 96, 192)),
57
- stage4=dict(
58
- num_modules=3,
59
- num_branches=4,
60
- block='BASIC',
61
- num_blocks=(4, 4, 4, 4),
62
- num_channels=(48, 96, 192, 384))),
63
- ),
64
- keypoint_head=dict(
65
- type='TopdownHeatmapSimpleHead',
66
- in_channels=48,
67
- out_channels=channel_cfg['num_output_channels'],
68
- num_deconv_layers=0,
69
- extra=dict(final_conv_kernel=1, ),
70
- loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)),
71
- train_cfg=dict(),
72
- test_cfg=dict(
73
- flip_test=True,
74
- post_process='default',
75
- shift_heatmap=True,
76
- modulate_kernel=11))
77
-
78
- data_cfg = dict(
79
- image_size=[192, 256],
80
- heatmap_size=[48, 64],
81
- num_output_channels=channel_cfg['num_output_channels'],
82
- num_joints=channel_cfg['dataset_joints'],
83
- dataset_channel=channel_cfg['dataset_channel'],
84
- inference_channel=channel_cfg['inference_channel'],
85
- soft_nms=False,
86
- nms_thr=1.0,
87
- oks_thr=0.9,
88
- vis_thr=0.2,
89
- use_gt_bbox=False,
90
- det_bbox_thr=0.0,
91
- bbox_file='data/coco/person_detection_results/'
92
- 'COCO_val2017_detections_AP_H_56_person.json',
93
- )
94
-
95
- train_pipeline = [
96
- dict(type='LoadImageFromFile'),
97
- dict(type='TopDownGetBboxCenterScale', padding=1.25),
98
- dict(type='TopDownRandomShiftBboxCenter', shift_factor=0.16, prob=0.3),
99
- dict(type='TopDownRandomFlip', flip_prob=0.5),
100
- dict(
101
- type='TopDownHalfBodyTransform',
102
- num_joints_half_body=8,
103
- prob_half_body=0.3),
104
- dict(
105
- type='TopDownGetRandomScaleRotation', rot_factor=40, scale_factor=0.5),
106
- dict(type='TopDownAffine'),
107
- dict(type='ToTensor'),
108
- dict(
109
- type='NormalizeTensor',
110
- mean=[0.485, 0.456, 0.406],
111
- std=[0.229, 0.224, 0.225]),
112
- dict(type='TopDownGenerateTarget', sigma=2),
113
- dict(
114
- type='Collect',
115
- keys=['img', 'target', 'target_weight'],
116
- meta_keys=[
117
- 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
118
- 'rotation', 'bbox_score', 'flip_pairs'
119
- ]),
120
- ]
121
-
122
- val_pipeline = [
123
- dict(type='LoadImageFromFile'),
124
- dict(type='TopDownGetBboxCenterScale', padding=1.25),
125
- dict(type='TopDownAffine'),
126
- dict(type='ToTensor'),
127
- dict(
128
- type='NormalizeTensor',
129
- mean=[0.485, 0.456, 0.406],
130
- std=[0.229, 0.224, 0.225]),
131
- dict(
132
- type='Collect',
133
- keys=['img'],
134
- meta_keys=[
135
- 'image_file', 'center', 'scale', 'rotation', 'bbox_score',
136
- 'flip_pairs'
137
- ]),
138
- ]
139
-
140
- test_pipeline = val_pipeline
141
-
142
- data_root = 'data/coco'
143
- data = dict(
144
- samples_per_gpu=32,
145
- workers_per_gpu=2,
146
- val_dataloader=dict(samples_per_gpu=32),
147
- test_dataloader=dict(samples_per_gpu=32),
148
- train=dict(
149
- type='TopDownCocoDataset',
150
- ann_file=f'{data_root}/annotations/person_keypoints_train2017.json',
151
- img_prefix=f'{data_root}/train2017/',
152
- data_cfg=data_cfg,
153
- pipeline=train_pipeline,
154
- dataset_info={{_base_.dataset_info}}),
155
- val=dict(
156
- type='TopDownCocoDataset',
157
- ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
158
- img_prefix=f'{data_root}/val2017/',
159
- data_cfg=data_cfg,
160
- pipeline=val_pipeline,
161
- dataset_info={{_base_.dataset_info}}),
162
- test=dict(
163
- type='TopDownCocoDataset',
164
- ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
165
- img_prefix=f'{data_root}/val2017/',
166
- data_cfg=data_cfg,
167
- pipeline=test_pipeline,
168
- dataset_info={{_base_.dataset_info}}),
169
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/lama/__init__.py DELETED
@@ -1,58 +0,0 @@
1
- # https://github.com/advimman/lama
2
-
3
- import yaml
4
- import torch
5
- from omegaconf import OmegaConf
6
- import numpy as np
7
-
8
- from einops import rearrange
9
- import os
10
- from modules import devices
11
- from annotator.annotator_path import models_path
12
- from annotator.lama.saicinpainting.training.trainers import load_checkpoint
13
-
14
-
15
- class LamaInpainting:
16
- model_dir = os.path.join(models_path, "lama")
17
-
18
- def __init__(self):
19
- self.model = None
20
- self.device = devices.get_device_for("controlnet")
21
-
22
- def load_model(self):
23
- remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/ControlNetLama.pth"
24
- modelpath = os.path.join(self.model_dir, "ControlNetLama.pth")
25
- if not os.path.exists(modelpath):
26
- from basicsr.utils.download_util import load_file_from_url
27
- load_file_from_url(remote_model_path, model_dir=self.model_dir)
28
- config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.yaml')
29
- cfg = yaml.safe_load(open(config_path, 'rt'))
30
- cfg = OmegaConf.create(cfg)
31
- cfg.training_model.predict_only = True
32
- cfg.visualizer.kind = 'noop'
33
- self.model = load_checkpoint(cfg, os.path.abspath(modelpath), strict=False, map_location='cpu')
34
- self.model = self.model.to(self.device)
35
- self.model.eval()
36
-
37
- def unload_model(self):
38
- if self.model is not None:
39
- self.model.cpu()
40
-
41
- def __call__(self, input_image):
42
- if self.model is None:
43
- self.load_model()
44
- self.model.to(self.device)
45
- color = np.ascontiguousarray(input_image[:, :, 0:3]).astype(np.float32) / 255.0
46
- mask = np.ascontiguousarray(input_image[:, :, 3:4]).astype(np.float32) / 255.0
47
- with torch.no_grad():
48
- color = torch.from_numpy(color).float().to(self.device)
49
- mask = torch.from_numpy(mask).float().to(self.device)
50
- mask = (mask > 0.5).float()
51
- color = color * (1 - mask)
52
- image_feed = torch.cat([color, mask], dim=2)
53
- image_feed = rearrange(image_feed, 'h w c -> 1 c h w')
54
- result = self.model(image_feed)[0]
55
- result = rearrange(result, 'c h w -> h w c')
56
- result = result * mask + color * (1 - mask)
57
- result *= 255.0
58
- return result.detach().cpu().numpy().clip(0, 255).astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/lama/config.yaml DELETED
@@ -1,157 +0,0 @@
1
- run_title: b18_ffc075_batch8x15
2
- training_model:
3
- kind: default
4
- visualize_each_iters: 1000
5
- concat_mask: true
6
- store_discr_outputs_for_vis: true
7
- losses:
8
- l1:
9
- weight_missing: 0
10
- weight_known: 10
11
- perceptual:
12
- weight: 0
13
- adversarial:
14
- kind: r1
15
- weight: 10
16
- gp_coef: 0.001
17
- mask_as_fake_target: true
18
- allow_scale_mask: true
19
- feature_matching:
20
- weight: 100
21
- resnet_pl:
22
- weight: 30
23
- weights_path: ${env:TORCH_HOME}
24
-
25
- optimizers:
26
- generator:
27
- kind: adam
28
- lr: 0.001
29
- discriminator:
30
- kind: adam
31
- lr: 0.0001
32
- visualizer:
33
- key_order:
34
- - image
35
- - predicted_image
36
- - discr_output_fake
37
- - discr_output_real
38
- - inpainted
39
- rescale_keys:
40
- - discr_output_fake
41
- - discr_output_real
42
- kind: directory
43
- outdir: /group-volume/User-Driven-Content-Generation/r.suvorov/inpainting/experiments/r.suvorov_2021-04-30_14-41-12_train_simple_pix2pix2_gap_sdpl_novgg_large_b18_ffc075_batch8x15/samples
44
- location:
45
- data_root_dir: /group-volume/User-Driven-Content-Generation/datasets/inpainting_data_root_large
46
- out_root_dir: /group-volume/User-Driven-Content-Generation/${env:USER}/inpainting/experiments
47
- tb_dir: /group-volume/User-Driven-Content-Generation/${env:USER}/inpainting/tb_logs
48
- data:
49
- batch_size: 15
50
- val_batch_size: 2
51
- num_workers: 3
52
- train:
53
- indir: ${location.data_root_dir}/train
54
- out_size: 256
55
- mask_gen_kwargs:
56
- irregular_proba: 1
57
- irregular_kwargs:
58
- max_angle: 4
59
- max_len: 200
60
- max_width: 100
61
- max_times: 5
62
- min_times: 1
63
- box_proba: 1
64
- box_kwargs:
65
- margin: 10
66
- bbox_min_size: 30
67
- bbox_max_size: 150
68
- max_times: 3
69
- min_times: 1
70
- segm_proba: 0
71
- segm_kwargs:
72
- confidence_threshold: 0.5
73
- max_object_area: 0.5
74
- min_mask_area: 0.07
75
- downsample_levels: 6
76
- num_variants_per_mask: 1
77
- rigidness_mode: 1
78
- max_foreground_coverage: 0.3
79
- max_foreground_intersection: 0.7
80
- max_mask_intersection: 0.1
81
- max_hidden_area: 0.1
82
- max_scale_change: 0.25
83
- horizontal_flip: true
84
- max_vertical_shift: 0.2
85
- position_shuffle: true
86
- transform_variant: distortions
87
- dataloader_kwargs:
88
- batch_size: ${data.batch_size}
89
- shuffle: true
90
- num_workers: ${data.num_workers}
91
- val:
92
- indir: ${location.data_root_dir}/val
93
- img_suffix: .png
94
- dataloader_kwargs:
95
- batch_size: ${data.val_batch_size}
96
- shuffle: false
97
- num_workers: ${data.num_workers}
98
- visual_test:
99
- indir: ${location.data_root_dir}/korean_test
100
- img_suffix: _input.png
101
- pad_out_to_modulo: 32
102
- dataloader_kwargs:
103
- batch_size: 1
104
- shuffle: false
105
- num_workers: ${data.num_workers}
106
- generator:
107
- kind: ffc_resnet
108
- input_nc: 4
109
- output_nc: 3
110
- ngf: 64
111
- n_downsampling: 3
112
- n_blocks: 18
113
- add_out_act: sigmoid
114
- init_conv_kwargs:
115
- ratio_gin: 0
116
- ratio_gout: 0
117
- enable_lfu: false
118
- downsample_conv_kwargs:
119
- ratio_gin: ${generator.init_conv_kwargs.ratio_gout}
120
- ratio_gout: ${generator.downsample_conv_kwargs.ratio_gin}
121
- enable_lfu: false
122
- resnet_conv_kwargs:
123
- ratio_gin: 0.75
124
- ratio_gout: ${generator.resnet_conv_kwargs.ratio_gin}
125
- enable_lfu: false
126
- discriminator:
127
- kind: pix2pixhd_nlayer
128
- input_nc: 3
129
- ndf: 64
130
- n_layers: 4
131
- evaluator:
132
- kind: default
133
- inpainted_key: inpainted
134
- integral_kind: ssim_fid100_f1
135
- trainer:
136
- kwargs:
137
- gpus: -1
138
- accelerator: ddp
139
- max_epochs: 200
140
- gradient_clip_val: 1
141
- log_gpu_memory: None
142
- limit_train_batches: 25000
143
- val_check_interval: ${trainer.kwargs.limit_train_batches}
144
- log_every_n_steps: 1000
145
- precision: 32
146
- terminate_on_nan: false
147
- check_val_every_n_epoch: 1
148
- num_sanity_val_steps: 8
149
- limit_val_batches: 1000
150
- replace_sampler_ddp: false
151
- checkpoint_kwargs:
152
- verbose: true
153
- save_top_k: 5
154
- save_last: true
155
- period: 1
156
- monitor: val_ssim_fid100_f1_total_mean
157
- mode: max
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/lama/saicinpainting/__init__.py DELETED
File without changes
annotator/lama/saicinpainting/training/__init__.py DELETED
File without changes
annotator/lama/saicinpainting/training/data/__init__.py DELETED
File without changes
annotator/lama/saicinpainting/training/data/masks.py DELETED
@@ -1,332 +0,0 @@
1
- import math
2
- import random
3
- import hashlib
4
- import logging
5
- from enum import Enum
6
-
7
- import cv2
8
- import numpy as np
9
-
10
- # from annotator.lama.saicinpainting.evaluation.masks.mask import SegmentationMask
11
- from annotator.lama.saicinpainting.utils import LinearRamp
12
-
13
- LOGGER = logging.getLogger(__name__)
14
-
15
-
16
- class DrawMethod(Enum):
17
- LINE = 'line'
18
- CIRCLE = 'circle'
19
- SQUARE = 'square'
20
-
21
-
22
- def make_random_irregular_mask(shape, max_angle=4, max_len=60, max_width=20, min_times=0, max_times=10,
23
- draw_method=DrawMethod.LINE):
24
- draw_method = DrawMethod(draw_method)
25
-
26
- height, width = shape
27
- mask = np.zeros((height, width), np.float32)
28
- times = np.random.randint(min_times, max_times + 1)
29
- for i in range(times):
30
- start_x = np.random.randint(width)
31
- start_y = np.random.randint(height)
32
- for j in range(1 + np.random.randint(5)):
33
- angle = 0.01 + np.random.randint(max_angle)
34
- if i % 2 == 0:
35
- angle = 2 * 3.1415926 - angle
36
- length = 10 + np.random.randint(max_len)
37
- brush_w = 5 + np.random.randint(max_width)
38
- end_x = np.clip((start_x + length * np.sin(angle)).astype(np.int32), 0, width)
39
- end_y = np.clip((start_y + length * np.cos(angle)).astype(np.int32), 0, height)
40
- if draw_method == DrawMethod.LINE:
41
- cv2.line(mask, (start_x, start_y), (end_x, end_y), 1.0, brush_w)
42
- elif draw_method == DrawMethod.CIRCLE:
43
- cv2.circle(mask, (start_x, start_y), radius=brush_w, color=1., thickness=-1)
44
- elif draw_method == DrawMethod.SQUARE:
45
- radius = brush_w // 2
46
- mask[start_y - radius:start_y + radius, start_x - radius:start_x + radius] = 1
47
- start_x, start_y = end_x, end_y
48
- return mask[None, ...]
49
-
50
-
51
- class RandomIrregularMaskGenerator:
52
- def __init__(self, max_angle=4, max_len=60, max_width=20, min_times=0, max_times=10, ramp_kwargs=None,
53
- draw_method=DrawMethod.LINE):
54
- self.max_angle = max_angle
55
- self.max_len = max_len
56
- self.max_width = max_width
57
- self.min_times = min_times
58
- self.max_times = max_times
59
- self.draw_method = draw_method
60
- self.ramp = LinearRamp(**ramp_kwargs) if ramp_kwargs is not None else None
61
-
62
- def __call__(self, img, iter_i=None, raw_image=None):
63
- coef = self.ramp(iter_i) if (self.ramp is not None) and (iter_i is not None) else 1
64
- cur_max_len = int(max(1, self.max_len * coef))
65
- cur_max_width = int(max(1, self.max_width * coef))
66
- cur_max_times = int(self.min_times + 1 + (self.max_times - self.min_times) * coef)
67
- return make_random_irregular_mask(img.shape[1:], max_angle=self.max_angle, max_len=cur_max_len,
68
- max_width=cur_max_width, min_times=self.min_times, max_times=cur_max_times,
69
- draw_method=self.draw_method)
70
-
71
-
72
- def make_random_rectangle_mask(shape, margin=10, bbox_min_size=30, bbox_max_size=100, min_times=0, max_times=3):
73
- height, width = shape
74
- mask = np.zeros((height, width), np.float32)
75
- bbox_max_size = min(bbox_max_size, height - margin * 2, width - margin * 2)
76
- times = np.random.randint(min_times, max_times + 1)
77
- for i in range(times):
78
- box_width = np.random.randint(bbox_min_size, bbox_max_size)
79
- box_height = np.random.randint(bbox_min_size, bbox_max_size)
80
- start_x = np.random.randint(margin, width - margin - box_width + 1)
81
- start_y = np.random.randint(margin, height - margin - box_height + 1)
82
- mask[start_y:start_y + box_height, start_x:start_x + box_width] = 1
83
- return mask[None, ...]
84
-
85
-
86
- class RandomRectangleMaskGenerator:
87
- def __init__(self, margin=10, bbox_min_size=30, bbox_max_size=100, min_times=0, max_times=3, ramp_kwargs=None):
88
- self.margin = margin
89
- self.bbox_min_size = bbox_min_size
90
- self.bbox_max_size = bbox_max_size
91
- self.min_times = min_times
92
- self.max_times = max_times
93
- self.ramp = LinearRamp(**ramp_kwargs) if ramp_kwargs is not None else None
94
-
95
- def __call__(self, img, iter_i=None, raw_image=None):
96
- coef = self.ramp(iter_i) if (self.ramp is not None) and (iter_i is not None) else 1
97
- cur_bbox_max_size = int(self.bbox_min_size + 1 + (self.bbox_max_size - self.bbox_min_size) * coef)
98
- cur_max_times = int(self.min_times + (self.max_times - self.min_times) * coef)
99
- return make_random_rectangle_mask(img.shape[1:], margin=self.margin, bbox_min_size=self.bbox_min_size,
100
- bbox_max_size=cur_bbox_max_size, min_times=self.min_times,
101
- max_times=cur_max_times)
102
-
103
-
104
- class RandomSegmentationMaskGenerator:
105
- def __init__(self, **kwargs):
106
- self.impl = None # will be instantiated in first call (effectively in subprocess)
107
- self.kwargs = kwargs
108
-
109
- def __call__(self, img, iter_i=None, raw_image=None):
110
- if self.impl is None:
111
- self.impl = SegmentationMask(**self.kwargs)
112
-
113
- masks = self.impl.get_masks(np.transpose(img, (1, 2, 0)))
114
- masks = [m for m in masks if len(np.unique(m)) > 1]
115
- return np.random.choice(masks)
116
-
117
-
118
- def make_random_superres_mask(shape, min_step=2, max_step=4, min_width=1, max_width=3):
119
- height, width = shape
120
- mask = np.zeros((height, width), np.float32)
121
- step_x = np.random.randint(min_step, max_step + 1)
122
- width_x = np.random.randint(min_width, min(step_x, max_width + 1))
123
- offset_x = np.random.randint(0, step_x)
124
-
125
- step_y = np.random.randint(min_step, max_step + 1)
126
- width_y = np.random.randint(min_width, min(step_y, max_width + 1))
127
- offset_y = np.random.randint(0, step_y)
128
-
129
- for dy in range(width_y):
130
- mask[offset_y + dy::step_y] = 1
131
- for dx in range(width_x):
132
- mask[:, offset_x + dx::step_x] = 1
133
- return mask[None, ...]
134
-
135
-
136
- class RandomSuperresMaskGenerator:
137
- def __init__(self, **kwargs):
138
- self.kwargs = kwargs
139
-
140
- def __call__(self, img, iter_i=None):
141
- return make_random_superres_mask(img.shape[1:], **self.kwargs)
142
-
143
-
144
- class DumbAreaMaskGenerator:
145
- min_ratio = 0.1
146
- max_ratio = 0.35
147
- default_ratio = 0.225
148
-
149
- def __init__(self, is_training):
150
- #Parameters:
151
- # is_training(bool): If true - random rectangular mask, if false - central square mask
152
- self.is_training = is_training
153
-
154
- def _random_vector(self, dimension):
155
- if self.is_training:
156
- lower_limit = math.sqrt(self.min_ratio)
157
- upper_limit = math.sqrt(self.max_ratio)
158
- mask_side = round((random.random() * (upper_limit - lower_limit) + lower_limit) * dimension)
159
- u = random.randint(0, dimension-mask_side-1)
160
- v = u+mask_side
161
- else:
162
- margin = (math.sqrt(self.default_ratio) / 2) * dimension
163
- u = round(dimension/2 - margin)
164
- v = round(dimension/2 + margin)
165
- return u, v
166
-
167
- def __call__(self, img, iter_i=None, raw_image=None):
168
- c, height, width = img.shape
169
- mask = np.zeros((height, width), np.float32)
170
- x1, x2 = self._random_vector(width)
171
- y1, y2 = self._random_vector(height)
172
- mask[x1:x2, y1:y2] = 1
173
- return mask[None, ...]
174
-
175
-
176
- class OutpaintingMaskGenerator:
177
- def __init__(self, min_padding_percent:float=0.04, max_padding_percent:int=0.25, left_padding_prob:float=0.5, top_padding_prob:float=0.5,
178
- right_padding_prob:float=0.5, bottom_padding_prob:float=0.5, is_fixed_randomness:bool=False):
179
- """
180
- is_fixed_randomness - get identical paddings for the same image if args are the same
181
- """
182
- self.min_padding_percent = min_padding_percent
183
- self.max_padding_percent = max_padding_percent
184
- self.probs = [left_padding_prob, top_padding_prob, right_padding_prob, bottom_padding_prob]
185
- self.is_fixed_randomness = is_fixed_randomness
186
-
187
- assert self.min_padding_percent <= self.max_padding_percent
188
- assert self.max_padding_percent > 0
189
- assert len([x for x in [self.min_padding_percent, self.max_padding_percent] if (x>=0 and x<=1)]) == 2, f"Padding percentage should be in [0,1]"
190
- assert sum(self.probs) > 0, f"At least one of the padding probs should be greater than 0 - {self.probs}"
191
- assert len([x for x in self.probs if (x >= 0) and (x <= 1)]) == 4, f"At least one of padding probs is not in [0,1] - {self.probs}"
192
- if len([x for x in self.probs if x > 0]) == 1:
193
- LOGGER.warning(f"Only one padding prob is greater than zero - {self.probs}. That means that the outpainting masks will be always on the same side")
194
-
195
- def apply_padding(self, mask, coord):
196
- mask[int(coord[0][0]*self.img_h):int(coord[1][0]*self.img_h),
197
- int(coord[0][1]*self.img_w):int(coord[1][1]*self.img_w)] = 1
198
- return mask
199
-
200
- def get_padding(self, size):
201
- n1 = int(self.min_padding_percent*size)
202
- n2 = int(self.max_padding_percent*size)
203
- return self.rnd.randint(n1, n2) / size
204
-
205
- @staticmethod
206
- def _img2rs(img):
207
- arr = np.ascontiguousarray(img.astype(np.uint8))
208
- str_hash = hashlib.sha1(arr).hexdigest()
209
- res = hash(str_hash)%(2**32)
210
- return res
211
-
212
- def __call__(self, img, iter_i=None, raw_image=None):
213
- c, self.img_h, self.img_w = img.shape
214
- mask = np.zeros((self.img_h, self.img_w), np.float32)
215
- at_least_one_mask_applied = False
216
-
217
- if self.is_fixed_randomness:
218
- assert raw_image is not None, f"Cant calculate hash on raw_image=None"
219
- rs = self._img2rs(raw_image)
220
- self.rnd = np.random.RandomState(rs)
221
- else:
222
- self.rnd = np.random
223
-
224
- coords = [[
225
- (0,0),
226
- (1,self.get_padding(size=self.img_h))
227
- ],
228
- [
229
- (0,0),
230
- (self.get_padding(size=self.img_w),1)
231
- ],
232
- [
233
- (0,1-self.get_padding(size=self.img_h)),
234
- (1,1)
235
- ],
236
- [
237
- (1-self.get_padding(size=self.img_w),0),
238
- (1,1)
239
- ]]
240
-
241
- for pp, coord in zip(self.probs, coords):
242
- if self.rnd.random() < pp:
243
- at_least_one_mask_applied = True
244
- mask = self.apply_padding(mask=mask, coord=coord)
245
-
246
- if not at_least_one_mask_applied:
247
- idx = self.rnd.choice(range(len(coords)), p=np.array(self.probs)/sum(self.probs))
248
- mask = self.apply_padding(mask=mask, coord=coords[idx])
249
- return mask[None, ...]
250
-
251
-
252
- class MixedMaskGenerator:
253
- def __init__(self, irregular_proba=1/3, irregular_kwargs=None,
254
- box_proba=1/3, box_kwargs=None,
255
- segm_proba=1/3, segm_kwargs=None,
256
- squares_proba=0, squares_kwargs=None,
257
- superres_proba=0, superres_kwargs=None,
258
- outpainting_proba=0, outpainting_kwargs=None,
259
- invert_proba=0):
260
- self.probas = []
261
- self.gens = []
262
-
263
- if irregular_proba > 0:
264
- self.probas.append(irregular_proba)
265
- if irregular_kwargs is None:
266
- irregular_kwargs = {}
267
- else:
268
- irregular_kwargs = dict(irregular_kwargs)
269
- irregular_kwargs['draw_method'] = DrawMethod.LINE
270
- self.gens.append(RandomIrregularMaskGenerator(**irregular_kwargs))
271
-
272
- if box_proba > 0:
273
- self.probas.append(box_proba)
274
- if box_kwargs is None:
275
- box_kwargs = {}
276
- self.gens.append(RandomRectangleMaskGenerator(**box_kwargs))
277
-
278
- if segm_proba > 0:
279
- self.probas.append(segm_proba)
280
- if segm_kwargs is None:
281
- segm_kwargs = {}
282
- self.gens.append(RandomSegmentationMaskGenerator(**segm_kwargs))
283
-
284
- if squares_proba > 0:
285
- self.probas.append(squares_proba)
286
- if squares_kwargs is None:
287
- squares_kwargs = {}
288
- else:
289
- squares_kwargs = dict(squares_kwargs)
290
- squares_kwargs['draw_method'] = DrawMethod.SQUARE
291
- self.gens.append(RandomIrregularMaskGenerator(**squares_kwargs))
292
-
293
- if superres_proba > 0:
294
- self.probas.append(superres_proba)
295
- if superres_kwargs is None:
296
- superres_kwargs = {}
297
- self.gens.append(RandomSuperresMaskGenerator(**superres_kwargs))
298
-
299
- if outpainting_proba > 0:
300
- self.probas.append(outpainting_proba)
301
- if outpainting_kwargs is None:
302
- outpainting_kwargs = {}
303
- self.gens.append(OutpaintingMaskGenerator(**outpainting_kwargs))
304
-
305
- self.probas = np.array(self.probas, dtype='float32')
306
- self.probas /= self.probas.sum()
307
- self.invert_proba = invert_proba
308
-
309
- def __call__(self, img, iter_i=None, raw_image=None):
310
- kind = np.random.choice(len(self.probas), p=self.probas)
311
- gen = self.gens[kind]
312
- result = gen(img, iter_i=iter_i, raw_image=raw_image)
313
- if self.invert_proba > 0 and random.random() < self.invert_proba:
314
- result = 1 - result
315
- return result
316
-
317
-
318
- def get_mask_generator(kind, kwargs):
319
- if kind is None:
320
- kind = "mixed"
321
- if kwargs is None:
322
- kwargs = {}
323
-
324
- if kind == "mixed":
325
- cl = MixedMaskGenerator
326
- elif kind == "outpainting":
327
- cl = OutpaintingMaskGenerator
328
- elif kind == "dumb":
329
- cl = DumbAreaMaskGenerator
330
- else:
331
- raise NotImplementedError(f"No such generator kind = {kind}")
332
- return cl(**kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/lama/saicinpainting/training/losses/__init__.py DELETED
File without changes
annotator/lama/saicinpainting/training/losses/adversarial.py DELETED
@@ -1,177 +0,0 @@
1
- from typing import Tuple, Dict, Optional
2
-
3
- import torch
4
- import torch.nn as nn
5
- import torch.nn.functional as F
6
-
7
-
8
- class BaseAdversarialLoss:
9
- def pre_generator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
10
- generator: nn.Module, discriminator: nn.Module):
11
- """
12
- Prepare for generator step
13
- :param real_batch: Tensor, a batch of real samples
14
- :param fake_batch: Tensor, a batch of samples produced by generator
15
- :param generator:
16
- :param discriminator:
17
- :return: None
18
- """
19
-
20
- def pre_discriminator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
21
- generator: nn.Module, discriminator: nn.Module):
22
- """
23
- Prepare for discriminator step
24
- :param real_batch: Tensor, a batch of real samples
25
- :param fake_batch: Tensor, a batch of samples produced by generator
26
- :param generator:
27
- :param discriminator:
28
- :return: None
29
- """
30
-
31
- def generator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
32
- discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor,
33
- mask: Optional[torch.Tensor] = None) \
34
- -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
35
- """
36
- Calculate generator loss
37
- :param real_batch: Tensor, a batch of real samples
38
- :param fake_batch: Tensor, a batch of samples produced by generator
39
- :param discr_real_pred: Tensor, discriminator output for real_batch
40
- :param discr_fake_pred: Tensor, discriminator output for fake_batch
41
- :param mask: Tensor, actual mask, which was at input of generator when making fake_batch
42
- :return: total generator loss along with some values that might be interesting to log
43
- """
44
- raise NotImplemented()
45
-
46
- def discriminator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
47
- discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor,
48
- mask: Optional[torch.Tensor] = None) \
49
- -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
50
- """
51
- Calculate discriminator loss and call .backward() on it
52
- :param real_batch: Tensor, a batch of real samples
53
- :param fake_batch: Tensor, a batch of samples produced by generator
54
- :param discr_real_pred: Tensor, discriminator output for real_batch
55
- :param discr_fake_pred: Tensor, discriminator output for fake_batch
56
- :param mask: Tensor, actual mask, which was at input of generator when making fake_batch
57
- :return: total discriminator loss along with some values that might be interesting to log
58
- """
59
- raise NotImplemented()
60
-
61
- def interpolate_mask(self, mask, shape):
62
- assert mask is not None
63
- assert self.allow_scale_mask or shape == mask.shape[-2:]
64
- if shape != mask.shape[-2:] and self.allow_scale_mask:
65
- if self.mask_scale_mode == 'maxpool':
66
- mask = F.adaptive_max_pool2d(mask, shape)
67
- else:
68
- mask = F.interpolate(mask, size=shape, mode=self.mask_scale_mode)
69
- return mask
70
-
71
- def make_r1_gp(discr_real_pred, real_batch):
72
- if torch.is_grad_enabled():
73
- grad_real = torch.autograd.grad(outputs=discr_real_pred.sum(), inputs=real_batch, create_graph=True)[0]
74
- grad_penalty = (grad_real.view(grad_real.shape[0], -1).norm(2, dim=1) ** 2).mean()
75
- else:
76
- grad_penalty = 0
77
- real_batch.requires_grad = False
78
-
79
- return grad_penalty
80
-
81
- class NonSaturatingWithR1(BaseAdversarialLoss):
82
- def __init__(self, gp_coef=5, weight=1, mask_as_fake_target=False, allow_scale_mask=False,
83
- mask_scale_mode='nearest', extra_mask_weight_for_gen=0,
84
- use_unmasked_for_gen=True, use_unmasked_for_discr=True):
85
- self.gp_coef = gp_coef
86
- self.weight = weight
87
- # use for discr => use for gen;
88
- # otherwise we teach only the discr to pay attention to very small difference
89
- assert use_unmasked_for_gen or (not use_unmasked_for_discr)
90
- # mask as target => use unmasked for discr:
91
- # if we don't care about unmasked regions at all
92
- # then it doesn't matter if the value of mask_as_fake_target is true or false
93
- assert use_unmasked_for_discr or (not mask_as_fake_target)
94
- self.use_unmasked_for_gen = use_unmasked_for_gen
95
- self.use_unmasked_for_discr = use_unmasked_for_discr
96
- self.mask_as_fake_target = mask_as_fake_target
97
- self.allow_scale_mask = allow_scale_mask
98
- self.mask_scale_mode = mask_scale_mode
99
- self.extra_mask_weight_for_gen = extra_mask_weight_for_gen
100
-
101
- def generator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
102
- discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor,
103
- mask=None) \
104
- -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
105
- fake_loss = F.softplus(-discr_fake_pred)
106
- if (self.mask_as_fake_target and self.extra_mask_weight_for_gen > 0) or \
107
- not self.use_unmasked_for_gen: # == if masked region should be treated differently
108
- mask = self.interpolate_mask(mask, discr_fake_pred.shape[-2:])
109
- if not self.use_unmasked_for_gen:
110
- fake_loss = fake_loss * mask
111
- else:
112
- pixel_weights = 1 + mask * self.extra_mask_weight_for_gen
113
- fake_loss = fake_loss * pixel_weights
114
-
115
- return fake_loss.mean() * self.weight, dict()
116
-
117
- def pre_discriminator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
118
- generator: nn.Module, discriminator: nn.Module):
119
- real_batch.requires_grad = True
120
-
121
- def discriminator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
122
- discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor,
123
- mask=None) \
124
- -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
125
-
126
- real_loss = F.softplus(-discr_real_pred)
127
- grad_penalty = make_r1_gp(discr_real_pred, real_batch) * self.gp_coef
128
- fake_loss = F.softplus(discr_fake_pred)
129
-
130
- if not self.use_unmasked_for_discr or self.mask_as_fake_target:
131
- # == if masked region should be treated differently
132
- mask = self.interpolate_mask(mask, discr_fake_pred.shape[-2:])
133
- # use_unmasked_for_discr=False only makes sense for fakes;
134
- # for reals there is no difference beetween two regions
135
- fake_loss = fake_loss * mask
136
- if self.mask_as_fake_target:
137
- fake_loss = fake_loss + (1 - mask) * F.softplus(-discr_fake_pred)
138
-
139
- sum_discr_loss = real_loss + grad_penalty + fake_loss
140
- metrics = dict(discr_real_out=discr_real_pred.mean(),
141
- discr_fake_out=discr_fake_pred.mean(),
142
- discr_real_gp=grad_penalty)
143
- return sum_discr_loss.mean(), metrics
144
-
145
- class BCELoss(BaseAdversarialLoss):
146
- def __init__(self, weight):
147
- self.weight = weight
148
- self.bce_loss = nn.BCEWithLogitsLoss()
149
-
150
- def generator_loss(self, discr_fake_pred: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
151
- real_mask_gt = torch.zeros(discr_fake_pred.shape).to(discr_fake_pred.device)
152
- fake_loss = self.bce_loss(discr_fake_pred, real_mask_gt) * self.weight
153
- return fake_loss, dict()
154
-
155
- def pre_discriminator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
156
- generator: nn.Module, discriminator: nn.Module):
157
- real_batch.requires_grad = True
158
-
159
- def discriminator_loss(self,
160
- mask: torch.Tensor,
161
- discr_real_pred: torch.Tensor,
162
- discr_fake_pred: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
163
-
164
- real_mask_gt = torch.zeros(discr_real_pred.shape).to(discr_real_pred.device)
165
- sum_discr_loss = (self.bce_loss(discr_real_pred, real_mask_gt) + self.bce_loss(discr_fake_pred, mask)) / 2
166
- metrics = dict(discr_real_out=discr_real_pred.mean(),
167
- discr_fake_out=discr_fake_pred.mean(),
168
- discr_real_gp=0)
169
- return sum_discr_loss, metrics
170
-
171
-
172
- def make_discrim_loss(kind, **kwargs):
173
- if kind == 'r1':
174
- return NonSaturatingWithR1(**kwargs)
175
- elif kind == 'bce':
176
- return BCELoss(**kwargs)
177
- raise ValueError(f'Unknown adversarial loss kind {kind}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/lama/saicinpainting/training/losses/constants.py DELETED
@@ -1,152 +0,0 @@
1
- weights = {"ade20k":
2
- [6.34517766497462,
3
- 9.328358208955224,
4
- 11.389521640091116,
5
- 16.10305958132045,
6
- 20.833333333333332,
7
- 22.22222222222222,
8
- 25.125628140703515,
9
- 43.29004329004329,
10
- 50.5050505050505,
11
- 54.6448087431694,
12
- 55.24861878453038,
13
- 60.24096385542168,
14
- 62.5,
15
- 66.2251655629139,
16
- 84.74576271186442,
17
- 90.90909090909092,
18
- 91.74311926605505,
19
- 96.15384615384616,
20
- 96.15384615384616,
21
- 97.08737864077669,
22
- 102.04081632653062,
23
- 135.13513513513513,
24
- 149.2537313432836,
25
- 153.84615384615384,
26
- 163.93442622950818,
27
- 166.66666666666666,
28
- 188.67924528301887,
29
- 192.30769230769232,
30
- 217.3913043478261,
31
- 227.27272727272725,
32
- 227.27272727272725,
33
- 227.27272727272725,
34
- 303.03030303030306,
35
- 322.5806451612903,
36
- 333.3333333333333,
37
- 370.3703703703703,
38
- 384.61538461538464,
39
- 416.6666666666667,
40
- 416.6666666666667,
41
- 434.7826086956522,
42
- 434.7826086956522,
43
- 454.5454545454545,
44
- 454.5454545454545,
45
- 500.0,
46
- 526.3157894736842,
47
- 526.3157894736842,
48
- 555.5555555555555,
49
- 555.5555555555555,
50
- 555.5555555555555,
51
- 555.5555555555555,
52
- 555.5555555555555,
53
- 555.5555555555555,
54
- 555.5555555555555,
55
- 588.2352941176471,
56
- 588.2352941176471,
57
- 588.2352941176471,
58
- 588.2352941176471,
59
- 588.2352941176471,
60
- 666.6666666666666,
61
- 666.6666666666666,
62
- 666.6666666666666,
63
- 666.6666666666666,
64
- 714.2857142857143,
65
- 714.2857142857143,
66
- 714.2857142857143,
67
- 714.2857142857143,
68
- 714.2857142857143,
69
- 769.2307692307693,
70
- 769.2307692307693,
71
- 769.2307692307693,
72
- 833.3333333333334,
73
- 833.3333333333334,
74
- 833.3333333333334,
75
- 833.3333333333334,
76
- 909.090909090909,
77
- 1000.0,
78
- 1111.111111111111,
79
- 1111.111111111111,
80
- 1111.111111111111,
81
- 1111.111111111111,
82
- 1111.111111111111,
83
- 1250.0,
84
- 1250.0,
85
- 1250.0,
86
- 1250.0,
87
- 1250.0,
88
- 1428.5714285714287,
89
- 1428.5714285714287,
90
- 1428.5714285714287,
91
- 1428.5714285714287,
92
- 1428.5714285714287,
93
- 1428.5714285714287,
94
- 1428.5714285714287,
95
- 1666.6666666666667,
96
- 1666.6666666666667,
97
- 1666.6666666666667,
98
- 1666.6666666666667,
99
- 1666.6666666666667,
100
- 1666.6666666666667,
101
- 1666.6666666666667,
102
- 1666.6666666666667,
103
- 1666.6666666666667,
104
- 1666.6666666666667,
105
- 1666.6666666666667,
106
- 2000.0,
107
- 2000.0,
108
- 2000.0,
109
- 2000.0,
110
- 2000.0,
111
- 2000.0,
112
- 2000.0,
113
- 2000.0,
114
- 2000.0,
115
- 2000.0,
116
- 2000.0,
117
- 2000.0,
118
- 2000.0,
119
- 2000.0,
120
- 2000.0,
121
- 2000.0,
122
- 2000.0,
123
- 2500.0,
124
- 2500.0,
125
- 2500.0,
126
- 2500.0,
127
- 2500.0,
128
- 2500.0,
129
- 2500.0,
130
- 2500.0,
131
- 2500.0,
132
- 2500.0,
133
- 2500.0,
134
- 2500.0,
135
- 2500.0,
136
- 3333.3333333333335,
137
- 3333.3333333333335,
138
- 3333.3333333333335,
139
- 3333.3333333333335,
140
- 3333.3333333333335,
141
- 3333.3333333333335,
142
- 3333.3333333333335,
143
- 3333.3333333333335,
144
- 3333.3333333333335,
145
- 3333.3333333333335,
146
- 3333.3333333333335,
147
- 3333.3333333333335,
148
- 3333.3333333333335,
149
- 5000.0,
150
- 5000.0,
151
- 5000.0]
152
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/lama/saicinpainting/training/losses/distance_weighting.py DELETED
@@ -1,126 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import torchvision
5
-
6
- from annotator.lama.saicinpainting.training.losses.perceptual import IMAGENET_STD, IMAGENET_MEAN
7
-
8
-
9
- def dummy_distance_weighter(real_img, pred_img, mask):
10
- return mask
11
-
12
-
13
- def get_gauss_kernel(kernel_size, width_factor=1):
14
- coords = torch.stack(torch.meshgrid(torch.arange(kernel_size),
15
- torch.arange(kernel_size)),
16
- dim=0).float()
17
- diff = torch.exp(-((coords - kernel_size // 2) ** 2).sum(0) / kernel_size / width_factor)
18
- diff /= diff.sum()
19
- return diff
20
-
21
-
22
- class BlurMask(nn.Module):
23
- def __init__(self, kernel_size=5, width_factor=1):
24
- super().__init__()
25
- self.filter = nn.Conv2d(1, 1, kernel_size, padding=kernel_size // 2, padding_mode='replicate', bias=False)
26
- self.filter.weight.data.copy_(get_gauss_kernel(kernel_size, width_factor=width_factor))
27
-
28
- def forward(self, real_img, pred_img, mask):
29
- with torch.no_grad():
30
- result = self.filter(mask) * mask
31
- return result
32
-
33
-
34
- class EmulatedEDTMask(nn.Module):
35
- def __init__(self, dilate_kernel_size=5, blur_kernel_size=5, width_factor=1):
36
- super().__init__()
37
- self.dilate_filter = nn.Conv2d(1, 1, dilate_kernel_size, padding=dilate_kernel_size// 2, padding_mode='replicate',
38
- bias=False)
39
- self.dilate_filter.weight.data.copy_(torch.ones(1, 1, dilate_kernel_size, dilate_kernel_size, dtype=torch.float))
40
- self.blur_filter = nn.Conv2d(1, 1, blur_kernel_size, padding=blur_kernel_size // 2, padding_mode='replicate', bias=False)
41
- self.blur_filter.weight.data.copy_(get_gauss_kernel(blur_kernel_size, width_factor=width_factor))
42
-
43
- def forward(self, real_img, pred_img, mask):
44
- with torch.no_grad():
45
- known_mask = 1 - mask
46
- dilated_known_mask = (self.dilate_filter(known_mask) > 1).float()
47
- result = self.blur_filter(1 - dilated_known_mask) * mask
48
- return result
49
-
50
-
51
- class PropagatePerceptualSim(nn.Module):
52
- def __init__(self, level=2, max_iters=10, temperature=500, erode_mask_size=3):
53
- super().__init__()
54
- vgg = torchvision.models.vgg19(pretrained=True).features
55
- vgg_avg_pooling = []
56
-
57
- for weights in vgg.parameters():
58
- weights.requires_grad = False
59
-
60
- cur_level_i = 0
61
- for module in vgg.modules():
62
- if module.__class__.__name__ == 'Sequential':
63
- continue
64
- elif module.__class__.__name__ == 'MaxPool2d':
65
- vgg_avg_pooling.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0))
66
- else:
67
- vgg_avg_pooling.append(module)
68
- if module.__class__.__name__ == 'ReLU':
69
- cur_level_i += 1
70
- if cur_level_i == level:
71
- break
72
-
73
- self.features = nn.Sequential(*vgg_avg_pooling)
74
-
75
- self.max_iters = max_iters
76
- self.temperature = temperature
77
- self.do_erode = erode_mask_size > 0
78
- if self.do_erode:
79
- self.erode_mask = nn.Conv2d(1, 1, erode_mask_size, padding=erode_mask_size // 2, bias=False)
80
- self.erode_mask.weight.data.fill_(1)
81
-
82
- def forward(self, real_img, pred_img, mask):
83
- with torch.no_grad():
84
- real_img = (real_img - IMAGENET_MEAN.to(real_img)) / IMAGENET_STD.to(real_img)
85
- real_feats = self.features(real_img)
86
-
87
- vertical_sim = torch.exp(-(real_feats[:, :, 1:] - real_feats[:, :, :-1]).pow(2).sum(1, keepdim=True)
88
- / self.temperature)
89
- horizontal_sim = torch.exp(-(real_feats[:, :, :, 1:] - real_feats[:, :, :, :-1]).pow(2).sum(1, keepdim=True)
90
- / self.temperature)
91
-
92
- mask_scaled = F.interpolate(mask, size=real_feats.shape[-2:], mode='bilinear', align_corners=False)
93
- if self.do_erode:
94
- mask_scaled = (self.erode_mask(mask_scaled) > 1).float()
95
-
96
- cur_knowness = 1 - mask_scaled
97
-
98
- for iter_i in range(self.max_iters):
99
- new_top_knowness = F.pad(cur_knowness[:, :, :-1] * vertical_sim, (0, 0, 1, 0), mode='replicate')
100
- new_bottom_knowness = F.pad(cur_knowness[:, :, 1:] * vertical_sim, (0, 0, 0, 1), mode='replicate')
101
-
102
- new_left_knowness = F.pad(cur_knowness[:, :, :, :-1] * horizontal_sim, (1, 0, 0, 0), mode='replicate')
103
- new_right_knowness = F.pad(cur_knowness[:, :, :, 1:] * horizontal_sim, (0, 1, 0, 0), mode='replicate')
104
-
105
- new_knowness = torch.stack([new_top_knowness, new_bottom_knowness,
106
- new_left_knowness, new_right_knowness],
107
- dim=0).max(0).values
108
-
109
- cur_knowness = torch.max(cur_knowness, new_knowness)
110
-
111
- cur_knowness = F.interpolate(cur_knowness, size=mask.shape[-2:], mode='bilinear')
112
- result = torch.min(mask, 1 - cur_knowness)
113
-
114
- return result
115
-
116
-
117
- def make_mask_distance_weighter(kind='none', **kwargs):
118
- if kind == 'none':
119
- return dummy_distance_weighter
120
- if kind == 'blur':
121
- return BlurMask(**kwargs)
122
- if kind == 'edt':
123
- return EmulatedEDTMask(**kwargs)
124
- if kind == 'pps':
125
- return PropagatePerceptualSim(**kwargs)
126
- raise ValueError(f'Unknown mask distance weighter kind {kind}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/lama/saicinpainting/training/losses/feature_matching.py DELETED
@@ -1,33 +0,0 @@
1
- from typing import List
2
-
3
- import torch
4
- import torch.nn.functional as F
5
-
6
-
7
- def masked_l2_loss(pred, target, mask, weight_known, weight_missing):
8
- per_pixel_l2 = F.mse_loss(pred, target, reduction='none')
9
- pixel_weights = mask * weight_missing + (1 - mask) * weight_known
10
- return (pixel_weights * per_pixel_l2).mean()
11
-
12
-
13
- def masked_l1_loss(pred, target, mask, weight_known, weight_missing):
14
- per_pixel_l1 = F.l1_loss(pred, target, reduction='none')
15
- pixel_weights = mask * weight_missing + (1 - mask) * weight_known
16
- return (pixel_weights * per_pixel_l1).mean()
17
-
18
-
19
- def feature_matching_loss(fake_features: List[torch.Tensor], target_features: List[torch.Tensor], mask=None):
20
- if mask is None:
21
- res = torch.stack([F.mse_loss(fake_feat, target_feat)
22
- for fake_feat, target_feat in zip(fake_features, target_features)]).mean()
23
- else:
24
- res = 0
25
- norm = 0
26
- for fake_feat, target_feat in zip(fake_features, target_features):
27
- cur_mask = F.interpolate(mask, size=fake_feat.shape[-2:], mode='bilinear', align_corners=False)
28
- error_weights = 1 - cur_mask
29
- cur_val = ((fake_feat - target_feat).pow(2) * error_weights).mean()
30
- res = res + cur_val
31
- norm += 1
32
- res = res / norm
33
- return res
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/lama/saicinpainting/training/losses/perceptual.py DELETED
@@ -1,113 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import torchvision
5
-
6
- # from models.ade20k import ModelBuilder
7
- from annotator.lama.saicinpainting.utils import check_and_warn_input_range
8
-
9
-
10
- IMAGENET_MEAN = torch.FloatTensor([0.485, 0.456, 0.406])[None, :, None, None]
11
- IMAGENET_STD = torch.FloatTensor([0.229, 0.224, 0.225])[None, :, None, None]
12
-
13
-
14
- class PerceptualLoss(nn.Module):
15
- def __init__(self, normalize_inputs=True):
16
- super(PerceptualLoss, self).__init__()
17
-
18
- self.normalize_inputs = normalize_inputs
19
- self.mean_ = IMAGENET_MEAN
20
- self.std_ = IMAGENET_STD
21
-
22
- vgg = torchvision.models.vgg19(pretrained=True).features
23
- vgg_avg_pooling = []
24
-
25
- for weights in vgg.parameters():
26
- weights.requires_grad = False
27
-
28
- for module in vgg.modules():
29
- if module.__class__.__name__ == 'Sequential':
30
- continue
31
- elif module.__class__.__name__ == 'MaxPool2d':
32
- vgg_avg_pooling.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0))
33
- else:
34
- vgg_avg_pooling.append(module)
35
-
36
- self.vgg = nn.Sequential(*vgg_avg_pooling)
37
-
38
- def do_normalize_inputs(self, x):
39
- return (x - self.mean_.to(x.device)) / self.std_.to(x.device)
40
-
41
- def partial_losses(self, input, target, mask=None):
42
- check_and_warn_input_range(target, 0, 1, 'PerceptualLoss target in partial_losses')
43
-
44
- # we expect input and target to be in [0, 1] range
45
- losses = []
46
-
47
- if self.normalize_inputs:
48
- features_input = self.do_normalize_inputs(input)
49
- features_target = self.do_normalize_inputs(target)
50
- else:
51
- features_input = input
52
- features_target = target
53
-
54
- for layer in self.vgg[:30]:
55
-
56
- features_input = layer(features_input)
57
- features_target = layer(features_target)
58
-
59
- if layer.__class__.__name__ == 'ReLU':
60
- loss = F.mse_loss(features_input, features_target, reduction='none')
61
-
62
- if mask is not None:
63
- cur_mask = F.interpolate(mask, size=features_input.shape[-2:],
64
- mode='bilinear', align_corners=False)
65
- loss = loss * (1 - cur_mask)
66
-
67
- loss = loss.mean(dim=tuple(range(1, len(loss.shape))))
68
- losses.append(loss)
69
-
70
- return losses
71
-
72
- def forward(self, input, target, mask=None):
73
- losses = self.partial_losses(input, target, mask=mask)
74
- return torch.stack(losses).sum(dim=0)
75
-
76
- def get_global_features(self, input):
77
- check_and_warn_input_range(input, 0, 1, 'PerceptualLoss input in get_global_features')
78
-
79
- if self.normalize_inputs:
80
- features_input = self.do_normalize_inputs(input)
81
- else:
82
- features_input = input
83
-
84
- features_input = self.vgg(features_input)
85
- return features_input
86
-
87
-
88
- class ResNetPL(nn.Module):
89
- def __init__(self, weight=1,
90
- weights_path=None, arch_encoder='resnet50dilated', segmentation=True):
91
- super().__init__()
92
- self.impl = ModelBuilder.get_encoder(weights_path=weights_path,
93
- arch_encoder=arch_encoder,
94
- arch_decoder='ppm_deepsup',
95
- fc_dim=2048,
96
- segmentation=segmentation)
97
- self.impl.eval()
98
- for w in self.impl.parameters():
99
- w.requires_grad_(False)
100
-
101
- self.weight = weight
102
-
103
- def forward(self, pred, target):
104
- pred = (pred - IMAGENET_MEAN.to(pred)) / IMAGENET_STD.to(pred)
105
- target = (target - IMAGENET_MEAN.to(target)) / IMAGENET_STD.to(target)
106
-
107
- pred_feats = self.impl(pred, return_feature_maps=True)
108
- target_feats = self.impl(target, return_feature_maps=True)
109
-
110
- result = torch.stack([F.mse_loss(cur_pred, cur_target)
111
- for cur_pred, cur_target
112
- in zip(pred_feats, target_feats)]).sum() * self.weight
113
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/lama/saicinpainting/training/losses/segmentation.py DELETED
@@ -1,43 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
- from .constants import weights as constant_weights
6
-
7
-
8
- class CrossEntropy2d(nn.Module):
9
- def __init__(self, reduction="mean", ignore_label=255, weights=None, *args, **kwargs):
10
- """
11
- weight (Tensor, optional): a manual rescaling weight given to each class.
12
- If given, has to be a Tensor of size "nclasses"
13
- """
14
- super(CrossEntropy2d, self).__init__()
15
- self.reduction = reduction
16
- self.ignore_label = ignore_label
17
- self.weights = weights
18
- if self.weights is not None:
19
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20
- self.weights = torch.FloatTensor(constant_weights[weights]).to(device)
21
-
22
- def forward(self, predict, target):
23
- """
24
- Args:
25
- predict:(n, c, h, w)
26
- target:(n, 1, h, w)
27
- """
28
- target = target.long()
29
- assert not target.requires_grad
30
- assert predict.dim() == 4, "{0}".format(predict.size())
31
- assert target.dim() == 4, "{0}".format(target.size())
32
- assert predict.size(0) == target.size(0), "{0} vs {1} ".format(predict.size(0), target.size(0))
33
- assert target.size(1) == 1, "{0}".format(target.size(1))
34
- assert predict.size(2) == target.size(2), "{0} vs {1} ".format(predict.size(2), target.size(2))
35
- assert predict.size(3) == target.size(3), "{0} vs {1} ".format(predict.size(3), target.size(3))
36
- target = target.squeeze(1)
37
- n, c, h, w = predict.size()
38
- target_mask = (target >= 0) * (target != self.ignore_label)
39
- target = target[target_mask]
40
- predict = predict.transpose(1, 2).transpose(2, 3).contiguous()
41
- predict = predict[target_mask.view(n, h, w, 1).repeat(1, 1, 1, c)].view(-1, c)
42
- loss = F.cross_entropy(predict, target, weight=self.weights, reduction=self.reduction)
43
- return loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/lama/saicinpainting/training/losses/style_loss.py DELETED
@@ -1,155 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torchvision.models as models
4
-
5
-
6
- class PerceptualLoss(nn.Module):
7
- r"""
8
- Perceptual loss, VGG-based
9
- https://arxiv.org/abs/1603.08155
10
- https://github.com/dxyang/StyleTransfer/blob/master/utils.py
11
- """
12
-
13
- def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]):
14
- super(PerceptualLoss, self).__init__()
15
- self.add_module('vgg', VGG19())
16
- self.criterion = torch.nn.L1Loss()
17
- self.weights = weights
18
-
19
- def __call__(self, x, y):
20
- # Compute features
21
- x_vgg, y_vgg = self.vgg(x), self.vgg(y)
22
-
23
- content_loss = 0.0
24
- content_loss += self.weights[0] * self.criterion(x_vgg['relu1_1'], y_vgg['relu1_1'])
25
- content_loss += self.weights[1] * self.criterion(x_vgg['relu2_1'], y_vgg['relu2_1'])
26
- content_loss += self.weights[2] * self.criterion(x_vgg['relu3_1'], y_vgg['relu3_1'])
27
- content_loss += self.weights[3] * self.criterion(x_vgg['relu4_1'], y_vgg['relu4_1'])
28
- content_loss += self.weights[4] * self.criterion(x_vgg['relu5_1'], y_vgg['relu5_1'])
29
-
30
-
31
- return content_loss
32
-
33
-
34
- class VGG19(torch.nn.Module):
35
- def __init__(self):
36
- super(VGG19, self).__init__()
37
- features = models.vgg19(pretrained=True).features
38
- self.relu1_1 = torch.nn.Sequential()
39
- self.relu1_2 = torch.nn.Sequential()
40
-
41
- self.relu2_1 = torch.nn.Sequential()
42
- self.relu2_2 = torch.nn.Sequential()
43
-
44
- self.relu3_1 = torch.nn.Sequential()
45
- self.relu3_2 = torch.nn.Sequential()
46
- self.relu3_3 = torch.nn.Sequential()
47
- self.relu3_4 = torch.nn.Sequential()
48
-
49
- self.relu4_1 = torch.nn.Sequential()
50
- self.relu4_2 = torch.nn.Sequential()
51
- self.relu4_3 = torch.nn.Sequential()
52
- self.relu4_4 = torch.nn.Sequential()
53
-
54
- self.relu5_1 = torch.nn.Sequential()
55
- self.relu5_2 = torch.nn.Sequential()
56
- self.relu5_3 = torch.nn.Sequential()
57
- self.relu5_4 = torch.nn.Sequential()
58
-
59
- for x in range(2):
60
- self.relu1_1.add_module(str(x), features[x])
61
-
62
- for x in range(2, 4):
63
- self.relu1_2.add_module(str(x), features[x])
64
-
65
- for x in range(4, 7):
66
- self.relu2_1.add_module(str(x), features[x])
67
-
68
- for x in range(7, 9):
69
- self.relu2_2.add_module(str(x), features[x])
70
-
71
- for x in range(9, 12):
72
- self.relu3_1.add_module(str(x), features[x])
73
-
74
- for x in range(12, 14):
75
- self.relu3_2.add_module(str(x), features[x])
76
-
77
- for x in range(14, 16):
78
- self.relu3_2.add_module(str(x), features[x])
79
-
80
- for x in range(16, 18):
81
- self.relu3_4.add_module(str(x), features[x])
82
-
83
- for x in range(18, 21):
84
- self.relu4_1.add_module(str(x), features[x])
85
-
86
- for x in range(21, 23):
87
- self.relu4_2.add_module(str(x), features[x])
88
-
89
- for x in range(23, 25):
90
- self.relu4_3.add_module(str(x), features[x])
91
-
92
- for x in range(25, 27):
93
- self.relu4_4.add_module(str(x), features[x])
94
-
95
- for x in range(27, 30):
96
- self.relu5_1.add_module(str(x), features[x])
97
-
98
- for x in range(30, 32):
99
- self.relu5_2.add_module(str(x), features[x])
100
-
101
- for x in range(32, 34):
102
- self.relu5_3.add_module(str(x), features[x])
103
-
104
- for x in range(34, 36):
105
- self.relu5_4.add_module(str(x), features[x])
106
-
107
- # don't need the gradients, just want the features
108
- for param in self.parameters():
109
- param.requires_grad = False
110
-
111
- def forward(self, x):
112
- relu1_1 = self.relu1_1(x)
113
- relu1_2 = self.relu1_2(relu1_1)
114
-
115
- relu2_1 = self.relu2_1(relu1_2)
116
- relu2_2 = self.relu2_2(relu2_1)
117
-
118
- relu3_1 = self.relu3_1(relu2_2)
119
- relu3_2 = self.relu3_2(relu3_1)
120
- relu3_3 = self.relu3_3(relu3_2)
121
- relu3_4 = self.relu3_4(relu3_3)
122
-
123
- relu4_1 = self.relu4_1(relu3_4)
124
- relu4_2 = self.relu4_2(relu4_1)
125
- relu4_3 = self.relu4_3(relu4_2)
126
- relu4_4 = self.relu4_4(relu4_3)
127
-
128
- relu5_1 = self.relu5_1(relu4_4)
129
- relu5_2 = self.relu5_2(relu5_1)
130
- relu5_3 = self.relu5_3(relu5_2)
131
- relu5_4 = self.relu5_4(relu5_3)
132
-
133
- out = {
134
- 'relu1_1': relu1_1,
135
- 'relu1_2': relu1_2,
136
-
137
- 'relu2_1': relu2_1,
138
- 'relu2_2': relu2_2,
139
-
140
- 'relu3_1': relu3_1,
141
- 'relu3_2': relu3_2,
142
- 'relu3_3': relu3_3,
143
- 'relu3_4': relu3_4,
144
-
145
- 'relu4_1': relu4_1,
146
- 'relu4_2': relu4_2,
147
- 'relu4_3': relu4_3,
148
- 'relu4_4': relu4_4,
149
-
150
- 'relu5_1': relu5_1,
151
- 'relu5_2': relu5_2,
152
- 'relu5_3': relu5_3,
153
- 'relu5_4': relu5_4,
154
- }
155
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/lama/saicinpainting/training/modules/__init__.py DELETED
@@ -1,31 +0,0 @@
1
- import logging
2
-
3
- from annotator.lama.saicinpainting.training.modules.ffc import FFCResNetGenerator
4
- from annotator.lama.saicinpainting.training.modules.pix2pixhd import GlobalGenerator, MultiDilatedGlobalGenerator, \
5
- NLayerDiscriminator, MultidilatedNLayerDiscriminator
6
-
7
- def make_generator(config, kind, **kwargs):
8
- logging.info(f'Make generator {kind}')
9
-
10
- if kind == 'pix2pixhd_multidilated':
11
- return MultiDilatedGlobalGenerator(**kwargs)
12
-
13
- if kind == 'pix2pixhd_global':
14
- return GlobalGenerator(**kwargs)
15
-
16
- if kind == 'ffc_resnet':
17
- return FFCResNetGenerator(**kwargs)
18
-
19
- raise ValueError(f'Unknown generator kind {kind}')
20
-
21
-
22
- def make_discriminator(kind, **kwargs):
23
- logging.info(f'Make discriminator {kind}')
24
-
25
- if kind == 'pix2pixhd_nlayer_multidilated':
26
- return MultidilatedNLayerDiscriminator(**kwargs)
27
-
28
- if kind == 'pix2pixhd_nlayer':
29
- return NLayerDiscriminator(**kwargs)
30
-
31
- raise ValueError(f'Unknown discriminator kind {kind}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/lama/saicinpainting/training/modules/base.py DELETED
@@ -1,80 +0,0 @@
1
- import abc
2
- from typing import Tuple, List
3
-
4
- import torch
5
- import torch.nn as nn
6
-
7
- from annotator.lama.saicinpainting.training.modules.depthwise_sep_conv import DepthWiseSeperableConv
8
- from annotator.lama.saicinpainting.training.modules.multidilated_conv import MultidilatedConv
9
-
10
-
11
- class BaseDiscriminator(nn.Module):
12
- @abc.abstractmethod
13
- def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
14
- """
15
- Predict scores and get intermediate activations. Useful for feature matching loss
16
- :return tuple (scores, list of intermediate activations)
17
- """
18
- raise NotImplemented()
19
-
20
-
21
- def get_conv_block_ctor(kind='default'):
22
- if not isinstance(kind, str):
23
- return kind
24
- if kind == 'default':
25
- return nn.Conv2d
26
- if kind == 'depthwise':
27
- return DepthWiseSeperableConv
28
- if kind == 'multidilated':
29
- return MultidilatedConv
30
- raise ValueError(f'Unknown convolutional block kind {kind}')
31
-
32
-
33
- def get_norm_layer(kind='bn'):
34
- if not isinstance(kind, str):
35
- return kind
36
- if kind == 'bn':
37
- return nn.BatchNorm2d
38
- if kind == 'in':
39
- return nn.InstanceNorm2d
40
- raise ValueError(f'Unknown norm block kind {kind}')
41
-
42
-
43
- def get_activation(kind='tanh'):
44
- if kind == 'tanh':
45
- return nn.Tanh()
46
- if kind == 'sigmoid':
47
- return nn.Sigmoid()
48
- if kind is False:
49
- return nn.Identity()
50
- raise ValueError(f'Unknown activation kind {kind}')
51
-
52
-
53
- class SimpleMultiStepGenerator(nn.Module):
54
- def __init__(self, steps: List[nn.Module]):
55
- super().__init__()
56
- self.steps = nn.ModuleList(steps)
57
-
58
- def forward(self, x):
59
- cur_in = x
60
- outs = []
61
- for step in self.steps:
62
- cur_out = step(cur_in)
63
- outs.append(cur_out)
64
- cur_in = torch.cat((cur_in, cur_out), dim=1)
65
- return torch.cat(outs[::-1], dim=1)
66
-
67
- def deconv_factory(kind, ngf, mult, norm_layer, activation, max_features):
68
- if kind == 'convtranspose':
69
- return [nn.ConvTranspose2d(min(max_features, ngf * mult),
70
- min(max_features, int(ngf * mult / 2)),
71
- kernel_size=3, stride=2, padding=1, output_padding=1),
72
- norm_layer(min(max_features, int(ngf * mult / 2))), activation]
73
- elif kind == 'bilinear':
74
- return [nn.Upsample(scale_factor=2, mode='bilinear'),
75
- DepthWiseSeperableConv(min(max_features, ngf * mult),
76
- min(max_features, int(ngf * mult / 2)),
77
- kernel_size=3, stride=1, padding=1),
78
- norm_layer(min(max_features, int(ngf * mult / 2))), activation]
79
- else:
80
- raise Exception(f"Invalid deconv kind: {kind}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/lama/saicinpainting/training/modules/depthwise_sep_conv.py DELETED
@@ -1,17 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
-
4
- class DepthWiseSeperableConv(nn.Module):
5
- def __init__(self, in_dim, out_dim, *args, **kwargs):
6
- super().__init__()
7
- if 'groups' in kwargs:
8
- # ignoring groups for Depthwise Sep Conv
9
- del kwargs['groups']
10
-
11
- self.depthwise = nn.Conv2d(in_dim, in_dim, *args, groups=in_dim, **kwargs)
12
- self.pointwise = nn.Conv2d(in_dim, out_dim, kernel_size=1)
13
-
14
- def forward(self, x):
15
- out = self.depthwise(x)
16
- out = self.pointwise(out)
17
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/lama/saicinpainting/training/modules/fake_fakes.py DELETED
@@ -1,47 +0,0 @@
1
- import torch
2
- from kornia import SamplePadding
3
- from kornia.augmentation import RandomAffine, CenterCrop
4
-
5
-
6
- class FakeFakesGenerator:
7
- def __init__(self, aug_proba=0.5, img_aug_degree=30, img_aug_translate=0.2):
8
- self.grad_aug = RandomAffine(degrees=360,
9
- translate=0.2,
10
- padding_mode=SamplePadding.REFLECTION,
11
- keepdim=False,
12
- p=1)
13
- self.img_aug = RandomAffine(degrees=img_aug_degree,
14
- translate=img_aug_translate,
15
- padding_mode=SamplePadding.REFLECTION,
16
- keepdim=True,
17
- p=1)
18
- self.aug_proba = aug_proba
19
-
20
- def __call__(self, input_images, masks):
21
- blend_masks = self._fill_masks_with_gradient(masks)
22
- blend_target = self._make_blend_target(input_images)
23
- result = input_images * (1 - blend_masks) + blend_target * blend_masks
24
- return result, blend_masks
25
-
26
- def _make_blend_target(self, input_images):
27
- batch_size = input_images.shape[0]
28
- permuted = input_images[torch.randperm(batch_size)]
29
- augmented = self.img_aug(input_images)
30
- is_aug = (torch.rand(batch_size, device=input_images.device)[:, None, None, None] < self.aug_proba).float()
31
- result = augmented * is_aug + permuted * (1 - is_aug)
32
- return result
33
-
34
- def _fill_masks_with_gradient(self, masks):
35
- batch_size, _, height, width = masks.shape
36
- grad = torch.linspace(0, 1, steps=width * 2, device=masks.device, dtype=masks.dtype) \
37
- .view(1, 1, 1, -1).expand(batch_size, 1, height * 2, width * 2)
38
- grad = self.grad_aug(grad)
39
- grad = CenterCrop((height, width))(grad)
40
- grad *= masks
41
-
42
- grad_for_min = grad + (1 - masks) * 10
43
- grad -= grad_for_min.view(batch_size, -1).min(-1).values[:, None, None, None]
44
- grad /= grad.view(batch_size, -1).max(-1).values[:, None, None, None] + 1e-6
45
- grad.clamp_(min=0, max=1)
46
-
47
- return grad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/lama/saicinpainting/training/modules/ffc.py DELETED
@@ -1,485 +0,0 @@
1
- # Fast Fourier Convolution NeurIPS 2020
2
- # original implementation https://github.com/pkumivision/FFC/blob/main/model_zoo/ffc.py
3
- # paper https://proceedings.neurips.cc/paper/2020/file/2fd5d41ec6cfab47e32164d5624269b1-Paper.pdf
4
-
5
- import numpy as np
6
- import torch
7
- import torch.nn as nn
8
- import torch.nn.functional as F
9
-
10
- from annotator.lama.saicinpainting.training.modules.base import get_activation, BaseDiscriminator
11
- from annotator.lama.saicinpainting.training.modules.spatial_transform import LearnableSpatialTransformWrapper
12
- from annotator.lama.saicinpainting.training.modules.squeeze_excitation import SELayer
13
- from annotator.lama.saicinpainting.utils import get_shape
14
-
15
-
16
- class FFCSE_block(nn.Module):
17
-
18
- def __init__(self, channels, ratio_g):
19
- super(FFCSE_block, self).__init__()
20
- in_cg = int(channels * ratio_g)
21
- in_cl = channels - in_cg
22
- r = 16
23
-
24
- self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
25
- self.conv1 = nn.Conv2d(channels, channels // r,
26
- kernel_size=1, bias=True)
27
- self.relu1 = nn.ReLU(inplace=True)
28
- self.conv_a2l = None if in_cl == 0 else nn.Conv2d(
29
- channels // r, in_cl, kernel_size=1, bias=True)
30
- self.conv_a2g = None if in_cg == 0 else nn.Conv2d(
31
- channels // r, in_cg, kernel_size=1, bias=True)
32
- self.sigmoid = nn.Sigmoid()
33
-
34
- def forward(self, x):
35
- x = x if type(x) is tuple else (x, 0)
36
- id_l, id_g = x
37
-
38
- x = id_l if type(id_g) is int else torch.cat([id_l, id_g], dim=1)
39
- x = self.avgpool(x)
40
- x = self.relu1(self.conv1(x))
41
-
42
- x_l = 0 if self.conv_a2l is None else id_l * \
43
- self.sigmoid(self.conv_a2l(x))
44
- x_g = 0 if self.conv_a2g is None else id_g * \
45
- self.sigmoid(self.conv_a2g(x))
46
- return x_l, x_g
47
-
48
-
49
- class FourierUnit(nn.Module):
50
-
51
- def __init__(self, in_channels, out_channels, groups=1, spatial_scale_factor=None, spatial_scale_mode='bilinear',
52
- spectral_pos_encoding=False, use_se=False, se_kwargs=None, ffc3d=False, fft_norm='ortho'):
53
- # bn_layer not used
54
- super(FourierUnit, self).__init__()
55
- self.groups = groups
56
-
57
- self.conv_layer = torch.nn.Conv2d(in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0),
58
- out_channels=out_channels * 2,
59
- kernel_size=1, stride=1, padding=0, groups=self.groups, bias=False)
60
- self.bn = torch.nn.BatchNorm2d(out_channels * 2)
61
- self.relu = torch.nn.ReLU(inplace=True)
62
-
63
- # squeeze and excitation block
64
- self.use_se = use_se
65
- if use_se:
66
- if se_kwargs is None:
67
- se_kwargs = {}
68
- self.se = SELayer(self.conv_layer.in_channels, **se_kwargs)
69
-
70
- self.spatial_scale_factor = spatial_scale_factor
71
- self.spatial_scale_mode = spatial_scale_mode
72
- self.spectral_pos_encoding = spectral_pos_encoding
73
- self.ffc3d = ffc3d
74
- self.fft_norm = fft_norm
75
-
76
- def forward(self, x):
77
- batch = x.shape[0]
78
-
79
- if self.spatial_scale_factor is not None:
80
- orig_size = x.shape[-2:]
81
- x = F.interpolate(x, scale_factor=self.spatial_scale_factor, mode=self.spatial_scale_mode, align_corners=False)
82
-
83
- r_size = x.size()
84
- # (batch, c, h, w/2+1, 2)
85
- fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
86
- ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
87
- ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
88
- ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1)
89
- ffted = ffted.view((batch, -1,) + ffted.size()[3:])
90
-
91
- if self.spectral_pos_encoding:
92
- height, width = ffted.shape[-2:]
93
- coords_vert = torch.linspace(0, 1, height)[None, None, :, None].expand(batch, 1, height, width).to(ffted)
94
- coords_hor = torch.linspace(0, 1, width)[None, None, None, :].expand(batch, 1, height, width).to(ffted)
95
- ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1)
96
-
97
- if self.use_se:
98
- ffted = self.se(ffted)
99
-
100
- ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1)
101
- ffted = self.relu(self.bn(ffted))
102
-
103
- ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(
104
- 0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2)
105
- ffted = torch.complex(ffted[..., 0], ffted[..., 1])
106
-
107
- ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:]
108
- output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm)
109
-
110
- if self.spatial_scale_factor is not None:
111
- output = F.interpolate(output, size=orig_size, mode=self.spatial_scale_mode, align_corners=False)
112
-
113
- return output
114
-
115
-
116
- class SeparableFourierUnit(nn.Module):
117
-
118
- def __init__(self, in_channels, out_channels, groups=1, kernel_size=3):
119
- # bn_layer not used
120
- super(SeparableFourierUnit, self).__init__()
121
- self.groups = groups
122
- row_out_channels = out_channels // 2
123
- col_out_channels = out_channels - row_out_channels
124
- self.row_conv = torch.nn.Conv2d(in_channels=in_channels * 2,
125
- out_channels=row_out_channels * 2,
126
- kernel_size=(kernel_size, 1), # kernel size is always like this, but the data will be transposed
127
- stride=1, padding=(kernel_size // 2, 0),
128
- padding_mode='reflect',
129
- groups=self.groups, bias=False)
130
- self.col_conv = torch.nn.Conv2d(in_channels=in_channels * 2,
131
- out_channels=col_out_channels * 2,
132
- kernel_size=(kernel_size, 1), # kernel size is always like this, but the data will be transposed
133
- stride=1, padding=(kernel_size // 2, 0),
134
- padding_mode='reflect',
135
- groups=self.groups, bias=False)
136
- self.row_bn = torch.nn.BatchNorm2d(row_out_channels * 2)
137
- self.col_bn = torch.nn.BatchNorm2d(col_out_channels * 2)
138
- self.relu = torch.nn.ReLU(inplace=True)
139
-
140
- def process_branch(self, x, conv, bn):
141
- batch = x.shape[0]
142
-
143
- r_size = x.size()
144
- # (batch, c, h, w/2+1, 2)
145
- ffted = torch.fft.rfft(x, norm="ortho")
146
- ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
147
- ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1)
148
- ffted = ffted.view((batch, -1,) + ffted.size()[3:])
149
-
150
- ffted = self.relu(bn(conv(ffted)))
151
-
152
- ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(
153
- 0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2)
154
- ffted = torch.complex(ffted[..., 0], ffted[..., 1])
155
-
156
- output = torch.fft.irfft(ffted, s=x.shape[-1:], norm="ortho")
157
- return output
158
-
159
-
160
- def forward(self, x):
161
- rowwise = self.process_branch(x, self.row_conv, self.row_bn)
162
- colwise = self.process_branch(x.permute(0, 1, 3, 2), self.col_conv, self.col_bn).permute(0, 1, 3, 2)
163
- out = torch.cat((rowwise, colwise), dim=1)
164
- return out
165
-
166
-
167
- class SpectralTransform(nn.Module):
168
-
169
- def __init__(self, in_channels, out_channels, stride=1, groups=1, enable_lfu=True, separable_fu=False, **fu_kwargs):
170
- # bn_layer not used
171
- super(SpectralTransform, self).__init__()
172
- self.enable_lfu = enable_lfu
173
- if stride == 2:
174
- self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
175
- else:
176
- self.downsample = nn.Identity()
177
-
178
- self.stride = stride
179
- self.conv1 = nn.Sequential(
180
- nn.Conv2d(in_channels, out_channels //
181
- 2, kernel_size=1, groups=groups, bias=False),
182
- nn.BatchNorm2d(out_channels // 2),
183
- nn.ReLU(inplace=True)
184
- )
185
- fu_class = SeparableFourierUnit if separable_fu else FourierUnit
186
- self.fu = fu_class(
187
- out_channels // 2, out_channels // 2, groups, **fu_kwargs)
188
- if self.enable_lfu:
189
- self.lfu = fu_class(
190
- out_channels // 2, out_channels // 2, groups)
191
- self.conv2 = torch.nn.Conv2d(
192
- out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False)
193
-
194
- def forward(self, x):
195
-
196
- x = self.downsample(x)
197
- x = self.conv1(x)
198
- output = self.fu(x)
199
-
200
- if self.enable_lfu:
201
- n, c, h, w = x.shape
202
- split_no = 2
203
- split_s = h // split_no
204
- xs = torch.cat(torch.split(
205
- x[:, :c // 4], split_s, dim=-2), dim=1).contiguous()
206
- xs = torch.cat(torch.split(xs, split_s, dim=-1),
207
- dim=1).contiguous()
208
- xs = self.lfu(xs)
209
- xs = xs.repeat(1, 1, split_no, split_no).contiguous()
210
- else:
211
- xs = 0
212
-
213
- output = self.conv2(x + output + xs)
214
-
215
- return output
216
-
217
-
218
- class FFC(nn.Module):
219
-
220
- def __init__(self, in_channels, out_channels, kernel_size,
221
- ratio_gin, ratio_gout, stride=1, padding=0,
222
- dilation=1, groups=1, bias=False, enable_lfu=True,
223
- padding_type='reflect', gated=False, **spectral_kwargs):
224
- super(FFC, self).__init__()
225
-
226
- assert stride == 1 or stride == 2, "Stride should be 1 or 2."
227
- self.stride = stride
228
-
229
- in_cg = int(in_channels * ratio_gin)
230
- in_cl = in_channels - in_cg
231
- out_cg = int(out_channels * ratio_gout)
232
- out_cl = out_channels - out_cg
233
- #groups_g = 1 if groups == 1 else int(groups * ratio_gout)
234
- #groups_l = 1 if groups == 1 else groups - groups_g
235
-
236
- self.ratio_gin = ratio_gin
237
- self.ratio_gout = ratio_gout
238
- self.global_in_num = in_cg
239
-
240
- module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d
241
- self.convl2l = module(in_cl, out_cl, kernel_size,
242
- stride, padding, dilation, groups, bias, padding_mode=padding_type)
243
- module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d
244
- self.convl2g = module(in_cl, out_cg, kernel_size,
245
- stride, padding, dilation, groups, bias, padding_mode=padding_type)
246
- module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d
247
- self.convg2l = module(in_cg, out_cl, kernel_size,
248
- stride, padding, dilation, groups, bias, padding_mode=padding_type)
249
- module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform
250
- self.convg2g = module(
251
- in_cg, out_cg, stride, 1 if groups == 1 else groups // 2, enable_lfu, **spectral_kwargs)
252
-
253
- self.gated = gated
254
- module = nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d
255
- self.gate = module(in_channels, 2, 1)
256
-
257
- def forward(self, x):
258
- x_l, x_g = x if type(x) is tuple else (x, 0)
259
- out_xl, out_xg = 0, 0
260
-
261
- if self.gated:
262
- total_input_parts = [x_l]
263
- if torch.is_tensor(x_g):
264
- total_input_parts.append(x_g)
265
- total_input = torch.cat(total_input_parts, dim=1)
266
-
267
- gates = torch.sigmoid(self.gate(total_input))
268
- g2l_gate, l2g_gate = gates.chunk(2, dim=1)
269
- else:
270
- g2l_gate, l2g_gate = 1, 1
271
-
272
- if self.ratio_gout != 1:
273
- out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate
274
- if self.ratio_gout != 0:
275
- out_xg = self.convl2g(x_l) * l2g_gate + self.convg2g(x_g)
276
-
277
- return out_xl, out_xg
278
-
279
-
280
- class FFC_BN_ACT(nn.Module):
281
-
282
- def __init__(self, in_channels, out_channels,
283
- kernel_size, ratio_gin, ratio_gout,
284
- stride=1, padding=0, dilation=1, groups=1, bias=False,
285
- norm_layer=nn.BatchNorm2d, activation_layer=nn.Identity,
286
- padding_type='reflect',
287
- enable_lfu=True, **kwargs):
288
- super(FFC_BN_ACT, self).__init__()
289
- self.ffc = FFC(in_channels, out_channels, kernel_size,
290
- ratio_gin, ratio_gout, stride, padding, dilation,
291
- groups, bias, enable_lfu, padding_type=padding_type, **kwargs)
292
- lnorm = nn.Identity if ratio_gout == 1 else norm_layer
293
- gnorm = nn.Identity if ratio_gout == 0 else norm_layer
294
- global_channels = int(out_channels * ratio_gout)
295
- self.bn_l = lnorm(out_channels - global_channels)
296
- self.bn_g = gnorm(global_channels)
297
-
298
- lact = nn.Identity if ratio_gout == 1 else activation_layer
299
- gact = nn.Identity if ratio_gout == 0 else activation_layer
300
- self.act_l = lact(inplace=True)
301
- self.act_g = gact(inplace=True)
302
-
303
- def forward(self, x):
304
- x_l, x_g = self.ffc(x)
305
- x_l = self.act_l(self.bn_l(x_l))
306
- x_g = self.act_g(self.bn_g(x_g))
307
- return x_l, x_g
308
-
309
-
310
- class FFCResnetBlock(nn.Module):
311
- def __init__(self, dim, padding_type, norm_layer, activation_layer=nn.ReLU, dilation=1,
312
- spatial_transform_kwargs=None, inline=False, **conv_kwargs):
313
- super().__init__()
314
- self.conv1 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation,
315
- norm_layer=norm_layer,
316
- activation_layer=activation_layer,
317
- padding_type=padding_type,
318
- **conv_kwargs)
319
- self.conv2 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation,
320
- norm_layer=norm_layer,
321
- activation_layer=activation_layer,
322
- padding_type=padding_type,
323
- **conv_kwargs)
324
- if spatial_transform_kwargs is not None:
325
- self.conv1 = LearnableSpatialTransformWrapper(self.conv1, **spatial_transform_kwargs)
326
- self.conv2 = LearnableSpatialTransformWrapper(self.conv2, **spatial_transform_kwargs)
327
- self.inline = inline
328
-
329
- def forward(self, x):
330
- if self.inline:
331
- x_l, x_g = x[:, :-self.conv1.ffc.global_in_num], x[:, -self.conv1.ffc.global_in_num:]
332
- else:
333
- x_l, x_g = x if type(x) is tuple else (x, 0)
334
-
335
- id_l, id_g = x_l, x_g
336
-
337
- x_l, x_g = self.conv1((x_l, x_g))
338
- x_l, x_g = self.conv2((x_l, x_g))
339
-
340
- x_l, x_g = id_l + x_l, id_g + x_g
341
- out = x_l, x_g
342
- if self.inline:
343
- out = torch.cat(out, dim=1)
344
- return out
345
-
346
-
347
- class ConcatTupleLayer(nn.Module):
348
- def forward(self, x):
349
- assert isinstance(x, tuple)
350
- x_l, x_g = x
351
- assert torch.is_tensor(x_l) or torch.is_tensor(x_g)
352
- if not torch.is_tensor(x_g):
353
- return x_l
354
- return torch.cat(x, dim=1)
355
-
356
-
357
- class FFCResNetGenerator(nn.Module):
358
- def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
359
- padding_type='reflect', activation_layer=nn.ReLU,
360
- up_norm_layer=nn.BatchNorm2d, up_activation=nn.ReLU(True),
361
- init_conv_kwargs={}, downsample_conv_kwargs={}, resnet_conv_kwargs={},
362
- spatial_transform_layers=None, spatial_transform_kwargs={},
363
- add_out_act=True, max_features=1024, out_ffc=False, out_ffc_kwargs={}):
364
- assert (n_blocks >= 0)
365
- super().__init__()
366
-
367
- model = [nn.ReflectionPad2d(3),
368
- FFC_BN_ACT(input_nc, ngf, kernel_size=7, padding=0, norm_layer=norm_layer,
369
- activation_layer=activation_layer, **init_conv_kwargs)]
370
-
371
- ### downsample
372
- for i in range(n_downsampling):
373
- mult = 2 ** i
374
- if i == n_downsampling - 1:
375
- cur_conv_kwargs = dict(downsample_conv_kwargs)
376
- cur_conv_kwargs['ratio_gout'] = resnet_conv_kwargs.get('ratio_gin', 0)
377
- else:
378
- cur_conv_kwargs = downsample_conv_kwargs
379
- model += [FFC_BN_ACT(min(max_features, ngf * mult),
380
- min(max_features, ngf * mult * 2),
381
- kernel_size=3, stride=2, padding=1,
382
- norm_layer=norm_layer,
383
- activation_layer=activation_layer,
384
- **cur_conv_kwargs)]
385
-
386
- mult = 2 ** n_downsampling
387
- feats_num_bottleneck = min(max_features, ngf * mult)
388
-
389
- ### resnet blocks
390
- for i in range(n_blocks):
391
- cur_resblock = FFCResnetBlock(feats_num_bottleneck, padding_type=padding_type, activation_layer=activation_layer,
392
- norm_layer=norm_layer, **resnet_conv_kwargs)
393
- if spatial_transform_layers is not None and i in spatial_transform_layers:
394
- cur_resblock = LearnableSpatialTransformWrapper(cur_resblock, **spatial_transform_kwargs)
395
- model += [cur_resblock]
396
-
397
- model += [ConcatTupleLayer()]
398
-
399
- ### upsample
400
- for i in range(n_downsampling):
401
- mult = 2 ** (n_downsampling - i)
402
- model += [nn.ConvTranspose2d(min(max_features, ngf * mult),
403
- min(max_features, int(ngf * mult / 2)),
404
- kernel_size=3, stride=2, padding=1, output_padding=1),
405
- up_norm_layer(min(max_features, int(ngf * mult / 2))),
406
- up_activation]
407
-
408
- if out_ffc:
409
- model += [FFCResnetBlock(ngf, padding_type=padding_type, activation_layer=activation_layer,
410
- norm_layer=norm_layer, inline=True, **out_ffc_kwargs)]
411
-
412
- model += [nn.ReflectionPad2d(3),
413
- nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
414
- if add_out_act:
415
- model.append(get_activation('tanh' if add_out_act is True else add_out_act))
416
- self.model = nn.Sequential(*model)
417
-
418
- def forward(self, input):
419
- return self.model(input)
420
-
421
-
422
- class FFCNLayerDiscriminator(BaseDiscriminator):
423
- def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, max_features=512,
424
- init_conv_kwargs={}, conv_kwargs={}):
425
- super().__init__()
426
- self.n_layers = n_layers
427
-
428
- def _act_ctor(inplace=True):
429
- return nn.LeakyReLU(negative_slope=0.2, inplace=inplace)
430
-
431
- kw = 3
432
- padw = int(np.ceil((kw-1.0)/2))
433
- sequence = [[FFC_BN_ACT(input_nc, ndf, kernel_size=kw, padding=padw, norm_layer=norm_layer,
434
- activation_layer=_act_ctor, **init_conv_kwargs)]]
435
-
436
- nf = ndf
437
- for n in range(1, n_layers):
438
- nf_prev = nf
439
- nf = min(nf * 2, max_features)
440
-
441
- cur_model = [
442
- FFC_BN_ACT(nf_prev, nf,
443
- kernel_size=kw, stride=2, padding=padw,
444
- norm_layer=norm_layer,
445
- activation_layer=_act_ctor,
446
- **conv_kwargs)
447
- ]
448
- sequence.append(cur_model)
449
-
450
- nf_prev = nf
451
- nf = min(nf * 2, 512)
452
-
453
- cur_model = [
454
- FFC_BN_ACT(nf_prev, nf,
455
- kernel_size=kw, stride=1, padding=padw,
456
- norm_layer=norm_layer,
457
- activation_layer=lambda *args, **kwargs: nn.LeakyReLU(*args, negative_slope=0.2, **kwargs),
458
- **conv_kwargs),
459
- ConcatTupleLayer()
460
- ]
461
- sequence.append(cur_model)
462
-
463
- sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
464
-
465
- for n in range(len(sequence)):
466
- setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
467
-
468
- def get_all_activations(self, x):
469
- res = [x]
470
- for n in range(self.n_layers + 2):
471
- model = getattr(self, 'model' + str(n))
472
- res.append(model(res[-1]))
473
- return res[1:]
474
-
475
- def forward(self, x):
476
- act = self.get_all_activations(x)
477
- feats = []
478
- for out in act[:-1]:
479
- if isinstance(out, tuple):
480
- if torch.is_tensor(out[1]):
481
- out = torch.cat(out, dim=1)
482
- else:
483
- out = out[0]
484
- feats.append(out)
485
- return act[-1], feats
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/lama/saicinpainting/training/modules/multidilated_conv.py DELETED
@@ -1,98 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import random
4
- from annotator.lama.saicinpainting.training.modules.depthwise_sep_conv import DepthWiseSeperableConv
5
-
6
- class MultidilatedConv(nn.Module):
7
- def __init__(self, in_dim, out_dim, kernel_size, dilation_num=3, comb_mode='sum', equal_dim=True,
8
- shared_weights=False, padding=1, min_dilation=1, shuffle_in_channels=False, use_depthwise=False, **kwargs):
9
- super().__init__()
10
- convs = []
11
- self.equal_dim = equal_dim
12
- assert comb_mode in ('cat_out', 'sum', 'cat_in', 'cat_both'), comb_mode
13
- if comb_mode in ('cat_out', 'cat_both'):
14
- self.cat_out = True
15
- if equal_dim:
16
- assert out_dim % dilation_num == 0
17
- out_dims = [out_dim // dilation_num] * dilation_num
18
- self.index = sum([[i + j * (out_dims[0]) for j in range(dilation_num)] for i in range(out_dims[0])], [])
19
- else:
20
- out_dims = [out_dim // 2 ** (i + 1) for i in range(dilation_num - 1)]
21
- out_dims.append(out_dim - sum(out_dims))
22
- index = []
23
- starts = [0] + out_dims[:-1]
24
- lengths = [out_dims[i] // out_dims[-1] for i in range(dilation_num)]
25
- for i in range(out_dims[-1]):
26
- for j in range(dilation_num):
27
- index += list(range(starts[j], starts[j] + lengths[j]))
28
- starts[j] += lengths[j]
29
- self.index = index
30
- assert(len(index) == out_dim)
31
- self.out_dims = out_dims
32
- else:
33
- self.cat_out = False
34
- self.out_dims = [out_dim] * dilation_num
35
-
36
- if comb_mode in ('cat_in', 'cat_both'):
37
- if equal_dim:
38
- assert in_dim % dilation_num == 0
39
- in_dims = [in_dim // dilation_num] * dilation_num
40
- else:
41
- in_dims = [in_dim // 2 ** (i + 1) for i in range(dilation_num - 1)]
42
- in_dims.append(in_dim - sum(in_dims))
43
- self.in_dims = in_dims
44
- self.cat_in = True
45
- else:
46
- self.cat_in = False
47
- self.in_dims = [in_dim] * dilation_num
48
-
49
- conv_type = DepthWiseSeperableConv if use_depthwise else nn.Conv2d
50
- dilation = min_dilation
51
- for i in range(dilation_num):
52
- if isinstance(padding, int):
53
- cur_padding = padding * dilation
54
- else:
55
- cur_padding = padding[i]
56
- convs.append(conv_type(
57
- self.in_dims[i], self.out_dims[i], kernel_size, padding=cur_padding, dilation=dilation, **kwargs
58
- ))
59
- if i > 0 and shared_weights:
60
- convs[-1].weight = convs[0].weight
61
- convs[-1].bias = convs[0].bias
62
- dilation *= 2
63
- self.convs = nn.ModuleList(convs)
64
-
65
- self.shuffle_in_channels = shuffle_in_channels
66
- if self.shuffle_in_channels:
67
- # shuffle list as shuffling of tensors is nondeterministic
68
- in_channels_permute = list(range(in_dim))
69
- random.shuffle(in_channels_permute)
70
- # save as buffer so it is saved and loaded with checkpoint
71
- self.register_buffer('in_channels_permute', torch.tensor(in_channels_permute))
72
-
73
- def forward(self, x):
74
- if self.shuffle_in_channels:
75
- x = x[:, self.in_channels_permute]
76
-
77
- outs = []
78
- if self.cat_in:
79
- if self.equal_dim:
80
- x = x.chunk(len(self.convs), dim=1)
81
- else:
82
- new_x = []
83
- start = 0
84
- for dim in self.in_dims:
85
- new_x.append(x[:, start:start+dim])
86
- start += dim
87
- x = new_x
88
- for i, conv in enumerate(self.convs):
89
- if self.cat_in:
90
- input = x[i]
91
- else:
92
- input = x
93
- outs.append(conv(input))
94
- if self.cat_out:
95
- out = torch.cat(outs, dim=1)[:, self.index]
96
- else:
97
- out = sum(outs)
98
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/lama/saicinpainting/training/modules/multiscale.py DELETED
@@ -1,244 +0,0 @@
1
- from typing import List, Tuple, Union, Optional
2
-
3
- import torch
4
- import torch.nn as nn
5
- import torch.nn.functional as F
6
-
7
- from annotator.lama.saicinpainting.training.modules.base import get_conv_block_ctor, get_activation
8
- from annotator.lama.saicinpainting.training.modules.pix2pixhd import ResnetBlock
9
-
10
-
11
- class ResNetHead(nn.Module):
12
- def __init__(self, input_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
13
- padding_type='reflect', conv_kind='default', activation=nn.ReLU(True)):
14
- assert (n_blocks >= 0)
15
- super(ResNetHead, self).__init__()
16
-
17
- conv_layer = get_conv_block_ctor(conv_kind)
18
-
19
- model = [nn.ReflectionPad2d(3),
20
- conv_layer(input_nc, ngf, kernel_size=7, padding=0),
21
- norm_layer(ngf),
22
- activation]
23
-
24
- ### downsample
25
- for i in range(n_downsampling):
26
- mult = 2 ** i
27
- model += [conv_layer(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
28
- norm_layer(ngf * mult * 2),
29
- activation]
30
-
31
- mult = 2 ** n_downsampling
32
-
33
- ### resnet blocks
34
- for i in range(n_blocks):
35
- model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,
36
- conv_kind=conv_kind)]
37
-
38
- self.model = nn.Sequential(*model)
39
-
40
- def forward(self, input):
41
- return self.model(input)
42
-
43
-
44
- class ResNetTail(nn.Module):
45
- def __init__(self, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
46
- padding_type='reflect', conv_kind='default', activation=nn.ReLU(True),
47
- up_norm_layer=nn.BatchNorm2d, up_activation=nn.ReLU(True), add_out_act=False, out_extra_layers_n=0,
48
- add_in_proj=None):
49
- assert (n_blocks >= 0)
50
- super(ResNetTail, self).__init__()
51
-
52
- mult = 2 ** n_downsampling
53
-
54
- model = []
55
-
56
- if add_in_proj is not None:
57
- model.append(nn.Conv2d(add_in_proj, ngf * mult, kernel_size=1))
58
-
59
- ### resnet blocks
60
- for i in range(n_blocks):
61
- model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,
62
- conv_kind=conv_kind)]
63
-
64
- ### upsample
65
- for i in range(n_downsampling):
66
- mult = 2 ** (n_downsampling - i)
67
- model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1,
68
- output_padding=1),
69
- up_norm_layer(int(ngf * mult / 2)),
70
- up_activation]
71
- self.model = nn.Sequential(*model)
72
-
73
- out_layers = []
74
- for _ in range(out_extra_layers_n):
75
- out_layers += [nn.Conv2d(ngf, ngf, kernel_size=1, padding=0),
76
- up_norm_layer(ngf),
77
- up_activation]
78
- out_layers += [nn.ReflectionPad2d(3),
79
- nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
80
-
81
- if add_out_act:
82
- out_layers.append(get_activation('tanh' if add_out_act is True else add_out_act))
83
-
84
- self.out_proj = nn.Sequential(*out_layers)
85
-
86
- def forward(self, input, return_last_act=False):
87
- features = self.model(input)
88
- out = self.out_proj(features)
89
- if return_last_act:
90
- return out, features
91
- else:
92
- return out
93
-
94
-
95
- class MultiscaleResNet(nn.Module):
96
- def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=2, n_blocks_head=2, n_blocks_tail=6, n_scales=3,
97
- norm_layer=nn.BatchNorm2d, padding_type='reflect', conv_kind='default', activation=nn.ReLU(True),
98
- up_norm_layer=nn.BatchNorm2d, up_activation=nn.ReLU(True), add_out_act=False, out_extra_layers_n=0,
99
- out_cumulative=False, return_only_hr=False):
100
- super().__init__()
101
-
102
- self.heads = nn.ModuleList([ResNetHead(input_nc, ngf=ngf, n_downsampling=n_downsampling,
103
- n_blocks=n_blocks_head, norm_layer=norm_layer, padding_type=padding_type,
104
- conv_kind=conv_kind, activation=activation)
105
- for i in range(n_scales)])
106
- tail_in_feats = ngf * (2 ** n_downsampling) + ngf
107
- self.tails = nn.ModuleList([ResNetTail(output_nc,
108
- ngf=ngf, n_downsampling=n_downsampling,
109
- n_blocks=n_blocks_tail, norm_layer=norm_layer, padding_type=padding_type,
110
- conv_kind=conv_kind, activation=activation, up_norm_layer=up_norm_layer,
111
- up_activation=up_activation, add_out_act=add_out_act,
112
- out_extra_layers_n=out_extra_layers_n,
113
- add_in_proj=None if (i == n_scales - 1) else tail_in_feats)
114
- for i in range(n_scales)])
115
-
116
- self.out_cumulative = out_cumulative
117
- self.return_only_hr = return_only_hr
118
-
119
- @property
120
- def num_scales(self):
121
- return len(self.heads)
122
-
123
- def forward(self, ms_inputs: List[torch.Tensor], smallest_scales_num: Optional[int] = None) \
124
- -> Union[torch.Tensor, List[torch.Tensor]]:
125
- """
126
- :param ms_inputs: List of inputs of different resolutions from HR to LR
127
- :param smallest_scales_num: int or None, number of smallest scales to take at input
128
- :return: Depending on return_only_hr:
129
- True: Only the most HR output
130
- False: List of outputs of different resolutions from HR to LR
131
- """
132
- if smallest_scales_num is None:
133
- assert len(self.heads) == len(ms_inputs), (len(self.heads), len(ms_inputs), smallest_scales_num)
134
- smallest_scales_num = len(self.heads)
135
- else:
136
- assert smallest_scales_num == len(ms_inputs) <= len(self.heads), (len(self.heads), len(ms_inputs), smallest_scales_num)
137
-
138
- cur_heads = self.heads[-smallest_scales_num:]
139
- ms_features = [cur_head(cur_inp) for cur_head, cur_inp in zip(cur_heads, ms_inputs)]
140
-
141
- all_outputs = []
142
- prev_tail_features = None
143
- for i in range(len(ms_features)):
144
- scale_i = -i - 1
145
-
146
- cur_tail_input = ms_features[-i - 1]
147
- if prev_tail_features is not None:
148
- if prev_tail_features.shape != cur_tail_input.shape:
149
- prev_tail_features = F.interpolate(prev_tail_features, size=cur_tail_input.shape[2:],
150
- mode='bilinear', align_corners=False)
151
- cur_tail_input = torch.cat((cur_tail_input, prev_tail_features), dim=1)
152
-
153
- cur_out, cur_tail_feats = self.tails[scale_i](cur_tail_input, return_last_act=True)
154
-
155
- prev_tail_features = cur_tail_feats
156
- all_outputs.append(cur_out)
157
-
158
- if self.out_cumulative:
159
- all_outputs_cum = [all_outputs[0]]
160
- for i in range(1, len(ms_features)):
161
- cur_out = all_outputs[i]
162
- cur_out_cum = cur_out + F.interpolate(all_outputs_cum[-1], size=cur_out.shape[2:],
163
- mode='bilinear', align_corners=False)
164
- all_outputs_cum.append(cur_out_cum)
165
- all_outputs = all_outputs_cum
166
-
167
- if self.return_only_hr:
168
- return all_outputs[-1]
169
- else:
170
- return all_outputs[::-1]
171
-
172
-
173
- class MultiscaleDiscriminatorSimple(nn.Module):
174
- def __init__(self, ms_impl):
175
- super().__init__()
176
- self.ms_impl = nn.ModuleList(ms_impl)
177
-
178
- @property
179
- def num_scales(self):
180
- return len(self.ms_impl)
181
-
182
- def forward(self, ms_inputs: List[torch.Tensor], smallest_scales_num: Optional[int] = None) \
183
- -> List[Tuple[torch.Tensor, List[torch.Tensor]]]:
184
- """
185
- :param ms_inputs: List of inputs of different resolutions from HR to LR
186
- :param smallest_scales_num: int or None, number of smallest scales to take at input
187
- :return: List of pairs (prediction, features) for different resolutions from HR to LR
188
- """
189
- if smallest_scales_num is None:
190
- assert len(self.ms_impl) == len(ms_inputs), (len(self.ms_impl), len(ms_inputs), smallest_scales_num)
191
- smallest_scales_num = len(self.heads)
192
- else:
193
- assert smallest_scales_num == len(ms_inputs) <= len(self.ms_impl), \
194
- (len(self.ms_impl), len(ms_inputs), smallest_scales_num)
195
-
196
- return [cur_discr(cur_input) for cur_discr, cur_input in zip(self.ms_impl[-smallest_scales_num:], ms_inputs)]
197
-
198
-
199
- class SingleToMultiScaleInputMixin:
200
- def forward(self, x: torch.Tensor) -> List:
201
- orig_height, orig_width = x.shape[2:]
202
- factors = [2 ** i for i in range(self.num_scales)]
203
- ms_inputs = [F.interpolate(x, size=(orig_height // f, orig_width // f), mode='bilinear', align_corners=False)
204
- for f in factors]
205
- return super().forward(ms_inputs)
206
-
207
-
208
- class GeneratorMultiToSingleOutputMixin:
209
- def forward(self, x):
210
- return super().forward(x)[0]
211
-
212
-
213
- class DiscriminatorMultiToSingleOutputMixin:
214
- def forward(self, x):
215
- out_feat_tuples = super().forward(x)
216
- return out_feat_tuples[0][0], [f for _, flist in out_feat_tuples for f in flist]
217
-
218
-
219
- class DiscriminatorMultiToSingleOutputStackedMixin:
220
- def __init__(self, *args, return_feats_only_levels=None, **kwargs):
221
- super().__init__(*args, **kwargs)
222
- self.return_feats_only_levels = return_feats_only_levels
223
-
224
- def forward(self, x):
225
- out_feat_tuples = super().forward(x)
226
- outs = [out for out, _ in out_feat_tuples]
227
- scaled_outs = [outs[0]] + [F.interpolate(cur_out, size=outs[0].shape[-2:],
228
- mode='bilinear', align_corners=False)
229
- for cur_out in outs[1:]]
230
- out = torch.cat(scaled_outs, dim=1)
231
- if self.return_feats_only_levels is not None:
232
- feat_lists = [out_feat_tuples[i][1] for i in self.return_feats_only_levels]
233
- else:
234
- feat_lists = [flist for _, flist in out_feat_tuples]
235
- feats = [f for flist in feat_lists for f in flist]
236
- return out, feats
237
-
238
-
239
- class MultiscaleDiscrSingleInput(SingleToMultiScaleInputMixin, DiscriminatorMultiToSingleOutputStackedMixin, MultiscaleDiscriminatorSimple):
240
- pass
241
-
242
-
243
- class MultiscaleResNetSingle(GeneratorMultiToSingleOutputMixin, SingleToMultiScaleInputMixin, MultiscaleResNet):
244
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/lama/saicinpainting/training/modules/pix2pixhd.py DELETED
@@ -1,669 +0,0 @@
1
- # original: https://github.com/NVIDIA/pix2pixHD/blob/master/models/networks.py
2
- import collections
3
- from functools import partial
4
- import functools
5
- import logging
6
- from collections import defaultdict
7
-
8
- import numpy as np
9
- import torch.nn as nn
10
-
11
- from annotator.lama.saicinpainting.training.modules.base import BaseDiscriminator, deconv_factory, get_conv_block_ctor, get_norm_layer, get_activation
12
- from annotator.lama.saicinpainting.training.modules.ffc import FFCResnetBlock
13
- from annotator.lama.saicinpainting.training.modules.multidilated_conv import MultidilatedConv
14
-
15
- class DotDict(defaultdict):
16
- # https://stackoverflow.com/questions/2352181/how-to-use-a-dot-to-access-members-of-dictionary
17
- """dot.notation access to dictionary attributes"""
18
- __getattr__ = defaultdict.get
19
- __setattr__ = defaultdict.__setitem__
20
- __delattr__ = defaultdict.__delitem__
21
-
22
- class Identity(nn.Module):
23
- def __init__(self):
24
- super().__init__()
25
-
26
- def forward(self, x):
27
- return x
28
-
29
-
30
- class ResnetBlock(nn.Module):
31
- def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False, conv_kind='default',
32
- dilation=1, in_dim=None, groups=1, second_dilation=None):
33
- super(ResnetBlock, self).__init__()
34
- self.in_dim = in_dim
35
- self.dim = dim
36
- if second_dilation is None:
37
- second_dilation = dilation
38
- self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout,
39
- conv_kind=conv_kind, dilation=dilation, in_dim=in_dim, groups=groups,
40
- second_dilation=second_dilation)
41
-
42
- if self.in_dim is not None:
43
- self.input_conv = nn.Conv2d(in_dim, dim, 1)
44
-
45
- self.out_channnels = dim
46
-
47
- def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout, conv_kind='default',
48
- dilation=1, in_dim=None, groups=1, second_dilation=1):
49
- conv_layer = get_conv_block_ctor(conv_kind)
50
-
51
- conv_block = []
52
- p = 0
53
- if padding_type == 'reflect':
54
- conv_block += [nn.ReflectionPad2d(dilation)]
55
- elif padding_type == 'replicate':
56
- conv_block += [nn.ReplicationPad2d(dilation)]
57
- elif padding_type == 'zero':
58
- p = dilation
59
- else:
60
- raise NotImplementedError('padding [%s] is not implemented' % padding_type)
61
-
62
- if in_dim is None:
63
- in_dim = dim
64
-
65
- conv_block += [conv_layer(in_dim, dim, kernel_size=3, padding=p, dilation=dilation),
66
- norm_layer(dim),
67
- activation]
68
- if use_dropout:
69
- conv_block += [nn.Dropout(0.5)]
70
-
71
- p = 0
72
- if padding_type == 'reflect':
73
- conv_block += [nn.ReflectionPad2d(second_dilation)]
74
- elif padding_type == 'replicate':
75
- conv_block += [nn.ReplicationPad2d(second_dilation)]
76
- elif padding_type == 'zero':
77
- p = second_dilation
78
- else:
79
- raise NotImplementedError('padding [%s] is not implemented' % padding_type)
80
- conv_block += [conv_layer(dim, dim, kernel_size=3, padding=p, dilation=second_dilation, groups=groups),
81
- norm_layer(dim)]
82
-
83
- return nn.Sequential(*conv_block)
84
-
85
- def forward(self, x):
86
- x_before = x
87
- if self.in_dim is not None:
88
- x = self.input_conv(x)
89
- out = x + self.conv_block(x_before)
90
- return out
91
-
92
- class ResnetBlock5x5(nn.Module):
93
- def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False, conv_kind='default',
94
- dilation=1, in_dim=None, groups=1, second_dilation=None):
95
- super(ResnetBlock5x5, self).__init__()
96
- self.in_dim = in_dim
97
- self.dim = dim
98
- if second_dilation is None:
99
- second_dilation = dilation
100
- self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout,
101
- conv_kind=conv_kind, dilation=dilation, in_dim=in_dim, groups=groups,
102
- second_dilation=second_dilation)
103
-
104
- if self.in_dim is not None:
105
- self.input_conv = nn.Conv2d(in_dim, dim, 1)
106
-
107
- self.out_channnels = dim
108
-
109
- def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout, conv_kind='default',
110
- dilation=1, in_dim=None, groups=1, second_dilation=1):
111
- conv_layer = get_conv_block_ctor(conv_kind)
112
-
113
- conv_block = []
114
- p = 0
115
- if padding_type == 'reflect':
116
- conv_block += [nn.ReflectionPad2d(dilation * 2)]
117
- elif padding_type == 'replicate':
118
- conv_block += [nn.ReplicationPad2d(dilation * 2)]
119
- elif padding_type == 'zero':
120
- p = dilation * 2
121
- else:
122
- raise NotImplementedError('padding [%s] is not implemented' % padding_type)
123
-
124
- if in_dim is None:
125
- in_dim = dim
126
-
127
- conv_block += [conv_layer(in_dim, dim, kernel_size=5, padding=p, dilation=dilation),
128
- norm_layer(dim),
129
- activation]
130
- if use_dropout:
131
- conv_block += [nn.Dropout(0.5)]
132
-
133
- p = 0
134
- if padding_type == 'reflect':
135
- conv_block += [nn.ReflectionPad2d(second_dilation * 2)]
136
- elif padding_type == 'replicate':
137
- conv_block += [nn.ReplicationPad2d(second_dilation * 2)]
138
- elif padding_type == 'zero':
139
- p = second_dilation * 2
140
- else:
141
- raise NotImplementedError('padding [%s] is not implemented' % padding_type)
142
- conv_block += [conv_layer(dim, dim, kernel_size=5, padding=p, dilation=second_dilation, groups=groups),
143
- norm_layer(dim)]
144
-
145
- return nn.Sequential(*conv_block)
146
-
147
- def forward(self, x):
148
- x_before = x
149
- if self.in_dim is not None:
150
- x = self.input_conv(x)
151
- out = x + self.conv_block(x_before)
152
- return out
153
-
154
-
155
- class MultidilatedResnetBlock(nn.Module):
156
- def __init__(self, dim, padding_type, conv_layer, norm_layer, activation=nn.ReLU(True), use_dropout=False):
157
- super().__init__()
158
- self.conv_block = self.build_conv_block(dim, padding_type, conv_layer, norm_layer, activation, use_dropout)
159
-
160
- def build_conv_block(self, dim, padding_type, conv_layer, norm_layer, activation, use_dropout, dilation=1):
161
- conv_block = []
162
- conv_block += [conv_layer(dim, dim, kernel_size=3, padding_mode=padding_type),
163
- norm_layer(dim),
164
- activation]
165
- if use_dropout:
166
- conv_block += [nn.Dropout(0.5)]
167
-
168
- conv_block += [conv_layer(dim, dim, kernel_size=3, padding_mode=padding_type),
169
- norm_layer(dim)]
170
-
171
- return nn.Sequential(*conv_block)
172
-
173
- def forward(self, x):
174
- out = x + self.conv_block(x)
175
- return out
176
-
177
-
178
- class MultiDilatedGlobalGenerator(nn.Module):
179
- def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3,
180
- n_blocks=3, norm_layer=nn.BatchNorm2d,
181
- padding_type='reflect', conv_kind='default',
182
- deconv_kind='convtranspose', activation=nn.ReLU(True),
183
- up_norm_layer=nn.BatchNorm2d, affine=None, up_activation=nn.ReLU(True),
184
- add_out_act=True, max_features=1024, multidilation_kwargs={},
185
- ffc_positions=None, ffc_kwargs={}):
186
- assert (n_blocks >= 0)
187
- super().__init__()
188
-
189
- conv_layer = get_conv_block_ctor(conv_kind)
190
- resnet_conv_layer = functools.partial(get_conv_block_ctor('multidilated'), **multidilation_kwargs)
191
- norm_layer = get_norm_layer(norm_layer)
192
- if affine is not None:
193
- norm_layer = partial(norm_layer, affine=affine)
194
- up_norm_layer = get_norm_layer(up_norm_layer)
195
- if affine is not None:
196
- up_norm_layer = partial(up_norm_layer, affine=affine)
197
-
198
- model = [nn.ReflectionPad2d(3),
199
- conv_layer(input_nc, ngf, kernel_size=7, padding=0),
200
- norm_layer(ngf),
201
- activation]
202
-
203
- identity = Identity()
204
- ### downsample
205
- for i in range(n_downsampling):
206
- mult = 2 ** i
207
-
208
- model += [conv_layer(min(max_features, ngf * mult),
209
- min(max_features, ngf * mult * 2),
210
- kernel_size=3, stride=2, padding=1),
211
- norm_layer(min(max_features, ngf * mult * 2)),
212
- activation]
213
-
214
- mult = 2 ** n_downsampling
215
- feats_num_bottleneck = min(max_features, ngf * mult)
216
-
217
- ### resnet blocks
218
- for i in range(n_blocks):
219
- if ffc_positions is not None and i in ffc_positions:
220
- model += [FFCResnetBlock(feats_num_bottleneck, padding_type, norm_layer, activation_layer=nn.ReLU,
221
- inline=True, **ffc_kwargs)]
222
- model += [MultidilatedResnetBlock(feats_num_bottleneck, padding_type=padding_type,
223
- conv_layer=resnet_conv_layer, activation=activation,
224
- norm_layer=norm_layer)]
225
-
226
- ### upsample
227
- for i in range(n_downsampling):
228
- mult = 2 ** (n_downsampling - i)
229
- model += deconv_factory(deconv_kind, ngf, mult, up_norm_layer, up_activation, max_features)
230
- model += [nn.ReflectionPad2d(3),
231
- nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
232
- if add_out_act:
233
- model.append(get_activation('tanh' if add_out_act is True else add_out_act))
234
- self.model = nn.Sequential(*model)
235
-
236
- def forward(self, input):
237
- return self.model(input)
238
-
239
- class ConfigGlobalGenerator(nn.Module):
240
- def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3,
241
- n_blocks=3, norm_layer=nn.BatchNorm2d,
242
- padding_type='reflect', conv_kind='default',
243
- deconv_kind='convtranspose', activation=nn.ReLU(True),
244
- up_norm_layer=nn.BatchNorm2d, affine=None, up_activation=nn.ReLU(True),
245
- add_out_act=True, max_features=1024,
246
- manual_block_spec=[],
247
- resnet_block_kind='multidilatedresnetblock',
248
- resnet_conv_kind='multidilated',
249
- resnet_dilation=1,
250
- multidilation_kwargs={}):
251
- assert (n_blocks >= 0)
252
- super().__init__()
253
-
254
- conv_layer = get_conv_block_ctor(conv_kind)
255
- resnet_conv_layer = functools.partial(get_conv_block_ctor(resnet_conv_kind), **multidilation_kwargs)
256
- norm_layer = get_norm_layer(norm_layer)
257
- if affine is not None:
258
- norm_layer = partial(norm_layer, affine=affine)
259
- up_norm_layer = get_norm_layer(up_norm_layer)
260
- if affine is not None:
261
- up_norm_layer = partial(up_norm_layer, affine=affine)
262
-
263
- model = [nn.ReflectionPad2d(3),
264
- conv_layer(input_nc, ngf, kernel_size=7, padding=0),
265
- norm_layer(ngf),
266
- activation]
267
-
268
- identity = Identity()
269
-
270
- ### downsample
271
- for i in range(n_downsampling):
272
- mult = 2 ** i
273
- model += [conv_layer(min(max_features, ngf * mult),
274
- min(max_features, ngf * mult * 2),
275
- kernel_size=3, stride=2, padding=1),
276
- norm_layer(min(max_features, ngf * mult * 2)),
277
- activation]
278
-
279
- mult = 2 ** n_downsampling
280
- feats_num_bottleneck = min(max_features, ngf * mult)
281
-
282
- if len(manual_block_spec) == 0:
283
- manual_block_spec = [
284
- DotDict(lambda : None, {
285
- 'n_blocks': n_blocks,
286
- 'use_default': True})
287
- ]
288
-
289
- ### resnet blocks
290
- for block_spec in manual_block_spec:
291
- def make_and_add_blocks(model, block_spec):
292
- block_spec = DotDict(lambda : None, block_spec)
293
- if not block_spec.use_default:
294
- resnet_conv_layer = functools.partial(get_conv_block_ctor(block_spec.resnet_conv_kind), **block_spec.multidilation_kwargs)
295
- resnet_conv_kind = block_spec.resnet_conv_kind
296
- resnet_block_kind = block_spec.resnet_block_kind
297
- if block_spec.resnet_dilation is not None:
298
- resnet_dilation = block_spec.resnet_dilation
299
- for i in range(block_spec.n_blocks):
300
- if resnet_block_kind == "multidilatedresnetblock":
301
- model += [MultidilatedResnetBlock(feats_num_bottleneck, padding_type=padding_type,
302
- conv_layer=resnet_conv_layer, activation=activation,
303
- norm_layer=norm_layer)]
304
- if resnet_block_kind == "resnetblock":
305
- model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,
306
- conv_kind=resnet_conv_kind)]
307
- if resnet_block_kind == "resnetblock5x5":
308
- model += [ResnetBlock5x5(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,
309
- conv_kind=resnet_conv_kind)]
310
- if resnet_block_kind == "resnetblockdwdil":
311
- model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,
312
- conv_kind=resnet_conv_kind, dilation=resnet_dilation, second_dilation=resnet_dilation)]
313
- make_and_add_blocks(model, block_spec)
314
-
315
- ### upsample
316
- for i in range(n_downsampling):
317
- mult = 2 ** (n_downsampling - i)
318
- model += deconv_factory(deconv_kind, ngf, mult, up_norm_layer, up_activation, max_features)
319
- model += [nn.ReflectionPad2d(3),
320
- nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
321
- if add_out_act:
322
- model.append(get_activation('tanh' if add_out_act is True else add_out_act))
323
- self.model = nn.Sequential(*model)
324
-
325
- def forward(self, input):
326
- return self.model(input)
327
-
328
-
329
- def make_dil_blocks(dilated_blocks_n, dilation_block_kind, dilated_block_kwargs):
330
- blocks = []
331
- for i in range(dilated_blocks_n):
332
- if dilation_block_kind == 'simple':
333
- blocks.append(ResnetBlock(**dilated_block_kwargs, dilation=2 ** (i + 1)))
334
- elif dilation_block_kind == 'multi':
335
- blocks.append(MultidilatedResnetBlock(**dilated_block_kwargs))
336
- else:
337
- raise ValueError(f'dilation_block_kind could not be "{dilation_block_kind}"')
338
- return blocks
339
-
340
-
341
- class GlobalGenerator(nn.Module):
342
- def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
343
- padding_type='reflect', conv_kind='default', activation=nn.ReLU(True),
344
- up_norm_layer=nn.BatchNorm2d, affine=None,
345
- up_activation=nn.ReLU(True), dilated_blocks_n=0, dilated_blocks_n_start=0,
346
- dilated_blocks_n_middle=0,
347
- add_out_act=True,
348
- max_features=1024, is_resblock_depthwise=False,
349
- ffc_positions=None, ffc_kwargs={}, dilation=1, second_dilation=None,
350
- dilation_block_kind='simple', multidilation_kwargs={}):
351
- assert (n_blocks >= 0)
352
- super().__init__()
353
-
354
- conv_layer = get_conv_block_ctor(conv_kind)
355
- norm_layer = get_norm_layer(norm_layer)
356
- if affine is not None:
357
- norm_layer = partial(norm_layer, affine=affine)
358
- up_norm_layer = get_norm_layer(up_norm_layer)
359
- if affine is not None:
360
- up_norm_layer = partial(up_norm_layer, affine=affine)
361
-
362
- if ffc_positions is not None:
363
- ffc_positions = collections.Counter(ffc_positions)
364
-
365
- model = [nn.ReflectionPad2d(3),
366
- conv_layer(input_nc, ngf, kernel_size=7, padding=0),
367
- norm_layer(ngf),
368
- activation]
369
-
370
- identity = Identity()
371
- ### downsample
372
- for i in range(n_downsampling):
373
- mult = 2 ** i
374
-
375
- model += [conv_layer(min(max_features, ngf * mult),
376
- min(max_features, ngf * mult * 2),
377
- kernel_size=3, stride=2, padding=1),
378
- norm_layer(min(max_features, ngf * mult * 2)),
379
- activation]
380
-
381
- mult = 2 ** n_downsampling
382
- feats_num_bottleneck = min(max_features, ngf * mult)
383
-
384
- dilated_block_kwargs = dict(dim=feats_num_bottleneck, padding_type=padding_type,
385
- activation=activation, norm_layer=norm_layer)
386
- if dilation_block_kind == 'simple':
387
- dilated_block_kwargs['conv_kind'] = conv_kind
388
- elif dilation_block_kind == 'multi':
389
- dilated_block_kwargs['conv_layer'] = functools.partial(
390
- get_conv_block_ctor('multidilated'), **multidilation_kwargs)
391
-
392
- # dilated blocks at the start of the bottleneck sausage
393
- if dilated_blocks_n_start is not None and dilated_blocks_n_start > 0:
394
- model += make_dil_blocks(dilated_blocks_n_start, dilation_block_kind, dilated_block_kwargs)
395
-
396
- # resnet blocks
397
- for i in range(n_blocks):
398
- # dilated blocks at the middle of the bottleneck sausage
399
- if i == n_blocks // 2 and dilated_blocks_n_middle is not None and dilated_blocks_n_middle > 0:
400
- model += make_dil_blocks(dilated_blocks_n_middle, dilation_block_kind, dilated_block_kwargs)
401
-
402
- if ffc_positions is not None and i in ffc_positions:
403
- for _ in range(ffc_positions[i]): # same position can occur more than once
404
- model += [FFCResnetBlock(feats_num_bottleneck, padding_type, norm_layer, activation_layer=nn.ReLU,
405
- inline=True, **ffc_kwargs)]
406
-
407
- if is_resblock_depthwise:
408
- resblock_groups = feats_num_bottleneck
409
- else:
410
- resblock_groups = 1
411
-
412
- model += [ResnetBlock(feats_num_bottleneck, padding_type=padding_type, activation=activation,
413
- norm_layer=norm_layer, conv_kind=conv_kind, groups=resblock_groups,
414
- dilation=dilation, second_dilation=second_dilation)]
415
-
416
-
417
- # dilated blocks at the end of the bottleneck sausage
418
- if dilated_blocks_n is not None and dilated_blocks_n > 0:
419
- model += make_dil_blocks(dilated_blocks_n, dilation_block_kind, dilated_block_kwargs)
420
-
421
- # upsample
422
- for i in range(n_downsampling):
423
- mult = 2 ** (n_downsampling - i)
424
- model += [nn.ConvTranspose2d(min(max_features, ngf * mult),
425
- min(max_features, int(ngf * mult / 2)),
426
- kernel_size=3, stride=2, padding=1, output_padding=1),
427
- up_norm_layer(min(max_features, int(ngf * mult / 2))),
428
- up_activation]
429
- model += [nn.ReflectionPad2d(3),
430
- nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
431
- if add_out_act:
432
- model.append(get_activation('tanh' if add_out_act is True else add_out_act))
433
- self.model = nn.Sequential(*model)
434
-
435
- def forward(self, input):
436
- return self.model(input)
437
-
438
-
439
- class GlobalGeneratorGated(GlobalGenerator):
440
- def __init__(self, *args, **kwargs):
441
- real_kwargs=dict(
442
- conv_kind='gated_bn_relu',
443
- activation=nn.Identity(),
444
- norm_layer=nn.Identity
445
- )
446
- real_kwargs.update(kwargs)
447
- super().__init__(*args, **real_kwargs)
448
-
449
-
450
- class GlobalGeneratorFromSuperChannels(nn.Module):
451
- def __init__(self, input_nc, output_nc, n_downsampling, n_blocks, super_channels, norm_layer="bn", padding_type='reflect', add_out_act=True):
452
- super().__init__()
453
- self.n_downsampling = n_downsampling
454
- norm_layer = get_norm_layer(norm_layer)
455
- if type(norm_layer) == functools.partial:
456
- use_bias = (norm_layer.func == nn.InstanceNorm2d)
457
- else:
458
- use_bias = (norm_layer == nn.InstanceNorm2d)
459
-
460
- channels = self.convert_super_channels(super_channels)
461
- self.channels = channels
462
-
463
- model = [nn.ReflectionPad2d(3),
464
- nn.Conv2d(input_nc, channels[0], kernel_size=7, padding=0, bias=use_bias),
465
- norm_layer(channels[0]),
466
- nn.ReLU(True)]
467
-
468
- for i in range(n_downsampling): # add downsampling layers
469
- mult = 2 ** i
470
- model += [nn.Conv2d(channels[0+i], channels[1+i], kernel_size=3, stride=2, padding=1, bias=use_bias),
471
- norm_layer(channels[1+i]),
472
- nn.ReLU(True)]
473
-
474
- mult = 2 ** n_downsampling
475
-
476
- n_blocks1 = n_blocks // 3
477
- n_blocks2 = n_blocks1
478
- n_blocks3 = n_blocks - n_blocks1 - n_blocks2
479
-
480
- for i in range(n_blocks1):
481
- c = n_downsampling
482
- dim = channels[c]
483
- model += [ResnetBlock(dim, padding_type=padding_type, norm_layer=norm_layer)]
484
-
485
- for i in range(n_blocks2):
486
- c = n_downsampling+1
487
- dim = channels[c]
488
- kwargs = {}
489
- if i == 0:
490
- kwargs = {"in_dim": channels[c-1]}
491
- model += [ResnetBlock(dim, padding_type=padding_type, norm_layer=norm_layer, **kwargs)]
492
-
493
- for i in range(n_blocks3):
494
- c = n_downsampling+2
495
- dim = channels[c]
496
- kwargs = {}
497
- if i == 0:
498
- kwargs = {"in_dim": channels[c-1]}
499
- model += [ResnetBlock(dim, padding_type=padding_type, norm_layer=norm_layer, **kwargs)]
500
-
501
- for i in range(n_downsampling): # add upsampling layers
502
- mult = 2 ** (n_downsampling - i)
503
- model += [nn.ConvTranspose2d(channels[n_downsampling+3+i],
504
- channels[n_downsampling+3+i+1],
505
- kernel_size=3, stride=2,
506
- padding=1, output_padding=1,
507
- bias=use_bias),
508
- norm_layer(channels[n_downsampling+3+i+1]),
509
- nn.ReLU(True)]
510
- model += [nn.ReflectionPad2d(3)]
511
- model += [nn.Conv2d(channels[2*n_downsampling+3], output_nc, kernel_size=7, padding=0)]
512
-
513
- if add_out_act:
514
- model.append(get_activation('tanh' if add_out_act is True else add_out_act))
515
- self.model = nn.Sequential(*model)
516
-
517
- def convert_super_channels(self, super_channels):
518
- n_downsampling = self.n_downsampling
519
- result = []
520
- cnt = 0
521
-
522
- if n_downsampling == 2:
523
- N1 = 10
524
- elif n_downsampling == 3:
525
- N1 = 13
526
- else:
527
- raise NotImplementedError
528
-
529
- for i in range(0, N1):
530
- if i in [1,4,7,10]:
531
- channel = super_channels[cnt] * (2 ** cnt)
532
- config = {'channel': channel}
533
- result.append(channel)
534
- logging.info(f"Downsample channels {result[-1]}")
535
- cnt += 1
536
-
537
- for i in range(3):
538
- for counter, j in enumerate(range(N1 + i * 3, N1 + 3 + i * 3)):
539
- if len(super_channels) == 6:
540
- channel = super_channels[3] * 4
541
- else:
542
- channel = super_channels[i + 3] * 4
543
- config = {'channel': channel}
544
- if counter == 0:
545
- result.append(channel)
546
- logging.info(f"Bottleneck channels {result[-1]}")
547
- cnt = 2
548
-
549
- for i in range(N1+9, N1+21):
550
- if i in [22, 25,28]:
551
- cnt -= 1
552
- if len(super_channels) == 6:
553
- channel = super_channels[5 - cnt] * (2 ** cnt)
554
- else:
555
- channel = super_channels[7 - cnt] * (2 ** cnt)
556
- result.append(int(channel))
557
- logging.info(f"Upsample channels {result[-1]}")
558
- return result
559
-
560
- def forward(self, input):
561
- return self.model(input)
562
-
563
-
564
- # Defines the PatchGAN discriminator with the specified arguments.
565
- class NLayerDiscriminator(BaseDiscriminator):
566
- def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,):
567
- super().__init__()
568
- self.n_layers = n_layers
569
-
570
- kw = 4
571
- padw = int(np.ceil((kw-1.0)/2))
572
- sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
573
- nn.LeakyReLU(0.2, True)]]
574
-
575
- nf = ndf
576
- for n in range(1, n_layers):
577
- nf_prev = nf
578
- nf = min(nf * 2, 512)
579
-
580
- cur_model = []
581
- cur_model += [
582
- nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
583
- norm_layer(nf),
584
- nn.LeakyReLU(0.2, True)
585
- ]
586
- sequence.append(cur_model)
587
-
588
- nf_prev = nf
589
- nf = min(nf * 2, 512)
590
-
591
- cur_model = []
592
- cur_model += [
593
- nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
594
- norm_layer(nf),
595
- nn.LeakyReLU(0.2, True)
596
- ]
597
- sequence.append(cur_model)
598
-
599
- sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
600
-
601
- for n in range(len(sequence)):
602
- setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
603
-
604
- def get_all_activations(self, x):
605
- res = [x]
606
- for n in range(self.n_layers + 2):
607
- model = getattr(self, 'model' + str(n))
608
- res.append(model(res[-1]))
609
- return res[1:]
610
-
611
- def forward(self, x):
612
- act = self.get_all_activations(x)
613
- return act[-1], act[:-1]
614
-
615
-
616
- class MultidilatedNLayerDiscriminator(BaseDiscriminator):
617
- def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, multidilation_kwargs={}):
618
- super().__init__()
619
- self.n_layers = n_layers
620
-
621
- kw = 4
622
- padw = int(np.ceil((kw-1.0)/2))
623
- sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
624
- nn.LeakyReLU(0.2, True)]]
625
-
626
- nf = ndf
627
- for n in range(1, n_layers):
628
- nf_prev = nf
629
- nf = min(nf * 2, 512)
630
-
631
- cur_model = []
632
- cur_model += [
633
- MultidilatedConv(nf_prev, nf, kernel_size=kw, stride=2, padding=[2, 3], **multidilation_kwargs),
634
- norm_layer(nf),
635
- nn.LeakyReLU(0.2, True)
636
- ]
637
- sequence.append(cur_model)
638
-
639
- nf_prev = nf
640
- nf = min(nf * 2, 512)
641
-
642
- cur_model = []
643
- cur_model += [
644
- nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
645
- norm_layer(nf),
646
- nn.LeakyReLU(0.2, True)
647
- ]
648
- sequence.append(cur_model)
649
-
650
- sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
651
-
652
- for n in range(len(sequence)):
653
- setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
654
-
655
- def get_all_activations(self, x):
656
- res = [x]
657
- for n in range(self.n_layers + 2):
658
- model = getattr(self, 'model' + str(n))
659
- res.append(model(res[-1]))
660
- return res[1:]
661
-
662
- def forward(self, x):
663
- act = self.get_all_activations(x)
664
- return act[-1], act[:-1]
665
-
666
-
667
- class NLayerDiscriminatorAsGen(NLayerDiscriminator):
668
- def forward(self, x):
669
- return super().forward(x)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/lama/saicinpainting/training/modules/spatial_transform.py DELETED
@@ -1,49 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from kornia.geometry.transform import rotate
5
-
6
-
7
- class LearnableSpatialTransformWrapper(nn.Module):
8
- def __init__(self, impl, pad_coef=0.5, angle_init_range=80, train_angle=True):
9
- super().__init__()
10
- self.impl = impl
11
- self.angle = torch.rand(1) * angle_init_range
12
- if train_angle:
13
- self.angle = nn.Parameter(self.angle, requires_grad=True)
14
- self.pad_coef = pad_coef
15
-
16
- def forward(self, x):
17
- if torch.is_tensor(x):
18
- return self.inverse_transform(self.impl(self.transform(x)), x)
19
- elif isinstance(x, tuple):
20
- x_trans = tuple(self.transform(elem) for elem in x)
21
- y_trans = self.impl(x_trans)
22
- return tuple(self.inverse_transform(elem, orig_x) for elem, orig_x in zip(y_trans, x))
23
- else:
24
- raise ValueError(f'Unexpected input type {type(x)}')
25
-
26
- def transform(self, x):
27
- height, width = x.shape[2:]
28
- pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef)
29
- x_padded = F.pad(x, [pad_w, pad_w, pad_h, pad_h], mode='reflect')
30
- x_padded_rotated = rotate(x_padded, angle=self.angle.to(x_padded))
31
- return x_padded_rotated
32
-
33
- def inverse_transform(self, y_padded_rotated, orig_x):
34
- height, width = orig_x.shape[2:]
35
- pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef)
36
-
37
- y_padded = rotate(y_padded_rotated, angle=-self.angle.to(y_padded_rotated))
38
- y_height, y_width = y_padded.shape[2:]
39
- y = y_padded[:, :, pad_h : y_height - pad_h, pad_w : y_width - pad_w]
40
- return y
41
-
42
-
43
- if __name__ == '__main__':
44
- layer = LearnableSpatialTransformWrapper(nn.Identity())
45
- x = torch.arange(2* 3 * 15 * 15).view(2, 3, 15, 15).float()
46
- y = layer(x)
47
- assert x.shape == y.shape
48
- assert torch.allclose(x[:, :, 1:, 1:][:, :, :-1, :-1], y[:, :, 1:, 1:][:, :, :-1, :-1])
49
- print('all ok')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/lama/saicinpainting/training/modules/squeeze_excitation.py DELETED
@@ -1,20 +0,0 @@
1
- import torch.nn as nn
2
-
3
-
4
- class SELayer(nn.Module):
5
- def __init__(self, channel, reduction=16):
6
- super(SELayer, self).__init__()
7
- self.avg_pool = nn.AdaptiveAvgPool2d(1)
8
- self.fc = nn.Sequential(
9
- nn.Linear(channel, channel // reduction, bias=False),
10
- nn.ReLU(inplace=True),
11
- nn.Linear(channel // reduction, channel, bias=False),
12
- nn.Sigmoid()
13
- )
14
-
15
- def forward(self, x):
16
- b, c, _, _ = x.size()
17
- y = self.avg_pool(x).view(b, c)
18
- y = self.fc(y).view(b, c, 1, 1)
19
- res = x * y.expand_as(x)
20
- return res
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/lama/saicinpainting/training/trainers/__init__.py DELETED
@@ -1,29 +0,0 @@
1
- import logging
2
- import torch
3
- from annotator.lama.saicinpainting.training.trainers.default import DefaultInpaintingTrainingModule
4
-
5
-
6
- def get_training_model_class(kind):
7
- if kind == 'default':
8
- return DefaultInpaintingTrainingModule
9
-
10
- raise ValueError(f'Unknown trainer module {kind}')
11
-
12
-
13
- def make_training_model(config):
14
- kind = config.training_model.kind
15
- kwargs = dict(config.training_model)
16
- kwargs.pop('kind')
17
- kwargs['use_ddp'] = config.trainer.kwargs.get('accelerator', None) == 'ddp'
18
-
19
- logging.info(f'Make training model {kind}')
20
-
21
- cls = get_training_model_class(kind)
22
- return cls(config, **kwargs)
23
-
24
-
25
- def load_checkpoint(train_config, path, map_location='cuda', strict=True):
26
- model = make_training_model(train_config).generator
27
- state = torch.load(path, map_location=map_location)
28
- model.load_state_dict(state, strict=strict)
29
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/lama/saicinpainting/training/trainers/base.py DELETED
@@ -1,293 +0,0 @@
1
- import copy
2
- import logging
3
- from typing import Dict, Tuple
4
-
5
- import pandas as pd
6
- import pytorch_lightning as ptl
7
- import torch
8
- import torch.nn as nn
9
- import torch.nn.functional as F
10
- # from torch.utils.data import DistributedSampler
11
-
12
- # from annotator.lama.saicinpainting.evaluation import make_evaluator
13
- # from annotator.lama.saicinpainting.training.data.datasets import make_default_train_dataloader, make_default_val_dataloader
14
- # from annotator.lama.saicinpainting.training.losses.adversarial import make_discrim_loss
15
- # from annotator.lama.saicinpainting.training.losses.perceptual import PerceptualLoss, ResNetPL
16
- from annotator.lama.saicinpainting.training.modules import make_generator #, make_discriminator
17
- # from annotator.lama.saicinpainting.training.visualizers import make_visualizer
18
- from annotator.lama.saicinpainting.utils import add_prefix_to_keys, average_dicts, set_requires_grad, flatten_dict, \
19
- get_has_ddp_rank
20
-
21
- LOGGER = logging.getLogger(__name__)
22
-
23
-
24
- def make_optimizer(parameters, kind='adamw', **kwargs):
25
- if kind == 'adam':
26
- optimizer_class = torch.optim.Adam
27
- elif kind == 'adamw':
28
- optimizer_class = torch.optim.AdamW
29
- else:
30
- raise ValueError(f'Unknown optimizer kind {kind}')
31
- return optimizer_class(parameters, **kwargs)
32
-
33
-
34
- def update_running_average(result: nn.Module, new_iterate_model: nn.Module, decay=0.999):
35
- with torch.no_grad():
36
- res_params = dict(result.named_parameters())
37
- new_params = dict(new_iterate_model.named_parameters())
38
-
39
- for k in res_params.keys():
40
- res_params[k].data.mul_(decay).add_(new_params[k].data, alpha=1 - decay)
41
-
42
-
43
- def make_multiscale_noise(base_tensor, scales=6, scale_mode='bilinear'):
44
- batch_size, _, height, width = base_tensor.shape
45
- cur_height, cur_width = height, width
46
- result = []
47
- align_corners = False if scale_mode in ('bilinear', 'bicubic') else None
48
- for _ in range(scales):
49
- cur_sample = torch.randn(batch_size, 1, cur_height, cur_width, device=base_tensor.device)
50
- cur_sample_scaled = F.interpolate(cur_sample, size=(height, width), mode=scale_mode, align_corners=align_corners)
51
- result.append(cur_sample_scaled)
52
- cur_height //= 2
53
- cur_width //= 2
54
- return torch.cat(result, dim=1)
55
-
56
-
57
- class BaseInpaintingTrainingModule(ptl.LightningModule):
58
- def __init__(self, config, use_ddp, *args, predict_only=False, visualize_each_iters=100,
59
- average_generator=False, generator_avg_beta=0.999, average_generator_start_step=30000,
60
- average_generator_period=10, store_discr_outputs_for_vis=False,
61
- **kwargs):
62
- super().__init__(*args, **kwargs)
63
- LOGGER.info('BaseInpaintingTrainingModule init called')
64
-
65
- self.config = config
66
-
67
- self.generator = make_generator(config, **self.config.generator)
68
- self.use_ddp = use_ddp
69
-
70
- if not get_has_ddp_rank():
71
- LOGGER.info(f'Generator\n{self.generator}')
72
-
73
- # if not predict_only:
74
- # self.save_hyperparameters(self.config)
75
- # self.discriminator = make_discriminator(**self.config.discriminator)
76
- # self.adversarial_loss = make_discrim_loss(**self.config.losses.adversarial)
77
- # self.visualizer = make_visualizer(**self.config.visualizer)
78
- # self.val_evaluator = make_evaluator(**self.config.evaluator)
79
- # self.test_evaluator = make_evaluator(**self.config.evaluator)
80
- #
81
- # if not get_has_ddp_rank():
82
- # LOGGER.info(f'Discriminator\n{self.discriminator}')
83
- #
84
- # extra_val = self.config.data.get('extra_val', ())
85
- # if extra_val:
86
- # self.extra_val_titles = list(extra_val)
87
- # self.extra_evaluators = nn.ModuleDict({k: make_evaluator(**self.config.evaluator)
88
- # for k in extra_val})
89
- # else:
90
- # self.extra_evaluators = {}
91
- #
92
- # self.average_generator = average_generator
93
- # self.generator_avg_beta = generator_avg_beta
94
- # self.average_generator_start_step = average_generator_start_step
95
- # self.average_generator_period = average_generator_period
96
- # self.generator_average = None
97
- # self.last_generator_averaging_step = -1
98
- # self.store_discr_outputs_for_vis = store_discr_outputs_for_vis
99
- #
100
- # if self.config.losses.get("l1", {"weight_known": 0})['weight_known'] > 0:
101
- # self.loss_l1 = nn.L1Loss(reduction='none')
102
- #
103
- # if self.config.losses.get("mse", {"weight": 0})['weight'] > 0:
104
- # self.loss_mse = nn.MSELoss(reduction='none')
105
- #
106
- # if self.config.losses.perceptual.weight > 0:
107
- # self.loss_pl = PerceptualLoss()
108
- #
109
- # # if self.config.losses.get("resnet_pl", {"weight": 0})['weight'] > 0:
110
- # # self.loss_resnet_pl = ResNetPL(**self.config.losses.resnet_pl)
111
- # # else:
112
- # # self.loss_resnet_pl = None
113
- #
114
- # self.loss_resnet_pl = None
115
-
116
- self.visualize_each_iters = visualize_each_iters
117
- LOGGER.info('BaseInpaintingTrainingModule init done')
118
-
119
- def configure_optimizers(self):
120
- discriminator_params = list(self.discriminator.parameters())
121
- return [
122
- dict(optimizer=make_optimizer(self.generator.parameters(), **self.config.optimizers.generator)),
123
- dict(optimizer=make_optimizer(discriminator_params, **self.config.optimizers.discriminator)),
124
- ]
125
-
126
- def train_dataloader(self):
127
- kwargs = dict(self.config.data.train)
128
- if self.use_ddp:
129
- kwargs['ddp_kwargs'] = dict(num_replicas=self.trainer.num_nodes * self.trainer.num_processes,
130
- rank=self.trainer.global_rank,
131
- shuffle=True)
132
- dataloader = make_default_train_dataloader(**self.config.data.train)
133
- return dataloader
134
-
135
- def val_dataloader(self):
136
- res = [make_default_val_dataloader(**self.config.data.val)]
137
-
138
- if self.config.data.visual_test is not None:
139
- res = res + [make_default_val_dataloader(**self.config.data.visual_test)]
140
- else:
141
- res = res + res
142
-
143
- extra_val = self.config.data.get('extra_val', ())
144
- if extra_val:
145
- res += [make_default_val_dataloader(**extra_val[k]) for k in self.extra_val_titles]
146
-
147
- return res
148
-
149
- def training_step(self, batch, batch_idx, optimizer_idx=None):
150
- self._is_training_step = True
151
- return self._do_step(batch, batch_idx, mode='train', optimizer_idx=optimizer_idx)
152
-
153
- def validation_step(self, batch, batch_idx, dataloader_idx):
154
- extra_val_key = None
155
- if dataloader_idx == 0:
156
- mode = 'val'
157
- elif dataloader_idx == 1:
158
- mode = 'test'
159
- else:
160
- mode = 'extra_val'
161
- extra_val_key = self.extra_val_titles[dataloader_idx - 2]
162
- self._is_training_step = False
163
- return self._do_step(batch, batch_idx, mode=mode, extra_val_key=extra_val_key)
164
-
165
- def training_step_end(self, batch_parts_outputs):
166
- if self.training and self.average_generator \
167
- and self.global_step >= self.average_generator_start_step \
168
- and self.global_step >= self.last_generator_averaging_step + self.average_generator_period:
169
- if self.generator_average is None:
170
- self.generator_average = copy.deepcopy(self.generator)
171
- else:
172
- update_running_average(self.generator_average, self.generator, decay=self.generator_avg_beta)
173
- self.last_generator_averaging_step = self.global_step
174
-
175
- full_loss = (batch_parts_outputs['loss'].mean()
176
- if torch.is_tensor(batch_parts_outputs['loss']) # loss is not tensor when no discriminator used
177
- else torch.tensor(batch_parts_outputs['loss']).float().requires_grad_(True))
178
- log_info = {k: v.mean() for k, v in batch_parts_outputs['log_info'].items()}
179
- self.log_dict(log_info, on_step=True, on_epoch=False)
180
- return full_loss
181
-
182
- def validation_epoch_end(self, outputs):
183
- outputs = [step_out for out_group in outputs for step_out in out_group]
184
- averaged_logs = average_dicts(step_out['log_info'] for step_out in outputs)
185
- self.log_dict({k: v.mean() for k, v in averaged_logs.items()})
186
-
187
- pd.set_option('display.max_columns', 500)
188
- pd.set_option('display.width', 1000)
189
-
190
- # standard validation
191
- val_evaluator_states = [s['val_evaluator_state'] for s in outputs if 'val_evaluator_state' in s]
192
- val_evaluator_res = self.val_evaluator.evaluation_end(states=val_evaluator_states)
193
- val_evaluator_res_df = pd.DataFrame(val_evaluator_res).stack(1).unstack(0)
194
- val_evaluator_res_df.dropna(axis=1, how='all', inplace=True)
195
- LOGGER.info(f'Validation metrics after epoch #{self.current_epoch}, '
196
- f'total {self.global_step} iterations:\n{val_evaluator_res_df}')
197
-
198
- for k, v in flatten_dict(val_evaluator_res).items():
199
- self.log(f'val_{k}', v)
200
-
201
- # standard visual test
202
- test_evaluator_states = [s['test_evaluator_state'] for s in outputs
203
- if 'test_evaluator_state' in s]
204
- test_evaluator_res = self.test_evaluator.evaluation_end(states=test_evaluator_states)
205
- test_evaluator_res_df = pd.DataFrame(test_evaluator_res).stack(1).unstack(0)
206
- test_evaluator_res_df.dropna(axis=1, how='all', inplace=True)
207
- LOGGER.info(f'Test metrics after epoch #{self.current_epoch}, '
208
- f'total {self.global_step} iterations:\n{test_evaluator_res_df}')
209
-
210
- for k, v in flatten_dict(test_evaluator_res).items():
211
- self.log(f'test_{k}', v)
212
-
213
- # extra validations
214
- if self.extra_evaluators:
215
- for cur_eval_title, cur_evaluator in self.extra_evaluators.items():
216
- cur_state_key = f'extra_val_{cur_eval_title}_evaluator_state'
217
- cur_states = [s[cur_state_key] for s in outputs if cur_state_key in s]
218
- cur_evaluator_res = cur_evaluator.evaluation_end(states=cur_states)
219
- cur_evaluator_res_df = pd.DataFrame(cur_evaluator_res).stack(1).unstack(0)
220
- cur_evaluator_res_df.dropna(axis=1, how='all', inplace=True)
221
- LOGGER.info(f'Extra val {cur_eval_title} metrics after epoch #{self.current_epoch}, '
222
- f'total {self.global_step} iterations:\n{cur_evaluator_res_df}')
223
- for k, v in flatten_dict(cur_evaluator_res).items():
224
- self.log(f'extra_val_{cur_eval_title}_{k}', v)
225
-
226
- def _do_step(self, batch, batch_idx, mode='train', optimizer_idx=None, extra_val_key=None):
227
- if optimizer_idx == 0: # step for generator
228
- set_requires_grad(self.generator, True)
229
- set_requires_grad(self.discriminator, False)
230
- elif optimizer_idx == 1: # step for discriminator
231
- set_requires_grad(self.generator, False)
232
- set_requires_grad(self.discriminator, True)
233
-
234
- batch = self(batch)
235
-
236
- total_loss = 0
237
- metrics = {}
238
-
239
- if optimizer_idx is None or optimizer_idx == 0: # step for generator
240
- total_loss, metrics = self.generator_loss(batch)
241
-
242
- elif optimizer_idx is None or optimizer_idx == 1: # step for discriminator
243
- if self.config.losses.adversarial.weight > 0:
244
- total_loss, metrics = self.discriminator_loss(batch)
245
-
246
- if self.get_ddp_rank() in (None, 0) and (batch_idx % self.visualize_each_iters == 0 or mode == 'test'):
247
- if self.config.losses.adversarial.weight > 0:
248
- if self.store_discr_outputs_for_vis:
249
- with torch.no_grad():
250
- self.store_discr_outputs(batch)
251
- vis_suffix = f'_{mode}'
252
- if mode == 'extra_val':
253
- vis_suffix += f'_{extra_val_key}'
254
- self.visualizer(self.current_epoch, batch_idx, batch, suffix=vis_suffix)
255
-
256
- metrics_prefix = f'{mode}_'
257
- if mode == 'extra_val':
258
- metrics_prefix += f'{extra_val_key}_'
259
- result = dict(loss=total_loss, log_info=add_prefix_to_keys(metrics, metrics_prefix))
260
- if mode == 'val':
261
- result['val_evaluator_state'] = self.val_evaluator.process_batch(batch)
262
- elif mode == 'test':
263
- result['test_evaluator_state'] = self.test_evaluator.process_batch(batch)
264
- elif mode == 'extra_val':
265
- result[f'extra_val_{extra_val_key}_evaluator_state'] = self.extra_evaluators[extra_val_key].process_batch(batch)
266
-
267
- return result
268
-
269
- def get_current_generator(self, no_average=False):
270
- if not no_average and not self.training and self.average_generator and self.generator_average is not None:
271
- return self.generator_average
272
- return self.generator
273
-
274
- def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
275
- """Pass data through generator and obtain at leas 'predicted_image' and 'inpainted' keys"""
276
- raise NotImplementedError()
277
-
278
- def generator_loss(self, batch) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
279
- raise NotImplementedError()
280
-
281
- def discriminator_loss(self, batch) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
282
- raise NotImplementedError()
283
-
284
- def store_discr_outputs(self, batch):
285
- out_size = batch['image'].shape[2:]
286
- discr_real_out, _ = self.discriminator(batch['image'])
287
- discr_fake_out, _ = self.discriminator(batch['predicted_image'])
288
- batch['discr_output_real'] = F.interpolate(discr_real_out, size=out_size, mode='nearest')
289
- batch['discr_output_fake'] = F.interpolate(discr_fake_out, size=out_size, mode='nearest')
290
- batch['discr_output_diff'] = batch['discr_output_real'] - batch['discr_output_fake']
291
-
292
- def get_ddp_rank(self):
293
- return self.trainer.global_rank if (self.trainer.num_nodes * self.trainer.num_processes) > 1 else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/lama/saicinpainting/training/trainers/default.py DELETED
@@ -1,175 +0,0 @@
1
- import logging
2
-
3
- import torch
4
- import torch.nn.functional as F
5
- from omegaconf import OmegaConf
6
-
7
- # from annotator.lama.saicinpainting.training.data.datasets import make_constant_area_crop_params
8
- from annotator.lama.saicinpainting.training.losses.distance_weighting import make_mask_distance_weighter
9
- from annotator.lama.saicinpainting.training.losses.feature_matching import feature_matching_loss, masked_l1_loss
10
- # from annotator.lama.saicinpainting.training.modules.fake_fakes import FakeFakesGenerator
11
- from annotator.lama.saicinpainting.training.trainers.base import BaseInpaintingTrainingModule, make_multiscale_noise
12
- from annotator.lama.saicinpainting.utils import add_prefix_to_keys, get_ramp
13
-
14
- LOGGER = logging.getLogger(__name__)
15
-
16
-
17
- def make_constant_area_crop_batch(batch, **kwargs):
18
- crop_y, crop_x, crop_height, crop_width = make_constant_area_crop_params(img_height=batch['image'].shape[2],
19
- img_width=batch['image'].shape[3],
20
- **kwargs)
21
- batch['image'] = batch['image'][:, :, crop_y : crop_y + crop_height, crop_x : crop_x + crop_width]
22
- batch['mask'] = batch['mask'][:, :, crop_y: crop_y + crop_height, crop_x: crop_x + crop_width]
23
- return batch
24
-
25
-
26
- class DefaultInpaintingTrainingModule(BaseInpaintingTrainingModule):
27
- def __init__(self, *args, concat_mask=True, rescale_scheduler_kwargs=None, image_to_discriminator='predicted_image',
28
- add_noise_kwargs=None, noise_fill_hole=False, const_area_crop_kwargs=None,
29
- distance_weighter_kwargs=None, distance_weighted_mask_for_discr=False,
30
- fake_fakes_proba=0, fake_fakes_generator_kwargs=None,
31
- **kwargs):
32
- super().__init__(*args, **kwargs)
33
- self.concat_mask = concat_mask
34
- self.rescale_size_getter = get_ramp(**rescale_scheduler_kwargs) if rescale_scheduler_kwargs is not None else None
35
- self.image_to_discriminator = image_to_discriminator
36
- self.add_noise_kwargs = add_noise_kwargs
37
- self.noise_fill_hole = noise_fill_hole
38
- self.const_area_crop_kwargs = const_area_crop_kwargs
39
- self.refine_mask_for_losses = make_mask_distance_weighter(**distance_weighter_kwargs) \
40
- if distance_weighter_kwargs is not None else None
41
- self.distance_weighted_mask_for_discr = distance_weighted_mask_for_discr
42
-
43
- self.fake_fakes_proba = fake_fakes_proba
44
- if self.fake_fakes_proba > 1e-3:
45
- self.fake_fakes_gen = FakeFakesGenerator(**(fake_fakes_generator_kwargs or {}))
46
-
47
- def forward(self, batch):
48
- if self.training and self.rescale_size_getter is not None:
49
- cur_size = self.rescale_size_getter(self.global_step)
50
- batch['image'] = F.interpolate(batch['image'], size=cur_size, mode='bilinear', align_corners=False)
51
- batch['mask'] = F.interpolate(batch['mask'], size=cur_size, mode='nearest')
52
-
53
- if self.training and self.const_area_crop_kwargs is not None:
54
- batch = make_constant_area_crop_batch(batch, **self.const_area_crop_kwargs)
55
-
56
- img = batch['image']
57
- mask = batch['mask']
58
-
59
- masked_img = img * (1 - mask)
60
-
61
- if self.add_noise_kwargs is not None:
62
- noise = make_multiscale_noise(masked_img, **self.add_noise_kwargs)
63
- if self.noise_fill_hole:
64
- masked_img = masked_img + mask * noise[:, :masked_img.shape[1]]
65
- masked_img = torch.cat([masked_img, noise], dim=1)
66
-
67
- if self.concat_mask:
68
- masked_img = torch.cat([masked_img, mask], dim=1)
69
-
70
- batch['predicted_image'] = self.generator(masked_img)
71
- batch['inpainted'] = mask * batch['predicted_image'] + (1 - mask) * batch['image']
72
-
73
- if self.fake_fakes_proba > 1e-3:
74
- if self.training and torch.rand(1).item() < self.fake_fakes_proba:
75
- batch['fake_fakes'], batch['fake_fakes_masks'] = self.fake_fakes_gen(img, mask)
76
- batch['use_fake_fakes'] = True
77
- else:
78
- batch['fake_fakes'] = torch.zeros_like(img)
79
- batch['fake_fakes_masks'] = torch.zeros_like(mask)
80
- batch['use_fake_fakes'] = False
81
-
82
- batch['mask_for_losses'] = self.refine_mask_for_losses(img, batch['predicted_image'], mask) \
83
- if self.refine_mask_for_losses is not None and self.training \
84
- else mask
85
-
86
- return batch
87
-
88
- def generator_loss(self, batch):
89
- img = batch['image']
90
- predicted_img = batch[self.image_to_discriminator]
91
- original_mask = batch['mask']
92
- supervised_mask = batch['mask_for_losses']
93
-
94
- # L1
95
- l1_value = masked_l1_loss(predicted_img, img, supervised_mask,
96
- self.config.losses.l1.weight_known,
97
- self.config.losses.l1.weight_missing)
98
-
99
- total_loss = l1_value
100
- metrics = dict(gen_l1=l1_value)
101
-
102
- # vgg-based perceptual loss
103
- if self.config.losses.perceptual.weight > 0:
104
- pl_value = self.loss_pl(predicted_img, img, mask=supervised_mask).sum() * self.config.losses.perceptual.weight
105
- total_loss = total_loss + pl_value
106
- metrics['gen_pl'] = pl_value
107
-
108
- # discriminator
109
- # adversarial_loss calls backward by itself
110
- mask_for_discr = supervised_mask if self.distance_weighted_mask_for_discr else original_mask
111
- self.adversarial_loss.pre_generator_step(real_batch=img, fake_batch=predicted_img,
112
- generator=self.generator, discriminator=self.discriminator)
113
- discr_real_pred, discr_real_features = self.discriminator(img)
114
- discr_fake_pred, discr_fake_features = self.discriminator(predicted_img)
115
- adv_gen_loss, adv_metrics = self.adversarial_loss.generator_loss(real_batch=img,
116
- fake_batch=predicted_img,
117
- discr_real_pred=discr_real_pred,
118
- discr_fake_pred=discr_fake_pred,
119
- mask=mask_for_discr)
120
- total_loss = total_loss + adv_gen_loss
121
- metrics['gen_adv'] = adv_gen_loss
122
- metrics.update(add_prefix_to_keys(adv_metrics, 'adv_'))
123
-
124
- # feature matching
125
- if self.config.losses.feature_matching.weight > 0:
126
- need_mask_in_fm = OmegaConf.to_container(self.config.losses.feature_matching).get('pass_mask', False)
127
- mask_for_fm = supervised_mask if need_mask_in_fm else None
128
- fm_value = feature_matching_loss(discr_fake_features, discr_real_features,
129
- mask=mask_for_fm) * self.config.losses.feature_matching.weight
130
- total_loss = total_loss + fm_value
131
- metrics['gen_fm'] = fm_value
132
-
133
- if self.loss_resnet_pl is not None:
134
- resnet_pl_value = self.loss_resnet_pl(predicted_img, img)
135
- total_loss = total_loss + resnet_pl_value
136
- metrics['gen_resnet_pl'] = resnet_pl_value
137
-
138
- return total_loss, metrics
139
-
140
- def discriminator_loss(self, batch):
141
- total_loss = 0
142
- metrics = {}
143
-
144
- predicted_img = batch[self.image_to_discriminator].detach()
145
- self.adversarial_loss.pre_discriminator_step(real_batch=batch['image'], fake_batch=predicted_img,
146
- generator=self.generator, discriminator=self.discriminator)
147
- discr_real_pred, discr_real_features = self.discriminator(batch['image'])
148
- discr_fake_pred, discr_fake_features = self.discriminator(predicted_img)
149
- adv_discr_loss, adv_metrics = self.adversarial_loss.discriminator_loss(real_batch=batch['image'],
150
- fake_batch=predicted_img,
151
- discr_real_pred=discr_real_pred,
152
- discr_fake_pred=discr_fake_pred,
153
- mask=batch['mask'])
154
- total_loss = total_loss + adv_discr_loss
155
- metrics['discr_adv'] = adv_discr_loss
156
- metrics.update(add_prefix_to_keys(adv_metrics, 'adv_'))
157
-
158
-
159
- if batch.get('use_fake_fakes', False):
160
- fake_fakes = batch['fake_fakes']
161
- self.adversarial_loss.pre_discriminator_step(real_batch=batch['image'], fake_batch=fake_fakes,
162
- generator=self.generator, discriminator=self.discriminator)
163
- discr_fake_fakes_pred, _ = self.discriminator(fake_fakes)
164
- fake_fakes_adv_discr_loss, fake_fakes_adv_metrics = self.adversarial_loss.discriminator_loss(
165
- real_batch=batch['image'],
166
- fake_batch=fake_fakes,
167
- discr_real_pred=discr_real_pred,
168
- discr_fake_pred=discr_fake_fakes_pred,
169
- mask=batch['mask']
170
- )
171
- total_loss = total_loss + fake_fakes_adv_discr_loss
172
- metrics['discr_adv_fake_fakes'] = fake_fakes_adv_discr_loss
173
- metrics.update(add_prefix_to_keys(fake_fakes_adv_metrics, 'adv_'))
174
-
175
- return total_loss, metrics
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/lama/saicinpainting/training/visualizers/__init__.py DELETED
@@ -1,15 +0,0 @@
1
- import logging
2
-
3
- from annotator.lama.saicinpainting.training.visualizers.directory import DirectoryVisualizer
4
- from annotator.lama.saicinpainting.training.visualizers.noop import NoopVisualizer
5
-
6
-
7
- def make_visualizer(kind, **kwargs):
8
- logging.info(f'Make visualizer {kind}')
9
-
10
- if kind == 'directory':
11
- return DirectoryVisualizer(**kwargs)
12
- if kind == 'noop':
13
- return NoopVisualizer()
14
-
15
- raise ValueError(f'Unknown visualizer kind {kind}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/lama/saicinpainting/training/visualizers/base.py DELETED
@@ -1,73 +0,0 @@
1
- import abc
2
- from typing import Dict, List
3
-
4
- import numpy as np
5
- import torch
6
- from skimage import color
7
- from skimage.segmentation import mark_boundaries
8
-
9
- from . import colors
10
-
11
- COLORS, _ = colors.generate_colors(151) # 151 - max classes for semantic segmentation
12
-
13
-
14
- class BaseVisualizer:
15
- @abc.abstractmethod
16
- def __call__(self, epoch_i, batch_i, batch, suffix='', rank=None):
17
- """
18
- Take a batch, make an image from it and visualize
19
- """
20
- raise NotImplementedError()
21
-
22
-
23
- def visualize_mask_and_images(images_dict: Dict[str, np.ndarray], keys: List[str],
24
- last_without_mask=True, rescale_keys=None, mask_only_first=None,
25
- black_mask=False) -> np.ndarray:
26
- mask = images_dict['mask'] > 0.5
27
- result = []
28
- for i, k in enumerate(keys):
29
- img = images_dict[k]
30
- img = np.transpose(img, (1, 2, 0))
31
-
32
- if rescale_keys is not None and k in rescale_keys:
33
- img = img - img.min()
34
- img /= img.max() + 1e-5
35
- if len(img.shape) == 2:
36
- img = np.expand_dims(img, 2)
37
-
38
- if img.shape[2] == 1:
39
- img = np.repeat(img, 3, axis=2)
40
- elif (img.shape[2] > 3):
41
- img_classes = img.argmax(2)
42
- img = color.label2rgb(img_classes, colors=COLORS)
43
-
44
- if mask_only_first:
45
- need_mark_boundaries = i == 0
46
- else:
47
- need_mark_boundaries = i < len(keys) - 1 or not last_without_mask
48
-
49
- if need_mark_boundaries:
50
- if black_mask:
51
- img = img * (1 - mask[0][..., None])
52
- img = mark_boundaries(img,
53
- mask[0],
54
- color=(1., 0., 0.),
55
- outline_color=(1., 1., 1.),
56
- mode='thick')
57
- result.append(img)
58
- return np.concatenate(result, axis=1)
59
-
60
-
61
- def visualize_mask_and_images_batch(batch: Dict[str, torch.Tensor], keys: List[str], max_items=10,
62
- last_without_mask=True, rescale_keys=None) -> np.ndarray:
63
- batch = {k: tens.detach().cpu().numpy() for k, tens in batch.items()
64
- if k in keys or k == 'mask'}
65
-
66
- batch_size = next(iter(batch.values())).shape[0]
67
- items_to_vis = min(batch_size, max_items)
68
- result = []
69
- for i in range(items_to_vis):
70
- cur_dct = {k: tens[i] for k, tens in batch.items()}
71
- result.append(visualize_mask_and_images(cur_dct, keys, last_without_mask=last_without_mask,
72
- rescale_keys=rescale_keys))
73
- return np.concatenate(result, axis=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/lama/saicinpainting/training/visualizers/colors.py DELETED
@@ -1,76 +0,0 @@
1
- import random
2
- import colorsys
3
-
4
- import numpy as np
5
- import matplotlib
6
- matplotlib.use('agg')
7
- import matplotlib.pyplot as plt
8
- from matplotlib.colors import LinearSegmentedColormap
9
-
10
-
11
- def generate_colors(nlabels, type='bright', first_color_black=False, last_color_black=True, verbose=False):
12
- # https://stackoverflow.com/questions/14720331/how-to-generate-random-colors-in-matplotlib
13
- """
14
- Creates a random colormap to be used together with matplotlib. Useful for segmentation tasks
15
- :param nlabels: Number of labels (size of colormap)
16
- :param type: 'bright' for strong colors, 'soft' for pastel colors
17
- :param first_color_black: Option to use first color as black, True or False
18
- :param last_color_black: Option to use last color as black, True or False
19
- :param verbose: Prints the number of labels and shows the colormap. True or False
20
- :return: colormap for matplotlib
21
- """
22
- if type not in ('bright', 'soft'):
23
- print ('Please choose "bright" or "soft" for type')
24
- return
25
-
26
- if verbose:
27
- print('Number of labels: ' + str(nlabels))
28
-
29
- # Generate color map for bright colors, based on hsv
30
- if type == 'bright':
31
- randHSVcolors = [(np.random.uniform(low=0.0, high=1),
32
- np.random.uniform(low=0.2, high=1),
33
- np.random.uniform(low=0.9, high=1)) for i in range(nlabels)]
34
-
35
- # Convert HSV list to RGB
36
- randRGBcolors = []
37
- for HSVcolor in randHSVcolors:
38
- randRGBcolors.append(colorsys.hsv_to_rgb(HSVcolor[0], HSVcolor[1], HSVcolor[2]))
39
-
40
- if first_color_black:
41
- randRGBcolors[0] = [0, 0, 0]
42
-
43
- if last_color_black:
44
- randRGBcolors[-1] = [0, 0, 0]
45
-
46
- random_colormap = LinearSegmentedColormap.from_list('new_map', randRGBcolors, N=nlabels)
47
-
48
- # Generate soft pastel colors, by limiting the RGB spectrum
49
- if type == 'soft':
50
- low = 0.6
51
- high = 0.95
52
- randRGBcolors = [(np.random.uniform(low=low, high=high),
53
- np.random.uniform(low=low, high=high),
54
- np.random.uniform(low=low, high=high)) for i in range(nlabels)]
55
-
56
- if first_color_black:
57
- randRGBcolors[0] = [0, 0, 0]
58
-
59
- if last_color_black:
60
- randRGBcolors[-1] = [0, 0, 0]
61
- random_colormap = LinearSegmentedColormap.from_list('new_map', randRGBcolors, N=nlabels)
62
-
63
- # Display colorbar
64
- if verbose:
65
- from matplotlib import colors, colorbar
66
- from matplotlib import pyplot as plt
67
- fig, ax = plt.subplots(1, 1, figsize=(15, 0.5))
68
-
69
- bounds = np.linspace(0, nlabels, nlabels + 1)
70
- norm = colors.BoundaryNorm(bounds, nlabels)
71
-
72
- cb = colorbar.ColorbarBase(ax, cmap=random_colormap, norm=norm, spacing='proportional', ticks=None,
73
- boundaries=bounds, format='%1i', orientation=u'horizontal')
74
-
75
- return randRGBcolors, random_colormap
76
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/lama/saicinpainting/training/visualizers/directory.py DELETED
@@ -1,36 +0,0 @@
1
- import os
2
-
3
- import cv2
4
- import numpy as np
5
-
6
- from annotator.lama.saicinpainting.training.visualizers.base import BaseVisualizer, visualize_mask_and_images_batch
7
- from annotator.lama.saicinpainting.utils import check_and_warn_input_range
8
-
9
-
10
- class DirectoryVisualizer(BaseVisualizer):
11
- DEFAULT_KEY_ORDER = 'image predicted_image inpainted'.split(' ')
12
-
13
- def __init__(self, outdir, key_order=DEFAULT_KEY_ORDER, max_items_in_batch=10,
14
- last_without_mask=True, rescale_keys=None):
15
- self.outdir = outdir
16
- os.makedirs(self.outdir, exist_ok=True)
17
- self.key_order = key_order
18
- self.max_items_in_batch = max_items_in_batch
19
- self.last_without_mask = last_without_mask
20
- self.rescale_keys = rescale_keys
21
-
22
- def __call__(self, epoch_i, batch_i, batch, suffix='', rank=None):
23
- check_and_warn_input_range(batch['image'], 0, 1, 'DirectoryVisualizer target image')
24
- vis_img = visualize_mask_and_images_batch(batch, self.key_order, max_items=self.max_items_in_batch,
25
- last_without_mask=self.last_without_mask,
26
- rescale_keys=self.rescale_keys)
27
-
28
- vis_img = np.clip(vis_img * 255, 0, 255).astype('uint8')
29
-
30
- curoutdir = os.path.join(self.outdir, f'epoch{epoch_i:04d}{suffix}')
31
- os.makedirs(curoutdir, exist_ok=True)
32
- rank_suffix = f'_r{rank}' if rank is not None else ''
33
- out_fname = os.path.join(curoutdir, f'batch{batch_i:07d}{rank_suffix}.jpg')
34
-
35
- vis_img = cv2.cvtColor(vis_img, cv2.COLOR_RGB2BGR)
36
- cv2.imwrite(out_fname, vis_img)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/lama/saicinpainting/training/visualizers/noop.py DELETED
@@ -1,9 +0,0 @@
1
- from annotator.lama.saicinpainting.training.visualizers.base import BaseVisualizer
2
-
3
-
4
- class NoopVisualizer(BaseVisualizer):
5
- def __init__(self, *args, **kwargs):
6
- pass
7
-
8
- def __call__(self, epoch_i, batch_i, batch, suffix='', rank=None):
9
- pass
 
 
 
 
 
 
 
 
 
 
annotator/lama/saicinpainting/utils.py DELETED
@@ -1,174 +0,0 @@
1
- import bisect
2
- import functools
3
- import logging
4
- import numbers
5
- import os
6
- import signal
7
- import sys
8
- import traceback
9
- import warnings
10
-
11
- import torch
12
- from pytorch_lightning import seed_everything
13
-
14
- LOGGER = logging.getLogger(__name__)
15
-
16
-
17
- def check_and_warn_input_range(tensor, min_value, max_value, name):
18
- actual_min = tensor.min()
19
- actual_max = tensor.max()
20
- if actual_min < min_value or actual_max > max_value:
21
- warnings.warn(f"{name} must be in {min_value}..{max_value} range, but it ranges {actual_min}..{actual_max}")
22
-
23
-
24
- def sum_dict_with_prefix(target, cur_dict, prefix, default=0):
25
- for k, v in cur_dict.items():
26
- target_key = prefix + k
27
- target[target_key] = target.get(target_key, default) + v
28
-
29
-
30
- def average_dicts(dict_list):
31
- result = {}
32
- norm = 1e-3
33
- for dct in dict_list:
34
- sum_dict_with_prefix(result, dct, '')
35
- norm += 1
36
- for k in list(result):
37
- result[k] /= norm
38
- return result
39
-
40
-
41
- def add_prefix_to_keys(dct, prefix):
42
- return {prefix + k: v for k, v in dct.items()}
43
-
44
-
45
- def set_requires_grad(module, value):
46
- for param in module.parameters():
47
- param.requires_grad = value
48
-
49
-
50
- def flatten_dict(dct):
51
- result = {}
52
- for k, v in dct.items():
53
- if isinstance(k, tuple):
54
- k = '_'.join(k)
55
- if isinstance(v, dict):
56
- for sub_k, sub_v in flatten_dict(v).items():
57
- result[f'{k}_{sub_k}'] = sub_v
58
- else:
59
- result[k] = v
60
- return result
61
-
62
-
63
- class LinearRamp:
64
- def __init__(self, start_value=0, end_value=1, start_iter=-1, end_iter=0):
65
- self.start_value = start_value
66
- self.end_value = end_value
67
- self.start_iter = start_iter
68
- self.end_iter = end_iter
69
-
70
- def __call__(self, i):
71
- if i < self.start_iter:
72
- return self.start_value
73
- if i >= self.end_iter:
74
- return self.end_value
75
- part = (i - self.start_iter) / (self.end_iter - self.start_iter)
76
- return self.start_value * (1 - part) + self.end_value * part
77
-
78
-
79
- class LadderRamp:
80
- def __init__(self, start_iters, values):
81
- self.start_iters = start_iters
82
- self.values = values
83
- assert len(values) == len(start_iters) + 1, (len(values), len(start_iters))
84
-
85
- def __call__(self, i):
86
- segment_i = bisect.bisect_right(self.start_iters, i)
87
- return self.values[segment_i]
88
-
89
-
90
- def get_ramp(kind='ladder', **kwargs):
91
- if kind == 'linear':
92
- return LinearRamp(**kwargs)
93
- if kind == 'ladder':
94
- return LadderRamp(**kwargs)
95
- raise ValueError(f'Unexpected ramp kind: {kind}')
96
-
97
-
98
- def print_traceback_handler(sig, frame):
99
- LOGGER.warning(f'Received signal {sig}')
100
- bt = ''.join(traceback.format_stack())
101
- LOGGER.warning(f'Requested stack trace:\n{bt}')
102
-
103
-
104
- def register_debug_signal_handlers(sig=None, handler=print_traceback_handler):
105
- LOGGER.warning(f'Setting signal {sig} handler {handler}')
106
- signal.signal(sig, handler)
107
-
108
-
109
- def handle_deterministic_config(config):
110
- seed = dict(config).get('seed', None)
111
- if seed is None:
112
- return False
113
-
114
- seed_everything(seed)
115
- return True
116
-
117
-
118
- def get_shape(t):
119
- if torch.is_tensor(t):
120
- return tuple(t.shape)
121
- elif isinstance(t, dict):
122
- return {n: get_shape(q) for n, q in t.items()}
123
- elif isinstance(t, (list, tuple)):
124
- return [get_shape(q) for q in t]
125
- elif isinstance(t, numbers.Number):
126
- return type(t)
127
- else:
128
- raise ValueError('unexpected type {}'.format(type(t)))
129
-
130
-
131
- def get_has_ddp_rank():
132
- master_port = os.environ.get('MASTER_PORT', None)
133
- node_rank = os.environ.get('NODE_RANK', None)
134
- local_rank = os.environ.get('LOCAL_RANK', None)
135
- world_size = os.environ.get('WORLD_SIZE', None)
136
- has_rank = master_port is not None or node_rank is not None or local_rank is not None or world_size is not None
137
- return has_rank
138
-
139
-
140
- def handle_ddp_subprocess():
141
- def main_decorator(main_func):
142
- @functools.wraps(main_func)
143
- def new_main(*args, **kwargs):
144
- # Trainer sets MASTER_PORT, NODE_RANK, LOCAL_RANK, WORLD_SIZE
145
- parent_cwd = os.environ.get('TRAINING_PARENT_WORK_DIR', None)
146
- has_parent = parent_cwd is not None
147
- has_rank = get_has_ddp_rank()
148
- assert has_parent == has_rank, f'Inconsistent state: has_parent={has_parent}, has_rank={has_rank}'
149
-
150
- if has_parent:
151
- # we are in the worker
152
- sys.argv.extend([
153
- f'hydra.run.dir={parent_cwd}',
154
- # 'hydra/hydra_logging=disabled',
155
- # 'hydra/job_logging=disabled'
156
- ])
157
- # do nothing if this is a top-level process
158
- # TRAINING_PARENT_WORK_DIR is set in handle_ddp_parent_process after hydra initialization
159
-
160
- main_func(*args, **kwargs)
161
- return new_main
162
- return main_decorator
163
-
164
-
165
- def handle_ddp_parent_process():
166
- parent_cwd = os.environ.get('TRAINING_PARENT_WORK_DIR', None)
167
- has_parent = parent_cwd is not None
168
- has_rank = get_has_ddp_rank()
169
- assert has_parent == has_rank, f'Inconsistent state: has_parent={has_parent}, has_rank={has_rank}'
170
-
171
- if parent_cwd is None:
172
- os.environ['TRAINING_PARENT_WORK_DIR'] = os.getcwd()
173
-
174
- return has_parent
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/leres/__init__.py DELETED
@@ -1,113 +0,0 @@
1
- import cv2
2
- import numpy as np
3
- import torch
4
- import os
5
- from modules import devices, shared
6
- from annotator.annotator_path import models_path
7
- from torchvision.transforms import transforms
8
-
9
- # AdelaiDepth/LeReS imports
10
- from .leres.depthmap import estimateleres, estimateboost
11
- from .leres.multi_depth_model_woauxi import RelDepthModel
12
- from .leres.net_tools import strip_prefix_if_present
13
-
14
- # pix2pix/merge net imports
15
- from .pix2pix.options.test_options import TestOptions
16
- from .pix2pix.models.pix2pix4depth_model import Pix2Pix4DepthModel
17
-
18
- base_model_path = os.path.join(models_path, "leres")
19
- old_modeldir = os.path.dirname(os.path.realpath(__file__))
20
-
21
- remote_model_path_leres = "https://huggingface.co/lllyasviel/Annotators/resolve/main/res101.pth"
22
- remote_model_path_pix2pix = "https://huggingface.co/lllyasviel/Annotators/resolve/main/latest_net_G.pth"
23
-
24
- model = None
25
- pix2pixmodel = None
26
-
27
- def unload_leres_model():
28
- global model, pix2pixmodel
29
- if model is not None:
30
- model = model.cpu()
31
- if pix2pixmodel is not None:
32
- pix2pixmodel = pix2pixmodel.unload_network('G')
33
-
34
-
35
- def apply_leres(input_image, thr_a, thr_b, boost=False):
36
- global model, pix2pixmodel
37
- if model is None:
38
- model_path = os.path.join(base_model_path, "res101.pth")
39
- old_model_path = os.path.join(old_modeldir, "res101.pth")
40
-
41
- if os.path.exists(old_model_path):
42
- model_path = old_model_path
43
- elif not os.path.exists(model_path):
44
- from basicsr.utils.download_util import load_file_from_url
45
- load_file_from_url(remote_model_path_leres, model_dir=base_model_path)
46
-
47
- if torch.cuda.is_available():
48
- checkpoint = torch.load(model_path)
49
- else:
50
- checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
51
-
52
- model = RelDepthModel(backbone='resnext101')
53
- model.load_state_dict(strip_prefix_if_present(checkpoint['depth_model'], "module."), strict=True)
54
- del checkpoint
55
-
56
- if boost and pix2pixmodel is None:
57
- pix2pixmodel_path = os.path.join(base_model_path, "latest_net_G.pth")
58
- if not os.path.exists(pix2pixmodel_path):
59
- from basicsr.utils.download_util import load_file_from_url
60
- load_file_from_url(remote_model_path_pix2pix, model_dir=base_model_path)
61
-
62
- opt = TestOptions().parse()
63
- if not torch.cuda.is_available():
64
- opt.gpu_ids = [] # cpu mode
65
- pix2pixmodel = Pix2Pix4DepthModel(opt)
66
- pix2pixmodel.save_dir = base_model_path
67
- pix2pixmodel.load_networks('latest')
68
- pix2pixmodel.eval()
69
-
70
- if devices.get_device_for("controlnet").type != 'mps':
71
- model = model.to(devices.get_device_for("controlnet"))
72
-
73
- assert input_image.ndim == 3
74
- height, width, dim = input_image.shape
75
-
76
- with torch.no_grad():
77
-
78
- if boost:
79
- depth = estimateboost(input_image, model, 0, pix2pixmodel, max(width, height))
80
- else:
81
- depth = estimateleres(input_image, model, width, height)
82
-
83
- numbytes=2
84
- depth_min = depth.min()
85
- depth_max = depth.max()
86
- max_val = (2**(8*numbytes))-1
87
-
88
- # check output before normalizing and mapping to 16 bit
89
- if depth_max - depth_min > np.finfo("float").eps:
90
- out = max_val * (depth - depth_min) / (depth_max - depth_min)
91
- else:
92
- out = np.zeros(depth.shape)
93
-
94
- # single channel, 16 bit image
95
- depth_image = out.astype("uint16")
96
-
97
- # convert to uint8
98
- depth_image = cv2.convertScaleAbs(depth_image, alpha=(255.0/65535.0))
99
-
100
- # remove near
101
- if thr_a != 0:
102
- thr_a = ((thr_a/100)*255)
103
- depth_image = cv2.threshold(depth_image, thr_a, 255, cv2.THRESH_TOZERO)[1]
104
-
105
- # invert image
106
- depth_image = cv2.bitwise_not(depth_image)
107
-
108
- # remove bg
109
- if thr_b != 0:
110
- thr_b = ((thr_b/100)*255)
111
- depth_image = cv2.threshold(depth_image, thr_b, 255, cv2.THRESH_TOZERO)[1]
112
-
113
- return depth_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/leres/leres/LICENSE DELETED
@@ -1,23 +0,0 @@
1
- https://github.com/thygate/stable-diffusion-webui-depthmap-script
2
-
3
- MIT License
4
-
5
- Copyright (c) 2023 Bob Thiry
6
-
7
- Permission is hereby granted, free of charge, to any person obtaining a copy
8
- of this software and associated documentation files (the "Software"), to deal
9
- in the Software without restriction, including without limitation the rights
10
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
- copies of the Software, and to permit persons to whom the Software is
12
- furnished to do so, subject to the following conditions:
13
-
14
- The above copyright notice and this permission notice shall be included in all
15
- copies or substantial portions of the Software.
16
-
17
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
- SOFTWARE.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/leres/leres/Resnet.py DELETED
@@ -1,199 +0,0 @@
1
- import torch.nn as nn
2
- import torch.nn as NN
3
-
4
- __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
5
- 'resnet152']
6
-
7
-
8
- model_urls = {
9
- 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
10
- 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
11
- 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
12
- 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
13
- 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
14
- }
15
-
16
-
17
- def conv3x3(in_planes, out_planes, stride=1):
18
- """3x3 convolution with padding"""
19
- return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
20
- padding=1, bias=False)
21
-
22
-
23
- class BasicBlock(nn.Module):
24
- expansion = 1
25
-
26
- def __init__(self, inplanes, planes, stride=1, downsample=None):
27
- super(BasicBlock, self).__init__()
28
- self.conv1 = conv3x3(inplanes, planes, stride)
29
- self.bn1 = NN.BatchNorm2d(planes) #NN.BatchNorm2d
30
- self.relu = nn.ReLU(inplace=True)
31
- self.conv2 = conv3x3(planes, planes)
32
- self.bn2 = NN.BatchNorm2d(planes) #NN.BatchNorm2d
33
- self.downsample = downsample
34
- self.stride = stride
35
-
36
- def forward(self, x):
37
- residual = x
38
-
39
- out = self.conv1(x)
40
- out = self.bn1(out)
41
- out = self.relu(out)
42
-
43
- out = self.conv2(out)
44
- out = self.bn2(out)
45
-
46
- if self.downsample is not None:
47
- residual = self.downsample(x)
48
-
49
- out += residual
50
- out = self.relu(out)
51
-
52
- return out
53
-
54
-
55
- class Bottleneck(nn.Module):
56
- expansion = 4
57
-
58
- def __init__(self, inplanes, planes, stride=1, downsample=None):
59
- super(Bottleneck, self).__init__()
60
- self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
61
- self.bn1 = NN.BatchNorm2d(planes) #NN.BatchNorm2d
62
- self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
63
- padding=1, bias=False)
64
- self.bn2 = NN.BatchNorm2d(planes) #NN.BatchNorm2d
65
- self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
66
- self.bn3 = NN.BatchNorm2d(planes * self.expansion) #NN.BatchNorm2d
67
- self.relu = nn.ReLU(inplace=True)
68
- self.downsample = downsample
69
- self.stride = stride
70
-
71
- def forward(self, x):
72
- residual = x
73
-
74
- out = self.conv1(x)
75
- out = self.bn1(out)
76
- out = self.relu(out)
77
-
78
- out = self.conv2(out)
79
- out = self.bn2(out)
80
- out = self.relu(out)
81
-
82
- out = self.conv3(out)
83
- out = self.bn3(out)
84
-
85
- if self.downsample is not None:
86
- residual = self.downsample(x)
87
-
88
- out += residual
89
- out = self.relu(out)
90
-
91
- return out
92
-
93
-
94
- class ResNet(nn.Module):
95
-
96
- def __init__(self, block, layers, num_classes=1000):
97
- self.inplanes = 64
98
- super(ResNet, self).__init__()
99
- self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
100
- bias=False)
101
- self.bn1 = NN.BatchNorm2d(64) #NN.BatchNorm2d
102
- self.relu = nn.ReLU(inplace=True)
103
- self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
104
- self.layer1 = self._make_layer(block, 64, layers[0])
105
- self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
106
- self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
107
- self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
108
- #self.avgpool = nn.AvgPool2d(7, stride=1)
109
- #self.fc = nn.Linear(512 * block.expansion, num_classes)
110
-
111
- for m in self.modules():
112
- if isinstance(m, nn.Conv2d):
113
- nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
114
- elif isinstance(m, nn.BatchNorm2d):
115
- nn.init.constant_(m.weight, 1)
116
- nn.init.constant_(m.bias, 0)
117
-
118
- def _make_layer(self, block, planes, blocks, stride=1):
119
- downsample = None
120
- if stride != 1 or self.inplanes != planes * block.expansion:
121
- downsample = nn.Sequential(
122
- nn.Conv2d(self.inplanes, planes * block.expansion,
123
- kernel_size=1, stride=stride, bias=False),
124
- NN.BatchNorm2d(planes * block.expansion), #NN.BatchNorm2d
125
- )
126
-
127
- layers = []
128
- layers.append(block(self.inplanes, planes, stride, downsample))
129
- self.inplanes = planes * block.expansion
130
- for i in range(1, blocks):
131
- layers.append(block(self.inplanes, planes))
132
-
133
- return nn.Sequential(*layers)
134
-
135
- def forward(self, x):
136
- features = []
137
-
138
- x = self.conv1(x)
139
- x = self.bn1(x)
140
- x = self.relu(x)
141
- x = self.maxpool(x)
142
-
143
- x = self.layer1(x)
144
- features.append(x)
145
- x = self.layer2(x)
146
- features.append(x)
147
- x = self.layer3(x)
148
- features.append(x)
149
- x = self.layer4(x)
150
- features.append(x)
151
-
152
- return features
153
-
154
-
155
- def resnet18(pretrained=True, **kwargs):
156
- """Constructs a ResNet-18 model.
157
- Args:
158
- pretrained (bool): If True, returns a model pre-trained on ImageNet
159
- """
160
- model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
161
- return model
162
-
163
-
164
- def resnet34(pretrained=True, **kwargs):
165
- """Constructs a ResNet-34 model.
166
- Args:
167
- pretrained (bool): If True, returns a model pre-trained on ImageNet
168
- """
169
- model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
170
- return model
171
-
172
-
173
- def resnet50(pretrained=True, **kwargs):
174
- """Constructs a ResNet-50 model.
175
- Args:
176
- pretrained (bool): If True, returns a model pre-trained on ImageNet
177
- """
178
- model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
179
-
180
- return model
181
-
182
-
183
- def resnet101(pretrained=True, **kwargs):
184
- """Constructs a ResNet-101 model.
185
- Args:
186
- pretrained (bool): If True, returns a model pre-trained on ImageNet
187
- """
188
- model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
189
-
190
- return model
191
-
192
-
193
- def resnet152(pretrained=True, **kwargs):
194
- """Constructs a ResNet-152 model.
195
- Args:
196
- pretrained (bool): If True, returns a model pre-trained on ImageNet
197
- """
198
- model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
199
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/leres/leres/Resnext_torch.py DELETED
@@ -1,237 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding: utf-8
3
- import torch.nn as nn
4
-
5
- try:
6
- from urllib import urlretrieve
7
- except ImportError:
8
- from urllib.request import urlretrieve
9
-
10
- __all__ = ['resnext101_32x8d']
11
-
12
-
13
- model_urls = {
14
- 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
15
- 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
16
- }
17
-
18
-
19
- def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
20
- """3x3 convolution with padding"""
21
- return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
22
- padding=dilation, groups=groups, bias=False, dilation=dilation)
23
-
24
-
25
- def conv1x1(in_planes, out_planes, stride=1):
26
- """1x1 convolution"""
27
- return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
28
-
29
-
30
- class BasicBlock(nn.Module):
31
- expansion = 1
32
-
33
- def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
34
- base_width=64, dilation=1, norm_layer=None):
35
- super(BasicBlock, self).__init__()
36
- if norm_layer is None:
37
- norm_layer = nn.BatchNorm2d
38
- if groups != 1 or base_width != 64:
39
- raise ValueError('BasicBlock only supports groups=1 and base_width=64')
40
- if dilation > 1:
41
- raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
42
- # Both self.conv1 and self.downsample layers downsample the input when stride != 1
43
- self.conv1 = conv3x3(inplanes, planes, stride)
44
- self.bn1 = norm_layer(planes)
45
- self.relu = nn.ReLU(inplace=True)
46
- self.conv2 = conv3x3(planes, planes)
47
- self.bn2 = norm_layer(planes)
48
- self.downsample = downsample
49
- self.stride = stride
50
-
51
- def forward(self, x):
52
- identity = x
53
-
54
- out = self.conv1(x)
55
- out = self.bn1(out)
56
- out = self.relu(out)
57
-
58
- out = self.conv2(out)
59
- out = self.bn2(out)
60
-
61
- if self.downsample is not None:
62
- identity = self.downsample(x)
63
-
64
- out += identity
65
- out = self.relu(out)
66
-
67
- return out
68
-
69
-
70
- class Bottleneck(nn.Module):
71
- # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
72
- # while original implementation places the stride at the first 1x1 convolution(self.conv1)
73
- # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
74
- # This variant is also known as ResNet V1.5 and improves accuracy according to
75
- # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
76
-
77
- expansion = 4
78
-
79
- def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
80
- base_width=64, dilation=1, norm_layer=None):
81
- super(Bottleneck, self).__init__()
82
- if norm_layer is None:
83
- norm_layer = nn.BatchNorm2d
84
- width = int(planes * (base_width / 64.)) * groups
85
- # Both self.conv2 and self.downsample layers downsample the input when stride != 1
86
- self.conv1 = conv1x1(inplanes, width)
87
- self.bn1 = norm_layer(width)
88
- self.conv2 = conv3x3(width, width, stride, groups, dilation)
89
- self.bn2 = norm_layer(width)
90
- self.conv3 = conv1x1(width, planes * self.expansion)
91
- self.bn3 = norm_layer(planes * self.expansion)
92
- self.relu = nn.ReLU(inplace=True)
93
- self.downsample = downsample
94
- self.stride = stride
95
-
96
- def forward(self, x):
97
- identity = x
98
-
99
- out = self.conv1(x)
100
- out = self.bn1(out)
101
- out = self.relu(out)
102
-
103
- out = self.conv2(out)
104
- out = self.bn2(out)
105
- out = self.relu(out)
106
-
107
- out = self.conv3(out)
108
- out = self.bn3(out)
109
-
110
- if self.downsample is not None:
111
- identity = self.downsample(x)
112
-
113
- out += identity
114
- out = self.relu(out)
115
-
116
- return out
117
-
118
-
119
- class ResNet(nn.Module):
120
-
121
- def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
122
- groups=1, width_per_group=64, replace_stride_with_dilation=None,
123
- norm_layer=None):
124
- super(ResNet, self).__init__()
125
- if norm_layer is None:
126
- norm_layer = nn.BatchNorm2d
127
- self._norm_layer = norm_layer
128
-
129
- self.inplanes = 64
130
- self.dilation = 1
131
- if replace_stride_with_dilation is None:
132
- # each element in the tuple indicates if we should replace
133
- # the 2x2 stride with a dilated convolution instead
134
- replace_stride_with_dilation = [False, False, False]
135
- if len(replace_stride_with_dilation) != 3:
136
- raise ValueError("replace_stride_with_dilation should be None "
137
- "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
138
- self.groups = groups
139
- self.base_width = width_per_group
140
- self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
141
- bias=False)
142
- self.bn1 = norm_layer(self.inplanes)
143
- self.relu = nn.ReLU(inplace=True)
144
- self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
145
- self.layer1 = self._make_layer(block, 64, layers[0])
146
- self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
147
- dilate=replace_stride_with_dilation[0])
148
- self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
149
- dilate=replace_stride_with_dilation[1])
150
- self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
151
- dilate=replace_stride_with_dilation[2])
152
- #self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
153
- #self.fc = nn.Linear(512 * block.expansion, num_classes)
154
-
155
- for m in self.modules():
156
- if isinstance(m, nn.Conv2d):
157
- nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
158
- elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
159
- nn.init.constant_(m.weight, 1)
160
- nn.init.constant_(m.bias, 0)
161
-
162
- # Zero-initialize the last BN in each residual branch,
163
- # so that the residual branch starts with zeros, and each residual block behaves like an identity.
164
- # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
165
- if zero_init_residual:
166
- for m in self.modules():
167
- if isinstance(m, Bottleneck):
168
- nn.init.constant_(m.bn3.weight, 0)
169
- elif isinstance(m, BasicBlock):
170
- nn.init.constant_(m.bn2.weight, 0)
171
-
172
- def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
173
- norm_layer = self._norm_layer
174
- downsample = None
175
- previous_dilation = self.dilation
176
- if dilate:
177
- self.dilation *= stride
178
- stride = 1
179
- if stride != 1 or self.inplanes != planes * block.expansion:
180
- downsample = nn.Sequential(
181
- conv1x1(self.inplanes, planes * block.expansion, stride),
182
- norm_layer(planes * block.expansion),
183
- )
184
-
185
- layers = []
186
- layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
187
- self.base_width, previous_dilation, norm_layer))
188
- self.inplanes = planes * block.expansion
189
- for _ in range(1, blocks):
190
- layers.append(block(self.inplanes, planes, groups=self.groups,
191
- base_width=self.base_width, dilation=self.dilation,
192
- norm_layer=norm_layer))
193
-
194
- return nn.Sequential(*layers)
195
-
196
- def _forward_impl(self, x):
197
- # See note [TorchScript super()]
198
- features = []
199
- x = self.conv1(x)
200
- x = self.bn1(x)
201
- x = self.relu(x)
202
- x = self.maxpool(x)
203
-
204
- x = self.layer1(x)
205
- features.append(x)
206
-
207
- x = self.layer2(x)
208
- features.append(x)
209
-
210
- x = self.layer3(x)
211
- features.append(x)
212
-
213
- x = self.layer4(x)
214
- features.append(x)
215
-
216
- #x = self.avgpool(x)
217
- #x = torch.flatten(x, 1)
218
- #x = self.fc(x)
219
-
220
- return features
221
-
222
- def forward(self, x):
223
- return self._forward_impl(x)
224
-
225
-
226
-
227
- def resnext101_32x8d(pretrained=True, **kwargs):
228
- """Constructs a ResNet-152 model.
229
- Args:
230
- pretrained (bool): If True, returns a model pre-trained on ImageNet
231
- """
232
- kwargs['groups'] = 32
233
- kwargs['width_per_group'] = 8
234
-
235
- model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
236
- return model
237
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/leres/leres/depthmap.py DELETED
@@ -1,546 +0,0 @@
1
- # Author: thygate
2
- # https://github.com/thygate/stable-diffusion-webui-depthmap-script
3
-
4
- from modules import devices
5
- from modules.shared import opts
6
- from torchvision.transforms import transforms
7
- from operator import getitem
8
-
9
- import torch, gc
10
- import cv2
11
- import numpy as np
12
- import skimage.measure
13
-
14
- whole_size_threshold = 1600 # R_max from the paper
15
- pix2pixsize = 1024
16
-
17
- def scale_torch(img):
18
- """
19
- Scale the image and output it in torch.tensor.
20
- :param img: input rgb is in shape [H, W, C], input depth/disp is in shape [H, W]
21
- :param scale: the scale factor. float
22
- :return: img. [C, H, W]
23
- """
24
- if len(img.shape) == 2:
25
- img = img[np.newaxis, :, :]
26
- if img.shape[2] == 3:
27
- transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406) , (0.229, 0.224, 0.225) )])
28
- img = transform(img.astype(np.float32))
29
- else:
30
- img = img.astype(np.float32)
31
- img = torch.from_numpy(img)
32
- return img
33
-
34
- def estimateleres(img, model, w, h):
35
- # leres transform input
36
- rgb_c = img[:, :, ::-1].copy()
37
- A_resize = cv2.resize(rgb_c, (w, h))
38
- img_torch = scale_torch(A_resize)[None, :, :, :]
39
-
40
- # compute
41
- with torch.no_grad():
42
- img_torch = img_torch.to(devices.get_device_for("controlnet"))
43
- prediction = model.depth_model(img_torch)
44
-
45
- prediction = prediction.squeeze().cpu().numpy()
46
- prediction = cv2.resize(prediction, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_CUBIC)
47
-
48
- return prediction
49
-
50
- def generatemask(size):
51
- # Generates a Guassian mask
52
- mask = np.zeros(size, dtype=np.float32)
53
- sigma = int(size[0]/16)
54
- k_size = int(2 * np.ceil(2 * int(size[0]/16)) + 1)
55
- mask[int(0.15*size[0]):size[0] - int(0.15*size[0]), int(0.15*size[1]): size[1] - int(0.15*size[1])] = 1
56
- mask = cv2.GaussianBlur(mask, (int(k_size), int(k_size)), sigma)
57
- mask = (mask - mask.min()) / (mask.max() - mask.min())
58
- mask = mask.astype(np.float32)
59
- return mask
60
-
61
- def resizewithpool(img, size):
62
- i_size = img.shape[0]
63
- n = int(np.floor(i_size/size))
64
-
65
- out = skimage.measure.block_reduce(img, (n, n), np.max)
66
- return out
67
-
68
- def rgb2gray(rgb):
69
- # Converts rgb to gray
70
- return np.dot(rgb[..., :3], [0.2989, 0.5870, 0.1140])
71
-
72
- def calculateprocessingres(img, basesize, confidence=0.1, scale_threshold=3, whole_size_threshold=3000):
73
- # Returns the R_x resolution described in section 5 of the main paper.
74
-
75
- # Parameters:
76
- # img :input rgb image
77
- # basesize : size the dilation kernel which is equal to receptive field of the network.
78
- # confidence: value of x in R_x; allowed percentage of pixels that are not getting any contextual cue.
79
- # scale_threshold: maximum allowed upscaling on the input image ; it has been set to 3.
80
- # whole_size_threshold: maximum allowed resolution. (R_max from section 6 of the main paper)
81
-
82
- # Returns:
83
- # outputsize_scale*speed_scale :The computed R_x resolution
84
- # patch_scale: K parameter from section 6 of the paper
85
-
86
- # speed scale parameter is to process every image in a smaller size to accelerate the R_x resolution search
87
- speed_scale = 32
88
- image_dim = int(min(img.shape[0:2]))
89
-
90
- gray = rgb2gray(img)
91
- grad = np.abs(cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)) + np.abs(cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3))
92
- grad = cv2.resize(grad, (image_dim, image_dim), cv2.INTER_AREA)
93
-
94
- # thresholding the gradient map to generate the edge-map as a proxy of the contextual cues
95
- m = grad.min()
96
- M = grad.max()
97
- middle = m + (0.4 * (M - m))
98
- grad[grad < middle] = 0
99
- grad[grad >= middle] = 1
100
-
101
- # dilation kernel with size of the receptive field
102
- kernel = np.ones((int(basesize/speed_scale), int(basesize/speed_scale)), float)
103
- # dilation kernel with size of the a quarter of receptive field used to compute k
104
- # as described in section 6 of main paper
105
- kernel2 = np.ones((int(basesize / (4*speed_scale)), int(basesize / (4*speed_scale))), float)
106
-
107
- # Output resolution limit set by the whole_size_threshold and scale_threshold.
108
- threshold = min(whole_size_threshold, scale_threshold * max(img.shape[:2]))
109
-
110
- outputsize_scale = basesize / speed_scale
111
- for p_size in range(int(basesize/speed_scale), int(threshold/speed_scale), int(basesize / (2*speed_scale))):
112
- grad_resized = resizewithpool(grad, p_size)
113
- grad_resized = cv2.resize(grad_resized, (p_size, p_size), cv2.INTER_NEAREST)
114
- grad_resized[grad_resized >= 0.5] = 1
115
- grad_resized[grad_resized < 0.5] = 0
116
-
117
- dilated = cv2.dilate(grad_resized, kernel, iterations=1)
118
- meanvalue = (1-dilated).mean()
119
- if meanvalue > confidence:
120
- break
121
- else:
122
- outputsize_scale = p_size
123
-
124
- grad_region = cv2.dilate(grad_resized, kernel2, iterations=1)
125
- patch_scale = grad_region.mean()
126
-
127
- return int(outputsize_scale*speed_scale), patch_scale
128
-
129
- # Generate a double-input depth estimation
130
- def doubleestimate(img, size1, size2, pix2pixsize, model, net_type, pix2pixmodel):
131
- # Generate the low resolution estimation
132
- estimate1 = singleestimate(img, size1, model, net_type)
133
- # Resize to the inference size of merge network.
134
- estimate1 = cv2.resize(estimate1, (pix2pixsize, pix2pixsize), interpolation=cv2.INTER_CUBIC)
135
-
136
- # Generate the high resolution estimation
137
- estimate2 = singleestimate(img, size2, model, net_type)
138
- # Resize to the inference size of merge network.
139
- estimate2 = cv2.resize(estimate2, (pix2pixsize, pix2pixsize), interpolation=cv2.INTER_CUBIC)
140
-
141
- # Inference on the merge model
142
- pix2pixmodel.set_input(estimate1, estimate2)
143
- pix2pixmodel.test()
144
- visuals = pix2pixmodel.get_current_visuals()
145
- prediction_mapped = visuals['fake_B']
146
- prediction_mapped = (prediction_mapped+1)/2
147
- prediction_mapped = (prediction_mapped - torch.min(prediction_mapped)) / (
148
- torch.max(prediction_mapped) - torch.min(prediction_mapped))
149
- prediction_mapped = prediction_mapped.squeeze().cpu().numpy()
150
-
151
- return prediction_mapped
152
-
153
- # Generate a single-input depth estimation
154
- def singleestimate(img, msize, model, net_type):
155
- # if net_type == 0:
156
- return estimateleres(img, model, msize, msize)
157
- # else:
158
- # return estimatemidasBoost(img, model, msize, msize)
159
-
160
- def applyGridpatch(blsize, stride, img, box):
161
- # Extract a simple grid patch.
162
- counter1 = 0
163
- patch_bound_list = {}
164
- for k in range(blsize, img.shape[1] - blsize, stride):
165
- for j in range(blsize, img.shape[0] - blsize, stride):
166
- patch_bound_list[str(counter1)] = {}
167
- patchbounds = [j - blsize, k - blsize, j - blsize + 2 * blsize, k - blsize + 2 * blsize]
168
- patch_bound = [box[0] + patchbounds[1], box[1] + patchbounds[0], patchbounds[3] - patchbounds[1],
169
- patchbounds[2] - patchbounds[0]]
170
- patch_bound_list[str(counter1)]['rect'] = patch_bound
171
- patch_bound_list[str(counter1)]['size'] = patch_bound[2]
172
- counter1 = counter1 + 1
173
- return patch_bound_list
174
-
175
- # Generating local patches to perform the local refinement described in section 6 of the main paper.
176
- def generatepatchs(img, base_size):
177
-
178
- # Compute the gradients as a proxy of the contextual cues.
179
- img_gray = rgb2gray(img)
180
- whole_grad = np.abs(cv2.Sobel(img_gray, cv2.CV_64F, 0, 1, ksize=3)) +\
181
- np.abs(cv2.Sobel(img_gray, cv2.CV_64F, 1, 0, ksize=3))
182
-
183
- threshold = whole_grad[whole_grad > 0].mean()
184
- whole_grad[whole_grad < threshold] = 0
185
-
186
- # We use the integral image to speed-up the evaluation of the amount of gradients for each patch.
187
- gf = whole_grad.sum()/len(whole_grad.reshape(-1))
188
- grad_integral_image = cv2.integral(whole_grad)
189
-
190
- # Variables are selected such that the initial patch size would be the receptive field size
191
- # and the stride is set to 1/3 of the receptive field size.
192
- blsize = int(round(base_size/2))
193
- stride = int(round(blsize*0.75))
194
-
195
- # Get initial Grid
196
- patch_bound_list = applyGridpatch(blsize, stride, img, [0, 0, 0, 0])
197
-
198
- # Refine initial Grid of patches by discarding the flat (in terms of gradients of the rgb image) ones. Refine
199
- # each patch size to ensure that there will be enough depth cues for the network to generate a consistent depth map.
200
- print("Selecting patches ...")
201
- patch_bound_list = adaptiveselection(grad_integral_image, patch_bound_list, gf)
202
-
203
- # Sort the patch list to make sure the merging operation will be done with the correct order: starting from biggest
204
- # patch
205
- patchset = sorted(patch_bound_list.items(), key=lambda x: getitem(x[1], 'size'), reverse=True)
206
- return patchset
207
-
208
- def getGF_fromintegral(integralimage, rect):
209
- # Computes the gradient density of a given patch from the gradient integral image.
210
- x1 = rect[1]
211
- x2 = rect[1]+rect[3]
212
- y1 = rect[0]
213
- y2 = rect[0]+rect[2]
214
- value = integralimage[x2, y2]-integralimage[x1, y2]-integralimage[x2, y1]+integralimage[x1, y1]
215
- return value
216
-
217
- # Adaptively select patches
218
- def adaptiveselection(integral_grad, patch_bound_list, gf):
219
- patchlist = {}
220
- count = 0
221
- height, width = integral_grad.shape
222
-
223
- search_step = int(32/factor)
224
-
225
- # Go through all patches
226
- for c in range(len(patch_bound_list)):
227
- # Get patch
228
- bbox = patch_bound_list[str(c)]['rect']
229
-
230
- # Compute the amount of gradients present in the patch from the integral image.
231
- cgf = getGF_fromintegral(integral_grad, bbox)/(bbox[2]*bbox[3])
232
-
233
- # Check if patching is beneficial by comparing the gradient density of the patch to
234
- # the gradient density of the whole image
235
- if cgf >= gf:
236
- bbox_test = bbox.copy()
237
- patchlist[str(count)] = {}
238
-
239
- # Enlarge each patch until the gradient density of the patch is equal
240
- # to the whole image gradient density
241
- while True:
242
-
243
- bbox_test[0] = bbox_test[0] - int(search_step/2)
244
- bbox_test[1] = bbox_test[1] - int(search_step/2)
245
-
246
- bbox_test[2] = bbox_test[2] + search_step
247
- bbox_test[3] = bbox_test[3] + search_step
248
-
249
- # Check if we are still within the image
250
- if bbox_test[0] < 0 or bbox_test[1] < 0 or bbox_test[1] + bbox_test[3] >= height \
251
- or bbox_test[0] + bbox_test[2] >= width:
252
- break
253
-
254
- # Compare gradient density
255
- cgf = getGF_fromintegral(integral_grad, bbox_test)/(bbox_test[2]*bbox_test[3])
256
- if cgf < gf:
257
- break
258
- bbox = bbox_test.copy()
259
-
260
- # Add patch to selected patches
261
- patchlist[str(count)]['rect'] = bbox
262
- patchlist[str(count)]['size'] = bbox[2]
263
- count = count + 1
264
-
265
- # Return selected patches
266
- return patchlist
267
-
268
- def impatch(image, rect):
269
- # Extract the given patch pixels from a given image.
270
- w1 = rect[0]
271
- h1 = rect[1]
272
- w2 = w1 + rect[2]
273
- h2 = h1 + rect[3]
274
- image_patch = image[h1:h2, w1:w2]
275
- return image_patch
276
-
277
- class ImageandPatchs:
278
- def __init__(self, root_dir, name, patchsinfo, rgb_image, scale=1):
279
- self.root_dir = root_dir
280
- self.patchsinfo = patchsinfo
281
- self.name = name
282
- self.patchs = patchsinfo
283
- self.scale = scale
284
-
285
- self.rgb_image = cv2.resize(rgb_image, (round(rgb_image.shape[1]*scale), round(rgb_image.shape[0]*scale)),
286
- interpolation=cv2.INTER_CUBIC)
287
-
288
- self.do_have_estimate = False
289
- self.estimation_updated_image = None
290
- self.estimation_base_image = None
291
-
292
- def __len__(self):
293
- return len(self.patchs)
294
-
295
- def set_base_estimate(self, est):
296
- self.estimation_base_image = est
297
- if self.estimation_updated_image is not None:
298
- self.do_have_estimate = True
299
-
300
- def set_updated_estimate(self, est):
301
- self.estimation_updated_image = est
302
- if self.estimation_base_image is not None:
303
- self.do_have_estimate = True
304
-
305
- def __getitem__(self, index):
306
- patch_id = int(self.patchs[index][0])
307
- rect = np.array(self.patchs[index][1]['rect'])
308
- msize = self.patchs[index][1]['size']
309
-
310
- ## applying scale to rect:
311
- rect = np.round(rect * self.scale)
312
- rect = rect.astype('int')
313
- msize = round(msize * self.scale)
314
-
315
- patch_rgb = impatch(self.rgb_image, rect)
316
- if self.do_have_estimate:
317
- patch_whole_estimate_base = impatch(self.estimation_base_image, rect)
318
- patch_whole_estimate_updated = impatch(self.estimation_updated_image, rect)
319
- return {'patch_rgb': patch_rgb, 'patch_whole_estimate_base': patch_whole_estimate_base,
320
- 'patch_whole_estimate_updated': patch_whole_estimate_updated, 'rect': rect,
321
- 'size': msize, 'id': patch_id}
322
- else:
323
- return {'patch_rgb': patch_rgb, 'rect': rect, 'size': msize, 'id': patch_id}
324
-
325
- def print_options(self, opt):
326
- """Print and save options
327
-
328
- It will print both current options and default values(if different).
329
- It will save options into a text file / [checkpoints_dir] / opt.txt
330
- """
331
- message = ''
332
- message += '----------------- Options ---------------\n'
333
- for k, v in sorted(vars(opt).items()):
334
- comment = ''
335
- default = self.parser.get_default(k)
336
- if v != default:
337
- comment = '\t[default: %s]' % str(default)
338
- message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
339
- message += '----------------- End -------------------'
340
- print(message)
341
-
342
- # save to the disk
343
- """
344
- expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
345
- util.mkdirs(expr_dir)
346
- file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
347
- with open(file_name, 'wt') as opt_file:
348
- opt_file.write(message)
349
- opt_file.write('\n')
350
- """
351
-
352
- def parse(self):
353
- """Parse our options, create checkpoints directory suffix, and set up gpu device."""
354
- opt = self.gather_options()
355
- opt.isTrain = self.isTrain # train or test
356
-
357
- # process opt.suffix
358
- if opt.suffix:
359
- suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
360
- opt.name = opt.name + suffix
361
-
362
- #self.print_options(opt)
363
-
364
- # set gpu ids
365
- str_ids = opt.gpu_ids.split(',')
366
- opt.gpu_ids = []
367
- for str_id in str_ids:
368
- id = int(str_id)
369
- if id >= 0:
370
- opt.gpu_ids.append(id)
371
- #if len(opt.gpu_ids) > 0:
372
- # torch.cuda.set_device(opt.gpu_ids[0])
373
-
374
- self.opt = opt
375
- return self.opt
376
-
377
-
378
- def estimateboost(img, model, model_type, pix2pixmodel, max_res=512):
379
- global whole_size_threshold
380
-
381
- # get settings
382
- if hasattr(opts, 'depthmap_script_boost_rmax'):
383
- whole_size_threshold = opts.depthmap_script_boost_rmax
384
-
385
- if model_type == 0: #leres
386
- net_receptive_field_size = 448
387
- patch_netsize = 2 * net_receptive_field_size
388
- elif model_type == 1: #dpt_beit_large_512
389
- net_receptive_field_size = 512
390
- patch_netsize = 2 * net_receptive_field_size
391
- else: #other midas
392
- net_receptive_field_size = 384
393
- patch_netsize = 2 * net_receptive_field_size
394
-
395
- gc.collect()
396
- devices.torch_gc()
397
-
398
- # Generate mask used to smoothly blend the local pathc estimations to the base estimate.
399
- # It is arbitrarily large to avoid artifacts during rescaling for each crop.
400
- mask_org = generatemask((3000, 3000))
401
- mask = mask_org.copy()
402
-
403
- # Value x of R_x defined in the section 5 of the main paper.
404
- r_threshold_value = 0.2
405
- #if R0:
406
- # r_threshold_value = 0
407
-
408
- input_resolution = img.shape
409
- scale_threshold = 3 # Allows up-scaling with a scale up to 3
410
-
411
- # Find the best input resolution R-x. The resolution search described in section 5-double estimation of the main paper and section B of the
412
- # supplementary material.
413
- whole_image_optimal_size, patch_scale = calculateprocessingres(img, net_receptive_field_size, r_threshold_value, scale_threshold, whole_size_threshold)
414
-
415
- # print('wholeImage being processed in :', whole_image_optimal_size)
416
-
417
- # Generate the base estimate using the double estimation.
418
- whole_estimate = doubleestimate(img, net_receptive_field_size, whole_image_optimal_size, pix2pixsize, model, model_type, pix2pixmodel)
419
-
420
- # Compute the multiplier described in section 6 of the main paper to make sure our initial patch can select
421
- # small high-density regions of the image.
422
- global factor
423
- factor = max(min(1, 4 * patch_scale * whole_image_optimal_size / whole_size_threshold), 0.2)
424
- # print('Adjust factor is:', 1/factor)
425
-
426
- # Check if Local boosting is beneficial.
427
- if max_res < whole_image_optimal_size:
428
- # print("No Local boosting. Specified Max Res is smaller than R20, Returning doubleestimate result")
429
- return cv2.resize(whole_estimate, (input_resolution[1], input_resolution[0]), interpolation=cv2.INTER_CUBIC)
430
-
431
- # Compute the default target resolution.
432
- if img.shape[0] > img.shape[1]:
433
- a = 2 * whole_image_optimal_size
434
- b = round(2 * whole_image_optimal_size * img.shape[1] / img.shape[0])
435
- else:
436
- a = round(2 * whole_image_optimal_size * img.shape[0] / img.shape[1])
437
- b = 2 * whole_image_optimal_size
438
- b = int(round(b / factor))
439
- a = int(round(a / factor))
440
-
441
- """
442
- # recompute a, b and saturate to max res.
443
- if max(a,b) > max_res:
444
- print('Default Res is higher than max-res: Reducing final resolution')
445
- if img.shape[0] > img.shape[1]:
446
- a = max_res
447
- b = round(max_res * img.shape[1] / img.shape[0])
448
- else:
449
- a = round(max_res * img.shape[0] / img.shape[1])
450
- b = max_res
451
- b = int(b)
452
- a = int(a)
453
- """
454
-
455
- img = cv2.resize(img, (b, a), interpolation=cv2.INTER_CUBIC)
456
-
457
- # Extract selected patches for local refinement
458
- base_size = net_receptive_field_size * 2
459
- patchset = generatepatchs(img, base_size)
460
-
461
- # print('Target resolution: ', img.shape)
462
-
463
- # Computing a scale in case user prompted to generate the results as the same resolution of the input.
464
- # Notice that our method output resolution is independent of the input resolution and this parameter will only
465
- # enable a scaling operation during the local patch merge implementation to generate results with the same resolution
466
- # as the input.
467
- """
468
- if output_resolution == 1:
469
- mergein_scale = input_resolution[0] / img.shape[0]
470
- print('Dynamicly change merged-in resolution; scale:', mergein_scale)
471
- else:
472
- mergein_scale = 1
473
- """
474
- # always rescale to input res for now
475
- mergein_scale = input_resolution[0] / img.shape[0]
476
-
477
- imageandpatchs = ImageandPatchs('', '', patchset, img, mergein_scale)
478
- whole_estimate_resized = cv2.resize(whole_estimate, (round(img.shape[1]*mergein_scale),
479
- round(img.shape[0]*mergein_scale)), interpolation=cv2.INTER_CUBIC)
480
- imageandpatchs.set_base_estimate(whole_estimate_resized.copy())
481
- imageandpatchs.set_updated_estimate(whole_estimate_resized.copy())
482
-
483
- print('Resulting depthmap resolution will be :', whole_estimate_resized.shape[:2])
484
- print('Patches to process: '+str(len(imageandpatchs)))
485
-
486
- # Enumerate through all patches, generate their estimations and refining the base estimate.
487
- for patch_ind in range(len(imageandpatchs)):
488
-
489
- # Get patch information
490
- patch = imageandpatchs[patch_ind] # patch object
491
- patch_rgb = patch['patch_rgb'] # rgb patch
492
- patch_whole_estimate_base = patch['patch_whole_estimate_base'] # corresponding patch from base
493
- rect = patch['rect'] # patch size and location
494
- patch_id = patch['id'] # patch ID
495
- org_size = patch_whole_estimate_base.shape # the original size from the unscaled input
496
- print('\t Processing patch', patch_ind, '/', len(imageandpatchs)-1, '|', rect)
497
-
498
- # We apply double estimation for patches. The high resolution value is fixed to twice the receptive
499
- # field size of the network for patches to accelerate the process.
500
- patch_estimation = doubleestimate(patch_rgb, net_receptive_field_size, patch_netsize, pix2pixsize, model, model_type, pix2pixmodel)
501
- patch_estimation = cv2.resize(patch_estimation, (pix2pixsize, pix2pixsize), interpolation=cv2.INTER_CUBIC)
502
- patch_whole_estimate_base = cv2.resize(patch_whole_estimate_base, (pix2pixsize, pix2pixsize), interpolation=cv2.INTER_CUBIC)
503
-
504
- # Merging the patch estimation into the base estimate using our merge network:
505
- # We feed the patch estimation and the same region from the updated base estimate to the merge network
506
- # to generate the target estimate for the corresponding region.
507
- pix2pixmodel.set_input(patch_whole_estimate_base, patch_estimation)
508
-
509
- # Run merging network
510
- pix2pixmodel.test()
511
- visuals = pix2pixmodel.get_current_visuals()
512
-
513
- prediction_mapped = visuals['fake_B']
514
- prediction_mapped = (prediction_mapped+1)/2
515
- prediction_mapped = prediction_mapped.squeeze().cpu().numpy()
516
-
517
- mapped = prediction_mapped
518
-
519
- # We use a simple linear polynomial to make sure the result of the merge network would match the values of
520
- # base estimate
521
- p_coef = np.polyfit(mapped.reshape(-1), patch_whole_estimate_base.reshape(-1), deg=1)
522
- merged = np.polyval(p_coef, mapped.reshape(-1)).reshape(mapped.shape)
523
-
524
- merged = cv2.resize(merged, (org_size[1],org_size[0]), interpolation=cv2.INTER_CUBIC)
525
-
526
- # Get patch size and location
527
- w1 = rect[0]
528
- h1 = rect[1]
529
- w2 = w1 + rect[2]
530
- h2 = h1 + rect[3]
531
-
532
- # To speed up the implementation, we only generate the Gaussian mask once with a sufficiently large size
533
- # and resize it to our needed size while merging the patches.
534
- if mask.shape != org_size:
535
- mask = cv2.resize(mask_org, (org_size[1],org_size[0]), interpolation=cv2.INTER_LINEAR)
536
-
537
- tobemergedto = imageandpatchs.estimation_updated_image
538
-
539
- # Update the whole estimation:
540
- # We use a simple Gaussian mask to blend the merged patch region with the base estimate to ensure seamless
541
- # blending at the boundaries of the patch region.
542
- tobemergedto[h1:h2, w1:w2] = np.multiply(tobemergedto[h1:h2, w1:w2], 1 - mask) + np.multiply(merged, mask)
543
- imageandpatchs.set_updated_estimate(tobemergedto)
544
-
545
- # output
546
- return cv2.resize(imageandpatchs.estimation_updated_image, (input_resolution[1], input_resolution[0]), interpolation=cv2.INTER_CUBIC)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/leres/leres/multi_depth_model_woauxi.py DELETED
@@ -1,34 +0,0 @@
1
- from . import network_auxi as network
2
- from .net_tools import get_func
3
- import torch
4
- import torch.nn as nn
5
- from modules import devices
6
-
7
- class RelDepthModel(nn.Module):
8
- def __init__(self, backbone='resnet50'):
9
- super(RelDepthModel, self).__init__()
10
- if backbone == 'resnet50':
11
- encoder = 'resnet50_stride32'
12
- elif backbone == 'resnext101':
13
- encoder = 'resnext101_stride32x8d'
14
- self.depth_model = DepthModel(encoder)
15
-
16
- def inference(self, rgb):
17
- with torch.no_grad():
18
- input = rgb.to(self.depth_model.device)
19
- depth = self.depth_model(input)
20
- #pred_depth_out = depth - depth.min() + 0.01
21
- return depth #pred_depth_out
22
-
23
-
24
- class DepthModel(nn.Module):
25
- def __init__(self, encoder):
26
- super(DepthModel, self).__init__()
27
- backbone = network.__name__.split('.')[-1] + '.' + encoder
28
- self.encoder_modules = get_func(backbone)()
29
- self.decoder_modules = network.Decoder()
30
-
31
- def forward(self, x):
32
- lateral_out = self.encoder_modules(x)
33
- out_logit = self.decoder_modules(lateral_out)
34
- return out_logit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/leres/leres/net_tools.py DELETED
@@ -1,54 +0,0 @@
1
- import importlib
2
- import torch
3
- import os
4
- from collections import OrderedDict
5
-
6
-
7
- def get_func(func_name):
8
- """Helper to return a function object by name. func_name must identify a
9
- function in this module or the path to a function relative to the base
10
- 'modeling' module.
11
- """
12
- if func_name == '':
13
- return None
14
- try:
15
- parts = func_name.split('.')
16
- # Refers to a function in this module
17
- if len(parts) == 1:
18
- return globals()[parts[0]]
19
- # Otherwise, assume we're referencing a module under modeling
20
- module_name = 'annotator.leres.leres.' + '.'.join(parts[:-1])
21
- module = importlib.import_module(module_name)
22
- return getattr(module, parts[-1])
23
- except Exception:
24
- print('Failed to f1ind function: %s', func_name)
25
- raise
26
-
27
- def load_ckpt(args, depth_model, shift_model, focal_model):
28
- """
29
- Load checkpoint.
30
- """
31
- if os.path.isfile(args.load_ckpt):
32
- print("loading checkpoint %s" % args.load_ckpt)
33
- checkpoint = torch.load(args.load_ckpt)
34
- if shift_model is not None:
35
- shift_model.load_state_dict(strip_prefix_if_present(checkpoint['shift_model'], 'module.'),
36
- strict=True)
37
- if focal_model is not None:
38
- focal_model.load_state_dict(strip_prefix_if_present(checkpoint['focal_model'], 'module.'),
39
- strict=True)
40
- depth_model.load_state_dict(strip_prefix_if_present(checkpoint['depth_model'], "module."),
41
- strict=True)
42
- del checkpoint
43
- if torch.cuda.is_available():
44
- torch.cuda.empty_cache()
45
-
46
-
47
- def strip_prefix_if_present(state_dict, prefix):
48
- keys = sorted(state_dict.keys())
49
- if not all(key.startswith(prefix) for key in keys):
50
- return state_dict
51
- stripped_state_dict = OrderedDict()
52
- for key, value in state_dict.items():
53
- stripped_state_dict[key.replace(prefix, "")] = value
54
- return stripped_state_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/leres/leres/network_auxi.py DELETED
@@ -1,417 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.init as init
4
-
5
- from . import Resnet, Resnext_torch
6
-
7
-
8
- def resnet50_stride32():
9
- return DepthNet(backbone='resnet', depth=50, upfactors=[2, 2, 2, 2])
10
-
11
- def resnext101_stride32x8d():
12
- return DepthNet(backbone='resnext101_32x8d', depth=101, upfactors=[2, 2, 2, 2])
13
-
14
-
15
- class Decoder(nn.Module):
16
- def __init__(self):
17
- super(Decoder, self).__init__()
18
- self.inchannels = [256, 512, 1024, 2048]
19
- self.midchannels = [256, 256, 256, 512]
20
- self.upfactors = [2,2,2,2]
21
- self.outchannels = 1
22
-
23
- self.conv = FTB(inchannels=self.inchannels[3], midchannels=self.midchannels[3])
24
- self.conv1 = nn.Conv2d(in_channels=self.midchannels[3], out_channels=self.midchannels[2], kernel_size=3, padding=1, stride=1, bias=True)
25
- self.upsample = nn.Upsample(scale_factor=self.upfactors[3], mode='bilinear', align_corners=True)
26
-
27
- self.ffm2 = FFM(inchannels=self.inchannels[2], midchannels=self.midchannels[2], outchannels = self.midchannels[2], upfactor=self.upfactors[2])
28
- self.ffm1 = FFM(inchannels=self.inchannels[1], midchannels=self.midchannels[1], outchannels = self.midchannels[1], upfactor=self.upfactors[1])
29
- self.ffm0 = FFM(inchannels=self.inchannels[0], midchannels=self.midchannels[0], outchannels = self.midchannels[0], upfactor=self.upfactors[0])
30
-
31
- self.outconv = AO(inchannels=self.midchannels[0], outchannels=self.outchannels, upfactor=2)
32
- self._init_params()
33
-
34
- def _init_params(self):
35
- for m in self.modules():
36
- if isinstance(m, nn.Conv2d):
37
- init.normal_(m.weight, std=0.01)
38
- if m.bias is not None:
39
- init.constant_(m.bias, 0)
40
- elif isinstance(m, nn.ConvTranspose2d):
41
- init.normal_(m.weight, std=0.01)
42
- if m.bias is not None:
43
- init.constant_(m.bias, 0)
44
- elif isinstance(m, nn.BatchNorm2d): #NN.BatchNorm2d
45
- init.constant_(m.weight, 1)
46
- init.constant_(m.bias, 0)
47
- elif isinstance(m, nn.Linear):
48
- init.normal_(m.weight, std=0.01)
49
- if m.bias is not None:
50
- init.constant_(m.bias, 0)
51
-
52
- def forward(self, features):
53
- x_32x = self.conv(features[3]) # 1/32
54
- x_32 = self.conv1(x_32x)
55
- x_16 = self.upsample(x_32) # 1/16
56
-
57
- x_8 = self.ffm2(features[2], x_16) # 1/8
58
- x_4 = self.ffm1(features[1], x_8) # 1/4
59
- x_2 = self.ffm0(features[0], x_4) # 1/2
60
- #-----------------------------------------
61
- x = self.outconv(x_2) # original size
62
- return x
63
-
64
- class DepthNet(nn.Module):
65
- __factory = {
66
- 18: Resnet.resnet18,
67
- 34: Resnet.resnet34,
68
- 50: Resnet.resnet50,
69
- 101: Resnet.resnet101,
70
- 152: Resnet.resnet152
71
- }
72
- def __init__(self,
73
- backbone='resnet',
74
- depth=50,
75
- upfactors=[2, 2, 2, 2]):
76
- super(DepthNet, self).__init__()
77
- self.backbone = backbone
78
- self.depth = depth
79
- self.pretrained = False
80
- self.inchannels = [256, 512, 1024, 2048]
81
- self.midchannels = [256, 256, 256, 512]
82
- self.upfactors = upfactors
83
- self.outchannels = 1
84
-
85
- # Build model
86
- if self.backbone == 'resnet':
87
- if self.depth not in DepthNet.__factory:
88
- raise KeyError("Unsupported depth:", self.depth)
89
- self.encoder = DepthNet.__factory[depth](pretrained=self.pretrained)
90
- elif self.backbone == 'resnext101_32x8d':
91
- self.encoder = Resnext_torch.resnext101_32x8d(pretrained=self.pretrained)
92
- else:
93
- self.encoder = Resnext_torch.resnext101(pretrained=self.pretrained)
94
-
95
- def forward(self, x):
96
- x = self.encoder(x) # 1/32, 1/16, 1/8, 1/4
97
- return x
98
-
99
-
100
- class FTB(nn.Module):
101
- def __init__(self, inchannels, midchannels=512):
102
- super(FTB, self).__init__()
103
- self.in1 = inchannels
104
- self.mid = midchannels
105
- self.conv1 = nn.Conv2d(in_channels=self.in1, out_channels=self.mid, kernel_size=3, padding=1, stride=1,
106
- bias=True)
107
- # NN.BatchNorm2d
108
- self.conv_branch = nn.Sequential(nn.ReLU(inplace=True), \
109
- nn.Conv2d(in_channels=self.mid, out_channels=self.mid, kernel_size=3,
110
- padding=1, stride=1, bias=True), \
111
- nn.BatchNorm2d(num_features=self.mid), \
112
- nn.ReLU(inplace=True), \
113
- nn.Conv2d(in_channels=self.mid, out_channels=self.mid, kernel_size=3,
114
- padding=1, stride=1, bias=True))
115
- self.relu = nn.ReLU(inplace=True)
116
-
117
- self.init_params()
118
-
119
- def forward(self, x):
120
- x = self.conv1(x)
121
- x = x + self.conv_branch(x)
122
- x = self.relu(x)
123
-
124
- return x
125
-
126
- def init_params(self):
127
- for m in self.modules():
128
- if isinstance(m, nn.Conv2d):
129
- init.normal_(m.weight, std=0.01)
130
- if m.bias is not None:
131
- init.constant_(m.bias, 0)
132
- elif isinstance(m, nn.ConvTranspose2d):
133
- # init.kaiming_normal_(m.weight, mode='fan_out')
134
- init.normal_(m.weight, std=0.01)
135
- # init.xavier_normal_(m.weight)
136
- if m.bias is not None:
137
- init.constant_(m.bias, 0)
138
- elif isinstance(m, nn.BatchNorm2d): # NN.BatchNorm2d
139
- init.constant_(m.weight, 1)
140
- init.constant_(m.bias, 0)
141
- elif isinstance(m, nn.Linear):
142
- init.normal_(m.weight, std=0.01)
143
- if m.bias is not None:
144
- init.constant_(m.bias, 0)
145
-
146
-
147
- class ATA(nn.Module):
148
- def __init__(self, inchannels, reduction=8):
149
- super(ATA, self).__init__()
150
- self.inchannels = inchannels
151
- self.avg_pool = nn.AdaptiveAvgPool2d(1)
152
- self.fc = nn.Sequential(nn.Linear(self.inchannels * 2, self.inchannels // reduction),
153
- nn.ReLU(inplace=True),
154
- nn.Linear(self.inchannels // reduction, self.inchannels),
155
- nn.Sigmoid())
156
- self.init_params()
157
-
158
- def forward(self, low_x, high_x):
159
- n, c, _, _ = low_x.size()
160
- x = torch.cat([low_x, high_x], 1)
161
- x = self.avg_pool(x)
162
- x = x.view(n, -1)
163
- x = self.fc(x).view(n, c, 1, 1)
164
- x = low_x * x + high_x
165
-
166
- return x
167
-
168
- def init_params(self):
169
- for m in self.modules():
170
- if isinstance(m, nn.Conv2d):
171
- # init.kaiming_normal_(m.weight, mode='fan_out')
172
- # init.normal(m.weight, std=0.01)
173
- init.xavier_normal_(m.weight)
174
- if m.bias is not None:
175
- init.constant_(m.bias, 0)
176
- elif isinstance(m, nn.ConvTranspose2d):
177
- # init.kaiming_normal_(m.weight, mode='fan_out')
178
- # init.normal_(m.weight, std=0.01)
179
- init.xavier_normal_(m.weight)
180
- if m.bias is not None:
181
- init.constant_(m.bias, 0)
182
- elif isinstance(m, nn.BatchNorm2d): # NN.BatchNorm2d
183
- init.constant_(m.weight, 1)
184
- init.constant_(m.bias, 0)
185
- elif isinstance(m, nn.Linear):
186
- init.normal_(m.weight, std=0.01)
187
- if m.bias is not None:
188
- init.constant_(m.bias, 0)
189
-
190
-
191
- class FFM(nn.Module):
192
- def __init__(self, inchannels, midchannels, outchannels, upfactor=2):
193
- super(FFM, self).__init__()
194
- self.inchannels = inchannels
195
- self.midchannels = midchannels
196
- self.outchannels = outchannels
197
- self.upfactor = upfactor
198
-
199
- self.ftb1 = FTB(inchannels=self.inchannels, midchannels=self.midchannels)
200
- # self.ata = ATA(inchannels = self.midchannels)
201
- self.ftb2 = FTB(inchannels=self.midchannels, midchannels=self.outchannels)
202
-
203
- self.upsample = nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True)
204
-
205
- self.init_params()
206
-
207
- def forward(self, low_x, high_x):
208
- x = self.ftb1(low_x)
209
- x = x + high_x
210
- x = self.ftb2(x)
211
- x = self.upsample(x)
212
-
213
- return x
214
-
215
- def init_params(self):
216
- for m in self.modules():
217
- if isinstance(m, nn.Conv2d):
218
- # init.kaiming_normal_(m.weight, mode='fan_out')
219
- init.normal_(m.weight, std=0.01)
220
- # init.xavier_normal_(m.weight)
221
- if m.bias is not None:
222
- init.constant_(m.bias, 0)
223
- elif isinstance(m, nn.ConvTranspose2d):
224
- # init.kaiming_normal_(m.weight, mode='fan_out')
225
- init.normal_(m.weight, std=0.01)
226
- # init.xavier_normal_(m.weight)
227
- if m.bias is not None:
228
- init.constant_(m.bias, 0)
229
- elif isinstance(m, nn.BatchNorm2d): # NN.Batchnorm2d
230
- init.constant_(m.weight, 1)
231
- init.constant_(m.bias, 0)
232
- elif isinstance(m, nn.Linear):
233
- init.normal_(m.weight, std=0.01)
234
- if m.bias is not None:
235
- init.constant_(m.bias, 0)
236
-
237
-
238
- class AO(nn.Module):
239
- # Adaptive output module
240
- def __init__(self, inchannels, outchannels, upfactor=2):
241
- super(AO, self).__init__()
242
- self.inchannels = inchannels
243
- self.outchannels = outchannels
244
- self.upfactor = upfactor
245
-
246
- self.adapt_conv = nn.Sequential(
247
- nn.Conv2d(in_channels=self.inchannels, out_channels=self.inchannels // 2, kernel_size=3, padding=1,
248
- stride=1, bias=True), \
249
- nn.BatchNorm2d(num_features=self.inchannels // 2), \
250
- nn.ReLU(inplace=True), \
251
- nn.Conv2d(in_channels=self.inchannels // 2, out_channels=self.outchannels, kernel_size=3, padding=1,
252
- stride=1, bias=True), \
253
- nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True))
254
-
255
- self.init_params()
256
-
257
- def forward(self, x):
258
- x = self.adapt_conv(x)
259
- return x
260
-
261
- def init_params(self):
262
- for m in self.modules():
263
- if isinstance(m, nn.Conv2d):
264
- # init.kaiming_normal_(m.weight, mode='fan_out')
265
- init.normal_(m.weight, std=0.01)
266
- # init.xavier_normal_(m.weight)
267
- if m.bias is not None:
268
- init.constant_(m.bias, 0)
269
- elif isinstance(m, nn.ConvTranspose2d):
270
- # init.kaiming_normal_(m.weight, mode='fan_out')
271
- init.normal_(m.weight, std=0.01)
272
- # init.xavier_normal_(m.weight)
273
- if m.bias is not None:
274
- init.constant_(m.bias, 0)
275
- elif isinstance(m, nn.BatchNorm2d): # NN.Batchnorm2d
276
- init.constant_(m.weight, 1)
277
- init.constant_(m.bias, 0)
278
- elif isinstance(m, nn.Linear):
279
- init.normal_(m.weight, std=0.01)
280
- if m.bias is not None:
281
- init.constant_(m.bias, 0)
282
-
283
-
284
-
285
- # ==============================================================================================================
286
-
287
-
288
- class ResidualConv(nn.Module):
289
- def __init__(self, inchannels):
290
- super(ResidualConv, self).__init__()
291
- # NN.BatchNorm2d
292
- self.conv = nn.Sequential(
293
- # nn.BatchNorm2d(num_features=inchannels),
294
- nn.ReLU(inplace=False),
295
- # nn.Conv2d(in_channels=inchannels, out_channels=inchannels, kernel_size=3, padding=1, stride=1, groups=inchannels,bias=True),
296
- # nn.Conv2d(in_channels=inchannels, out_channels=inchannels, kernel_size=1, padding=0, stride=1, groups=1,bias=True)
297
- nn.Conv2d(in_channels=inchannels, out_channels=inchannels / 2, kernel_size=3, padding=1, stride=1,
298
- bias=False),
299
- nn.BatchNorm2d(num_features=inchannels / 2),
300
- nn.ReLU(inplace=False),
301
- nn.Conv2d(in_channels=inchannels / 2, out_channels=inchannels, kernel_size=3, padding=1, stride=1,
302
- bias=False)
303
- )
304
- self.init_params()
305
-
306
- def forward(self, x):
307
- x = self.conv(x) + x
308
- return x
309
-
310
- def init_params(self):
311
- for m in self.modules():
312
- if isinstance(m, nn.Conv2d):
313
- # init.kaiming_normal_(m.weight, mode='fan_out')
314
- init.normal_(m.weight, std=0.01)
315
- # init.xavier_normal_(m.weight)
316
- if m.bias is not None:
317
- init.constant_(m.bias, 0)
318
- elif isinstance(m, nn.ConvTranspose2d):
319
- # init.kaiming_normal_(m.weight, mode='fan_out')
320
- init.normal_(m.weight, std=0.01)
321
- # init.xavier_normal_(m.weight)
322
- if m.bias is not None:
323
- init.constant_(m.bias, 0)
324
- elif isinstance(m, nn.BatchNorm2d): # NN.BatchNorm2d
325
- init.constant_(m.weight, 1)
326
- init.constant_(m.bias, 0)
327
- elif isinstance(m, nn.Linear):
328
- init.normal_(m.weight, std=0.01)
329
- if m.bias is not None:
330
- init.constant_(m.bias, 0)
331
-
332
-
333
- class FeatureFusion(nn.Module):
334
- def __init__(self, inchannels, outchannels):
335
- super(FeatureFusion, self).__init__()
336
- self.conv = ResidualConv(inchannels=inchannels)
337
- # NN.BatchNorm2d
338
- self.up = nn.Sequential(ResidualConv(inchannels=inchannels),
339
- nn.ConvTranspose2d(in_channels=inchannels, out_channels=outchannels, kernel_size=3,
340
- stride=2, padding=1, output_padding=1),
341
- nn.BatchNorm2d(num_features=outchannels),
342
- nn.ReLU(inplace=True))
343
-
344
- def forward(self, lowfeat, highfeat):
345
- return self.up(highfeat + self.conv(lowfeat))
346
-
347
- def init_params(self):
348
- for m in self.modules():
349
- if isinstance(m, nn.Conv2d):
350
- # init.kaiming_normal_(m.weight, mode='fan_out')
351
- init.normal_(m.weight, std=0.01)
352
- # init.xavier_normal_(m.weight)
353
- if m.bias is not None:
354
- init.constant_(m.bias, 0)
355
- elif isinstance(m, nn.ConvTranspose2d):
356
- # init.kaiming_normal_(m.weight, mode='fan_out')
357
- init.normal_(m.weight, std=0.01)
358
- # init.xavier_normal_(m.weight)
359
- if m.bias is not None:
360
- init.constant_(m.bias, 0)
361
- elif isinstance(m, nn.BatchNorm2d): # NN.BatchNorm2d
362
- init.constant_(m.weight, 1)
363
- init.constant_(m.bias, 0)
364
- elif isinstance(m, nn.Linear):
365
- init.normal_(m.weight, std=0.01)
366
- if m.bias is not None:
367
- init.constant_(m.bias, 0)
368
-
369
-
370
- class SenceUnderstand(nn.Module):
371
- def __init__(self, channels):
372
- super(SenceUnderstand, self).__init__()
373
- self.channels = channels
374
- self.conv1 = nn.Sequential(nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
375
- nn.ReLU(inplace=True))
376
- self.pool = nn.AdaptiveAvgPool2d(8)
377
- self.fc = nn.Sequential(nn.Linear(512 * 8 * 8, self.channels),
378
- nn.ReLU(inplace=True))
379
- self.conv2 = nn.Sequential(
380
- nn.Conv2d(in_channels=self.channels, out_channels=self.channels, kernel_size=1, padding=0),
381
- nn.ReLU(inplace=True))
382
- self.initial_params()
383
-
384
- def forward(self, x):
385
- n, c, h, w = x.size()
386
- x = self.conv1(x)
387
- x = self.pool(x)
388
- x = x.view(n, -1)
389
- x = self.fc(x)
390
- x = x.view(n, self.channels, 1, 1)
391
- x = self.conv2(x)
392
- x = x.repeat(1, 1, h, w)
393
- return x
394
-
395
- def initial_params(self, dev=0.01):
396
- for m in self.modules():
397
- if isinstance(m, nn.Conv2d):
398
- # print torch.sum(m.weight)
399
- m.weight.data.normal_(0, dev)
400
- if m.bias is not None:
401
- m.bias.data.fill_(0)
402
- elif isinstance(m, nn.ConvTranspose2d):
403
- # print torch.sum(m.weight)
404
- m.weight.data.normal_(0, dev)
405
- if m.bias is not None:
406
- m.bias.data.fill_(0)
407
- elif isinstance(m, nn.Linear):
408
- m.weight.data.normal_(0, dev)
409
-
410
-
411
- if __name__ == '__main__':
412
- net = DepthNet(depth=50, pretrained=True)
413
- print(net)
414
- inputs = torch.ones(4,3,128,128)
415
- out = net(inputs)
416
- print(out.size())
417
-