Spaces:
Restarting
on
Zero
Restarting
on
Zero
import io | |
import json | |
import urllib.parse | |
import urllib.request | |
from math import pi | |
import comfy.model_management as model_management | |
import comfy.utils | |
import numpy as np | |
import torch | |
from PIL import Image | |
from ..log import log | |
from ..utils import ( | |
EASINGS, | |
apply_easing, | |
get_server_info, | |
numpy_NFOV, | |
pil2tensor, | |
tensor2np, | |
) | |
def get_image(filename, subfolder, folder_type): | |
log.debug( | |
f"Getting image {filename} from foldertype {folder_type} {f'in subfolder: {subfolder}' if subfolder else ''}" | |
) | |
data = {"filename": filename, "subfolder": subfolder, "type": folder_type} | |
base_url, port = get_server_info() | |
url_values = urllib.parse.urlencode(data) | |
url = f"http://{base_url}:{port}/view?{url_values}" | |
log.debug(f"Fetching image from {url}") | |
with urllib.request.urlopen(url) as response: | |
return io.BytesIO(response.read()) | |
class MTB_ToDevice: | |
"""Send a image or mask tensor to the given device.""" | |
def INPUT_TYPES(cls): | |
devices = ["cpu"] | |
if torch.backends.mps.is_available(): | |
devices.append("mps") | |
if torch.cuda.is_available(): | |
devices.append("cuda") | |
for i in range(torch.cuda.device_count()): | |
devices.append(f"cuda{i}") | |
return { | |
"required": { | |
"ignore_errors": ("BOOLEAN", {"default": False}), | |
"device": (devices, {"default": "cpu"}), | |
}, | |
"optional": { | |
"image": ("IMAGE",), | |
"mask": ("MASK",), | |
}, | |
} | |
RETURN_TYPES = ("IMAGE", "MASK") | |
RETURN_NAMES = ("images", "masks") | |
CATEGORY = "mtb/utils" | |
FUNCTION = "to_device" | |
def to_device( | |
self, | |
*, | |
ignore_errors=False, | |
device="cuda", | |
image: torch.Tensor | None = None, | |
mask: torch.Tensor | None = None, | |
): | |
if not ignore_errors and image is None and mask is None: | |
raise ValueError( | |
"You must either provide an image or a mask," | |
" use ignore_error to passthrough" | |
) | |
if image is not None: | |
image = image.to(device) | |
if mask is not None: | |
mask = mask.to(device) | |
return (image, mask) | |
# class MTB_ApplyTextTemplate: | |
class MTB_ApplyTextTemplate: | |
""" | |
Experimental node to interpolate strings from inputs. | |
Interpolation just requires {}, for instance: | |
Some string {var_1} and {var_2} | |
""" | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"template": ("STRING", {"default": "", "multiline": True}), | |
}, | |
} | |
RETURN_TYPES = ("STRING",) | |
RETURN_NAMES = ("string",) | |
CATEGORY = "mtb/utils" | |
FUNCTION = "execute" | |
def execute(self, *, template: str, **kwargs): | |
res = f"{template}" | |
for k, v in kwargs.items(): | |
res = res.replace(f"{{{k}}}", f"{v}") | |
return (res,) | |
class MTB_MatchDimensions: | |
"""Match images dimensions along the given dimension, preserving aspect ratio.""" | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"source": ("IMAGE",), | |
"reference": ("IMAGE",), | |
"match": (["height", "width"], {"default": "height"}), | |
}, | |
} | |
RETURN_TYPES = ("IMAGE", "INT", "INT") | |
RETURN_NAMES = ("image", "new_width", "new_height") | |
CATEGORY = "mtb/utils" | |
FUNCTION = "execute" | |
def execute( | |
self, source: torch.Tensor, reference: torch.Tensor, match: str | |
): | |
import torchvision.transforms.functional as VF | |
_batch_size, height, width, _channels = source.shape | |
_rbatch_size, rheight, rwidth, _rchannels = reference.shape | |
source_aspect_ratio = width / height | |
# reference_aspect_ratio = rwidth / rheight | |
source = source.permute(0, 3, 1, 2) | |
reference = reference.permute(0, 3, 1, 2) | |
if match == "height": | |
new_height = rheight | |
new_width = int(rheight * source_aspect_ratio) | |
else: | |
new_width = rwidth | |
new_height = int(rwidth / source_aspect_ratio) | |
resized_images = [ | |
VF.resize( | |
source[i], | |
(new_height, new_width), | |
antialias=True, | |
interpolation=Image.BICUBIC, | |
) | |
for i in range(_batch_size) | |
] | |
resized_source = torch.stack(resized_images, dim=0) | |
resized_source = resized_source.permute(0, 2, 3, 1) | |
return (resized_source, new_width, new_height) | |
class MTB_FloatToFloats: | |
"""Conversion utility for compatibility with other extensions (AD, IPA, Fitz are using FLOAT to represent list of floats.)""" | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"float": ("FLOAT", {"default": 0.0, "forceInput": True}), | |
} | |
} | |
RETURN_TYPES = ("FLOATS",) | |
RETURN_NAMES = ("floats",) | |
CATEGORY = "mtb/utils" | |
FUNCTION = "convert" | |
def convert(self, float: float): | |
return (float,) | |
class MTB_FloatsToInts: | |
"""Conversion utility for compatibility with frame interpolation.""" | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"floats": ("FLOATS", {"forceInput": True}), | |
} | |
} | |
RETURN_TYPES = ("INTS", "INT") | |
CATEGORY = "mtb/utils" | |
FUNCTION = "convert" | |
def convert(self, floats: list[float]): | |
vals = [int(x) for x in floats] | |
return (vals, vals) | |
class MTB_FloatsToFloat: | |
"""Conversion utility for compatibility with other extensions (AD, IPA, Fitz are using FLOAT to represent list of floats.)""" | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"floats": ("FLOATS",), | |
} | |
} | |
RETURN_TYPES = ("FLOAT",) | |
RETURN_NAMES = ("float",) | |
CATEGORY = "mtb/utils" | |
FUNCTION = "convert" | |
def convert(self, floats): | |
return (floats,) | |
class MTB_AutoPanEquilateral: | |
"""Generate a 360 panning video from an equilateral image.""" | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"equilateral_image": ("IMAGE",), | |
"fovX": ("FLOAT", {"default": 45.0}), | |
"fovY": ("FLOAT", {"default": 45.0}), | |
"elevation": ("FLOAT", {"default": 0.5}), | |
"frame_count": ("INT", {"default": 100}), | |
"width": ("INT", {"default": 768}), | |
"height": ("INT", {"default": 512}), | |
}, | |
"optional": { | |
"floats_fovX": ("FLOATS",), | |
"floats_fovY": ("FLOATS",), | |
"floats_elevation": ("FLOATS",), | |
}, | |
} | |
RETURN_TYPES = ("IMAGE",) | |
RETURN_NAMES = ("image",) | |
CATEGORY = "mtb/utils" | |
FUNCTION = "generate_frames" | |
def check_floats(self, f: list[float] | None, expected_count: int): | |
if f: | |
if len(f) == expected_count: | |
return True | |
return False | |
return True | |
def generate_frames( | |
self, | |
equilateral_image: torch.Tensor, | |
fovX: float, | |
fovY: float, | |
elevation: float, | |
frame_count: int, | |
width: int, | |
height: int, | |
floats_fovX: list[float] | None = None, | |
floats_fovY: list[float] | None = None, | |
floats_elevation: list[float] | None = None, | |
): | |
source = tensor2np(equilateral_image) | |
if len(source) > 1: | |
log.warn( | |
"You provided more than one image in the equilateral_image input, only the first will be used." | |
) | |
if not all( | |
[ | |
self.check_floats(x, frame_count) | |
for x in [floats_fovX, floats_fovY, floats_elevation] | |
] | |
): | |
raise ValueError( | |
"You provided less than the expected number of fovX, fovY, or elevation values." | |
) | |
source = source[0] | |
frames = [] | |
pbar = comfy.utils.ProgressBar(frame_count) | |
for i in range(frame_count): | |
rotation_angle = (i / frame_count) * 2 * pi | |
if floats_elevation: | |
elevation = floats_elevation[i] | |
if floats_fovX: | |
fovX = floats_fovX[i] | |
if floats_fovY: | |
fovY = floats_fovY[i] | |
fov = [fovX / 100, fovY / 100] | |
center_point = [rotation_angle / (2 * pi), elevation] | |
nfov = numpy_NFOV(fov, height, width) | |
frame = nfov.to_nfov(source, center_point=center_point) | |
frames.append(frame) | |
model_management.throw_exception_if_processing_interrupted() | |
pbar.update(1) | |
return (pil2tensor(frames),) | |
class MTB_GetBatchFromHistory: | |
"""Very experimental node to load images from the history of the server. | |
Queue items without output are ignored in the count. | |
""" | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"enable": ("BOOLEAN", {"default": True}), | |
"count": ("INT", {"default": 1, "min": 0}), | |
"offset": ("INT", {"default": 0, "min": -1e9, "max": 1e9}), | |
"internal_count": ("INT", {"default": 0}), | |
}, | |
"optional": { | |
"passthrough_image": ("IMAGE",), | |
}, | |
} | |
RETURN_TYPES = ("IMAGE",) | |
RETURN_NAMES = ("images",) | |
CATEGORY = "mtb/animation" | |
FUNCTION = "load_from_history" | |
def load_from_history( | |
self, | |
*, | |
enable=True, | |
count=0, | |
offset=0, | |
internal_count=0, # hacky way to invalidate the node | |
passthrough_image=None, | |
): | |
if not enable or count == 0: | |
if passthrough_image is not None: | |
log.debug("Using passthrough image") | |
return (passthrough_image,) | |
log.debug("Load from history is disabled for this iteration") | |
return (torch.zeros(0),) | |
frames = [] | |
base_url, port = get_server_info() | |
history_url = f"http://{base_url}:{port}/history" | |
log.debug(f"Fetching history from {history_url}") | |
output = torch.zeros(0) | |
with urllib.request.urlopen(history_url) as response: | |
output = self.load_batch_frames(response, offset, count, frames) | |
if output.size(0) == 0: | |
log.warn("No output found in history") | |
return (output,) | |
def load_batch_frames(self, response, offset, count, frames): | |
history = json.loads(response.read()) | |
output_images = [] | |
for run in history.values(): | |
for node_output in run["outputs"].values(): | |
if "images" in node_output: | |
for image in node_output["images"]: | |
image_data = get_image( | |
image["filename"], | |
image["subfolder"], | |
image["type"], | |
) | |
output_images.append(image_data) | |
if not output_images: | |
return torch.zeros(0) | |
# Directly get desired range of images | |
start_index = max(len(output_images) - offset - count, 0) | |
end_index = len(output_images) - offset | |
selected_images = output_images[start_index:end_index] | |
frames = [Image.open(image) for image in selected_images] | |
if not frames: | |
return torch.zeros(0) | |
elif len(frames) != count: | |
log.warning(f"Expected {count} images, got {len(frames)} instead") | |
return pil2tensor(frames) | |
class MTB_AnyToString: | |
"""Tries to take any input and convert it to a string.""" | |
def INPUT_TYPES(cls): | |
return { | |
"required": {"input": ("*",)}, | |
} | |
RETURN_TYPES = ("STRING",) | |
FUNCTION = "do_str" | |
CATEGORY = "mtb/converters" | |
def do_str(self, input): | |
if isinstance(input, str): | |
return (input,) | |
elif isinstance(input, torch.Tensor): | |
return (f"Tensor of shape {input.shape} and dtype {input.dtype}",) | |
elif isinstance(input, Image.Image): | |
return (f"PIL Image of size {input.size} and mode {input.mode}",) | |
elif isinstance(input, np.ndarray): | |
return ( | |
f"Numpy array of shape {input.shape} and dtype {input.dtype}", | |
) | |
elif isinstance(input, dict): | |
return ( | |
f"Dictionary of {len(input)} items, with keys {input.keys()}", | |
) | |
else: | |
log.debug(f"Falling back to string conversion of {input}") | |
return (str(input),) | |
class MTB_StringReplace: | |
"""Basic string replacement.""" | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"string": ("STRING", {"forceInput": True}), | |
"old": ("STRING", {"default": ""}), | |
"new": ("STRING", {"default": ""}), | |
} | |
} | |
FUNCTION = "replace_str" | |
RETURN_TYPES = ("STRING",) | |
CATEGORY = "mtb/string" | |
def replace_str(self, string: str, old: str, new: str): | |
log.debug(f"Current string: {string}") | |
log.debug(f"Find string: {old}") | |
log.debug(f"Replace string: {new}") | |
string = string.replace(old, new) | |
log.debug(f"New string: {string}") | |
return (string,) | |
class MTB_MathExpression: | |
"""Node to evaluate a simple math expression string""" | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"expression": ("STRING", {"default": "", "multiline": True}), | |
} | |
} | |
FUNCTION = "eval_expression" | |
RETURN_TYPES = ("FLOAT", "INT") | |
RETURN_NAMES = ("result (float)", "result (int)") | |
CATEGORY = "mtb/math" | |
DESCRIPTION = ( | |
"evaluate a simple math expression string, only supports literal_eval" | |
) | |
def eval_expression(self, expression: str, **kwargs): | |
from ast import literal_eval | |
for key, value in kwargs.items(): | |
log.debug(f"Replacing placeholder <{key}> with value {value}") | |
expression = expression.replace(f"<{key}>", str(value)) | |
result = -1 | |
try: | |
result = literal_eval(expression) | |
except SyntaxError as e: | |
raise ValueError( | |
f"The expression syntax is wrong '{expression}': {e}" | |
) from e | |
except Exception as e: | |
raise ValueError( | |
f"Math expression only support literal_eval now: {e}" | |
) | |
return (result, int(result)) | |
class MTB_FitNumber: | |
"""Fit the input float using a source and target range""" | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"value": ("FLOAT", {"default": 0, "forceInput": True}), | |
"clamp": ("BOOLEAN", {"default": False}), | |
"source_min": ( | |
"FLOAT", | |
{"default": 0.0, "step": 0.01, "min": -1e5}, | |
), | |
"source_max": ( | |
"FLOAT", | |
{"default": 1.0, "step": 0.01, "min": -1e5}, | |
), | |
"target_min": ( | |
"FLOAT", | |
{"default": 0.0, "step": 0.01, "min": -1e5}, | |
), | |
"target_max": ( | |
"FLOAT", | |
{"default": 1.0, "step": 0.01, "min": -1e5}, | |
), | |
"easing": ( | |
EASINGS, | |
{"default": "Linear"}, | |
), | |
} | |
} | |
FUNCTION = "set_range" | |
RETURN_TYPES = ("FLOAT",) | |
CATEGORY = "mtb/math" | |
DESCRIPTION = "Fit the input float using a source and target range" | |
def set_range( | |
self, | |
value: float, | |
clamp: bool, | |
source_min: float, | |
source_max: float, | |
target_min: float, | |
target_max: float, | |
easing: str, | |
): | |
if source_min == source_max: | |
normalized_value = 0 | |
else: | |
normalized_value = (value - source_min) / (source_max - source_min) | |
if clamp: | |
normalized_value = max(min(normalized_value, 1), 0) | |
eased_value = apply_easing(normalized_value, easing) | |
# - Convert the eased value to the target range | |
res = target_min + (target_max - target_min) * eased_value | |
return (res,) | |
class MTB_ConcatImages: | |
"""Add images to batch.""" | |
RETURN_TYPES = ("IMAGE",) | |
FUNCTION = "concatenate_tensors" | |
CATEGORY = "mtb/image" | |
def INPUT_TYPES(cls): | |
return { | |
"required": {"reverse": ("BOOLEAN", {"default": False})}, | |
"optional": { | |
"on_mismatch": ( | |
["Error", "Smallest", "Largest"], | |
{"default": "Smallest"}, | |
) | |
}, | |
} | |
def concatenate_tensors( | |
self, | |
reverse: bool, | |
on_mismatch: str = "Smallest", | |
**kwargs: torch.Tensor, | |
) -> tuple[torch.Tensor]: | |
tensors = list(kwargs.values()) | |
if on_mismatch == "Error": | |
shapes = [tensor.shape for tensor in tensors] | |
if not all(shape == shapes[0] for shape in shapes): | |
raise ValueError( | |
"All input tensors must have the same shape when on_mismatch is 'Error'." | |
) | |
else: | |
import torch.nn.functional as F | |
if on_mismatch == "Smallest": | |
target_shape = min( | |
(tensor.shape for tensor in tensors), | |
key=lambda s: (s[1], s[2]), | |
) | |
else: # on_mismatch == "Largest" | |
target_shape = max( | |
(tensor.shape for tensor in tensors), | |
key=lambda s: (s[1], s[2]), | |
) | |
target_height, target_width = target_shape[1], target_shape[2] | |
resized_tensors = [] | |
for tensor in tensors: | |
if ( | |
tensor.shape[1] != target_height | |
or tensor.shape[2] != target_width | |
): | |
resized_tensor = F.interpolate( | |
tensor.permute(0, 3, 1, 2), | |
size=(target_height, target_width), | |
mode="bilinear", | |
align_corners=False, | |
) | |
resized_tensor = resized_tensor.permute(0, 2, 3, 1) | |
resized_tensors.append(resized_tensor) | |
else: | |
resized_tensors.append(tensor) | |
tensors = resized_tensors | |
concatenated = torch.cat(tensors, dim=0) | |
return (concatenated,) | |
__nodes__ = [ | |
MTB_StringReplace, | |
MTB_FitNumber, | |
MTB_GetBatchFromHistory, | |
MTB_AnyToString, | |
MTB_ConcatImages, | |
MTB_MathExpression, | |
MTB_ToDevice, | |
MTB_ApplyTextTemplate, | |
MTB_MatchDimensions, | |
MTB_AutoPanEquilateral, | |
MTB_FloatsToFloat, | |
MTB_FloatToFloats, | |
MTB_FloatsToInts, | |
] | |