|
|
import io |
|
|
import os |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from tempfile import NamedTemporaryFile |
|
|
from dotenv import load_dotenv |
|
|
from omegaconf import OmegaConf |
|
|
from PIL import Image, ImageFilter |
|
|
from huggingface_hub import hf_hub_download |
|
|
from depth_anything_v2.dpt import DepthAnythingV2 |
|
|
from ultralytics import YOLO |
|
|
from simple_lama_inpainting import SimpleLama |
|
|
from saicinpainting.training.trainers import load_checkpoint |
|
|
from saicinpainting.evaluation.utils import move_to_device |
|
|
from saicinpainting.evaluation.data import pad_tensor_to_modulo |
|
|
|
|
|
|
|
|
load_dotenv(verbose=False) |
|
|
|
|
|
|
|
|
|
|
|
DEPTH_ANYTHING = DepthAnythingV2(**{'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}) |
|
|
DEPTH_ANYTHING.load_state_dict(torch.load(hf_hub_download(repo_id='depth-anything/Depth-Anything-V2-Large', filename='depth_anything_v2_vitl.pth', repo_type='model', token=os.environ['HF_TOKEN']), map_location='cpu')) |
|
|
DEPTH_ANYTHING = DEPTH_ANYTHING.to('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu').eval() |
|
|
HAND_YOLO = YOLO(hf_hub_download('Bingsu/adetailer', 'hand_yolov8n.pt', token=os.environ['HF_TOKEN'])) |
|
|
PERSON_YOLO = YOLO(hf_hub_download('Bingsu/adetailer', 'person_yolov8n-seg.pt', token=os.environ['HF_TOKEN'])) |
|
|
LAMA = None |
|
|
LAMA_TRAIN_CFG = OmegaConf.load('big-lama/config.yaml') |
|
|
LAMA_TRAIN_CFG['training_model']['predict_only'] = True |
|
|
LAMA = load_checkpoint(LAMA_TRAIN_CFG, 'big-lama/models/best.ckpt', strict=False, map_location='cpu') |
|
|
LAMA = LAMA.to('cuda' if torch.cuda.is_available() else 'cpu').eval() |
|
|
|
|
|
|
|
|
def resize_iamge(image, maximum=2048, resample=Image.Resampling.LANCZOS): |
|
|
width, height = image.size |
|
|
|
|
|
if width < height: |
|
|
if maximum < height: |
|
|
scale = maximum / height |
|
|
else: |
|
|
return image |
|
|
elif maximum < width: |
|
|
scale = maximum / width |
|
|
else: |
|
|
return image |
|
|
|
|
|
return image.resize((round(width * scale), round(height * scale)), resample=resample) |
|
|
|
|
|
|
|
|
def kmeans_pp(X, n_clusters, n_init=1, max_iter=300, tol=1e-4, random_state=None): |
|
|
X = np.asarray(X, dtype=np.float32) |
|
|
N, D = X.shape |
|
|
n_clusters = min(n_clusters, N) |
|
|
|
|
|
rng = np.random.default_rng(random_state) |
|
|
|
|
|
def init_plus_plus(): |
|
|
centers = np.empty((n_clusters, D), dtype=np.float32) |
|
|
idx0 = rng.integers(N) |
|
|
centers[0] = X[idx0] |
|
|
d2 = np.sum((X - centers[0])**2, axis=1) |
|
|
|
|
|
for c in range(1, n_clusters): |
|
|
s = d2.sum() |
|
|
|
|
|
if not np.isfinite(s) or s <= 0: |
|
|
idx = rng.integers(N) |
|
|
else: |
|
|
r = rng.random() * s |
|
|
idx = np.searchsorted(np.cumsum(d2), r) |
|
|
|
|
|
if idx >= N: |
|
|
idx = N - 1 |
|
|
|
|
|
centers[c] = X[idx] |
|
|
d2 = np.minimum(d2, np.sum((X - centers[c])**2, axis=1)) |
|
|
|
|
|
return centers |
|
|
|
|
|
best_inertia = np.inf |
|
|
best_labels = None |
|
|
best_centers = None |
|
|
|
|
|
for _ in range(n_init): |
|
|
centers = init_plus_plus() |
|
|
|
|
|
labels = np.full(N, -1, dtype=np.int32) |
|
|
|
|
|
for _it in range(max_iter): |
|
|
dmin = np.full(N, np.inf, dtype=np.float32) |
|
|
|
|
|
for c in range(n_clusters): |
|
|
d = np.sum((X - centers[c])**2, axis=1) |
|
|
better = d < dmin |
|
|
labels[better] = c |
|
|
dmin[better] = d[better] |
|
|
|
|
|
new_centers = centers.copy() |
|
|
empty = [] |
|
|
|
|
|
for c in range(n_clusters): |
|
|
pts = X[labels == c] |
|
|
if pts.size == 0: |
|
|
empty.append(c) |
|
|
else: |
|
|
new_centers[c] = pts.mean(axis=0).astype(np.float32) |
|
|
|
|
|
if empty: |
|
|
far_idx = np.argmax(dmin) |
|
|
|
|
|
for c in empty: |
|
|
new_centers[c] = X[far_idx] |
|
|
|
|
|
shift = np.sqrt(((centers - new_centers)**2).sum(axis=1)).max() |
|
|
centers = new_centers |
|
|
|
|
|
if shift <= tol: |
|
|
break |
|
|
|
|
|
dmin = np.full(N, np.inf, dtype=np.float32) |
|
|
|
|
|
for c in range(n_clusters): |
|
|
d = np.sum((X - centers[c])**2, axis=1) |
|
|
better = d < dmin |
|
|
labels[better] = c |
|
|
dmin[better] = d[better] |
|
|
inertia = float(dmin.sum()) |
|
|
|
|
|
if inertia < best_inertia: |
|
|
best_inertia = inertia |
|
|
best_labels = labels.copy() |
|
|
best_centers = centers.copy() |
|
|
|
|
|
return best_labels, best_centers |
|
|
|
|
|
|
|
|
def connected_components_8(mask: np.ndarray): |
|
|
H, W = mask.shape |
|
|
labels = np.zeros((H, W), dtype=np.int32) |
|
|
seen = np.zeros((H, W), dtype=bool) |
|
|
nbrs = [(-1,-1),(-1,0),(-1,1), |
|
|
( 0,-1), ( 0,1), |
|
|
( 1,-1),( 1,0),( 1,1)] |
|
|
comp_id = 0 |
|
|
bboxes = [] |
|
|
|
|
|
ys, xs = np.where(mask) |
|
|
|
|
|
for y0, x0 in zip(ys, xs): |
|
|
if seen[y0, x0]: |
|
|
continue |
|
|
|
|
|
comp_id += 1 |
|
|
stack = [(y0, x0)] |
|
|
seen[y0, x0] = True |
|
|
labels[y0, x0] = comp_id |
|
|
|
|
|
minx = maxx = x0 |
|
|
miny = maxy = y0 |
|
|
|
|
|
while stack: |
|
|
y, x = stack.pop() |
|
|
|
|
|
if x < minx: minx = x |
|
|
if x > maxx: maxx = x |
|
|
if y < miny: miny = y |
|
|
if y > maxy: maxy = y |
|
|
|
|
|
for dy, dx in nbrs: |
|
|
ny, nx = y + dy, x + dx |
|
|
|
|
|
if 0 <= ny < H and 0 <= nx < W: |
|
|
if mask[ny, nx] and not seen[ny, nx]: |
|
|
seen[ny, nx] = True |
|
|
labels[ny, nx] = comp_id |
|
|
stack.append((ny, nx)) |
|
|
|
|
|
bboxes.append((minx, miny, maxx, maxy)) |
|
|
|
|
|
return labels, bboxes |
|
|
|
|
|
|
|
|
def bbox_contained(inner, outer): |
|
|
fx1, fy1, fx2, fy2 = inner |
|
|
mx1, my1, mx2, my2 = outer |
|
|
|
|
|
return (fx1 >= mx1) and (fy1 >= my1) and (fx2 <= mx2) and (fy2 <= my2) |
|
|
|
|
|
|
|
|
def expand_bbox(b, H, W, pad=1): |
|
|
x1,y1,x2,y2 = b |
|
|
|
|
|
return (max(0, x1-pad), max(0, y1-pad), min(W-1, x2+pad), min(H-1, y2+pad)) |
|
|
|
|
|
|
|
|
def overlap_ratio(a, b): |
|
|
ix1, iy1 = max(a[0], b[0]), max(a[1], b[1]) |
|
|
ix2, iy2 = min(a[2], b[2]), min(a[3], b[3]) |
|
|
|
|
|
if ix1 >= ix2 or iy1 >= iy2: |
|
|
return 0.0 |
|
|
|
|
|
inter = (ix2 - ix1) * (iy2 - iy1) |
|
|
area = (b[2] - b[0]) * (b[3] - b[1]) |
|
|
|
|
|
return inter / area |
|
|
|
|
|
|
|
|
def lama_inpaint(model, image, mask, modulo): |
|
|
img_t = torch.from_numpy(np.array(image)).permute(2,0,1).unsqueeze(0) / 255. |
|
|
mask_t = (torch.from_numpy(np.array(mask)) > 127).float().unsqueeze(0).unsqueeze(0) |
|
|
|
|
|
orig_h, orig_w = img_t.shape[-2:] |
|
|
|
|
|
img_t = pad_tensor_to_modulo(img_t, modulo) |
|
|
|
|
|
h, w = mask_t.shape[-2:] |
|
|
pad_h = (modulo - h % modulo) % modulo |
|
|
pad_w = (modulo - w % modulo) % modulo |
|
|
mask_t = F.pad(mask_t, (0, pad_w, 0, pad_h), mode='constant', value=0) |
|
|
|
|
|
batch = {'image': img_t, 'mask': mask_t} |
|
|
batch = move_to_device(batch, model.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
result = model(batch)['inpainted'][0].permute(1, 2, 0).detach().cpu().numpy() |
|
|
result = result[:orig_h, :orig_w, ...] |
|
|
result = (result.clip(0, 1) * 255).astype('uint8') |
|
|
|
|
|
return Image.fromarray(result) |
|
|
|
|
|
|
|
|
def feather(image: Image.Image, gauss_radius=1, band_px=1, strength=1.0) -> Image.Image: |
|
|
A_pil = image.getchannel('A') |
|
|
k = 2 * int(band_px) + 1 |
|
|
a_dil = A_pil.filter(ImageFilter.MaxFilter(k)) |
|
|
a_ero = A_pil.filter(ImageFilter.MinFilter(k)) |
|
|
band = np.asarray(a_dil, dtype=np.uint8) != np.asarray(a_ero, dtype=np.uint8) |
|
|
|
|
|
arr = np.asarray(image, dtype=np.float32) / 255.0 |
|
|
A = arr[..., 3:4] |
|
|
rgb_pm = arr[..., :3] * A |
|
|
|
|
|
pm_rgba_u8 = np.empty(arr.shape, dtype=np.uint8) |
|
|
pm_rgba_u8[..., :3] = np.clip(rgb_pm * 255.0, 0, 255).astype(np.uint8) |
|
|
pm_rgba_u8[..., 3] = (arr[..., 3] * 255.0 + 0.5).astype(np.uint8) |
|
|
|
|
|
blurred = Image.fromarray(pm_rgba_u8, 'RGBA').filter(ImageFilter.GaussianBlur(gauss_radius)) |
|
|
blurred_f = np.asarray(blurred, dtype=np.float32) / 255.0 |
|
|
rgb_pm_blur = blurred_f[..., :3] |
|
|
A_blur = blurred_f[..., 3:4] |
|
|
|
|
|
s = float(np.clip(strength, 0.0, 1.0)) |
|
|
|
|
|
if s < 1.0: |
|
|
A_blur = (1.0 - s) * A + s * A_blur |
|
|
|
|
|
eps = 1e-6 |
|
|
rgb_norm = rgb_pm_blur / np.maximum(A_blur, eps) |
|
|
|
|
|
band3 = band[..., None] |
|
|
out_rgb = np.where(band3, rgb_norm, arr[..., :3]) |
|
|
out_A = np.where(band3, A_blur, A) |
|
|
|
|
|
out = np.concatenate([out_rgb, out_A], axis=-1) |
|
|
out = (np.clip(out, 0.0, 1.0) * 255.0 + 0.5).astype(np.uint8) |
|
|
|
|
|
return Image.fromarray(out, 'RGBA') |
|
|
|
|
|
|
|
|
def convert_webp(image: Image.Image) -> str: |
|
|
with io.BytesIO() as buffer: |
|
|
image.save(buffer, format='WEBP', lossless=True, method=6) |
|
|
buffer.seek(0) |
|
|
|
|
|
with NamedTemporaryFile(delete=False, suffix='.webp') as file: |
|
|
file.write(buffer.read()) |
|
|
file.flush() |
|
|
|
|
|
return file.name |
|
|
|
|
|
|
|
|
def generate_parallax_images(image, n_layers=5, maximum=2048, strategy=None): |
|
|
global LAMA |
|
|
|
|
|
rgb_image = resize_iamge(image.convert('RGB'), maximum) |
|
|
width, height = rgb_image.size |
|
|
rgb = np.asarray(rgb_image) |
|
|
|
|
|
depth = DEPTH_ANYTHING.infer_image(rgb[:, :, ::-1]) |
|
|
|
|
|
if strategy == 'k-means': |
|
|
n_clusters = n_layers |
|
|
x = depth.reshape(-1, 1) |
|
|
mask = np.isfinite(x[:, 0]) |
|
|
labels, centers = kmeans_pp(x[mask].astype(np.float32), n_clusters=n_clusters, n_init=1, max_iter=100, tol=1e-4, random_state=None) |
|
|
centers = centers.reshape(-1) |
|
|
order = np.argsort(centers) |
|
|
rank_of_label = np.empty_like(order) |
|
|
rank_of_label[order] = np.arange(n_clusters) |
|
|
labels_full = np.full(x.shape[0], -1, dtype=int) |
|
|
labels_full[mask] = labels |
|
|
levels = centers[order].astype(np.float64) |
|
|
quantized_depth = np.zeros(x.shape[0], dtype=np.float32) |
|
|
valid_idx = np.where(mask)[0] |
|
|
quantized_depth[valid_idx] = levels[rank_of_label[labels_full[valid_idx]]] |
|
|
quantized_depth = quantized_depth.reshape(height, width) |
|
|
depth = quantized_depth.astype(np.float64) |
|
|
depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-8) |
|
|
edges = (levels - levels.min()) / (levels.max() - levels.min() + 1e-8) |
|
|
else: |
|
|
bins = np.linspace(0, np.max(depth), n_layers + 1) |
|
|
quantized = np.digitize(depth, bins) - 1 |
|
|
depth = quantized * (1 / (n_layers - 1)) |
|
|
edges = np.arange(n_layers) * (1 / (n_layers - 1)) |
|
|
|
|
|
depth_mod = np.zeros_like(depth, dtype=np.float64) |
|
|
front_mask = depth >= edges[len(edges) - 1] |
|
|
|
|
|
front_labels, front_bboxes = connected_components_8(front_mask) |
|
|
_, near_bboxes = connected_components_8(depth >= edges[1]) |
|
|
|
|
|
inpaint_mask = np.zeros_like(front_mask, dtype=bool) |
|
|
|
|
|
person_results = PERSON_YOLO.predict(source=rgb, conf=0.5, iou=0.45, verbose=False, device='0' if torch.cuda.is_available() else 'cpu') |
|
|
hand_results = HAND_YOLO.predict(source=rgb, conf=0.5, iou=0.45, verbose=False, device='0' if torch.cuda.is_available() else 'cpu') |
|
|
person_boxes = [] |
|
|
hand_boxes = [] |
|
|
|
|
|
if len(person_results) > 0 and person_results[0].boxes is not None and len(person_results[0].boxes) > 0: |
|
|
for box in person_results[0].boxes: |
|
|
person_boxes.append(box.xyxy.detach().cpu().numpy()[0]) |
|
|
|
|
|
if len(hand_results) > 0 and hand_results[0].boxes is not None and len(hand_results[0].boxes) > 0: |
|
|
for box in hand_results[0].boxes: |
|
|
hand_boxes.append(box.xyxy.detach().cpu().numpy()[0]) |
|
|
|
|
|
if len(front_bboxes) > 0: |
|
|
need_inpaint = True |
|
|
inpaintable_indexes = [] |
|
|
|
|
|
for i, fb in enumerate(front_bboxes, start=1): |
|
|
contained = any(bbox_contained(fb, mb) for mb in near_bboxes) |
|
|
inpaintable = False |
|
|
|
|
|
if contained: |
|
|
fx1, fy1, fx2, fy2 = fb |
|
|
fb_exclusive = np.array([fx1, fy1, fx2 + 1, fy2 + 1], dtype=np.int32) |
|
|
detected_hand = False |
|
|
|
|
|
for xyxy in hand_boxes: |
|
|
area_a = (xyxy[2] - xyxy[0]) * (xyxy[3] - xyxy[1]) |
|
|
area_b = (fb_exclusive[2] - fb_exclusive[0]) * (fb_exclusive[3] - fb_exclusive[1]) |
|
|
|
|
|
if area_a > area_b: |
|
|
a = xyxy |
|
|
b = fb_exclusive |
|
|
else: |
|
|
a = fb_exclusive |
|
|
b = xyxy |
|
|
|
|
|
if overlap_ratio(a, b) >= 0.75: |
|
|
detected_hand = True |
|
|
|
|
|
break |
|
|
|
|
|
if detected_hand: |
|
|
inpaintable = True |
|
|
|
|
|
else: |
|
|
detected_person = False |
|
|
|
|
|
for xyxy in person_boxes: |
|
|
area_a = (xyxy[2] - xyxy[0]) * (xyxy[3] - xyxy[1]) |
|
|
area_b = (fb_exclusive[2] - fb_exclusive[0]) * (fb_exclusive[3] - fb_exclusive[1]) |
|
|
|
|
|
if area_a > area_b: |
|
|
a = xyxy |
|
|
b = fb_exclusive |
|
|
else: |
|
|
a = fb_exclusive |
|
|
b = xyxy |
|
|
|
|
|
if overlap_ratio(a, b) >= 0.75: |
|
|
detected_person = True |
|
|
|
|
|
break |
|
|
|
|
|
if not detected_person: |
|
|
inpaintable = True |
|
|
|
|
|
inpaintable_indexes.append(inpaintable) |
|
|
|
|
|
if all(inpaintable_indexes): |
|
|
need_inpaint = True |
|
|
|
|
|
for i, fb in enumerate(front_bboxes, start=1): |
|
|
inpaint_mask |= (front_labels == i) |
|
|
|
|
|
else: |
|
|
need_inpaint = False |
|
|
|
|
|
else: |
|
|
need_inpaint = False |
|
|
|
|
|
if need_inpaint: |
|
|
hi_labels, hi_bboxes = connected_components_8((depth >= edges[1]) & (depth < edges[len(edges) - 1])) |
|
|
|
|
|
for cid in range(1, hi_labels.max() + 1): |
|
|
comp = (hi_labels == cid) |
|
|
median = np.median(depth[comp]) |
|
|
depth_mod[comp] = median |
|
|
|
|
|
keep_mask = (depth < edges[1]) |
|
|
depth_mod[keep_mask] = depth[keep_mask] |
|
|
depth_mod[depth >= edges[len(edges) - 1]] = edges[len(edges) - 1] |
|
|
|
|
|
else: |
|
|
hi_labels, hi_bboxes = connected_components_8(depth >= edges[1]) |
|
|
|
|
|
for cid in range(1, hi_labels.max() + 1): |
|
|
comp = (hi_labels == cid) |
|
|
median = np.median(depth[comp]) |
|
|
depth_mod[comp] = median |
|
|
|
|
|
keep_mask = (depth < edges[1]) |
|
|
depth_mod[keep_mask] = depth[keep_mask] |
|
|
|
|
|
depth = depth_mod |
|
|
layers = [] |
|
|
|
|
|
for i in reversed(range(n_layers)): |
|
|
if i > 0: |
|
|
if i < n_layers - 1: |
|
|
mask = (depth >= edges[i]) & (depth < edges[i + 1]) |
|
|
|
|
|
if rgb[mask].size > 0: |
|
|
if need_inpaint: |
|
|
need_inpaint = False |
|
|
|
|
|
hole_mask = Image.fromarray((inpaint_mask * 255).astype(np.uint8), mode='L').filter(ImageFilter.BoxBlur(16)) |
|
|
inpaint_image = lama_inpaint(LAMA, rgb_image, hole_mask, LAMA_TRAIN_CFG.get('dataset', {}).get('pad_out_to_modulo', 8)) |
|
|
|
|
|
if inpaint_image.size != (width, height): |
|
|
inpaint_image = inpaint_image.resize((width, height), Image.Resampling.BICUBIC) |
|
|
|
|
|
inpaint = np.asarray(inpaint_image.convert('RGB')) |
|
|
|
|
|
rgba = np.zeros((height, width, 4), np.uint8) |
|
|
rgba[..., :3][inpaint_mask] = inpaint[..., :3][inpaint_mask] |
|
|
rgba[..., 3][inpaint_mask] = 255 |
|
|
rgba[..., :3][mask] = inpaint[..., :3][mask] |
|
|
rgba[..., 3][mask] = 255 |
|
|
|
|
|
layers.insert(0, convert_webp(feather(Image.fromarray(rgba, 'RGBA')))) |
|
|
|
|
|
continue |
|
|
|
|
|
else: |
|
|
layers.insert(0, convert_webp(Image.new('RGBA', (1, 1), (0, 0, 0, 0)))) |
|
|
|
|
|
continue |
|
|
|
|
|
else: |
|
|
mask = (depth >= edges[i]) |
|
|
|
|
|
if rgb[mask].size == 0: |
|
|
layers.insert(0, convert_webp(Image.new('RGBA', (1, 1), (0, 0, 0, 0)))) |
|
|
|
|
|
continue |
|
|
|
|
|
rgba = np.zeros((height, width, 4), np.uint8) |
|
|
rgba[..., :3][mask] = rgb[mask] |
|
|
rgba[..., 3][mask] = 255 |
|
|
|
|
|
layers.insert(0, convert_webp(feather(Image.fromarray(rgba, 'RGBA')))) |
|
|
|
|
|
else: |
|
|
mask = (depth < edges[1]) |
|
|
|
|
|
if rgb[mask].size > 0: |
|
|
rgba = np.zeros((height, width, 4), np.uint8) |
|
|
rgba[..., :3][mask] = rgb[mask] |
|
|
rgba[..., 3][mask] = 255 |
|
|
|
|
|
mask_image = Image.fromarray(((rgba[..., 3] == 0) * 255).astype(np.uint8), mode='L').filter(ImageFilter.BoxBlur(16)) |
|
|
inpaint_image = lama_inpaint(LAMA, rgb_image, mask_image, LAMA_TRAIN_CFG.get('dataset', {}).get('pad_out_to_modulo', 8)) |
|
|
|
|
|
if inpaint_image.size != (width, height): |
|
|
inpaint_image = inpaint_image.resize((width, height), Image.Resampling.BICUBIC) |
|
|
|
|
|
layers.insert(0, convert_webp(inpaint_image)) |
|
|
|
|
|
else: |
|
|
layers.insert(0, convert_webp(Image.new('RGBA', (1, 1), (0, 0, 0, 0)))) |
|
|
|
|
|
return layers |
|
|
|