ControlNet-Preprocessors / core /pipelines /controlnet_preprocessor.py
RioShiina's picture
Upload folder using huggingface_hub
79e946f verified
from typing import Dict, Any, List
import imageio
import tempfile
import numpy as np
import torch
import gradio as gr
from PIL import Image
from .base_pipeline import BasePipeline
from comfy_integration.nodes import NODE_CLASS_MAPPINGS
from nodes import NODE_DISPLAY_NAME_MAPPINGS
from utils.app_utils import get_value_at_index
REVERSE_DISPLAY_NAME_MAP = None
CPU_ONLY_PREPROCESSORS = {
"Binary Lines", "Canny Edge", "Color Pallete", "Fake Scribble Lines (aka scribble_hed)",
"Image Intensity", "Image Luminance", "Inpaint Preprocessor", "PyraCanny", "Scribble Lines",
"Scribble XDoG Lines", "Standard Lineart", "Content Shuffle", "Tile"
}
def run_node_by_function_name(node_instance: Any, **kwargs) -> Any:
node_class = type(node_instance)
function_name = getattr(node_class, 'FUNCTION', None)
if not function_name:
raise AttributeError(f"Node class '{node_class.__name__}' is missing the required 'FUNCTION' attribute.")
execution_method = getattr(node_instance, function_name, None)
if not callable(execution_method):
raise AttributeError(f"Method '{function_name}' not found or not callable on node '{node_class.__name__}'.")
return execution_method(**kwargs)
class ControlNetPreprocessorPipeline(BasePipeline):
def _gpu_logic(
self, pil_images: List[Image.Image], preprocessor_name: str, model_name: str,
params: Dict[str, Any], progress=gr.Progress(track_tqdm=True)
) -> List[Image.Image]:
global REVERSE_DISPLAY_NAME_MAP
if REVERSE_DISPLAY_NAME_MAP is None:
raise RuntimeError("REVERSE_DISPLAY_NAME_MAP has not been initialized. `build_reverse_map` must be called on startup.")
class_name = REVERSE_DISPLAY_NAME_MAP.get(preprocessor_name)
if not class_name or class_name not in NODE_CLASS_MAPPINGS:
raise ValueError(f"Preprocessor '{preprocessor_name}' not found.")
preprocessor_instance = NODE_CLASS_MAPPINGS[class_name]()
call_args = {**params, 'ckpt_name': model_name}
processed_pil_images = []
total_frames = len(pil_images)
for i, frame_pil in enumerate(pil_images):
progress(i / total_frames, desc=f"Processing frame {i+1}/{total_frames} with {preprocessor_name}...")
frame_tensor = torch.from_numpy(np.array(frame_pil).astype(np.float32) / 255.0).unsqueeze(0)
resolution_arg = {'resolution': max(frame_tensor.shape[2], frame_tensor.shape[3])}
result_tuple = run_node_by_function_name(
preprocessor_instance,
image=frame_tensor,
**resolution_arg,
**call_args
)
processed_tensor = get_value_at_index(result_tuple, 0)
processed_np = (processed_tensor.squeeze(0).cpu().numpy().clip(0, 1) * 255.0).astype(np.uint8)
processed_pil_images.append(Image.fromarray(processed_np))
return processed_pil_images
def run(self, input_type, image_input, video_input, preprocessor_name, model_name, zero_gpu_duration, *args, progress=gr.Progress(track_tqdm=True)):
from utils import app_utils
pil_images, is_video, fps = [], False, 30
progress(0, desc="Reading input file...")
if input_type == "Image":
if image_input is None: raise gr.Error("Please provide an input image.")
pil_images = [image_input]
elif input_type == "Video":
if video_input is None: raise gr.Error("Please provide an input video.")
try:
video_reader = imageio.get_reader(video_input)
meta = video_reader.get_meta_data()
fps = meta.get('fps', 30)
pil_images = [Image.fromarray(frame) for frame in video_reader]
is_video = True
video_reader.close()
except Exception as e: raise gr.Error(f"Failed to read video file: {e}")
else:
raise gr.Error("Invalid input type selected.")
if not pil_images: raise gr.Error("Could not extract any frames from the input.")
if app_utils.PREPROCESSOR_PARAMETER_MAP is None:
raise RuntimeError("Preprocessor parameter map is not built. Check startup logs.")
params_config = app_utils.PREPROCESSOR_PARAMETER_MAP.get(preprocessor_name, [])
sliders_params = [p for p in params_config if p['type'] in ["INT", "FLOAT"]]
dropdown_params = [p for p in params_config if isinstance(p['type'], list)]
checkbox_params = [p for p in params_config if p['type'] == "BOOLEAN"]
ordered_params_config = sliders_params + dropdown_params + checkbox_params
param_names = [p['name'] for p in ordered_params_config]
provided_params = {param_names[i]: args[i] for i in range(len(param_names))}
if preprocessor_name not in CPU_ONLY_PREPROCESSORS:
print(f"--- '{preprocessor_name}' requires GPU, requesting ZeroGPU. ---")
try:
processed_pil_images = self._execute_gpu_logic(
self._gpu_logic,
duration=zero_gpu_duration,
default_duration=60,
task_name=f"Preprocessor '{preprocessor_name}'",
pil_images=pil_images,
preprocessor_name=preprocessor_name,
model_name=model_name,
params=provided_params,
progress=progress
)
except Exception as e:
import traceback; traceback.print_exc()
raise gr.Error(f"Failed to run preprocessor '{preprocessor_name}' on GPU: {e}")
else:
print(f"--- Running '{preprocessor_name}' on CPU, no ZeroGPU requested. ---")
try:
processed_pil_images = self._gpu_logic(pil_images, preprocessor_name, model_name, provided_params, progress=progress)
except Exception as e:
import traceback; traceback.print_exc()
raise gr.Error(f"Failed to run preprocessor '{preprocessor_name}' on CPU: {e}")
if not processed_pil_images: raise gr.Error("Processing returned no frames.")
progress(0.9, desc="Finalizing output...")
if is_video:
frames_np = [np.array(img) for img in processed_pil_images]
frames_tensor = torch.from_numpy(np.stack(frames_np)).to(torch.float32) / 255.0
video_path = self._encode_video_from_frames(frames_tensor, fps, progress)
return [video_path]
else:
progress(1.0, desc="Done!")
return processed_pil_images