full / segmentation /make_final_mask.py
caubetotbunggg's picture
Upload folder using huggingface_hub
7bf5a8e verified
import cv2
import glob
from tqdm import tqdm
import numpy as np
import os
def make_final_mask(seg_path, sam_path, result_path):
# Try multiple extensions for seg_path
seg_patterns = [
os.path.join(seg_path, '*.png'),
os.path.join(seg_path, '*.jpg'),
os.path.join(seg_path, '*.jpeg'),
]
seg_full_path = []
for pattern in seg_patterns:
files = sorted(glob.glob(pattern))
seg_full_path.extend(files)
seg_full_path = sorted(list(set(seg_full_path))) # Remove duplicates
# Try multiple extensions for sam_path
sam_patterns = [
os.path.join(sam_path, '*.jpg'),
os.path.join(sam_path, '*.png'),
os.path.join(sam_path, '*.jpeg'),
]
sam_full_path = []
for pattern in sam_patterns:
files = sorted(glob.glob(pattern))
sam_full_path.extend(files)
sam_full_path = sorted(list(set(sam_full_path)))
# DEBUG
print(f"[DEBUG] seg_path: {seg_path}")
print(f"[DEBUG] Found {len(seg_full_path)} seg images")
if len(seg_full_path) > 0:
print(f"[DEBUG] First 3 seg files: {seg_full_path[:3]}")
print(f"[DEBUG] sam_path: {sam_path}")
print(f"[DEBUG] Found {len(sam_full_path)} sam images")
if len(sam_full_path) > 0:
print(f"[DEBUG] First 3 sam files: {sam_full_path[:3]}")
if len(seg_full_path) == 0:
print(f"[ERROR] No seg images found in {seg_path}")
return
if len(sam_full_path) == 0:
print(f"[ERROR] No sam images found in {sam_path}")
return
# Match by filename (without extension)
seg_dict = {}
for path in seg_full_path:
basename = os.path.splitext(os.path.basename(path))[0]
seg_dict[basename] = path
sam_dict = {}
for path in sam_full_path:
basename = os.path.splitext(os.path.basename(path))[0]
sam_dict[basename] = path
# Find matching pairs
matched_pairs = []
for name in seg_dict.keys():
if name in sam_dict:
matched_pairs.append((seg_dict[name], sam_dict[name]))
print(f"[INFO] Found {len(matched_pairs)} matching pairs")
if len(matched_pairs) == 0:
print("[ERROR] No matching pairs found!")
print(f"[DEBUG] Seg basenames: {list(seg_dict.keys())[:5]}")
print(f"[DEBUG] Sam basenames: {list(sam_dict.keys())[:5]}")
return
for seg, sam in tqdm(matched_pairs):
seg_img = cv2.imread(seg)
sam_img = cv2.imread(sam)
if seg_img is None:
print(f"[WARN] Failed to read seg image: {seg}")
continue
if sam_img is None:
print(f"[WARN] Failed to read sam image: {sam}")
continue
# Resize if shapes don't match
if seg_img.shape != sam_img.shape:
print(f"[INFO] Resizing sam {sam_img.shape} to match seg {seg_img.shape}")
sam_img = cv2.resize(sam_img, (seg_img.shape[1], seg_img.shape[0]))
img_name = os.path.basename(sam)
added_img = cv2.bitwise_and(seg_img, sam_img)
binary_map = cv2.cvtColor(added_img, cv2.COLOR_BGR2GRAY)
nlabels, labels, stats, centroids = cv2.connectedComponentsWithStats(
binary_map, None, None, None, 8, cv2.CV_32S
)
# Get CC_STAT_AREA component as stats[label, COLUMN]
areas = stats[1:, cv2.CC_STAT_AREA]
result = np.zeros((labels.shape), np.uint8)
for i in range(0, nlabels - 1):
if areas[i] >= 400: # Keep
result[labels == i + 1] = 255
output_path = os.path.join(result_path, img_name)
cv2.imwrite(output_path, result)
print(f"[INFO] Saved: {output_path}")
if __name__ == '__main__':
seg_path = '/Users/Admin/ScalpVision/datasets/seg_train' # mask gốc
sam_path = '/Users/Admin/ScalpVision/prediction/sam_result/sam_val' # mask SAM
result_path = 'prediction/ensemble_result/ensemble_val' # output mask hợp nhất
os.makedirs(result_path, exist_ok=True)
make_final_mask(seg_path, sam_path, result_path)