fffiloni commited on
Commit
20cfcc2
1 Parent(s): 01d3b78

Delete utils

Browse files
utils/Config.py DELETED
@@ -1,24 +0,0 @@
1
- from fvcore.common.config import CfgNode as _CfgNode
2
-
3
- class CfgNode(_CfgNode):
4
- """
5
- The same as `fvcore.common.config.CfgNode`, but different in:
6
- 1. Use unsafe yaml loading by default.
7
- Note that this may lead to arbitrary code execution: you must not
8
- load a config file from untrusted sources before manually inspecting
9
- the content of the file.
10
- 2. Support config versioning.
11
- When attempting to merge an old config, it will convert the old config automatically.
12
- .. automethod:: clone
13
- .. automethod:: freeze
14
- .. automethod:: defrost
15
- .. automethod:: is_frozen
16
- .. automethod:: load_yaml_with_base
17
- .. automethod:: merge_from_list
18
- .. automethod:: merge_from_other_cfg
19
- """
20
-
21
- def merge_from_dict(self, dict):
22
- pass
23
-
24
- node = CfgNode()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/__init__.py DELETED
File without changes
utils/__pycache__/__init__.cpython-38.pyc DELETED
Binary file (130 Bytes)
 
utils/__pycache__/arguments.cpython-38.pyc DELETED
Binary file (3.98 kB)
 
utils/__pycache__/ddim.cpython-38.pyc DELETED
Binary file (6.17 kB)
 
utils/__pycache__/distributed.cpython-38.pyc DELETED
Binary file (4.9 kB)
 
utils/__pycache__/inpainting.cpython-38.pyc DELETED
Binary file (5.14 kB)
 
utils/__pycache__/misc.cpython-38.pyc DELETED
Binary file (3.71 kB)
 
utils/__pycache__/model.cpython-38.pyc DELETED
Binary file (817 Bytes)
 
utils/__pycache__/model_loading.cpython-38.pyc DELETED
Binary file (1.1 kB)
 
utils/__pycache__/readme.txt DELETED
File without changes
utils/__pycache__/util.cpython-38.pyc DELETED
Binary file (9.97 kB)
 
utils/__pycache__/visualizer.cpython-38.pyc DELETED
Binary file (43.1 kB)
 
