|
|
|
|
|
|
|
import os |
|
import cv2 |
|
import logging |
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from typing import List, Optional, Tuple |
|
from PIL import Image, ImageDraw |
|
from transformers import AutoProcessor, AutoModelForMaskGeneration |
|
from simple_lama_inpainting import SimpleLama |
|
|
|
|
|
|
|
|
|
SAM_REPO = "facebook/sam-vit-huge" |
|
|
|
|
|
|
|
|
|
SAM_FULL_PRECISION = True |
|
LAMA_FULL_PRECISION = True |
|
|
|
|
|
|
|
|
|
SAM_PROCESSOR = None |
|
SAM_MODEL = None |
|
SIMPLE_LAMA = None |
|
|
|
|
|
|
|
|
|
def initialize_sam_and_lama(device="cuda"): |
|
global SAM_PROCESSOR, SAM_MODEL, SIMPLE_LAMA |
|
|
|
if SAM_PROCESSOR is None or SAM_MODEL is None or SIMPLE_LAMA is None: |
|
logging.info("Loading SAM model...") |
|
SAM_PROCESSOR = AutoProcessor.from_pretrained(SAM_REPO) |
|
SAM_MODEL = load_sam_model(SAM_REPO, SAM_FULL_PRECISION) |
|
|
|
logging.info("Loading LaMa inpainting model...") |
|
lama_device = "cpu" |
|
|
|
logging.info("LAMA will use CPU - this is intentional for compatibility") |
|
SIMPLE_LAMA = SimpleLama(device=lama_device) |
|
logging.info(f"Successfully loaded LAMA model on {lama_device.upper()}") |
|
|
|
def load_sam_model(repo_id: str, full_precision: bool): |
|
try: |
|
torch.cuda.empty_cache() |
|
|
|
model = AutoModelForMaskGeneration.from_pretrained( |
|
repo_id, |
|
device_map="auto", |
|
torch_dtype=torch.float32 if full_precision else torch.float16 |
|
) |
|
|
|
if not hasattr(model, 'hf_device_map'): |
|
model = model.cuda() |
|
if not full_precision: |
|
model = model.half() |
|
|
|
model.eval() |
|
|
|
with torch.no_grad(): |
|
logging.info(f"Verifying SAM model is on CUDA") |
|
param = next(model.parameters()) |
|
if not param.is_cuda: |
|
model = model.cuda() |
|
logging.warning(f"Forced SAM model to CUDA") |
|
logging.info(f"SAM model device: {param.device}") |
|
|
|
return model |
|
except Exception as e: |
|
logging.error(f"Failed to load SAM model: {e}") |
|
raise |
|
|
|
|
|
|
|
|
|
ARTIFACTS_LIST = ["jewelry", "necklace", "bracelet", "ring", "earrings", "watch", "glasses"] |
|
|
|
|
|
|
|
|
|
@pipeline_step |
|
def remove_object_batch(contexts: List[ProcessingContext], batch_logs: List[dict]) -> None: |
|
initialize_sam_and_lama() |
|
|
|
logging.info(f"[DEBUG] remove_object_batch => Starting with {len(contexts)} contexts.") |
|
|
|
for ctx_idx, ctx in enumerate(contexts): |
|
|
|
step_log = { |
|
"function": "remove_object_batch", |
|
"image_url": getattr(ctx, "url", "unknown"), |
|
"status": None, |
|
"artifacts_found": [], |
|
"image_dimensions": None, |
|
"artifact_boxes": [] |
|
} |
|
|
|
if ctx.skip_run or ctx.skip_processing: |
|
step_log["status"] = "skipped" |
|
batch_logs.append(step_log) |
|
continue |
|
|
|
if "original" not in ctx.pil_img: |
|
logging.debug(f"(Context #{ctx_idx}) => RBC 'original' missing => {ctx.url}") |
|
step_log["status"] = "error" |
|
step_log["exception"] = "No RBC 'original' in ctx" |
|
ctx.skip_run = True |
|
batch_logs.append(step_log) |
|
continue |
|
|
|
dr = ctx.detection_result |
|
if not dr or dr.get("status") != "ok": |
|
logging.debug(f"(Context #{ctx_idx}) => No valid detection => {ctx.url}") |
|
step_log["status"] = "no_detection" |
|
batch_logs.append(step_log) |
|
continue |
|
|
|
boxes = dr.get("boxes", []) |
|
kws = dr.get("final_keywords", []) |
|
if len(boxes) != len(kws) or not boxes: |
|
logging.debug(f"(Context #{ctx_idx}) => mismatch or no boxes => {ctx.url}") |
|
step_log["status"] = "no_boxes_in_detection" |
|
batch_logs.append(step_log) |
|
continue |
|
|
|
artifact_indices = [i for i, kw_ in enumerate(kws) if kw_ in ARTIFACTS_LIST] |
|
if not artifact_indices: |
|
logging.debug(f"(Context #{ctx_idx}) => No artifacts found => {ctx.url}. Skipping flatten.") |
|
step_log["status"] = "no_artifacts_found" |
|
batch_logs.append(step_log) |
|
continue |
|
|
|
pil_rgba, orig_fname, _ = ctx.pil_img["original"] |
|
logging.debug(f"(Context #{ctx_idx}) Flattening RBC image to white background (since artifacts exist).") |
|
|
|
flattened = Image.new("RGB", pil_rgba.size, (255, 255, 255)) |
|
flattened.paste(pil_rgba.convert("RGB"), mask=pil_rgba.getchannel("A")) |
|
logging.debug(f"(Context #{ctx_idx}) Background conversion complete.") |
|
|
|
updated_img = flattened |
|
found_labels = [] |
|
|
|
for art_i in artifact_indices: |
|
box_ = boxes[art_i] |
|
kw_ = kws[art_i] |
|
step_log["artifact_boxes"].append({ |
|
"original_box": box_, |
|
"label": kw_ |
|
}) |
|
|
|
w_img, h_img = updated_img.size |
|
expanded = expand_bbox(box_, w_img, h_img, pad=24) |
|
logging.debug(f"(Context #{ctx_idx}) Artifact {art_i}: Expanded box from {box_} to {expanded}.") |
|
step_log["artifact_boxes"][-1]["expanded_box"] = expanded |
|
|
|
logging.debug(f"(Context #{ctx_idx}) Removing object in region {expanded}.") |
|
try: |
|
updated_img = remove_object_inplace( |
|
updated_img, |
|
expanded, |
|
SAM_PROCESSOR, |
|
SAM_MODEL, |
|
SIMPLE_LAMA, |
|
device="cuda", |
|
debug_save_prefix=f"dbg_ctx{ctx_idx}_artifact{art_i}", |
|
dilate_mask=True, |
|
dilate_kernel_size=40 |
|
) |
|
logging.debug(f"(Context #{ctx_idx}) Object removal complete for artifact {art_i}.") |
|
found_labels.append(kw_) |
|
except RuntimeError as re: |
|
logging.warning(f"[WARNING] TorchScript inpainting failed for artifact {art_i}, skipping removal.\n{re}") |
|
step_log["artifact_boxes"][-1]["skipped_inpainting"] = True |
|
|
|
ctx.pil_img["original"] = [updated_img, orig_fname, None] |
|
step_log["artifacts_found"] = found_labels |
|
step_log["status"] = "artifacts_removed" |
|
step_log["image_dimensions"] = (updated_img.width, updated_img.height) |
|
logging.debug(f"(Context #{ctx_idx}) => Artifacts removed => {ctx.url}") |
|
batch_logs.append(step_log) |
|
|
|
logging.debug("[DEBUG] remove_object_batch => Finished.\n") |
|
|
|
|
|
def expand_bbox(box, w, h, pad=24): |
|
x1, y1, x2, y2 = box |
|
expanded_box = [ |
|
max(0, x1 - pad), |
|
max(0, y1 - pad), |
|
min(w, x2 + pad), |
|
min(h, y2 + pad) |
|
] |
|
logging.debug(f"expand_bbox => Original: {box}, Expanded: {expanded_box}") |
|
return expanded_box |
|
|
|
|
|
def remove_object_inplace( |
|
pil_rgb: Image.Image, |
|
bbox: List[int], |
|
sam_processor, |
|
sam_model, |
|
lama_model_jit, |
|
device="cuda", |
|
debug_save_prefix=None, |
|
dilate_mask=False, |
|
dilate_kernel_size=15 |
|
) -> Image.Image: |
|
logging.debug(f"remove_object_inplace => Processing bbox {bbox} on image size {pil_rgb.size}") |
|
|
|
image_rgb = pil_rgb.convert("RGB") |
|
|
|
inputs = sam_processor( |
|
images=image_rgb, |
|
input_boxes=[[[bbox[0], bbox[1], bbox[2], bbox[3]]]], |
|
return_tensors="pt" |
|
).to(device) |
|
|
|
if not SAM_FULL_PRECISION and sam_model.dtype == torch.float16: |
|
inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in inputs.items()} |
|
|
|
with torch.no_grad(): |
|
out_sam = sam_model(**inputs) |
|
|
|
pred_masks = out_sam.pred_masks |
|
if pred_masks.ndim == 5 and pred_masks.shape[2] == 3: |
|
pred_masks = pred_masks[:, 0, 0, :, :] |
|
elif pred_masks.ndim == 4 and pred_masks.shape[1] == 3: |
|
pred_masks = pred_masks[:, 0, :, :] |
|
if pred_masks.ndim == 3: |
|
pred_masks = pred_masks.unsqueeze(1) |
|
|
|
if "reshaped_input_sizes" in inputs: |
|
t_h, t_w = inputs["reshaped_input_sizes"][0].tolist() |
|
pred_masks = F.interpolate( |
|
pred_masks, |
|
size=(t_h, t_w), |
|
mode="bilinear", |
|
align_corners=False |
|
) |
|
|
|
mask_bin = (pred_masks[0, 0] > 0.5).cpu().numpy().astype(np.uint8) |
|
|
|
if dilate_mask: |
|
kernel = np.ones((dilate_kernel_size, dilate_kernel_size), dtype=np.uint8) |
|
mask_bin = cv2.dilate(mask_bin, kernel, iterations=1) |
|
logging.debug(f"remove_object_inplace => Dilated mask mean: {mask_bin.mean():.6f}") |
|
|
|
updated_crop = inpaint_region_with_lama_multi_fallback( |
|
image_rgb, |
|
mask_bin, |
|
bbox, |
|
lama_model_jit |
|
) |
|
|
|
logging.debug(f"remove_object_inplace => Inpainting complete for bbox {bbox}") |
|
return updated_crop |
|
|
|
|
|
def inpaint_region_with_lama_multi_fallback( |
|
image_rgb: Image.Image, |
|
mask_bin: np.ndarray, |
|
bbox: List[int], |
|
lama_model_jit |
|
) -> Image.Image: |
|
x1, y1, x2, y2 = bbox |
|
subregion = image_rgb.crop((x1, y1, x2, y2)) |
|
mask_sub = mask_bin[y1:y2, x1:x2].copy() |
|
orig_w, orig_h = subregion.size |
|
logging.debug(f"inpaint_region_with_lama_multi_fallback => Cropped region: w={orig_w}, h={orig_h}") |
|
|
|
if orig_w < 2 or orig_h < 2: |
|
logging.warning("Subregion too small for inpainting. Filling with white instead.") |
|
return fill_white(image_rgb, bbox) |
|
|
|
max_dim = max(orig_w, orig_h) |
|
target_size = 1024 |
|
scale = 1.0 |
|
if max_dim > target_size: |
|
scale = target_size / float(max_dim) |
|
new_w = max(1, int(round(orig_w * scale))) |
|
new_h = max(1, int(round(orig_h * scale))) |
|
subregion = subregion.resize((new_w, new_h), Image.Resampling.LANCZOS) |
|
mask_sub = cv2.resize(mask_sub, (new_w, new_h), interpolation=cv2.INTER_NEAREST) |
|
logging.debug(f"inpaint_region_with_lama_multi_fallback => scaled to {new_w}x{new_h} (factor={scale:.3f})") |
|
else: |
|
new_w, new_h = orig_w, orig_h |
|
|
|
pad_w = (32 - (new_w % 32)) % 32 |
|
pad_h = (32 - (new_h % 32)) % 32 |
|
logging.debug(f"inpaint_region_with_lama_multi_fallback => pad_w={pad_w}, pad_h={pad_h}") |
|
|
|
sub_tensor = ( |
|
torch.from_numpy(np.array(subregion)) |
|
.permute(2, 0, 1) |
|
.unsqueeze(0) |
|
.float() / 255.0 |
|
) |
|
mask_tensor = torch.from_numpy(mask_sub.astype(np.float32)).unsqueeze(0).unsqueeze(0) |
|
|
|
original_f_pad = F.pad |
|
original_torch_pad = getattr(torch, "pad", None) |
|
original_reflection = None |
|
if hasattr(torch._C._nn, "reflection_pad2d"): |
|
original_reflection = torch._C._nn.reflection_pad2d |
|
|
|
def custom_f_pad(inp, pad_vals, mode="constant", value=0): |
|
if mode == "reflect": |
|
mode = "replicate" |
|
return original_f_pad(inp, pad_vals, mode=mode, value=value) |
|
|
|
def custom_torch_pad(inp, pad_vals, mode="constant", value=0): |
|
if mode == "reflect": |
|
mode = "replicate" |
|
return original_torch_pad(inp, pad_vals, mode=mode, value=value) |
|
|
|
def replicate_pad2d(*args, **kwargs): |
|
return F.replication_pad2d(*args, **kwargs) |
|
|
|
try: |
|
F.pad = custom_f_pad |
|
if original_torch_pad is not None: |
|
torch.pad = custom_torch_pad |
|
if original_reflection is not None: |
|
torch._C._nn.reflection_pad2d = replicate_pad2d |
|
|
|
sub_tensor_padded = F.pad(sub_tensor, (0, pad_w, 0, pad_h), mode='reflect') |
|
mask_tensor_padded = F.pad(mask_tensor, (0, pad_w, 0, pad_h), mode='constant', value=0) |
|
|
|
result_tensor = None |
|
try: |
|
with torch.no_grad(): |
|
sub_tensor_gpu = sub_tensor_padded.to("cuda") |
|
mask_tensor_gpu = mask_tensor_padded.to("cuda") |
|
result_tensor = lama_model_jit.model.forward(sub_tensor_gpu, mask_tensor_gpu) |
|
except RuntimeError as re_gpu: |
|
logging.warning(f"TorchScript GPU inpainting failed => {re_gpu}\nAttempting CPU fallback...") |
|
try: |
|
result_tensor = inpaint_torchscript_cpu_fallback(sub_tensor_padded, mask_tensor_padded, lama_model_jit) |
|
except RuntimeError as re_cpu: |
|
logging.warning(f"TorchScript CPU fallback also failed => {re_cpu}\nFilling with white region.") |
|
return fill_white(image_rgb, bbox) |
|
|
|
finally: |
|
F.pad = original_f_pad |
|
if original_torch_pad is not None: |
|
torch.pad = original_torch_pad |
|
if original_reflection is not None: |
|
torch._C._nn.reflection_pad2d = original_reflection |
|
|
|
if result_tensor is None: |
|
logging.warning("Result is None after fallback => filling with white region.") |
|
return fill_white(image_rgb, bbox) |
|
|
|
result_tensor_cropped = result_tensor[:, :, :new_h, :new_w] |
|
out_np = ( |
|
result_tensor_cropped.squeeze(0) |
|
.permute(1, 2, 0) |
|
.mul(255) |
|
.clamp(0, 255) |
|
.byte() |
|
.cpu() |
|
.numpy() |
|
) |
|
inpainted_pil = Image.fromarray(out_np) |
|
|
|
if scale != 1.0: |
|
inpainted_pil = inpainted_pil.resize((orig_w, orig_h), Image.Resampling.LANCZOS) |
|
|
|
final_sub = Image.new("RGB", (orig_w, orig_h), (255, 255, 255)) |
|
final_sub.paste(inpainted_pil, (0, 0)) |
|
out_img = image_rgb.copy() |
|
out_img.paste(final_sub, (x1, y1)) |
|
logging.debug(f"inpaint_region_with_lama_multi_fallback => done for region {bbox}") |
|
return out_img |
|
|
|
|
|
def inpaint_torchscript_cpu_fallback( |
|
sub_tensor_padded: torch.Tensor, |
|
mask_tensor_padded: torch.Tensor, |
|
lama_model_jit |
|
) -> torch.Tensor: |
|
orig_device = next(lama_model_jit.model.parameters()).device |
|
lama_model_jit.model.to("cpu") |
|
sub_cpu = sub_tensor_padded.cpu() |
|
mask_cpu = mask_tensor_padded.cpu() |
|
with torch.no_grad(): |
|
result_cpu = lama_model_jit.model.forward(sub_cpu, mask_cpu) |
|
lama_model_jit.model.to(orig_device) |
|
return result_cpu |
|
|
|
|
|
def fill_white(image_rgb: Image.Image, bbox: List[int]) -> Image.Image: |
|
x1, y1, x2, y2 = bbox |
|
ret_img = image_rgb.copy() |
|
draw = ImageDraw.Draw(ret_img) |
|
draw.rectangle([x1, y1, x2, y2], fill=(255, 255, 255)) |
|
return ret_img |
|
|
|
|
|
def inpaint_region_with_lama_gpu_only( |
|
image_rgb: Image.Image, |
|
mask_bin: np.ndarray, |
|
bbox: List[int], |
|
lama_model, |
|
debug_save_prefix: Optional[str] = None |
|
) -> Image.Image: |
|
x1, y1, x2, y2 = bbox |
|
subregion = image_rgb.crop((x1, y1, x2, y2)) |
|
mask_sub = mask_bin[y1:y2, x1:x2].copy() |
|
orig_w, orig_h = subregion.size |
|
if orig_w < 2 or orig_h < 2: |
|
return image_rgb |
|
|
|
target_size = 1024 |
|
scale = 1.0 |
|
max_dim = max(orig_w, orig_h) |
|
if max_dim > target_size: |
|
scale = target_size / float(max_dim) |
|
new_w = max(1, int(round(orig_w * scale))) |
|
new_h = max(1, int(round(orig_h * scale))) |
|
subregion = subregion.resize((new_w, new_h), Image.Resampling.LANCZOS) |
|
mask_sub = cv2.resize(mask_sub, (new_w, new_h), interpolation=cv2.INTER_NEAREST) |
|
else: |
|
new_w, new_h = orig_w, orig_h |
|
|
|
pad_w = (32 - (new_w % 32)) % 32 |
|
pad_h = (32 - (new_h % 32)) % 32 |
|
sub_np = np.array(subregion) |
|
sub_tensor = ( |
|
torch.from_numpy(sub_np) |
|
.permute(2, 0, 1) |
|
.unsqueeze(0) |
|
.float() |
|
.to("cuda") |
|
/ 255.0 |
|
).contiguous() |
|
|
|
mask_tensor = ( |
|
torch.from_numpy((mask_sub > 0).astype(np.float32)) |
|
.unsqueeze(0) |
|
.unsqueeze(0) |
|
.float() |
|
.to("cuda") |
|
).contiguous() |
|
|
|
original_F_pad = F.pad |
|
original_torch_pad = getattr(torch, "pad", None) |
|
|
|
def custom_F_pad(input, pad_vals, mode="constant", value=0): |
|
if mode == "reflect": |
|
mode = "replicate" |
|
return original_F_pad(input, pad_vals, mode=mode, value=value) |
|
|
|
def custom_torch_pad(input, pad_vals, mode="constant", value=0): |
|
if mode == "reflect": |
|
mode = "replicate" |
|
return original_torch_pad(input, pad_vals, mode=mode, value=value) |
|
|
|
original_reflection_pad2d = None |
|
if hasattr(torch._C._nn, 'reflection_pad2d'): |
|
original_reflection_pad2d = torch._C._nn.reflection_pad2d |
|
def no_reflection_pad2d(*args, **kwargs): |
|
return F.replication_pad2d(*args, **kwargs) |
|
|
|
try: |
|
F.pad = custom_F_pad |
|
if original_torch_pad is not None: |
|
torch.pad = custom_torch_pad |
|
if original_reflection_pad2d is not None: |
|
torch._C._nn.reflection_pad2d = no_reflection_pad2d |
|
|
|
sub_tensor_padded = F.pad(sub_tensor, (0, pad_w, 0, pad_h), mode='reflect') |
|
mask_tensor_padded = F.pad(mask_tensor, (0, pad_w, 0, pad_h), mode='constant', value=0) |
|
|
|
try: |
|
with torch.no_grad(): |
|
result_tensor = lama_model.model.forward(sub_tensor_padded, mask_tensor_padded) |
|
except RuntimeError as e: |
|
result_tensor = run_lama_on_cpu_fallback( |
|
sub_tensor_padded.cpu(), |
|
mask_tensor_padded.cpu(), |
|
lama_model |
|
) |
|
|
|
finally: |
|
F.pad = original_F_pad |
|
if original_torch_pad is not None: |
|
torch.pad = original_torch_pad |
|
if original_reflection_pad2d is not None: |
|
torch._C._nn.reflection_pad2d = original_reflection_pad2d |
|
|
|
result_tensor_cropped = result_tensor[:, :, :new_h, :new_w] |
|
result_np = ( |
|
result_tensor_cropped.squeeze(0) |
|
.permute(1, 2, 0) |
|
.mul(255) |
|
.clamp(0, 255) |
|
.cpu() |
|
.numpy() |
|
.astype(np.uint8) |
|
) |
|
inpainted_pil = Image.fromarray(result_np) |
|
|
|
if scale != 1.0: |
|
inpainted_pil = inpainted_pil.resize((orig_w, orig_h), Image.Resampling.LANCZOS) |
|
|
|
final_sub = Image.new("RGB", (orig_w, orig_h), (255, 255, 255)) |
|
final_sub.paste(inpainted_pil, (0, 0)) |
|
out_img = image_rgb.copy() |
|
out_img.paste(final_sub, (x1, y1)) |
|
torch.cuda.empty_cache() |
|
return out_img.convert("RGB") |
|
|
|
|
|
def run_lama_on_cpu_fallback( |
|
sub_tensor_padded_cpu: torch.Tensor, |
|
mask_tensor_padded_cpu: torch.Tensor, |
|
lama_model |
|
) -> torch.Tensor: |
|
with torch.no_grad(): |
|
orig_device = next(lama_model.model.parameters()).device |
|
lama_model.model.to("cpu") |
|
sub_t = sub_tensor_padded_cpu |
|
mask_t = mask_tensor_padded_cpu |
|
result = lama_model.model.forward(sub_t, mask_t) |
|
lama_model.model.to(orig_device) |
|
return result |
|
|
|
|
|
|
|
|
|
|