nikunjkdtechnoland
commited on
Commit
·
4b98c85
1
Parent(s):
e041d7d
some more add more files
Browse files- iopaint/file_manager/utils.py +65 -0
- iopaint/model/anytext/ldm/modules/diffusionmodules/upscaling.py +81 -0
- iopaint/model/anytext/ldm/modules/diffusionmodules/util.py +271 -0
- iopaint/model/anytext/ldm/util.py +197 -0
- iopaint/model/anytext/utils.py +151 -0
- iopaint/model/original_sd_configs/v1-inference.yaml +70 -0
- iopaint/model/original_sd_configs/v2-inference-v.yaml +68 -0
- iopaint/model/utils.py +1033 -0
- iopaint/model/zits.py +476 -0
- iopaint/plugins/segment_anything/modeling/tiny_vit_sam.py +822 -0
- iopaint/plugins/segment_anything/modeling/transformer.py +240 -0
- iopaint/plugins/segment_anything/utils/transforms.py +112 -0
- iopaint/tests/test_sdxl.py +172 -0
- iopaint/tests/utils.py +77 -0
- iopaint/web_config.py +307 -0
- pretrained-model/version.txt +1 -0
- pretrained-model/version_diffusers_cache.txt +1 -0
- utils/tools.py +505 -0
iopaint/file_manager/utils.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copy from: https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/utils.py
|
| 2 |
+
import hashlib
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
from typing import Union
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def generate_filename(directory: Path, original_filename, *options) -> str:
|
| 9 |
+
text = str(directory.absolute()) + original_filename
|
| 10 |
+
for v in options:
|
| 11 |
+
text += "%s" % v
|
| 12 |
+
md5_hash = hashlib.md5()
|
| 13 |
+
md5_hash.update(text.encode("utf-8"))
|
| 14 |
+
return md5_hash.hexdigest() + ".jpg"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def parse_size(size):
|
| 18 |
+
if isinstance(size, int):
|
| 19 |
+
# If the size parameter is a single number, assume square aspect.
|
| 20 |
+
return [size, size]
|
| 21 |
+
|
| 22 |
+
if isinstance(size, (tuple, list)):
|
| 23 |
+
if len(size) == 1:
|
| 24 |
+
# If single value tuple/list is provided, exand it to two elements
|
| 25 |
+
return size + type(size)(size)
|
| 26 |
+
return size
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
thumbnail_size = [int(x) for x in size.lower().split("x", 1)]
|
| 30 |
+
except ValueError:
|
| 31 |
+
raise ValueError( # pylint: disable=raise-missing-from
|
| 32 |
+
"Bad thumbnail size format. Valid format is INTxINT."
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
if len(thumbnail_size) == 1:
|
| 36 |
+
# If the size parameter only contains a single integer, assume square aspect.
|
| 37 |
+
thumbnail_size.append(thumbnail_size[0])
|
| 38 |
+
|
| 39 |
+
return thumbnail_size
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def aspect_to_string(size):
|
| 43 |
+
if isinstance(size, str):
|
| 44 |
+
return size
|
| 45 |
+
|
| 46 |
+
return "x".join(map(str, size))
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
IMG_SUFFIX = {".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"}
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def glob_img(p: Union[Path, str], recursive: bool = False):
|
| 53 |
+
p = Path(p)
|
| 54 |
+
if p.is_file() and p.suffix in IMG_SUFFIX:
|
| 55 |
+
yield p
|
| 56 |
+
else:
|
| 57 |
+
if recursive:
|
| 58 |
+
files = Path(p).glob("**/*.*")
|
| 59 |
+
else:
|
| 60 |
+
files = Path(p).glob("*.*")
|
| 61 |
+
|
| 62 |
+
for it in files:
|
| 63 |
+
if it.suffix not in IMG_SUFFIX:
|
| 64 |
+
continue
|
| 65 |
+
yield it
|
iopaint/model/anytext/ldm/modules/diffusionmodules/upscaling.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import numpy as np
|
| 4 |
+
from functools import partial
|
| 5 |
+
|
| 6 |
+
from iopaint.model.anytext.ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
|
| 7 |
+
from iopaint.model.anytext.ldm.util import default
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class AbstractLowScaleModel(nn.Module):
|
| 11 |
+
# for concatenating a downsampled image to the latent representation
|
| 12 |
+
def __init__(self, noise_schedule_config=None):
|
| 13 |
+
super(AbstractLowScaleModel, self).__init__()
|
| 14 |
+
if noise_schedule_config is not None:
|
| 15 |
+
self.register_schedule(**noise_schedule_config)
|
| 16 |
+
|
| 17 |
+
def register_schedule(self, beta_schedule="linear", timesteps=1000,
|
| 18 |
+
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
| 19 |
+
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
|
| 20 |
+
cosine_s=cosine_s)
|
| 21 |
+
alphas = 1. - betas
|
| 22 |
+
alphas_cumprod = np.cumprod(alphas, axis=0)
|
| 23 |
+
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
| 24 |
+
|
| 25 |
+
timesteps, = betas.shape
|
| 26 |
+
self.num_timesteps = int(timesteps)
|
| 27 |
+
self.linear_start = linear_start
|
| 28 |
+
self.linear_end = linear_end
|
| 29 |
+
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
|
| 30 |
+
|
| 31 |
+
to_torch = partial(torch.tensor, dtype=torch.float32)
|
| 32 |
+
|
| 33 |
+
self.register_buffer('betas', to_torch(betas))
|
| 34 |
+
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
| 35 |
+
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
|
| 36 |
+
|
| 37 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
| 38 |
+
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
|
| 39 |
+
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
|
| 40 |
+
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
|
| 41 |
+
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
|
| 42 |
+
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
|
| 43 |
+
|
| 44 |
+
def q_sample(self, x_start, t, noise=None):
|
| 45 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
| 46 |
+
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
| 47 |
+
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
|
| 48 |
+
|
| 49 |
+
def forward(self, x):
|
| 50 |
+
return x, None
|
| 51 |
+
|
| 52 |
+
def decode(self, x):
|
| 53 |
+
return x
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class SimpleImageConcat(AbstractLowScaleModel):
|
| 57 |
+
# no noise level conditioning
|
| 58 |
+
def __init__(self):
|
| 59 |
+
super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
|
| 60 |
+
self.max_noise_level = 0
|
| 61 |
+
|
| 62 |
+
def forward(self, x):
|
| 63 |
+
# fix to constant noise level
|
| 64 |
+
return x, torch.zeros(x.shape[0], device=x.device).long()
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
|
| 68 |
+
def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
|
| 69 |
+
super().__init__(noise_schedule_config=noise_schedule_config)
|
| 70 |
+
self.max_noise_level = max_noise_level
|
| 71 |
+
|
| 72 |
+
def forward(self, x, noise_level=None):
|
| 73 |
+
if noise_level is None:
|
| 74 |
+
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
|
| 75 |
+
else:
|
| 76 |
+
assert isinstance(noise_level, torch.Tensor)
|
| 77 |
+
z = self.q_sample(x, noise_level)
|
| 78 |
+
return z, noise_level
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
|
iopaint/model/anytext/ldm/modules/diffusionmodules/util.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# adopted from
|
| 2 |
+
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
| 3 |
+
# and
|
| 4 |
+
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
| 5 |
+
# and
|
| 6 |
+
# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
|
| 7 |
+
#
|
| 8 |
+
# thanks!
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
import math
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import numpy as np
|
| 16 |
+
from einops import repeat
|
| 17 |
+
|
| 18 |
+
from iopaint.model.anytext.ldm.util import instantiate_from_config
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
| 22 |
+
if schedule == "linear":
|
| 23 |
+
betas = (
|
| 24 |
+
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
elif schedule == "cosine":
|
| 28 |
+
timesteps = (
|
| 29 |
+
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
|
| 30 |
+
)
|
| 31 |
+
alphas = timesteps / (1 + cosine_s) * np.pi / 2
|
| 32 |
+
alphas = torch.cos(alphas).pow(2)
|
| 33 |
+
alphas = alphas / alphas[0]
|
| 34 |
+
betas = 1 - alphas[1:] / alphas[:-1]
|
| 35 |
+
betas = np.clip(betas, a_min=0, a_max=0.999)
|
| 36 |
+
|
| 37 |
+
elif schedule == "sqrt_linear":
|
| 38 |
+
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
|
| 39 |
+
elif schedule == "sqrt":
|
| 40 |
+
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
|
| 41 |
+
else:
|
| 42 |
+
raise ValueError(f"schedule '{schedule}' unknown.")
|
| 43 |
+
return betas.numpy()
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
|
| 47 |
+
if ddim_discr_method == 'uniform':
|
| 48 |
+
c = num_ddpm_timesteps // num_ddim_timesteps
|
| 49 |
+
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
|
| 50 |
+
elif ddim_discr_method == 'quad':
|
| 51 |
+
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
|
| 52 |
+
else:
|
| 53 |
+
raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
|
| 54 |
+
|
| 55 |
+
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
|
| 56 |
+
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
| 57 |
+
steps_out = ddim_timesteps + 1
|
| 58 |
+
if verbose:
|
| 59 |
+
print(f'Selected timesteps for ddim sampler: {steps_out}')
|
| 60 |
+
return steps_out
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
|
| 64 |
+
# select alphas for computing the variance schedule
|
| 65 |
+
alphas = alphacums[ddim_timesteps]
|
| 66 |
+
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
|
| 67 |
+
|
| 68 |
+
# according the the formula provided in https://arxiv.org/abs/2010.02502
|
| 69 |
+
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
|
| 70 |
+
if verbose:
|
| 71 |
+
print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
|
| 72 |
+
print(f'For the chosen value of eta, which is {eta}, '
|
| 73 |
+
f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
|
| 74 |
+
return sigmas.to(torch.float32), alphas.to(torch.float32), alphas_prev.astype(np.float32)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
| 78 |
+
"""
|
| 79 |
+
Create a beta schedule that discretizes the given alpha_t_bar function,
|
| 80 |
+
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
| 81 |
+
:param num_diffusion_timesteps: the number of betas to produce.
|
| 82 |
+
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
| 83 |
+
produces the cumulative product of (1-beta) up to that
|
| 84 |
+
part of the diffusion process.
|
| 85 |
+
:param max_beta: the maximum beta to use; use values lower than 1 to
|
| 86 |
+
prevent singularities.
|
| 87 |
+
"""
|
| 88 |
+
betas = []
|
| 89 |
+
for i in range(num_diffusion_timesteps):
|
| 90 |
+
t1 = i / num_diffusion_timesteps
|
| 91 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
| 92 |
+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
| 93 |
+
return np.array(betas)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def extract_into_tensor(a, t, x_shape):
|
| 97 |
+
b, *_ = t.shape
|
| 98 |
+
out = a.gather(-1, t)
|
| 99 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def checkpoint(func, inputs, params, flag):
|
| 103 |
+
"""
|
| 104 |
+
Evaluate a function without caching intermediate activations, allowing for
|
| 105 |
+
reduced memory at the expense of extra compute in the backward pass.
|
| 106 |
+
:param func: the function to evaluate.
|
| 107 |
+
:param inputs: the argument sequence to pass to `func`.
|
| 108 |
+
:param params: a sequence of parameters `func` depends on but does not
|
| 109 |
+
explicitly take as arguments.
|
| 110 |
+
:param flag: if False, disable gradient checkpointing.
|
| 111 |
+
"""
|
| 112 |
+
if flag:
|
| 113 |
+
args = tuple(inputs) + tuple(params)
|
| 114 |
+
return CheckpointFunction.apply(func, len(inputs), *args)
|
| 115 |
+
else:
|
| 116 |
+
return func(*inputs)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class CheckpointFunction(torch.autograd.Function):
|
| 120 |
+
@staticmethod
|
| 121 |
+
def forward(ctx, run_function, length, *args):
|
| 122 |
+
ctx.run_function = run_function
|
| 123 |
+
ctx.input_tensors = list(args[:length])
|
| 124 |
+
ctx.input_params = list(args[length:])
|
| 125 |
+
ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
|
| 126 |
+
"dtype": torch.get_autocast_gpu_dtype(),
|
| 127 |
+
"cache_enabled": torch.is_autocast_cache_enabled()}
|
| 128 |
+
with torch.no_grad():
|
| 129 |
+
output_tensors = ctx.run_function(*ctx.input_tensors)
|
| 130 |
+
return output_tensors
|
| 131 |
+
|
| 132 |
+
@staticmethod
|
| 133 |
+
def backward(ctx, *output_grads):
|
| 134 |
+
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
| 135 |
+
with torch.enable_grad(), \
|
| 136 |
+
torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
|
| 137 |
+
# Fixes a bug where the first op in run_function modifies the
|
| 138 |
+
# Tensor storage in place, which is not allowed for detach()'d
|
| 139 |
+
# Tensors.
|
| 140 |
+
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
| 141 |
+
output_tensors = ctx.run_function(*shallow_copies)
|
| 142 |
+
input_grads = torch.autograd.grad(
|
| 143 |
+
output_tensors,
|
| 144 |
+
ctx.input_tensors + ctx.input_params,
|
| 145 |
+
output_grads,
|
| 146 |
+
allow_unused=True,
|
| 147 |
+
)
|
| 148 |
+
del ctx.input_tensors
|
| 149 |
+
del ctx.input_params
|
| 150 |
+
del output_tensors
|
| 151 |
+
return (None, None) + input_grads
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
| 155 |
+
"""
|
| 156 |
+
Create sinusoidal timestep embeddings.
|
| 157 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
| 158 |
+
These may be fractional.
|
| 159 |
+
:param dim: the dimension of the output.
|
| 160 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 161 |
+
:return: an [N x dim] Tensor of positional embeddings.
|
| 162 |
+
"""
|
| 163 |
+
if not repeat_only:
|
| 164 |
+
half = dim // 2
|
| 165 |
+
freqs = torch.exp(
|
| 166 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
| 167 |
+
).to(device=timesteps.device)
|
| 168 |
+
args = timesteps[:, None].float() * freqs[None]
|
| 169 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 170 |
+
if dim % 2:
|
| 171 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 172 |
+
else:
|
| 173 |
+
embedding = repeat(timesteps, 'b -> b d', d=dim)
|
| 174 |
+
return embedding
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def zero_module(module):
|
| 178 |
+
"""
|
| 179 |
+
Zero out the parameters of a module and return it.
|
| 180 |
+
"""
|
| 181 |
+
for p in module.parameters():
|
| 182 |
+
p.detach().zero_()
|
| 183 |
+
return module
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def scale_module(module, scale):
|
| 187 |
+
"""
|
| 188 |
+
Scale the parameters of a module and return it.
|
| 189 |
+
"""
|
| 190 |
+
for p in module.parameters():
|
| 191 |
+
p.detach().mul_(scale)
|
| 192 |
+
return module
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def mean_flat(tensor):
|
| 196 |
+
"""
|
| 197 |
+
Take the mean over all non-batch dimensions.
|
| 198 |
+
"""
|
| 199 |
+
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def normalization(channels):
|
| 203 |
+
"""
|
| 204 |
+
Make a standard normalization layer.
|
| 205 |
+
:param channels: number of input channels.
|
| 206 |
+
:return: an nn.Module for normalization.
|
| 207 |
+
"""
|
| 208 |
+
return GroupNorm32(32, channels)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
| 212 |
+
class SiLU(nn.Module):
|
| 213 |
+
def forward(self, x):
|
| 214 |
+
return x * torch.sigmoid(x)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class GroupNorm32(nn.GroupNorm):
|
| 218 |
+
def forward(self, x):
|
| 219 |
+
# return super().forward(x.float()).type(x.dtype)
|
| 220 |
+
return super().forward(x).type(x.dtype)
|
| 221 |
+
|
| 222 |
+
def conv_nd(dims, *args, **kwargs):
|
| 223 |
+
"""
|
| 224 |
+
Create a 1D, 2D, or 3D convolution module.
|
| 225 |
+
"""
|
| 226 |
+
if dims == 1:
|
| 227 |
+
return nn.Conv1d(*args, **kwargs)
|
| 228 |
+
elif dims == 2:
|
| 229 |
+
return nn.Conv2d(*args, **kwargs)
|
| 230 |
+
elif dims == 3:
|
| 231 |
+
return nn.Conv3d(*args, **kwargs)
|
| 232 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def linear(*args, **kwargs):
|
| 236 |
+
"""
|
| 237 |
+
Create a linear module.
|
| 238 |
+
"""
|
| 239 |
+
return nn.Linear(*args, **kwargs)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def avg_pool_nd(dims, *args, **kwargs):
|
| 243 |
+
"""
|
| 244 |
+
Create a 1D, 2D, or 3D average pooling module.
|
| 245 |
+
"""
|
| 246 |
+
if dims == 1:
|
| 247 |
+
return nn.AvgPool1d(*args, **kwargs)
|
| 248 |
+
elif dims == 2:
|
| 249 |
+
return nn.AvgPool2d(*args, **kwargs)
|
| 250 |
+
elif dims == 3:
|
| 251 |
+
return nn.AvgPool3d(*args, **kwargs)
|
| 252 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
class HybridConditioner(nn.Module):
|
| 256 |
+
|
| 257 |
+
def __init__(self, c_concat_config, c_crossattn_config):
|
| 258 |
+
super().__init__()
|
| 259 |
+
self.concat_conditioner = instantiate_from_config(c_concat_config)
|
| 260 |
+
self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
|
| 261 |
+
|
| 262 |
+
def forward(self, c_concat, c_crossattn):
|
| 263 |
+
c_concat = self.concat_conditioner(c_concat)
|
| 264 |
+
c_crossattn = self.crossattn_conditioner(c_crossattn)
|
| 265 |
+
return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def noise_like(shape, device, repeat=False):
|
| 269 |
+
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
| 270 |
+
noise = lambda: torch.randn(shape, device=device)
|
| 271 |
+
return repeat_noise() if repeat else noise()
|
iopaint/model/anytext/ldm/util.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import optim
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from inspect import isfunction
|
| 8 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def log_txt_as_img(wh, xc, size=10):
|
| 12 |
+
# wh a tuple of (width, height)
|
| 13 |
+
# xc a list of captions to plot
|
| 14 |
+
b = len(xc)
|
| 15 |
+
txts = list()
|
| 16 |
+
for bi in range(b):
|
| 17 |
+
txt = Image.new("RGB", wh, color="white")
|
| 18 |
+
draw = ImageDraw.Draw(txt)
|
| 19 |
+
font = ImageFont.truetype('font/Arial_Unicode.ttf', size=size)
|
| 20 |
+
nc = int(32 * (wh[0] / 256))
|
| 21 |
+
lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
|
| 22 |
+
|
| 23 |
+
try:
|
| 24 |
+
draw.text((0, 0), lines, fill="black", font=font)
|
| 25 |
+
except UnicodeEncodeError:
|
| 26 |
+
print("Cant encode string for logging. Skipping.")
|
| 27 |
+
|
| 28 |
+
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
| 29 |
+
txts.append(txt)
|
| 30 |
+
txts = np.stack(txts)
|
| 31 |
+
txts = torch.tensor(txts)
|
| 32 |
+
return txts
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def ismap(x):
|
| 36 |
+
if not isinstance(x, torch.Tensor):
|
| 37 |
+
return False
|
| 38 |
+
return (len(x.shape) == 4) and (x.shape[1] > 3)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def isimage(x):
|
| 42 |
+
if not isinstance(x,torch.Tensor):
|
| 43 |
+
return False
|
| 44 |
+
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def exists(x):
|
| 48 |
+
return x is not None
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def default(val, d):
|
| 52 |
+
if exists(val):
|
| 53 |
+
return val
|
| 54 |
+
return d() if isfunction(d) else d
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def mean_flat(tensor):
|
| 58 |
+
"""
|
| 59 |
+
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
|
| 60 |
+
Take the mean over all non-batch dimensions.
|
| 61 |
+
"""
|
| 62 |
+
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def count_params(model, verbose=False):
|
| 66 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 67 |
+
if verbose:
|
| 68 |
+
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
|
| 69 |
+
return total_params
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def instantiate_from_config(config, **kwargs):
|
| 73 |
+
if "target" not in config:
|
| 74 |
+
if config == '__is_first_stage__':
|
| 75 |
+
return None
|
| 76 |
+
elif config == "__is_unconditional__":
|
| 77 |
+
return None
|
| 78 |
+
raise KeyError("Expected key `target` to instantiate.")
|
| 79 |
+
return get_obj_from_str(config["target"])(**config.get("params", dict()), **kwargs)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def get_obj_from_str(string, reload=False):
|
| 83 |
+
module, cls = string.rsplit(".", 1)
|
| 84 |
+
if reload:
|
| 85 |
+
module_imp = importlib.import_module(module)
|
| 86 |
+
importlib.reload(module_imp)
|
| 87 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class AdamWwithEMAandWings(optim.Optimizer):
|
| 91 |
+
# credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
|
| 92 |
+
def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
|
| 93 |
+
weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
|
| 94 |
+
ema_power=1., param_names=()):
|
| 95 |
+
"""AdamW that saves EMA versions of the parameters."""
|
| 96 |
+
if not 0.0 <= lr:
|
| 97 |
+
raise ValueError("Invalid learning rate: {}".format(lr))
|
| 98 |
+
if not 0.0 <= eps:
|
| 99 |
+
raise ValueError("Invalid epsilon value: {}".format(eps))
|
| 100 |
+
if not 0.0 <= betas[0] < 1.0:
|
| 101 |
+
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
| 102 |
+
if not 0.0 <= betas[1] < 1.0:
|
| 103 |
+
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
| 104 |
+
if not 0.0 <= weight_decay:
|
| 105 |
+
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
| 106 |
+
if not 0.0 <= ema_decay <= 1.0:
|
| 107 |
+
raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
|
| 108 |
+
defaults = dict(lr=lr, betas=betas, eps=eps,
|
| 109 |
+
weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
|
| 110 |
+
ema_power=ema_power, param_names=param_names)
|
| 111 |
+
super().__init__(params, defaults)
|
| 112 |
+
|
| 113 |
+
def __setstate__(self, state):
|
| 114 |
+
super().__setstate__(state)
|
| 115 |
+
for group in self.param_groups:
|
| 116 |
+
group.setdefault('amsgrad', False)
|
| 117 |
+
|
| 118 |
+
@torch.no_grad()
|
| 119 |
+
def step(self, closure=None):
|
| 120 |
+
"""Performs a single optimization step.
|
| 121 |
+
Args:
|
| 122 |
+
closure (callable, optional): A closure that reevaluates the model
|
| 123 |
+
and returns the loss.
|
| 124 |
+
"""
|
| 125 |
+
loss = None
|
| 126 |
+
if closure is not None:
|
| 127 |
+
with torch.enable_grad():
|
| 128 |
+
loss = closure()
|
| 129 |
+
|
| 130 |
+
for group in self.param_groups:
|
| 131 |
+
params_with_grad = []
|
| 132 |
+
grads = []
|
| 133 |
+
exp_avgs = []
|
| 134 |
+
exp_avg_sqs = []
|
| 135 |
+
ema_params_with_grad = []
|
| 136 |
+
state_sums = []
|
| 137 |
+
max_exp_avg_sqs = []
|
| 138 |
+
state_steps = []
|
| 139 |
+
amsgrad = group['amsgrad']
|
| 140 |
+
beta1, beta2 = group['betas']
|
| 141 |
+
ema_decay = group['ema_decay']
|
| 142 |
+
ema_power = group['ema_power']
|
| 143 |
+
|
| 144 |
+
for p in group['params']:
|
| 145 |
+
if p.grad is None:
|
| 146 |
+
continue
|
| 147 |
+
params_with_grad.append(p)
|
| 148 |
+
if p.grad.is_sparse:
|
| 149 |
+
raise RuntimeError('AdamW does not support sparse gradients')
|
| 150 |
+
grads.append(p.grad)
|
| 151 |
+
|
| 152 |
+
state = self.state[p]
|
| 153 |
+
|
| 154 |
+
# State initialization
|
| 155 |
+
if len(state) == 0:
|
| 156 |
+
state['step'] = 0
|
| 157 |
+
# Exponential moving average of gradient values
|
| 158 |
+
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
| 159 |
+
# Exponential moving average of squared gradient values
|
| 160 |
+
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
| 161 |
+
if amsgrad:
|
| 162 |
+
# Maintains max of all exp. moving avg. of sq. grad. values
|
| 163 |
+
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
| 164 |
+
# Exponential moving average of parameter values
|
| 165 |
+
state['param_exp_avg'] = p.detach().float().clone()
|
| 166 |
+
|
| 167 |
+
exp_avgs.append(state['exp_avg'])
|
| 168 |
+
exp_avg_sqs.append(state['exp_avg_sq'])
|
| 169 |
+
ema_params_with_grad.append(state['param_exp_avg'])
|
| 170 |
+
|
| 171 |
+
if amsgrad:
|
| 172 |
+
max_exp_avg_sqs.append(state['max_exp_avg_sq'])
|
| 173 |
+
|
| 174 |
+
# update the steps for each param group update
|
| 175 |
+
state['step'] += 1
|
| 176 |
+
# record the step after step update
|
| 177 |
+
state_steps.append(state['step'])
|
| 178 |
+
|
| 179 |
+
optim._functional.adamw(params_with_grad,
|
| 180 |
+
grads,
|
| 181 |
+
exp_avgs,
|
| 182 |
+
exp_avg_sqs,
|
| 183 |
+
max_exp_avg_sqs,
|
| 184 |
+
state_steps,
|
| 185 |
+
amsgrad=amsgrad,
|
| 186 |
+
beta1=beta1,
|
| 187 |
+
beta2=beta2,
|
| 188 |
+
lr=group['lr'],
|
| 189 |
+
weight_decay=group['weight_decay'],
|
| 190 |
+
eps=group['eps'],
|
| 191 |
+
maximize=False)
|
| 192 |
+
|
| 193 |
+
cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
|
| 194 |
+
for param, ema_param in zip(params_with_grad, ema_params_with_grad):
|
| 195 |
+
ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
|
| 196 |
+
|
| 197 |
+
return loss
|
iopaint/model/anytext/utils.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import datetime
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image, ImageDraw
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def save_images(img_list, folder):
|
| 9 |
+
if not os.path.exists(folder):
|
| 10 |
+
os.makedirs(folder)
|
| 11 |
+
now = datetime.datetime.now()
|
| 12 |
+
date_str = now.strftime("%Y-%m-%d")
|
| 13 |
+
folder_path = os.path.join(folder, date_str)
|
| 14 |
+
if not os.path.exists(folder_path):
|
| 15 |
+
os.makedirs(folder_path)
|
| 16 |
+
time_str = now.strftime("%H_%M_%S")
|
| 17 |
+
for idx, img in enumerate(img_list):
|
| 18 |
+
image_number = idx + 1
|
| 19 |
+
filename = f"{time_str}_{image_number}.jpg"
|
| 20 |
+
save_path = os.path.join(folder_path, filename)
|
| 21 |
+
cv2.imwrite(save_path, img[..., ::-1])
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def check_channels(image):
|
| 25 |
+
channels = image.shape[2] if len(image.shape) == 3 else 1
|
| 26 |
+
if channels == 1:
|
| 27 |
+
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
|
| 28 |
+
elif channels > 3:
|
| 29 |
+
image = image[:, :, :3]
|
| 30 |
+
return image
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def resize_image(img, max_length=768):
|
| 34 |
+
height, width = img.shape[:2]
|
| 35 |
+
max_dimension = max(height, width)
|
| 36 |
+
|
| 37 |
+
if max_dimension > max_length:
|
| 38 |
+
scale_factor = max_length / max_dimension
|
| 39 |
+
new_width = int(round(width * scale_factor))
|
| 40 |
+
new_height = int(round(height * scale_factor))
|
| 41 |
+
new_size = (new_width, new_height)
|
| 42 |
+
img = cv2.resize(img, new_size)
|
| 43 |
+
height, width = img.shape[:2]
|
| 44 |
+
img = cv2.resize(img, (width - (width % 64), height - (height % 64)))
|
| 45 |
+
return img
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def insert_spaces(string, nSpace):
|
| 49 |
+
if nSpace == 0:
|
| 50 |
+
return string
|
| 51 |
+
new_string = ""
|
| 52 |
+
for char in string:
|
| 53 |
+
new_string += char + " " * nSpace
|
| 54 |
+
return new_string[:-nSpace]
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def draw_glyph(font, text):
|
| 58 |
+
g_size = 50
|
| 59 |
+
W, H = (512, 80)
|
| 60 |
+
new_font = font.font_variant(size=g_size)
|
| 61 |
+
img = Image.new(mode="1", size=(W, H), color=0)
|
| 62 |
+
draw = ImageDraw.Draw(img)
|
| 63 |
+
left, top, right, bottom = new_font.getbbox(text)
|
| 64 |
+
text_width = max(right - left, 5)
|
| 65 |
+
text_height = max(bottom - top, 5)
|
| 66 |
+
ratio = min(W * 0.9 / text_width, H * 0.9 / text_height)
|
| 67 |
+
new_font = font.font_variant(size=int(g_size * ratio))
|
| 68 |
+
|
| 69 |
+
text_width, text_height = new_font.getsize(text)
|
| 70 |
+
offset_x, offset_y = new_font.getoffset(text)
|
| 71 |
+
x = (img.width - text_width) // 2
|
| 72 |
+
y = (img.height - text_height) // 2 - offset_y // 2
|
| 73 |
+
draw.text((x, y), text, font=new_font, fill="white")
|
| 74 |
+
img = np.expand_dims(np.array(img), axis=2).astype(np.float64)
|
| 75 |
+
return img
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def draw_glyph2(
|
| 79 |
+
font, text, polygon, vertAng=10, scale=1, width=512, height=512, add_space=True
|
| 80 |
+
):
|
| 81 |
+
enlarge_polygon = polygon * scale
|
| 82 |
+
rect = cv2.minAreaRect(enlarge_polygon)
|
| 83 |
+
box = cv2.boxPoints(rect)
|
| 84 |
+
box = np.int0(box)
|
| 85 |
+
w, h = rect[1]
|
| 86 |
+
angle = rect[2]
|
| 87 |
+
if angle < -45:
|
| 88 |
+
angle += 90
|
| 89 |
+
angle = -angle
|
| 90 |
+
if w < h:
|
| 91 |
+
angle += 90
|
| 92 |
+
|
| 93 |
+
vert = False
|
| 94 |
+
if abs(angle) % 90 < vertAng or abs(90 - abs(angle) % 90) % 90 < vertAng:
|
| 95 |
+
_w = max(box[:, 0]) - min(box[:, 0])
|
| 96 |
+
_h = max(box[:, 1]) - min(box[:, 1])
|
| 97 |
+
if _h >= _w:
|
| 98 |
+
vert = True
|
| 99 |
+
angle = 0
|
| 100 |
+
|
| 101 |
+
img = np.zeros((height * scale, width * scale, 3), np.uint8)
|
| 102 |
+
img = Image.fromarray(img)
|
| 103 |
+
|
| 104 |
+
# infer font size
|
| 105 |
+
image4ratio = Image.new("RGB", img.size, "white")
|
| 106 |
+
draw = ImageDraw.Draw(image4ratio)
|
| 107 |
+
_, _, _tw, _th = draw.textbbox(xy=(0, 0), text=text, font=font)
|
| 108 |
+
text_w = min(w, h) * (_tw / _th)
|
| 109 |
+
if text_w <= max(w, h):
|
| 110 |
+
# add space
|
| 111 |
+
if len(text) > 1 and not vert and add_space:
|
| 112 |
+
for i in range(1, 100):
|
| 113 |
+
text_space = insert_spaces(text, i)
|
| 114 |
+
_, _, _tw2, _th2 = draw.textbbox(xy=(0, 0), text=text_space, font=font)
|
| 115 |
+
if min(w, h) * (_tw2 / _th2) > max(w, h):
|
| 116 |
+
break
|
| 117 |
+
text = insert_spaces(text, i - 1)
|
| 118 |
+
font_size = min(w, h) * 0.80
|
| 119 |
+
else:
|
| 120 |
+
shrink = 0.75 if vert else 0.85
|
| 121 |
+
font_size = min(w, h) / (text_w / max(w, h)) * shrink
|
| 122 |
+
new_font = font.font_variant(size=int(font_size))
|
| 123 |
+
|
| 124 |
+
left, top, right, bottom = new_font.getbbox(text)
|
| 125 |
+
text_width = right - left
|
| 126 |
+
text_height = bottom - top
|
| 127 |
+
|
| 128 |
+
layer = Image.new("RGBA", img.size, (0, 0, 0, 0))
|
| 129 |
+
draw = ImageDraw.Draw(layer)
|
| 130 |
+
if not vert:
|
| 131 |
+
draw.text(
|
| 132 |
+
(rect[0][0] - text_width // 2, rect[0][1] - text_height // 2 - top),
|
| 133 |
+
text,
|
| 134 |
+
font=new_font,
|
| 135 |
+
fill=(255, 255, 255, 255),
|
| 136 |
+
)
|
| 137 |
+
else:
|
| 138 |
+
x_s = min(box[:, 0]) + _w // 2 - text_height // 2
|
| 139 |
+
y_s = min(box[:, 1])
|
| 140 |
+
for c in text:
|
| 141 |
+
draw.text((x_s, y_s), c, font=new_font, fill=(255, 255, 255, 255))
|
| 142 |
+
_, _t, _, _b = new_font.getbbox(c)
|
| 143 |
+
y_s += _b
|
| 144 |
+
|
| 145 |
+
rotated_layer = layer.rotate(angle, expand=1, center=(rect[0][0], rect[0][1]))
|
| 146 |
+
|
| 147 |
+
x_offset = int((img.width - rotated_layer.width) / 2)
|
| 148 |
+
y_offset = int((img.height - rotated_layer.height) / 2)
|
| 149 |
+
img.paste(rotated_layer, (x_offset, y_offset), rotated_layer)
|
| 150 |
+
img = np.expand_dims(np.array(img.convert("1")), axis=2).astype(np.float64)
|
| 151 |
+
return img
|
iopaint/model/original_sd_configs/v1-inference.yaml
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
base_learning_rate: 1.0e-04
|
| 3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
| 4 |
+
params:
|
| 5 |
+
linear_start: 0.00085
|
| 6 |
+
linear_end: 0.0120
|
| 7 |
+
num_timesteps_cond: 1
|
| 8 |
+
log_every_t: 200
|
| 9 |
+
timesteps: 1000
|
| 10 |
+
first_stage_key: "jpg"
|
| 11 |
+
cond_stage_key: "txt"
|
| 12 |
+
image_size: 64
|
| 13 |
+
channels: 4
|
| 14 |
+
cond_stage_trainable: false # Note: different from the one we trained before
|
| 15 |
+
conditioning_key: crossattn
|
| 16 |
+
monitor: val/loss_simple_ema
|
| 17 |
+
scale_factor: 0.18215
|
| 18 |
+
use_ema: False
|
| 19 |
+
|
| 20 |
+
scheduler_config: # 10000 warmup steps
|
| 21 |
+
target: ldm.lr_scheduler.LambdaLinearScheduler
|
| 22 |
+
params:
|
| 23 |
+
warm_up_steps: [ 10000 ]
|
| 24 |
+
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
| 25 |
+
f_start: [ 1.e-6 ]
|
| 26 |
+
f_max: [ 1. ]
|
| 27 |
+
f_min: [ 1. ]
|
| 28 |
+
|
| 29 |
+
unet_config:
|
| 30 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
| 31 |
+
params:
|
| 32 |
+
image_size: 32 # unused
|
| 33 |
+
in_channels: 4
|
| 34 |
+
out_channels: 4
|
| 35 |
+
model_channels: 320
|
| 36 |
+
attention_resolutions: [ 4, 2, 1 ]
|
| 37 |
+
num_res_blocks: 2
|
| 38 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
| 39 |
+
num_heads: 8
|
| 40 |
+
use_spatial_transformer: True
|
| 41 |
+
transformer_depth: 1
|
| 42 |
+
context_dim: 768
|
| 43 |
+
use_checkpoint: True
|
| 44 |
+
legacy: False
|
| 45 |
+
|
| 46 |
+
first_stage_config:
|
| 47 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
| 48 |
+
params:
|
| 49 |
+
embed_dim: 4
|
| 50 |
+
monitor: val/rec_loss
|
| 51 |
+
ddconfig:
|
| 52 |
+
double_z: true
|
| 53 |
+
z_channels: 4
|
| 54 |
+
resolution: 256
|
| 55 |
+
in_channels: 3
|
| 56 |
+
out_ch: 3
|
| 57 |
+
ch: 128
|
| 58 |
+
ch_mult:
|
| 59 |
+
- 1
|
| 60 |
+
- 2
|
| 61 |
+
- 4
|
| 62 |
+
- 4
|
| 63 |
+
num_res_blocks: 2
|
| 64 |
+
attn_resolutions: []
|
| 65 |
+
dropout: 0.0
|
| 66 |
+
lossconfig:
|
| 67 |
+
target: torch.nn.Identity
|
| 68 |
+
|
| 69 |
+
cond_stage_config:
|
| 70 |
+
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
iopaint/model/original_sd_configs/v2-inference-v.yaml
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
base_learning_rate: 1.0e-4
|
| 3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
| 4 |
+
params:
|
| 5 |
+
parameterization: "v"
|
| 6 |
+
linear_start: 0.00085
|
| 7 |
+
linear_end: 0.0120
|
| 8 |
+
num_timesteps_cond: 1
|
| 9 |
+
log_every_t: 200
|
| 10 |
+
timesteps: 1000
|
| 11 |
+
first_stage_key: "jpg"
|
| 12 |
+
cond_stage_key: "txt"
|
| 13 |
+
image_size: 64
|
| 14 |
+
channels: 4
|
| 15 |
+
cond_stage_trainable: false
|
| 16 |
+
conditioning_key: crossattn
|
| 17 |
+
monitor: val/loss_simple_ema
|
| 18 |
+
scale_factor: 0.18215
|
| 19 |
+
use_ema: False # we set this to false because this is an inference only config
|
| 20 |
+
|
| 21 |
+
unet_config:
|
| 22 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
| 23 |
+
params:
|
| 24 |
+
use_checkpoint: True
|
| 25 |
+
use_fp16: True
|
| 26 |
+
image_size: 32 # unused
|
| 27 |
+
in_channels: 4
|
| 28 |
+
out_channels: 4
|
| 29 |
+
model_channels: 320
|
| 30 |
+
attention_resolutions: [ 4, 2, 1 ]
|
| 31 |
+
num_res_blocks: 2
|
| 32 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
| 33 |
+
num_head_channels: 64 # need to fix for flash-attn
|
| 34 |
+
use_spatial_transformer: True
|
| 35 |
+
use_linear_in_transformer: True
|
| 36 |
+
transformer_depth: 1
|
| 37 |
+
context_dim: 1024
|
| 38 |
+
legacy: False
|
| 39 |
+
|
| 40 |
+
first_stage_config:
|
| 41 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
| 42 |
+
params:
|
| 43 |
+
embed_dim: 4
|
| 44 |
+
monitor: val/rec_loss
|
| 45 |
+
ddconfig:
|
| 46 |
+
#attn_type: "vanilla-xformers"
|
| 47 |
+
double_z: true
|
| 48 |
+
z_channels: 4
|
| 49 |
+
resolution: 256
|
| 50 |
+
in_channels: 3
|
| 51 |
+
out_ch: 3
|
| 52 |
+
ch: 128
|
| 53 |
+
ch_mult:
|
| 54 |
+
- 1
|
| 55 |
+
- 2
|
| 56 |
+
- 4
|
| 57 |
+
- 4
|
| 58 |
+
num_res_blocks: 2
|
| 59 |
+
attn_resolutions: []
|
| 60 |
+
dropout: 0.0
|
| 61 |
+
lossconfig:
|
| 62 |
+
target: torch.nn.Identity
|
| 63 |
+
|
| 64 |
+
cond_stage_config:
|
| 65 |
+
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
| 66 |
+
params:
|
| 67 |
+
freeze: True
|
| 68 |
+
layer: "penultimate"
|
iopaint/model/utils.py
ADDED
|
@@ -0,0 +1,1033 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gc
|
| 2 |
+
import math
|
| 3 |
+
import random
|
| 4 |
+
import traceback
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import numpy as np
|
| 9 |
+
import collections
|
| 10 |
+
from itertools import repeat
|
| 11 |
+
|
| 12 |
+
from diffusers import (
|
| 13 |
+
DDIMScheduler,
|
| 14 |
+
PNDMScheduler,
|
| 15 |
+
LMSDiscreteScheduler,
|
| 16 |
+
EulerDiscreteScheduler,
|
| 17 |
+
EulerAncestralDiscreteScheduler,
|
| 18 |
+
DPMSolverMultistepScheduler,
|
| 19 |
+
UniPCMultistepScheduler,
|
| 20 |
+
LCMScheduler,
|
| 21 |
+
DPMSolverSinglestepScheduler,
|
| 22 |
+
KDPM2DiscreteScheduler,
|
| 23 |
+
KDPM2AncestralDiscreteScheduler,
|
| 24 |
+
HeunDiscreteScheduler,
|
| 25 |
+
)
|
| 26 |
+
from loguru import logger
|
| 27 |
+
|
| 28 |
+
from iopaint.schema import SDSampler
|
| 29 |
+
from torch import conv2d, conv_transpose2d
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def make_beta_schedule(
|
| 33 |
+
device, schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
|
| 34 |
+
):
|
| 35 |
+
if schedule == "linear":
|
| 36 |
+
betas = (
|
| 37 |
+
torch.linspace(
|
| 38 |
+
linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
|
| 39 |
+
)
|
| 40 |
+
** 2
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
elif schedule == "cosine":
|
| 44 |
+
timesteps = (
|
| 45 |
+
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
|
| 46 |
+
).to(device)
|
| 47 |
+
alphas = timesteps / (1 + cosine_s) * np.pi / 2
|
| 48 |
+
alphas = torch.cos(alphas).pow(2).to(device)
|
| 49 |
+
alphas = alphas / alphas[0]
|
| 50 |
+
betas = 1 - alphas[1:] / alphas[:-1]
|
| 51 |
+
betas = np.clip(betas, a_min=0, a_max=0.999)
|
| 52 |
+
|
| 53 |
+
elif schedule == "sqrt_linear":
|
| 54 |
+
betas = torch.linspace(
|
| 55 |
+
linear_start, linear_end, n_timestep, dtype=torch.float64
|
| 56 |
+
)
|
| 57 |
+
elif schedule == "sqrt":
|
| 58 |
+
betas = (
|
| 59 |
+
torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
|
| 60 |
+
** 0.5
|
| 61 |
+
)
|
| 62 |
+
else:
|
| 63 |
+
raise ValueError(f"schedule '{schedule}' unknown.")
|
| 64 |
+
return betas.numpy()
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
|
| 68 |
+
# select alphas for computing the variance schedule
|
| 69 |
+
alphas = alphacums[ddim_timesteps]
|
| 70 |
+
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
|
| 71 |
+
|
| 72 |
+
# according the the formula provided in https://arxiv.org/abs/2010.02502
|
| 73 |
+
sigmas = eta * np.sqrt(
|
| 74 |
+
(1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
|
| 75 |
+
)
|
| 76 |
+
if verbose:
|
| 77 |
+
print(
|
| 78 |
+
f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}"
|
| 79 |
+
)
|
| 80 |
+
print(
|
| 81 |
+
f"For the chosen value of eta, which is {eta}, "
|
| 82 |
+
f"this results in the following sigma_t schedule for ddim sampler {sigmas}"
|
| 83 |
+
)
|
| 84 |
+
return sigmas, alphas, alphas_prev
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def make_ddim_timesteps(
|
| 88 |
+
ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True
|
| 89 |
+
):
|
| 90 |
+
if ddim_discr_method == "uniform":
|
| 91 |
+
c = num_ddpm_timesteps // num_ddim_timesteps
|
| 92 |
+
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
|
| 93 |
+
elif ddim_discr_method == "quad":
|
| 94 |
+
ddim_timesteps = (
|
| 95 |
+
(np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2
|
| 96 |
+
).astype(int)
|
| 97 |
+
else:
|
| 98 |
+
raise NotImplementedError(
|
| 99 |
+
f'There is no ddim discretization method called "{ddim_discr_method}"'
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
|
| 103 |
+
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
| 104 |
+
steps_out = ddim_timesteps + 1
|
| 105 |
+
if verbose:
|
| 106 |
+
print(f"Selected timesteps for ddim sampler: {steps_out}")
|
| 107 |
+
return steps_out
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def noise_like(shape, device, repeat=False):
|
| 111 |
+
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
|
| 112 |
+
shape[0], *((1,) * (len(shape) - 1))
|
| 113 |
+
)
|
| 114 |
+
noise = lambda: torch.randn(shape, device=device)
|
| 115 |
+
return repeat_noise() if repeat else noise()
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def timestep_embedding(device, timesteps, dim, max_period=10000, repeat_only=False):
|
| 119 |
+
"""
|
| 120 |
+
Create sinusoidal timestep embeddings.
|
| 121 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
| 122 |
+
These may be fractional.
|
| 123 |
+
:param dim: the dimension of the output.
|
| 124 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 125 |
+
:return: an [N x dim] Tensor of positional embeddings.
|
| 126 |
+
"""
|
| 127 |
+
half = dim // 2
|
| 128 |
+
freqs = torch.exp(
|
| 129 |
+
-math.log(max_period)
|
| 130 |
+
* torch.arange(start=0, end=half, dtype=torch.float32)
|
| 131 |
+
/ half
|
| 132 |
+
).to(device=device)
|
| 133 |
+
|
| 134 |
+
args = timesteps[:, None].float() * freqs[None]
|
| 135 |
+
|
| 136 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 137 |
+
if dim % 2:
|
| 138 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 139 |
+
return embedding
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
###### MAT and FcF #######
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def normalize_2nd_moment(x, dim=1):
|
| 146 |
+
return (
|
| 147 |
+
x * (x.square().mean(dim=dim, keepdim=True) + torch.finfo(x.dtype).eps).rsqrt()
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class EasyDict(dict):
|
| 152 |
+
"""Convenience class that behaves like a dict but allows access with the attribute syntax."""
|
| 153 |
+
|
| 154 |
+
def __getattr__(self, name: str) -> Any:
|
| 155 |
+
try:
|
| 156 |
+
return self[name]
|
| 157 |
+
except KeyError:
|
| 158 |
+
raise AttributeError(name)
|
| 159 |
+
|
| 160 |
+
def __setattr__(self, name: str, value: Any) -> None:
|
| 161 |
+
self[name] = value
|
| 162 |
+
|
| 163 |
+
def __delattr__(self, name: str) -> None:
|
| 164 |
+
del self[name]
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def _bias_act_ref(x, b=None, dim=1, act="linear", alpha=None, gain=None, clamp=None):
|
| 168 |
+
"""Slow reference implementation of `bias_act()` using standard TensorFlow ops."""
|
| 169 |
+
assert isinstance(x, torch.Tensor)
|
| 170 |
+
assert clamp is None or clamp >= 0
|
| 171 |
+
spec = activation_funcs[act]
|
| 172 |
+
alpha = float(alpha if alpha is not None else spec.def_alpha)
|
| 173 |
+
gain = float(gain if gain is not None else spec.def_gain)
|
| 174 |
+
clamp = float(clamp if clamp is not None else -1)
|
| 175 |
+
|
| 176 |
+
# Add bias.
|
| 177 |
+
if b is not None:
|
| 178 |
+
assert isinstance(b, torch.Tensor) and b.ndim == 1
|
| 179 |
+
assert 0 <= dim < x.ndim
|
| 180 |
+
assert b.shape[0] == x.shape[dim]
|
| 181 |
+
x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
|
| 182 |
+
|
| 183 |
+
# Evaluate activation function.
|
| 184 |
+
alpha = float(alpha)
|
| 185 |
+
x = spec.func(x, alpha=alpha)
|
| 186 |
+
|
| 187 |
+
# Scale by gain.
|
| 188 |
+
gain = float(gain)
|
| 189 |
+
if gain != 1:
|
| 190 |
+
x = x * gain
|
| 191 |
+
|
| 192 |
+
# Clamp.
|
| 193 |
+
if clamp >= 0:
|
| 194 |
+
x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
|
| 195 |
+
return x
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def bias_act(
|
| 199 |
+
x, b=None, dim=1, act="linear", alpha=None, gain=None, clamp=None, impl="ref"
|
| 200 |
+
):
|
| 201 |
+
r"""Fused bias and activation function.
|
| 202 |
+
|
| 203 |
+
Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
|
| 204 |
+
and scales the result by `gain`. Each of the steps is optional. In most cases,
|
| 205 |
+
the fused op is considerably more efficient than performing the same calculation
|
| 206 |
+
using standard PyTorch ops. It supports first and second order gradients,
|
| 207 |
+
but not third order gradients.
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
x: Input activation tensor. Can be of any shape.
|
| 211 |
+
b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
|
| 212 |
+
as `x`. The shape must be known, and it must match the dimension of `x`
|
| 213 |
+
corresponding to `dim`.
|
| 214 |
+
dim: The dimension in `x` corresponding to the elements of `b`.
|
| 215 |
+
The value of `dim` is ignored if `b` is not specified.
|
| 216 |
+
act: Name of the activation function to evaluate, or `"linear"` to disable.
|
| 217 |
+
Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
|
| 218 |
+
See `activation_funcs` for a full list. `None` is not allowed.
|
| 219 |
+
alpha: Shape parameter for the activation function, or `None` to use the default.
|
| 220 |
+
gain: Scaling factor for the output tensor, or `None` to use default.
|
| 221 |
+
See `activation_funcs` for the default scaling of each activation function.
|
| 222 |
+
If unsure, consider specifying 1.
|
| 223 |
+
clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
|
| 224 |
+
the clamping (default).
|
| 225 |
+
impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
|
| 226 |
+
|
| 227 |
+
Returns:
|
| 228 |
+
Tensor of the same shape and datatype as `x`.
|
| 229 |
+
"""
|
| 230 |
+
assert isinstance(x, torch.Tensor)
|
| 231 |
+
assert impl in ["ref", "cuda"]
|
| 232 |
+
return _bias_act_ref(
|
| 233 |
+
x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def _get_filter_size(f):
|
| 238 |
+
if f is None:
|
| 239 |
+
return 1, 1
|
| 240 |
+
|
| 241 |
+
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
| 242 |
+
fw = f.shape[-1]
|
| 243 |
+
fh = f.shape[0]
|
| 244 |
+
|
| 245 |
+
fw = int(fw)
|
| 246 |
+
fh = int(fh)
|
| 247 |
+
assert fw >= 1 and fh >= 1
|
| 248 |
+
return fw, fh
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def _get_weight_shape(w):
|
| 252 |
+
shape = [int(sz) for sz in w.shape]
|
| 253 |
+
return shape
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def _parse_scaling(scaling):
|
| 257 |
+
if isinstance(scaling, int):
|
| 258 |
+
scaling = [scaling, scaling]
|
| 259 |
+
assert isinstance(scaling, (list, tuple))
|
| 260 |
+
assert all(isinstance(x, int) for x in scaling)
|
| 261 |
+
sx, sy = scaling
|
| 262 |
+
assert sx >= 1 and sy >= 1
|
| 263 |
+
return sx, sy
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def _parse_padding(padding):
|
| 267 |
+
if isinstance(padding, int):
|
| 268 |
+
padding = [padding, padding]
|
| 269 |
+
assert isinstance(padding, (list, tuple))
|
| 270 |
+
assert all(isinstance(x, int) for x in padding)
|
| 271 |
+
if len(padding) == 2:
|
| 272 |
+
padx, pady = padding
|
| 273 |
+
padding = [padx, padx, pady, pady]
|
| 274 |
+
padx0, padx1, pady0, pady1 = padding
|
| 275 |
+
return padx0, padx1, pady0, pady1
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def setup_filter(
|
| 279 |
+
f,
|
| 280 |
+
device=torch.device("cpu"),
|
| 281 |
+
normalize=True,
|
| 282 |
+
flip_filter=False,
|
| 283 |
+
gain=1,
|
| 284 |
+
separable=None,
|
| 285 |
+
):
|
| 286 |
+
r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
|
| 287 |
+
|
| 288 |
+
Args:
|
| 289 |
+
f: Torch tensor, numpy array, or python list of the shape
|
| 290 |
+
`[filter_height, filter_width]` (non-separable),
|
| 291 |
+
`[filter_taps]` (separable),
|
| 292 |
+
`[]` (impulse), or
|
| 293 |
+
`None` (identity).
|
| 294 |
+
device: Result device (default: cpu).
|
| 295 |
+
normalize: Normalize the filter so that it retains the magnitude
|
| 296 |
+
for constant input signal (DC)? (default: True).
|
| 297 |
+
flip_filter: Flip the filter? (default: False).
|
| 298 |
+
gain: Overall scaling factor for signal magnitude (default: 1).
|
| 299 |
+
separable: Return a separable filter? (default: select automatically).
|
| 300 |
+
|
| 301 |
+
Returns:
|
| 302 |
+
Float32 tensor of the shape
|
| 303 |
+
`[filter_height, filter_width]` (non-separable) or
|
| 304 |
+
`[filter_taps]` (separable).
|
| 305 |
+
"""
|
| 306 |
+
# Validate.
|
| 307 |
+
if f is None:
|
| 308 |
+
f = 1
|
| 309 |
+
f = torch.as_tensor(f, dtype=torch.float32)
|
| 310 |
+
assert f.ndim in [0, 1, 2]
|
| 311 |
+
assert f.numel() > 0
|
| 312 |
+
if f.ndim == 0:
|
| 313 |
+
f = f[np.newaxis]
|
| 314 |
+
|
| 315 |
+
# Separable?
|
| 316 |
+
if separable is None:
|
| 317 |
+
separable = f.ndim == 1 and f.numel() >= 8
|
| 318 |
+
if f.ndim == 1 and not separable:
|
| 319 |
+
f = f.ger(f)
|
| 320 |
+
assert f.ndim == (1 if separable else 2)
|
| 321 |
+
|
| 322 |
+
# Apply normalize, flip, gain, and device.
|
| 323 |
+
if normalize:
|
| 324 |
+
f /= f.sum()
|
| 325 |
+
if flip_filter:
|
| 326 |
+
f = f.flip(list(range(f.ndim)))
|
| 327 |
+
f = f * (gain ** (f.ndim / 2))
|
| 328 |
+
f = f.to(device=device)
|
| 329 |
+
return f
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def _ntuple(n):
|
| 333 |
+
def parse(x):
|
| 334 |
+
if isinstance(x, collections.abc.Iterable):
|
| 335 |
+
return x
|
| 336 |
+
return tuple(repeat(x, n))
|
| 337 |
+
|
| 338 |
+
return parse
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
to_2tuple = _ntuple(2)
|
| 342 |
+
|
| 343 |
+
activation_funcs = {
|
| 344 |
+
"linear": EasyDict(
|
| 345 |
+
func=lambda x, **_: x,
|
| 346 |
+
def_alpha=0,
|
| 347 |
+
def_gain=1,
|
| 348 |
+
cuda_idx=1,
|
| 349 |
+
ref="",
|
| 350 |
+
has_2nd_grad=False,
|
| 351 |
+
),
|
| 352 |
+
"relu": EasyDict(
|
| 353 |
+
func=lambda x, **_: torch.nn.functional.relu(x),
|
| 354 |
+
def_alpha=0,
|
| 355 |
+
def_gain=np.sqrt(2),
|
| 356 |
+
cuda_idx=2,
|
| 357 |
+
ref="y",
|
| 358 |
+
has_2nd_grad=False,
|
| 359 |
+
),
|
| 360 |
+
"lrelu": EasyDict(
|
| 361 |
+
func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha),
|
| 362 |
+
def_alpha=0.2,
|
| 363 |
+
def_gain=np.sqrt(2),
|
| 364 |
+
cuda_idx=3,
|
| 365 |
+
ref="y",
|
| 366 |
+
has_2nd_grad=False,
|
| 367 |
+
),
|
| 368 |
+
"tanh": EasyDict(
|
| 369 |
+
func=lambda x, **_: torch.tanh(x),
|
| 370 |
+
def_alpha=0,
|
| 371 |
+
def_gain=1,
|
| 372 |
+
cuda_idx=4,
|
| 373 |
+
ref="y",
|
| 374 |
+
has_2nd_grad=True,
|
| 375 |
+
),
|
| 376 |
+
"sigmoid": EasyDict(
|
| 377 |
+
func=lambda x, **_: torch.sigmoid(x),
|
| 378 |
+
def_alpha=0,
|
| 379 |
+
def_gain=1,
|
| 380 |
+
cuda_idx=5,
|
| 381 |
+
ref="y",
|
| 382 |
+
has_2nd_grad=True,
|
| 383 |
+
),
|
| 384 |
+
"elu": EasyDict(
|
| 385 |
+
func=lambda x, **_: torch.nn.functional.elu(x),
|
| 386 |
+
def_alpha=0,
|
| 387 |
+
def_gain=1,
|
| 388 |
+
cuda_idx=6,
|
| 389 |
+
ref="y",
|
| 390 |
+
has_2nd_grad=True,
|
| 391 |
+
),
|
| 392 |
+
"selu": EasyDict(
|
| 393 |
+
func=lambda x, **_: torch.nn.functional.selu(x),
|
| 394 |
+
def_alpha=0,
|
| 395 |
+
def_gain=1,
|
| 396 |
+
cuda_idx=7,
|
| 397 |
+
ref="y",
|
| 398 |
+
has_2nd_grad=True,
|
| 399 |
+
),
|
| 400 |
+
"softplus": EasyDict(
|
| 401 |
+
func=lambda x, **_: torch.nn.functional.softplus(x),
|
| 402 |
+
def_alpha=0,
|
| 403 |
+
def_gain=1,
|
| 404 |
+
cuda_idx=8,
|
| 405 |
+
ref="y",
|
| 406 |
+
has_2nd_grad=True,
|
| 407 |
+
),
|
| 408 |
+
"swish": EasyDict(
|
| 409 |
+
func=lambda x, **_: torch.sigmoid(x) * x,
|
| 410 |
+
def_alpha=0,
|
| 411 |
+
def_gain=np.sqrt(2),
|
| 412 |
+
cuda_idx=9,
|
| 413 |
+
ref="x",
|
| 414 |
+
has_2nd_grad=True,
|
| 415 |
+
),
|
| 416 |
+
}
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl="cuda"):
|
| 420 |
+
r"""Pad, upsample, filter, and downsample a batch of 2D images.
|
| 421 |
+
|
| 422 |
+
Performs the following sequence of operations for each channel:
|
| 423 |
+
|
| 424 |
+
1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
|
| 425 |
+
|
| 426 |
+
2. Pad the image with the specified number of zeros on each side (`padding`).
|
| 427 |
+
Negative padding corresponds to cropping the image.
|
| 428 |
+
|
| 429 |
+
3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
|
| 430 |
+
so that the footprint of all output pixels lies within the input image.
|
| 431 |
+
|
| 432 |
+
4. Downsample the image by keeping every Nth pixel (`down`).
|
| 433 |
+
|
| 434 |
+
This sequence of operations bears close resemblance to scipy.signal.upfirdn().
|
| 435 |
+
The fused op is considerably more efficient than performing the same calculation
|
| 436 |
+
using standard PyTorch ops. It supports gradients of arbitrary order.
|
| 437 |
+
|
| 438 |
+
Args:
|
| 439 |
+
x: Float32/float64/float16 input tensor of the shape
|
| 440 |
+
`[batch_size, num_channels, in_height, in_width]`.
|
| 441 |
+
f: Float32 FIR filter of the shape
|
| 442 |
+
`[filter_height, filter_width]` (non-separable),
|
| 443 |
+
`[filter_taps]` (separable), or
|
| 444 |
+
`None` (identity).
|
| 445 |
+
up: Integer upsampling factor. Can be a single int or a list/tuple
|
| 446 |
+
`[x, y]` (default: 1).
|
| 447 |
+
down: Integer downsampling factor. Can be a single int or a list/tuple
|
| 448 |
+
`[x, y]` (default: 1).
|
| 449 |
+
padding: Padding with respect to the upsampled image. Can be a single number
|
| 450 |
+
or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
| 451 |
+
(default: 0).
|
| 452 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
| 453 |
+
gain: Overall scaling factor for signal magnitude (default: 1).
|
| 454 |
+
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
| 455 |
+
|
| 456 |
+
Returns:
|
| 457 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
| 458 |
+
"""
|
| 459 |
+
# assert isinstance(x, torch.Tensor)
|
| 460 |
+
# assert impl in ['ref', 'cuda']
|
| 461 |
+
return _upfirdn2d_ref(
|
| 462 |
+
x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
|
| 467 |
+
"""Slow reference implementation of `upfirdn2d()` using standard PyTorch ops."""
|
| 468 |
+
# Validate arguments.
|
| 469 |
+
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
| 470 |
+
if f is None:
|
| 471 |
+
f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
|
| 472 |
+
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
| 473 |
+
assert not f.requires_grad
|
| 474 |
+
batch_size, num_channels, in_height, in_width = x.shape
|
| 475 |
+
# upx, upy = _parse_scaling(up)
|
| 476 |
+
# downx, downy = _parse_scaling(down)
|
| 477 |
+
|
| 478 |
+
upx, upy = up, up
|
| 479 |
+
downx, downy = down, down
|
| 480 |
+
|
| 481 |
+
# padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
| 482 |
+
padx0, padx1, pady0, pady1 = padding[0], padding[1], padding[2], padding[3]
|
| 483 |
+
|
| 484 |
+
# Upsample by inserting zeros.
|
| 485 |
+
x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
|
| 486 |
+
x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
|
| 487 |
+
x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
|
| 488 |
+
|
| 489 |
+
# Pad or crop.
|
| 490 |
+
x = torch.nn.functional.pad(
|
| 491 |
+
x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]
|
| 492 |
+
)
|
| 493 |
+
x = x[
|
| 494 |
+
:,
|
| 495 |
+
:,
|
| 496 |
+
max(-pady0, 0) : x.shape[2] - max(-pady1, 0),
|
| 497 |
+
max(-padx0, 0) : x.shape[3] - max(-padx1, 0),
|
| 498 |
+
]
|
| 499 |
+
|
| 500 |
+
# Setup filter.
|
| 501 |
+
f = f * (gain ** (f.ndim / 2))
|
| 502 |
+
f = f.to(x.dtype)
|
| 503 |
+
if not flip_filter:
|
| 504 |
+
f = f.flip(list(range(f.ndim)))
|
| 505 |
+
|
| 506 |
+
# Convolve with the filter.
|
| 507 |
+
f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
|
| 508 |
+
if f.ndim == 4:
|
| 509 |
+
x = conv2d(input=x, weight=f, groups=num_channels)
|
| 510 |
+
else:
|
| 511 |
+
x = conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
|
| 512 |
+
x = conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
|
| 513 |
+
|
| 514 |
+
# Downsample by throwing away pixels.
|
| 515 |
+
x = x[:, :, ::downy, ::downx]
|
| 516 |
+
return x
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl="cuda"):
|
| 520 |
+
r"""Downsample a batch of 2D images using the given 2D FIR filter.
|
| 521 |
+
|
| 522 |
+
By default, the result is padded so that its shape is a fraction of the input.
|
| 523 |
+
User-specified padding is applied on top of that, with negative values
|
| 524 |
+
indicating cropping. Pixels outside the image are assumed to be zero.
|
| 525 |
+
|
| 526 |
+
Args:
|
| 527 |
+
x: Float32/float64/float16 input tensor of the shape
|
| 528 |
+
`[batch_size, num_channels, in_height, in_width]`.
|
| 529 |
+
f: Float32 FIR filter of the shape
|
| 530 |
+
`[filter_height, filter_width]` (non-separable),
|
| 531 |
+
`[filter_taps]` (separable), or
|
| 532 |
+
`None` (identity).
|
| 533 |
+
down: Integer downsampling factor. Can be a single int or a list/tuple
|
| 534 |
+
`[x, y]` (default: 1).
|
| 535 |
+
padding: Padding with respect to the input. Can be a single number or a
|
| 536 |
+
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
| 537 |
+
(default: 0).
|
| 538 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
| 539 |
+
gain: Overall scaling factor for signal magnitude (default: 1).
|
| 540 |
+
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
| 541 |
+
|
| 542 |
+
Returns:
|
| 543 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
| 544 |
+
"""
|
| 545 |
+
downx, downy = _parse_scaling(down)
|
| 546 |
+
# padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
| 547 |
+
padx0, padx1, pady0, pady1 = padding, padding, padding, padding
|
| 548 |
+
|
| 549 |
+
fw, fh = _get_filter_size(f)
|
| 550 |
+
p = [
|
| 551 |
+
padx0 + (fw - downx + 1) // 2,
|
| 552 |
+
padx1 + (fw - downx) // 2,
|
| 553 |
+
pady0 + (fh - downy + 1) // 2,
|
| 554 |
+
pady1 + (fh - downy) // 2,
|
| 555 |
+
]
|
| 556 |
+
return upfirdn2d(
|
| 557 |
+
x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl="cuda"):
|
| 562 |
+
r"""Upsample a batch of 2D images using the given 2D FIR filter.
|
| 563 |
+
|
| 564 |
+
By default, the result is padded so that its shape is a multiple of the input.
|
| 565 |
+
User-specified padding is applied on top of that, with negative values
|
| 566 |
+
indicating cropping. Pixels outside the image are assumed to be zero.
|
| 567 |
+
|
| 568 |
+
Args:
|
| 569 |
+
x: Float32/float64/float16 input tensor of the shape
|
| 570 |
+
`[batch_size, num_channels, in_height, in_width]`.
|
| 571 |
+
f: Float32 FIR filter of the shape
|
| 572 |
+
`[filter_height, filter_width]` (non-separable),
|
| 573 |
+
`[filter_taps]` (separable), or
|
| 574 |
+
`None` (identity).
|
| 575 |
+
up: Integer upsampling factor. Can be a single int or a list/tuple
|
| 576 |
+
`[x, y]` (default: 1).
|
| 577 |
+
padding: Padding with respect to the output. Can be a single number or a
|
| 578 |
+
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
| 579 |
+
(default: 0).
|
| 580 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
| 581 |
+
gain: Overall scaling factor for signal magnitude (default: 1).
|
| 582 |
+
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
| 583 |
+
|
| 584 |
+
Returns:
|
| 585 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
| 586 |
+
"""
|
| 587 |
+
upx, upy = _parse_scaling(up)
|
| 588 |
+
# upx, upy = up, up
|
| 589 |
+
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
| 590 |
+
# padx0, padx1, pady0, pady1 = padding, padding, padding, padding
|
| 591 |
+
fw, fh = _get_filter_size(f)
|
| 592 |
+
p = [
|
| 593 |
+
padx0 + (fw + upx - 1) // 2,
|
| 594 |
+
padx1 + (fw - upx) // 2,
|
| 595 |
+
pady0 + (fh + upy - 1) // 2,
|
| 596 |
+
pady1 + (fh - upy) // 2,
|
| 597 |
+
]
|
| 598 |
+
return upfirdn2d(
|
| 599 |
+
x,
|
| 600 |
+
f,
|
| 601 |
+
up=up,
|
| 602 |
+
padding=p,
|
| 603 |
+
flip_filter=flip_filter,
|
| 604 |
+
gain=gain * upx * upy,
|
| 605 |
+
impl=impl,
|
| 606 |
+
)
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
class MinibatchStdLayer(torch.nn.Module):
|
| 610 |
+
def __init__(self, group_size, num_channels=1):
|
| 611 |
+
super().__init__()
|
| 612 |
+
self.group_size = group_size
|
| 613 |
+
self.num_channels = num_channels
|
| 614 |
+
|
| 615 |
+
def forward(self, x):
|
| 616 |
+
N, C, H, W = x.shape
|
| 617 |
+
G = (
|
| 618 |
+
torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N))
|
| 619 |
+
if self.group_size is not None
|
| 620 |
+
else N
|
| 621 |
+
)
|
| 622 |
+
F = self.num_channels
|
| 623 |
+
c = C // F
|
| 624 |
+
|
| 625 |
+
y = x.reshape(
|
| 626 |
+
G, -1, F, c, H, W
|
| 627 |
+
) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c.
|
| 628 |
+
y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group.
|
| 629 |
+
y = y.square().mean(dim=0) # [nFcHW] Calc variance over group.
|
| 630 |
+
y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group.
|
| 631 |
+
y = y.mean(dim=[2, 3, 4]) # [nF] Take average over channels and pixels.
|
| 632 |
+
y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions.
|
| 633 |
+
y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels.
|
| 634 |
+
x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels.
|
| 635 |
+
return x
|
| 636 |
+
|
| 637 |
+
|
| 638 |
+
class FullyConnectedLayer(torch.nn.Module):
|
| 639 |
+
def __init__(
|
| 640 |
+
self,
|
| 641 |
+
in_features, # Number of input features.
|
| 642 |
+
out_features, # Number of output features.
|
| 643 |
+
bias=True, # Apply additive bias before the activation function?
|
| 644 |
+
activation="linear", # Activation function: 'relu', 'lrelu', etc.
|
| 645 |
+
lr_multiplier=1, # Learning rate multiplier.
|
| 646 |
+
bias_init=0, # Initial value for the additive bias.
|
| 647 |
+
):
|
| 648 |
+
super().__init__()
|
| 649 |
+
self.weight = torch.nn.Parameter(
|
| 650 |
+
torch.randn([out_features, in_features]) / lr_multiplier
|
| 651 |
+
)
|
| 652 |
+
self.bias = (
|
| 653 |
+
torch.nn.Parameter(torch.full([out_features], np.float32(bias_init)))
|
| 654 |
+
if bias
|
| 655 |
+
else None
|
| 656 |
+
)
|
| 657 |
+
self.activation = activation
|
| 658 |
+
|
| 659 |
+
self.weight_gain = lr_multiplier / np.sqrt(in_features)
|
| 660 |
+
self.bias_gain = lr_multiplier
|
| 661 |
+
|
| 662 |
+
def forward(self, x):
|
| 663 |
+
w = self.weight * self.weight_gain
|
| 664 |
+
b = self.bias
|
| 665 |
+
if b is not None and self.bias_gain != 1:
|
| 666 |
+
b = b * self.bias_gain
|
| 667 |
+
|
| 668 |
+
if self.activation == "linear" and b is not None:
|
| 669 |
+
# out = torch.addmm(b.unsqueeze(0), x, w.t())
|
| 670 |
+
x = x.matmul(w.t())
|
| 671 |
+
out = x + b.reshape([-1 if i == x.ndim - 1 else 1 for i in range(x.ndim)])
|
| 672 |
+
else:
|
| 673 |
+
x = x.matmul(w.t())
|
| 674 |
+
out = bias_act(x, b, act=self.activation, dim=x.ndim - 1)
|
| 675 |
+
return out
|
| 676 |
+
|
| 677 |
+
|
| 678 |
+
def _conv2d_wrapper(
|
| 679 |
+
x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True
|
| 680 |
+
):
|
| 681 |
+
"""Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations."""
|
| 682 |
+
out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
|
| 683 |
+
|
| 684 |
+
# Flip weight if requested.
|
| 685 |
+
if (
|
| 686 |
+
not flip_weight
|
| 687 |
+
): # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
|
| 688 |
+
w = w.flip([2, 3])
|
| 689 |
+
|
| 690 |
+
# Workaround performance pitfall in cuDNN 8.0.5, triggered when using
|
| 691 |
+
# 1x1 kernel + memory_format=channels_last + less than 64 channels.
|
| 692 |
+
if (
|
| 693 |
+
kw == 1
|
| 694 |
+
and kh == 1
|
| 695 |
+
and stride == 1
|
| 696 |
+
and padding in [0, [0, 0], (0, 0)]
|
| 697 |
+
and not transpose
|
| 698 |
+
):
|
| 699 |
+
if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64:
|
| 700 |
+
if out_channels <= 4 and groups == 1:
|
| 701 |
+
in_shape = x.shape
|
| 702 |
+
x = w.squeeze(3).squeeze(2) @ x.reshape(
|
| 703 |
+
[in_shape[0], in_channels_per_group, -1]
|
| 704 |
+
)
|
| 705 |
+
x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]])
|
| 706 |
+
else:
|
| 707 |
+
x = x.to(memory_format=torch.contiguous_format)
|
| 708 |
+
w = w.to(memory_format=torch.contiguous_format)
|
| 709 |
+
x = conv2d(x, w, groups=groups)
|
| 710 |
+
return x.to(memory_format=torch.channels_last)
|
| 711 |
+
|
| 712 |
+
# Otherwise => execute using conv2d_gradfix.
|
| 713 |
+
op = conv_transpose2d if transpose else conv2d
|
| 714 |
+
return op(x, w, stride=stride, padding=padding, groups=groups)
|
| 715 |
+
|
| 716 |
+
|
| 717 |
+
def conv2d_resample(
|
| 718 |
+
x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False
|
| 719 |
+
):
|
| 720 |
+
r"""2D convolution with optional up/downsampling.
|
| 721 |
+
|
| 722 |
+
Padding is performed only once at the beginning, not between the operations.
|
| 723 |
+
|
| 724 |
+
Args:
|
| 725 |
+
x: Input tensor of shape
|
| 726 |
+
`[batch_size, in_channels, in_height, in_width]`.
|
| 727 |
+
w: Weight tensor of shape
|
| 728 |
+
`[out_channels, in_channels//groups, kernel_height, kernel_width]`.
|
| 729 |
+
f: Low-pass filter for up/downsampling. Must be prepared beforehand by
|
| 730 |
+
calling setup_filter(). None = identity (default).
|
| 731 |
+
up: Integer upsampling factor (default: 1).
|
| 732 |
+
down: Integer downsampling factor (default: 1).
|
| 733 |
+
padding: Padding with respect to the upsampled image. Can be a single number
|
| 734 |
+
or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
| 735 |
+
(default: 0).
|
| 736 |
+
groups: Split input channels into N groups (default: 1).
|
| 737 |
+
flip_weight: False = convolution, True = correlation (default: True).
|
| 738 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
| 739 |
+
|
| 740 |
+
Returns:
|
| 741 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
| 742 |
+
"""
|
| 743 |
+
# Validate arguments.
|
| 744 |
+
assert isinstance(x, torch.Tensor) and (x.ndim == 4)
|
| 745 |
+
assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
|
| 746 |
+
assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2])
|
| 747 |
+
assert isinstance(up, int) and (up >= 1)
|
| 748 |
+
assert isinstance(down, int) and (down >= 1)
|
| 749 |
+
# assert isinstance(groups, int) and (groups >= 1), f"!!!!!! groups: {groups} isinstance(groups, int) {isinstance(groups, int)} {type(groups)}"
|
| 750 |
+
out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
|
| 751 |
+
fw, fh = _get_filter_size(f)
|
| 752 |
+
# px0, px1, py0, py1 = _parse_padding(padding)
|
| 753 |
+
px0, px1, py0, py1 = padding, padding, padding, padding
|
| 754 |
+
|
| 755 |
+
# Adjust padding to account for up/downsampling.
|
| 756 |
+
if up > 1:
|
| 757 |
+
px0 += (fw + up - 1) // 2
|
| 758 |
+
px1 += (fw - up) // 2
|
| 759 |
+
py0 += (fh + up - 1) // 2
|
| 760 |
+
py1 += (fh - up) // 2
|
| 761 |
+
if down > 1:
|
| 762 |
+
px0 += (fw - down + 1) // 2
|
| 763 |
+
px1 += (fw - down) // 2
|
| 764 |
+
py0 += (fh - down + 1) // 2
|
| 765 |
+
py1 += (fh - down) // 2
|
| 766 |
+
|
| 767 |
+
# Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
|
| 768 |
+
if kw == 1 and kh == 1 and (down > 1 and up == 1):
|
| 769 |
+
x = upfirdn2d(
|
| 770 |
+
x=x, f=f, down=down, padding=[px0, px1, py0, py1], flip_filter=flip_filter
|
| 771 |
+
)
|
| 772 |
+
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
| 773 |
+
return x
|
| 774 |
+
|
| 775 |
+
# Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
|
| 776 |
+
if kw == 1 and kh == 1 and (up > 1 and down == 1):
|
| 777 |
+
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
| 778 |
+
x = upfirdn2d(
|
| 779 |
+
x=x,
|
| 780 |
+
f=f,
|
| 781 |
+
up=up,
|
| 782 |
+
padding=[px0, px1, py0, py1],
|
| 783 |
+
gain=up**2,
|
| 784 |
+
flip_filter=flip_filter,
|
| 785 |
+
)
|
| 786 |
+
return x
|
| 787 |
+
|
| 788 |
+
# Fast path: downsampling only => use strided convolution.
|
| 789 |
+
if down > 1 and up == 1:
|
| 790 |
+
x = upfirdn2d(x=x, f=f, padding=[px0, px1, py0, py1], flip_filter=flip_filter)
|
| 791 |
+
x = _conv2d_wrapper(
|
| 792 |
+
x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight
|
| 793 |
+
)
|
| 794 |
+
return x
|
| 795 |
+
|
| 796 |
+
# Fast path: upsampling with optional downsampling => use transpose strided convolution.
|
| 797 |
+
if up > 1:
|
| 798 |
+
if groups == 1:
|
| 799 |
+
w = w.transpose(0, 1)
|
| 800 |
+
else:
|
| 801 |
+
w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
|
| 802 |
+
w = w.transpose(1, 2)
|
| 803 |
+
w = w.reshape(
|
| 804 |
+
groups * in_channels_per_group, out_channels // groups, kh, kw
|
| 805 |
+
)
|
| 806 |
+
px0 -= kw - 1
|
| 807 |
+
px1 -= kw - up
|
| 808 |
+
py0 -= kh - 1
|
| 809 |
+
py1 -= kh - up
|
| 810 |
+
pxt = max(min(-px0, -px1), 0)
|
| 811 |
+
pyt = max(min(-py0, -py1), 0)
|
| 812 |
+
x = _conv2d_wrapper(
|
| 813 |
+
x=x,
|
| 814 |
+
w=w,
|
| 815 |
+
stride=up,
|
| 816 |
+
padding=[pyt, pxt],
|
| 817 |
+
groups=groups,
|
| 818 |
+
transpose=True,
|
| 819 |
+
flip_weight=(not flip_weight),
|
| 820 |
+
)
|
| 821 |
+
x = upfirdn2d(
|
| 822 |
+
x=x,
|
| 823 |
+
f=f,
|
| 824 |
+
padding=[px0 + pxt, px1 + pxt, py0 + pyt, py1 + pyt],
|
| 825 |
+
gain=up**2,
|
| 826 |
+
flip_filter=flip_filter,
|
| 827 |
+
)
|
| 828 |
+
if down > 1:
|
| 829 |
+
x = upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
|
| 830 |
+
return x
|
| 831 |
+
|
| 832 |
+
# Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
|
| 833 |
+
if up == 1 and down == 1:
|
| 834 |
+
if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
|
| 835 |
+
return _conv2d_wrapper(
|
| 836 |
+
x=x, w=w, padding=[py0, px0], groups=groups, flip_weight=flip_weight
|
| 837 |
+
)
|
| 838 |
+
|
| 839 |
+
# Fallback: Generic reference implementation.
|
| 840 |
+
x = upfirdn2d(
|
| 841 |
+
x=x,
|
| 842 |
+
f=(f if up > 1 else None),
|
| 843 |
+
up=up,
|
| 844 |
+
padding=[px0, px1, py0, py1],
|
| 845 |
+
gain=up**2,
|
| 846 |
+
flip_filter=flip_filter,
|
| 847 |
+
)
|
| 848 |
+
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
| 849 |
+
if down > 1:
|
| 850 |
+
x = upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
|
| 851 |
+
return x
|
| 852 |
+
|
| 853 |
+
|
| 854 |
+
class Conv2dLayer(torch.nn.Module):
|
| 855 |
+
def __init__(
|
| 856 |
+
self,
|
| 857 |
+
in_channels, # Number of input channels.
|
| 858 |
+
out_channels, # Number of output channels.
|
| 859 |
+
kernel_size, # Width and height of the convolution kernel.
|
| 860 |
+
bias=True, # Apply additive bias before the activation function?
|
| 861 |
+
activation="linear", # Activation function: 'relu', 'lrelu', etc.
|
| 862 |
+
up=1, # Integer upsampling factor.
|
| 863 |
+
down=1, # Integer downsampling factor.
|
| 864 |
+
resample_filter=[
|
| 865 |
+
1,
|
| 866 |
+
3,
|
| 867 |
+
3,
|
| 868 |
+
1,
|
| 869 |
+
], # Low-pass filter to apply when resampling activations.
|
| 870 |
+
conv_clamp=None, # Clamp the output to +-X, None = disable clamping.
|
| 871 |
+
channels_last=False, # Expect the input to have memory_format=channels_last?
|
| 872 |
+
trainable=True, # Update the weights of this layer during training?
|
| 873 |
+
):
|
| 874 |
+
super().__init__()
|
| 875 |
+
self.activation = activation
|
| 876 |
+
self.up = up
|
| 877 |
+
self.down = down
|
| 878 |
+
self.register_buffer("resample_filter", setup_filter(resample_filter))
|
| 879 |
+
self.conv_clamp = conv_clamp
|
| 880 |
+
self.padding = kernel_size // 2
|
| 881 |
+
self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size**2))
|
| 882 |
+
self.act_gain = activation_funcs[activation].def_gain
|
| 883 |
+
|
| 884 |
+
memory_format = (
|
| 885 |
+
torch.channels_last if channels_last else torch.contiguous_format
|
| 886 |
+
)
|
| 887 |
+
weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(
|
| 888 |
+
memory_format=memory_format
|
| 889 |
+
)
|
| 890 |
+
bias = torch.zeros([out_channels]) if bias else None
|
| 891 |
+
if trainable:
|
| 892 |
+
self.weight = torch.nn.Parameter(weight)
|
| 893 |
+
self.bias = torch.nn.Parameter(bias) if bias is not None else None
|
| 894 |
+
else:
|
| 895 |
+
self.register_buffer("weight", weight)
|
| 896 |
+
if bias is not None:
|
| 897 |
+
self.register_buffer("bias", bias)
|
| 898 |
+
else:
|
| 899 |
+
self.bias = None
|
| 900 |
+
|
| 901 |
+
def forward(self, x, gain=1):
|
| 902 |
+
w = self.weight * self.weight_gain
|
| 903 |
+
x = conv2d_resample(
|
| 904 |
+
x=x,
|
| 905 |
+
w=w,
|
| 906 |
+
f=self.resample_filter,
|
| 907 |
+
up=self.up,
|
| 908 |
+
down=self.down,
|
| 909 |
+
padding=self.padding,
|
| 910 |
+
)
|
| 911 |
+
|
| 912 |
+
act_gain = self.act_gain * gain
|
| 913 |
+
act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
|
| 914 |
+
out = bias_act(
|
| 915 |
+
x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp
|
| 916 |
+
)
|
| 917 |
+
return out
|
| 918 |
+
|
| 919 |
+
|
| 920 |
+
def torch_gc():
|
| 921 |
+
if torch.cuda.is_available():
|
| 922 |
+
torch.cuda.empty_cache()
|
| 923 |
+
torch.cuda.ipc_collect()
|
| 924 |
+
gc.collect()
|
| 925 |
+
|
| 926 |
+
|
| 927 |
+
def set_seed(seed: int):
|
| 928 |
+
random.seed(seed)
|
| 929 |
+
np.random.seed(seed)
|
| 930 |
+
torch.manual_seed(seed)
|
| 931 |
+
torch.cuda.manual_seed_all(seed)
|
| 932 |
+
|
| 933 |
+
|
| 934 |
+
def get_scheduler(sd_sampler, scheduler_config):
|
| 935 |
+
# https://github.com/huggingface/diffusers/issues/4167
|
| 936 |
+
keys_to_pop = ["use_karras_sigmas", "algorithm_type"]
|
| 937 |
+
scheduler_config = dict(scheduler_config)
|
| 938 |
+
for it in keys_to_pop:
|
| 939 |
+
scheduler_config.pop(it, None)
|
| 940 |
+
|
| 941 |
+
# fmt: off
|
| 942 |
+
samplers = {
|
| 943 |
+
SDSampler.dpm_plus_plus_2m: [DPMSolverMultistepScheduler],
|
| 944 |
+
SDSampler.dpm_plus_plus_2m_karras: [DPMSolverMultistepScheduler, dict(use_karras_sigmas=True)],
|
| 945 |
+
SDSampler.dpm_plus_plus_2m_sde: [DPMSolverMultistepScheduler, dict(algorithm_type="sde-dpmsolver++")],
|
| 946 |
+
SDSampler.dpm_plus_plus_2m_sde_karras: [DPMSolverMultistepScheduler, dict(algorithm_type="sde-dpmsolver++", use_karras_sigmas=True)],
|
| 947 |
+
SDSampler.dpm_plus_plus_sde: [DPMSolverSinglestepScheduler],
|
| 948 |
+
SDSampler.dpm_plus_plus_sde_karras: [DPMSolverSinglestepScheduler, dict(use_karras_sigmas=True)],
|
| 949 |
+
SDSampler.dpm2: [KDPM2DiscreteScheduler],
|
| 950 |
+
SDSampler.dpm2_karras: [KDPM2DiscreteScheduler, dict(use_karras_sigmas=True)],
|
| 951 |
+
SDSampler.dpm2_a: [KDPM2AncestralDiscreteScheduler],
|
| 952 |
+
SDSampler.dpm2_a_karras: [KDPM2AncestralDiscreteScheduler, dict(use_karras_sigmas=True)],
|
| 953 |
+
SDSampler.euler: [EulerDiscreteScheduler],
|
| 954 |
+
SDSampler.euler_a: [EulerAncestralDiscreteScheduler],
|
| 955 |
+
SDSampler.heun: [HeunDiscreteScheduler],
|
| 956 |
+
SDSampler.lms: [LMSDiscreteScheduler],
|
| 957 |
+
SDSampler.lms_karras: [LMSDiscreteScheduler, dict(use_karras_sigmas=True)],
|
| 958 |
+
SDSampler.ddim: [DDIMScheduler],
|
| 959 |
+
SDSampler.pndm: [PNDMScheduler],
|
| 960 |
+
SDSampler.uni_pc: [UniPCMultistepScheduler],
|
| 961 |
+
SDSampler.lcm: [LCMScheduler],
|
| 962 |
+
}
|
| 963 |
+
# fmt: on
|
| 964 |
+
if sd_sampler in samplers:
|
| 965 |
+
if len(samplers[sd_sampler]) == 2:
|
| 966 |
+
scheduler_cls, kwargs = samplers[sd_sampler]
|
| 967 |
+
else:
|
| 968 |
+
scheduler_cls, kwargs = samplers[sd_sampler][0], {}
|
| 969 |
+
return scheduler_cls.from_config(scheduler_config, **kwargs)
|
| 970 |
+
else:
|
| 971 |
+
raise ValueError(sd_sampler)
|
| 972 |
+
|
| 973 |
+
|
| 974 |
+
def is_local_files_only(**kwargs) -> bool:
|
| 975 |
+
from huggingface_hub.constants import HF_HUB_OFFLINE
|
| 976 |
+
|
| 977 |
+
return HF_HUB_OFFLINE or kwargs.get("local_files_only", False)
|
| 978 |
+
|
| 979 |
+
|
| 980 |
+
def handle_from_pretrained_exceptions(func, **kwargs):
|
| 981 |
+
try:
|
| 982 |
+
return func(**kwargs)
|
| 983 |
+
except ValueError as e:
|
| 984 |
+
if "You are trying to load the model files of the `variant=fp16`" in str(e):
|
| 985 |
+
logger.info("variant=fp16 not found, try revision=fp16")
|
| 986 |
+
try:
|
| 987 |
+
return func(**{**kwargs, "variant": None, "revision": "fp16"})
|
| 988 |
+
except Exception as e:
|
| 989 |
+
logger.info("revision=fp16 not found, try revision=main")
|
| 990 |
+
return func(**{**kwargs, "variant": None, "revision": "main"})
|
| 991 |
+
raise e
|
| 992 |
+
except OSError as e:
|
| 993 |
+
previous_traceback = traceback.format_exc()
|
| 994 |
+
if "RevisionNotFoundError: 404 Client Error." in previous_traceback:
|
| 995 |
+
logger.info("revision=fp16 not found, try revision=main")
|
| 996 |
+
return func(**{**kwargs, "variant": None, "revision": "main"})
|
| 997 |
+
elif "Max retries exceeded" in previous_traceback:
|
| 998 |
+
logger.exception(
|
| 999 |
+
"Fetching model from HuggingFace failed. "
|
| 1000 |
+
"If this is your first time downloading the model, you may need to set up proxy in terminal."
|
| 1001 |
+
"If the model has already been downloaded, you can add --local-files-only when starting."
|
| 1002 |
+
)
|
| 1003 |
+
exit(-1)
|
| 1004 |
+
raise e
|
| 1005 |
+
except Exception as e:
|
| 1006 |
+
raise e
|
| 1007 |
+
|
| 1008 |
+
|
| 1009 |
+
def get_torch_dtype(device, no_half: bool):
|
| 1010 |
+
device = str(device)
|
| 1011 |
+
use_fp16 = not no_half
|
| 1012 |
+
use_gpu = device == "cuda"
|
| 1013 |
+
# https://github.com/huggingface/diffusers/issues/4480
|
| 1014 |
+
# pipe.enable_attention_slicing and float16 will cause black output on mps
|
| 1015 |
+
# if device in ["cuda", "mps"] and use_fp16:
|
| 1016 |
+
if device in ["cuda"] and use_fp16:
|
| 1017 |
+
return use_gpu, torch.float16
|
| 1018 |
+
return use_gpu, torch.float32
|
| 1019 |
+
|
| 1020 |
+
|
| 1021 |
+
def enable_low_mem(pipe, enable: bool):
|
| 1022 |
+
if torch.backends.mps.is_available():
|
| 1023 |
+
# https://huggingface.co/docs/diffusers/v0.25.0/en/api/pipelines/stable_diffusion/image_variation#diffusers.StableDiffusionImageVariationPipeline.enable_attention_slicing
|
| 1024 |
+
# CUDA: Don't enable attention slicing if you're already using `scaled_dot_product_attention` (SDPA) from PyTorch 2.0 or xFormers.
|
| 1025 |
+
if enable:
|
| 1026 |
+
pipe.enable_attention_slicing("max")
|
| 1027 |
+
else:
|
| 1028 |
+
# https://huggingface.co/docs/diffusers/optimization/mps
|
| 1029 |
+
# Devices with less than 64GB of memory are recommended to use enable_attention_slicing
|
| 1030 |
+
pipe.enable_attention_slicing()
|
| 1031 |
+
|
| 1032 |
+
if enable:
|
| 1033 |
+
pipe.vae.enable_tiling()
|
iopaint/model/zits.py
ADDED
|
@@ -0,0 +1,476 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
from iopaint.helper import get_cache_path_by_url, load_jit_model, download_model
|
| 9 |
+
from iopaint.schema import InpaintRequest
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
from .base import InpaintModel
|
| 13 |
+
|
| 14 |
+
ZITS_INPAINT_MODEL_URL = os.environ.get(
|
| 15 |
+
"ZITS_INPAINT_MODEL_URL",
|
| 16 |
+
"https://github.com/Sanster/models/releases/download/add_zits/zits-inpaint-0717.pt",
|
| 17 |
+
)
|
| 18 |
+
ZITS_INPAINT_MODEL_MD5 = os.environ.get(
|
| 19 |
+
"ZITS_INPAINT_MODEL_MD5", "9978cc7157dc29699e42308d675b2154"
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
ZITS_EDGE_LINE_MODEL_URL = os.environ.get(
|
| 23 |
+
"ZITS_EDGE_LINE_MODEL_URL",
|
| 24 |
+
"https://github.com/Sanster/models/releases/download/add_zits/zits-edge-line-0717.pt",
|
| 25 |
+
)
|
| 26 |
+
ZITS_EDGE_LINE_MODEL_MD5 = os.environ.get(
|
| 27 |
+
"ZITS_EDGE_LINE_MODEL_MD5", "55e31af21ba96bbf0c80603c76ea8c5f"
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
ZITS_STRUCTURE_UPSAMPLE_MODEL_URL = os.environ.get(
|
| 31 |
+
"ZITS_STRUCTURE_UPSAMPLE_MODEL_URL",
|
| 32 |
+
"https://github.com/Sanster/models/releases/download/add_zits/zits-structure-upsample-0717.pt",
|
| 33 |
+
)
|
| 34 |
+
ZITS_STRUCTURE_UPSAMPLE_MODEL_MD5 = os.environ.get(
|
| 35 |
+
"ZITS_STRUCTURE_UPSAMPLE_MODEL_MD5", "3d88a07211bd41b2ec8cc0d999f29927"
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
ZITS_WIRE_FRAME_MODEL_URL = os.environ.get(
|
| 39 |
+
"ZITS_WIRE_FRAME_MODEL_URL",
|
| 40 |
+
"https://github.com/Sanster/models/releases/download/add_zits/zits-wireframe-0717.pt",
|
| 41 |
+
)
|
| 42 |
+
ZITS_WIRE_FRAME_MODEL_MD5 = os.environ.get(
|
| 43 |
+
"ZITS_WIRE_FRAME_MODEL_MD5", "a9727c63a8b48b65c905d351b21ce46b"
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def resize(img, height, width, center_crop=False):
|
| 48 |
+
imgh, imgw = img.shape[0:2]
|
| 49 |
+
|
| 50 |
+
if center_crop and imgh != imgw:
|
| 51 |
+
# center crop
|
| 52 |
+
side = np.minimum(imgh, imgw)
|
| 53 |
+
j = (imgh - side) // 2
|
| 54 |
+
i = (imgw - side) // 2
|
| 55 |
+
img = img[j : j + side, i : i + side, ...]
|
| 56 |
+
|
| 57 |
+
if imgh > height and imgw > width:
|
| 58 |
+
inter = cv2.INTER_AREA
|
| 59 |
+
else:
|
| 60 |
+
inter = cv2.INTER_LINEAR
|
| 61 |
+
img = cv2.resize(img, (height, width), interpolation=inter)
|
| 62 |
+
|
| 63 |
+
return img
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def to_tensor(img, scale=True, norm=False):
|
| 67 |
+
if img.ndim == 2:
|
| 68 |
+
img = img[:, :, np.newaxis]
|
| 69 |
+
c = img.shape[-1]
|
| 70 |
+
|
| 71 |
+
if scale:
|
| 72 |
+
img_t = torch.from_numpy(img).permute(2, 0, 1).float().div(255)
|
| 73 |
+
else:
|
| 74 |
+
img_t = torch.from_numpy(img).permute(2, 0, 1).float()
|
| 75 |
+
|
| 76 |
+
if norm:
|
| 77 |
+
mean = torch.tensor([0.5, 0.5, 0.5]).reshape(c, 1, 1)
|
| 78 |
+
std = torch.tensor([0.5, 0.5, 0.5]).reshape(c, 1, 1)
|
| 79 |
+
img_t = (img_t - mean) / std
|
| 80 |
+
return img_t
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def load_masked_position_encoding(mask):
|
| 84 |
+
ones_filter = np.ones((3, 3), dtype=np.float32)
|
| 85 |
+
d_filter1 = np.array([[1, 1, 0], [1, 1, 0], [0, 0, 0]], dtype=np.float32)
|
| 86 |
+
d_filter2 = np.array([[0, 0, 0], [1, 1, 0], [1, 1, 0]], dtype=np.float32)
|
| 87 |
+
d_filter3 = np.array([[0, 1, 1], [0, 1, 1], [0, 0, 0]], dtype=np.float32)
|
| 88 |
+
d_filter4 = np.array([[0, 0, 0], [0, 1, 1], [0, 1, 1]], dtype=np.float32)
|
| 89 |
+
str_size = 256
|
| 90 |
+
pos_num = 128
|
| 91 |
+
|
| 92 |
+
ori_mask = mask.copy()
|
| 93 |
+
ori_h, ori_w = ori_mask.shape[0:2]
|
| 94 |
+
ori_mask = ori_mask / 255
|
| 95 |
+
mask = cv2.resize(mask, (str_size, str_size), interpolation=cv2.INTER_AREA)
|
| 96 |
+
mask[mask > 0] = 255
|
| 97 |
+
h, w = mask.shape[0:2]
|
| 98 |
+
mask3 = mask.copy()
|
| 99 |
+
mask3 = 1.0 - (mask3 / 255.0)
|
| 100 |
+
pos = np.zeros((h, w), dtype=np.int32)
|
| 101 |
+
direct = np.zeros((h, w, 4), dtype=np.int32)
|
| 102 |
+
i = 0
|
| 103 |
+
while np.sum(1 - mask3) > 0:
|
| 104 |
+
i += 1
|
| 105 |
+
mask3_ = cv2.filter2D(mask3, -1, ones_filter)
|
| 106 |
+
mask3_[mask3_ > 0] = 1
|
| 107 |
+
sub_mask = mask3_ - mask3
|
| 108 |
+
pos[sub_mask == 1] = i
|
| 109 |
+
|
| 110 |
+
m = cv2.filter2D(mask3, -1, d_filter1)
|
| 111 |
+
m[m > 0] = 1
|
| 112 |
+
m = m - mask3
|
| 113 |
+
direct[m == 1, 0] = 1
|
| 114 |
+
|
| 115 |
+
m = cv2.filter2D(mask3, -1, d_filter2)
|
| 116 |
+
m[m > 0] = 1
|
| 117 |
+
m = m - mask3
|
| 118 |
+
direct[m == 1, 1] = 1
|
| 119 |
+
|
| 120 |
+
m = cv2.filter2D(mask3, -1, d_filter3)
|
| 121 |
+
m[m > 0] = 1
|
| 122 |
+
m = m - mask3
|
| 123 |
+
direct[m == 1, 2] = 1
|
| 124 |
+
|
| 125 |
+
m = cv2.filter2D(mask3, -1, d_filter4)
|
| 126 |
+
m[m > 0] = 1
|
| 127 |
+
m = m - mask3
|
| 128 |
+
direct[m == 1, 3] = 1
|
| 129 |
+
|
| 130 |
+
mask3 = mask3_
|
| 131 |
+
|
| 132 |
+
abs_pos = pos.copy()
|
| 133 |
+
rel_pos = pos / (str_size / 2) # to 0~1 maybe larger than 1
|
| 134 |
+
rel_pos = (rel_pos * pos_num).astype(np.int32)
|
| 135 |
+
rel_pos = np.clip(rel_pos, 0, pos_num - 1)
|
| 136 |
+
|
| 137 |
+
if ori_w != w or ori_h != h:
|
| 138 |
+
rel_pos = cv2.resize(rel_pos, (ori_w, ori_h), interpolation=cv2.INTER_NEAREST)
|
| 139 |
+
rel_pos[ori_mask == 0] = 0
|
| 140 |
+
direct = cv2.resize(direct, (ori_w, ori_h), interpolation=cv2.INTER_NEAREST)
|
| 141 |
+
direct[ori_mask == 0, :] = 0
|
| 142 |
+
|
| 143 |
+
return rel_pos, abs_pos, direct
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def load_image(img, mask, device, sigma256=3.0):
|
| 147 |
+
"""
|
| 148 |
+
Args:
|
| 149 |
+
img: [H, W, C] RGB
|
| 150 |
+
mask: [H, W] 255 为 masks 区域
|
| 151 |
+
sigma256:
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
|
| 155 |
+
"""
|
| 156 |
+
h, w, _ = img.shape
|
| 157 |
+
imgh, imgw = img.shape[0:2]
|
| 158 |
+
img_256 = resize(img, 256, 256)
|
| 159 |
+
|
| 160 |
+
mask = (mask > 127).astype(np.uint8) * 255
|
| 161 |
+
mask_256 = cv2.resize(mask, (256, 256), interpolation=cv2.INTER_AREA)
|
| 162 |
+
mask_256[mask_256 > 0] = 255
|
| 163 |
+
|
| 164 |
+
mask_512 = cv2.resize(mask, (512, 512), interpolation=cv2.INTER_AREA)
|
| 165 |
+
mask_512[mask_512 > 0] = 255
|
| 166 |
+
|
| 167 |
+
# original skimage implemention
|
| 168 |
+
# https://scikit-image.org/docs/stable/api/skimage.feature.html#skimage.feature.canny
|
| 169 |
+
# low_threshold: Lower bound for hysteresis thresholding (linking edges). If None, low_threshold is set to 10% of dtype’s max.
|
| 170 |
+
# high_threshold: Upper bound for hysteresis thresholding (linking edges). If None, high_threshold is set to 20% of dtype’s max.
|
| 171 |
+
|
| 172 |
+
try:
|
| 173 |
+
import skimage
|
| 174 |
+
|
| 175 |
+
gray_256 = skimage.color.rgb2gray(img_256)
|
| 176 |
+
edge_256 = skimage.feature.canny(gray_256, sigma=3.0, mask=None).astype(float)
|
| 177 |
+
# cv2.imwrite("skimage_gray.jpg", (gray_256*255).astype(np.uint8))
|
| 178 |
+
# cv2.imwrite("skimage_edge.jpg", (edge_256*255).astype(np.uint8))
|
| 179 |
+
except:
|
| 180 |
+
gray_256 = cv2.cvtColor(img_256, cv2.COLOR_RGB2GRAY)
|
| 181 |
+
gray_256_blured = cv2.GaussianBlur(
|
| 182 |
+
gray_256, ksize=(7, 7), sigmaX=sigma256, sigmaY=sigma256
|
| 183 |
+
)
|
| 184 |
+
edge_256 = cv2.Canny(
|
| 185 |
+
gray_256_blured, threshold1=int(255 * 0.1), threshold2=int(255 * 0.2)
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# cv2.imwrite("opencv_edge.jpg", edge_256)
|
| 189 |
+
|
| 190 |
+
# line
|
| 191 |
+
img_512 = resize(img, 512, 512)
|
| 192 |
+
|
| 193 |
+
rel_pos, abs_pos, direct = load_masked_position_encoding(mask)
|
| 194 |
+
|
| 195 |
+
batch = dict()
|
| 196 |
+
batch["images"] = to_tensor(img.copy()).unsqueeze(0).to(device)
|
| 197 |
+
batch["img_256"] = to_tensor(img_256, norm=True).unsqueeze(0).to(device)
|
| 198 |
+
batch["masks"] = to_tensor(mask).unsqueeze(0).to(device)
|
| 199 |
+
batch["mask_256"] = to_tensor(mask_256).unsqueeze(0).to(device)
|
| 200 |
+
batch["mask_512"] = to_tensor(mask_512).unsqueeze(0).to(device)
|
| 201 |
+
batch["edge_256"] = to_tensor(edge_256, scale=False).unsqueeze(0).to(device)
|
| 202 |
+
batch["img_512"] = to_tensor(img_512).unsqueeze(0).to(device)
|
| 203 |
+
batch["rel_pos"] = torch.LongTensor(rel_pos).unsqueeze(0).to(device)
|
| 204 |
+
batch["abs_pos"] = torch.LongTensor(abs_pos).unsqueeze(0).to(device)
|
| 205 |
+
batch["direct"] = torch.LongTensor(direct).unsqueeze(0).to(device)
|
| 206 |
+
batch["h"] = imgh
|
| 207 |
+
batch["w"] = imgw
|
| 208 |
+
|
| 209 |
+
return batch
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def to_device(data, device):
|
| 213 |
+
if isinstance(data, torch.Tensor):
|
| 214 |
+
return data.to(device)
|
| 215 |
+
if isinstance(data, dict):
|
| 216 |
+
for key in data:
|
| 217 |
+
if isinstance(data[key], torch.Tensor):
|
| 218 |
+
data[key] = data[key].to(device)
|
| 219 |
+
return data
|
| 220 |
+
if isinstance(data, list):
|
| 221 |
+
return [to_device(d, device) for d in data]
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
class ZITS(InpaintModel):
|
| 225 |
+
name = "zits"
|
| 226 |
+
min_size = 256
|
| 227 |
+
pad_mod = 32
|
| 228 |
+
pad_to_square = True
|
| 229 |
+
is_erase_model = True
|
| 230 |
+
|
| 231 |
+
def __init__(self, device, **kwargs):
|
| 232 |
+
"""
|
| 233 |
+
|
| 234 |
+
Args:
|
| 235 |
+
device:
|
| 236 |
+
"""
|
| 237 |
+
super().__init__(device)
|
| 238 |
+
self.device = device
|
| 239 |
+
self.sample_edge_line_iterations = 1
|
| 240 |
+
|
| 241 |
+
def init_model(self, device, **kwargs):
|
| 242 |
+
self.wireframe = load_jit_model(
|
| 243 |
+
ZITS_WIRE_FRAME_MODEL_URL, device, ZITS_WIRE_FRAME_MODEL_MD5
|
| 244 |
+
)
|
| 245 |
+
self.edge_line = load_jit_model(
|
| 246 |
+
ZITS_EDGE_LINE_MODEL_URL, device, ZITS_EDGE_LINE_MODEL_MD5
|
| 247 |
+
)
|
| 248 |
+
self.structure_upsample = load_jit_model(
|
| 249 |
+
ZITS_STRUCTURE_UPSAMPLE_MODEL_URL, device, ZITS_STRUCTURE_UPSAMPLE_MODEL_MD5
|
| 250 |
+
)
|
| 251 |
+
self.inpaint = load_jit_model(
|
| 252 |
+
ZITS_INPAINT_MODEL_URL, device, ZITS_INPAINT_MODEL_MD5
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
@staticmethod
|
| 256 |
+
def download():
|
| 257 |
+
download_model(ZITS_WIRE_FRAME_MODEL_URL, ZITS_WIRE_FRAME_MODEL_MD5)
|
| 258 |
+
download_model(ZITS_EDGE_LINE_MODEL_URL, ZITS_EDGE_LINE_MODEL_MD5)
|
| 259 |
+
download_model(
|
| 260 |
+
ZITS_STRUCTURE_UPSAMPLE_MODEL_URL, ZITS_STRUCTURE_UPSAMPLE_MODEL_MD5
|
| 261 |
+
)
|
| 262 |
+
download_model(ZITS_INPAINT_MODEL_URL, ZITS_INPAINT_MODEL_MD5)
|
| 263 |
+
|
| 264 |
+
@staticmethod
|
| 265 |
+
def is_downloaded() -> bool:
|
| 266 |
+
model_paths = [
|
| 267 |
+
get_cache_path_by_url(ZITS_WIRE_FRAME_MODEL_URL),
|
| 268 |
+
get_cache_path_by_url(ZITS_EDGE_LINE_MODEL_URL),
|
| 269 |
+
get_cache_path_by_url(ZITS_STRUCTURE_UPSAMPLE_MODEL_URL),
|
| 270 |
+
get_cache_path_by_url(ZITS_INPAINT_MODEL_URL),
|
| 271 |
+
]
|
| 272 |
+
return all([os.path.exists(it) for it in model_paths])
|
| 273 |
+
|
| 274 |
+
def wireframe_edge_and_line(self, items, enable: bool):
|
| 275 |
+
# 最终向 items 中添加 edge 和 line key
|
| 276 |
+
if not enable:
|
| 277 |
+
items["edge"] = torch.zeros_like(items["masks"])
|
| 278 |
+
items["line"] = torch.zeros_like(items["masks"])
|
| 279 |
+
return
|
| 280 |
+
|
| 281 |
+
start = time.time()
|
| 282 |
+
try:
|
| 283 |
+
line_256 = self.wireframe_forward(
|
| 284 |
+
items["img_512"],
|
| 285 |
+
h=256,
|
| 286 |
+
w=256,
|
| 287 |
+
masks=items["mask_512"],
|
| 288 |
+
mask_th=0.85,
|
| 289 |
+
)
|
| 290 |
+
except:
|
| 291 |
+
line_256 = torch.zeros_like(items["mask_256"])
|
| 292 |
+
|
| 293 |
+
print(f"wireframe_forward time: {(time.time() - start) * 1000:.2f}ms")
|
| 294 |
+
|
| 295 |
+
# np_line = (line[0][0].numpy() * 255).astype(np.uint8)
|
| 296 |
+
# cv2.imwrite("line.jpg", np_line)
|
| 297 |
+
|
| 298 |
+
start = time.time()
|
| 299 |
+
edge_pred, line_pred = self.sample_edge_line_logits(
|
| 300 |
+
context=[items["img_256"], items["edge_256"], line_256],
|
| 301 |
+
mask=items["mask_256"].clone(),
|
| 302 |
+
iterations=self.sample_edge_line_iterations,
|
| 303 |
+
add_v=0.05,
|
| 304 |
+
mul_v=4,
|
| 305 |
+
)
|
| 306 |
+
print(f"sample_edge_line_logits time: {(time.time() - start) * 1000:.2f}ms")
|
| 307 |
+
|
| 308 |
+
# np_edge_pred = (edge_pred[0][0].numpy() * 255).astype(np.uint8)
|
| 309 |
+
# cv2.imwrite("edge_pred.jpg", np_edge_pred)
|
| 310 |
+
# np_line_pred = (line_pred[0][0].numpy() * 255).astype(np.uint8)
|
| 311 |
+
# cv2.imwrite("line_pred.jpg", np_line_pred)
|
| 312 |
+
# exit()
|
| 313 |
+
|
| 314 |
+
input_size = min(items["h"], items["w"])
|
| 315 |
+
if input_size != 256 and input_size > 256:
|
| 316 |
+
while edge_pred.shape[2] < input_size:
|
| 317 |
+
edge_pred = self.structure_upsample(edge_pred)
|
| 318 |
+
edge_pred = torch.sigmoid((edge_pred + 2) * 2)
|
| 319 |
+
|
| 320 |
+
line_pred = self.structure_upsample(line_pred)
|
| 321 |
+
line_pred = torch.sigmoid((line_pred + 2) * 2)
|
| 322 |
+
|
| 323 |
+
edge_pred = F.interpolate(
|
| 324 |
+
edge_pred,
|
| 325 |
+
size=(input_size, input_size),
|
| 326 |
+
mode="bilinear",
|
| 327 |
+
align_corners=False,
|
| 328 |
+
)
|
| 329 |
+
line_pred = F.interpolate(
|
| 330 |
+
line_pred,
|
| 331 |
+
size=(input_size, input_size),
|
| 332 |
+
mode="bilinear",
|
| 333 |
+
align_corners=False,
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
# np_edge_pred = (edge_pred[0][0].numpy() * 255).astype(np.uint8)
|
| 337 |
+
# cv2.imwrite("edge_pred_upsample.jpg", np_edge_pred)
|
| 338 |
+
# np_line_pred = (line_pred[0][0].numpy() * 255).astype(np.uint8)
|
| 339 |
+
# cv2.imwrite("line_pred_upsample.jpg", np_line_pred)
|
| 340 |
+
# exit()
|
| 341 |
+
|
| 342 |
+
items["edge"] = edge_pred.detach()
|
| 343 |
+
items["line"] = line_pred.detach()
|
| 344 |
+
|
| 345 |
+
@torch.no_grad()
|
| 346 |
+
def forward(self, image, mask, config: InpaintRequest):
|
| 347 |
+
"""Input images and output images have same size
|
| 348 |
+
images: [H, W, C] RGB
|
| 349 |
+
masks: [H, W]
|
| 350 |
+
return: BGR IMAGE
|
| 351 |
+
"""
|
| 352 |
+
mask = mask[:, :, 0]
|
| 353 |
+
items = load_image(image, mask, device=self.device)
|
| 354 |
+
|
| 355 |
+
self.wireframe_edge_and_line(items, config.zits_wireframe)
|
| 356 |
+
|
| 357 |
+
inpainted_image = self.inpaint(
|
| 358 |
+
items["images"],
|
| 359 |
+
items["masks"],
|
| 360 |
+
items["edge"],
|
| 361 |
+
items["line"],
|
| 362 |
+
items["rel_pos"],
|
| 363 |
+
items["direct"],
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
inpainted_image = inpainted_image * 255.0
|
| 367 |
+
inpainted_image = (
|
| 368 |
+
inpainted_image.cpu().permute(0, 2, 3, 1)[0].numpy().astype(np.uint8)
|
| 369 |
+
)
|
| 370 |
+
inpainted_image = inpainted_image[:, :, ::-1]
|
| 371 |
+
|
| 372 |
+
# cv2.imwrite("inpainted.jpg", inpainted_image)
|
| 373 |
+
# exit()
|
| 374 |
+
|
| 375 |
+
return inpainted_image
|
| 376 |
+
|
| 377 |
+
def wireframe_forward(self, images, h, w, masks, mask_th=0.925):
|
| 378 |
+
lcnn_mean = torch.tensor([109.730, 103.832, 98.681]).reshape(1, 3, 1, 1)
|
| 379 |
+
lcnn_std = torch.tensor([22.275, 22.124, 23.229]).reshape(1, 3, 1, 1)
|
| 380 |
+
images = images * 255.0
|
| 381 |
+
# the masks value of lcnn is 127.5
|
| 382 |
+
masked_images = images * (1 - masks) + torch.ones_like(images) * masks * 127.5
|
| 383 |
+
masked_images = (masked_images - lcnn_mean) / lcnn_std
|
| 384 |
+
|
| 385 |
+
def to_int(x):
|
| 386 |
+
return tuple(map(int, x))
|
| 387 |
+
|
| 388 |
+
lines_tensor = []
|
| 389 |
+
lmap = np.zeros((h, w))
|
| 390 |
+
|
| 391 |
+
output_masked = self.wireframe(masked_images)
|
| 392 |
+
|
| 393 |
+
output_masked = to_device(output_masked, "cpu")
|
| 394 |
+
if output_masked["num_proposals"] == 0:
|
| 395 |
+
lines_masked = []
|
| 396 |
+
scores_masked = []
|
| 397 |
+
else:
|
| 398 |
+
lines_masked = output_masked["lines_pred"].numpy()
|
| 399 |
+
lines_masked = [
|
| 400 |
+
[line[1] * h, line[0] * w, line[3] * h, line[2] * w]
|
| 401 |
+
for line in lines_masked
|
| 402 |
+
]
|
| 403 |
+
scores_masked = output_masked["lines_score"].numpy()
|
| 404 |
+
|
| 405 |
+
for line, score in zip(lines_masked, scores_masked):
|
| 406 |
+
if score > mask_th:
|
| 407 |
+
try:
|
| 408 |
+
import skimage
|
| 409 |
+
|
| 410 |
+
rr, cc, value = skimage.draw.line_aa(
|
| 411 |
+
*to_int(line[0:2]), *to_int(line[2:4])
|
| 412 |
+
)
|
| 413 |
+
lmap[rr, cc] = np.maximum(lmap[rr, cc], value)
|
| 414 |
+
except:
|
| 415 |
+
cv2.line(
|
| 416 |
+
lmap,
|
| 417 |
+
to_int(line[0:2][::-1]),
|
| 418 |
+
to_int(line[2:4][::-1]),
|
| 419 |
+
(1, 1, 1),
|
| 420 |
+
1,
|
| 421 |
+
cv2.LINE_AA,
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
lmap = np.clip(lmap * 255, 0, 255).astype(np.uint8)
|
| 425 |
+
lines_tensor.append(to_tensor(lmap).unsqueeze(0))
|
| 426 |
+
|
| 427 |
+
lines_tensor = torch.cat(lines_tensor, dim=0)
|
| 428 |
+
return lines_tensor.detach().to(self.device)
|
| 429 |
+
|
| 430 |
+
def sample_edge_line_logits(
|
| 431 |
+
self, context, mask=None, iterations=1, add_v=0, mul_v=4
|
| 432 |
+
):
|
| 433 |
+
[img, edge, line] = context
|
| 434 |
+
|
| 435 |
+
img = img * (1 - mask)
|
| 436 |
+
edge = edge * (1 - mask)
|
| 437 |
+
line = line * (1 - mask)
|
| 438 |
+
|
| 439 |
+
for i in range(iterations):
|
| 440 |
+
edge_logits, line_logits = self.edge_line(img, edge, line, masks=mask)
|
| 441 |
+
|
| 442 |
+
edge_pred = torch.sigmoid(edge_logits)
|
| 443 |
+
line_pred = torch.sigmoid((line_logits + add_v) * mul_v)
|
| 444 |
+
edge = edge + edge_pred * mask
|
| 445 |
+
edge[edge >= 0.25] = 1
|
| 446 |
+
edge[edge < 0.25] = 0
|
| 447 |
+
line = line + line_pred * mask
|
| 448 |
+
|
| 449 |
+
b, _, h, w = edge_pred.shape
|
| 450 |
+
edge_pred = edge_pred.reshape(b, -1, 1)
|
| 451 |
+
line_pred = line_pred.reshape(b, -1, 1)
|
| 452 |
+
mask = mask.reshape(b, -1)
|
| 453 |
+
|
| 454 |
+
edge_probs = torch.cat([1 - edge_pred, edge_pred], dim=-1)
|
| 455 |
+
line_probs = torch.cat([1 - line_pred, line_pred], dim=-1)
|
| 456 |
+
edge_probs[:, :, 1] += 0.5
|
| 457 |
+
line_probs[:, :, 1] += 0.5
|
| 458 |
+
edge_max_probs = edge_probs.max(dim=-1)[0] + (1 - mask) * (-100)
|
| 459 |
+
line_max_probs = line_probs.max(dim=-1)[0] + (1 - mask) * (-100)
|
| 460 |
+
|
| 461 |
+
indices = torch.sort(
|
| 462 |
+
edge_max_probs + line_max_probs, dim=-1, descending=True
|
| 463 |
+
)[1]
|
| 464 |
+
|
| 465 |
+
for ii in range(b):
|
| 466 |
+
keep = int((i + 1) / iterations * torch.sum(mask[ii, ...]))
|
| 467 |
+
|
| 468 |
+
assert torch.sum(mask[ii][indices[ii, :keep]]) == keep, "Error!!!"
|
| 469 |
+
mask[ii][indices[ii, :keep]] = 0
|
| 470 |
+
|
| 471 |
+
mask = mask.reshape(b, 1, h, w)
|
| 472 |
+
edge = edge * (1 - mask)
|
| 473 |
+
line = line * (1 - mask)
|
| 474 |
+
|
| 475 |
+
edge, line = edge.to(torch.float32), line.to(torch.float32)
|
| 476 |
+
return edge, line
|
iopaint/plugins/segment_anything/modeling/tiny_vit_sam.py
ADDED
|
@@ -0,0 +1,822 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# TinyViT Model Architecture
|
| 3 |
+
# Copyright (c) 2022 Microsoft
|
| 4 |
+
# Adapted from LeViT and Swin Transformer
|
| 5 |
+
# LeViT: (https://github.com/facebookresearch/levit)
|
| 6 |
+
# Swin: (https://github.com/microsoft/swin-transformer)
|
| 7 |
+
# Build the TinyViT Model
|
| 8 |
+
# --------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
import collections
|
| 11 |
+
import itertools
|
| 12 |
+
import math
|
| 13 |
+
import warnings
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
import torch.utils.checkpoint as checkpoint
|
| 18 |
+
from typing import Tuple
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _ntuple(n):
|
| 22 |
+
def parse(x):
|
| 23 |
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
| 24 |
+
return x
|
| 25 |
+
return tuple(itertools.repeat(x, n))
|
| 26 |
+
|
| 27 |
+
return parse
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
to_2tuple = _ntuple(2)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _trunc_normal_(tensor, mean, std, a, b):
|
| 34 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
| 35 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
| 36 |
+
def norm_cdf(x):
|
| 37 |
+
# Computes standard normal cumulative distribution function
|
| 38 |
+
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
| 39 |
+
|
| 40 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
| 41 |
+
warnings.warn(
|
| 42 |
+
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
| 43 |
+
"The distribution of values may be incorrect.",
|
| 44 |
+
stacklevel=2,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
# Values are generated by using a truncated uniform distribution and
|
| 48 |
+
# then using the inverse CDF for the normal distribution.
|
| 49 |
+
# Get upper and lower cdf values
|
| 50 |
+
l = norm_cdf((a - mean) / std)
|
| 51 |
+
u = norm_cdf((b - mean) / std)
|
| 52 |
+
|
| 53 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
| 54 |
+
# [2l-1, 2u-1].
|
| 55 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
| 56 |
+
|
| 57 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
| 58 |
+
# standard normal
|
| 59 |
+
tensor.erfinv_()
|
| 60 |
+
|
| 61 |
+
# Transform to proper mean, std
|
| 62 |
+
tensor.mul_(std * math.sqrt(2.0))
|
| 63 |
+
tensor.add_(mean)
|
| 64 |
+
|
| 65 |
+
# Clamp to ensure it's in the proper range
|
| 66 |
+
tensor.clamp_(min=a, max=b)
|
| 67 |
+
return tensor
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
|
| 71 |
+
# type: (Tensor, float, float, float, float) -> Tensor
|
| 72 |
+
r"""Fills the input Tensor with values drawn from a truncated
|
| 73 |
+
normal distribution. The values are effectively drawn from the
|
| 74 |
+
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
| 75 |
+
with values outside :math:`[a, b]` redrawn until they are within
|
| 76 |
+
the bounds. The method used for generating the random values works
|
| 77 |
+
best when :math:`a \leq \text{mean} \leq b`.
|
| 78 |
+
|
| 79 |
+
NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are
|
| 80 |
+
applied while sampling the normal with mean/std applied, therefore a, b args
|
| 81 |
+
should be adjusted to match the range of mean, std args.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
tensor: an n-dimensional `torch.Tensor`
|
| 85 |
+
mean: the mean of the normal distribution
|
| 86 |
+
std: the standard deviation of the normal distribution
|
| 87 |
+
a: the minimum cutoff value
|
| 88 |
+
b: the maximum cutoff value
|
| 89 |
+
Examples:
|
| 90 |
+
>>> w = torch.empty(3, 5)
|
| 91 |
+
>>> nn.init.trunc_normal_(w)
|
| 92 |
+
"""
|
| 93 |
+
with torch.no_grad():
|
| 94 |
+
return _trunc_normal_(tensor, mean, std, a, b)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def drop_path(
|
| 98 |
+
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
|
| 99 |
+
):
|
| 100 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 101 |
+
|
| 102 |
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
| 103 |
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
| 104 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
| 105 |
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
| 106 |
+
'survival rate' as the argument.
|
| 107 |
+
|
| 108 |
+
"""
|
| 109 |
+
if drop_prob == 0.0 or not training:
|
| 110 |
+
return x
|
| 111 |
+
keep_prob = 1 - drop_prob
|
| 112 |
+
shape = (x.shape[0],) + (1,) * (
|
| 113 |
+
x.ndim - 1
|
| 114 |
+
) # work with diff dim tensors, not just 2D ConvNets
|
| 115 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 116 |
+
if keep_prob > 0.0 and scale_by_keep:
|
| 117 |
+
random_tensor.div_(keep_prob)
|
| 118 |
+
return x * random_tensor
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class TimmDropPath(nn.Module):
|
| 122 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
| 123 |
+
|
| 124 |
+
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
|
| 125 |
+
super(TimmDropPath, self).__init__()
|
| 126 |
+
self.drop_prob = drop_prob
|
| 127 |
+
self.scale_by_keep = scale_by_keep
|
| 128 |
+
|
| 129 |
+
def forward(self, x):
|
| 130 |
+
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
| 131 |
+
|
| 132 |
+
def extra_repr(self):
|
| 133 |
+
return f"drop_prob={round(self.drop_prob,3):0.3f}"
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class Conv2d_BN(torch.nn.Sequential):
|
| 137 |
+
def __init__(
|
| 138 |
+
self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1
|
| 139 |
+
):
|
| 140 |
+
super().__init__()
|
| 141 |
+
self.add_module(
|
| 142 |
+
"c", torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False)
|
| 143 |
+
)
|
| 144 |
+
bn = torch.nn.BatchNorm2d(b)
|
| 145 |
+
torch.nn.init.constant_(bn.weight, bn_weight_init)
|
| 146 |
+
torch.nn.init.constant_(bn.bias, 0)
|
| 147 |
+
self.add_module("bn", bn)
|
| 148 |
+
|
| 149 |
+
@torch.no_grad()
|
| 150 |
+
def fuse(self):
|
| 151 |
+
c, bn = self._modules.values()
|
| 152 |
+
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
|
| 153 |
+
w = c.weight * w[:, None, None, None]
|
| 154 |
+
b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
|
| 155 |
+
m = torch.nn.Conv2d(
|
| 156 |
+
w.size(1) * self.c.groups,
|
| 157 |
+
w.size(0),
|
| 158 |
+
w.shape[2:],
|
| 159 |
+
stride=self.c.stride,
|
| 160 |
+
padding=self.c.padding,
|
| 161 |
+
dilation=self.c.dilation,
|
| 162 |
+
groups=self.c.groups,
|
| 163 |
+
)
|
| 164 |
+
m.weight.data.copy_(w)
|
| 165 |
+
m.bias.data.copy_(b)
|
| 166 |
+
return m
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class DropPath(TimmDropPath):
|
| 170 |
+
def __init__(self, drop_prob=None):
|
| 171 |
+
super().__init__(drop_prob=drop_prob)
|
| 172 |
+
self.drop_prob = drop_prob
|
| 173 |
+
|
| 174 |
+
def __repr__(self):
|
| 175 |
+
msg = super().__repr__()
|
| 176 |
+
msg += f"(drop_prob={self.drop_prob})"
|
| 177 |
+
return msg
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class PatchEmbed(nn.Module):
|
| 181 |
+
def __init__(self, in_chans, embed_dim, resolution, activation):
|
| 182 |
+
super().__init__()
|
| 183 |
+
img_size: Tuple[int, int] = to_2tuple(resolution)
|
| 184 |
+
self.patches_resolution = (img_size[0] // 4, img_size[1] // 4)
|
| 185 |
+
self.num_patches = self.patches_resolution[0] * self.patches_resolution[1]
|
| 186 |
+
self.in_chans = in_chans
|
| 187 |
+
self.embed_dim = embed_dim
|
| 188 |
+
n = embed_dim
|
| 189 |
+
self.seq = nn.Sequential(
|
| 190 |
+
Conv2d_BN(in_chans, n // 2, 3, 2, 1),
|
| 191 |
+
activation(),
|
| 192 |
+
Conv2d_BN(n // 2, n, 3, 2, 1),
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
def forward(self, x):
|
| 196 |
+
return self.seq(x)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
class MBConv(nn.Module):
|
| 200 |
+
def __init__(self, in_chans, out_chans, expand_ratio, activation, drop_path):
|
| 201 |
+
super().__init__()
|
| 202 |
+
self.in_chans = in_chans
|
| 203 |
+
self.hidden_chans = int(in_chans * expand_ratio)
|
| 204 |
+
self.out_chans = out_chans
|
| 205 |
+
|
| 206 |
+
self.conv1 = Conv2d_BN(in_chans, self.hidden_chans, ks=1)
|
| 207 |
+
self.act1 = activation()
|
| 208 |
+
|
| 209 |
+
self.conv2 = Conv2d_BN(
|
| 210 |
+
self.hidden_chans,
|
| 211 |
+
self.hidden_chans,
|
| 212 |
+
ks=3,
|
| 213 |
+
stride=1,
|
| 214 |
+
pad=1,
|
| 215 |
+
groups=self.hidden_chans,
|
| 216 |
+
)
|
| 217 |
+
self.act2 = activation()
|
| 218 |
+
|
| 219 |
+
self.conv3 = Conv2d_BN(self.hidden_chans, out_chans, ks=1, bn_weight_init=0.0)
|
| 220 |
+
self.act3 = activation()
|
| 221 |
+
|
| 222 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 223 |
+
|
| 224 |
+
def forward(self, x):
|
| 225 |
+
shortcut = x
|
| 226 |
+
|
| 227 |
+
x = self.conv1(x)
|
| 228 |
+
x = self.act1(x)
|
| 229 |
+
|
| 230 |
+
x = self.conv2(x)
|
| 231 |
+
x = self.act2(x)
|
| 232 |
+
|
| 233 |
+
x = self.conv3(x)
|
| 234 |
+
|
| 235 |
+
x = self.drop_path(x)
|
| 236 |
+
|
| 237 |
+
x += shortcut
|
| 238 |
+
x = self.act3(x)
|
| 239 |
+
|
| 240 |
+
return x
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
class PatchMerging(nn.Module):
|
| 244 |
+
def __init__(self, input_resolution, dim, out_dim, activation):
|
| 245 |
+
super().__init__()
|
| 246 |
+
|
| 247 |
+
self.input_resolution = input_resolution
|
| 248 |
+
self.dim = dim
|
| 249 |
+
self.out_dim = out_dim
|
| 250 |
+
self.act = activation()
|
| 251 |
+
self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0)
|
| 252 |
+
stride_c = 2
|
| 253 |
+
if out_dim == 320 or out_dim == 448 or out_dim == 576:
|
| 254 |
+
stride_c = 1
|
| 255 |
+
self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim)
|
| 256 |
+
self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
|
| 257 |
+
|
| 258 |
+
def forward(self, x):
|
| 259 |
+
if x.ndim == 3:
|
| 260 |
+
H, W = self.input_resolution
|
| 261 |
+
B = len(x)
|
| 262 |
+
# (B, C, H, W)
|
| 263 |
+
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
|
| 264 |
+
|
| 265 |
+
x = self.conv1(x)
|
| 266 |
+
x = self.act(x)
|
| 267 |
+
|
| 268 |
+
x = self.conv2(x)
|
| 269 |
+
x = self.act(x)
|
| 270 |
+
x = self.conv3(x)
|
| 271 |
+
x = x.flatten(2).transpose(1, 2)
|
| 272 |
+
return x
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
class ConvLayer(nn.Module):
|
| 276 |
+
def __init__(
|
| 277 |
+
self,
|
| 278 |
+
dim,
|
| 279 |
+
input_resolution,
|
| 280 |
+
depth,
|
| 281 |
+
activation,
|
| 282 |
+
drop_path=0.0,
|
| 283 |
+
downsample=None,
|
| 284 |
+
use_checkpoint=False,
|
| 285 |
+
out_dim=None,
|
| 286 |
+
conv_expand_ratio=4.0,
|
| 287 |
+
):
|
| 288 |
+
super().__init__()
|
| 289 |
+
self.dim = dim
|
| 290 |
+
self.input_resolution = input_resolution
|
| 291 |
+
self.depth = depth
|
| 292 |
+
self.use_checkpoint = use_checkpoint
|
| 293 |
+
|
| 294 |
+
# build blocks
|
| 295 |
+
self.blocks = nn.ModuleList(
|
| 296 |
+
[
|
| 297 |
+
MBConv(
|
| 298 |
+
dim,
|
| 299 |
+
dim,
|
| 300 |
+
conv_expand_ratio,
|
| 301 |
+
activation,
|
| 302 |
+
drop_path[i] if isinstance(drop_path, list) else drop_path,
|
| 303 |
+
)
|
| 304 |
+
for i in range(depth)
|
| 305 |
+
]
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
# patch merging layer
|
| 309 |
+
if downsample is not None:
|
| 310 |
+
self.downsample = downsample(
|
| 311 |
+
input_resolution, dim=dim, out_dim=out_dim, activation=activation
|
| 312 |
+
)
|
| 313 |
+
else:
|
| 314 |
+
self.downsample = None
|
| 315 |
+
|
| 316 |
+
def forward(self, x):
|
| 317 |
+
for blk in self.blocks:
|
| 318 |
+
if self.use_checkpoint:
|
| 319 |
+
x = checkpoint.checkpoint(blk, x)
|
| 320 |
+
else:
|
| 321 |
+
x = blk(x)
|
| 322 |
+
if self.downsample is not None:
|
| 323 |
+
x = self.downsample(x)
|
| 324 |
+
return x
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
class Mlp(nn.Module):
|
| 328 |
+
def __init__(
|
| 329 |
+
self,
|
| 330 |
+
in_features,
|
| 331 |
+
hidden_features=None,
|
| 332 |
+
out_features=None,
|
| 333 |
+
act_layer=nn.GELU,
|
| 334 |
+
drop=0.0,
|
| 335 |
+
):
|
| 336 |
+
super().__init__()
|
| 337 |
+
out_features = out_features or in_features
|
| 338 |
+
hidden_features = hidden_features or in_features
|
| 339 |
+
self.norm = nn.LayerNorm(in_features)
|
| 340 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 341 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 342 |
+
self.act = act_layer()
|
| 343 |
+
self.drop = nn.Dropout(drop)
|
| 344 |
+
|
| 345 |
+
def forward(self, x):
|
| 346 |
+
x = self.norm(x)
|
| 347 |
+
|
| 348 |
+
x = self.fc1(x)
|
| 349 |
+
x = self.act(x)
|
| 350 |
+
x = self.drop(x)
|
| 351 |
+
x = self.fc2(x)
|
| 352 |
+
x = self.drop(x)
|
| 353 |
+
return x
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
class Attention(torch.nn.Module):
|
| 357 |
+
def __init__(
|
| 358 |
+
self,
|
| 359 |
+
dim,
|
| 360 |
+
key_dim,
|
| 361 |
+
num_heads=8,
|
| 362 |
+
attn_ratio=4,
|
| 363 |
+
resolution=(14, 14),
|
| 364 |
+
):
|
| 365 |
+
super().__init__()
|
| 366 |
+
# (h, w)
|
| 367 |
+
assert isinstance(resolution, tuple) and len(resolution) == 2
|
| 368 |
+
self.num_heads = num_heads
|
| 369 |
+
self.scale = key_dim**-0.5
|
| 370 |
+
self.key_dim = key_dim
|
| 371 |
+
self.nh_kd = nh_kd = key_dim * num_heads
|
| 372 |
+
self.d = int(attn_ratio * key_dim)
|
| 373 |
+
self.dh = int(attn_ratio * key_dim) * num_heads
|
| 374 |
+
self.attn_ratio = attn_ratio
|
| 375 |
+
h = self.dh + nh_kd * 2
|
| 376 |
+
|
| 377 |
+
self.norm = nn.LayerNorm(dim)
|
| 378 |
+
self.qkv = nn.Linear(dim, h)
|
| 379 |
+
self.proj = nn.Linear(self.dh, dim)
|
| 380 |
+
|
| 381 |
+
points = list(itertools.product(range(resolution[0]), range(resolution[1])))
|
| 382 |
+
N = len(points)
|
| 383 |
+
attention_offsets = {}
|
| 384 |
+
idxs = []
|
| 385 |
+
for p1 in points:
|
| 386 |
+
for p2 in points:
|
| 387 |
+
offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
|
| 388 |
+
if offset not in attention_offsets:
|
| 389 |
+
attention_offsets[offset] = len(attention_offsets)
|
| 390 |
+
idxs.append(attention_offsets[offset])
|
| 391 |
+
self.attention_biases = torch.nn.Parameter(
|
| 392 |
+
torch.zeros(num_heads, len(attention_offsets))
|
| 393 |
+
)
|
| 394 |
+
self.register_buffer(
|
| 395 |
+
"attention_bias_idxs", torch.LongTensor(idxs).view(N, N), persistent=False
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
@torch.no_grad()
|
| 399 |
+
def train(self, mode=True):
|
| 400 |
+
super().train(mode)
|
| 401 |
+
if mode and hasattr(self, "ab"):
|
| 402 |
+
del self.ab
|
| 403 |
+
else:
|
| 404 |
+
self.register_buffer(
|
| 405 |
+
"ab",
|
| 406 |
+
self.attention_biases[:, self.attention_bias_idxs],
|
| 407 |
+
persistent=False,
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
def forward(self, x): # x (B,N,C)
|
| 411 |
+
B, N, _ = x.shape
|
| 412 |
+
|
| 413 |
+
# Normalization
|
| 414 |
+
x = self.norm(x)
|
| 415 |
+
|
| 416 |
+
qkv = self.qkv(x)
|
| 417 |
+
# (B, N, num_heads, d)
|
| 418 |
+
q, k, v = qkv.view(B, N, self.num_heads, -1).split(
|
| 419 |
+
[self.key_dim, self.key_dim, self.d], dim=3
|
| 420 |
+
)
|
| 421 |
+
# (B, num_heads, N, d)
|
| 422 |
+
q = q.permute(0, 2, 1, 3)
|
| 423 |
+
k = k.permute(0, 2, 1, 3)
|
| 424 |
+
v = v.permute(0, 2, 1, 3)
|
| 425 |
+
|
| 426 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale + (
|
| 427 |
+
self.attention_biases[:, self.attention_bias_idxs]
|
| 428 |
+
if self.training
|
| 429 |
+
else self.ab
|
| 430 |
+
)
|
| 431 |
+
attn = attn.softmax(dim=-1)
|
| 432 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
|
| 433 |
+
x = self.proj(x)
|
| 434 |
+
return x
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
class TinyViTBlock(nn.Module):
|
| 438 |
+
r"""TinyViT Block.
|
| 439 |
+
|
| 440 |
+
Args:
|
| 441 |
+
dim (int): Number of input channels.
|
| 442 |
+
input_resolution (tuple[int, int]): Input resolution.
|
| 443 |
+
num_heads (int): Number of attention heads.
|
| 444 |
+
window_size (int): Window size.
|
| 445 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
| 446 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
| 447 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
| 448 |
+
local_conv_size (int): the kernel size of the convolution between
|
| 449 |
+
Attention and MLP. Default: 3
|
| 450 |
+
activation: the activation function. Default: nn.GELU
|
| 451 |
+
"""
|
| 452 |
+
|
| 453 |
+
def __init__(
|
| 454 |
+
self,
|
| 455 |
+
dim,
|
| 456 |
+
input_resolution,
|
| 457 |
+
num_heads,
|
| 458 |
+
window_size=7,
|
| 459 |
+
mlp_ratio=4.0,
|
| 460 |
+
drop=0.0,
|
| 461 |
+
drop_path=0.0,
|
| 462 |
+
local_conv_size=3,
|
| 463 |
+
activation=nn.GELU,
|
| 464 |
+
):
|
| 465 |
+
super().__init__()
|
| 466 |
+
self.dim = dim
|
| 467 |
+
self.input_resolution = input_resolution
|
| 468 |
+
self.num_heads = num_heads
|
| 469 |
+
assert window_size > 0, "window_size must be greater than 0"
|
| 470 |
+
self.window_size = window_size
|
| 471 |
+
self.mlp_ratio = mlp_ratio
|
| 472 |
+
|
| 473 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 474 |
+
|
| 475 |
+
assert dim % num_heads == 0, "dim must be divisible by num_heads"
|
| 476 |
+
head_dim = dim // num_heads
|
| 477 |
+
|
| 478 |
+
window_resolution = (window_size, window_size)
|
| 479 |
+
self.attn = Attention(
|
| 480 |
+
dim, head_dim, num_heads, attn_ratio=1, resolution=window_resolution
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 484 |
+
mlp_activation = activation
|
| 485 |
+
self.mlp = Mlp(
|
| 486 |
+
in_features=dim,
|
| 487 |
+
hidden_features=mlp_hidden_dim,
|
| 488 |
+
act_layer=mlp_activation,
|
| 489 |
+
drop=drop,
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
pad = local_conv_size // 2
|
| 493 |
+
self.local_conv = Conv2d_BN(
|
| 494 |
+
dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
def forward(self, x):
|
| 498 |
+
H, W = self.input_resolution
|
| 499 |
+
B, L, C = x.shape
|
| 500 |
+
assert L == H * W, "input feature has wrong size"
|
| 501 |
+
res_x = x
|
| 502 |
+
if H == self.window_size and W == self.window_size:
|
| 503 |
+
x = self.attn(x)
|
| 504 |
+
else:
|
| 505 |
+
x = x.view(B, H, W, C)
|
| 506 |
+
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
| 507 |
+
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
| 508 |
+
padding = pad_b > 0 or pad_r > 0
|
| 509 |
+
|
| 510 |
+
if padding:
|
| 511 |
+
x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
|
| 512 |
+
|
| 513 |
+
pH, pW = H + pad_b, W + pad_r
|
| 514 |
+
nH = pH // self.window_size
|
| 515 |
+
nW = pW // self.window_size
|
| 516 |
+
# window partition
|
| 517 |
+
x = (
|
| 518 |
+
x.view(B, nH, self.window_size, nW, self.window_size, C)
|
| 519 |
+
.transpose(2, 3)
|
| 520 |
+
.reshape(B * nH * nW, self.window_size * self.window_size, C)
|
| 521 |
+
)
|
| 522 |
+
x = self.attn(x)
|
| 523 |
+
# window reverse
|
| 524 |
+
x = (
|
| 525 |
+
x.view(B, nH, nW, self.window_size, self.window_size, C)
|
| 526 |
+
.transpose(2, 3)
|
| 527 |
+
.reshape(B, pH, pW, C)
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
if padding:
|
| 531 |
+
x = x[:, :H, :W].contiguous()
|
| 532 |
+
|
| 533 |
+
x = x.view(B, L, C)
|
| 534 |
+
|
| 535 |
+
x = res_x + self.drop_path(x)
|
| 536 |
+
|
| 537 |
+
x = x.transpose(1, 2).reshape(B, C, H, W)
|
| 538 |
+
x = self.local_conv(x)
|
| 539 |
+
x = x.view(B, C, L).transpose(1, 2)
|
| 540 |
+
|
| 541 |
+
x = x + self.drop_path(self.mlp(x))
|
| 542 |
+
return x
|
| 543 |
+
|
| 544 |
+
def extra_repr(self) -> str:
|
| 545 |
+
return (
|
| 546 |
+
f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
|
| 547 |
+
f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}"
|
| 548 |
+
)
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
class BasicLayer(nn.Module):
|
| 552 |
+
"""A basic TinyViT layer for one stage.
|
| 553 |
+
|
| 554 |
+
Args:
|
| 555 |
+
dim (int): Number of input channels.
|
| 556 |
+
input_resolution (tuple[int]): Input resolution.
|
| 557 |
+
depth (int): Number of blocks.
|
| 558 |
+
num_heads (int): Number of attention heads.
|
| 559 |
+
window_size (int): Local window size.
|
| 560 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
| 561 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
| 562 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
| 563 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
| 564 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
| 565 |
+
local_conv_size: the kernel size of the depthwise convolution between attention and MLP. Default: 3
|
| 566 |
+
activation: the activation function. Default: nn.GELU
|
| 567 |
+
out_dim: the output dimension of the layer. Default: dim
|
| 568 |
+
"""
|
| 569 |
+
|
| 570 |
+
def __init__(
|
| 571 |
+
self,
|
| 572 |
+
dim,
|
| 573 |
+
input_resolution,
|
| 574 |
+
depth,
|
| 575 |
+
num_heads,
|
| 576 |
+
window_size,
|
| 577 |
+
mlp_ratio=4.0,
|
| 578 |
+
drop=0.0,
|
| 579 |
+
drop_path=0.0,
|
| 580 |
+
downsample=None,
|
| 581 |
+
use_checkpoint=False,
|
| 582 |
+
local_conv_size=3,
|
| 583 |
+
activation=nn.GELU,
|
| 584 |
+
out_dim=None,
|
| 585 |
+
):
|
| 586 |
+
super().__init__()
|
| 587 |
+
self.dim = dim
|
| 588 |
+
self.input_resolution = input_resolution
|
| 589 |
+
self.depth = depth
|
| 590 |
+
self.use_checkpoint = use_checkpoint
|
| 591 |
+
|
| 592 |
+
# build blocks
|
| 593 |
+
self.blocks = nn.ModuleList(
|
| 594 |
+
[
|
| 595 |
+
TinyViTBlock(
|
| 596 |
+
dim=dim,
|
| 597 |
+
input_resolution=input_resolution,
|
| 598 |
+
num_heads=num_heads,
|
| 599 |
+
window_size=window_size,
|
| 600 |
+
mlp_ratio=mlp_ratio,
|
| 601 |
+
drop=drop,
|
| 602 |
+
drop_path=drop_path[i]
|
| 603 |
+
if isinstance(drop_path, list)
|
| 604 |
+
else drop_path,
|
| 605 |
+
local_conv_size=local_conv_size,
|
| 606 |
+
activation=activation,
|
| 607 |
+
)
|
| 608 |
+
for i in range(depth)
|
| 609 |
+
]
|
| 610 |
+
)
|
| 611 |
+
|
| 612 |
+
# patch merging layer
|
| 613 |
+
if downsample is not None:
|
| 614 |
+
self.downsample = downsample(
|
| 615 |
+
input_resolution, dim=dim, out_dim=out_dim, activation=activation
|
| 616 |
+
)
|
| 617 |
+
else:
|
| 618 |
+
self.downsample = None
|
| 619 |
+
|
| 620 |
+
def forward(self, x):
|
| 621 |
+
for blk in self.blocks:
|
| 622 |
+
if self.use_checkpoint:
|
| 623 |
+
x = checkpoint.checkpoint(blk, x)
|
| 624 |
+
else:
|
| 625 |
+
x = blk(x)
|
| 626 |
+
if self.downsample is not None:
|
| 627 |
+
x = self.downsample(x)
|
| 628 |
+
return x
|
| 629 |
+
|
| 630 |
+
def extra_repr(self) -> str:
|
| 631 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
class LayerNorm2d(nn.Module):
|
| 635 |
+
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
|
| 636 |
+
super().__init__()
|
| 637 |
+
self.weight = nn.Parameter(torch.ones(num_channels))
|
| 638 |
+
self.bias = nn.Parameter(torch.zeros(num_channels))
|
| 639 |
+
self.eps = eps
|
| 640 |
+
|
| 641 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 642 |
+
u = x.mean(1, keepdim=True)
|
| 643 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
| 644 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
| 645 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
| 646 |
+
return x
|
| 647 |
+
|
| 648 |
+
|
| 649 |
+
class TinyViT(nn.Module):
|
| 650 |
+
def __init__(
|
| 651 |
+
self,
|
| 652 |
+
img_size=224,
|
| 653 |
+
in_chans=3,
|
| 654 |
+
num_classes=1000,
|
| 655 |
+
embed_dims=[96, 192, 384, 768],
|
| 656 |
+
depths=[2, 2, 6, 2],
|
| 657 |
+
num_heads=[3, 6, 12, 24],
|
| 658 |
+
window_sizes=[7, 7, 14, 7],
|
| 659 |
+
mlp_ratio=4.0,
|
| 660 |
+
drop_rate=0.0,
|
| 661 |
+
drop_path_rate=0.1,
|
| 662 |
+
use_checkpoint=False,
|
| 663 |
+
mbconv_expand_ratio=4.0,
|
| 664 |
+
local_conv_size=3,
|
| 665 |
+
layer_lr_decay=1.0,
|
| 666 |
+
):
|
| 667 |
+
super().__init__()
|
| 668 |
+
self.img_size = img_size
|
| 669 |
+
self.num_classes = num_classes
|
| 670 |
+
self.depths = depths
|
| 671 |
+
self.num_layers = len(depths)
|
| 672 |
+
self.mlp_ratio = mlp_ratio
|
| 673 |
+
|
| 674 |
+
activation = nn.GELU
|
| 675 |
+
|
| 676 |
+
self.patch_embed = PatchEmbed(
|
| 677 |
+
in_chans=in_chans,
|
| 678 |
+
embed_dim=embed_dims[0],
|
| 679 |
+
resolution=img_size,
|
| 680 |
+
activation=activation,
|
| 681 |
+
)
|
| 682 |
+
|
| 683 |
+
patches_resolution = self.patch_embed.patches_resolution
|
| 684 |
+
self.patches_resolution = patches_resolution
|
| 685 |
+
|
| 686 |
+
# stochastic depth
|
| 687 |
+
dpr = [
|
| 688 |
+
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
| 689 |
+
] # stochastic depth decay rule
|
| 690 |
+
|
| 691 |
+
# build layers
|
| 692 |
+
self.layers = nn.ModuleList()
|
| 693 |
+
for i_layer in range(self.num_layers):
|
| 694 |
+
kwargs = dict(
|
| 695 |
+
dim=embed_dims[i_layer],
|
| 696 |
+
input_resolution=(
|
| 697 |
+
patches_resolution[0]
|
| 698 |
+
// (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
|
| 699 |
+
patches_resolution[1]
|
| 700 |
+
// (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
|
| 701 |
+
),
|
| 702 |
+
# input_resolution=(patches_resolution[0] // (2 ** i_layer),
|
| 703 |
+
# patches_resolution[1] // (2 ** i_layer)),
|
| 704 |
+
depth=depths[i_layer],
|
| 705 |
+
drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
|
| 706 |
+
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
| 707 |
+
use_checkpoint=use_checkpoint,
|
| 708 |
+
out_dim=embed_dims[min(i_layer + 1, len(embed_dims) - 1)],
|
| 709 |
+
activation=activation,
|
| 710 |
+
)
|
| 711 |
+
if i_layer == 0:
|
| 712 |
+
layer = ConvLayer(
|
| 713 |
+
conv_expand_ratio=mbconv_expand_ratio,
|
| 714 |
+
**kwargs,
|
| 715 |
+
)
|
| 716 |
+
else:
|
| 717 |
+
layer = BasicLayer(
|
| 718 |
+
num_heads=num_heads[i_layer],
|
| 719 |
+
window_size=window_sizes[i_layer],
|
| 720 |
+
mlp_ratio=self.mlp_ratio,
|
| 721 |
+
drop=drop_rate,
|
| 722 |
+
local_conv_size=local_conv_size,
|
| 723 |
+
**kwargs,
|
| 724 |
+
)
|
| 725 |
+
self.layers.append(layer)
|
| 726 |
+
|
| 727 |
+
# Classifier head
|
| 728 |
+
self.norm_head = nn.LayerNorm(embed_dims[-1])
|
| 729 |
+
self.head = (
|
| 730 |
+
nn.Linear(embed_dims[-1], num_classes)
|
| 731 |
+
if num_classes > 0
|
| 732 |
+
else torch.nn.Identity()
|
| 733 |
+
)
|
| 734 |
+
|
| 735 |
+
# init weights
|
| 736 |
+
self.apply(self._init_weights)
|
| 737 |
+
self.set_layer_lr_decay(layer_lr_decay)
|
| 738 |
+
self.neck = nn.Sequential(
|
| 739 |
+
nn.Conv2d(
|
| 740 |
+
embed_dims[-1],
|
| 741 |
+
256,
|
| 742 |
+
kernel_size=1,
|
| 743 |
+
bias=False,
|
| 744 |
+
),
|
| 745 |
+
LayerNorm2d(256),
|
| 746 |
+
nn.Conv2d(
|
| 747 |
+
256,
|
| 748 |
+
256,
|
| 749 |
+
kernel_size=3,
|
| 750 |
+
padding=1,
|
| 751 |
+
bias=False,
|
| 752 |
+
),
|
| 753 |
+
LayerNorm2d(256),
|
| 754 |
+
)
|
| 755 |
+
|
| 756 |
+
def set_layer_lr_decay(self, layer_lr_decay):
|
| 757 |
+
decay_rate = layer_lr_decay
|
| 758 |
+
|
| 759 |
+
# layers -> blocks (depth)
|
| 760 |
+
depth = sum(self.depths)
|
| 761 |
+
lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)]
|
| 762 |
+
# print("LR SCALES:", lr_scales)
|
| 763 |
+
|
| 764 |
+
def _set_lr_scale(m, scale):
|
| 765 |
+
for p in m.parameters():
|
| 766 |
+
p.lr_scale = scale
|
| 767 |
+
|
| 768 |
+
self.patch_embed.apply(lambda x: _set_lr_scale(x, lr_scales[0]))
|
| 769 |
+
i = 0
|
| 770 |
+
for layer in self.layers:
|
| 771 |
+
for block in layer.blocks:
|
| 772 |
+
block.apply(lambda x: _set_lr_scale(x, lr_scales[i]))
|
| 773 |
+
i += 1
|
| 774 |
+
if layer.downsample is not None:
|
| 775 |
+
layer.downsample.apply(lambda x: _set_lr_scale(x, lr_scales[i - 1]))
|
| 776 |
+
assert i == depth
|
| 777 |
+
for m in [self.norm_head, self.head]:
|
| 778 |
+
m.apply(lambda x: _set_lr_scale(x, lr_scales[-1]))
|
| 779 |
+
|
| 780 |
+
for k, p in self.named_parameters():
|
| 781 |
+
p.param_name = k
|
| 782 |
+
|
| 783 |
+
def _check_lr_scale(m):
|
| 784 |
+
for p in m.parameters():
|
| 785 |
+
assert hasattr(p, "lr_scale"), p.param_name
|
| 786 |
+
|
| 787 |
+
self.apply(_check_lr_scale)
|
| 788 |
+
|
| 789 |
+
def _init_weights(self, m):
|
| 790 |
+
if isinstance(m, nn.Linear):
|
| 791 |
+
trunc_normal_(m.weight, std=0.02)
|
| 792 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 793 |
+
nn.init.constant_(m.bias, 0)
|
| 794 |
+
elif isinstance(m, nn.LayerNorm):
|
| 795 |
+
nn.init.constant_(m.bias, 0)
|
| 796 |
+
nn.init.constant_(m.weight, 1.0)
|
| 797 |
+
|
| 798 |
+
@torch.jit.ignore
|
| 799 |
+
def no_weight_decay_keywords(self):
|
| 800 |
+
return {"attention_biases"}
|
| 801 |
+
|
| 802 |
+
def forward_features(self, x):
|
| 803 |
+
# x: (N, C, H, W)
|
| 804 |
+
x = self.patch_embed(x)
|
| 805 |
+
|
| 806 |
+
x = self.layers[0](x)
|
| 807 |
+
start_i = 1
|
| 808 |
+
|
| 809 |
+
for i in range(start_i, len(self.layers)):
|
| 810 |
+
layer = self.layers[i]
|
| 811 |
+
x = layer(x)
|
| 812 |
+
B, _, C = x.size()
|
| 813 |
+
x = x.view(B, 64, 64, C)
|
| 814 |
+
x = x.permute(0, 3, 1, 2)
|
| 815 |
+
x = self.neck(x)
|
| 816 |
+
return x
|
| 817 |
+
|
| 818 |
+
def forward(self, x):
|
| 819 |
+
x = self.forward_features(x)
|
| 820 |
+
# x = self.norm_head(x)
|
| 821 |
+
# x = self.head(x)
|
| 822 |
+
return x
|
iopaint/plugins/segment_anything/modeling/transformer.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import Tensor, nn
|
| 9 |
+
|
| 10 |
+
import math
|
| 11 |
+
from typing import Tuple, Type
|
| 12 |
+
|
| 13 |
+
from .common import MLPBlock
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class TwoWayTransformer(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
depth: int,
|
| 20 |
+
embedding_dim: int,
|
| 21 |
+
num_heads: int,
|
| 22 |
+
mlp_dim: int,
|
| 23 |
+
activation: Type[nn.Module] = nn.ReLU,
|
| 24 |
+
attention_downsample_rate: int = 2,
|
| 25 |
+
) -> None:
|
| 26 |
+
"""
|
| 27 |
+
A transformer decoder that attends to an input image using
|
| 28 |
+
queries whose positional embedding is supplied.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
depth (int): number of layers in the transformer
|
| 32 |
+
embedding_dim (int): the channel dimension for the input embeddings
|
| 33 |
+
num_heads (int): the number of heads for multihead attention. Must
|
| 34 |
+
divide embedding_dim
|
| 35 |
+
mlp_dim (int): the channel dimension internal to the MLP block
|
| 36 |
+
activation (nn.Module): the activation to use in the MLP block
|
| 37 |
+
"""
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.depth = depth
|
| 40 |
+
self.embedding_dim = embedding_dim
|
| 41 |
+
self.num_heads = num_heads
|
| 42 |
+
self.mlp_dim = mlp_dim
|
| 43 |
+
self.layers = nn.ModuleList()
|
| 44 |
+
|
| 45 |
+
for i in range(depth):
|
| 46 |
+
self.layers.append(
|
| 47 |
+
TwoWayAttentionBlock(
|
| 48 |
+
embedding_dim=embedding_dim,
|
| 49 |
+
num_heads=num_heads,
|
| 50 |
+
mlp_dim=mlp_dim,
|
| 51 |
+
activation=activation,
|
| 52 |
+
attention_downsample_rate=attention_downsample_rate,
|
| 53 |
+
skip_first_layer_pe=(i == 0),
|
| 54 |
+
)
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
self.final_attn_token_to_image = Attention(
|
| 58 |
+
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
| 59 |
+
)
|
| 60 |
+
self.norm_final_attn = nn.LayerNorm(embedding_dim)
|
| 61 |
+
|
| 62 |
+
def forward(
|
| 63 |
+
self,
|
| 64 |
+
image_embedding: Tensor,
|
| 65 |
+
image_pe: Tensor,
|
| 66 |
+
point_embedding: Tensor,
|
| 67 |
+
) -> Tuple[Tensor, Tensor]:
|
| 68 |
+
"""
|
| 69 |
+
Args:
|
| 70 |
+
image_embedding (torch.Tensor): image to attend to. Should be shape
|
| 71 |
+
B x embedding_dim x h x w for any h and w.
|
| 72 |
+
image_pe (torch.Tensor): the positional encoding to add to the image. Must
|
| 73 |
+
have the same shape as image_embedding.
|
| 74 |
+
point_embedding (torch.Tensor): the embedding to add to the query points.
|
| 75 |
+
Must have shape B x N_points x embedding_dim for any N_points.
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
torch.Tensor: the processed point_embedding
|
| 79 |
+
torch.Tensor: the processed image_embedding
|
| 80 |
+
"""
|
| 81 |
+
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
|
| 82 |
+
bs, c, h, w = image_embedding.shape
|
| 83 |
+
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
|
| 84 |
+
image_pe = image_pe.flatten(2).permute(0, 2, 1)
|
| 85 |
+
|
| 86 |
+
# Prepare queries
|
| 87 |
+
queries = point_embedding
|
| 88 |
+
keys = image_embedding
|
| 89 |
+
|
| 90 |
+
# Apply transformer blocks and final layernorm
|
| 91 |
+
for layer in self.layers:
|
| 92 |
+
queries, keys = layer(
|
| 93 |
+
queries=queries,
|
| 94 |
+
keys=keys,
|
| 95 |
+
query_pe=point_embedding,
|
| 96 |
+
key_pe=image_pe,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# Apply the final attenion layer from the points to the image
|
| 100 |
+
q = queries + point_embedding
|
| 101 |
+
k = keys + image_pe
|
| 102 |
+
attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
|
| 103 |
+
queries = queries + attn_out
|
| 104 |
+
queries = self.norm_final_attn(queries)
|
| 105 |
+
|
| 106 |
+
return queries, keys
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class TwoWayAttentionBlock(nn.Module):
|
| 110 |
+
def __init__(
|
| 111 |
+
self,
|
| 112 |
+
embedding_dim: int,
|
| 113 |
+
num_heads: int,
|
| 114 |
+
mlp_dim: int = 2048,
|
| 115 |
+
activation: Type[nn.Module] = nn.ReLU,
|
| 116 |
+
attention_downsample_rate: int = 2,
|
| 117 |
+
skip_first_layer_pe: bool = False,
|
| 118 |
+
) -> None:
|
| 119 |
+
"""
|
| 120 |
+
A transformer block with four layers: (1) self-attention of sparse
|
| 121 |
+
inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
|
| 122 |
+
block on sparse inputs, and (4) cross attention of dense inputs to sparse
|
| 123 |
+
inputs.
|
| 124 |
+
|
| 125 |
+
Arguments:
|
| 126 |
+
embedding_dim (int): the channel dimension of the embeddings
|
| 127 |
+
num_heads (int): the number of heads in the attention layers
|
| 128 |
+
mlp_dim (int): the hidden dimension of the mlp block
|
| 129 |
+
activation (nn.Module): the activation of the mlp block
|
| 130 |
+
skip_first_layer_pe (bool): skip the PE on the first layer
|
| 131 |
+
"""
|
| 132 |
+
super().__init__()
|
| 133 |
+
self.self_attn = Attention(embedding_dim, num_heads)
|
| 134 |
+
self.norm1 = nn.LayerNorm(embedding_dim)
|
| 135 |
+
|
| 136 |
+
self.cross_attn_token_to_image = Attention(
|
| 137 |
+
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
| 138 |
+
)
|
| 139 |
+
self.norm2 = nn.LayerNorm(embedding_dim)
|
| 140 |
+
|
| 141 |
+
self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
|
| 142 |
+
self.norm3 = nn.LayerNorm(embedding_dim)
|
| 143 |
+
|
| 144 |
+
self.norm4 = nn.LayerNorm(embedding_dim)
|
| 145 |
+
self.cross_attn_image_to_token = Attention(
|
| 146 |
+
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
self.skip_first_layer_pe = skip_first_layer_pe
|
| 150 |
+
|
| 151 |
+
def forward(
|
| 152 |
+
self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
|
| 153 |
+
) -> Tuple[Tensor, Tensor]:
|
| 154 |
+
# Self attention block
|
| 155 |
+
if self.skip_first_layer_pe:
|
| 156 |
+
queries = self.self_attn(q=queries, k=queries, v=queries)
|
| 157 |
+
else:
|
| 158 |
+
q = queries + query_pe
|
| 159 |
+
attn_out = self.self_attn(q=q, k=q, v=queries)
|
| 160 |
+
queries = queries + attn_out
|
| 161 |
+
queries = self.norm1(queries)
|
| 162 |
+
|
| 163 |
+
# Cross attention block, tokens attending to image embedding
|
| 164 |
+
q = queries + query_pe
|
| 165 |
+
k = keys + key_pe
|
| 166 |
+
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
|
| 167 |
+
queries = queries + attn_out
|
| 168 |
+
queries = self.norm2(queries)
|
| 169 |
+
|
| 170 |
+
# MLP block
|
| 171 |
+
mlp_out = self.mlp(queries)
|
| 172 |
+
queries = queries + mlp_out
|
| 173 |
+
queries = self.norm3(queries)
|
| 174 |
+
|
| 175 |
+
# Cross attention block, image embedding attending to tokens
|
| 176 |
+
q = queries + query_pe
|
| 177 |
+
k = keys + key_pe
|
| 178 |
+
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
|
| 179 |
+
keys = keys + attn_out
|
| 180 |
+
keys = self.norm4(keys)
|
| 181 |
+
|
| 182 |
+
return queries, keys
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class Attention(nn.Module):
|
| 186 |
+
"""
|
| 187 |
+
An attention layer that allows for downscaling the size of the embedding
|
| 188 |
+
after projection to queries, keys, and values.
|
| 189 |
+
"""
|
| 190 |
+
|
| 191 |
+
def __init__(
|
| 192 |
+
self,
|
| 193 |
+
embedding_dim: int,
|
| 194 |
+
num_heads: int,
|
| 195 |
+
downsample_rate: int = 1,
|
| 196 |
+
) -> None:
|
| 197 |
+
super().__init__()
|
| 198 |
+
self.embedding_dim = embedding_dim
|
| 199 |
+
self.internal_dim = embedding_dim // downsample_rate
|
| 200 |
+
self.num_heads = num_heads
|
| 201 |
+
assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
|
| 202 |
+
|
| 203 |
+
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
|
| 204 |
+
self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
|
| 205 |
+
self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
|
| 206 |
+
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
|
| 207 |
+
|
| 208 |
+
def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
|
| 209 |
+
b, n, c = x.shape
|
| 210 |
+
x = x.reshape(b, n, num_heads, c // num_heads)
|
| 211 |
+
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
|
| 212 |
+
|
| 213 |
+
def _recombine_heads(self, x: Tensor) -> Tensor:
|
| 214 |
+
b, n_heads, n_tokens, c_per_head = x.shape
|
| 215 |
+
x = x.transpose(1, 2)
|
| 216 |
+
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
|
| 217 |
+
|
| 218 |
+
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
| 219 |
+
# Input projections
|
| 220 |
+
q = self.q_proj(q)
|
| 221 |
+
k = self.k_proj(k)
|
| 222 |
+
v = self.v_proj(v)
|
| 223 |
+
|
| 224 |
+
# Separate into heads
|
| 225 |
+
q = self._separate_heads(q, self.num_heads)
|
| 226 |
+
k = self._separate_heads(k, self.num_heads)
|
| 227 |
+
v = self._separate_heads(v, self.num_heads)
|
| 228 |
+
|
| 229 |
+
# Attention
|
| 230 |
+
_, _, _, c_per_head = q.shape
|
| 231 |
+
attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
|
| 232 |
+
attn = attn / math.sqrt(c_per_head)
|
| 233 |
+
attn = torch.softmax(attn, dim=-1)
|
| 234 |
+
|
| 235 |
+
# Get output
|
| 236 |
+
out = attn @ v
|
| 237 |
+
out = self._recombine_heads(out)
|
| 238 |
+
out = self.out_proj(out)
|
| 239 |
+
|
| 240 |
+
return out
|
iopaint/plugins/segment_anything/utils/transforms.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from torch.nn import functional as F
|
| 10 |
+
from torchvision.transforms.functional import resize, to_pil_image # type: ignore
|
| 11 |
+
|
| 12 |
+
from copy import deepcopy
|
| 13 |
+
from typing import Tuple
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ResizeLongestSide:
|
| 17 |
+
"""
|
| 18 |
+
Resizes images to longest side 'target_length', as well as provides
|
| 19 |
+
methods for resizing coordinates and boxes. Provides methods for
|
| 20 |
+
transforming both numpy array and batched torch tensors.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, target_length: int) -> None:
|
| 24 |
+
self.target_length = target_length
|
| 25 |
+
|
| 26 |
+
def apply_image(self, image: np.ndarray) -> np.ndarray:
|
| 27 |
+
"""
|
| 28 |
+
Expects a numpy array with shape HxWxC in uint8 format.
|
| 29 |
+
"""
|
| 30 |
+
target_size = self.get_preprocess_shape(
|
| 31 |
+
image.shape[0], image.shape[1], self.target_length
|
| 32 |
+
)
|
| 33 |
+
return np.array(resize(to_pil_image(image), target_size))
|
| 34 |
+
|
| 35 |
+
def apply_coords(
|
| 36 |
+
self, coords: np.ndarray, original_size: Tuple[int, ...]
|
| 37 |
+
) -> np.ndarray:
|
| 38 |
+
"""
|
| 39 |
+
Expects a numpy array of length 2 in the final dimension. Requires the
|
| 40 |
+
original image size in (H, W) format.
|
| 41 |
+
"""
|
| 42 |
+
old_h, old_w = original_size
|
| 43 |
+
new_h, new_w = self.get_preprocess_shape(
|
| 44 |
+
original_size[0], original_size[1], self.target_length
|
| 45 |
+
)
|
| 46 |
+
coords = deepcopy(coords).astype(float)
|
| 47 |
+
coords[..., 0] = coords[..., 0] * (new_w / old_w)
|
| 48 |
+
coords[..., 1] = coords[..., 1] * (new_h / old_h)
|
| 49 |
+
return coords
|
| 50 |
+
|
| 51 |
+
def apply_boxes(
|
| 52 |
+
self, boxes: np.ndarray, original_size: Tuple[int, ...]
|
| 53 |
+
) -> np.ndarray:
|
| 54 |
+
"""
|
| 55 |
+
Expects a numpy array shape Bx4. Requires the original image size
|
| 56 |
+
in (H, W) format.
|
| 57 |
+
"""
|
| 58 |
+
boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
|
| 59 |
+
return boxes.reshape(-1, 4)
|
| 60 |
+
|
| 61 |
+
def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
|
| 62 |
+
"""
|
| 63 |
+
Expects batched images with shape BxCxHxW and float format. This
|
| 64 |
+
transformation may not exactly match apply_image. apply_image is
|
| 65 |
+
the transformation expected by the model.
|
| 66 |
+
"""
|
| 67 |
+
# Expects an image in BCHW format. May not exactly match apply_image.
|
| 68 |
+
target_size = self.get_preprocess_shape(
|
| 69 |
+
image.shape[0], image.shape[1], self.target_length
|
| 70 |
+
)
|
| 71 |
+
return F.interpolate(
|
| 72 |
+
image, target_size, mode="bilinear", align_corners=False, antialias=True
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
def apply_coords_torch(
|
| 76 |
+
self, coords: torch.Tensor, original_size: Tuple[int, ...]
|
| 77 |
+
) -> torch.Tensor:
|
| 78 |
+
"""
|
| 79 |
+
Expects a torch tensor with length 2 in the last dimension. Requires the
|
| 80 |
+
original image size in (H, W) format.
|
| 81 |
+
"""
|
| 82 |
+
old_h, old_w = original_size
|
| 83 |
+
new_h, new_w = self.get_preprocess_shape(
|
| 84 |
+
original_size[0], original_size[1], self.target_length
|
| 85 |
+
)
|
| 86 |
+
coords = deepcopy(coords).to(torch.float)
|
| 87 |
+
coords[..., 0] = coords[..., 0] * (new_w / old_w)
|
| 88 |
+
coords[..., 1] = coords[..., 1] * (new_h / old_h)
|
| 89 |
+
return coords
|
| 90 |
+
|
| 91 |
+
def apply_boxes_torch(
|
| 92 |
+
self, boxes: torch.Tensor, original_size: Tuple[int, ...]
|
| 93 |
+
) -> torch.Tensor:
|
| 94 |
+
"""
|
| 95 |
+
Expects a torch tensor with shape Bx4. Requires the original image
|
| 96 |
+
size in (H, W) format.
|
| 97 |
+
"""
|
| 98 |
+
boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
|
| 99 |
+
return boxes.reshape(-1, 4)
|
| 100 |
+
|
| 101 |
+
@staticmethod
|
| 102 |
+
def get_preprocess_shape(
|
| 103 |
+
oldh: int, oldw: int, long_side_length: int
|
| 104 |
+
) -> Tuple[int, int]:
|
| 105 |
+
"""
|
| 106 |
+
Compute the output size given input size and target long side length.
|
| 107 |
+
"""
|
| 108 |
+
scale = long_side_length * 1.0 / max(oldh, oldw)
|
| 109 |
+
newh, neww = oldh * scale, oldw * scale
|
| 110 |
+
neww = int(neww + 0.5)
|
| 111 |
+
newh = int(newh + 0.5)
|
| 112 |
+
return (newh, neww)
|
iopaint/tests/test_sdxl.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
from iopaint.tests.utils import check_device, current_dir
|
| 4 |
+
|
| 5 |
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
| 6 |
+
|
| 7 |
+
import pytest
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from iopaint.model_manager import ModelManager
|
| 11 |
+
from iopaint.schema import HDStrategy, SDSampler, FREEUConfig
|
| 12 |
+
from iopaint.tests.test_model import get_config, assert_equal
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@pytest.mark.parametrize("device", ["cuda", "mps"])
|
| 16 |
+
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
| 17 |
+
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
| 18 |
+
def test_sdxl(device, strategy, sampler):
|
| 19 |
+
sd_steps = check_device(device)
|
| 20 |
+
|
| 21 |
+
model = ModelManager(
|
| 22 |
+
name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
|
| 23 |
+
device=torch.device(device),
|
| 24 |
+
disable_nsfw=True,
|
| 25 |
+
sd_cpu_textencoder=False,
|
| 26 |
+
)
|
| 27 |
+
cfg = get_config(
|
| 28 |
+
strategy=strategy,
|
| 29 |
+
prompt="face of a fox, sitting on a bench",
|
| 30 |
+
sd_steps=sd_steps,
|
| 31 |
+
sd_strength=1.0,
|
| 32 |
+
sd_guidance_scale=7.0,
|
| 33 |
+
)
|
| 34 |
+
cfg.sd_sampler = sampler
|
| 35 |
+
|
| 36 |
+
assert_equal(
|
| 37 |
+
model,
|
| 38 |
+
cfg,
|
| 39 |
+
f"sdxl_device_{device}.png",
|
| 40 |
+
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
| 41 |
+
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
| 42 |
+
fx=2,
|
| 43 |
+
fy=2,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
| 48 |
+
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
| 49 |
+
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
| 50 |
+
def test_sdxl_cpu_text_encoder(device, strategy, sampler):
|
| 51 |
+
sd_steps = check_device(device)
|
| 52 |
+
|
| 53 |
+
model = ModelManager(
|
| 54 |
+
name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
|
| 55 |
+
device=torch.device(device),
|
| 56 |
+
disable_nsfw=True,
|
| 57 |
+
sd_cpu_textencoder=True,
|
| 58 |
+
)
|
| 59 |
+
cfg = get_config(
|
| 60 |
+
strategy=strategy,
|
| 61 |
+
prompt="face of a fox, sitting on a bench",
|
| 62 |
+
sd_steps=sd_steps,
|
| 63 |
+
sd_strength=1.0,
|
| 64 |
+
sd_guidance_scale=7.0,
|
| 65 |
+
)
|
| 66 |
+
cfg.sd_sampler = sampler
|
| 67 |
+
|
| 68 |
+
assert_equal(
|
| 69 |
+
model,
|
| 70 |
+
cfg,
|
| 71 |
+
f"sdxl_device_{device}.png",
|
| 72 |
+
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
| 73 |
+
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
| 74 |
+
fx=2,
|
| 75 |
+
fy=2,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@pytest.mark.parametrize("device", ["cuda", "mps"])
|
| 80 |
+
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
| 81 |
+
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
| 82 |
+
def test_sdxl_lcm_lora_and_freeu(device, strategy, sampler):
|
| 83 |
+
sd_steps = check_device(device)
|
| 84 |
+
|
| 85 |
+
model = ModelManager(
|
| 86 |
+
name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
|
| 87 |
+
device=torch.device(device),
|
| 88 |
+
disable_nsfw=True,
|
| 89 |
+
sd_cpu_textencoder=False,
|
| 90 |
+
)
|
| 91 |
+
cfg = get_config(
|
| 92 |
+
strategy=strategy,
|
| 93 |
+
prompt="face of a fox, sitting on a bench",
|
| 94 |
+
sd_steps=sd_steps,
|
| 95 |
+
sd_strength=1.0,
|
| 96 |
+
sd_guidance_scale=2.0,
|
| 97 |
+
sd_lcm_lora=True,
|
| 98 |
+
)
|
| 99 |
+
cfg.sd_sampler = sampler
|
| 100 |
+
|
| 101 |
+
name = f"device_{device}_{sampler}"
|
| 102 |
+
|
| 103 |
+
assert_equal(
|
| 104 |
+
model,
|
| 105 |
+
cfg,
|
| 106 |
+
f"sdxl_{name}_lcm_lora.png",
|
| 107 |
+
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
| 108 |
+
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
| 109 |
+
fx=2,
|
| 110 |
+
fy=2,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
cfg = get_config(
|
| 114 |
+
strategy=strategy,
|
| 115 |
+
prompt="face of a fox, sitting on a bench",
|
| 116 |
+
sd_steps=sd_steps,
|
| 117 |
+
sd_guidance_scale=7.5,
|
| 118 |
+
sd_freeu=True,
|
| 119 |
+
sd_freeu_config=FREEUConfig(),
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
assert_equal(
|
| 123 |
+
model,
|
| 124 |
+
cfg,
|
| 125 |
+
f"sdxl_{name}_freeu_device_{device}.png",
|
| 126 |
+
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
| 127 |
+
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
| 128 |
+
fx=2,
|
| 129 |
+
fy=2,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
@pytest.mark.parametrize("device", ["cuda", "mps"])
|
| 134 |
+
@pytest.mark.parametrize(
|
| 135 |
+
"rect",
|
| 136 |
+
[
|
| 137 |
+
[-128, -128, 1024, 1024],
|
| 138 |
+
],
|
| 139 |
+
)
|
| 140 |
+
def test_sdxl_outpainting(device, rect):
|
| 141 |
+
sd_steps = check_device(device)
|
| 142 |
+
|
| 143 |
+
model = ModelManager(
|
| 144 |
+
name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
|
| 145 |
+
device=torch.device(device),
|
| 146 |
+
disable_nsfw=True,
|
| 147 |
+
sd_cpu_textencoder=False,
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
cfg = get_config(
|
| 151 |
+
strategy=HDStrategy.ORIGINAL,
|
| 152 |
+
prompt="a dog sitting on a bench in the park",
|
| 153 |
+
sd_steps=sd_steps,
|
| 154 |
+
use_extender=True,
|
| 155 |
+
extender_x=rect[0],
|
| 156 |
+
extender_y=rect[1],
|
| 157 |
+
extender_width=rect[2],
|
| 158 |
+
extender_height=rect[3],
|
| 159 |
+
sd_strength=1.0,
|
| 160 |
+
sd_guidance_scale=8.0,
|
| 161 |
+
sd_sampler=SDSampler.ddim,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
assert_equal(
|
| 165 |
+
model,
|
| 166 |
+
cfg,
|
| 167 |
+
f"sdxl_outpainting_dog_ddim_{'_'.join(map(str, rect))}_device_{device}.png",
|
| 168 |
+
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
| 169 |
+
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
| 170 |
+
fx=1.5,
|
| 171 |
+
fy=1.5,
|
| 172 |
+
)
|
iopaint/tests/utils.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import cv2
|
| 3 |
+
import pytest
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from iopaint.helper import encode_pil_to_base64
|
| 7 |
+
from iopaint.schema import LDMSampler, HDStrategy, InpaintRequest, SDSampler
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
current_dir = Path(__file__).parent.absolute().resolve()
|
| 11 |
+
save_dir = current_dir / "result"
|
| 12 |
+
save_dir.mkdir(exist_ok=True, parents=True)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def check_device(device: str) -> int:
|
| 16 |
+
if device == "cuda" and not torch.cuda.is_available():
|
| 17 |
+
pytest.skip("CUDA is not available, skip test on cuda")
|
| 18 |
+
if device == "mps" and not torch.backends.mps.is_available():
|
| 19 |
+
pytest.skip("mps is not available, skip test on mps")
|
| 20 |
+
steps = 2 if device == "cpu" else 20
|
| 21 |
+
return steps
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def assert_equal(
|
| 25 |
+
model,
|
| 26 |
+
config: InpaintRequest,
|
| 27 |
+
gt_name,
|
| 28 |
+
fx: float = 1,
|
| 29 |
+
fy: float = 1,
|
| 30 |
+
img_p=current_dir / "image.png",
|
| 31 |
+
mask_p=current_dir / "mask.png",
|
| 32 |
+
):
|
| 33 |
+
img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p)
|
| 34 |
+
print(f"Input image shape: {img.shape}")
|
| 35 |
+
res = model(img, mask, config)
|
| 36 |
+
ok = cv2.imwrite(
|
| 37 |
+
str(save_dir / gt_name),
|
| 38 |
+
res,
|
| 39 |
+
[int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0],
|
| 40 |
+
)
|
| 41 |
+
assert ok, save_dir / gt_name
|
| 42 |
+
|
| 43 |
+
"""
|
| 44 |
+
Note that JPEG is lossy compression, so even if it is the highest quality 100,
|
| 45 |
+
when the saved images is reloaded, a difference occurs with the original pixel value.
|
| 46 |
+
If you want to save the original images as it is, save it as PNG or BMP.
|
| 47 |
+
"""
|
| 48 |
+
# gt = cv2.imread(str(current_dir / gt_name), cv2.IMREAD_UNCHANGED)
|
| 49 |
+
# assert np.array_equal(res, gt)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def get_data(
|
| 53 |
+
fx: float = 1,
|
| 54 |
+
fy: float = 1.0,
|
| 55 |
+
img_p=current_dir / "image.png",
|
| 56 |
+
mask_p=current_dir / "mask.png",
|
| 57 |
+
):
|
| 58 |
+
img = cv2.imread(str(img_p))
|
| 59 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
|
| 60 |
+
mask = cv2.imread(str(mask_p), cv2.IMREAD_GRAYSCALE)
|
| 61 |
+
img = cv2.resize(img, None, fx=fx, fy=fy, interpolation=cv2.INTER_AREA)
|
| 62 |
+
mask = cv2.resize(mask, None, fx=fx, fy=fy, interpolation=cv2.INTER_NEAREST)
|
| 63 |
+
return img, mask
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def get_config(**kwargs):
|
| 67 |
+
data = dict(
|
| 68 |
+
sd_sampler=kwargs.get("sd_sampler", SDSampler.uni_pc),
|
| 69 |
+
ldm_steps=1,
|
| 70 |
+
ldm_sampler=LDMSampler.plms,
|
| 71 |
+
hd_strategy=kwargs.get("strategy", HDStrategy.ORIGINAL),
|
| 72 |
+
hd_strategy_crop_margin=32,
|
| 73 |
+
hd_strategy_crop_trigger_size=200,
|
| 74 |
+
hd_strategy_resize_limit=200,
|
| 75 |
+
)
|
| 76 |
+
data.update(**kwargs)
|
| 77 |
+
return InpaintRequest(image="", mask="", **data)
|
iopaint/web_config.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
from iopaint.schema import (
|
| 6 |
+
Device,
|
| 7 |
+
InteractiveSegModel,
|
| 8 |
+
RemoveBGModel,
|
| 9 |
+
RealESRGANModel,
|
| 10 |
+
ApiConfig,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
| 14 |
+
|
| 15 |
+
from datetime import datetime
|
| 16 |
+
from json import JSONDecodeError
|
| 17 |
+
|
| 18 |
+
import gradio as gr
|
| 19 |
+
from iopaint.download import scan_models
|
| 20 |
+
from loguru import logger
|
| 21 |
+
|
| 22 |
+
from iopaint.const import *
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
_config_file: Path = None
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
default_configs = dict(
|
| 29 |
+
host="127.0.0.1",
|
| 30 |
+
port=8080,
|
| 31 |
+
inbrowser=True,
|
| 32 |
+
model=DEFAULT_MODEL,
|
| 33 |
+
model_dir=DEFAULT_MODEL_DIR,
|
| 34 |
+
no_half=False,
|
| 35 |
+
low_mem=False,
|
| 36 |
+
cpu_offload=False,
|
| 37 |
+
disable_nsfw_checker=False,
|
| 38 |
+
local_files_only=False,
|
| 39 |
+
cpu_textencoder=False,
|
| 40 |
+
device=Device.cuda,
|
| 41 |
+
input=None,
|
| 42 |
+
output_dir=None,
|
| 43 |
+
quality=95,
|
| 44 |
+
enable_interactive_seg=False,
|
| 45 |
+
interactive_seg_model=InteractiveSegModel.vit_b,
|
| 46 |
+
interactive_seg_device=Device.cpu,
|
| 47 |
+
enable_remove_bg=False,
|
| 48 |
+
remove_bg_model=RemoveBGModel.briaai_rmbg_1_4,
|
| 49 |
+
enable_anime_seg=False,
|
| 50 |
+
enable_realesrgan=False,
|
| 51 |
+
realesrgan_device=Device.cpu,
|
| 52 |
+
realesrgan_model=RealESRGANModel.realesr_general_x4v3,
|
| 53 |
+
enable_gfpgan=False,
|
| 54 |
+
gfpgan_device=Device.cpu,
|
| 55 |
+
enable_restoreformer=False,
|
| 56 |
+
restoreformer_device=Device.cpu,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class WebConfig(ApiConfig):
|
| 61 |
+
model_dir: str = DEFAULT_MODEL_DIR
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def load_config(p: Path) -> WebConfig:
|
| 65 |
+
if p.exists():
|
| 66 |
+
with open(p, "r", encoding="utf-8") as f:
|
| 67 |
+
try:
|
| 68 |
+
return WebConfig(**{**default_configs, **json.load(f)})
|
| 69 |
+
except JSONDecodeError:
|
| 70 |
+
print(f"Load config file failed, using default configs")
|
| 71 |
+
return WebConfig(**default_configs)
|
| 72 |
+
else:
|
| 73 |
+
return WebConfig(**default_configs)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def save_config(
|
| 77 |
+
host,
|
| 78 |
+
port,
|
| 79 |
+
model,
|
| 80 |
+
model_dir,
|
| 81 |
+
no_half,
|
| 82 |
+
low_mem,
|
| 83 |
+
cpu_offload,
|
| 84 |
+
disable_nsfw_checker,
|
| 85 |
+
local_files_only,
|
| 86 |
+
cpu_textencoder,
|
| 87 |
+
device,
|
| 88 |
+
input,
|
| 89 |
+
output_dir,
|
| 90 |
+
quality,
|
| 91 |
+
enable_interactive_seg,
|
| 92 |
+
interactive_seg_model,
|
| 93 |
+
interactive_seg_device,
|
| 94 |
+
enable_remove_bg,
|
| 95 |
+
remove_bg_model,
|
| 96 |
+
enable_anime_seg,
|
| 97 |
+
enable_realesrgan,
|
| 98 |
+
realesrgan_device,
|
| 99 |
+
realesrgan_model,
|
| 100 |
+
enable_gfpgan,
|
| 101 |
+
gfpgan_device,
|
| 102 |
+
enable_restoreformer,
|
| 103 |
+
restoreformer_device,
|
| 104 |
+
inbrowser,
|
| 105 |
+
):
|
| 106 |
+
config = WebConfig(**locals())
|
| 107 |
+
if str(config.input) == ".":
|
| 108 |
+
config.input = None
|
| 109 |
+
if str(config.output_dir) == ".":
|
| 110 |
+
config.output_dir = None
|
| 111 |
+
config.model = config.model.strip()
|
| 112 |
+
print(config.model_dump_json(indent=4))
|
| 113 |
+
if config.input and not os.path.exists(config.input):
|
| 114 |
+
return "[Error] Input file or directory does not exist"
|
| 115 |
+
|
| 116 |
+
current_time = datetime.now().strftime("%H:%M:%S")
|
| 117 |
+
msg = f"[{current_time}] Successful save config to: {str(_config_file.absolute())}"
|
| 118 |
+
logger.info(msg)
|
| 119 |
+
try:
|
| 120 |
+
with open(_config_file, "w", encoding="utf-8") as f:
|
| 121 |
+
f.write(config.model_dump_json(indent=4))
|
| 122 |
+
except Exception as e:
|
| 123 |
+
return f"Save configure file failed: {str(e)}"
|
| 124 |
+
return msg
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def change_current_model(new_model):
|
| 128 |
+
return new_model
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def main(config_file: Path):
|
| 132 |
+
global _config_file
|
| 133 |
+
_config_file = config_file
|
| 134 |
+
|
| 135 |
+
init_config = load_config(config_file)
|
| 136 |
+
downloaded_models = [it.name for it in scan_models()]
|
| 137 |
+
|
| 138 |
+
with gr.Blocks() as demo:
|
| 139 |
+
with gr.Row():
|
| 140 |
+
with gr.Column():
|
| 141 |
+
gr.Textbox(config_file, label="Config file", interactive=False)
|
| 142 |
+
with gr.Column():
|
| 143 |
+
save_btn = gr.Button(value="Save configurations")
|
| 144 |
+
message = gr.HTML()
|
| 145 |
+
|
| 146 |
+
with gr.Tabs():
|
| 147 |
+
with gr.Tab("Common"):
|
| 148 |
+
with gr.Row():
|
| 149 |
+
host = gr.Textbox(init_config.host, label="Host")
|
| 150 |
+
port = gr.Number(init_config.port, label="Port", precision=0)
|
| 151 |
+
inbrowser = gr.Checkbox(init_config.inbrowser, label=INBROWSER_HELP)
|
| 152 |
+
|
| 153 |
+
with gr.Column():
|
| 154 |
+
model = gr.Textbox(
|
| 155 |
+
init_config.model,
|
| 156 |
+
label="Current Model. This is the model that will be used when the service starts. "
|
| 157 |
+
"If the model has not been downloaded before, it will be automatically downloaded. "
|
| 158 |
+
"You can select a model from the dropdown box below or manually enter the SD/SDXL model ID from HuggingFace, for example, runwayml/stable-diffusion-inpainting.",
|
| 159 |
+
)
|
| 160 |
+
with gr.Row():
|
| 161 |
+
recommend_model = gr.Dropdown(
|
| 162 |
+
["lama", "mat", "migan"] + DIFFUSION_MODELS,
|
| 163 |
+
label="Recommended Models",
|
| 164 |
+
)
|
| 165 |
+
downloaded_model = gr.Dropdown(
|
| 166 |
+
downloaded_models, label="Downloaded Models"
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
device = gr.Radio(
|
| 170 |
+
Device.values(), label="Device", value=init_config.device
|
| 171 |
+
)
|
| 172 |
+
quality = gr.Slider(
|
| 173 |
+
value=95,
|
| 174 |
+
label=f"Image Quality ({QUALITY_HELP})",
|
| 175 |
+
minimum=75,
|
| 176 |
+
maximum=100,
|
| 177 |
+
step=1,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
no_half = gr.Checkbox(init_config.no_half, label=f"{NO_HALF_HELP}")
|
| 181 |
+
cpu_offload = gr.Checkbox(
|
| 182 |
+
init_config.cpu_offload, label=f"{CPU_OFFLOAD_HELP}"
|
| 183 |
+
)
|
| 184 |
+
low_mem = gr.Checkbox(init_config.low_mem, label=f"{LOW_MEM_HELP}")
|
| 185 |
+
cpu_textencoder = gr.Checkbox(
|
| 186 |
+
init_config.cpu_textencoder, label=f"{CPU_TEXTENCODER_HELP}"
|
| 187 |
+
)
|
| 188 |
+
disable_nsfw_checker = gr.Checkbox(
|
| 189 |
+
init_config.disable_nsfw_checker, label=f"{DISABLE_NSFW_HELP}"
|
| 190 |
+
)
|
| 191 |
+
local_files_only = gr.Checkbox(
|
| 192 |
+
init_config.local_files_only, label=f"{LOCAL_FILES_ONLY_HELP}"
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
with gr.Column():
|
| 196 |
+
model_dir = gr.Textbox(
|
| 197 |
+
init_config.model_dir, label=f"{MODEL_DIR_HELP}"
|
| 198 |
+
)
|
| 199 |
+
input = gr.Textbox(
|
| 200 |
+
init_config.input,
|
| 201 |
+
label=f"Input file or directory. {INPUT_HELP}",
|
| 202 |
+
)
|
| 203 |
+
output_dir = gr.Textbox(
|
| 204 |
+
init_config.output_dir,
|
| 205 |
+
label=f"Output directory. {OUTPUT_DIR_HELP}",
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
with gr.Tab("Plugins"):
|
| 209 |
+
with gr.Row():
|
| 210 |
+
enable_interactive_seg = gr.Checkbox(
|
| 211 |
+
init_config.enable_interactive_seg, label=INTERACTIVE_SEG_HELP
|
| 212 |
+
)
|
| 213 |
+
interactive_seg_model = gr.Radio(
|
| 214 |
+
InteractiveSegModel.values(),
|
| 215 |
+
label=f"Segment Anything models. {INTERACTIVE_SEG_MODEL_HELP}",
|
| 216 |
+
value=init_config.interactive_seg_model,
|
| 217 |
+
)
|
| 218 |
+
interactive_seg_device = gr.Radio(
|
| 219 |
+
Device.values(),
|
| 220 |
+
label="Segment Anything Device",
|
| 221 |
+
value=init_config.interactive_seg_device,
|
| 222 |
+
)
|
| 223 |
+
with gr.Row():
|
| 224 |
+
enable_remove_bg = gr.Checkbox(
|
| 225 |
+
init_config.enable_remove_bg, label=REMOVE_BG_HELP
|
| 226 |
+
)
|
| 227 |
+
remove_bg_model = gr.Radio(
|
| 228 |
+
RemoveBGModel.values(),
|
| 229 |
+
label="Remove bg model",
|
| 230 |
+
value=init_config.remove_bg_model,
|
| 231 |
+
)
|
| 232 |
+
with gr.Row():
|
| 233 |
+
enable_anime_seg = gr.Checkbox(
|
| 234 |
+
init_config.enable_anime_seg, label=ANIMESEG_HELP
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
with gr.Row():
|
| 238 |
+
enable_realesrgan = gr.Checkbox(
|
| 239 |
+
init_config.enable_realesrgan, label=REALESRGAN_HELP
|
| 240 |
+
)
|
| 241 |
+
realesrgan_device = gr.Radio(
|
| 242 |
+
Device.values(),
|
| 243 |
+
label="RealESRGAN Device",
|
| 244 |
+
value=init_config.realesrgan_device,
|
| 245 |
+
)
|
| 246 |
+
realesrgan_model = gr.Radio(
|
| 247 |
+
RealESRGANModel.values(),
|
| 248 |
+
label="RealESRGAN model",
|
| 249 |
+
value=init_config.realesrgan_model,
|
| 250 |
+
)
|
| 251 |
+
with gr.Row():
|
| 252 |
+
enable_gfpgan = gr.Checkbox(
|
| 253 |
+
init_config.enable_gfpgan, label=GFPGAN_HELP
|
| 254 |
+
)
|
| 255 |
+
gfpgan_device = gr.Radio(
|
| 256 |
+
Device.values(),
|
| 257 |
+
label="GFPGAN Device",
|
| 258 |
+
value=init_config.gfpgan_device,
|
| 259 |
+
)
|
| 260 |
+
with gr.Row():
|
| 261 |
+
enable_restoreformer = gr.Checkbox(
|
| 262 |
+
init_config.enable_restoreformer, label=RESTOREFORMER_HELP
|
| 263 |
+
)
|
| 264 |
+
restoreformer_device = gr.Radio(
|
| 265 |
+
Device.values(),
|
| 266 |
+
label="RestoreFormer Device",
|
| 267 |
+
value=init_config.restoreformer_device,
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
downloaded_model.change(change_current_model, [downloaded_model], model)
|
| 271 |
+
recommend_model.change(change_current_model, [recommend_model], model)
|
| 272 |
+
|
| 273 |
+
save_btn.click(
|
| 274 |
+
save_config,
|
| 275 |
+
[
|
| 276 |
+
host,
|
| 277 |
+
port,
|
| 278 |
+
model,
|
| 279 |
+
model_dir,
|
| 280 |
+
no_half,
|
| 281 |
+
low_mem,
|
| 282 |
+
cpu_offload,
|
| 283 |
+
disable_nsfw_checker,
|
| 284 |
+
local_files_only,
|
| 285 |
+
cpu_textencoder,
|
| 286 |
+
device,
|
| 287 |
+
input,
|
| 288 |
+
output_dir,
|
| 289 |
+
quality,
|
| 290 |
+
enable_interactive_seg,
|
| 291 |
+
interactive_seg_model,
|
| 292 |
+
interactive_seg_device,
|
| 293 |
+
enable_remove_bg,
|
| 294 |
+
remove_bg_model,
|
| 295 |
+
enable_anime_seg,
|
| 296 |
+
enable_realesrgan,
|
| 297 |
+
realesrgan_device,
|
| 298 |
+
realesrgan_model,
|
| 299 |
+
enable_gfpgan,
|
| 300 |
+
gfpgan_device,
|
| 301 |
+
enable_restoreformer,
|
| 302 |
+
restoreformer_device,
|
| 303 |
+
inbrowser,
|
| 304 |
+
],
|
| 305 |
+
message,
|
| 306 |
+
)
|
| 307 |
+
demo.launch(inbrowser=True, show_api=False)
|
pretrained-model/version.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
1
|
pretrained-model/version_diffusers_cache.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
1
|
utils/tools.py
ADDED
|
@@ -0,0 +1,505 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import yaml
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def pil_loader(path):
|
| 11 |
+
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
|
| 12 |
+
with open(path, 'rb') as f:
|
| 13 |
+
img = Image.open(f)
|
| 14 |
+
return img.convert('RGB')
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def default_loader(path):
|
| 18 |
+
return pil_loader(path)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def tensor_img_to_npimg(tensor_img):
|
| 22 |
+
"""
|
| 23 |
+
Turn a tensor image with shape CxHxW to a numpy array image with shape HxWxC
|
| 24 |
+
:param tensor_img:
|
| 25 |
+
:return: a numpy array image with shape HxWxC
|
| 26 |
+
"""
|
| 27 |
+
if not (torch.is_tensor(tensor_img) and tensor_img.ndimension() == 3):
|
| 28 |
+
raise NotImplementedError("Not supported tensor image. Only tensors with dimension CxHxW are supported.")
|
| 29 |
+
npimg = np.transpose(tensor_img.numpy(), (1, 2, 0))
|
| 30 |
+
npimg = npimg.squeeze()
|
| 31 |
+
assert isinstance(npimg, np.ndarray) and (npimg.ndim in {2, 3})
|
| 32 |
+
return npimg
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# Change the values of tensor x from range [0, 1] to [-1, 1]
|
| 36 |
+
def normalize(x):
|
| 37 |
+
return x.mul_(2).add_(-1)
|
| 38 |
+
|
| 39 |
+
def same_padding(images, ksizes, strides, rates):
|
| 40 |
+
assert len(images.size()) == 4
|
| 41 |
+
batch_size, channel, rows, cols = images.size()
|
| 42 |
+
out_rows = (rows + strides[0] - 1) // strides[0]
|
| 43 |
+
out_cols = (cols + strides[1] - 1) // strides[1]
|
| 44 |
+
effective_k_row = (ksizes[0] - 1) * rates[0] + 1
|
| 45 |
+
effective_k_col = (ksizes[1] - 1) * rates[1] + 1
|
| 46 |
+
padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows)
|
| 47 |
+
padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols)
|
| 48 |
+
# Pad the input
|
| 49 |
+
padding_top = int(padding_rows / 2.)
|
| 50 |
+
padding_left = int(padding_cols / 2.)
|
| 51 |
+
padding_bottom = padding_rows - padding_top
|
| 52 |
+
padding_right = padding_cols - padding_left
|
| 53 |
+
paddings = (padding_left, padding_right, padding_top, padding_bottom)
|
| 54 |
+
images = torch.nn.ZeroPad2d(paddings)(images)
|
| 55 |
+
return images
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def extract_image_patches(images, ksizes, strides, rates, padding='same'):
|
| 59 |
+
"""
|
| 60 |
+
Extract patches from images and put them in the C output dimension.
|
| 61 |
+
:param padding:
|
| 62 |
+
:param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape
|
| 63 |
+
:param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for
|
| 64 |
+
each dimension of images
|
| 65 |
+
:param strides: [stride_rows, stride_cols]
|
| 66 |
+
:param rates: [dilation_rows, dilation_cols]
|
| 67 |
+
:return: A Tensor
|
| 68 |
+
"""
|
| 69 |
+
assert len(images.size()) == 4
|
| 70 |
+
assert padding in ['same', 'valid']
|
| 71 |
+
batch_size, channel, height, width = images.size()
|
| 72 |
+
|
| 73 |
+
if padding == 'same':
|
| 74 |
+
images = same_padding(images, ksizes, strides, rates)
|
| 75 |
+
elif padding == 'valid':
|
| 76 |
+
pass
|
| 77 |
+
else:
|
| 78 |
+
raise NotImplementedError('Unsupported padding type: {}.\
|
| 79 |
+
Only "same" or "valid" are supported.'.format(padding))
|
| 80 |
+
|
| 81 |
+
unfold = torch.nn.Unfold(kernel_size=ksizes,
|
| 82 |
+
dilation=rates,
|
| 83 |
+
padding=0,
|
| 84 |
+
stride=strides)
|
| 85 |
+
patches = unfold(images)
|
| 86 |
+
return patches # [N, C*k*k, L], L is the total number of such blocks
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def random_bbox(config, batch_size):
|
| 90 |
+
"""Generate a random tlhw with configuration.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
config: Config should have configuration including img
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
tuple: (top, left, height, width)
|
| 97 |
+
|
| 98 |
+
"""
|
| 99 |
+
img_height, img_width, _ = config['image_shape']
|
| 100 |
+
h, w = config['mask_shape']
|
| 101 |
+
margin_height, margin_width = config['margin']
|
| 102 |
+
maxt = img_height - margin_height - h
|
| 103 |
+
maxl = img_width - margin_width - w
|
| 104 |
+
bbox_list = []
|
| 105 |
+
if config['mask_batch_same']:
|
| 106 |
+
t = np.random.randint(margin_height, maxt)
|
| 107 |
+
l = np.random.randint(margin_width, maxl)
|
| 108 |
+
bbox_list.append((t, l, h, w))
|
| 109 |
+
bbox_list = bbox_list * batch_size
|
| 110 |
+
else:
|
| 111 |
+
for i in range(batch_size):
|
| 112 |
+
t = np.random.randint(margin_height, maxt)
|
| 113 |
+
l = np.random.randint(margin_width, maxl)
|
| 114 |
+
bbox_list.append((t, l, h, w))
|
| 115 |
+
|
| 116 |
+
return torch.tensor(bbox_list, dtype=torch.int64)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def test_random_bbox():
|
| 120 |
+
image_shape = [256, 256, 3]
|
| 121 |
+
mask_shape = [128, 128]
|
| 122 |
+
margin = [0, 0]
|
| 123 |
+
bbox = random_bbox(image_shape)
|
| 124 |
+
return bbox
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def bbox2mask(bboxes, height, width, max_delta_h, max_delta_w):
|
| 128 |
+
batch_size = bboxes.size(0)
|
| 129 |
+
mask = torch.zeros((batch_size, 1, height, width), dtype=torch.float32)
|
| 130 |
+
for i in range(batch_size):
|
| 131 |
+
bbox = bboxes[i]
|
| 132 |
+
delta_h = np.random.randint(max_delta_h // 2 + 1)
|
| 133 |
+
delta_w = np.random.randint(max_delta_w // 2 + 1)
|
| 134 |
+
mask[i, :, bbox[0] + delta_h:bbox[0] + bbox[2] - delta_h, bbox[1] + delta_w:bbox[1] + bbox[3] - delta_w] = 1.
|
| 135 |
+
return mask
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def test_bbox2mask():
|
| 139 |
+
image_shape = [256, 256, 3]
|
| 140 |
+
mask_shape = [128, 128]
|
| 141 |
+
margin = [0, 0]
|
| 142 |
+
max_delta_shape = [32, 32]
|
| 143 |
+
bbox = random_bbox(image_shape)
|
| 144 |
+
mask = bbox2mask(bbox, image_shape[0], image_shape[1], max_delta_shape[0], max_delta_shape[1])
|
| 145 |
+
return mask
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def local_patch(x, bbox_list):
|
| 149 |
+
assert len(x.size()) == 4
|
| 150 |
+
patches = []
|
| 151 |
+
for i, bbox in enumerate(bbox_list):
|
| 152 |
+
t, l, h, w = bbox
|
| 153 |
+
patches.append(x[i, :, t:t + h, l:l + w])
|
| 154 |
+
return torch.stack(patches, dim=0)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def mask_image(x, bboxes, config):
|
| 158 |
+
height, width, _ = config['image_shape']
|
| 159 |
+
max_delta_h, max_delta_w = config['max_delta_shape']
|
| 160 |
+
mask = bbox2mask(bboxes, height, width, max_delta_h, max_delta_w)
|
| 161 |
+
if x.is_cuda:
|
| 162 |
+
mask = mask.cuda()
|
| 163 |
+
|
| 164 |
+
if config['mask_type'] == 'hole':
|
| 165 |
+
result = x * (1. - mask)
|
| 166 |
+
elif config['mask_type'] == 'mosaic':
|
| 167 |
+
# TODO: Matching the mosaic patch size and the mask size
|
| 168 |
+
mosaic_unit_size = config['mosaic_unit_size']
|
| 169 |
+
downsampled_image = F.interpolate(x, scale_factor=1. / mosaic_unit_size, mode='nearest')
|
| 170 |
+
upsampled_image = F.interpolate(downsampled_image, size=(height, width), mode='nearest')
|
| 171 |
+
result = upsampled_image * mask + x * (1. - mask)
|
| 172 |
+
else:
|
| 173 |
+
raise NotImplementedError('Not implemented mask type.')
|
| 174 |
+
|
| 175 |
+
return result, mask
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def spatial_discounting_mask(config):
|
| 179 |
+
"""Generate spatial discounting mask constant.
|
| 180 |
+
|
| 181 |
+
Spatial discounting mask is first introduced in publication:
|
| 182 |
+
Generative Image Inpainting with Contextual Attention, Yu et al.
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
config: Config should have configuration including HEIGHT, WIDTH,
|
| 186 |
+
DISCOUNTED_MASK.
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
tf.Tensor: spatial discounting mask
|
| 190 |
+
|
| 191 |
+
"""
|
| 192 |
+
gamma = config['spatial_discounting_gamma']
|
| 193 |
+
height, width = config['mask_shape']
|
| 194 |
+
shape = [1, 1, height, width]
|
| 195 |
+
if config['discounted_mask']:
|
| 196 |
+
mask_values = np.ones((height, width))
|
| 197 |
+
for i in range(height):
|
| 198 |
+
for j in range(width):
|
| 199 |
+
mask_values[i, j] = max(
|
| 200 |
+
gamma ** min(i, height - i),
|
| 201 |
+
gamma ** min(j, width - j))
|
| 202 |
+
mask_values = np.expand_dims(mask_values, 0)
|
| 203 |
+
mask_values = np.expand_dims(mask_values, 0)
|
| 204 |
+
else:
|
| 205 |
+
mask_values = np.ones(shape)
|
| 206 |
+
spatial_discounting_mask_tensor = torch.tensor(mask_values, dtype=torch.float32)
|
| 207 |
+
if config['cuda']:
|
| 208 |
+
spatial_discounting_mask_tensor = spatial_discounting_mask_tensor.cuda()
|
| 209 |
+
return spatial_discounting_mask_tensor
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def reduce_mean(x, axis=None, keepdim=False):
|
| 213 |
+
if not axis:
|
| 214 |
+
axis = range(len(x.shape))
|
| 215 |
+
for i in sorted(axis, reverse=True):
|
| 216 |
+
x = torch.mean(x, dim=i, keepdim=keepdim)
|
| 217 |
+
return x
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def reduce_std(x, axis=None, keepdim=False):
|
| 221 |
+
if not axis:
|
| 222 |
+
axis = range(len(x.shape))
|
| 223 |
+
for i in sorted(axis, reverse=True):
|
| 224 |
+
x = torch.std(x, dim=i, keepdim=keepdim)
|
| 225 |
+
return x
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def reduce_sum(x, axis=None, keepdim=False):
|
| 229 |
+
if not axis:
|
| 230 |
+
axis = range(len(x.shape))
|
| 231 |
+
for i in sorted(axis, reverse=True):
|
| 232 |
+
x = torch.sum(x, dim=i, keepdim=keepdim)
|
| 233 |
+
return x
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def flow_to_image(flow):
|
| 237 |
+
"""Transfer flow map to image.
|
| 238 |
+
Part of code forked from flownet.
|
| 239 |
+
"""
|
| 240 |
+
out = []
|
| 241 |
+
maxu = -999.
|
| 242 |
+
maxv = -999.
|
| 243 |
+
minu = 999.
|
| 244 |
+
minv = 999.
|
| 245 |
+
maxrad = -1
|
| 246 |
+
for i in range(flow.shape[0]):
|
| 247 |
+
u = flow[i, :, :, 0]
|
| 248 |
+
v = flow[i, :, :, 1]
|
| 249 |
+
idxunknow = (abs(u) > 1e7) | (abs(v) > 1e7)
|
| 250 |
+
u[idxunknow] = 0
|
| 251 |
+
v[idxunknow] = 0
|
| 252 |
+
maxu = max(maxu, np.max(u))
|
| 253 |
+
minu = min(minu, np.min(u))
|
| 254 |
+
maxv = max(maxv, np.max(v))
|
| 255 |
+
minv = min(minv, np.min(v))
|
| 256 |
+
rad = np.sqrt(u ** 2 + v ** 2)
|
| 257 |
+
maxrad = max(maxrad, np.max(rad))
|
| 258 |
+
u = u / (maxrad + np.finfo(float).eps)
|
| 259 |
+
v = v / (maxrad + np.finfo(float).eps)
|
| 260 |
+
img = compute_color(u, v)
|
| 261 |
+
out.append(img)
|
| 262 |
+
return np.float32(np.uint8(out))
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def pt_flow_to_image(flow):
|
| 266 |
+
"""Transfer flow map to image.
|
| 267 |
+
Part of code forked from flownet.
|
| 268 |
+
"""
|
| 269 |
+
out = []
|
| 270 |
+
maxu = torch.tensor(-999)
|
| 271 |
+
maxv = torch.tensor(-999)
|
| 272 |
+
minu = torch.tensor(999)
|
| 273 |
+
minv = torch.tensor(999)
|
| 274 |
+
maxrad = torch.tensor(-1)
|
| 275 |
+
if torch.cuda.is_available():
|
| 276 |
+
maxu = maxu.cuda()
|
| 277 |
+
maxv = maxv.cuda()
|
| 278 |
+
minu = minu.cuda()
|
| 279 |
+
minv = minv.cuda()
|
| 280 |
+
maxrad = maxrad.cuda()
|
| 281 |
+
for i in range(flow.shape[0]):
|
| 282 |
+
u = flow[i, 0, :, :]
|
| 283 |
+
v = flow[i, 1, :, :]
|
| 284 |
+
idxunknow = (torch.abs(u) > 1e7) + (torch.abs(v) > 1e7)
|
| 285 |
+
u[idxunknow] = 0
|
| 286 |
+
v[idxunknow] = 0
|
| 287 |
+
maxu = torch.max(maxu, torch.max(u))
|
| 288 |
+
minu = torch.min(minu, torch.min(u))
|
| 289 |
+
maxv = torch.max(maxv, torch.max(v))
|
| 290 |
+
minv = torch.min(minv, torch.min(v))
|
| 291 |
+
rad = torch.sqrt((u ** 2 + v ** 2).float()).to(torch.int64)
|
| 292 |
+
maxrad = torch.max(maxrad, torch.max(rad))
|
| 293 |
+
u = u / (maxrad + torch.finfo(torch.float32).eps)
|
| 294 |
+
v = v / (maxrad + torch.finfo(torch.float32).eps)
|
| 295 |
+
# TODO: change the following to pytorch
|
| 296 |
+
img = pt_compute_color(u, v)
|
| 297 |
+
out.append(img)
|
| 298 |
+
|
| 299 |
+
return torch.stack(out, dim=0)
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def highlight_flow(flow):
|
| 303 |
+
"""Convert flow into middlebury color code image.
|
| 304 |
+
"""
|
| 305 |
+
out = []
|
| 306 |
+
s = flow.shape
|
| 307 |
+
for i in range(flow.shape[0]):
|
| 308 |
+
img = np.ones((s[1], s[2], 3)) * 144.
|
| 309 |
+
u = flow[i, :, :, 0]
|
| 310 |
+
v = flow[i, :, :, 1]
|
| 311 |
+
for h in range(s[1]):
|
| 312 |
+
for w in range(s[1]):
|
| 313 |
+
ui = u[h, w]
|
| 314 |
+
vi = v[h, w]
|
| 315 |
+
img[ui, vi, :] = 255.
|
| 316 |
+
out.append(img)
|
| 317 |
+
return np.float32(np.uint8(out))
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
def pt_highlight_flow(flow):
|
| 321 |
+
"""Convert flow into middlebury color code image.
|
| 322 |
+
"""
|
| 323 |
+
out = []
|
| 324 |
+
s = flow.shape
|
| 325 |
+
for i in range(flow.shape[0]):
|
| 326 |
+
img = np.ones((s[1], s[2], 3)) * 144.
|
| 327 |
+
u = flow[i, :, :, 0]
|
| 328 |
+
v = flow[i, :, :, 1]
|
| 329 |
+
for h in range(s[1]):
|
| 330 |
+
for w in range(s[1]):
|
| 331 |
+
ui = u[h, w]
|
| 332 |
+
vi = v[h, w]
|
| 333 |
+
img[ui, vi, :] = 255.
|
| 334 |
+
out.append(img)
|
| 335 |
+
return np.float32(np.uint8(out))
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def compute_color(u, v):
|
| 339 |
+
h, w = u.shape
|
| 340 |
+
img = np.zeros([h, w, 3])
|
| 341 |
+
nanIdx = np.isnan(u) | np.isnan(v)
|
| 342 |
+
u[nanIdx] = 0
|
| 343 |
+
v[nanIdx] = 0
|
| 344 |
+
# colorwheel = COLORWHEEL
|
| 345 |
+
colorwheel = make_color_wheel()
|
| 346 |
+
ncols = np.size(colorwheel, 0)
|
| 347 |
+
rad = np.sqrt(u ** 2 + v ** 2)
|
| 348 |
+
a = np.arctan2(-v, -u) / np.pi
|
| 349 |
+
fk = (a + 1) / 2 * (ncols - 1) + 1
|
| 350 |
+
k0 = np.floor(fk).astype(int)
|
| 351 |
+
k1 = k0 + 1
|
| 352 |
+
k1[k1 == ncols + 1] = 1
|
| 353 |
+
f = fk - k0
|
| 354 |
+
for i in range(np.size(colorwheel, 1)):
|
| 355 |
+
tmp = colorwheel[:, i]
|
| 356 |
+
col0 = tmp[k0 - 1] / 255
|
| 357 |
+
col1 = tmp[k1 - 1] / 255
|
| 358 |
+
col = (1 - f) * col0 + f * col1
|
| 359 |
+
idx = rad <= 1
|
| 360 |
+
col[idx] = 1 - rad[idx] * (1 - col[idx])
|
| 361 |
+
notidx = np.logical_not(idx)
|
| 362 |
+
col[notidx] *= 0.75
|
| 363 |
+
img[:, :, i] = np.uint8(np.floor(255 * col * (1 - nanIdx)))
|
| 364 |
+
return img
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def pt_compute_color(u, v):
|
| 368 |
+
h, w = u.shape
|
| 369 |
+
img = torch.zeros([3, h, w])
|
| 370 |
+
if torch.cuda.is_available():
|
| 371 |
+
img = img.cuda()
|
| 372 |
+
nanIdx = (torch.isnan(u) + torch.isnan(v)) != 0
|
| 373 |
+
u[nanIdx] = 0.
|
| 374 |
+
v[nanIdx] = 0.
|
| 375 |
+
# colorwheel = COLORWHEEL
|
| 376 |
+
colorwheel = pt_make_color_wheel()
|
| 377 |
+
if torch.cuda.is_available():
|
| 378 |
+
colorwheel = colorwheel.cuda()
|
| 379 |
+
ncols = colorwheel.size()[0]
|
| 380 |
+
rad = torch.sqrt((u ** 2 + v ** 2).to(torch.float32))
|
| 381 |
+
a = torch.atan2(-v.to(torch.float32), -u.to(torch.float32)) / np.pi
|
| 382 |
+
fk = (a + 1) / 2 * (ncols - 1) + 1
|
| 383 |
+
k0 = torch.floor(fk).to(torch.int64)
|
| 384 |
+
k1 = k0 + 1
|
| 385 |
+
k1[k1 == ncols + 1] = 1
|
| 386 |
+
f = fk - k0.to(torch.float32)
|
| 387 |
+
for i in range(colorwheel.size()[1]):
|
| 388 |
+
tmp = colorwheel[:, i]
|
| 389 |
+
col0 = tmp[k0 - 1]
|
| 390 |
+
col1 = tmp[k1 - 1]
|
| 391 |
+
col = (1 - f) * col0 + f * col1
|
| 392 |
+
idx = rad <= 1. / 255.
|
| 393 |
+
col[idx] = 1 - rad[idx] * (1 - col[idx])
|
| 394 |
+
notidx = (idx != 0)
|
| 395 |
+
col[notidx] *= 0.75
|
| 396 |
+
img[i, :, :] = col * (1 - nanIdx).to(torch.float32)
|
| 397 |
+
return img
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
def make_color_wheel():
|
| 401 |
+
RY, YG, GC, CB, BM, MR = (15, 6, 4, 11, 13, 6)
|
| 402 |
+
ncols = RY + YG + GC + CB + BM + MR
|
| 403 |
+
colorwheel = np.zeros([ncols, 3])
|
| 404 |
+
col = 0
|
| 405 |
+
# RY
|
| 406 |
+
colorwheel[0:RY, 0] = 255
|
| 407 |
+
colorwheel[0:RY, 1] = np.transpose(np.floor(255 * np.arange(0, RY) / RY))
|
| 408 |
+
col += RY
|
| 409 |
+
# YG
|
| 410 |
+
colorwheel[col:col + YG, 0] = 255 - np.transpose(np.floor(255 * np.arange(0, YG) / YG))
|
| 411 |
+
colorwheel[col:col + YG, 1] = 255
|
| 412 |
+
col += YG
|
| 413 |
+
# GC
|
| 414 |
+
colorwheel[col:col + GC, 1] = 255
|
| 415 |
+
colorwheel[col:col + GC, 2] = np.transpose(np.floor(255 * np.arange(0, GC) / GC))
|
| 416 |
+
col += GC
|
| 417 |
+
# CB
|
| 418 |
+
colorwheel[col:col + CB, 1] = 255 - np.transpose(np.floor(255 * np.arange(0, CB) / CB))
|
| 419 |
+
colorwheel[col:col + CB, 2] = 255
|
| 420 |
+
col += CB
|
| 421 |
+
# BM
|
| 422 |
+
colorwheel[col:col + BM, 2] = 255
|
| 423 |
+
colorwheel[col:col + BM, 0] = np.transpose(np.floor(255 * np.arange(0, BM) / BM))
|
| 424 |
+
col += + BM
|
| 425 |
+
# MR
|
| 426 |
+
colorwheel[col:col + MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))
|
| 427 |
+
colorwheel[col:col + MR, 0] = 255
|
| 428 |
+
return colorwheel
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
def pt_make_color_wheel():
|
| 432 |
+
RY, YG, GC, CB, BM, MR = (15, 6, 4, 11, 13, 6)
|
| 433 |
+
ncols = RY + YG + GC + CB + BM + MR
|
| 434 |
+
colorwheel = torch.zeros([ncols, 3])
|
| 435 |
+
col = 0
|
| 436 |
+
# RY
|
| 437 |
+
colorwheel[0:RY, 0] = 1.
|
| 438 |
+
colorwheel[0:RY, 1] = torch.arange(0, RY, dtype=torch.float32) / RY
|
| 439 |
+
col += RY
|
| 440 |
+
# YG
|
| 441 |
+
colorwheel[col:col + YG, 0] = 1. - (torch.arange(0, YG, dtype=torch.float32) / YG)
|
| 442 |
+
colorwheel[col:col + YG, 1] = 1.
|
| 443 |
+
col += YG
|
| 444 |
+
# GC
|
| 445 |
+
colorwheel[col:col + GC, 1] = 1.
|
| 446 |
+
colorwheel[col:col + GC, 2] = torch.arange(0, GC, dtype=torch.float32) / GC
|
| 447 |
+
col += GC
|
| 448 |
+
# CB
|
| 449 |
+
colorwheel[col:col + CB, 1] = 1. - (torch.arange(0, CB, dtype=torch.float32) / CB)
|
| 450 |
+
colorwheel[col:col + CB, 2] = 1.
|
| 451 |
+
col += CB
|
| 452 |
+
# BM
|
| 453 |
+
colorwheel[col:col + BM, 2] = 1.
|
| 454 |
+
colorwheel[col:col + BM, 0] = torch.arange(0, BM, dtype=torch.float32) / BM
|
| 455 |
+
col += BM
|
| 456 |
+
# MR
|
| 457 |
+
colorwheel[col:col + MR, 2] = 1. - (torch.arange(0, MR, dtype=torch.float32) / MR)
|
| 458 |
+
colorwheel[col:col + MR, 0] = 1.
|
| 459 |
+
return colorwheel
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
def is_image_file(filename):
|
| 463 |
+
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']
|
| 464 |
+
filename_lower = filename.lower()
|
| 465 |
+
return any(filename_lower.endswith(extension) for extension in IMG_EXTENSIONS)
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
def deprocess(img):
|
| 469 |
+
img = img.add_(1).div_(2)
|
| 470 |
+
return img
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
# get configs
|
| 474 |
+
def get_config(config):
|
| 475 |
+
with open(config, 'r') as stream:
|
| 476 |
+
return yaml.load(stream,Loader=yaml.Loader)
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
# Get model list for resume
|
| 480 |
+
def get_model_list(dirname, key, iteration=0):
|
| 481 |
+
if os.path.exists(dirname) is False:
|
| 482 |
+
return None
|
| 483 |
+
gen_models = [os.path.join(dirname, f) for f in os.listdir(dirname) if
|
| 484 |
+
os.path.isfile(os.path.join(dirname, f)) and key in f and ".pt" in f]
|
| 485 |
+
if gen_models is None:
|
| 486 |
+
return None
|
| 487 |
+
gen_models.sort()
|
| 488 |
+
if iteration == 0:
|
| 489 |
+
last_model_name = gen_models[-1]
|
| 490 |
+
else:
|
| 491 |
+
for model_name in gen_models:
|
| 492 |
+
if '{:0>8d}'.format(iteration) in model_name:
|
| 493 |
+
return model_name
|
| 494 |
+
raise ValueError('Not found models with this iteration')
|
| 495 |
+
return last_model_name
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
if __name__ == '__main__':
|
| 499 |
+
test_random_bbox()
|
| 500 |
+
mask = test_bbox2mask()
|
| 501 |
+
print(mask.shape)
|
| 502 |
+
import matplotlib.pyplot as plt
|
| 503 |
+
|
| 504 |
+
plt.imshow(mask, cmap='gray')
|
| 505 |
+
plt.show()
|