harelcain commited on
Commit
e210426
·
verified ·
1 Parent(s): c211703

Upload 4 files

Browse files
Files changed (1) hide show
  1. app.py +185 -15
app.py CHANGED
@@ -388,12 +388,56 @@ def postprocess_foreground(aligned, target, level=2):
388
  return np.clip(result, 0, 255).astype(np.uint8)
389
 
390
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391
  # ============== Alignment Pipeline ==============
392
 
393
  def align_image(source_img, target_img, pp_level=2):
394
  target_h, target_w = target_img.shape[:2]
395
  target_size = (target_w, target_h)
396
  source_resized = cv2.resize(source_img, target_size, interpolation=cv2.INTER_LANCZOS4)
 
397
 
398
  kp_src, desc_src = extract_features(source_resized)
399
  kp_tgt, desc_tgt = extract_features(target_img)
@@ -416,11 +460,94 @@ def align_image(source_img, target_img, pp_level=2):
416
 
417
  result = full_histogram_matching(aligned, target_img, mask=color_mask)
418
 
419
- # Post-processing
420
- if pp_level > 0:
421
- result = postprocess_foreground(result, target_img, level=pp_level)
 
422
 
423
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424
 
425
 
426
  # ============== FastAPI App ==============
@@ -468,8 +595,8 @@ async def align_api(
468
  if source_img is None or target_img is None:
469
  raise HTTPException(status_code=400, detail="Failed to decode images")
470
 
471
- aligned = align_image(source_img, target_img, pp_level=pp_level)
472
- png_bytes = encode_image_png(aligned)
473
 
474
  return Response(content=png_bytes, media_type="image/png")
475
 
@@ -498,8 +625,8 @@ async def align_base64_api(
498
  if source_img is None or target_img is None:
499
  raise HTTPException(status_code=400, detail="Failed to decode images")
500
 
501
- aligned = align_image(source_img, target_img, pp_level=pp_level)
502
- png_bytes = encode_image_png(aligned)
503
  b64 = base64.b64encode(png_bytes).decode('utf-8')
504
 
505
  return {"image": f"data:image/png;base64,{b64}"}
@@ -508,6 +635,50 @@ async def align_base64_api(
508
  raise HTTPException(status_code=500, detail=str(e))
509
 
510
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
511
  HTML_CONTENT = """
512
  <!DOCTYPE html>
513
  <html lang="en">
@@ -687,8 +858,8 @@ HTML_CONTENT = """
687
  </div>
688
 
689
  <div class="result" id="result">
690
- <h2>&#10024; Aligned Result</h2>
691
- <img id="resultImg" src="">
692
  <br>
693
  <a id="downloadLink" download="aligned.png">Download Aligned Image</a>
694
  </div>
@@ -762,18 +933,17 @@ console.log(data.image); // data:image/png;base64,...</code></pre>
762
  formData.append('target', targetFile);
763
  formData.append('pp', document.getElementById('ppLevel').value);
764
 
765
- const response = await fetch('/api/align', {
766
  method: 'POST',
767
  body: formData
768
  });
769
 
770
  if (!response.ok) throw new Error('Alignment failed');
771
 
772
- const blob = await response.blob();
773
- const url = URL.createObjectURL(blob);
774
 
775
- document.getElementById('resultImg').src = url;
776
- document.getElementById('downloadLink').href = url;
777
  result.classList.add('show');
778
  } catch (err) {
779
  alert('Error: ' + err.message);
 
388
  return np.clip(result, 0, 255).astype(np.uint8)
389
 
390
 
391
+ # ============== Paste-back unedited regions ==============
392
+
393
+ def detect_unedited_mask(aligned, target, threshold=45, min_edit_area=2000,
394
+ safety_radius=8, blur_size=31):
395
+ diff = cv2.absdiff(aligned, target)
396
+ diff_gray = cv2.cvtColor(diff, cv2.COLOR_BGR2GRAY)
397
+ _, edited_binary = cv2.threshold(diff_gray, threshold, 255, cv2.THRESH_BINARY)
398
+
399
+ grow_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
400
+ edited_binary = cv2.dilate(edited_binary, grow_kernel, iterations=1)
401
+
402
+ close_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (31, 31))
403
+ edited_binary = cv2.morphologyEx(edited_binary, cv2.MORPH_CLOSE, close_kernel, iterations=1)
404
+
405
+ num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(edited_binary, connectivity=8)
406
+ cleaned = np.zeros_like(edited_binary)
407
+ for i in range(1, num_labels):
408
+ if stats[i, cv2.CC_STAT_AREA] >= min_edit_area:
409
+ cleaned[labels == i] = 255
410
+ edited_binary = cleaned
411
+
412
+ if safety_radius > 0:
413
+ safety_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
414
+ (safety_radius * 2 + 1, safety_radius * 2 + 1))
415
+ edited_binary = cv2.dilate(edited_binary, safety_kernel, iterations=1)
416
+
417
+ unedited_binary = 255 - edited_binary
418
+
419
+ open_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (41, 41))
420
+ unedited_binary = cv2.morphologyEx(unedited_binary, cv2.MORPH_OPEN, open_kernel, iterations=2)
421
+
422
+ blur_size = blur_size | 1
423
+ soft_mask = cv2.GaussianBlur(unedited_binary.astype(np.float32) / 255.0,
424
+ (blur_size, blur_size), 0)
425
+ return soft_mask
426
+
427
+
428
+ def paste_unedited_regions(aligned, target, mask):
429
+ mask_3ch = mask[:, :, np.newaxis]
430
+ result = target.astype(np.float32) * mask_3ch + aligned.astype(np.float32) * (1.0 - mask_3ch)
431
+ return np.clip(result, 0, 255).astype(np.uint8)
432
+
433
+
434
  # ============== Alignment Pipeline ==============
