|
import runpod |
|
from runpod.serverless.utils import rp_upload |
|
import json |
|
import urllib.request |
|
import urllib.parse |
|
import time |
|
import os |
|
import requests |
|
import base64 |
|
from io import BytesIO |
|
from PIL import Image |
|
|
|
|
|
COMFY_API_AVAILABLE_INTERVAL_MS = 100 |
|
|
|
COMFY_API_AVAILABLE_MAX_RETRIES = 500 |
|
|
|
COMFY_POLLING_INTERVAL_MS = os.environ.get("COMFY_POLLING_INTERVAL_MS", 1000) |
|
|
|
COMFY_POLLING_MAX_RETRIES = os.environ.get("COMFY_POLLING_MAX_RETRIES", 1000) |
|
|
|
COMFY_HOST = "127.0.0.1:8188" |
|
|
|
|
|
REFRESH_WORKER = os.environ.get("REFRESH_WORKER", "false").lower() == "true" |
|
|
|
OUTPUT_WEBP = os.environ.get("OUTPUT_WEBP", "true").lower() == "true" |
|
OUTPUT_RAW_OUTPUTS = os.environ.get("OUTPUT_RAW_OUTPUTS", "false").lower() == "true" |
|
|
|
|
|
def validate_input(job_input): |
|
""" |
|
Validates the input for the handler function. |
|
|
|
Args: |
|
job_input (dict): The input data to validate. |
|
|
|
Returns: |
|
tuple: A tuple containing the validated data and an error message, if any. |
|
The structure is (validated_data, error_message). |
|
""" |
|
|
|
if job_input is None: |
|
return None, "Please provide input" |
|
|
|
|
|
if isinstance(job_input, str): |
|
try: |
|
job_input = json.loads(job_input) |
|
except json.JSONDecodeError: |
|
return None, "Invalid JSON format in input" |
|
|
|
|
|
workflow = job_input.get("workflow") |
|
if workflow is None: |
|
return None, "Missing 'workflow' parameter" |
|
|
|
|
|
args = job_input.get("args") |
|
if args is not None: |
|
if not isinstance(args, dict): |
|
return ( |
|
None, |
|
"'args' must be a dict", |
|
) |
|
|
|
|
|
return {"workflow": workflow, "args": args}, None |
|
|
|
|
|
def check_server(url, retries=500, delay=50): |
|
""" |
|
Check if a server is reachable via HTTP GET request |
|
|
|
Args: |
|
- url (str): The URL to check |
|
- retries (int, optional): The number of times to attempt connecting to the server. Default is 50 |
|
- delay (int, optional): The time in milliseconds to wait between retries. Default is 500 |
|
|
|
Returns: |
|
bool: True if the server is reachable within the given number of retries, otherwise False |
|
""" |
|
|
|
for i in range(retries): |
|
try: |
|
response = requests.get(url) |
|
|
|
|
|
if response.status_code == 200: |
|
print(f"runpod-worker-comfy - API is reachable") |
|
return True |
|
except requests.RequestException as e: |
|
|
|
pass |
|
|
|
|
|
time.sleep(delay / 1000) |
|
|
|
print( |
|
f"runpod-worker-comfy - Failed to connect to server at {url} after {retries} attempts." |
|
) |
|
return False |
|
|
|
|
|
def upload_images(images): |
|
""" |
|
Upload a list of base64 encoded images to the ComfyUI server using the /upload/image endpoint. |
|
|
|
Args: |
|
images (list): A list of dictionaries, each containing the 'name' of the image and the 'image' as a base64 encoded string. |
|
server_address (str): The address of the ComfyUI server. |
|
|
|
Returns: |
|
list: A list of responses from the server for each image upload. |
|
""" |
|
if not images: |
|
return {"status": "success", "message": "No images to upload", "details": []} |
|
|
|
responses = [] |
|
upload_errors = [] |
|
|
|
print(f"runpod-worker-comfy - image(s) upload") |
|
|
|
for image in images: |
|
name = image["name"] |
|
image_data = image["image"] |
|
blob = base64.b64decode(image_data) |
|
|
|
|
|
files = { |
|
"image": (name, BytesIO(blob), "image/png"), |
|
"overwrite": (None, "true"), |
|
} |
|
|
|
|
|
response = requests.post(f"http://{COMFY_HOST}/upload/image", files=files) |
|
if response.status_code != 200: |
|
upload_errors.append(f"Error uploading {name}: {response.text}") |
|
else: |
|
responses.append(f"Successfully uploaded {name}") |
|
|
|
if upload_errors: |
|
print(f"runpod-worker-comfy - image(s) upload with errors") |
|
return { |
|
"status": "error", |
|
"message": "Some images failed to upload", |
|
"details": upload_errors, |
|
} |
|
|
|
print(f"runpod-worker-comfy - image(s) upload complete") |
|
return { |
|
"status": "success", |
|
"message": "All images uploaded successfully", |
|
"details": responses, |
|
} |
|
|
|
|
|
def queue_workflow(workflow): |
|
""" |
|
Queue a workflow to be processed by ComfyUI |
|
|
|
Args: |
|
workflow (dict): A dictionary containing the workflow to be processed |
|
|
|
Returns: |
|
dict: The JSON response from ComfyUI after processing the workflow |
|
""" |
|
|
|
|
|
data = json.dumps({"prompt": workflow}).encode("utf-8") |
|
|
|
req = urllib.request.Request(f"http://{COMFY_HOST}/prompt", data=data) |
|
return json.loads(urllib.request.urlopen(req).read()) |
|
|
|
|
|
def get_history(prompt_id): |
|
""" |
|
Retrieve the history of a given prompt using its ID |
|
|
|
Args: |
|
prompt_id (str): The ID of the prompt whose history is to be retrieved |
|
|
|
Returns: |
|
dict: The history of the prompt, containing all the processing steps and results |
|
""" |
|
with urllib.request.urlopen(f"http://{COMFY_HOST}/history/{prompt_id}") as response: |
|
return json.loads(response.read()) |
|
|
|
|
|
def base64_encode(img_path): |
|
""" |
|
Returns base64 encoded image. |
|
|
|
Args: |
|
img_path (str): The path to the image |
|
|
|
Returns: |
|
str: The base64 encoded image |
|
""" |
|
with open(img_path, "rb") as image_file: |
|
encoded_string = base64.b64encode(image_file.read()).decode("utf-8") |
|
return f"{encoded_string}" |
|
|
|
|
|
def process_output_images(outputs, job_id): |
|
""" |
|
This function takes the "outputs" from image generation and the job ID, |
|
then determines the correct way to return the image, either as a direct URL |
|
to an AWS S3 bucket or as a base64 encoded string, depending on the |
|
environment configuration. |
|
|
|
Args: |
|
outputs (dict): A dictionary containing the outputs from image generation, |
|
typically includes node IDs and their respective output data. |
|
job_id (str): The unique identifier for the job. |
|
|
|
Returns: |
|
dict: A dictionary with the status ('success' or 'error') and the message, |
|
which is either the URL to the image in the AWS S3 bucket or a base64 |
|
encoded string of the image. In case of error, the message details the issue. |
|
|
|
The function works as follows: |
|
- It first determines the output path for the images from an environment variable, |
|
defaulting to "/comfyui/output" if not set. |
|
- It then iterates through the outputs to find the filenames of the generated images. |
|
- After confirming the existence of the image in the output folder, it checks if the |
|
AWS S3 bucket is configured via the BUCKET_ENDPOINT_URL environment variable. |
|
- If AWS S3 is configured, it uploads the image to the bucket and returns the URL. |
|
- If AWS S3 is not configured, it encodes the image in base64 and returns the string. |
|
- If the image file does not exist in the output folder, it returns an error status |
|
with a message indicating the missing image file. |
|
""" |
|
|
|
|
|
COMFY_OUTPUT_PATH = os.environ.get("COMFY_OUTPUT_PATH", "/comfyui/output") |
|
|
|
output_images = {} |
|
|
|
for node_id, node_output in outputs.items(): |
|
if "images" in node_output: |
|
for image in node_output["images"]: |
|
output_images = os.path.join(image["subfolder"], image["filename"]) |
|
|
|
print(f"runpod-worker-comfy - image generation is done") |
|
|
|
|
|
local_image_path = f"{COMFY_OUTPUT_PATH}/{output_images}" |
|
|
|
print(f"runpod-worker-comfy - {local_image_path}") |
|
|
|
|
|
if os.path.exists(local_image_path): |
|
if os.environ.get("BUCKET_ENDPOINT_URL", False): |
|
|
|
image = rp_upload.upload_image(job_id, local_image_path) |
|
print( |
|
"runpod-worker-comfy - the image was generated and uploaded to AWS S3" |
|
) |
|
else: |
|
|
|
image = base64_encode(local_image_path) |
|
print( |
|
"runpod-worker-comfy - the image was generated and converted to base64" |
|
) |
|
|
|
return { |
|
"status": "success", |
|
"message": image, |
|
} |
|
else: |
|
print("runpod-worker-comfy - the image does not exist in the output folder") |
|
return { |
|
"status": "error", |
|
"message": f"the image does not exist in the specified output folder: {local_image_path}", |
|
} |
|
|
|
def process_input(workflow, args): |
|
""" |
|
处理输入,根据输入参数,替换 workflow 中的参数,eg: |
|
workflow: {"1": } |
|
""" |
|
for key, node in workflow.items(): |
|
if node["class_type"] in ["IntegerInput_fal", "FloatInput_fal", "BooleanInput_fal", "StringInput_fal"]: |
|
input_name = node["inputs"]["name"] |
|
if input_name in args: |
|
|
|
if node["class_type"] in ["IntegerInput_fal", "FloatInput_fal"]: |
|
node["inputs"]["number"] = args[input_name] |
|
else: |
|
node["inputs"]["value"] = args[input_name] |
|
|
|
def convert_image_to_base64(filename): |
|
"""将图像文件转换为 WebP 格式并返回 Base64 编码的字符串。""" |
|
try: |
|
COMFY_OUTPUT_PATH = os.environ.get("COMFY_OUTPUT_PATH", "/comfyui/output") |
|
fullpath = os.path.join(COMFY_OUTPUT_PATH, filename) |
|
if not OUTPUT_WEBP: |
|
return "data:image/png;base64," + base64_encode(fullpath) |
|
else: |
|
with Image.open(fullpath) as img: |
|
|
|
with BytesIO() as output: |
|
|
|
img.save(output, format="WebP") |
|
|
|
output.seek(0) |
|
return "data:image/webp;base64," + base64.b64encode(output.read()).decode('utf-8') |
|
except Exception as e: |
|
print(f"Error converting image {filename}: {e}") |
|
return None |
|
|
|
def process_output(workflow, outputs, jobid): |
|
""" |
|
根据保存的 node,返回保存的具体数据 |
|
workflow 形式为: |
|
{ |
|
"433": { |
|
"inputs": { |
|
"filename_prefix": "result", |
|
"output_name": "upscale", |
|
"images": [ |
|
"466", |
|
0 |
|
] |
|
}, |
|
"class_type": "SaveImage_fal", |
|
"_meta": { |
|
"title": "Save Image (fal)" |
|
} |
|
}, |
|
} |
|
|
|
outputs 形式为: |
|
{"433": {"images": [{"filename": "xxx.png", "type": "output"}]}} |
|
需要根据 433 找到 workflow 的输出名字,此处为 upscale 然后最终输出为: |
|
{ |
|
"upscale": {"images": [{"filename": "xxx.png", "type": "output", "url": "data,webp,data:xxx"}] |
|
} |
|
""" |
|
|
|
final_output = {} |
|
|
|
|
|
for output_id, workflow_data in workflow.items(): |
|
|
|
if workflow_data["class_type"] == "SaveImage_fal": |
|
|
|
if output_id in outputs: |
|
output_data = outputs[output_id] |
|
output_name = workflow_data["inputs"]["output_name"] |
|
|
|
|
|
for image in output_data["images"]: |
|
filename = image['filename'] |
|
|
|
base64_image = convert_image_to_base64(filename) |
|
if base64_image: |
|
image["url"] = f"{base64_image}" |
|
else: |
|
image["url"] = None |
|
|
|
|
|
final_output[output_name] = { |
|
"images": output_data["images"] |
|
} |
|
else: |
|
print(f"Warning: output_id {output_id} not found in outputs.") |
|
|
|
print(json.dumps(final_output, indent=4, ensure_ascii=False)) |
|
return final_output |
|
|
|
def handler(job): |
|
""" |
|
The main function that handles a job of generating an image. |
|
|
|
This function validates the input, sends a prompt to ComfyUI for processing, |
|
polls ComfyUI for result, and retrieves generated images. |
|
|
|
Args: |
|
job (dict): A dictionary containing job details and input parameters. |
|
|
|
Returns: |
|
dict: A dictionary containing either an error message or a success status with generated images. |
|
""" |
|
job_input = job["input"] |
|
|
|
|
|
validated_data, error_message = validate_input(job_input) |
|
if error_message: |
|
return {"error": error_message} |
|
|
|
|
|
workflow = validated_data["workflow"] |
|
args = validated_data.get("args") |
|
process_input(workflow, args) |
|
|
|
|
|
check_server( |
|
f"http://{COMFY_HOST}", |
|
COMFY_API_AVAILABLE_MAX_RETRIES, |
|
COMFY_API_AVAILABLE_INTERVAL_MS, |
|
) |
|
|
|
|
|
try: |
|
queued_workflow = queue_workflow(workflow) |
|
prompt_id = queued_workflow["prompt_id"] |
|
print(f"runpod-worker-comfy - queued workflow with ID {prompt_id}") |
|
except Exception as e: |
|
return {"error": f"Error queuing workflow: {str(e)}"} |
|
|
|
|
|
print(f"runpod-worker-comfy - wait until image generation is complete") |
|
retries = 0 |
|
try: |
|
while retries < COMFY_POLLING_MAX_RETRIES: |
|
history = get_history(prompt_id) |
|
|
|
|
|
if prompt_id in history: |
|
if history[prompt_id].get("outputs"): |
|
break |
|
elif history[prompt_id].get('status') and history[prompt_id].get('status').get('status_str')=='error': |
|
return {"error": history[prompt_id].get('status').get('messages')[-1][1]['exception_message']} |
|
else: |
|
|
|
time.sleep(COMFY_POLLING_INTERVAL_MS / 1000) |
|
retries += 1 |
|
else: |
|
return {"error": "Max retries reached while waiting for image generation"} |
|
except Exception as e: |
|
return {"error": f"Error waiting for image generation: {str(e)}"} |
|
|
|
outputs = history[prompt_id].get("outputs") |
|
jobid = job["id"] |
|
|
|
|
|
output_result = process_output(workflow, outputs, jobid) |
|
|
|
result = {"result": output_result, "refresh_worker": REFRESH_WORKER} |
|
if OUTPUT_RAW_OUTPUTS: |
|
result["outputs": outputs] |
|
return result |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
runpod.serverless.start({"handler": handler}) |
|
|