|
import base64 |
|
import traceback |
|
import json |
|
|
|
import runpod |
|
from runpod.serverless.utils.rp_validator import validate |
|
from runpod.serverless.utils.rp_download import file |
|
from runpod.serverless.modules.rp_logger import RunPodLogger |
|
|
|
from predict import main |
|
|
|
logger = RunPodLogger() |
|
|
|
|
|
def predict( |
|
job_id: int, |
|
image_url, |
|
prompt, |
|
negative_prompt, |
|
steps, |
|
cfg, |
|
denoise, |
|
mask_expand: int = 30, |
|
gaus_kernel_size: int = 100, |
|
gaus_sigma: int = 100, |
|
target_ratio: str = '1:1', |
|
quantity: int = 1 |
|
): |
|
if image_url: |
|
image = file(image_url) |
|
image_path = image["file_path"] |
|
|
|
result = main( |
|
image_path=image_path, |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
denoise=denoise, |
|
steps=steps, |
|
cfg=cfg, |
|
mask_expand=mask_expand, |
|
gaus_kernel_size=gaus_kernel_size, |
|
gaus_sigma=gaus_sigma, |
|
target_ratio=target_ratio, |
|
quantity=quantity |
|
) |
|
return result |
|
|
|
|
|
def handler(job): |
|
try: |
|
try: |
|
print(job['input']) |
|
payload = json.loads(job['input']) |
|
except Exception as ex: |
|
print(ex) |
|
payload = job['input'] |
|
|
|
result = predict( |
|
job['id'], |
|
image_url=payload.get('image_url', ''), |
|
prompt=payload.get("prompt", ""), |
|
negative_prompt=payload.get("negative_prompt", ""), |
|
steps=payload.get('steps', 20), |
|
cfg=payload.get('cfg', 3), |
|
denoise=payload.get('denoise', 1), |
|
mask_expand=payload.get("mask_expand", 30), |
|
gaus_kernel_size=payload.get('gaus_kernel_size', 100), |
|
gaus_sigma=payload.get('gaus_sigma', 100), |
|
target_ratio=payload.get('target_ratio', '1:1'), |
|
quantity=payload.get('quantity', 1), |
|
) |
|
output = [] |
|
for r in result: |
|
with open(r, "rb") as file_: |
|
file_content = file_.read() |
|
encode = base64.b64encode(file_content).decode('utf-8') |
|
output.append(encode) |
|
return { |
|
'images': output, |
|
} |
|
except Exception as e: |
|
logger.error(f'An exception was raised: {e}') |
|
|
|
return { |
|
'error': str(e), |
|
'output': traceback.format_exc(), |
|
'refresh_worker': True |
|
} |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
logger.info('Starting RunPod Serverless...') |
|
runpod.serverless.start( |
|
{ |
|
'handler': handler |
|
} |
|
) |
|
|