435
 
436
  def align_image(source_img, target_img, pp_level=2):
437
  target_h, target_w = target_img.shape[:2]
438
  target_size = (target_w, target_h)
439
  source_resized = cv2.resize(source_img, target_size, interpolation=cv2.INTER_LANCZOS4)
440
+ naive_resized = source_resized.copy()
441
 
442
  kp_src, desc_src = extract_features(source_resized)
443
  kp_tgt, desc_tgt = extract_features(target_img)
 
460
 
461
  result = full_histogram_matching(aligned, target_img, mask=color_mask)
462
 
463
+ # Paste back unedited regions from target
464
+ pre_paste = result.copy()
465
+ unedited_mask = detect_unedited_mask(result, target_img)
466
+ result = paste_unedited_regions(result, target_img, unedited_mask)
467
 
468
+ # Post-processing (only affects edited regions, then re-paste)
469
+ pp_result = None
470
+ if pp_level > 0:
471
+ pp_result = postprocess_foreground(result, target_img, level=pp_level)
472
+ pp_result = paste_unedited_regions(pp_result, target_img, unedited_mask)
473
+
474
+ final = pp_result if pp_result is not None else result
475
+ return final, naive_resized, result, pre_paste, unedited_mask, pp_result
476
+
477
+
478
+ def compute_diff_image(img1, img2, amplify=3.0):
479
+ gray1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY).astype(np.float32)
480
+ gray2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY).astype(np.float32)
481
+ abs_diff = np.abs(gray1 - gray2)
482
+ diff_vis = np.clip(abs_diff * amplify, 0, 255).astype(np.uint8)
483
+ return cv2.cvtColor(diff_vis, cv2.COLOR_GRAY2BGR)
484
+
485
+
486
+ def create_visualization_panel(naive_resized, aligned, target, pre_paste=None,
487
+ unedited_mask=None, postprocessed=None):
488
+ h, w = target.shape[:2]
489
+ label_height = 40
490
+ font = cv2.FONT_HERSHEY_SIMPLEX
491
+ font_scale = 0.7
492
+ font_thickness = 2
493
+ font_color = (255, 255, 255)
494
+ bg_color = (40, 40, 40)
495
+
496
+ def add_label(img, text):
497
+ label_bar = np.full((label_height, img.shape[1], 3), bg_color, dtype=np.uint8)
498
+ text_size = cv2.getTextSize(text, font, font_scale, font_thickness)[0]
499
+ text_x = (img.shape[1] - text_size[0]) // 2
500
+ text_y = (label_height + text_size[1]) // 2
501
+ cv2.putText(label_bar, text, (text_x, text_y), font, font_scale, font_color, font_thickness)
502
+ return np.vstack([label_bar, img])
503
+
504
+ empty = np.full((h + label_height, w, 3), bg_color, dtype=np.uint8)
505
+
506
+ mask_vis_bgr = None
507
+ if unedited_mask is not None:
508
+ mask_vis = (unedited_mask * 255).astype(np.uint8)
509
+ mask_vis_bgr = cv2.cvtColor(mask_vis, cv2.COLOR_GRAY2BGR)
510
+
511
+ diff_naive = compute_diff_image(naive_resized, target)
512
+ diff_aligned = compute_diff_image(aligned, target)
513
+ diff_pre_paste = compute_diff_image(pre_paste, target) if pre_paste is not None else None
514
+
515
+ if postprocessed is not None:
516
+ diff_pp = compute_diff_image(postprocessed, target)
517
+ diff_aligned_vs_pp = compute_diff_image(aligned, postprocessed)
518
+ row1_items = [
519
+ add_label(naive_resized, "Naive Resize"),
520
+ add_label(aligned, "Aligned+Pasted"),
521
+ add_label(postprocessed, "Post-processed"),
522
+ add_label(target, "Target Reference"),
523
+ ]
524
+ row2_items = [
525
+ add_label(diff_naive, "Diff: Naive vs Target"),
526
+ add_label(diff_pre_paste, "Diff: Pre-paste vs Target") if diff_pre_paste is not None else empty,
527
+ add_label(diff_aligned, "Diff: Pasted vs Target"),
528
+ add_label(diff_pp, "Diff: Post-proc vs Target"),
529
+ ]
530
+ if mask_vis_bgr is not None:
531
+ row1_items.append(add_label(mask_vis_bgr, "Unedited Mask"))
532
+ row2_items.append(empty)
533
+ else:
534
+ row1_items = [
535
+ add_label(naive_resized, "Naive Resize"),
536
+ add_label(aligned, "Aligned+Pasted"),
537
+ add_label(target, "Target Reference"),
538
+ ]
539
+ row2_items = [
540
+ add_label(diff_naive, "Diff: Naive vs Target"),
541
+ add_label(diff_pre_paste, "Diff: Pre-paste vs Target") if diff_pre_paste is not None else empty,
542
+ add_label(diff_aligned, "Diff: Pasted vs Target"),
543
+ ]
544
+ if mask_vis_bgr is not None:
545
+ row1_items.append(add_label(mask_vis_bgr, "Unedited Mask"))
546
+ row2_items.append(empty)
547
+
548
+ row1 = np.hstack(row1_items)
549
+ row2 = np.hstack(row2_items)
550
+ return np.vstack([row1, row2])
551
 
