Prism / parallax.py
Masaaki Kawata
Update parallax.py
7ad46ff
raw
history blame
17.6 kB
import io
import os
import numpy as np
import torch
import torch.nn.functional as F
from tempfile import NamedTemporaryFile
from dotenv import load_dotenv
from omegaconf import OmegaConf
from PIL import Image, ImageFilter
from huggingface_hub import hf_hub_download
from depth_anything_v2.dpt import DepthAnythingV2
from ultralytics import YOLO
from simple_lama_inpainting import SimpleLama
from saicinpainting.training.trainers import load_checkpoint
from saicinpainting.evaluation.utils import move_to_device
from saicinpainting.evaluation.data import pad_tensor_to_modulo
load_dotenv(verbose=False)
#DEPTH_ANYTHING = DepthAnythingV2(**{'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]})
#DEPTH_ANYTHING.load_state_dict(torch.load(hf_hub_download(repo_id='depth-anything/Depth-Anything-V2-Base', filename='depth_anything_v2_vitb.pth', repo_type='model', token=os.environ['HF_TOKEN']), map_location='cpu'))
DEPTH_ANYTHING = DepthAnythingV2(**{'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]})
DEPTH_ANYTHING.load_state_dict(torch.load(hf_hub_download(repo_id='depth-anything/Depth-Anything-V2-Large', filename='depth_anything_v2_vitl.pth', repo_type='model', token=os.environ['HF_TOKEN']), map_location='cpu'))
DEPTH_ANYTHING = DEPTH_ANYTHING.to('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu').eval()
HAND_YOLO = YOLO(hf_hub_download('Bingsu/adetailer', 'hand_yolov8n.pt', token=os.environ['HF_TOKEN']))
PERSON_YOLO = YOLO(hf_hub_download('Bingsu/adetailer', 'person_yolov8n-seg.pt', token=os.environ['HF_TOKEN']))
LAMA = None
LAMA_TRAIN_CFG = OmegaConf.load('big-lama/config.yaml')
LAMA_TRAIN_CFG['training_model']['predict_only'] = True
LAMA = load_checkpoint(LAMA_TRAIN_CFG, 'big-lama/models/best.ckpt', strict=False, map_location='cpu')
LAMA = LAMA.to('cuda' if torch.cuda.is_available() else 'cpu').eval()
def resize_iamge(image, maximum=2048, resample=Image.Resampling.LANCZOS):
width, height = image.size
if width < height:
if maximum < height:
scale = maximum / height
else:
return image
elif maximum < width:
scale = maximum / width
else:
return image
return image.resize((round(width * scale), round(height * scale)), resample=resample)
def kmeans_pp(X, n_clusters, n_init=1, max_iter=300, tol=1e-4, random_state=None):
X = np.asarray(X, dtype=np.float32)
N, D = X.shape
n_clusters = min(n_clusters, N)
rng = np.random.default_rng(random_state)
def init_plus_plus():
centers = np.empty((n_clusters, D), dtype=np.float32)
idx0 = rng.integers(N)
centers[0] = X[idx0]
d2 = np.sum((X - centers[0])**2, axis=1)
for c in range(1, n_clusters):
s = d2.sum()
if not np.isfinite(s) or s <= 0:
idx = rng.integers(N)
else:
r = rng.random() * s
idx = np.searchsorted(np.cumsum(d2), r)
if idx >= N:
idx = N - 1
centers[c] = X[idx]
d2 = np.minimum(d2, np.sum((X - centers[c])**2, axis=1))
return centers
best_inertia = np.inf
best_labels = None
best_centers = None
for _ in range(n_init):
centers = init_plus_plus()
labels = np.full(N, -1, dtype=np.int32)
for _it in range(max_iter):
dmin = np.full(N, np.inf, dtype=np.float32)
for c in range(n_clusters):
d = np.sum((X - centers[c])**2, axis=1)
better = d < dmin
labels[better] = c
dmin[better] = d[better]
new_centers = centers.copy()
empty = []
for c in range(n_clusters):
pts = X[labels == c]
if pts.size == 0:
empty.append(c)
else:
new_centers[c] = pts.mean(axis=0).astype(np.float32)
if empty:
far_idx = np.argmax(dmin)
for c in empty:
new_centers[c] = X[far_idx]
shift = np.sqrt(((centers - new_centers)**2).sum(axis=1)).max()
centers = new_centers
if shift <= tol:
break
dmin = np.full(N, np.inf, dtype=np.float32)
for c in range(n_clusters):
d = np.sum((X - centers[c])**2, axis=1)
better = d < dmin
labels[better] = c
dmin[better] = d[better]
inertia = float(dmin.sum())
if inertia < best_inertia:
best_inertia = inertia
best_labels = labels.copy()
best_centers = centers.copy()
return best_labels, best_centers
def connected_components_8(mask: np.ndarray):
H, W = mask.shape
labels = np.zeros((H, W), dtype=np.int32)
seen = np.zeros((H, W), dtype=bool)
nbrs = [(-1,-1),(-1,0),(-1,1),
( 0,-1), ( 0,1),
( 1,-1),( 1,0),( 1,1)]
comp_id = 0
bboxes = []
ys, xs = np.where(mask)
for y0, x0 in zip(ys, xs):
if seen[y0, x0]:
continue
comp_id += 1
stack = [(y0, x0)]
seen[y0, x0] = True
labels[y0, x0] = comp_id
minx = maxx = x0
miny = maxy = y0
while stack:
y, x = stack.pop()
if x < minx: minx = x
if x > maxx: maxx = x
if y < miny: miny = y
if y > maxy: maxy = y
for dy, dx in nbrs:
ny, nx = y + dy, x + dx
if 0 <= ny < H and 0 <= nx < W:
if mask[ny, nx] and not seen[ny, nx]:
seen[ny, nx] = True
labels[ny, nx] = comp_id
stack.append((ny, nx))
bboxes.append((minx, miny, maxx, maxy))
return labels, bboxes
def bbox_contained(inner, outer):
fx1, fy1, fx2, fy2 = inner
mx1, my1, mx2, my2 = outer
return (fx1 >= mx1) and (fy1 >= my1) and (fx2 <= mx2) and (fy2 <= my2)
def expand_bbox(b, H, W, pad=1):
x1,y1,x2,y2 = b
return (max(0, x1-pad), max(0, y1-pad), min(W-1, x2+pad), min(H-1, y2+pad))
def overlap_ratio(a, b):
ix1, iy1 = max(a[0], b[0]), max(a[1], b[1])
ix2, iy2 = min(a[2], b[2]), min(a[3], b[3])
if ix1 >= ix2 or iy1 >= iy2:
return 0.0
inter = (ix2 - ix1) * (iy2 - iy1)
area = (b[2] - b[0]) * (b[3] - b[1])
return inter / area
def lama_inpaint(model, image, mask, modulo):
img_t = torch.from_numpy(np.array(image)).permute(2,0,1).unsqueeze(0) / 255.
mask_t = (torch.from_numpy(np.array(mask)) > 127).float().unsqueeze(0).unsqueeze(0)
orig_h, orig_w = img_t.shape[-2:]
img_t = pad_tensor_to_modulo(img_t, modulo)
h, w = mask_t.shape[-2:]
pad_h = (modulo - h % modulo) % modulo
pad_w = (modulo - w % modulo) % modulo
mask_t = F.pad(mask_t, (0, pad_w, 0, pad_h), mode='constant', value=0)
batch = {'image': img_t, 'mask': mask_t}
batch = move_to_device(batch, model.device)
with torch.no_grad():
result = model(batch)['inpainted'][0].permute(1, 2, 0).detach().cpu().numpy()
result = result[:orig_h, :orig_w, ...]
result = (result.clip(0, 1) * 255).astype('uint8')
return Image.fromarray(result)
def feather(image: Image.Image, gauss_radius=1, band_px=1, strength=1.0) -> Image.Image:
A_pil = image.getchannel('A')
k = 2 * int(band_px) + 1 # odd
a_dil = A_pil.filter(ImageFilter.MaxFilter(k))
a_ero = A_pil.filter(ImageFilter.MinFilter(k))
band = np.asarray(a_dil, dtype=np.uint8) != np.asarray(a_ero, dtype=np.uint8)
arr = np.asarray(image, dtype=np.float32) / 255.0
A = arr[..., 3:4]
rgb_pm = arr[..., :3] * A
pm_rgba_u8 = np.empty(arr.shape, dtype=np.uint8)
pm_rgba_u8[..., :3] = np.clip(rgb_pm * 255.0, 0, 255).astype(np.uint8)
pm_rgba_u8[..., 3] = (arr[..., 3] * 255.0 + 0.5).astype(np.uint8)
blurred = Image.fromarray(pm_rgba_u8, 'RGBA').filter(ImageFilter.GaussianBlur(gauss_radius))
blurred_f = np.asarray(blurred, dtype=np.float32) / 255.0
rgb_pm_blur = blurred_f[..., :3]
A_blur = blurred_f[..., 3:4]
s = float(np.clip(strength, 0.0, 1.0))
if s < 1.0:
A_blur = (1.0 - s) * A + s * A_blur
eps = 1e-6
rgb_norm = rgb_pm_blur / np.maximum(A_blur, eps)
band3 = band[..., None]
out_rgb = np.where(band3, rgb_norm, arr[..., :3])
out_A = np.where(band3, A_blur, A)
out = np.concatenate([out_rgb, out_A], axis=-1)
out = (np.clip(out, 0.0, 1.0) * 255.0 + 0.5).astype(np.uint8)
return Image.fromarray(out, 'RGBA')
def convert_webp(image: Image.Image) -> str:
with io.BytesIO() as buffer:
image.save(buffer, format='WEBP', lossless=True, method=6)
buffer.seek(0)
with NamedTemporaryFile(delete=False, suffix='.webp') as file:
file.write(buffer.read())
file.flush()
return file.name
def generate_parallax_images(image, n_layers=5, maximum=2048, strategy=None):
global LAMA
rgb_image = resize_iamge(image.convert('RGB'), maximum)
width, height = rgb_image.size
rgb = np.asarray(rgb_image)
depth = DEPTH_ANYTHING.infer_image(rgb[:, :, ::-1])
if strategy == 'k-means':
n_clusters = n_layers
x = depth.reshape(-1, 1)
mask = np.isfinite(x[:, 0])
labels, centers = kmeans_pp(x[mask].astype(np.float32), n_clusters=n_clusters, n_init=1, max_iter=100, tol=1e-4, random_state=None)
centers = centers.reshape(-1)
order = np.argsort(centers)
rank_of_label = np.empty_like(order)
rank_of_label[order] = np.arange(n_clusters)
labels_full = np.full(x.shape[0], -1, dtype=int)
labels_full[mask] = labels
levels = centers[order].astype(np.float64)
quantized_depth = np.zeros(x.shape[0], dtype=np.float32)
valid_idx = np.where(mask)[0]
quantized_depth[valid_idx] = levels[rank_of_label[labels_full[valid_idx]]]
quantized_depth = quantized_depth.reshape(height, width)
depth = quantized_depth.astype(np.float64)
depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-8)
edges = (levels - levels.min()) / (levels.max() - levels.min() + 1e-8)
else:
bins = np.linspace(0, np.max(depth), n_layers + 1)
quantized = np.digitize(depth, bins) - 1
depth = quantized * (1 / (n_layers - 1))
edges = np.arange(n_layers) * (1 / (n_layers - 1))
depth_mod = np.zeros_like(depth, dtype=np.float64)
front_mask = depth >= edges[len(edges) - 1]
front_labels, front_bboxes = connected_components_8(front_mask)
_, near_bboxes = connected_components_8(depth >= edges[1])
inpaint_mask = np.zeros_like(front_mask, dtype=bool)
person_results = PERSON_YOLO.predict(source=rgb, conf=0.5, iou=0.45, verbose=False, device='0' if torch.cuda.is_available() else 'cpu')
hand_results = HAND_YOLO.predict(source=rgb, conf=0.5, iou=0.45, verbose=False, device='0' if torch.cuda.is_available() else 'cpu')
person_boxes = []
hand_boxes = []
if len(person_results) > 0 and person_results[0].boxes is not None and len(person_results[0].boxes) > 0:
for box in person_results[0].boxes:
person_boxes.append(box.xyxy.detach().cpu().numpy()[0])
if len(hand_results) > 0 and hand_results[0].boxes is not None and len(hand_results[0].boxes) > 0:
for box in hand_results[0].boxes:
hand_boxes.append(box.xyxy.detach().cpu().numpy()[0])
if len(front_bboxes) > 0:
need_inpaint = True
inpaintable_indexes = []
for i, fb in enumerate(front_bboxes, start=1):
contained = any(bbox_contained(fb, mb) for mb in near_bboxes)
inpaintable = False
if contained:
fx1, fy1, fx2, fy2 = fb
fb_exclusive = np.array([fx1, fy1, fx2 + 1, fy2 + 1], dtype=np.int32)
detected_hand = False
for xyxy in hand_boxes:
area_a = (xyxy[2] - xyxy[0]) * (xyxy[3] - xyxy[1])
area_b = (fb_exclusive[2] - fb_exclusive[0]) * (fb_exclusive[3] - fb_exclusive[1])
if area_a > area_b:
a = xyxy
b = fb_exclusive
else:
a = fb_exclusive
b = xyxy
if overlap_ratio(a, b) >= 0.75:
detected_hand = True
break
if detected_hand:
inpaintable = True
else:
detected_person = False
for xyxy in person_boxes:
area_a = (xyxy[2] - xyxy[0]) * (xyxy[3] - xyxy[1])
area_b = (fb_exclusive[2] - fb_exclusive[0]) * (fb_exclusive[3] - fb_exclusive[1])
if area_a > area_b:
a = xyxy
b = fb_exclusive
else:
a = fb_exclusive
b = xyxy
if overlap_ratio(a, b) >= 0.75:
detected_person = True
break
if not detected_person:
inpaintable = True
inpaintable_indexes.append(inpaintable)
if all(inpaintable_indexes):
need_inpaint = True
for i, fb in enumerate(front_bboxes, start=1):
inpaint_mask |= (front_labels == i)
else:
need_inpaint = False
else:
need_inpaint = False
if need_inpaint:
hi_labels, hi_bboxes = connected_components_8((depth >= edges[1]) & (depth < edges[len(edges) - 1]))
for cid in range(1, hi_labels.max() + 1):
comp = (hi_labels == cid)
median = np.median(depth[comp])
depth_mod[comp] = median
keep_mask = (depth < edges[1])
depth_mod[keep_mask] = depth[keep_mask]
depth_mod[depth >= edges[len(edges) - 1]] = edges[len(edges) - 1]
else:
hi_labels, hi_bboxes = connected_components_8(depth >= edges[1])
for cid in range(1, hi_labels.max() + 1):
comp = (hi_labels == cid)
median = np.median(depth[comp])
depth_mod[comp] = median
keep_mask = (depth < edges[1])
depth_mod[keep_mask] = depth[keep_mask]
depth = depth_mod
layers = []
for i in reversed(range(n_layers)):
if i > 0:
if i < n_layers - 1:
mask = (depth >= edges[i]) & (depth < edges[i + 1])
if rgb[mask].size > 0:
if need_inpaint:
need_inpaint = False
hole_mask = Image.fromarray((inpaint_mask * 255).astype(np.uint8), mode='L').filter(ImageFilter.BoxBlur(16))
inpaint_image = lama_inpaint(LAMA, rgb_image, hole_mask, LAMA_TRAIN_CFG.get('dataset', {}).get('pad_out_to_modulo', 8))
if inpaint_image.size != (width, height):
inpaint_image = inpaint_image.resize((width, height), Image.Resampling.BICUBIC)
inpaint = np.asarray(inpaint_image.convert('RGB'))
rgba = np.zeros((height, width, 4), np.uint8)
rgba[..., :3][inpaint_mask] = inpaint[..., :3][inpaint_mask]
rgba[..., 3][inpaint_mask] = 255
rgba[..., :3][mask] = inpaint[..., :3][mask]
rgba[..., 3][mask] = 255
layers.insert(0, convert_webp(feather(Image.fromarray(rgba, 'RGBA'))))
continue
else:
layers.insert(0, convert_webp(Image.new('RGBA', (1, 1), (0, 0, 0, 0))))
continue
else:
mask = (depth >= edges[i])
if rgb[mask].size == 0:
layers.insert(0, convert_webp(Image.new('RGBA', (1, 1), (0, 0, 0, 0))))
continue
rgba = np.zeros((height, width, 4), np.uint8)
rgba[..., :3][mask] = rgb[mask]
rgba[..., 3][mask] = 255
layers.insert(0, convert_webp(feather(Image.fromarray(rgba, 'RGBA'))))
else:
mask = (depth < edges[1])
if rgb[mask].size > 0:
rgba = np.zeros((height, width, 4), np.uint8)
rgba[..., :3][mask] = rgb[mask]
rgba[..., 3][mask] = 255
mask_image = Image.fromarray(((rgba[..., 3] == 0) * 255).astype(np.uint8), mode='L').filter(ImageFilter.BoxBlur(16))
inpaint_image = lama_inpaint(LAMA, rgb_image, mask_image, LAMA_TRAIN_CFG.get('dataset', {}).get('pad_out_to_modulo', 8))
if inpaint_image.size != (width, height):
inpaint_image = inpaint_image.resize((width, height), Image.Resampling.BICUBIC)
layers.insert(0, convert_webp(inpaint_image))
else:
layers.insert(0, convert_webp(Image.new('RGBA', (1, 1), (0, 0, 0, 0))))
return layers