from PIL import Image from pathlib import Path import torch import numpy as np from tqdm import tqdm class ImageWithPathDataset(torch.utils.data.Dataset): def __init__(self, root_image_path, output_path, transform=None): self.root_image_path = root_image_path self.image_paths = list(root_image_path.glob("**/*.jpg")) self.transform = transform self.output_path = output_path def __len__(self): return len(self.image_paths) def __getitem__(self, idx): image_path = self.image_paths[idx] image = Image.open(image_path).convert("RGB") if self.transform: image = self.transform(image) output_emb_path = self.output_path / image_path.parent.relative_to( self.root_image_path ) output_emb_path.mkdir(exist_ok=True, parents=True) output_emb_path = output_emb_path / image_path.stem return image, output_emb_path