DiffIR2VR / utils /common.py
jimmycv07's picture
first commit
1de8821
raw
history blame
5.85 kB
from typing import Mapping, Any, Tuple, Callable
import importlib
import os
from urllib.parse import urlparse
import torch
from torch import Tensor
from torch.nn import functional as F
import numpy as np
from torch.hub import download_url_to_file, get_dir
def get_obj_from_str(string: str, reload: bool=False) -> Any:
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def instantiate_from_config(config: Mapping[str, Any]) -> Any:
if not "target" in config:
raise KeyError("Expected key `target` to instantiate.")
# import ipdb; ipdb.set_trace()
return get_obj_from_str(config["target"])(**config.get("params", dict()))
def wavelet_blur(image: Tensor, radius: int):
"""
Apply wavelet blur to the input tensor.
"""
# input shape: (1, 3, H, W)
# convolution kernel
kernel_vals = [
[0.0625, 0.125, 0.0625],
[0.125, 0.25, 0.125],
[0.0625, 0.125, 0.0625],
]
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
# add channel dimensions to the kernel to make it a 4D tensor
kernel = kernel[None, None]
# repeat the kernel across all input channels
kernel = kernel.repeat(3, 1, 1, 1)
image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
# apply convolution
output = F.conv2d(image, kernel, groups=3, dilation=radius)
return output
def wavelet_decomposition(image: Tensor, levels=5):
"""
Apply wavelet decomposition to the input tensor.
This function only returns the low frequency & the high frequency.
"""
high_freq = torch.zeros_like(image)
for i in range(levels):
radius = 2 ** i
low_freq = wavelet_blur(image, radius)
high_freq += (image - low_freq)
image = low_freq
return high_freq, low_freq
def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
"""
Apply wavelet decomposition, so that the content will have the same color as the style.
"""
# calculate the wavelet decomposition of the content feature
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
del content_low_freq
# calculate the wavelet decomposition of the style feature
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
del style_high_freq
# reconstruct the content feature with the style's high frequency
return content_high_freq + style_low_freq
# https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/utils/download_util.py/
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
"""Load file form http url, will download models if necessary.
Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
Args:
url (str): URL to be downloaded.
model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
Default: None.
progress (bool): Whether to show the download progress. Default: True.
file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
Returns:
str: The path to the downloaded file.
"""
if model_dir is None: # use the pytorch hub_dir
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, 'checkpoints')
os.makedirs(model_dir, exist_ok=True)
parts = urlparse(url)
filename = os.path.basename(parts.path)
if file_name is not None:
filename = file_name
cached_file = os.path.abspath(os.path.join(model_dir, filename))
if not os.path.exists(cached_file):
print(f'Downloading: "{url}" to {cached_file}\n')
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
return cached_file
def sliding_windows(h: int, w: int, tile_size: int, tile_stride: int) -> Tuple[int, int, int, int]:
hi_list = list(range(0, h - tile_size + 1, tile_stride))
if (h - tile_size) % tile_stride != 0:
hi_list.append(h - tile_size)
wi_list = list(range(0, w - tile_size + 1, tile_stride))
if (w - tile_size) % tile_stride != 0:
wi_list.append(w - tile_size)
coords = []
for hi in hi_list:
for wi in wi_list:
coords.append((hi, hi + tile_size, wi, wi + tile_size))
return coords
# https://github.com/csslc/CCSR/blob/main/model/q_sampler.py#L503
def gaussian_weights(tile_width: int, tile_height: int) -> np.ndarray:
"""Generates a gaussian mask of weights for tile contributions"""
latent_width = tile_width
latent_height = tile_height
var = 0.01
midpoint = (latent_width - 1) / 2 # -1 because index goes from 0 to latent_width - 1
x_probs = [
np.exp(-(x - midpoint) * (x - midpoint) / (latent_width * latent_width) / (2 * var)) / np.sqrt(2 * np.pi * var)
for x in range(latent_width)]
midpoint = latent_height / 2
y_probs = [
np.exp(-(y - midpoint) * (y - midpoint) / (latent_height * latent_height) / (2 * var)) / np.sqrt(2 * np.pi * var)
for y in range(latent_height)]
weights = np.outer(y_probs, x_probs)
return weights
COUNT_VRAM = bool(os.environ.get("COUNT_VRAM", False))
def count_vram_usage(func: Callable) -> Callable:
if not COUNT_VRAM:
return func
def wrapper(*args, **kwargs):
peak_before = torch.cuda.max_memory_allocated() / (1024 ** 3)
ret = func(*args, **kwargs)
torch.cuda.synchronize()
peak_after = torch.cuda.max_memory_allocated() / (1024 ** 3)
print(f"VRAM peak before {func.__name__}: {peak_before:.5f} GB, after: {peak_after:.5f} GB")
return ret
return wrapper