In [11]:
import torch
from carvekit.api.high import HiInterface

# Check doc strings for more information
interface = HiInterface(object_type="object",  # Can be "object" or "hairs-like".
                        batch_size_seg=5,
                        batch_size_matting=1,
                        device='cuda' if torch.cuda.is_available() else 'cpu',
                        seg_mask_size=640,  # Use 640 for Tracer B7 and 320 for U2Net
                        matting_mask_size=2048,
                        trimap_prob_threshold=231,
                        trimap_dilation=30,
                        trimap_erosion_iters=5,
                        fp16=False)
import os

# input_dir = "../data/raw"
# output_dir = "../data/nobg"
input_dir = "../data/raw/paldex.io"
output_dir = "../data/nobg/paldex.io"

# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Loop over all files and subdirectories in the input directory
for root, dirs, files in os.walk(input_dir):
    for filename in files:
        # Construct full file path
        file_path = os.path.join(root, filename)
        
        # Process the image and remove the background
        images_without_background = interface([file_path])
        image_wo_bg = images_without_background[0]
        
        # Create output subdirectory if it doesn't exist
        output_subdir = os.path.join(output_dir, os.path.relpath(root, input_dir))
        os.makedirs(output_subdir, exist_ok=True)
        
        # Save the processed image to the output directory
        # Since the image format is RGBA, we save it as PNG
        filename = os.path.splitext(filename)[0] + ".png"
        output_file_path = os.path.join(output_subdir, filename)
        image_wo_bg.save(output_file_path)


