|
|
import cv2 |
|
|
import glob |
|
|
from tqdm import tqdm |
|
|
import numpy as np |
|
|
import os |
|
|
|
|
|
def make_final_mask(seg_path, sam_path, result_path): |
|
|
|
|
|
seg_patterns = [ |
|
|
os.path.join(seg_path, '*.png'), |
|
|
os.path.join(seg_path, '*.jpg'), |
|
|
os.path.join(seg_path, '*.jpeg'), |
|
|
] |
|
|
|
|
|
seg_full_path = [] |
|
|
for pattern in seg_patterns: |
|
|
files = sorted(glob.glob(pattern)) |
|
|
seg_full_path.extend(files) |
|
|
seg_full_path = sorted(list(set(seg_full_path))) |
|
|
|
|
|
|
|
|
sam_patterns = [ |
|
|
os.path.join(sam_path, '*.jpg'), |
|
|
os.path.join(sam_path, '*.png'), |
|
|
os.path.join(sam_path, '*.jpeg'), |
|
|
] |
|
|
|
|
|
sam_full_path = [] |
|
|
for pattern in sam_patterns: |
|
|
files = sorted(glob.glob(pattern)) |
|
|
sam_full_path.extend(files) |
|
|
sam_full_path = sorted(list(set(sam_full_path))) |
|
|
|
|
|
|
|
|
print(f"[DEBUG] seg_path: {seg_path}") |
|
|
print(f"[DEBUG] Found {len(seg_full_path)} seg images") |
|
|
if len(seg_full_path) > 0: |
|
|
print(f"[DEBUG] First 3 seg files: {seg_full_path[:3]}") |
|
|
|
|
|
print(f"[DEBUG] sam_path: {sam_path}") |
|
|
print(f"[DEBUG] Found {len(sam_full_path)} sam images") |
|
|
if len(sam_full_path) > 0: |
|
|
print(f"[DEBUG] First 3 sam files: {sam_full_path[:3]}") |
|
|
|
|
|
if len(seg_full_path) == 0: |
|
|
print(f"[ERROR] No seg images found in {seg_path}") |
|
|
return |
|
|
|
|
|
if len(sam_full_path) == 0: |
|
|
print(f"[ERROR] No sam images found in {sam_path}") |
|
|
return |
|
|
|
|
|
|
|
|
seg_dict = {} |
|
|
for path in seg_full_path: |
|
|
basename = os.path.splitext(os.path.basename(path))[0] |
|
|
seg_dict[basename] = path |
|
|
|
|
|
sam_dict = {} |
|
|
for path in sam_full_path: |
|
|
basename = os.path.splitext(os.path.basename(path))[0] |
|
|
sam_dict[basename] = path |
|
|
|
|
|
|
|
|
matched_pairs = [] |
|
|
for name in seg_dict.keys(): |
|
|
if name in sam_dict: |
|
|
matched_pairs.append((seg_dict[name], sam_dict[name])) |
|
|
|
|
|
print(f"[INFO] Found {len(matched_pairs)} matching pairs") |
|
|
|
|
|
if len(matched_pairs) == 0: |
|
|
print("[ERROR] No matching pairs found!") |
|
|
print(f"[DEBUG] Seg basenames: {list(seg_dict.keys())[:5]}") |
|
|
print(f"[DEBUG] Sam basenames: {list(sam_dict.keys())[:5]}") |
|
|
return |
|
|
|
|
|
for seg, sam in tqdm(matched_pairs): |
|
|
seg_img = cv2.imread(seg) |
|
|
sam_img = cv2.imread(sam) |
|
|
|
|
|
if seg_img is None: |
|
|
print(f"[WARN] Failed to read seg image: {seg}") |
|
|
continue |
|
|
if sam_img is None: |
|
|
print(f"[WARN] Failed to read sam image: {sam}") |
|
|
continue |
|
|
|
|
|
|
|
|
if seg_img.shape != sam_img.shape: |
|
|
print(f"[INFO] Resizing sam {sam_img.shape} to match seg {seg_img.shape}") |
|
|
sam_img = cv2.resize(sam_img, (seg_img.shape[1], seg_img.shape[0])) |
|
|
|
|
|
img_name = os.path.basename(sam) |
|
|
added_img = cv2.bitwise_and(seg_img, sam_img) |
|
|
binary_map = cv2.cvtColor(added_img, cv2.COLOR_BGR2GRAY) |
|
|
|
|
|
nlabels, labels, stats, centroids = cv2.connectedComponentsWithStats( |
|
|
binary_map, None, None, None, 8, cv2.CV_32S |
|
|
) |
|
|
|
|
|
|
|
|
areas = stats[1:, cv2.CC_STAT_AREA] |
|
|
result = np.zeros((labels.shape), np.uint8) |
|
|
for i in range(0, nlabels - 1): |
|
|
if areas[i] >= 400: |
|
|
result[labels == i + 1] = 255 |
|
|
|
|
|
output_path = os.path.join(result_path, img_name) |
|
|
cv2.imwrite(output_path, result) |
|
|
print(f"[INFO] Saved: {output_path}") |
|
|
|
|
|
if __name__ == '__main__': |
|
|
seg_path = '/Users/Admin/ScalpVision/datasets/seg_train' |
|
|
sam_path = '/Users/Admin/ScalpVision/prediction/sam_result/sam_val' |
|
|
result_path = 'prediction/ensemble_result/ensemble_val' |
|
|
os.makedirs(result_path, exist_ok=True) |
|
|
make_final_mask(seg_path, sam_path, result_path) |