import runpod from runpod.serverless.utils import rp_upload import json import urllib.request import urllib.parse import time import os import requests import base64 from io import BytesIO from PIL import Image # Time to wait between API check attempts in milliseconds COMFY_API_AVAILABLE_INTERVAL_MS = 100 # Maximum number of API check attempts COMFY_API_AVAILABLE_MAX_RETRIES = 500 # Time to wait between poll attempts in milliseconds COMFY_POLLING_INTERVAL_MS = os.environ.get("COMFY_POLLING_INTERVAL_MS", 1000) # Maximum number of poll attempts COMFY_POLLING_MAX_RETRIES = os.environ.get("COMFY_POLLING_MAX_RETRIES", 1000) # Host where ComfyUI is running COMFY_HOST = "127.0.0.1:8188" # Enforce a clean state after each job is done # see https://docs.runpod.io/docs/handler-additional-controls#refresh-worker REFRESH_WORKER = os.environ.get("REFRESH_WORKER", "false").lower() == "true" # 是否把图片转为 webp,文件可以小不少 OUTPUT_WEBP = os.environ.get("OUTPUT_WEBP", "true").lower() == "true" OUTPUT_RAW_OUTPUTS = os.environ.get("OUTPUT_RAW_OUTPUTS", "false").lower() == "true" def validate_input(job_input): """ Validates the input for the handler function. Args: job_input (dict): The input data to validate. Returns: tuple: A tuple containing the validated data and an error message, if any. The structure is (validated_data, error_message). """ # Validate if job_input is provided if job_input is None: return None, "Please provide input" # Check if input is a string and try to parse it as JSON if isinstance(job_input, str): try: job_input = json.loads(job_input) except json.JSONDecodeError: return None, "Invalid JSON format in input" # Validate 'workflow' in input workflow = job_input.get("workflow") if workflow is None: return None, "Missing 'workflow' parameter" # Validate 'args' in input, if provided args = job_input.get("args") if args is not None: if not isinstance(args, dict): return ( None, "'args' must be a dict", ) # Return validated data and no error return {"workflow": workflow, "args": args}, None def check_server(url, retries=500, delay=50): """ Check if a server is reachable via HTTP GET request Args: - url (str): The URL to check - retries (int, optional): The number of times to attempt connecting to the server. Default is 50 - delay (int, optional): The time in milliseconds to wait between retries. Default is 500 Returns: bool: True if the server is reachable within the given number of retries, otherwise False """ for i in range(retries): try: response = requests.get(url) # If the response status code is 200, the server is up and running if response.status_code == 200: print(f"runpod-worker-comfy - API is reachable") return True except requests.RequestException as e: # If an exception occurs, the server may not be ready pass # Wait for the specified delay before retrying time.sleep(delay / 1000) print( f"runpod-worker-comfy - Failed to connect to server at {url} after {retries} attempts." ) return False def upload_images(images): """ Upload a list of base64 encoded images to the ComfyUI server using the /upload/image endpoint. Args: images (list): A list of dictionaries, each containing the 'name' of the image and the 'image' as a base64 encoded string. server_address (str): The address of the ComfyUI server. Returns: list: A list of responses from the server for each image upload. """ if not images: return {"status": "success", "message": "No images to upload", "details": []} responses = [] upload_errors = [] print(f"runpod-worker-comfy - image(s) upload") for image in images: name = image["name"] image_data = image["image"] blob = base64.b64decode(image_data) # Prepare the form data files = { "image": (name, BytesIO(blob), "image/png"), "overwrite": (None, "true"), } # POST request to upload the image response = requests.post(f"http://{COMFY_HOST}/upload/image", files=files) if response.status_code != 200: upload_errors.append(f"Error uploading {name}: {response.text}") else: responses.append(f"Successfully uploaded {name}") if upload_errors: print(f"runpod-worker-comfy - image(s) upload with errors") return { "status": "error", "message": "Some images failed to upload", "details": upload_errors, } print(f"runpod-worker-comfy - image(s) upload complete") return { "status": "success", "message": "All images uploaded successfully", "details": responses, } def queue_workflow(workflow): """ Queue a workflow to be processed by ComfyUI Args: workflow (dict): A dictionary containing the workflow to be processed Returns: dict: The JSON response from ComfyUI after processing the workflow """ # The top level element "prompt" is required by ComfyUI data = json.dumps({"prompt": workflow}).encode("utf-8") req = urllib.request.Request(f"http://{COMFY_HOST}/prompt", data=data) return json.loads(urllib.request.urlopen(req).read()) def get_history(prompt_id): """ Retrieve the history of a given prompt using its ID Args: prompt_id (str): The ID of the prompt whose history is to be retrieved Returns: dict: The history of the prompt, containing all the processing steps and results """ with urllib.request.urlopen(f"http://{COMFY_HOST}/history/{prompt_id}") as response: return json.loads(response.read()) def base64_encode(img_path): """ Returns base64 encoded image. Args: img_path (str): The path to the image Returns: str: The base64 encoded image """ with open(img_path, "rb") as image_file: encoded_string = base64.b64encode(image_file.read()).decode("utf-8") return f"{encoded_string}" def process_output_images(outputs, job_id): """ This function takes the "outputs" from image generation and the job ID, then determines the correct way to return the image, either as a direct URL to an AWS S3 bucket or as a base64 encoded string, depending on the environment configuration. Args: outputs (dict): A dictionary containing the outputs from image generation, typically includes node IDs and their respective output data. job_id (str): The unique identifier for the job. Returns: dict: A dictionary with the status ('success' or 'error') and the message, which is either the URL to the image in the AWS S3 bucket or a base64 encoded string of the image. In case of error, the message details the issue. The function works as follows: - It first determines the output path for the images from an environment variable, defaulting to "/comfyui/output" if not set. - It then iterates through the outputs to find the filenames of the generated images. - After confirming the existence of the image in the output folder, it checks if the AWS S3 bucket is configured via the BUCKET_ENDPOINT_URL environment variable. - If AWS S3 is configured, it uploads the image to the bucket and returns the URL. - If AWS S3 is not configured, it encodes the image in base64 and returns the string. - If the image file does not exist in the output folder, it returns an error status with a message indicating the missing image file. """ # The path where ComfyUI stores the generated images COMFY_OUTPUT_PATH = os.environ.get("COMFY_OUTPUT_PATH", "/comfyui/output") output_images = {} for node_id, node_output in outputs.items(): if "images" in node_output: for image in node_output["images"]: output_images = os.path.join(image["subfolder"], image["filename"]) print(f"runpod-worker-comfy - image generation is done") # expected image output folder local_image_path = f"{COMFY_OUTPUT_PATH}/{output_images}" print(f"runpod-worker-comfy - {local_image_path}") # The image is in the output folder if os.path.exists(local_image_path): if os.environ.get("BUCKET_ENDPOINT_URL", False): # URL to image in AWS S3 image = rp_upload.upload_image(job_id, local_image_path) print( "runpod-worker-comfy - the image was generated and uploaded to AWS S3" ) else: # base64 image image = base64_encode(local_image_path) print( "runpod-worker-comfy - the image was generated and converted to base64" ) return { "status": "success", "message": image, } else: print("runpod-worker-comfy - the image does not exist in the output folder") return { "status": "error", "message": f"the image does not exist in the specified output folder: {local_image_path}", } def process_input(workflow, args): """ 处理输入,根据输入参数,替换 workflow 中的参数,eg: workflow: {"1": } """ for key, node in workflow.items(): if node["class_type"] in ["IntegerInput_fal", "FloatInput_fal", "BooleanInput_fal", "StringInput_fal"]: input_name = node["inputs"]["name"] if input_name in args: # 更新节点的 inputs.value if node["class_type"] in ["IntegerInput_fal", "FloatInput_fal"]: node["inputs"]["number"] = args[input_name] else: node["inputs"]["value"] = args[input_name] def convert_image_to_base64(filename): """将图像文件转换为 WebP 格式并返回 Base64 编码的字符串。""" try: COMFY_OUTPUT_PATH = os.environ.get("COMFY_OUTPUT_PATH", "/comfyui/output") fullpath = os.path.join(COMFY_OUTPUT_PATH, filename) if not OUTPUT_WEBP: return "data:image/png;base64," + base64_encode(fullpath) else: with Image.open(fullpath) as img: # 创建一个 BytesIO 对象来保存转换后的图像 with BytesIO() as output: # 将图像转换为 WebP 格式并保存到 BytesIO img.save(output, format="WebP") # 获取 BytesIO 的内容并进行 Base64 编码 output.seek(0) # 重置指针到开头 return "data:image/webp;base64," + base64.b64encode(output.read()).decode('utf-8') except Exception as e: print(f"Error converting image {filename}: {e}") return None def process_output(workflow, outputs, jobid): """ 根据保存的 node,返回保存的具体数据 workflow 形式为: { "433": { "inputs": { "filename_prefix": "result", "output_name": "upscale", "images": [ "466", 0 ] }, "class_type": "SaveImage_fal", "_meta": { "title": "Save Image (fal)" } }, } outputs 形式为: {"433": {"images": [{"filename": "xxx.png", "type": "output"}]}} 需要根据 433 找到 workflow 的输出名字,此处为 upscale 然后最终输出为: { "upscale": {"images": [{"filename": "xxx.png", "type": "output", "url": "data,webp,data:xxx"}] } """ final_output = {} # 遍历 workflow 中的每个工作流 for output_id, workflow_data in workflow.items(): # 只处理 class_type 为 SaveImage_fal 的工作流 if workflow_data["class_type"] == "SaveImage_fal": # 从 outputs 中获取对应的图像数据 if output_id in outputs: output_data = outputs[output_id] output_name = workflow_data["inputs"]["output_name"] # 处理输出,添加 url 字段 for image in output_data["images"]: filename = image['filename'] # 转换图像为 WebP 格式并获取 Base64 编码 base64_image = convert_image_to_base64(filename) if base64_image: image["url"] = f"{base64_image}" else: image["url"] = None # 或者可以设置为某个默认值或错误信息 # 构建最终的输出格式 final_output[output_name] = { "images": output_data["images"] } else: print(f"Warning: output_id {output_id} not found in outputs.") print(json.dumps(final_output, indent=4, ensure_ascii=False)) return final_output def handler(job): """ The main function that handles a job of generating an image. This function validates the input, sends a prompt to ComfyUI for processing, polls ComfyUI for result, and retrieves generated images. Args: job (dict): A dictionary containing job details and input parameters. Returns: dict: A dictionary containing either an error message or a success status with generated images. """ job_input = job["input"] # Make sure that the input is valid validated_data, error_message = validate_input(job_input) if error_message: return {"error": error_message} # Extract validated data workflow = validated_data["workflow"] args = validated_data.get("args") process_input(workflow, args) # Make sure that the ComfyUI API is available check_server( f"http://{COMFY_HOST}", COMFY_API_AVAILABLE_MAX_RETRIES, COMFY_API_AVAILABLE_INTERVAL_MS, ) # Queue the workflow try: queued_workflow = queue_workflow(workflow) prompt_id = queued_workflow["prompt_id"] print(f"runpod-worker-comfy - queued workflow with ID {prompt_id}") except Exception as e: return {"error": f"Error queuing workflow: {str(e)}"} # Poll for completion print(f"runpod-worker-comfy - wait until image generation is complete") retries = 0 try: while retries < COMFY_POLLING_MAX_RETRIES: history = get_history(prompt_id) # Exit the loop if we have found the history if prompt_id in history: if history[prompt_id].get("outputs"): break elif history[prompt_id].get('status') and history[prompt_id].get('status').get('status_str')=='error': return {"error": history[prompt_id].get('status').get('messages')[-1][1]['exception_message']} else: # Wait before trying again time.sleep(COMFY_POLLING_INTERVAL_MS / 1000) retries += 1 else: return {"error": "Max retries reached while waiting for image generation"} except Exception as e: return {"error": f"Error waiting for image generation: {str(e)}"} outputs = history[prompt_id].get("outputs") jobid = job["id"] # Get the generated image and return it as URL in an AWS bucket or as base64 # images_result = process_output_images(outputs, jobid) output_result = process_output(workflow, outputs, jobid) result = {"result": output_result, "refresh_worker": REFRESH_WORKER} if OUTPUT_RAW_OUTPUTS: result["outputs": outputs] return result # Start the handler only if this script is run directly if __name__ == "__main__": runpod.serverless.start({"handler": handler})