Spaces:
Configuration error
Configuration error
import os | |
import torch | |
import numpy as np | |
from PIL import Image, ImageOps | |
from .utils import BIGMAX | |
from .logger import logger | |
class LoadImagesFromDirectory: | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"directory": ("STRING", {"default": ""}), | |
}, | |
"optional": { | |
"image_load_cap": ("INT", {"default": 0, "min": 0, "max": BIGMAX, "step": 1}), | |
"start_index": ("INT", {"default": 0, "min": 0, "max": BIGMAX, "step": 1}), | |
} | |
} | |
RETURN_TYPES = ("IMAGE", "MASK", "INT") | |
FUNCTION = "load_images" | |
CATEGORY = "" | |
def load_images(self, directory: str, image_load_cap: int = 0, start_index: int = 0): | |
if not os.path.isdir(directory): | |
raise FileNotFoundError(f"Directory '{directory} cannot be found.'") | |
dir_files = os.listdir(directory) | |
if len(dir_files) == 0: | |
raise FileNotFoundError(f"No files in directory '{directory}'.") | |
dir_files = sorted(dir_files) | |
dir_files = [os.path.join(directory, x) for x in dir_files] | |
# start at start_index | |
dir_files = dir_files[start_index:] | |
images = [] | |
masks = [] | |
limit_images = False | |
if image_load_cap > 0: | |
limit_images = True | |
image_count = 0 | |
for image_path in dir_files: | |
if os.path.isdir(image_path): | |
continue | |
if limit_images and image_count >= image_load_cap: | |
break | |
i = Image.open(image_path) | |
i = ImageOps.exif_transpose(i) | |
image = i.convert("RGB") | |
image = np.array(image).astype(np.float32) / 255.0 | |
image = torch.from_numpy(image)[None,] | |
if 'A' in i.getbands(): | |
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 | |
mask = 1. - torch.from_numpy(mask) | |
else: | |
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") | |
images.append(image) | |
masks.append(mask) | |
image_count += 1 | |
if len(images) == 0: | |
raise FileNotFoundError(f"No images could be loaded from directory '{directory}'.") | |
return (torch.cat(images, dim=0), torch.stack(masks, dim=0), image_count) | |