utils/arguments.py DELETED
@@ -1,98 +0,0 @@
1
- import yaml
2
- import json
3
- import argparse
4
- import logging
5
-
6
- logger = logging.getLogger(__name__)
7
-
8
-
9
- def load_config_dict_to_opt(opt, config_dict):
10
- """
11
- Load the key, value pairs from config_dict to opt, overriding existing values in opt
12
- if there is any.
13
- """
14
- if not isinstance(config_dict, dict):
15
- raise TypeError("Config must be a Python dictionary")
16
- for k, v in config_dict.items():
17
- k_parts = k.split('.')
18
- pointer = opt
19
- for k_part in k_parts[:-1]:
20
- if k_part not in pointer:
21
- pointer[k_part] = {}
22
- pointer = pointer[k_part]
23
- assert isinstance(pointer, dict), "Overriding key needs to be inside a Python dict."
24
- ori_value = pointer.get(k_parts[-1])
25
- pointer[k_parts[-1]] = v
26
- if ori_value:
27
- logger.warning(f"Overrided {k} from {ori_value} to {pointer[k_parts[-1]]}")
28
-
29
-
30
- def load_opt_from_config_files(conf_file):
31
- """
32
- Load opt from the config files, settings in later files can override those in previous files.
33
-
34
- Args:
35
- conf_files: config file path
36
-
37
- Returns:
38
- dict: a dictionary of opt settings
39
- """
40
- opt = {}
41
- with open(conf_file, encoding='utf-8') as f:
42
- config_dict = yaml.safe_load(f)
43
-
44
- load_config_dict_to_opt(opt, config_dict)
45
-
46
- return opt
47
-
48
-
49
- def load_opt_command(args):
50
- parser = argparse.ArgumentParser(description='MainzTrain: Pretrain or fine-tune models for NLP tasks.')
51
- parser.add_argument('command', help='Command: train/evaluate/train-and-evaluate')
52
- parser.add_argument('--conf_files', required=True, help='Path(s) to the MainzTrain config file(s).')
53
- parser.add_argument('--config_overrides', nargs='*', help='Override parameters on config with a json style string, e.g. {"<PARAM_NAME_1>": <PARAM_VALUE_1>, "<PARAM_GROUP_2>.<PARAM_SUBGROUP_2>.<PARAM_2>": <PARAM_VALUE_2>}. A key with "." updates the object in the corresponding nested dict. Remember to escape " in command line.')
54
- parser.add_argument('--overrides', help='arguments that used to overide the config file in cmdline', nargs=argparse.REMAINDER)
55
-
56
- cmdline_args = parser.parse_args() if not args else parser.parse_args(args)
57
-
58
- opt = load_opt_from_config_files(cmdline_args.conf_files)
59
-
60
- if cmdline_args.config_overrides:
61
- config_overrides_string = ' '.join(cmdline_args.config_overrides)
62
- logger.warning(f"Command line config overrides: {config_overrides_string}")
63
- config_dict = json.loads(config_overrides_string)
64
- load_config_dict_to_opt(opt, config_dict)
65
-
66
- if cmdline_args.overrides:
67
- assert len(cmdline_args.overrides) % 2 == 0, "overides arguments is not paired, required: key value"
68
- keys = [cmdline_args.overrides[idx*2] for idx in range(len(cmdline_args.overrides)//2)]
69
- vals = [cmdline_args.overrides[idx*2+1] for idx in range(len(cmdline_args.overrides)//2)]
70
- vals = [val.replace('false', '').replace('False','') if len(val.replace(' ', '')) == 5 else val for val in vals]
71
-
72
- types = []
73
- for key in keys:
74
- key = key.split('.')
75
- ele = opt.copy()
76
- while len(key) > 0:
77
- ele = ele[key.pop(0)]
78
- types.append(type(ele))
79
-
80
- config_dict = {x:z(y) for x,y,z in zip(keys, vals, types)}
81
- load_config_dict_to_opt(opt, config_dict)
82
-
83
- # combine cmdline_args into opt dictionary
84
- for key, val in cmdline_args.__dict__.items():
85
- if val is not None:
86
- opt[key] = val
87
-
88
- return opt, cmdline_args
89
-
90
-
91
- def save_opt_to_json(opt, conf_file):
92
- with open(conf_file, 'w', encoding='utf-8') as f:
93
- json.dump(opt, f, indent=4)
94
-
95
-
96
- def save_opt_to_yaml(opt, conf_file):
97
- with open(conf_file, 'w', encoding='utf-8') as f:
98
- yaml.dump(opt, f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/ddim.py DELETED
@@ -1,203 +0,0 @@
1
- """SAMPLING ONLY."""
2
-
3
- import torch
4
- import numpy as np
5
- from tqdm import tqdm
6
- from functools import partial
7
-
8
- from .util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9
-
10
-
11
- class DDIMSampler(object):
12
- def __init__(self, model, schedule="linear", **kwargs):
13
- super().__init__()
14
- self.model = model
15
- self.ddpm_num_timesteps = model.num_timesteps
16
- self.schedule = schedule
17
-
18
- def register_buffer(self, name, attr):
19
- if type(attr) == torch.Tensor:
20
- if attr.device != torch.device("cuda"):
21
- attr = attr.to(torch.device("cuda"))
22
- setattr(self, name, attr)
23
-
24
- def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
25
- self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
26
- num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
27
- alphas_cumprod = self.model.alphas_cumprod
28
- assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
29
- to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
30
-
31
- self.register_buffer('betas', to_torch(self.model.betas))
32
- self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
33
- self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
34
-
35
- # calculations for diffusion q(x_t | x_{t-1}) and others
36
- self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
37
- self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
38
- self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
39
- self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
40
- self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
41
-
42
- # ddim sampling parameters
43
- ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
44
- ddim_timesteps=self.ddim_timesteps,
45
- eta=ddim_eta,verbose=verbose)
46
- self.register_buffer('ddim_sigmas', ddim_sigmas)
47
- self.register_buffer('ddim_alphas', ddim_alphas)
48
- self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
49
- self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
50
- sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
51
- (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
52
- 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
53
- self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
54
-
55
- @torch.no_grad()
56
- def sample(self,
57
- S,
58
- batch_size,
59
- shape,
60
- conditioning=None,
61
- callback=None,
62
- normals_sequence=None,
63
- img_callback=None,
64
- quantize_x0=False,
65
- eta=0.,
66
- mask=None,
67
- x0=None,
68
- temperature=1.,
69
- noise_dropout=0.,
70
- score_corrector=None,
71
- corrector_kwargs=None,
72
- verbose=True,
73
- x_T=None,
74
- log_every_t=100,
75
- unconditional_guidance_scale=1.,
76
- unconditional_conditioning=None,
77
- # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
78
- **kwargs
79
- ):
80
- if conditioning is not None:
81
- if isinstance(conditioning, dict):
82
- cbs = conditioning[list(conditioning.keys())[0]].shape[0]
83
- if cbs != batch_size:
84
- print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
85
- else:
86
- if conditioning.shape[0] != batch_size:
87
- print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
88
-
89
- self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
90
- # sampling
91
- C, H, W = shape
92
- size = (batch_size, C, H, W)
93
- print(f'Data shape for DDIM sampling is {size}, eta {eta}')
94
-
95
- samples, intermediates = self.ddim_sampling(conditioning, size,
96
- callback=callback,
97
- img_callback=img_callback,
98
- quantize_denoised=quantize_x0,
99
- mask=mask, x0=x0,
100
- ddim_use_original_steps=False,
101
- noise_dropout=noise_dropout,
102
- temperature=temperature,
103
- score_corrector=score_corrector,
104
- corrector_kwargs=corrector_kwargs,
105
- x_T=x_T,
106
- log_every_t=log_every_t,
107
- unconditional_guidance_scale=unconditional_guidance_scale,
108
- unconditional_conditioning=unconditional_conditioning,
109
- )
110
- return samples, intermediates
111
-
112
- @torch.no_grad()
113
- def ddim_sampling(self, cond, shape,
114
- x_T=None, ddim_use_original_steps=False,
115
- callback=None, timesteps=None, quantize_denoised=False,
116
- mask=None, x0=None, img_callback=None, log_every_t=100,
117
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
118
- unconditional_guidance_scale=1., unconditional_conditioning=None,):
119
- device = self.model.betas.device
120
- b = shape[0]
121
- if x_T is None:
122
- img = torch.randn(shape, device=device)
123
- else:
124
- img = x_T
125
-
126
- if timesteps is None:
127
- timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
128
- elif timesteps is not None and not ddim_use_original_steps:
129
- subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
130
- timesteps = self.ddim_timesteps[:subset_end]
131
-
132
- intermediates = {'x_inter': [img], 'pred_x0': [img]}
133
- time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
134
- total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
135
- print(f"Running DDIM Sampling with {total_steps} timesteps")
136
-
137
- iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
138
-
139
- for i, step in enumerate(iterator):
140
- index = total_steps - i - 1
141
- ts = torch.full((b,), step, device=device, dtype=torch.long)
142
-
143
- if mask is not None:
144
- assert x0 is not None
145
- img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
146
- img = img_orig * mask + (1. - mask) * img
147
-
148
- outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
149
- quantize_denoised=quantize_denoised, temperature=temperature,
150
- noise_dropout=noise_dropout, score_corrector=score_corrector,
151
- corrector_kwargs=corrector_kwargs,
152
- unconditional_guidance_scale=unconditional_guidance_scale,
153
- unconditional_conditioning=unconditional_conditioning)
154
- img, pred_x0 = outs
155
- if callback: callback(i)
156
- if img_callback: img_callback(pred_x0, i)
157
-
158
- if index % log_every_t == 0 or index == total_steps - 1:
159
- intermediates['x_inter'].append(img)
160
- intermediates['pred_x0'].append(pred_x0)
161
-
162
- return img, intermediates
163
-
164
- @torch.no_grad()
165
- def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
166
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
167
- unconditional_guidance_scale=1., unconditional_conditioning=None):
168
- b, *_, device = *x.shape, x.device
169
-
170
- if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
171
- e_t = self.model.apply_model(x, t, c)
172
- else:
173
- x_in = torch.cat([x] * 2)
174
- t_in = torch.cat([t] * 2)
175
- c_in = torch.cat([unconditional_conditioning, c])
176
- e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
177
- e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
178
-
179
- if score_corrector is not None:
180
- assert self.model.parameterization == "eps"
181
- e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
182
-
183
- alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
184
- alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
185
- sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
186
- sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
187
- # select parameters corresponding to the currently considered timestep
188
- a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
189
- a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
190
- sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
191
- sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
192
-
193
- # current prediction for x_0
194
- pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
195
- if quantize_denoised:
196
- pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
197
- # direction pointing to x_t
198
- dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
199
- noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
200
- if noise_dropout > 0.:
201
- noise = torch.nn.functional.dropout(noise, p=noise_dropout)
202
- x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
203
- return x_prev, pred_x0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/distributed.py DELETED
@@ -1,180 +0,0 @@
1
- import os
2
- import time
3
- import torch
4
- import pickle
5
- import torch.distributed as dist
6
-
7
-
8
- def init_distributed(opt):
9
- opt['CUDA'] = opt.get('CUDA', True) and torch.cuda.is_available()
10
- if 'OMPI_COMM_WORLD_SIZE' not in os.environ:
11
- # application was started without MPI
12
- # default to single node with single process
13
- opt['env_info'] = 'no MPI'
14
- opt['world_size'] = 1
15
- opt['local_size'] = 1
16
- opt['rank'] = 0
17
- opt['local_rank'] = 0
18
- opt['master_address'] = '127.0.0.1'
19
- opt['master_port'] = '8673'
20
- else:
21
- # application was started with MPI
22
- # get MPI parameters
23
- opt['world_size'] = int(os.environ['OMPI_COMM_WORLD_SIZE'])
24
- opt['local_size'] = int(os.environ['OMPI_COMM_WORLD_LOCAL_SIZE'])
25
- opt['rank'] = int(os.environ['OMPI_COMM_WORLD_RANK'])
26
- opt['local_rank'] = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
27
-
28
- # set up device
29
- if not opt['CUDA']:
30
- assert opt['world_size'] == 1, 'multi-GPU training without CUDA is not supported since we use NCCL as communication backend'
31
- opt['device'] = torch.device("cpu")
32
- else:
33
- torch.cuda.set_device(opt['local_rank'])
34
- opt['device'] = torch.device("cuda", opt['local_rank'])
35
- return opt
36
-
37
- def is_main_process():
38
- rank = 0
39
- if 'OMPI_COMM_WORLD_SIZE' in os.environ:
40
- rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
41
-
42
- return rank == 0
43
-
44
- def get_world_size():
45
- if not dist.is_available():
46
- return 1
47
- if not dist.is_initialized():
48
- return 1
49
- return dist.get_world_size()
50
-
51
- def get_rank():
52
- if not dist.is_available():
53
- return 0
54
- if not dist.is_initialized():
55
- return 0
56
- return dist.get_rank()
57
-
58
-
59
- def synchronize():
60
- """
61
- Helper function to synchronize (barrier) among all processes when
62
- using distributed training
63
- """
64
- if not dist.is_available():
65
- return
66
- if not dist.is_initialized():
67
- return
68
- world_size = dist.get_world_size()
69
- rank = dist.get_rank()
70
- if world_size == 1:
71
- return
72
-
73
- def _send_and_wait(r):
74
- if rank == r:
75
- tensor = torch.tensor(0, device="cuda")
76
- else:
77
- tensor = torch.tensor(1, device="cuda")
78
- dist.broadcast(tensor, r)
79
- while tensor.item() == 1:
80
- time.sleep(1)
81
-
82
- _send_and_wait(0)
83
- # now sync on the main process
84
- _send_and_wait(1)
85
-
86
-
87
- def all_gather(data):
88
- """
89
- Run all_gather on arbitrary picklable data (not necessarily tensors)
90
- Args:
91
- data: any picklable object
92
- Returns:
93
- list[data]: list of data gathered from each rank
94
- """
95
- world_size = get_world_size()
96
- if world_size == 1:
97
- return [data]
98
-
99
- # serialized to a Tensor
100
- buffer = pickle.dumps(data)
101
- storage = torch.ByteStorage.from_buffer(buffer)
102
- tensor = torch.ByteTensor(storage).to("cuda")
103
-
104
- # obtain Tensor size of each rank
105
- local_size = torch.IntTensor([tensor.numel()]).to("cuda")
106
- size_list = [torch.IntTensor([0]).to("cuda") for _ in range(world_size)]
107
- dist.all_gather(size_list, local_size)
108
- size_list = [int(size.item()) for size in size_list]
109
- max_size = max(size_list)
110
-
111
- # receiving Tensor from all ranks
112
- # we pad the tensor because torch all_gather does not support
113
- # gathering tensors of different shapes
114
- tensor_list = []
115
- for _ in size_list:
116
- tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
117
- if local_size != max_size:
118
- padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
119
- tensor = torch.cat((tensor, padding), dim=0)
120
- dist.all_gather(tensor_list, tensor)
121
-
122
- data_list = []
123
- for size, tensor in zip(size_list, tensor_list):
124
- buffer = tensor.cpu().numpy().tobytes()[:size]
125
- data_list.append(pickle.loads(buffer))
126
-
127
- return data_list
128
-
129
-
130
- def reduce_dict(input_dict, average=True):
131
- """
132
- Args:
133
- input_dict (dict): all the values will be reduced
134
- average (bool): whether to do average or sum
135
- Reduce the values in the dictionary from all processes so that process with rank
136
- 0 has the averaged results. Returns a dict with the same fields as
137
- input_dict, after reduction.
138
- """
139
- world_size = get_world_size()
140
- if world_size < 2:
141
- return input_dict
142
- with torch.no_grad():
143
- names = []
144
- values = []
145
- # sort the keys so that they are consistent across processes
146
- for k in sorted(input_dict.keys()):
147
- names.append(k)
148
- values.append(input_dict[k])
149
- values = torch.stack(values, dim=0)
150
- dist.reduce(values, dst=0)
151
- if dist.get_rank() == 0 and average:
152
- # only main process gets accumulated, so only divide by
153
- # world_size in this case
154
- values /= world_size
155
- reduced_dict = {k: v for k, v in zip(names, values)}
156
- return reduced_dict
157
-
158
-
159
- def broadcast_data(data):
160
- if not torch.distributed.is_initialized():
161
- return data
162
- rank = dist.get_rank()
163
- if rank == 0:
164
- data_tensor = torch.tensor(data + [0], device="cuda")
165
- else:
166
- data_tensor = torch.tensor(data + [1], device="cuda")
167
- torch.distributed.broadcast(data_tensor, 0)
168
- while data_tensor.cpu().numpy()[-1] == 1:
169
- time.sleep(1)
170
-
171
- return data_tensor.cpu().numpy().tolist()[:-1]
172
-
173
-
174
- def reduce_sum(tensor):
175
- if get_world_size() <= 1:
176
- return tensor
177
-
178
- tensor = tensor.clone()
179
- dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
180
- return tensor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/inpainting.py DELETED
@@ -1,177 +0,0 @@
1
- import sys
2
- import cv2
3
- import torch
4
- import numpy as np
5
- import gradio as gr
6
- from PIL import Image
7
- from omegaconf import OmegaConf
8
- from einops import repeat
9
- from imwatermark import WatermarkEncoder
10
- from pathlib import Path
11
-
12
- from .ddim import DDIMSampler
13
- from .util import instantiate_from_config
14
-
15
-
16
- torch.set_grad_enabled(False)
17
-
18
-
19
- def put_watermark(img, wm_encoder=None):
20
- if wm_encoder is not None:
21
- img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
22
- img = wm_encoder.encode(img, 'dwtDct')
23
- img = Image.fromarray(img[:, :, ::-1])
24
- return img
25
-
26
-
27
- def initialize_model(config, ckpt):
28
- config = OmegaConf.load(config)
29
- model = instantiate_from_config(config.model)
30
-
31
- model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
32
-
33
- device = torch.device(
34
- "cuda") if torch.cuda.is_available() else torch.device("cpu")
35
- model = model.to(device)
36
- sampler = DDIMSampler(model)
37
-
38
- return sampler
39
-
40
-
41
- def make_batch_sd(
42
- image,
43
- mask,
44
- txt,
45
- device,
46
- num_samples=1):
47
- image = np.array(image.convert("RGB"))
48
- image = image[None].transpose(0, 3, 1, 2)
49
- image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
50
-
51
- mask = np.array(mask.convert("L"))
52
- mask = mask.astype(np.float32) / 255.0
53
- mask = mask[None, None]
54
- mask[mask < 0.5] = 0
55
- mask[mask >= 0.5] = 1
56
- mask = torch.from_numpy(mask)
57
-
58
- masked_image = image * (mask < 0.5)
59
-
60
- batch = {
61
- "image": repeat(image.to(device=device), "1 ... -> n ...", n=num_samples),
62
- "txt": num_samples * [txt],
63
- "mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples),
64
- "masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples),
65
- }
66
- return batch
67
-
68
- @torch.no_grad()
69
- def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1, w=512, h=512):
70
- device = torch.device(
71
- "cuda") if torch.cuda.is_available() else torch.device("cpu")
72
- model = sampler.model
73
-
74
- print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
75
- wm = "SDV2"
76
- wm_encoder = WatermarkEncoder()
77
- wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
78
-
79
- prng = np.random.RandomState(seed)
80
- start_code = prng.randn(num_samples, 4, h // 8, w // 8)
81
- start_code = torch.from_numpy(start_code).to(
82
- device=device, dtype=torch.float32)
83
-
84
- with torch.no_grad(), \
85
- torch.autocast("cuda"):
86
- batch = make_batch_sd(image, mask, txt=prompt,
87
- device=device, num_samples=num_samples)
88
-
89
- c = model.cond_stage_model.encode(batch["txt"])
90
-
91
- c_cat = list()
92
- for ck in model.concat_keys:
93
- cc = batch[ck].float()
94
- if ck != model.masked_image_key:
95
- bchw = [num_samples, 4, h // 8, w // 8]
96
- cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
97
- else:
98
- cc = model.get_first_stage_encoding(
99
- model.encode_first_stage(cc))
100
- c_cat.append(cc)
101
- c_cat = torch.cat(c_cat, dim=1)
102
-
103
- # cond
104
- cond = {"c_concat": [c_cat], "c_crossattn": [c]}
105
-
106
- # uncond cond
107
- uc_cross = model.get_unconditional_conditioning(num_samples, "")
108
- uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
109
-
110
- shape = [model.channels, h // 8, w // 8]
111
- samples_cfg, intermediates = sampler.sample(
112
- ddim_steps,
113
- num_samples,
114
- shape,
115
- cond,
116
- verbose=False,
117
- eta=1.0,
118
- unconditional_guidance_scale=scale,
119
- unconditional_conditioning=uc_full,
120
- x_T=start_code,
121
- )
122
- x_samples_ddim = model.decode_first_stage(samples_cfg)
123
-
124
- result = torch.clamp((x_samples_ddim + 1.0) / 2.0,
125
- min=0.0, max=1.0)
126
-
127
- result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255
128
- return [put_watermark(Image.fromarray(img.astype(np.uint8)), wm_encoder) for img in result]
129
-
130
- def pad_image(input_image):
131
- pad_w, pad_h = np.max(((2, 2), np.ceil(
132
- np.array(input_image.size) / 64).astype(int)), axis=0) * 64 - input_image.size
133
- im_padded = Image.fromarray(
134
- np.pad(np.array(input_image), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
135
- return im_padded
136
-
137
- def crop_image(input_image):
138
- crop_w, crop_h = np.floor(np.array(input_image.size) / 64).astype(int) * 64
139
- im_cropped = Image.fromarray(np.array(input_image)[:crop_h, :crop_w])
140
- return im_cropped
141
-
142
- # sampler = initialize_model(sys.argv[1], sys.argv[2])
143
- @torch.no_grad()
144
- def predict(model, input_image, prompt, ddim_steps, num_samples, scale, seed):
145
- """_summary_
146
-
147
- Args:
148
- input_image (_type_): dict
149
- - image: PIL.Image. Input image.
150
- - mask: PIL.Image. Mask image.
151
- prompt (_type_): string to be used as prompt.
152
- ddim_steps (_type_): typical 45
153
- num_samples (_type_): typical 4
154
- scale (_type_): typical 10.0 Guidance Scale.
155
- seed (_type_): typical 1529160519
156
-
157
- """
158
- init_image = input_image["image"].convert("RGB")
159
- init_mask = input_image["mask"].convert("RGB")
160
- image = pad_image(init_image) # resize to integer multiple of 32
161
- mask = pad_image(init_mask) # resize to integer multiple of 32
162
- width, height = image.size
163
- print("Inpainting...", width, height)
164
-
165
- result = inpaint(
166
- sampler=model,
167
- image=image,
168
- mask=mask,
169
- prompt=prompt,
170
- seed=seed,
171
- scale=scale,
172
- ddim_steps=ddim_steps,
173
- num_samples=num_samples,
174
- h=height, w=width
175
- )
176
-
177
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/misc.py DELETED
@@ -1,122 +0,0 @@
1
- import math
2
- import numpy as np
3
-
4
- def get_prompt_templates():
5
- prompt_templates = [
6
- '{}.',
7
- 'a photo of a {}.',
8
- 'a bad photo of a {}.',
9
- 'a photo of many {}.',
10
- 'a sculpture of a {}.',
11
- 'a photo of the hard to see {}.',
12
- 'a low resolution photo of the {}.',
13
- 'a rendering of a {}.',
14
- 'graffiti of a {}.',
15
- 'a bad photo of the {}.',
16
- 'a cropped photo of the {}.',
17
- 'a tattoo of a {}.',
18
- 'the embroidered {}.',
19
- 'a photo of a hard to see {}.',
20
- 'a bright photo of a {}.',
21
- 'a photo of a clean {}.',
22
- 'a photo of a dirty {}.',
23
- 'a dark photo of the {}.',
24
- 'a drawing of a {}.',
25
- 'a photo of my {}.',
26
- 'the plastic {}.',
27
- 'a photo of the cool {}.',
28
- 'a close-up photo of a {}.',
29
- 'a black and white photo of the {}.',
30
- 'a painting of the {}.',
31
- 'a painting of a {}.',
32
- 'a pixelated photo of the {}.',
33
- 'a sculpture of the {}.',
34
- 'a bright photo of the {}.',
35
- 'a cropped photo of a {}.',
36
- 'a plastic {}.',
37
- 'a photo of the dirty {}.',
38
- 'a jpeg corrupted photo of a {}.',
39
- 'a blurry photo of the {}.',
40
- 'a photo of the {}.',
41
- 'a good photo of the {}.',
42
- 'a rendering of the {}.',
43
- 'a {} in a video game.',
44
- 'a photo of one {}.',
45
- 'a doodle of a {}.',
46
- 'a close-up photo of the {}.',
47
- 'the origami {}.',
48
- 'the {} in a video game.',
49
- 'a sketch of a {}.',
50
- 'a doodle of the {}.',
51
- 'a origami {}.',
52
- 'a low resolution photo of a {}.',
53
- 'the toy {}.',
54
- 'a rendition of the {}.',
55
- 'a photo of the clean {}.',
56
- 'a photo of a large {}.',
57
- 'a rendition of a {}.',
58
- 'a photo of a nice {}.',
59
- 'a photo of a weird {}.',
60
- 'a blurry photo of a {}.',
61
- 'a cartoon {}.',
62
- 'art of a {}.',
63
- 'a sketch of the {}.',
64
- 'a embroidered {}.',
65
- 'a pixelated photo of a {}.',
66
- 'itap of the {}.',
67
- 'a jpeg corrupted photo of the {}.',
68
- 'a good photo of a {}.',
69
- 'a plushie {}.',
70
- 'a photo of the nice {}.',
71
- 'a photo of the small {}.',
72
- 'a photo of the weird {}.',
73
- 'the cartoon {}.',
74
- 'art of the {}.',
75
- 'a drawing of the {}.',
76
- 'a photo of the large {}.',
77
- 'a black and white photo of a {}.',
78
- 'the plushie {}.',
79
- 'a dark photo of a {}.',
80
- 'itap of a {}.',
81
- 'graffiti of the {}.',
82
- 'a toy {}.',
83
- 'itap of my {}.',
84
- 'a photo of a cool {}.',
85
- 'a photo of a small {}.',
86
- 'a tattoo of the {}.',
87
- ]
88
- return prompt_templates
89
-
90
-
91
- def prompt_engineering(classnames, topk=1, suffix='.'):
92
- prompt_templates = get_prompt_templates()
93
- temp_idx = np.random.randint(min(len(prompt_templates), topk))
94
-
95
- if isinstance(classnames, list):
96
- classname = random.choice(classnames)
97
- else:
98
- classname = classnames
99
-
100
- return prompt_templates[temp_idx].replace('.', suffix).format(classname.replace(',', '').replace('+', ' '))
101
-
102
- class AverageMeter(object):
103
- """Computes and stores the average and current value."""
104
- def __init__(self):
105
- self.reset()
106
-
107
- def reset(self):
108
- self.val = 0
109
- self.avg = 0
110
- self.sum = 0
111
- self.count = 0
112
-
113
- def update(self, val, n=1, decay=0):
114
- self.val = val
115
- if decay:
116
- alpha = math.exp(-n / decay) # exponential decay over 100 updates
117
- self.sum = alpha * self.sum + (1 - alpha) * val * n
118
- self.count = alpha * self.count + (1 - alpha) * n
119
- else:
120
- self.sum += val * n
121
- self.count += n
122
- self.avg = self.sum / self.count
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/model.py DELETED
@@ -1,32 +0,0 @@
1
- import logging
2
- import os
3
- import time
4
- import pickle
5
-
6
- import torch
7
- import torch.distributed as dist
8
-
9
- from fvcore.nn import FlopCountAnalysis
10
- from fvcore.nn import flop_count_table
11
- from fvcore.nn import flop_count_str
12
-
13
- logger = logging.getLogger(__name__)
14
-
15
-
16
- NORM_MODULES = [
17
- torch.nn.BatchNorm1d,
18
- torch.nn.BatchNorm2d,
19
- torch.nn.BatchNorm3d,
20
- torch.nn.SyncBatchNorm,
21
- # NaiveSyncBatchNorm inherits from BatchNorm2d
22
- torch.nn.GroupNorm,
23
- torch.nn.InstanceNorm1d,
24
- torch.nn.InstanceNorm2d,
25
- torch.nn.InstanceNorm3d,
26
- torch.nn.LayerNorm,
27
- torch.nn.LocalResponseNorm,
28
- ]
29
-
30
- def register_norm_module(cls):
31
- NORM_MODULES.append(cls)
32
- return cls
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/model_loading.py DELETED
@@ -1,42 +0,0 @@
1
- # --------------------------------------------------------
2
- # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
3
- # Copyright (c) 2022 Microsoft
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Xueyan Zou (xueyan@cs.wisc.edu)
6
- # --------------------------------------------------------
7
-
8
- import logging
9
- from utils.distributed import is_main_process
10
- logger = logging.getLogger(__name__)
11
-
12
-
13
- def align_and_update_state_dicts(model_state_dict, ckpt_state_dict):
14
- model_keys = sorted(model_state_dict.keys())
15
- ckpt_keys = sorted(ckpt_state_dict.keys())
16
- result_dicts = {}
17
- matched_log = []
18
- unmatched_log = []
19
- unloaded_log = []
20
- for model_key in model_keys:
21
- model_weight = model_state_dict[model_key]
22
- if model_key in ckpt_keys:
23
- ckpt_weight = ckpt_state_dict[model_key]
24
- if model_weight.shape == ckpt_weight.shape:
25
- result_dicts[model_key] = ckpt_weight
26
- ckpt_keys.pop(ckpt_keys.index(model_key))
27
- matched_log.append("Loaded {}, Model Shape: {} <-> Ckpt Shape: {}".format(model_key, model_weight.shape, ckpt_weight.shape))
28
- else:
29
- unmatched_log.append("*UNMATCHED* {}, Model Shape: {} <-> Ckpt Shape: {}".format(model_key, model_weight.shape, ckpt_weight.shape))
30
- else:
31
- unloaded_log.append("*UNLOADED* {}, Model Shape: {}".format(model_key, model_weight.shape))
32
-
33
- if is_main_process():
34
- for info in matched_log:
35
- logger.info(info)
36
- for info in unloaded_log:
37
- logger.warning(info)
38
- for key in ckpt_keys:
39
- logger.warning("$UNUSED$ {}, Ckpt Shape: {}".format(key, ckpt_state_dict[key].shape))
40
- for info in unmatched_log:
41
- logger.warning(info)
42
- return result_dicts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/util.py DELETED
@@ -1,283 +0,0 @@
1
- # adopted from
2
- # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3
- # and
4
- # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5
- # and
6
- # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7
- #
8
- # thanks!
9
- import importlib
10
-
11
- import os
12
- import math
13
- import torch
14
- import torch.nn as nn
15
- import numpy as np
16
- from einops import repeat
17
-
18
-
19
- def instantiate_from_config(config):
20
- if not "target" in config:
21
- if config == '__is_first_stage__':
22
- return None
23
- elif config == "__is_unconditional__":
24
- return None
25
- raise KeyError("Expected key `target` to instantiate.")
26
- return get_obj_from_str(config["target"])(**config.get("params", dict()))
27
-
28
-
29
- def get_obj_from_str(string, reload=False):
30
- module, cls = string.rsplit(".", 1)
31
- if reload:
32
- module_imp = importlib.import_module(module)
33
- importlib.reload(module_imp)
34
- return getattr(importlib.import_module(module, package=None), cls)
35
-
36
-
37
- def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
38
- if schedule == "linear":
39
- betas = (
40
- torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
41
- )
42
-
43
- elif schedule == "cosine":
44
- timesteps = (
45
- torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
46
- )
47
- alphas = timesteps / (1 + cosine_s) * np.pi / 2
48
- alphas = torch.cos(alphas).pow(2)
49
- alphas = alphas / alphas[0]
50
- betas = 1 - alphas[1:] / alphas[:-1]
51
- betas = np.clip(betas, a_min=0, a_max=0.999)
52
-
53
- elif schedule == "sqrt_linear":
54
- betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
55
- elif schedule == "sqrt":
56
- betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
57
- else:
58
- raise ValueError(f"schedule '{schedule}' unknown.")
59
- return betas.numpy()
60
-
61
-
62
- def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
63
- if ddim_discr_method == 'uniform':
64
- c = num_ddpm_timesteps // num_ddim_timesteps
65
- ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
66
- elif ddim_discr_method == 'quad':
67
- ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
68
- else:
69
- raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
70
-
71
- # assert ddim_timesteps.shape[0] == num_ddim_timesteps
72
- # add one to get the final alpha values right (the ones from first scale to data during sampling)
73
- steps_out = ddim_timesteps + 1
74
- if verbose:
75
- print(f'Selected timesteps for ddim sampler: {steps_out}')
76
- return steps_out
77
-
78
-
79
- def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
80
- # select alphas for computing the variance schedule
81
- alphas = alphacums[ddim_timesteps]
82
- alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
83
-
84
- # according the the formula provided in https://arxiv.org/abs/2010.02502
85
- sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
86
- if verbose:
87
- print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
88
- print(f'For the chosen value of eta, which is {eta}, '
89
- f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
90
- return sigmas, alphas, alphas_prev
91
-
92
-
93
- def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
94
- """
95
- Create a beta schedule that discretizes the given alpha_t_bar function,
96
- which defines the cumulative product of (1-beta) over time from t = [0,1].
97
- :param num_diffusion_timesteps: the number of betas to produce.
98
- :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
99
- produces the cumulative product of (1-beta) up to that
100
- part of the diffusion process.
101
- :param max_beta: the maximum beta to use; use values lower than 1 to
102
- prevent singularities.
103
- """
104
- betas = []
105
- for i in range(num_diffusion_timesteps):
106
- t1 = i / num_diffusion_timesteps
107
- t2 = (i + 1) / num_diffusion_timesteps
108
- betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
109
- return np.array(betas)
110
-
111
-
112
- def extract_into_tensor(a, t, x_shape):
113
- b, *_ = t.shape
114
- out = a.gather(-1, t)
115
- return out.reshape(b, *((1,) * (len(x_shape) - 1)))
116
-
117
-
118
- def checkpoint(func, inputs, params, flag):
119
- """
120
- Evaluate a function without caching intermediate activations, allowing for
121
- reduced memory at the expense of extra compute in the backward pass.
122
- :param func: the function to evaluate.
123
- :param inputs: the argument sequence to pass to `func`.
124
- :param params: a sequence of parameters `func` depends on but does not
125
- explicitly take as arguments.
126
- :param flag: if False, disable gradient checkpointing.
127
- """
128
- if flag:
129
- args = tuple(inputs) + tuple(params)
130
- return CheckpointFunction.apply(func, len(inputs), *args)
131
- else:
132
- return func(*inputs)
133
-
134
-
135
- class CheckpointFunction(torch.autograd.Function):
136
- @staticmethod
137
- def forward(ctx, run_function, length, *args):
138
- ctx.run_function = run_function
139
- ctx.input_tensors = list(args[:length])
140
- ctx.input_params = list(args[length:])
141
-
142
- with torch.no_grad():
143
- output_tensors = ctx.run_function(*ctx.input_tensors)
144
- return output_tensors
145
-
146
- @staticmethod
147
- def backward(ctx, *output_grads):
148
- ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
149
- with torch.enable_grad():
150
- # Fixes a bug where the first op in run_function modifies the
151
- # Tensor storage in place, which is not allowed for detach()'d
152
- # Tensors.
153
- shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
154
- output_tensors = ctx.run_function(*shallow_copies)
155
- input_grads = torch.autograd.grad(
156
- output_tensors,
157
- ctx.input_tensors + ctx.input_params,
158
- output_grads,
159
- allow_unused=True,
160
- )
161
- del ctx.input_tensors
162
- del ctx.input_params
163
- del output_tensors
164
- return (None, None) + input_grads
165
-
166
-
167
- def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
168
- """
169
- Create sinusoidal timestep embeddings.
170
- :param timesteps: a 1-D Tensor of N indices, one per batch element.
171
- These may be fractional.
172
- :param dim: the dimension of the output.
173
- :param max_period: controls the minimum frequency of the embeddings.
174
- :return: an [N x dim] Tensor of positional embeddings.
175
- """
176
- if not repeat_only:
177
- half = dim // 2
178
- freqs = torch.exp(
179
- -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
180
- ).to(device=timesteps.device)
181
- args = timesteps[:, None].float() * freqs[None]
182
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
183
- if dim % 2:
184
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
185
- else:
186
- embedding = repeat(timesteps, 'b -> b d', d=dim)
187
- return embedding
188
-
189
-
190
- def zero_module(module):
191
- """
192
- Zero out the parameters of a module and return it.
193
- """
194
- for p in module.parameters():
195
- p.detach().zero_()
196
- return module
197
-
198
-
199
- def scale_module(module, scale):
200
- """
201
- Scale the parameters of a module and return it.
202
- """
203
- for p in module.parameters():
204
- p.detach().mul_(scale)
205
- return module
206
-
207
-
208
- def mean_flat(tensor):
209
- """
210
- Take the mean over all non-batch dimensions.
211
- """
212
- return tensor.mean(dim=list(range(1, len(tensor.shape))))
213
-
214
-
215
- def normalization(channels):
216
- """
217
- Make a standard normalization layer.
218
- :param channels: number of input channels.
219
- :return: an nn.Module for normalization.
220
- """
221
- return GroupNorm32(32, channels)
222
-
223
-
224
- # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
225
- class SiLU(nn.Module):
226
- def forward(self, x):
227
- return x * torch.sigmoid(x)
228
-
229
-
230
- class GroupNorm32(nn.GroupNorm):
231
- def forward(self, x):
232
- return super().forward(x.float()).type(x.dtype)
233
-
234
- def conv_nd(dims, *args, **kwargs):
235
- """
236
- Create a 1D, 2D, or 3D convolution module.
237
- """
238
- if dims == 1:
239
- return nn.Conv1d(*args, **kwargs)
240
- elif dims == 2:
241
- return nn.Conv2d(*args, **kwargs)
242
- elif dims == 3:
243
- return nn.Conv3d(*args, **kwargs)
244
- raise ValueError(f"unsupported dimensions: {dims}")
245
-
246
-
247
- def linear(*args, **kwargs):
248
- """
249
- Create a linear module.
250
- """
251
- return nn.Linear(*args, **kwargs)
252
-
253
-
254
- def avg_pool_nd(dims, *args, **kwargs):
255
- """
256
- Create a 1D, 2D, or 3D average pooling module.
257
- """
258
- if dims == 1:
259
- return nn.AvgPool1d(*args, **kwargs)
260
- elif dims == 2:
261
- return nn.AvgPool2d(*args, **kwargs)
262
- elif dims == 3:
263
- return nn.AvgPool3d(*args, **kwargs)
264
- raise ValueError(f"unsupported dimensions: {dims}")
265
-
266
-
267
- class HybridConditioner(nn.Module):
268
-
269
- def __init__(self, c_concat_config, c_crossattn_config):
270
- super().__init__()
271
- self.concat_conditioner = instantiate_from_config(c_concat_config)
272
- self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
273
-
274
- def forward(self, c_concat, c_crossattn):
275
- c_concat = self.concat_conditioner(c_concat)
276
- c_crossattn = self.crossattn_conditioner(c_crossattn)
277
- return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
278
-
279
-
280
- def noise_like(shape, device, repeat=False):
281
- repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
282
- noise = lambda: torch.randn(shape, device=device)
283
- return repeat_noise() if repeat else noise()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/visualizer.py DELETED
@@ -1,1278 +0,0 @@
1
- # Copyright (c) Facebook, Inc. and its affiliates.
2
- import colorsys
3
- import logging
4
- import math
5
- import numpy as np
6
- from enum import Enum, unique
7
- import cv2
8
- import matplotlib as mpl
9
- import matplotlib.colors as mplc
10
- import matplotlib.figure as mplfigure
11
- import pycocotools.mask as mask_util
12
- import torch
13
- from matplotlib.backends.backend_agg import FigureCanvasAgg
14
- from PIL import Image
15
-
16
- from detectron2.data import MetadataCatalog
17
- from detectron2.structures import BitMasks, Boxes, BoxMode, Keypoints, PolygonMasks, RotatedBoxes
18
- from detectron2.utils.file_io import PathManager
19
-
20
- from detectron2.utils.colormap import random_color
21
-
22
- logger = logging.getLogger(__name__)
23
- __all__ = ["ColorMode", "VisImage", "Visualizer"]
24
-
25
-
26
- _SMALL_OBJECT_AREA_THRESH = 1000
27
- _LARGE_MASK_AREA_THRESH = 120000
28
- _OFF_WHITE = (1.0, 1.0, 240.0 / 255)
29
- _BLACK = (0, 0, 0)
30
- _RED = (1.0, 0, 0)
31
-
32
- _KEYPOINT_THRESHOLD = 0.05
33
-
34
-
35
- @unique
36
- class ColorMode(Enum):
37
- """
38
- Enum of different color modes to use for instance visualizations.
39
- """
40
-
41
- IMAGE = 0
42
- """
43
- Picks a random color for every instance and overlay segmentations with low opacity.
44
- """
45
- SEGMENTATION = 1
46
- """
47
- Let instances of the same category have similar colors
48
- (from metadata.thing_colors), and overlay them with
49
- high opacity. This provides more attention on the quality of segmentation.
50
- """
51
- IMAGE_BW = 2
52
- """
53
- Same as IMAGE, but convert all areas without masks to gray-scale.
54
- Only available for drawing per-instance mask predictions.
55
- """
56
-
57
-
58
- class GenericMask:
59
- """
60
- Attribute:
61
- polygons (list[ndarray]): list[ndarray]: polygons for this mask.
62
- Each ndarray has format [x, y, x, y, ...]
63
- mask (ndarray): a binary mask
64
- """
65
-
66
- def __init__(self, mask_or_polygons, height, width):
67
- self._mask = self._polygons = self._has_holes = None
68
- self.height = height
69
- self.width = width
70
-
71
- m = mask_or_polygons
72
- if isinstance(m, dict):
73
- # RLEs
74
- assert "counts" in m and "size" in m
75
- if isinstance(m["counts"], list): # uncompressed RLEs
76
- h, w = m["size"]
77
- assert h == height and w == width
78
- m = mask_util.frPyObjects(m, h, w)
79
- self._mask = mask_util.decode(m)[:, :]
80
- return
81
-
82
- if isinstance(m, list): # list[ndarray]
83
- self._polygons = [np.asarray(x).reshape(-1) for x in m]
84
- return
85
-
86
- if isinstance(m, np.ndarray): # assumed to be a binary mask
87
- assert m.shape[1] != 2, m.shape
88
- assert m.shape == (
89
- height,
90
- width,
91
- ), f"mask shape: {m.shape}, target dims: {height}, {width}"
92
- self._mask = m.astype("uint8")
93
- return
94
-
95
- raise ValueError("GenericMask cannot handle object {} of type '{}'".format(m, type(m)))
96
-
97
- @property
98
- def mask(self):
99
- if self._mask is None:
100
- self._mask = self.polygons_to_mask(self._polygons)
101
- return self._mask
102
-
103
- @property
104
- def polygons(self):
105
- if self._polygons is None:
106
- self._polygons, self._has_holes = self.mask_to_polygons(self._mask)
107
- return self._polygons
108
-
109
- @property
110
- def has_holes(self):
111
- if self._has_holes is None:
112
- if self._mask is not None:
113
- self._polygons, self._has_holes = self.mask_to_polygons(self._mask)
114
- else:
115
- self._has_holes = False # if original format is polygon, does not have holes
116
- return self._has_holes
117
-
118
- def mask_to_polygons(self, mask):
119
- # cv2.RETR_CCOMP flag retrieves all the contours and arranges them to a 2-level
120
- # hierarchy. External contours (boundary) of the object are placed in hierarchy-1.
121
- # Internal contours (holes) are placed in hierarchy-2.
122
- # cv2.CHAIN_APPROX_NONE flag gets vertices of polygons from contours.
123
- mask = np.ascontiguousarray(mask) # some versions of cv2 does not support incontiguous arr
124
- res = cv2.findContours(mask.astype("uint8"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
125
- hierarchy = res[-1]
126
- if hierarchy is None: # empty mask
127
- return [], False
128
- has_holes = (hierarchy.reshape(-1, 4)[:, 3] >= 0).sum() > 0
129
- res = res[-2]
130
- res = [x.flatten() for x in res]
131
- # These coordinates from OpenCV are integers in range [0, W-1 or H-1].
132
- # We add 0.5 to turn them into real-value coordinate space. A better solution
133
- # would be to first +0.5 and then dilate the returned polygon by 0.5.
134
- res = [x + 0.5 for x in res if len(x) >= 6]
135
- return res, has_holes
136
-
137
- def polygons_to_mask(self, polygons):
138
- rle = mask_util.frPyObjects(polygons, self.height, self.width)
139
- rle = mask_util.merge(rle)
140
- return mask_util.decode(rle)[:, :]
141
-
142
- def area(self):
143
- return self.mask.sum()
144
-
145
- def bbox(self):
146
- p = mask_util.frPyObjects(self.polygons, self.height, self.width)
147
- p = mask_util.merge(p)
148
- bbox = mask_util.toBbox(p)
149
- bbox[2] += bbox[0]
150
- bbox[3] += bbox[1]
151
- return bbox
152
-
153
-
154
- class _PanopticPrediction:
155
- """
156
- Unify different panoptic annotation/prediction formats
157
- """
158
-
159
- def __init__(self, panoptic_seg, segments_info, metadata=None):
160
- if segments_info is None:
161
- assert metadata is not None
162
- # If "segments_info" is None, we assume "panoptic_img" is a
163
- # H*W int32 image storing the panoptic_id in the format of
164
- # category_id * label_divisor + instance_id. We reserve -1 for
165
- # VOID label.
166
- label_divisor = metadata.label_divisor
167
- segments_info = []
168
- for panoptic_label in np.unique(panoptic_seg.numpy()):
169
- if panoptic_label == -1:
170
- # VOID region.
171
- continue
172
- pred_class = panoptic_label // label_divisor
173
- isthing = pred_class in metadata.thing_dataset_id_to_contiguous_id.values()
174
- segments_info.append(
175
- {
176
- "id": int(panoptic_label),
177
- "category_id": int(pred_class),
178
- "isthing": bool(isthing),
179
- }
180
- )
181
- del metadata
182
-
183
- self._seg = panoptic_seg
184
-
185
- self._sinfo = {s["id"]: s for s in segments_info} # seg id -> seg info
186
- segment_ids, areas = torch.unique(panoptic_seg, sorted=True, return_counts=True)
187
- areas = areas.numpy()
188
- sorted_idxs = np.argsort(-areas)
189
- self._seg_ids, self._seg_areas = segment_ids[sorted_idxs], areas[sorted_idxs]
190
- self._seg_ids = self._seg_ids.tolist()
191
- for sid, area in zip(self._seg_ids, self._seg_areas):
192
- if sid in self._sinfo:
193
- self._sinfo[sid]["area"] = float(area)
194
-
195
- def non_empty_mask(self):
196
- """
197
- Returns:
198
- (H, W) array, a mask for all pixels that have a prediction
199
- """
200
- empty_ids = []
201
- for id in self._seg_ids:
202
- if id not in self._sinfo:
203
- empty_ids.append(id)
204
- if len(empty_ids) == 0:
205
- return np.zeros(self._seg.shape, dtype=np.uint8)
206
- assert (
207
- len(empty_ids) == 1
208
- ), ">1 ids corresponds to no labels. This is currently not supported"
209
- return (self._seg != empty_ids[0]).numpy().astype(np.bool)
210
-
211
- def semantic_masks(self):
212
- for sid in self._seg_ids:
213
- sinfo = self._sinfo.get(sid)
214
- if sinfo is None or sinfo["isthing"]:
215
- # Some pixels (e.g. id 0 in PanopticFPN) have no instance or semantic predictions.
216
- continue
217
- yield (self._seg == sid).numpy().astype(np.bool), sinfo
218
-
219
- def instance_masks(self):
220
- for sid in self._seg_ids:
221
- sinfo = self._sinfo.get(sid)
222
- if sinfo is None or not sinfo["isthing"]:
223
- continue
224
- mask = (self._seg == sid).numpy().astype(np.bool)
225
- if mask.sum() > 0:
226
- yield mask, sinfo
227
-
228
-
229
- def _create_text_labels(classes, scores, class_names, is_crowd=None):
230
- """
231
- Args:
232
- classes (list[int] or None):
233
- scores (list[float] or None):
234
- class_names (list[str] or None):
235
- is_crowd (list[bool] or None):
236
-
237
- Returns:
238
- list[str] or None
239
- """
240
- labels = None
241
- if classes is not None:
242
- if class_names is not None and len(class_names) > 0:
243
- labels = [class_names[i] for i in classes]
244
- else:
245
- labels = [str(i) for i in classes]
246
- if scores is not None:
247
- if labels is None:
248
- labels = ["{:.0f}%".format(s * 100) for s in scores]
249
- else:
250
- labels = ["{} {:.0f}%".format(l, s * 100) for l, s in zip(labels, scores)]
251
- if labels is not None and is_crowd is not None:
252
- labels = [l + ("|crowd" if crowd else "") for l, crowd in zip(labels, is_crowd)]
253
- return labels
254
-
255
-
256
- class VisImage:
257
- def __init__(self, img, scale=1.0):
258
- """
259
- Args:
260
- img (ndarray): an RGB image of shape (H, W, 3) in range [0, 255].
261
- scale (float): scale the input image
262
- """
263
- self.img = img
264
- self.scale = scale
265
- self.width, self.height = img.shape[1], img.shape[0]
266
- self._setup_figure(img)
267
-
268
- def _setup_figure(self, img):
269
- """
270
- Args:
271
- Same as in :meth:`__init__()`.
272
-
273
- Returns:
274
- fig (matplotlib.pyplot.figure): top level container for all the image plot elements.
275
- ax (matplotlib.pyplot.Axes): contains figure elements and sets the coordinate system.
276
- """
277
- fig = mplfigure.Figure(frameon=False)
278
- self.dpi = fig.get_dpi()
279
- # add a small 1e-2 to avoid precision lost due to matplotlib's truncation
280
- # (https://github.com/matplotlib/matplotlib/issues/15363)
281
- fig.set_size_inches(
282
- (self.width * self.scale + 1e-2) / self.dpi,
283
- (self.height * self.scale + 1e-2) / self.dpi,
284
- )
285
- self.canvas = FigureCanvasAgg(fig)
286
- # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig)
287
- ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
288
- ax.axis("off")
289
- self.fig = fig
290
- self.ax = ax
291
- self.reset_image(img)
292
-
293
- def reset_image(self, img):
294
- """
295
- Args:
296
- img: same as in __init__
297
- """
298
- img = img.astype("uint8")
299
- self.ax.imshow(img, extent=(0, self.width, self.height, 0), interpolation="nearest")
300
-
301
- def save(self, filepath):
302
- """
303
- Args:
304
- filepath (str): a string that contains the absolute path, including the file name, where
305
- the visualized image will be saved.
306
- """
307
- self.fig.savefig(filepath)
308
-
309
- def get_image(self):
310
- """
311
- Returns:
312
- ndarray:
313
- the visualized image of shape (H, W, 3) (RGB) in uint8 type.
314
- The shape is scaled w.r.t the input image using the given `scale` argument.
315
- """
316
- canvas = self.canvas
317
- s, (width, height) = canvas.print_to_buffer()
318
- # buf = io.BytesIO() # works for cairo backend
319
- # canvas.print_rgba(buf)
320
- # width, height = self.width, self.height
321
- # s = buf.getvalue()
322
-
323
- buffer = np.frombuffer(s, dtype="uint8")
324
-
325
- img_rgba = buffer.reshape(height, width, 4)
326
- rgb, alpha = np.split(img_rgba, [3], axis=2)
327
- return rgb.astype("uint8")
328
-
329
-
330
- class Visualizer:
331
- """
332
- Visualizer that draws data about detection/segmentation on images.
333
-
334
- It contains methods like `draw_{text,box,circle,line,binary_mask,polygon}`
335
- that draw primitive objects to images, as well as high-level wrappers like
336
- `draw_{instance_predictions,sem_seg,panoptic_seg_predictions,dataset_dict}`
337
- that draw composite data in some pre-defined style.
338
-
339
- Note that the exact visualization style for the high-level wrappers are subject to change.
340
- Style such as color, opacity, label contents, visibility of labels, or even the visibility
341
- of objects themselves (e.g. when the object is too small) may change according
342
- to different heuristics, as long as the results still look visually reasonable.
343
-
344
- To obtain a consistent style, you can implement custom drawing functions with the
345
- abovementioned primitive methods instead. If you need more customized visualization
346
- styles, you can process the data yourself following their format documented in
347
- tutorials (:doc:`/tutorials/models`, :doc:`/tutorials/datasets`). This class does not
348
- intend to satisfy everyone's preference on drawing styles.
349
-
350
- This visualizer focuses on high rendering quality rather than performance. It is not
351
- designed to be used for real-time applications.
352
- """
353
-
354
- # TODO implement a fast, rasterized version using OpenCV
355
-
356
- def __init__(self, img_rgb, metadata=None, scale=1.0, instance_mode=ColorMode.IMAGE):
357
- """
358
- Args:
359
- img_rgb: a numpy array of shape (H, W, C), where H and W correspond to
360
- the height and width of the image respectively. C is the number of
361
- color channels. The image is required to be in RGB format since that
362
- is a requirement of the Matplotlib library. The image is also expected
363
- to be in the range [0, 255].
364
- metadata (Metadata): dataset metadata (e.g. class names and colors)
365
- instance_mode (ColorMode): defines one of the pre-defined style for drawing
366
- instances on an image.
367
- """
368
- self.img = np.asarray(img_rgb).clip(0, 255).astype(np.uint8)
369
- if metadata is None:
370
- metadata = MetadataCatalog.get("__nonexist__")
371
- self.metadata = metadata
372
- self.output = VisImage(self.img, scale=scale)
373
- self.cpu_device = torch.device("cpu")
374
-
375
- # too small texts are useless, therefore clamp to 9
376
- self._default_font_size = max(
377
- np.sqrt(self.output.height * self.output.width) // 90, 10 // scale
378
- )
379
- self._default_font_size = 18
380
- self._instance_mode = instance_mode
381
- self.keypoint_threshold = _KEYPOINT_THRESHOLD
382
-
383
- def draw_instance_predictions(self, predictions):
384
- """
385
- Draw instance-level prediction results on an image.
386
-
387
- Args:
388
- predictions (Instances): the output of an instance detection/segmentation
389
- model. Following fields will be used to draw:
390
- "pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle").
391
-
392
- Returns:
393
- output (VisImage): image object with visualizations.
394
- """
395
- boxes = predictions.pred_boxes if predictions.has("pred_boxes") else None
396
- scores = predictions.scores if predictions.has("scores") else None
397
- classes = predictions.pred_classes.tolist() if predictions.has("pred_classes") else None
398
- labels = _create_text_labels(classes, scores, self.metadata.get("thing_classes", None))
399
- keypoints = predictions.pred_keypoints if predictions.has("pred_keypoints") else None
400
-
401
- keep = (scores > 0.8).cpu()
402
- boxes = boxes[keep]
403
- scores = scores[keep]
404
- classes = np.array(classes)
405
- classes = classes[np.array(keep)]
406
- labels = np.array(labels)
407
- labels = labels[np.array(keep)]
408
-
409
- if predictions.has("pred_masks"):
410
- masks = np.asarray(predictions.pred_masks)
411
- masks = masks[np.array(keep)]
412
- masks = [GenericMask(x, self.output.height, self.output.width) for x in masks]
413
- else:
414
- masks = None
415
-
416
- if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"):
417
- # if self.metadata.get("thing_colors"):
418
- colors = [
419
- self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in classes
420
- ]
421
- alpha = 0.4
422
- else:
423
- colors = None
424
- alpha = 0.4
425
-
426
- if self._instance_mode == ColorMode.IMAGE_BW:
427
- self.output.reset_image(
428
- self._create_grayscale_image(
429
- (predictions.pred_masks.any(dim=0) > 0).numpy()
430
- if predictions.has("pred_masks")
431
- else None
432
- )
433
- )
434
- alpha = 0.3
435
-
436
- self.overlay_instances(
437
- masks=masks,
438
- boxes=boxes,
439
- labels=labels,
440
- keypoints=keypoints,
441
- assigned_colors=colors,
442
- alpha=alpha,
443
- )
444
- return self.output
445
-
446
- def draw_sem_seg(self, sem_seg, area_threshold=None, alpha=0.7):
447
- """
448
- Draw semantic segmentation predictions/labels.
449
-
450
- Args:
451
- sem_seg (Tensor or ndarray): the segmentation of shape (H, W).
452
- Each value is the integer label of the pixel.
453
- area_threshold (int): segments with less than `area_threshold` are not drawn.
454
- alpha (float): the larger it is, the more opaque the segmentations are.
455
-
456
- Returns:
457
- output (VisImage): image object with visualizations.
458
- """
459
- if isinstance(sem_seg, torch.Tensor):
460
- sem_seg = sem_seg.numpy()
461
- labels, areas = np.unique(sem_seg, return_counts=True)
462
- sorted_idxs = np.argsort(-areas).tolist()
463
- labels = labels[sorted_idxs]
464
- for label in filter(lambda l: l < len(self.metadata.stuff_classes), labels):
465
- try:
466
- mask_color = [x / 255 for x in self.metadata.stuff_colors[label]]
467
- except (AttributeError, IndexError):
468
- mask_color = None
469
-
470
- binary_mask = (sem_seg == label).astype(np.uint8)
471
- text = self.metadata.stuff_classes[label]
472
- self.draw_binary_mask(
473
- binary_mask,
474
- color=mask_color,
475
- edge_color=_OFF_WHITE,
476
- text=text,
477
- alpha=alpha,
478
- area_threshold=area_threshold,
479
- )
480
- return self.output
481
-
482
- def draw_panoptic_seg(self, panoptic_seg, segments_info, area_threshold=None, alpha=0.7):
483
- """
484
- Draw panoptic prediction annotations or results.
485
-
486
- Args:
487
- panoptic_seg (Tensor): of shape (height, width) where the values are ids for each
488
- segment.
489
- segments_info (list[dict] or None): Describe each segment in `panoptic_seg`.
490
- If it is a ``list[dict]``, each dict contains keys "id", "category_id".
491
- If None, category id of each pixel is computed by
492
- ``pixel // metadata.label_divisor``.
493
- area_threshold (int): stuff segments with less than `area_threshold` are not drawn.
494
-
495
- Returns:
496
- output (VisImage): image object with visualizations.
497
- """
498
- pred = _PanopticPrediction(panoptic_seg, segments_info, self.metadata)
499
-
500
- if self._instance_mode == ColorMode.IMAGE_BW:
501
- self.output.reset_image(self._create_grayscale_image(pred.non_empty_mask()))
502
-
503
- # draw mask for all semantic segments first i.e. "stuff"
504
- for mask, sinfo in pred.semantic_masks():
505
- category_idx = sinfo["category_id"]
506
- try:
507
- mask_color = [x / 255 for x in self.metadata.stuff_colors[category_idx]]
508
- except AttributeError:
509
- mask_color = None
510
-
511
- text = self.metadata.stuff_classes[category_idx]
512
- self.draw_binary_mask(
513
- mask,
514
- color=mask_color,
515
- edge_color=_OFF_WHITE,
516
- text=text,
517
- alpha=alpha,
518
- area_threshold=area_threshold,
519
- )
520
-
521
- # draw mask for all instances second
522
- all_instances = list(pred.instance_masks())
523
- if len(all_instances) == 0:
524
- return self.output
525
- masks, sinfo = list(zip(*all_instances))
526
- category_ids = [x["category_id"] for x in sinfo]
527
-
528
- try:
529
- scores = [x["score"] for x in sinfo]
530
- except KeyError:
531
- scores = None
532
- labels = _create_text_labels(
533
- category_ids, scores, self.metadata.thing_classes, [x.get("iscrowd", 0) for x in sinfo]
534
- )
535
-
536
- try:
537
- colors = [
538
- self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in category_ids
539
- ]
540
- except AttributeError:
541
- colors = None
542
- self.overlay_instances(masks=masks, labels=labels, assigned_colors=colors, alpha=alpha)
543
-
544
- return self.output
545
-
546
- draw_panoptic_seg_predictions = draw_panoptic_seg # backward compatibility
547
-
548
- def draw_dataset_dict(self, dic):
549
- """
550
- Draw annotations/segmentaions in Detectron2 Dataset format.
551
-
552
- Args:
553
- dic (dict): annotation/segmentation data of one image, in Detectron2 Dataset format.
554
-
555
- Returns:
556
- output (VisImage): image object with visualizations.
557
- """
558
- annos = dic.get("annotations", None)
559
- if annos:
560
- if "segmentation" in annos[0]:
561
- masks = [x["segmentation"] for x in annos]
562
- else:
563
- masks = None
564
- if "keypoints" in annos[0]:
565
- keypts = [x["keypoints"] for x in annos]
566
- keypts = np.array(keypts).reshape(len(annos), -1, 3)
567
- else:
568
- keypts = None
569
-
570
- boxes = [
571
- BoxMode.convert(x["bbox"], x["bbox_mode"], BoxMode.XYXY_ABS)
572
- if len(x["bbox"]) == 4
573
- else x["bbox"]
574
- for x in annos
575
- ]
576
-
577
- colors = None
578
- category_ids = [x["category_id"] for x in annos]
579
- if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"):
580
- colors = [
581
- self._jitter([x / 255 for x in self.metadata.thing_colors[c]])
582
- for c in category_ids
583
- ]
584
- names = self.metadata.get("thing_classes", None)
585
- labels = _create_text_labels(
586
- category_ids,
587
- scores=None,
588
- class_names=names,
589
- is_crowd=[x.get("iscrowd", 0) for x in annos],
590
- )
591
- self.overlay_instances(
592
- labels=labels, boxes=boxes, masks=masks, keypoints=keypts, assigned_colors=colors
593
- )
594
-
595
- sem_seg = dic.get("sem_seg", None)
596
- if sem_seg is None and "sem_seg_file_name" in dic:
597
- with PathManager.open(dic["sem_seg_file_name"], "rb") as f:
598
- sem_seg = Image.open(f)
599
- sem_seg = np.asarray(sem_seg, dtype="uint8")
600
- if sem_seg is not None:
601
- self.draw_sem_seg(sem_seg, area_threshold=0, alpha=0.4)
602
-
603
- pan_seg = dic.get("pan_seg", None)
604
- if pan_seg is None and "pan_seg_file_name" in dic:
605
- with PathManager.open(dic["pan_seg_file_name"], "rb") as f:
606
- pan_seg = Image.open(f)
607
- pan_seg = np.asarray(pan_seg)
608
- from panopticapi.utils import rgb2id
609
-
610
- pan_seg = rgb2id(pan_seg)
611
- if pan_seg is not None:
612
- segments_info = dic["segments_info"]
613
- pan_seg = torch.tensor(pan_seg)
614
- self.draw_panoptic_seg(pan_seg, segments_info, area_threshold=0, alpha=0.7)
615
- return self.output
616
-
617
- def overlay_instances(
618
- self,
619
- *,
620
- boxes=None,
621
- labels=None,
622
- masks=None,
623
- keypoints=None,
624
- assigned_colors=None,
625
- alpha=0.5,
626
- ):
627
- """
628
- Args:
629
- boxes (Boxes, RotatedBoxes or ndarray): either a :class:`Boxes`,
630
- or an Nx4 numpy array of XYXY_ABS format for the N objects in a single image,
631
- or a :class:`RotatedBoxes`,
632
- or an Nx5 numpy array of (x_center, y_center, width, height, angle_degrees) format
633
- for the N objects in a single image,
634
- labels (list[str]): the text to be displayed for each instance.
635
- masks (masks-like object): Supported types are:
636
-
637
- * :class:`detectron2.structures.PolygonMasks`,
638
- :class:`detectron2.structures.BitMasks`.
639
- * list[list[ndarray]]: contains the segmentation masks for all objects in one image.
640
- The first level of the list corresponds to individual instances. The second
641
- level to all the polygon that compose the instance, and the third level
642
- to the polygon coordinates. The third level should have the format of
643
- [x0, y0, x1, y1, ..., xn, yn] (n >= 3).
644
- * list[ndarray]: each ndarray is a binary mask of shape (H, W).
645
- * list[dict]: each dict is a COCO-style RLE.
646
- keypoints (Keypoint or array like): an array-like object of shape (N, K, 3),
647
- where the N is the number of instances and K is the number of keypoints.
648
- The last dimension corresponds to (x, y, visibility or score).
649
- assigned_colors (list[matplotlib.colors]): a list of colors, where each color
650
- corresponds to each mask or box in the image. Refer to 'matplotlib.colors'
651
- for full list of formats that the colors are accepted in.
652
- Returns:
653
- output (VisImage): image object with visualizations.
654
- """
655
- num_instances = 0
656
- if boxes is not None:
657
- boxes = self._convert_boxes(boxes)
658
- num_instances = len(boxes)
659
- if masks is not None:
660
- masks = self._convert_masks(masks)
661
- if num_instances:
662
- assert len(masks) == num_instances
663
- else:
664
- num_instances = len(masks)
665
- if keypoints is not None:
666
- if num_instances:
667
- assert len(keypoints) == num_instances
668
- else:
669
- num_instances = len(keypoints)
670
- keypoints = self._convert_keypoints(keypoints)
671
- if labels is not None:
672
- assert len(labels) == num_instances
673
- if assigned_colors is None:
674
- assigned_colors = [random_color(rgb=True, maximum=1) for _ in range(num_instances)]
675
- if num_instances == 0:
676
- return self.output
677
- if boxes is not None and boxes.shape[1] == 5:
678
- return self.overlay_rotated_instances(
679
- boxes=boxes, labels=labels, assigned_colors=assigned_colors
680
- )
681
-
682
- # Display in largest to smallest order to reduce occlusion.
683
- areas = None
684
- if boxes is not None:
685
- areas = np.prod(boxes[:, 2:] - boxes[:, :2], axis=1)
686
- elif masks is not None:
687
- areas = np.asarray([x.area() for x in masks])
688
-
689
- if areas is not None:
690
- sorted_idxs = np.argsort(-areas).tolist()
691
- # Re-order overlapped instances in descending order.
692
- boxes = boxes[sorted_idxs] if boxes is not None else None
693
- labels = [labels[k] for k in sorted_idxs] if labels is not None else None
694
- masks = [masks[idx] for idx in sorted_idxs] if masks is not None else None
695
- assigned_colors = [assigned_colors[idx] for idx in sorted_idxs]
696
- keypoints = keypoints[sorted_idxs] if keypoints is not None else None
697
-
698
- for i in range(num_instances):
699
- color = assigned_colors[i]
700
- if boxes is not None:
701
- self.draw_box(boxes[i], edge_color=color)
702
-
703
- if masks is not None:
704
- for segment in masks[i].polygons:
705
- self.draw_polygon(segment.reshape(-1, 2), color, alpha=alpha)
706
-
707
- if labels is not None:
708
- # first get a box
709
- if boxes is not None:
710
- x0, y0, x1, y1 = boxes[i]
711
- text_pos = (x0, y0) # if drawing boxes, put text on the box corner.
712
- horiz_align = "left"
713
- elif masks is not None:
714
- # skip small mask without polygon
715
- if len(masks[i].polygons) == 0:
716
- continue
717
-
718
- x0, y0, x1, y1 = masks[i].bbox()
719
-
720
- # draw text in the center (defined by median) when box is not drawn
721
- # median is less sensitive to outliers.
722
- text_pos = np.median(masks[i].mask.nonzero(), axis=1)[::-1]
723
- horiz_align = "center"
724
- else:
725
- continue # drawing the box confidence for keypoints isn't very useful.
726
- # for small objects, draw text at the side to avoid occlusion
727
- instance_area = (y1 - y0) * (x1 - x0)
728
- if (
729
- instance_area < _SMALL_OBJECT_AREA_THRESH * self.output.scale
730
- or y1 - y0 < 40 * self.output.scale
731
- ):
732
- if y1 >= self.output.height - 5:
733
- text_pos = (x1, y0)
734
- else:
735
- text_pos = (x0, y1)
736
-
737
- height_ratio = (y1 - y0) / np.sqrt(self.output.height * self.output.width)
738
- lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
739
- font_size = (
740
- np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2)
741
- * 0.5
742
- * self._default_font_size
743
- )
744
- self.draw_text(
745
- labels[i],
746
- text_pos,
747
- color=lighter_color,
748
- horizontal_alignment=horiz_align,
749
- font_size=font_size,
750
- )
751
-
752
- # draw keypoints
753
- if keypoints is not None:
754
- for keypoints_per_instance in keypoints:
755
- self.draw_and_connect_keypoints(keypoints_per_instance)
756
-
757
- return self.output
758
-
759
- def overlay_rotated_instances(self, boxes=None, labels=None, assigned_colors=None):
760
- """
761
- Args:
762
- boxes (ndarray): an Nx5 numpy array of
763
- (x_center, y_center, width, height, angle_degrees) format
764
- for the N objects in a single image.
765
- labels (list[str]): the text to be displayed for each instance.
766
- assigned_colors (list[matplotlib.colors]): a list of colors, where each color
767
- corresponds to each mask or box in the image. Refer to 'matplotlib.colors'
768
- for full list of formats that the colors are accepted in.
769
-
770
- Returns:
771
- output (VisImage): image object with visualizations.
772
- """
773
- num_instances = len(boxes)
774
-
775
- if assigned_colors is None:
776
- assigned_colors = [random_color(rgb=True, maximum=1) for _ in range(num_instances)]
777
- if num_instances == 0:
778
- return self.output
779
-
780
- # Display in largest to smallest order to reduce occlusion.
781
- if boxes is not None:
782
- areas = boxes[:, 2] * boxes[:, 3]
783
-
784
- sorted_idxs = np.argsort(-areas).tolist()
785
- # Re-order overlapped instances in descending order.
786
- boxes = boxes[sorted_idxs]
787
- labels = [labels[k] for k in sorted_idxs] if labels is not None else None
788
- colors = [assigned_colors[idx] for idx in sorted_idxs]
789
-
790
- for i in range(num_instances):
791
- self.draw_rotated_box_with_label(
792
- boxes[i], edge_color=colors[i], label=labels[i] if labels is not None else None
793
- )
794
-
795
- return self.output
796
-
797
- def draw_and_connect_keypoints(self, keypoints):
798
- """
799
- Draws keypoints of an instance and follows the rules for keypoint connections
800
- to draw lines between appropriate keypoints. This follows color heuristics for
801
- line color.
802
-
803
- Args:
804
- keypoints (Tensor): a tensor of shape (K, 3), where K is the number of keypoints
805
- and the last dimension corresponds to (x, y, probability).
806
-
807
- Returns:
808
- output (VisImage): image object with visualizations.
809
- """
810
- visible = {}
811
- keypoint_names = self.metadata.get("keypoint_names")
812
- for idx, keypoint in enumerate(keypoints):
813
-
814
- # draw keypoint
815
- x, y, prob = keypoint
816
- if prob > self.keypoint_threshold:
817
- self.draw_circle((x, y), color=_RED)
818
- if keypoint_names:
819
- keypoint_name = keypoint_names[idx]
820
- visible[keypoint_name] = (x, y)
821
-
822
- if self.metadata.get("keypoint_connection_rules"):
823
- for kp0, kp1, color in self.metadata.keypoint_connection_rules:
824
- if kp0 in visible and kp1 in visible:
825
- x0, y0 = visible[kp0]
826
- x1, y1 = visible[kp1]
827
- color = tuple(x / 255.0 for x in color)
828
- self.draw_line([x0, x1], [y0, y1], color=color)
829
-
830
- # draw lines from nose to mid-shoulder and mid-shoulder to mid-hip
831
- # Note that this strategy is specific to person keypoints.
832
- # For other keypoints, it should just do nothing
833
- try:
834
- ls_x, ls_y = visible["left_shoulder"]
835
- rs_x, rs_y = visible["right_shoulder"]
836
- mid_shoulder_x, mid_shoulder_y = (ls_x + rs_x) / 2, (ls_y + rs_y) / 2
837
- except KeyError:
838
- pass
839
- else:
840
- # draw line from nose to mid-shoulder
841
- nose_x, nose_y = visible.get("nose", (None, None))
842
- if nose_x is not None:
843
- self.draw_line([nose_x, mid_shoulder_x], [nose_y, mid_shoulder_y], color=_RED)
844
-
845
- try:
846
- # draw line from mid-shoulder to mid-hip
847
- lh_x, lh_y = visible["left_hip"]
848
- rh_x, rh_y = visible["right_hip"]
849
- except KeyError:
850
- pass
851
- else:
852
- mid_hip_x, mid_hip_y = (lh_x + rh_x) / 2, (lh_y + rh_y) / 2
853
- self.draw_line([mid_hip_x, mid_shoulder_x], [mid_hip_y, mid_shoulder_y], color=_RED)
854
- return self.output
855
-
856
- """
857
- Primitive drawing functions:
858
- """
859
-
860
- def draw_text(
861
- self,
862
- text,
863
- position,
864
- *,
865
- font_size=None,
866
- color="g",
867
- horizontal_alignment="center",
868
- rotation=0,
869
- ):
870
- """
871
- Args:
872
- text (str): class label
873
- position (tuple): a tuple of the x and y coordinates to place text on image.
874
- font_size (int, optional): font of the text. If not provided, a font size
875
- proportional to the image width is calculated and used.
876
- color: color of the text. Refer to `matplotlib.colors` for full list
877
- of formats that are accepted.
878
- horizontal_alignment (str): see `matplotlib.text.Text`
879
- rotation: rotation angle in degrees CCW
880
-
881
- Returns:
882
- output (VisImage): image object with text drawn.
883
- """
884
- if not font_size:
885
- font_size = self._default_font_size
886
-
887
- # since the text background is dark, we don't want the text to be dark
888
- color = np.maximum(list(mplc.to_rgb(color)), 0.2)
889
- color[np.argmax(color)] = max(0.8, np.max(color))
890
-
891
- x, y = position
892
- self.output.ax.text(
893
- x,
894
- y,
895
- text,
896
- size=font_size * self.output.scale,
897
- family="sans-serif",
898
- bbox={"facecolor": "black", "alpha": 0.8, "pad": 0.7, "edgecolor": "none"},
899
- verticalalignment="top",
900
- horizontalalignment=horizontal_alignment,
901
- color=color,
902
- zorder=10,
903
- rotation=rotation,
904
- )
905
- return self.output
906
-
907
- def draw_box(self, box_coord, alpha=0.5, edge_color="g", line_style="-"):
908
- """
909
- Args:
910
- box_coord (tuple): a tuple containing x0, y0, x1, y1 coordinates, where x0 and y0
911
- are the coordinates of the image's top left corner. x1 and y1 are the
912
- coordinates of the image's bottom right corner.
913
- alpha (float): blending efficient. Smaller values lead to more transparent masks.
914
- edge_color: color of the outline of the box. Refer to `matplotlib.colors`
915
- for full list of formats that are accepted.
916
- line_style (string): the string to use to create the outline of the boxes.
917
-
918
- Returns:
919
- output (VisImage): image object with box drawn.
920
- """
921
- x0, y0, x1, y1 = box_coord
922
- width = x1 - x0
923
- height = y1 - y0
924
-
925
- linewidth = max(self._default_font_size / 4, 1)
926
-
927
- self.output.ax.add_patch(
928
- mpl.patches.Rectangle(
929
- (x0, y0),
930
- width,
931
- height,
932
- fill=False,
933
- edgecolor=edge_color,
934
- linewidth=linewidth * self.output.scale,
935
- alpha=alpha,
936
- linestyle=line_style,
937
- )
938
- )
939
- return self.output
940
-
941
- def draw_rotated_box_with_label(
942
- self, rotated_box, alpha=0.5, edge_color="g", line_style="-", label=None
943
- ):
944
- """
945
- Draw a rotated box with label on its top-left corner.
946
-
947
- Args:
948
- rotated_box (tuple): a tuple containing (cnt_x, cnt_y, w, h, angle),
949
- where cnt_x and cnt_y are the center coordinates of the box.
950
- w and h are the width and height of the box. angle represents how
951
- many degrees the box is rotated CCW with regard to the 0-degree box.
952
- alpha (float): blending efficient. Smaller values lead to more transparent masks.
953
- edge_color: color of the outline of the box. Refer to `matplotlib.colors`
954
- for full list of formats that are accepted.
955
- line_style (string): the string to use to create the outline of the boxes.
956
- label (string): label for rotated box. It will not be rendered when set to None.
957
-
958
- Returns:
959
- output (VisImage): image object with box drawn.
960
- """
961
- cnt_x, cnt_y, w, h, angle = rotated_box
962
- area = w * h
963
- # use thinner lines when the box is small
964
- linewidth = self._default_font_size / (
965
- 6 if area < _SMALL_OBJECT_AREA_THRESH * self.output.scale else 3
966
- )
967
-
968
- theta = angle * math.pi / 180.0
969
- c = math.cos(theta)
970
- s = math.sin(theta)
971
- rect = [(-w / 2, h / 2), (-w / 2, -h / 2), (w / 2, -h / 2), (w / 2, h / 2)]
972
- # x: left->right ; y: top->down
973
- rotated_rect = [(s * yy + c * xx + cnt_x, c * yy - s * xx + cnt_y) for (xx, yy) in rect]
974
- for k in range(4):
975
- j = (k + 1) % 4
976
- self.draw_line(
977
- [rotated_rect[k][0], rotated_rect[j][0]],
978
- [rotated_rect[k][1], rotated_rect[j][1]],
979
- color=edge_color,
980
- linestyle="--" if k == 1 else line_style,
981
- linewidth=linewidth,
982
- )
983
-
984
- if label is not None:
985
- text_pos = rotated_rect[1] # topleft corner
986
-
987
- height_ratio = h / np.sqrt(self.output.height * self.output.width)
988
- label_color = self._change_color_brightness(edge_color, brightness_factor=0.7)
989
- font_size = (
990
- np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) * 0.5 * self._default_font_size
991
- )
992
- self.draw_text(label, text_pos, color=label_color, font_size=font_size, rotation=angle)
993
-
994
- return self.output
995
-
996
- def draw_circle(self, circle_coord, color, radius=3):
997
- """
998
- Args:
999
- circle_coord (list(int) or tuple(int)): contains the x and y coordinates
1000
- of the center of the circle.
1001
- color: color of the polygon. Refer to `matplotlib.colors` for a full list of
1002
- formats that are accepted.
1003
- radius (int): radius of the circle.
1004
-
1005
- Returns:
1006
- output (VisImage): image object with box drawn.
1007
- """
1008
- x, y = circle_coord
1009
- self.output.ax.add_patch(
1010
- mpl.patches.Circle(circle_coord, radius=radius, fill=True, color=color)
1011
- )
1012
- return self.output
1013
-
1014
- def draw_line(self, x_data, y_data, color, linestyle="-", linewidth=None):
1015
- """
1016
- Args:
1017
- x_data (list[int]): a list containing x values of all the points being drawn.
1018
- Length of list should match the length of y_data.
1019
- y_data (list[int]): a list containing y values of all the points being drawn.
1020
- Length of list should match the length of x_data.
1021
- color: color of the line. Refer to `matplotlib.colors` for a full list of
1022
- formats that are accepted.
1023
- linestyle: style of the line. Refer to `matplotlib.lines.Line2D`
1024
- for a full list of formats that are accepted.
1025
- linewidth (float or None): width of the line. When it's None,
1026
- a default value will be computed and used.
1027
-
1028
- Returns:
1029
- output (VisImage): image object with line drawn.
1030
- """
1031
- if linewidth is None:
1032
- linewidth = self._default_font_size / 3
1033
- linewidth = max(linewidth, 1)
1034
- self.output.ax.add_line(
1035
- mpl.lines.Line2D(
1036
- x_data,
1037
- y_data,
1038
- linewidth=linewidth * self.output.scale,
1039
- color=color,
1040
- linestyle=linestyle,
1041
- )
1042
- )
1043
- return self.output
1044
-
1045
- def draw_binary_mask(
1046
- self, binary_mask, color=None, *, edge_color=None, text=None, alpha=0.7, area_threshold=10
1047
- ):
1048
- """
1049
- Args:
1050
- binary_mask (ndarray): numpy array of shape (H, W), where H is the image height and
1051
- W is the image width. Each value in the array is either a 0 or 1 value of uint8
1052
- type.
1053
- color: color of the mask. Refer to `matplotlib.colors` for a full list of
1054
- formats that are accepted. If None, will pick a random color.
1055
- edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a
1056
- full list of formats that are accepted.
1057
- text (str): if None, will be drawn on the object
1058
- alpha (float): blending efficient. Smaller values lead to more transparent masks.
1059
- area_threshold (float): a connected component smaller than this area will not be shown.
1060
-
1061
- Returns:
1062
- output (VisImage): image object with mask drawn.
1063
- """
1064
- if color is None:
1065
- color = random_color(rgb=True, maximum=1)
1066
- color = mplc.to_rgb(color)
1067
-
1068
- has_valid_segment = False
1069
- binary_mask = binary_mask.astype("uint8") # opencv needs uint8
1070
- mask = GenericMask(binary_mask, self.output.height, self.output.width)
1071
- shape2d = (binary_mask.shape[0], binary_mask.shape[1])
1072
-
1073
- if not mask.has_holes:
1074
- # draw polygons for regular masks
1075
- for segment in mask.polygons:
1076
- area = mask_util.area(mask_util.frPyObjects([segment], shape2d[0], shape2d[1]))
1077
- if area < (area_threshold or 0):
1078
- continue
1079
- has_valid_segment = True
1080
- segment = segment.reshape(-1, 2)
1081
- self.draw_polygon(segment, color=color, edge_color=edge_color, alpha=alpha)
1082
- else:
1083
- # TODO: Use Path/PathPatch to draw vector graphics:
1084
- # https://stackoverflow.com/questions/8919719/how-to-plot-a-complex-polygon
1085
- rgba = np.zeros(shape2d + (4,), dtype="float32")
1086
- rgba[:, :, :3] = color
1087
- rgba[:, :, 3] = (mask.mask == 1).astype("float32") * alpha
1088
- has_valid_segment = True
1089
- self.output.ax.imshow(rgba, extent=(0, self.output.width, self.output.height, 0))
1090
-
1091
- if text is not None and has_valid_segment:
1092
- lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
1093
- self._draw_text_in_mask(binary_mask, text, lighter_color)
1094
- return self.output
1095
-
1096
- def draw_soft_mask(self, soft_mask, color=None, *, text=None, alpha=0.5):
1097
- """
1098
- Args:
1099
- soft_mask (ndarray): float array of shape (H, W), each value in [0, 1].
1100
- color: color of the mask. Refer to `matplotlib.colors` for a full list of
1101
- formats that are accepted. If None, will pick a random color.
1102
- text (str): if None, will be drawn on the object
1103
- alpha (float): blending efficient. Smaller values lead to more transparent masks.
1104
-
1105
- Returns:
1106
- output (VisImage): image object with mask drawn.
1107
- """
1108
- if color is None:
1109
- color = random_color(rgb=True, maximum=1)
1110
- color = mplc.to_rgb(color)
1111
-
1112
- shape2d = (soft_mask.shape[0], soft_mask.shape[1])
1113
- rgba = np.zeros(shape2d + (4,), dtype="float32")
1114
- rgba[:, :, :3] = color
1115
- rgba[:, :, 3] = soft_mask * alpha
1116
- self.output.ax.imshow(rgba, extent=(0, self.output.width, self.output.height, 0))
1117
-
1118
- if text is not None:
1119
- lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
1120
- binary_mask = (soft_mask > 0.5).astype("uint8")
1121
- self._draw_text_in_mask(binary_mask, text, lighter_color)
1122
- return self.output
1123
-
1124
- def draw_polygon(self, segment, color, edge_color=None, alpha=0.5):
1125
- """
1126
- Args:
1127
- segment: numpy array of shape Nx2, containing all the points in the polygon.
1128
- color: color of the polygon. Refer to `matplotlib.colors` for a full list of
1129
- formats that are accepted.
1130
- edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a
1131
- full list of formats that are accepted. If not provided, a darker shade
1132
- of the polygon color will be used instead.
1133
- alpha (float): blending efficient. Smaller values lead to more transparent masks.
1134
-
1135
- Returns:
1136
- output (VisImage): image object with polygon drawn.
1137
- """
1138
- if edge_color is None:
1139
- # make edge color darker than the polygon color
1140
- if alpha > 0.8:
1141
- edge_color = self._change_color_brightness(color, brightness_factor=-0.7)
1142
- else:
1143
- edge_color = color
1144
- edge_color = mplc.to_rgb(edge_color) + (1,)
1145
-
1146
- polygon = mpl.patches.Polygon(
1147
- segment,
1148
- fill=True,
1149
- facecolor=mplc.to_rgb(color) + (alpha,),
1150
- edgecolor=edge_color,
1151
- linewidth=max(self._default_font_size // 15 * self.output.scale, 1),
1152
- )
1153
- self.output.ax.add_patch(polygon)
1154
- return self.output
1155
-
1156
- """
1157
- Internal methods:
1158
- """
1159
-
1160
- def _jitter(self, color):
1161
- """
1162
- Randomly modifies given color to produce a slightly different color than the color given.
1163
-
1164
- Args:
1165
- color (tuple[double]): a tuple of 3 elements, containing the RGB values of the color
1166
- picked. The values in the list are in the [0.0, 1.0] range.
1167
-
1168
- Returns:
1169
- jittered_color (tuple[double]): a tuple of 3 elements, containing the RGB values of the
1170
- color after being jittered. The values in the list are in the [0.0, 1.0] range.
1171
- """
1172
- color = mplc.to_rgb(color)
1173
- # np.random.seed(0)
1174
- vec = np.random.rand(3)
1175
- # better to do it in another color space
1176
- vec = vec / np.linalg.norm(vec) * 0.5
1177
- res = np.clip(vec + color, 0, 1)
1178
- return tuple(res)
1179
-
1180
- def _create_grayscale_image(self, mask=None):
1181
- """
1182
- Create a grayscale version of the original image.
1183
- The colors in masked area, if given, will be kept.
1184
- """
1185
- img_bw = self.img.astype("f4").mean(axis=2)
1186
- img_bw = np.stack([img_bw] * 3, axis=2)
1187
- if mask is not None:
1188
- img_bw[mask] = self.img[mask]
1189
- return img_bw
1190
-
1191
- def _change_color_brightness(self, color, brightness_factor):
1192
- """
1193
- Depending on the brightness_factor, gives a lighter or darker color i.e. a color with
1194
- less or more saturation than the original color.
1195
-
1196
- Args:
1197
- color: color of the polygon. Refer to `matplotlib.colors` for a full list of
1198
- formats that are accepted.
1199
- brightness_factor (float): a value in [-1.0, 1.0] range. A lightness factor of
1200
- 0 will correspond to no change, a factor in [-1.0, 0) range will result in
1201
- a darker color and a factor in (0, 1.0] range will result in a lighter color.
1202
-
1203
- Returns:
1204
- modified_color (tuple[double]): a tuple containing the RGB values of the
1205
- modified color. Each value in the tuple is in the [0.0, 1.0] range.
1206
- """
1207
- assert brightness_factor >= -1.0 and brightness_factor <= 1.0
1208
- color = mplc.to_rgb(color)
1209
- polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color))
1210
- modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1])
1211
- modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness
1212
- modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness
1213
- modified_color = colorsys.hls_to_rgb(polygon_color[0], modified_lightness, polygon_color[2])
1214
- return modified_color
1215
-
1216
- def _convert_boxes(self, boxes):
1217
- """
1218
- Convert different format of boxes to an NxB array, where B = 4 or 5 is the box dimension.
1219
- """
1220
- if isinstance(boxes, Boxes) or isinstance(boxes, RotatedBoxes):
1221
- return boxes.tensor.detach().numpy()
1222
- else:
1223
- return np.asarray(boxes)
1224
-
1225
- def _convert_masks(self, masks_or_polygons):
1226
- """
1227
- Convert different format of masks or polygons to a tuple of masks and polygons.
1228
-
1229
- Returns:
1230
- list[GenericMask]:
1231
- """
1232
-
1233
- m = masks_or_polygons
1234
- if isinstance(m, PolygonMasks):
1235
- m = m.polygons
1236
- if isinstance(m, BitMasks):
1237
- m = m.tensor.numpy()
1238
- if isinstance(m, torch.Tensor):
1239
- m = m.numpy()
1240
- ret = []
1241
- for x in m:
1242
- if isinstance(x, GenericMask):
1243
- ret.append(x)
1244
- else:
1245
- ret.append(GenericMask(x, self.output.height, self.output.width))
1246
- return ret
1247
-
1248
- def _draw_text_in_mask(self, binary_mask, text, color):
1249
- """
1250
- Find proper places to draw text given a binary mask.
1251
- """
1252
- # TODO sometimes drawn on wrong objects. the heuristics here can improve.
1253
- _num_cc, cc_labels, stats, centroids = cv2.connectedComponentsWithStats(binary_mask, 8)
1254
- if stats[1:, -1].size == 0:
1255
- return
1256
- largest_component_id = np.argmax(stats[1:, -1]) + 1
1257
-
1258
- # draw text on the largest component, as well as other very large components.
1259
- for cid in range(1, _num_cc):
1260
- if cid == largest_component_id or stats[cid, -1] > _LARGE_MASK_AREA_THRESH:
1261
- # median is more stable than centroid
1262
- # center = centroids[largest_component_id]
1263
- center = np.median((cc_labels == cid).nonzero(), axis=1)[::-1]
1264
- self.draw_text(text, center, color=color)
1265
-
1266
- def _convert_keypoints(self, keypoints):
1267
- if isinstance(keypoints, Keypoints):
1268
- keypoints = keypoints.tensor
1269
- keypoints = np.asarray(keypoints)
1270
- return keypoints
1271
-
1272
- def get_output(self):
1273
- """
1274
- Returns:
1275
- output (VisImage): the image output containing the visualizations added
1276
- to the image.
1277
- """
1278
- return self.output