Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,772 Bytes
79c5088 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import torch
import torch.nn.functional as F
import io
import cv2
import numpy as np
from PIL import Image
from rembg import remove
def normalize(
z_t,
i,
max_norm_zs,
):
max_norm = max_norm_zs[i]
if max_norm < 0:
return z_t, 1
norm = torch.norm(z_t)
if norm < max_norm:
return z_t, 1
coeff = max_norm / norm
z_t = z_t * coeff
return z_t, coeff
def normalize2(x, dim):
x_mean = x.mean(dim=dim, keepdim=True)
x_std = x.std(dim=dim, keepdim=True)
x_normalized = (x - x_mean) / x_std
return x_normalized
def find_lambda_via_newton_batched(Qp, K_source, K_target, max_iter=50, tol=1e-7):
dot_QpK_source = torch.einsum("bcd,bmd->bcm", Qp, K_source) # shape [B]
dot_QpK_target = torch.einsum("bcd,bmd->bcm", Qp, K_target) # shape [B]
X = torch.exp(dot_QpK_source)
lmbd = torch.zeros([1], device=Qp.device, dtype=Qp.dtype) + 0.7
for it in range(max_iter):
y = torch.exp(lmbd * dot_QpK_target)
Z = (X + y).sum(dim=(2), keepdim=True)
x = X / Z
y = y / Z
val = (x.sum(dim=(1,2)) - y.sum(dim=(1,2))).sum()
grad = - (dot_QpK_target * y).sum()
if not (val.abs() > tol and grad.abs() > 1e-12):
break
lmbd = lmbd - val / grad
if lmbd.item() < 0.4:
return 0.1
elif lmbd.item() > 0.9:
return 0.65
return lmbd.item()
def find_lambda_via_super_halley(Qp, K_source, K_target, max_iter=50, tol=1e-7):
dot_QpK_source = torch.einsum("bcd,bmd->bcm", Qp, K_source)
dot_QpK_target = torch.einsum("bcd,bmd->bcm", Qp, K_target)
X = torch.exp(dot_QpK_source)
lmbd = torch.zeros([], device=Qp.device, dtype=Qp.dtype) + 0.8
for it in range(max_iter):
y = torch.exp(lmbd * dot_QpK_target)
Z = (X + y).sum(dim=2, keepdim=True)
x = X / Z
y = y / Z
val = (x.sum(dim=(1,2)) - y.sum(dim=(1,2))).sum()
grad = - (dot_QpK_target * y).sum()
f2 = - (dot_QpK_target**2 * y).sum()
if not (val.abs() > tol and grad.abs() > 1e-12):
break
denom = grad**2 - val * f2
if denom.abs() < 1e-20:
break
update = (val * grad) / denom
lmbd = lmbd - update
print(f"iter={it}, λ={lmbd.item():.6f}, val={val.item():.6e}, grad={grad.item():.6e}")
return lmbd
def find_smallest_key_with_suffix(features_dict: dict, suffix: str = "_1") -> str:
smallest_key = None
smallest_number = float('inf')
for key in features_dict.keys():
if key.endswith(suffix):
try:
number = int(key.split('_')[0])
if number < smallest_number:
smallest_number = number
smallest_key = key
except ValueError:
continue
return smallest_key
def extract_mask(masks, original_width, original_height):
if not masks:
return None
combined_mask = torch.zeros(512, 512)
scale_x = 512 / original_width
scale_y = 512 / original_height
for mask in masks:
start_x, start_y = mask["start_point"]
end_x, end_y = mask["end_point"]
start_x, end_x = min(start_x, end_x), max(start_x, end_x)
start_y, end_y = min(start_y, end_y), max(start_y, end_y)
scaled_start_x, scaled_start_y = int(start_x * scale_x), int(start_y * scale_y)
scaled_end_x, scaled_end_y = int(end_x * scale_x), int(end_y * scale_y)
combined_mask[scaled_start_y:scaled_end_y, scaled_start_x:scaled_end_x] += 1
binary_mask = (combined_mask > 0).float()
resized_mask = F.interpolate(binary_mask[None, None, :, :], size=(64, 64), mode="nearest")[0, 0]
return resized_mask
def remove_foreground(pil_image, threshold=128):
try:
with io.BytesIO() as input_buffer:
pil_image.save(input_buffer, format="PNG")
input_image_bytes = input_buffer.getvalue()
output_image_bytes = remove(input_image_bytes, alpha_matting=True)
output_image = Image.open(io.BytesIO(output_image_bytes))
mask = output_image.split()[-1]
mask_array = np.array(mask)
binary_mask = (mask_array >= threshold).astype(np.float32)
kernel = np.ones((15, 15), np.uint8)
binary_mask = cv2.erode(binary_mask, kernel, iterations=1)
mask_tensor = torch.from_numpy(binary_mask)
resized_mask = F.interpolate(mask_tensor[None, None, :, :], size=(64, 64), mode="nearest")[0, 0]
return resized_mask
except Exception as e:
print(f"Error while removing foreground: {e}")
return None |