552
 
553
  # ============== FastAPI App ==============
 
595
  if source_img is None or target_img is None:
596
  raise HTTPException(status_code=400, detail="Failed to decode images")
597
 
598
+ final, *_ = align_image(source_img, target_img, pp_level=pp_level)
599
+ png_bytes = encode_image_png(final)
600
 
601
  return Response(content=png_bytes, media_type="image/png")
602
 
 
625
  if source_img is None or target_img is None:
626
  raise HTTPException(status_code=400, detail="Failed to decode images")
627
 
628
+ final, *_ = align_image(source_img, target_img, pp_level=pp_level)
629
+ png_bytes = encode_image_png(final)
630
  b64 = base64.b64encode(png_bytes).decode('utf-8')
631
 
632
  return {"image": f"data:image/png;base64,{b64}"}
 
635
  raise HTTPException(status_code=500, detail=str(e))
636
 
637
 
638
+ @app.post("/api/align/viz")
639
+ async def align_viz_api(
640
+ source: UploadFile = File(...),
641
+ target: UploadFile = File(...),
642
+ pp: int = Form(2, description="Post-processing level 0-3 (0=none, default=2)")
643
+ ):
644
+ """
645
+ Align source image to target and return visualization panel + final result.
646
+ """
647
+ try:
648
+ pp_level = max(0, min(3, pp))
649
+ source_data = await source.read()
650
+ target_data = await target.read()
651
+
652
+ source_img = decode_image(source_data)
653
+ target_img = decode_image(target_data)
654
+
655
+ if source_img is None or target_img is None:
656
+ raise HTTPException(status_code=400, detail="Failed to decode images")
657
+
658
+ final, naive_resized, pasted, pre_paste, unedited_mask, pp_result = \
659
+ align_image(source_img, target_img, pp_level=pp_level)
660
+
661
+ panel = create_visualization_panel(
662
+ naive_resized, pasted, target_img,
663
+ pre_paste=pre_paste,
664
+ unedited_mask=unedited_mask,
665
+ postprocessed=pp_result
666
+ )
667
+
668
+ panel_bytes = encode_image_png(panel)
669
+ final_bytes = encode_image_png(final)
670
+ panel_b64 = base64.b64encode(panel_bytes).decode('utf-8')
671
+ final_b64 = base64.b64encode(final_bytes).decode('utf-8')
672
+
673
+ return {
674
+ "panel": f"data:image/png;base64,{panel_b64}",
675
+ "image": f"data:image/png;base64,{final_b64}",
676
+ }
677
+
678
+ except Exception as e:
679
+ raise HTTPException(status_code=500, detail=str(e))
680
+
681
+
682
  HTML_CONTENT = """
683
  <!DOCTYPE html>
684
  <html lang="en">
 
858
  </div>
859
 
860
  <div class="result" id="result">
861
+ <h2>&#10024; Visualization</h2>
862
+ <img id="panelImg" src="" style="max-width:100%">
863
  <br>
864
  <a id="downloadLink" download="aligned.png">Download Aligned Image</a>
865
  </div>
 
933
  formData.append('target', targetFile);
934
  formData.append('pp', document.getElementById('ppLevel').value);
935
 
936
+ const response = await fetch('/api/align/viz', {
937
  method: 'POST',
938
  body: formData
939
  });
940
 
941
  if (!response.ok) throw new Error('Alignment failed');
942
 
943
+ const data = await response.json();
 
944
 
945
+ document.getElementById('panelImg').src = data.panel;
946
+ document.getElementById('downloadLink').href = data.image;
947
  result.classList.add('show');
948
  } catch (err) {
949
  alert('Error: ' + err.message);