workerflux / src /rp_handler.py
Peter-Young's picture
Upload folder using huggingface_hub
5193146 verified
import runpod
from runpod.serverless.utils import rp_upload
import json
import urllib.request
import urllib.parse
import time
import os
import requests
import base64
from io import BytesIO
from PIL import Image
# Time to wait between API check attempts in milliseconds
COMFY_API_AVAILABLE_INTERVAL_MS = 100
# Maximum number of API check attempts
COMFY_API_AVAILABLE_MAX_RETRIES = 500
# Time to wait between poll attempts in milliseconds
COMFY_POLLING_INTERVAL_MS = os.environ.get("COMFY_POLLING_INTERVAL_MS", 1000)
# Maximum number of poll attempts
COMFY_POLLING_MAX_RETRIES = os.environ.get("COMFY_POLLING_MAX_RETRIES", 1000)
# Host where ComfyUI is running
COMFY_HOST = "127.0.0.1:8188"
# Enforce a clean state after each job is done
# see https://docs.runpod.io/docs/handler-additional-controls#refresh-worker
REFRESH_WORKER = os.environ.get("REFRESH_WORKER", "false").lower() == "true"
# 是否把图片转为 webp,文件可以小不少
OUTPUT_WEBP = os.environ.get("OUTPUT_WEBP", "true").lower() == "true"
OUTPUT_RAW_OUTPUTS = os.environ.get("OUTPUT_RAW_OUTPUTS", "false").lower() == "true"
def validate_input(job_input):
"""
Validates the input for the handler function.
Args:
job_input (dict): The input data to validate.
Returns:
tuple: A tuple containing the validated data and an error message, if any.
The structure is (validated_data, error_message).
"""
# Validate if job_input is provided
if job_input is None:
return None, "Please provide input"
# Check if input is a string and try to parse it as JSON
if isinstance(job_input, str):
try:
job_input = json.loads(job_input)
except json.JSONDecodeError:
return None, "Invalid JSON format in input"
# Validate 'workflow' in input
workflow = job_input.get("workflow")
if workflow is None:
return None, "Missing 'workflow' parameter"
# Validate 'args' in input, if provided
args = job_input.get("args")
if args is not None:
if not isinstance(args, dict):
return (
None,
"'args' must be a dict",
)
# Return validated data and no error
return {"workflow": workflow, "args": args}, None
def check_server(url, retries=500, delay=50):
"""
Check if a server is reachable via HTTP GET request
Args:
- url (str): The URL to check
- retries (int, optional): The number of times to attempt connecting to the server. Default is 50
- delay (int, optional): The time in milliseconds to wait between retries. Default is 500
Returns:
bool: True if the server is reachable within the given number of retries, otherwise False
"""
for i in range(retries):
try:
response = requests.get(url)
# If the response status code is 200, the server is up and running
if response.status_code == 200:
print(f"runpod-worker-comfy - API is reachable")
return True
except requests.RequestException as e:
# If an exception occurs, the server may not be ready
pass
# Wait for the specified delay before retrying
time.sleep(delay / 1000)
print(
f"runpod-worker-comfy - Failed to connect to server at {url} after {retries} attempts."
)
return False
def upload_images(images):
"""
Upload a list of base64 encoded images to the ComfyUI server using the /upload/image endpoint.
Args:
images (list): A list of dictionaries, each containing the 'name' of the image and the 'image' as a base64 encoded string.
server_address (str): The address of the ComfyUI server.
Returns:
list: A list of responses from the server for each image upload.
"""
if not images:
return {"status": "success", "message": "No images to upload", "details": []}
responses = []
upload_errors = []
print(f"runpod-worker-comfy - image(s) upload")
for image in images:
name = image["name"]
image_data = image["image"]
blob = base64.b64decode(image_data)
# Prepare the form data
files = {
"image": (name, BytesIO(blob), "image/png"),
"overwrite": (None, "true"),
}
# POST request to upload the image
response = requests.post(f"http://{COMFY_HOST}/upload/image", files=files)
if response.status_code != 200:
upload_errors.append(f"Error uploading {name}: {response.text}")
else:
responses.append(f"Successfully uploaded {name}")
if upload_errors:
print(f"runpod-worker-comfy - image(s) upload with errors")
return {
"status": "error",
"message": "Some images failed to upload",
"details": upload_errors,
}
print(f"runpod-worker-comfy - image(s) upload complete")
return {
"status": "success",
"message": "All images uploaded successfully",
"details": responses,
}
def queue_workflow(workflow):
"""
Queue a workflow to be processed by ComfyUI
Args:
workflow (dict): A dictionary containing the workflow to be processed
Returns:
dict: The JSON response from ComfyUI after processing the workflow
"""
# The top level element "prompt" is required by ComfyUI
data = json.dumps({"prompt": workflow}).encode("utf-8")
req = urllib.request.Request(f"http://{COMFY_HOST}/prompt", data=data)
return json.loads(urllib.request.urlopen(req).read())
def get_history(prompt_id):
"""
Retrieve the history of a given prompt using its ID
Args:
prompt_id (str): The ID of the prompt whose history is to be retrieved
Returns:
dict: The history of the prompt, containing all the processing steps and results
"""
with urllib.request.urlopen(f"http://{COMFY_HOST}/history/{prompt_id}") as response:
return json.loads(response.read())
def base64_encode(img_path):
"""
Returns base64 encoded image.
Args:
img_path (str): The path to the image
Returns:
str: The base64 encoded image
"""
with open(img_path, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
return f"{encoded_string}"
def process_output_images(outputs, job_id):
"""
This function takes the "outputs" from image generation and the job ID,
then determines the correct way to return the image, either as a direct URL
to an AWS S3 bucket or as a base64 encoded string, depending on the
environment configuration.
Args:
outputs (dict): A dictionary containing the outputs from image generation,
typically includes node IDs and their respective output data.
job_id (str): The unique identifier for the job.
Returns:
dict: A dictionary with the status ('success' or 'error') and the message,
which is either the URL to the image in the AWS S3 bucket or a base64
encoded string of the image. In case of error, the message details the issue.
The function works as follows:
- It first determines the output path for the images from an environment variable,
defaulting to "/comfyui/output" if not set.
- It then iterates through the outputs to find the filenames of the generated images.
- After confirming the existence of the image in the output folder, it checks if the
AWS S3 bucket is configured via the BUCKET_ENDPOINT_URL environment variable.
- If AWS S3 is configured, it uploads the image to the bucket and returns the URL.
- If AWS S3 is not configured, it encodes the image in base64 and returns the string.
- If the image file does not exist in the output folder, it returns an error status
with a message indicating the missing image file.
"""
# The path where ComfyUI stores the generated images
COMFY_OUTPUT_PATH = os.environ.get("COMFY_OUTPUT_PATH", "/comfyui/output")
output_images = {}
for node_id, node_output in outputs.items():
if "images" in node_output:
for image in node_output["images"]:
output_images = os.path.join(image["subfolder"], image["filename"])
print(f"runpod-worker-comfy - image generation is done")
# expected image output folder
local_image_path = f"{COMFY_OUTPUT_PATH}/{output_images}"
print(f"runpod-worker-comfy - {local_image_path}")
# The image is in the output folder
if os.path.exists(local_image_path):
if os.environ.get("BUCKET_ENDPOINT_URL", False):
# URL to image in AWS S3
image = rp_upload.upload_image(job_id, local_image_path)
print(
"runpod-worker-comfy - the image was generated and uploaded to AWS S3"
)
else:
# base64 image
image = base64_encode(local_image_path)
print(
"runpod-worker-comfy - the image was generated and converted to base64"
)
return {
"status": "success",
"message": image,
}
else:
print("runpod-worker-comfy - the image does not exist in the output folder")
return {
"status": "error",
"message": f"the image does not exist in the specified output folder: {local_image_path}",
}
def process_input(workflow, args):
"""
处理输入,根据输入参数,替换 workflow 中的参数,eg:
workflow: {"1": }
"""
for key, node in workflow.items():
if node["class_type"] in ["IntegerInput_fal", "FloatInput_fal", "BooleanInput_fal", "StringInput_fal"]:
input_name = node["inputs"]["name"]
if input_name in args:
# 更新节点的 inputs.value
if node["class_type"] in ["IntegerInput_fal", "FloatInput_fal"]:
node["inputs"]["number"] = args[input_name]
else:
node["inputs"]["value"] = args[input_name]
def convert_image_to_base64(filename):
"""将图像文件转换为 WebP 格式并返回 Base64 编码的字符串。"""
try:
COMFY_OUTPUT_PATH = os.environ.get("COMFY_OUTPUT_PATH", "/comfyui/output")
fullpath = os.path.join(COMFY_OUTPUT_PATH, filename)
if not OUTPUT_WEBP:
return "data:image/png;base64," + base64_encode(fullpath)
else:
with Image.open(fullpath) as img:
# 创建一个 BytesIO 对象来保存转换后的图像
with BytesIO() as output:
# 将图像转换为 WebP 格式并保存到 BytesIO
img.save(output, format="WebP")
# 获取 BytesIO 的内容并进行 Base64 编码
output.seek(0) # 重置指针到开头
return "data:image/webp;base64," + base64.b64encode(output.read()).decode('utf-8')
except Exception as e:
print(f"Error converting image {filename}: {e}")
return None
def process_output(workflow, outputs, jobid):
"""
根据保存的 node,返回保存的具体数据
workflow 形式为:
{
"433": {
"inputs": {
"filename_prefix": "result",
"output_name": "upscale",
"images": [
"466",
0
]
},
"class_type": "SaveImage_fal",
"_meta": {
"title": "Save Image (fal)"
}
},
}
outputs 形式为:
{"433": {"images": [{"filename": "xxx.png", "type": "output"}]}}
需要根据 433 找到 workflow 的输出名字,此处为 upscale 然后最终输出为:
{
"upscale": {"images": [{"filename": "xxx.png", "type": "output", "url": "data,webp,data:xxx"}]
}
"""
final_output = {}
# 遍历 workflow 中的每个工作流
for output_id, workflow_data in workflow.items():
# 只处理 class_type 为 SaveImage_fal 的工作流
if workflow_data["class_type"] == "SaveImage_fal":
# 从 outputs 中获取对应的图像数据
if output_id in outputs:
output_data = outputs[output_id]
output_name = workflow_data["inputs"]["output_name"]
# 处理输出,添加 url 字段
for image in output_data["images"]:
filename = image['filename']
# 转换图像为 WebP 格式并获取 Base64 编码
base64_image = convert_image_to_base64(filename)
if base64_image:
image["url"] = f"{base64_image}"
else:
image["url"] = None # 或者可以设置为某个默认值或错误信息
# 构建最终的输出格式
final_output[output_name] = {
"images": output_data["images"]
}
else:
print(f"Warning: output_id {output_id} not found in outputs.")
print(json.dumps(final_output, indent=4, ensure_ascii=False))
return final_output
def handler(job):
"""
The main function that handles a job of generating an image.
This function validates the input, sends a prompt to ComfyUI for processing,
polls ComfyUI for result, and retrieves generated images.
Args:
job (dict): A dictionary containing job details and input parameters.
Returns:
dict: A dictionary containing either an error message or a success status with generated images.
"""
job_input = job["input"]
# Make sure that the input is valid
validated_data, error_message = validate_input(job_input)
if error_message:
return {"error": error_message}
# Extract validated data
workflow = validated_data["workflow"]
args = validated_data.get("args")
process_input(workflow, args)
# Make sure that the ComfyUI API is available
check_server(
f"http://{COMFY_HOST}",
COMFY_API_AVAILABLE_MAX_RETRIES,
COMFY_API_AVAILABLE_INTERVAL_MS,
)
# Queue the workflow
try:
queued_workflow = queue_workflow(workflow)
prompt_id = queued_workflow["prompt_id"]
print(f"runpod-worker-comfy - queued workflow with ID {prompt_id}")
except Exception as e:
return {"error": f"Error queuing workflow: {str(e)}"}
# Poll for completion
print(f"runpod-worker-comfy - wait until image generation is complete")
retries = 0
try:
while retries < COMFY_POLLING_MAX_RETRIES:
history = get_history(prompt_id)
# Exit the loop if we have found the history
if prompt_id in history:
if history[prompt_id].get("outputs"):
break
elif history[prompt_id].get('status') and history[prompt_id].get('status').get('status_str')=='error':
return {"error": history[prompt_id].get('status').get('messages')[-1][1]['exception_message']}
else:
# Wait before trying again
time.sleep(COMFY_POLLING_INTERVAL_MS / 1000)
retries += 1
else:
return {"error": "Max retries reached while waiting for image generation"}
except Exception as e:
return {"error": f"Error waiting for image generation: {str(e)}"}
outputs = history[prompt_id].get("outputs")
jobid = job["id"]
# Get the generated image and return it as URL in an AWS bucket or as base64
# images_result = process_output_images(outputs, jobid)
output_result = process_output(workflow, outputs, jobid)
result = {"result": output_result, "refresh_worker": REFRESH_WORKER}
if OUTPUT_RAW_OUTPUTS:
result["outputs": outputs]
return result
# Start the handler only if this script is run directly
if __name__ == "__main__":
runpod.serverless.start({"handler": handler})