|
import hashlib |
|
import json |
|
import os |
|
import re |
|
from pathlib import Path |
|
|
|
import folder_paths |
|
import numpy as np |
|
import torch |
|
from PIL import Image, ImageOps |
|
from PIL.PngImagePlugin import PngInfo |
|
|
|
from ..log import log |
|
|
|
|
|
class MTB_LoadImageSequence: |
|
"""Load an image sequence from a folder. The current frame is used to determine which image to load. |
|
|
|
Usually used in conjunction with the `Primitive` node set to increment to load a sequence of images from a folder. |
|
Use -1 to load all matching frames as a batch. |
|
|
|
""" |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"path": ("STRING", {"default": "videos/####.png"}), |
|
"current_frame": ( |
|
"INT", |
|
{"default": 0, "min": -1, "max": 9999999}, |
|
), |
|
}, |
|
"optional": { |
|
"range": ("STRING", {"default": ""}), |
|
}, |
|
} |
|
|
|
CATEGORY = "mtb/IO" |
|
FUNCTION = "load_image" |
|
RETURN_TYPES = ( |
|
"IMAGE", |
|
"MASK", |
|
"INT", |
|
"INT", |
|
) |
|
RETURN_NAMES = ( |
|
"image", |
|
"mask", |
|
"current_frame", |
|
"total_frames", |
|
) |
|
|
|
def load_image(self, path=None, current_frame=0, range=""): |
|
load_all = current_frame == -1 |
|
total_frames = 1 |
|
|
|
if range: |
|
frames = self.get_frames_from_range(path, range) |
|
imgs, masks = zip(*(img_from_path(frame) for frame in frames)) |
|
out_img = torch.cat(imgs, dim=0) |
|
out_mask = torch.cat(masks, dim=0) |
|
total_frames = len(imgs) |
|
return (out_img, out_mask, -1, total_frames) |
|
|
|
elif load_all: |
|
log.debug(f"Loading all frames from {path}") |
|
frames = resolve_all_frames(path) |
|
log.debug(f"Found {len(frames)} frames") |
|
|
|
imgs = [] |
|
masks = [] |
|
|
|
imgs, masks = zip(*(img_from_path(frame) for frame in frames)) |
|
|
|
out_img = torch.cat(imgs, dim=0) |
|
out_mask = torch.cat(masks, dim=0) |
|
total_frames = len(imgs) |
|
|
|
return (out_img, out_mask, -1, total_frames) |
|
|
|
log.debug(f"Loading image: {path}, {current_frame}") |
|
resolved_path = resolve_path(path, current_frame) |
|
image_path = folder_paths.get_annotated_filepath(resolved_path) |
|
image, mask = img_from_path(image_path) |
|
return (image, mask, current_frame, total_frames) |
|
|
|
def get_frames_from_range(self, path, range_str): |
|
try: |
|
start, end = map(int, range_str.split("-")) |
|
except ValueError: |
|
raise ValueError( |
|
f"Invalid range format: {range_str}. Expected format is 'start-end'." |
|
) |
|
|
|
frames = resolve_all_frames(path) |
|
total_frames = len(frames) |
|
|
|
if start < 0 or end >= total_frames: |
|
raise ValueError( |
|
f"Range {range_str} is out of bounds. Total frames available: {total_frames}" |
|
) |
|
|
|
if "#" in path: |
|
frame_regex = re.escape(path).replace(r"\#", r"(\d+)") |
|
frame_number_regex = re.compile(frame_regex) |
|
|
|
matching_frames = [] |
|
for frame in frames: |
|
match = frame_number_regex.search(frame) |
|
|
|
if match: |
|
frame_number = int(match.group(1)) |
|
if start <= frame_number <= end: |
|
matching_frames.append(frame) |
|
|
|
return matching_frames |
|
else: |
|
log.warning( |
|
f"Wildcard pattern or directory will use indexes instead of frame numbers for : {path}" |
|
) |
|
|
|
selected_frames = frames[start : end + 1] |
|
|
|
return selected_frames |
|
|
|
@staticmethod |
|
def IS_CHANGED(path="", current_frame=0, range=""): |
|
print(f"Checking if changed: {path}, {current_frame}") |
|
if range or current_frame == -1: |
|
resolved_paths = resolve_all_frames(path) |
|
timestamps = [ |
|
os.path.getmtime(folder_paths.get_annotated_filepath(p)) |
|
for p in resolved_paths |
|
] |
|
combined_hash = hashlib.sha256( |
|
"".join(map(str, timestamps)).encode() |
|
) |
|
return combined_hash.hexdigest() |
|
resolved_path = resolve_path(path, current_frame) |
|
image_path = folder_paths.get_annotated_filepath(resolved_path) |
|
if os.path.exists(image_path): |
|
m = hashlib.sha256() |
|
with open(image_path, "rb") as f: |
|
m.update(f.read()) |
|
return m.digest().hex() |
|
return "NONE" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import glob |
|
|
|
|
|
def img_from_path(path): |
|
img = Image.open(path) |
|
img = ImageOps.exif_transpose(img) |
|
image = img.convert("RGB") |
|
image = np.array(image).astype(np.float32) / 255.0 |
|
image = torch.from_numpy(image)[None,] |
|
if "A" in img.getbands(): |
|
mask = np.array(img.getchannel("A")).astype(np.float32) / 255.0 |
|
mask = 1.0 - torch.from_numpy(mask) |
|
else: |
|
mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu") |
|
return ( |
|
image, |
|
mask, |
|
) |
|
|
|
|
|
def resolve_all_frames(path: str): |
|
frames: list[str] = [] |
|
if "#" not in path: |
|
pth = Path(path) |
|
if pth.is_dir(): |
|
for f in pth.iterdir(): |
|
if f.suffix in [".jpg", ".png"]: |
|
frames.append(f.as_posix()) |
|
elif "*" in path: |
|
frames = glob.glob(path) |
|
else: |
|
raise ValueError( |
|
"The path doesn't contain a # or a * or is not a directory" |
|
) |
|
frames.sort() |
|
|
|
return frames |
|
|
|
pattern = path |
|
folder_path, file_pattern = os.path.split(pattern) |
|
|
|
log.debug(f"Resolving all frames in {folder_path}") |
|
hash_count = file_pattern.count("#") |
|
frame_pattern = re.sub(r"#+", "*", file_pattern) |
|
|
|
log.debug(f"Found pattern: {frame_pattern}") |
|
|
|
matching_files = glob.glob(os.path.join(folder_path, frame_pattern)) |
|
|
|
log.debug(f"Found {len(matching_files)} matching files") |
|
|
|
frame_regex = re.escape(file_pattern).replace(r"\#", r"(\d+)") |
|
|
|
frame_number_regex = re.compile(frame_regex) |
|
|
|
for file in matching_files: |
|
match = frame_number_regex.search(file) |
|
if match: |
|
frame_number = match.group(1) |
|
log.debug(f"Found frame number: {frame_number}") |
|
|
|
frames.append(file) |
|
|
|
frames.sort() |
|
return frames |
|
|
|
|
|
def resolve_path(path, frame): |
|
hashes = path.count("#") |
|
padded_number = str(frame).zfill(hashes) |
|
return re.sub("#+", padded_number, path) |
|
|
|
|
|
class MTB_SaveImageSequence: |
|
"""Save an image sequence to a folder. The current frame is used to determine which image to save. |
|
|
|
This is merely a wrapper around the `save_images` function with formatting for the output folder and filename. |
|
""" |
|
|
|
def __init__(self): |
|
self.output_dir = folder_paths.get_output_directory() |
|
self.type = "output" |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"images": ("IMAGE",), |
|
"filename_prefix": ("STRING", {"default": "Sequence"}), |
|
"current_frame": ( |
|
"INT", |
|
{"default": 0, "min": 0, "max": 9999999}, |
|
), |
|
}, |
|
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, |
|
} |
|
|
|
RETURN_TYPES = () |
|
FUNCTION = "save_images" |
|
|
|
OUTPUT_NODE = True |
|
|
|
CATEGORY = "mtb/IO" |
|
|
|
def save_images( |
|
self, |
|
images, |
|
filename_prefix="Sequence", |
|
current_frame=0, |
|
prompt=None, |
|
extra_pnginfo=None, |
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(images) > 1: |
|
raise ValueError("Can only save one image at a time") |
|
|
|
resolved_path = Path(self.output_dir) / filename_prefix |
|
resolved_path.mkdir(parents=True, exist_ok=True) |
|
|
|
resolved_img = ( |
|
resolved_path / f"{filename_prefix}_{current_frame:05}.png" |
|
) |
|
|
|
output_image = images[0].cpu().numpy() |
|
img = Image.fromarray( |
|
np.clip(output_image * 255.0, 0, 255).astype(np.uint8) |
|
) |
|
metadata = PngInfo() |
|
if prompt is not None: |
|
metadata.add_text("prompt", json.dumps(prompt)) |
|
if extra_pnginfo is not None: |
|
for x in extra_pnginfo: |
|
metadata.add_text(x, json.dumps(extra_pnginfo[x])) |
|
|
|
img.save(resolved_img, pnginfo=metadata, compress_level=4) |
|
return { |
|
"ui": { |
|
"images": [ |
|
{ |
|
"filename": resolved_img.name, |
|
"subfolder": resolved_path.name, |
|
"type": self.type, |
|
} |
|
] |
|
} |
|
} |
|
|
|
|
|
__nodes__ = [ |
|
MTB_LoadImageSequence, |
|
MTB_SaveImageSequence, |
|
] |
|
|