Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Upload 4 files
Browse files
app.py
CHANGED
|
@@ -388,12 +388,56 @@ def postprocess_foreground(aligned, target, level=2):
|
|
| 388 |
return np.clip(result, 0, 255).astype(np.uint8)
|
| 389 |
|
| 390 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 391 |
# ============== Alignment Pipeline ==============
|
| 392 |
|
| 393 |
def align_image(source_img, target_img, pp_level=2):
|
| 394 |
target_h, target_w = target_img.shape[:2]
|
| 395 |
target_size = (target_w, target_h)
|
| 396 |
source_resized = cv2.resize(source_img, target_size, interpolation=cv2.INTER_LANCZOS4)
|
|
|
|
| 397 |
|
| 398 |
kp_src, desc_src = extract_features(source_resized)
|
| 399 |
kp_tgt, desc_tgt = extract_features(target_img)
|
|
@@ -416,11 +460,94 @@ def align_image(source_img, target_img, pp_level=2):
|
|
| 416 |
|
| 417 |
result = full_histogram_matching(aligned, target_img, mask=color_mask)
|
| 418 |
|
| 419 |
-
#
|
| 420 |
-
|
| 421 |
-
|
|
|
|
| 422 |
|
| 423 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 424 |
|
| 425 |
|
| 426 |
# ============== FastAPI App ==============
|
|
@@ -468,8 +595,8 @@ async def align_api(
|
|
| 468 |
if source_img is None or target_img is None:
|
| 469 |
raise HTTPException(status_code=400, detail="Failed to decode images")
|
| 470 |
|
| 471 |
-
|
| 472 |
-
png_bytes = encode_image_png(
|
| 473 |
|
| 474 |
return Response(content=png_bytes, media_type="image/png")
|
| 475 |
|
|
@@ -498,8 +625,8 @@ async def align_base64_api(
|
|
| 498 |
if source_img is None or target_img is None:
|
| 499 |
raise HTTPException(status_code=400, detail="Failed to decode images")
|
| 500 |
|
| 501 |
-
|
| 502 |
-
png_bytes = encode_image_png(
|
| 503 |
b64 = base64.b64encode(png_bytes).decode('utf-8')
|
| 504 |
|
| 505 |
return {"image": f"data:image/png;base64,{b64}"}
|
|
@@ -508,6 +635,50 @@ async def align_base64_api(
|
|
| 508 |
raise HTTPException(status_code=500, detail=str(e))
|
| 509 |
|
| 510 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 511 |
HTML_CONTENT = """
|
| 512 |
<!DOCTYPE html>
|
| 513 |
<html lang="en">
|
|
@@ -687,8 +858,8 @@ HTML_CONTENT = """
|
|
| 687 |
</div>
|
| 688 |
|
| 689 |
<div class="result" id="result">
|
| 690 |
-
<h2>✨
|
| 691 |
-
<img id="
|
| 692 |
<br>
|
| 693 |
<a id="downloadLink" download="aligned.png">Download Aligned Image</a>
|
| 694 |
</div>
|
|
@@ -762,18 +933,17 @@ console.log(data.image); // data:image/png;base64,...</code></pre>
|
|
| 762 |
formData.append('target', targetFile);
|
| 763 |
formData.append('pp', document.getElementById('ppLevel').value);
|
| 764 |
|
| 765 |
-
const response = await fetch('/api/align', {
|
| 766 |
method: 'POST',
|
| 767 |
body: formData
|
| 768 |
});
|
| 769 |
|
| 770 |
if (!response.ok) throw new Error('Alignment failed');
|
| 771 |
|
| 772 |
-
const
|
| 773 |
-
const url = URL.createObjectURL(blob);
|
| 774 |
|
| 775 |
-
document.getElementById('
|
| 776 |
-
document.getElementById('downloadLink').href =
|
| 777 |
result.classList.add('show');
|
| 778 |
} catch (err) {
|
| 779 |
alert('Error: ' + err.message);
|
|
|
|
| 388 |
return np.clip(result, 0, 255).astype(np.uint8)
|
| 389 |
|
| 390 |
|
| 391 |
+
# ============== Paste-back unedited regions ==============
|
| 392 |
+
|
| 393 |
+
def detect_unedited_mask(aligned, target, threshold=45, min_edit_area=2000,
|
| 394 |
+
safety_radius=8, blur_size=31):
|
| 395 |
+
diff = cv2.absdiff(aligned, target)
|
| 396 |
+
diff_gray = cv2.cvtColor(diff, cv2.COLOR_BGR2GRAY)
|
| 397 |
+
_, edited_binary = cv2.threshold(diff_gray, threshold, 255, cv2.THRESH_BINARY)
|
| 398 |
+
|
| 399 |
+
grow_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
|
| 400 |
+
edited_binary = cv2.dilate(edited_binary, grow_kernel, iterations=1)
|
| 401 |
+
|
| 402 |
+
close_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (31, 31))
|
| 403 |
+
edited_binary = cv2.morphologyEx(edited_binary, cv2.MORPH_CLOSE, close_kernel, iterations=1)
|
| 404 |
+
|
| 405 |
+
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(edited_binary, connectivity=8)
|
| 406 |
+
cleaned = np.zeros_like(edited_binary)
|
| 407 |
+
for i in range(1, num_labels):
|
| 408 |
+
if stats[i, cv2.CC_STAT_AREA] >= min_edit_area:
|
| 409 |
+
cleaned[labels == i] = 255
|
| 410 |
+
edited_binary = cleaned
|
| 411 |
+
|
| 412 |
+
if safety_radius > 0:
|
| 413 |
+
safety_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
|
| 414 |
+
(safety_radius * 2 + 1, safety_radius * 2 + 1))
|
| 415 |
+
edited_binary = cv2.dilate(edited_binary, safety_kernel, iterations=1)
|
| 416 |
+
|
| 417 |
+
unedited_binary = 255 - edited_binary
|
| 418 |
+
|
| 419 |
+
open_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (41, 41))
|
| 420 |
+
unedited_binary = cv2.morphologyEx(unedited_binary, cv2.MORPH_OPEN, open_kernel, iterations=2)
|
| 421 |
+
|
| 422 |
+
blur_size = blur_size | 1
|
| 423 |
+
soft_mask = cv2.GaussianBlur(unedited_binary.astype(np.float32) / 255.0,
|
| 424 |
+
(blur_size, blur_size), 0)
|
| 425 |
+
return soft_mask
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def paste_unedited_regions(aligned, target, mask):
|
| 429 |
+
mask_3ch = mask[:, :, np.newaxis]
|
| 430 |
+
result = target.astype(np.float32) * mask_3ch + aligned.astype(np.float32) * (1.0 - mask_3ch)
|
| 431 |
+
return np.clip(result, 0, 255).astype(np.uint8)
|
| 432 |
+
|
| 433 |
+
|
| 434 |
# ============== Alignment Pipeline ==============
|
| 435 |
|
| 436 |
def align_image(source_img, target_img, pp_level=2):
|
| 437 |
target_h, target_w = target_img.shape[:2]
|
| 438 |
target_size = (target_w, target_h)
|
| 439 |
source_resized = cv2.resize(source_img, target_size, interpolation=cv2.INTER_LANCZOS4)
|
| 440 |
+
naive_resized = source_resized.copy()
|
| 441 |
|
| 442 |
kp_src, desc_src = extract_features(source_resized)
|
| 443 |
kp_tgt, desc_tgt = extract_features(target_img)
|
|
|
|
| 460 |
|
| 461 |
result = full_histogram_matching(aligned, target_img, mask=color_mask)
|
| 462 |
|
| 463 |
+
# Paste back unedited regions from target
|
| 464 |
+
pre_paste = result.copy()
|
| 465 |
+
unedited_mask = detect_unedited_mask(result, target_img)
|
| 466 |
+
result = paste_unedited_regions(result, target_img, unedited_mask)
|
| 467 |
|
| 468 |
+
# Post-processing (only affects edited regions, then re-paste)
|
| 469 |
+
pp_result = None
|
| 470 |
+
if pp_level > 0:
|
| 471 |
+
pp_result = postprocess_foreground(result, target_img, level=pp_level)
|
| 472 |
+
pp_result = paste_unedited_regions(pp_result, target_img, unedited_mask)
|
| 473 |
+
|
| 474 |
+
final = pp_result if pp_result is not None else result
|
| 475 |
+
return final, naive_resized, result, pre_paste, unedited_mask, pp_result
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
def compute_diff_image(img1, img2, amplify=3.0):
|
| 479 |
+
gray1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY).astype(np.float32)
|
| 480 |
+
gray2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY).astype(np.float32)
|
| 481 |
+
abs_diff = np.abs(gray1 - gray2)
|
| 482 |
+
diff_vis = np.clip(abs_diff * amplify, 0, 255).astype(np.uint8)
|
| 483 |
+
return cv2.cvtColor(diff_vis, cv2.COLOR_GRAY2BGR)
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
def create_visualization_panel(naive_resized, aligned, target, pre_paste=None,
|
| 487 |
+
unedited_mask=None, postprocessed=None):
|
| 488 |
+
h, w = target.shape[:2]
|
| 489 |
+
label_height = 40
|
| 490 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 491 |
+
font_scale = 0.7
|
| 492 |
+
font_thickness = 2
|
| 493 |
+
font_color = (255, 255, 255)
|
| 494 |
+
bg_color = (40, 40, 40)
|
| 495 |
+
|
| 496 |
+
def add_label(img, text):
|
| 497 |
+
label_bar = np.full((label_height, img.shape[1], 3), bg_color, dtype=np.uint8)
|
| 498 |
+
text_size = cv2.getTextSize(text, font, font_scale, font_thickness)[0]
|
| 499 |
+
text_x = (img.shape[1] - text_size[0]) // 2
|
| 500 |
+
text_y = (label_height + text_size[1]) // 2
|
| 501 |
+
cv2.putText(label_bar, text, (text_x, text_y), font, font_scale, font_color, font_thickness)
|
| 502 |
+
return np.vstack([label_bar, img])
|
| 503 |
+
|
| 504 |
+
empty = np.full((h + label_height, w, 3), bg_color, dtype=np.uint8)
|
| 505 |
+
|
| 506 |
+
mask_vis_bgr = None
|
| 507 |
+
if unedited_mask is not None:
|
| 508 |
+
mask_vis = (unedited_mask * 255).astype(np.uint8)
|
| 509 |
+
mask_vis_bgr = cv2.cvtColor(mask_vis, cv2.COLOR_GRAY2BGR)
|
| 510 |
+
|
| 511 |
+
diff_naive = compute_diff_image(naive_resized, target)
|
| 512 |
+
diff_aligned = compute_diff_image(aligned, target)
|
| 513 |
+
diff_pre_paste = compute_diff_image(pre_paste, target) if pre_paste is not None else None
|
| 514 |
+
|
| 515 |
+
if postprocessed is not None:
|
| 516 |
+
diff_pp = compute_diff_image(postprocessed, target)
|
| 517 |
+
diff_aligned_vs_pp = compute_diff_image(aligned, postprocessed)
|
| 518 |
+
row1_items = [
|
| 519 |
+
add_label(naive_resized, "Naive Resize"),
|
| 520 |
+
add_label(aligned, "Aligned+Pasted"),
|
| 521 |
+
add_label(postprocessed, "Post-processed"),
|
| 522 |
+
add_label(target, "Target Reference"),
|
| 523 |
+
]
|
| 524 |
+
row2_items = [
|
| 525 |
+
add_label(diff_naive, "Diff: Naive vs Target"),
|
| 526 |
+
add_label(diff_pre_paste, "Diff: Pre-paste vs Target") if diff_pre_paste is not None else empty,
|
| 527 |
+
add_label(diff_aligned, "Diff: Pasted vs Target"),
|
| 528 |
+
add_label(diff_pp, "Diff: Post-proc vs Target"),
|
| 529 |
+
]
|
| 530 |
+
if mask_vis_bgr is not None:
|
| 531 |
+
row1_items.append(add_label(mask_vis_bgr, "Unedited Mask"))
|
| 532 |
+
row2_items.append(empty)
|
| 533 |
+
else:
|
| 534 |
+
row1_items = [
|
| 535 |
+
add_label(naive_resized, "Naive Resize"),
|
| 536 |
+
add_label(aligned, "Aligned+Pasted"),
|
| 537 |
+
add_label(target, "Target Reference"),
|
| 538 |
+
]
|
| 539 |
+
row2_items = [
|
| 540 |
+
add_label(diff_naive, "Diff: Naive vs Target"),
|
| 541 |
+
add_label(diff_pre_paste, "Diff: Pre-paste vs Target") if diff_pre_paste is not None else empty,
|
| 542 |
+
add_label(diff_aligned, "Diff: Pasted vs Target"),
|
| 543 |
+
]
|
| 544 |
+
if mask_vis_bgr is not None:
|
| 545 |
+
row1_items.append(add_label(mask_vis_bgr, "Unedited Mask"))
|
| 546 |
+
row2_items.append(empty)
|
| 547 |
+
|
| 548 |
+
row1 = np.hstack(row1_items)
|
| 549 |
+
row2 = np.hstack(row2_items)
|
| 550 |
+
return np.vstack([row1, row2])
|
| 551 |
|
| 552 |
|
| 553 |
# ============== FastAPI App ==============
|
|
|
|
| 595 |
if source_img is None or target_img is None:
|
| 596 |
raise HTTPException(status_code=400, detail="Failed to decode images")
|
| 597 |
|
| 598 |
+
final, *_ = align_image(source_img, target_img, pp_level=pp_level)
|
| 599 |
+
png_bytes = encode_image_png(final)
|
| 600 |
|
| 601 |
return Response(content=png_bytes, media_type="image/png")
|
| 602 |
|
|
|
|
| 625 |
if source_img is None or target_img is None:
|
| 626 |
raise HTTPException(status_code=400, detail="Failed to decode images")
|
| 627 |
|
| 628 |
+
final, *_ = align_image(source_img, target_img, pp_level=pp_level)
|
| 629 |
+
png_bytes = encode_image_png(final)
|
| 630 |
b64 = base64.b64encode(png_bytes).decode('utf-8')
|
| 631 |
|
| 632 |
return {"image": f"data:image/png;base64,{b64}"}
|
|
|
|
| 635 |
raise HTTPException(status_code=500, detail=str(e))
|
| 636 |
|
| 637 |
|
| 638 |
+
@app.post("/api/align/viz")
|
| 639 |
+
async def align_viz_api(
|
| 640 |
+
source: UploadFile = File(...),
|
| 641 |
+
target: UploadFile = File(...),
|
| 642 |
+
pp: int = Form(2, description="Post-processing level 0-3 (0=none, default=2)")
|
| 643 |
+
):
|
| 644 |
+
"""
|
| 645 |
+
Align source image to target and return visualization panel + final result.
|
| 646 |
+
"""
|
| 647 |
+
try:
|
| 648 |
+
pp_level = max(0, min(3, pp))
|
| 649 |
+
source_data = await source.read()
|
| 650 |
+
target_data = await target.read()
|
| 651 |
+
|
| 652 |
+
source_img = decode_image(source_data)
|
| 653 |
+
target_img = decode_image(target_data)
|
| 654 |
+
|
| 655 |
+
if source_img is None or target_img is None:
|
| 656 |
+
raise HTTPException(status_code=400, detail="Failed to decode images")
|
| 657 |
+
|
| 658 |
+
final, naive_resized, pasted, pre_paste, unedited_mask, pp_result = \
|
| 659 |
+
align_image(source_img, target_img, pp_level=pp_level)
|
| 660 |
+
|
| 661 |
+
panel = create_visualization_panel(
|
| 662 |
+
naive_resized, pasted, target_img,
|
| 663 |
+
pre_paste=pre_paste,
|
| 664 |
+
unedited_mask=unedited_mask,
|
| 665 |
+
postprocessed=pp_result
|
| 666 |
+
)
|
| 667 |
+
|
| 668 |
+
panel_bytes = encode_image_png(panel)
|
| 669 |
+
final_bytes = encode_image_png(final)
|
| 670 |
+
panel_b64 = base64.b64encode(panel_bytes).decode('utf-8')
|
| 671 |
+
final_b64 = base64.b64encode(final_bytes).decode('utf-8')
|
| 672 |
+
|
| 673 |
+
return {
|
| 674 |
+
"panel": f"data:image/png;base64,{panel_b64}",
|
| 675 |
+
"image": f"data:image/png;base64,{final_b64}",
|
| 676 |
+
}
|
| 677 |
+
|
| 678 |
+
except Exception as e:
|
| 679 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 680 |
+
|
| 681 |
+
|
| 682 |
HTML_CONTENT = """
|
| 683 |
<!DOCTYPE html>
|
| 684 |
<html lang="en">
|
|
|
|
| 858 |
</div>
|
| 859 |
|
| 860 |
<div class="result" id="result">
|
| 861 |
+
<h2>✨ Visualization</h2>
|
| 862 |
+
<img id="panelImg" src="" style="max-width:100%">
|
| 863 |
<br>
|
| 864 |
<a id="downloadLink" download="aligned.png">Download Aligned Image</a>
|
| 865 |
</div>
|
|
|
|
| 933 |
formData.append('target', targetFile);
|
| 934 |
formData.append('pp', document.getElementById('ppLevel').value);
|
| 935 |
|
| 936 |
+
const response = await fetch('/api/align/viz', {
|
| 937 |
method: 'POST',
|
| 938 |
body: formData
|
| 939 |
});
|
| 940 |
|
| 941 |
if (!response.ok) throw new Error('Alignment failed');
|
| 942 |
|
| 943 |
+
const data = await response.json();
|
|
|
|
| 944 |
|
| 945 |
+
document.getElementById('panelImg').src = data.panel;
|
| 946 |
+
document.getElementById('downloadLink').href = data.image;
|
| 947 |
result.classList.add('show');
|
| 948 |
} catch (err) {
|
| 949 |
alert('Error: ' + err.message);
|