flux-outpaint / rp_handler.py
deneesk's picture
Update rp_handler.py
f17a349 verified
import base64
import traceback
import json
import runpod
from runpod.serverless.utils.rp_validator import validate
from runpod.serverless.utils.rp_download import file
from runpod.serverless.modules.rp_logger import RunPodLogger
from predict import main
logger = RunPodLogger()
def predict(
job_id: int,
image_url,
prompt,
negative_prompt,
steps,
cfg,
denoise,
mask_expand: int = 30,
gaus_kernel_size: int = 100,
gaus_sigma: int = 100,
target_ratio: str = '1:1',
quantity: int = 1
):
if image_url:
image = file(image_url)
image_path = image["file_path"]
result = main(
image_path=image_path,
prompt=prompt,
negative_prompt=negative_prompt,
denoise=denoise,
steps=steps,
cfg=cfg,
mask_expand=mask_expand,
gaus_kernel_size=gaus_kernel_size,
gaus_sigma=gaus_sigma,
target_ratio=target_ratio,
quantity=quantity
)
return result
def handler(job):
try:
try:
print(job['input'])
payload = json.loads(job['input'])
except Exception as ex:
print(ex)
payload = job['input']
result = predict(
job['id'],
image_url=payload.get('image_url', ''),
prompt=payload.get("prompt", ""),
negative_prompt=payload.get("negative_prompt", ""),
steps=payload.get('steps', 20),
cfg=payload.get('cfg', 3),
denoise=payload.get('denoise', 1),
mask_expand=payload.get("mask_expand", 30),
gaus_kernel_size=payload.get('gaus_kernel_size', 100),
gaus_sigma=payload.get('gaus_sigma', 100),
target_ratio=payload.get('target_ratio', '1:1'),
quantity=payload.get('quantity', 1),
)
output = []
for r in result:
with open(r, "rb") as file_:
file_content = file_.read()
encode = base64.b64encode(file_content).decode('utf-8')
output.append(encode)
return {
'images': output,
}
except Exception as e:
logger.error(f'An exception was raised: {e}')
return {
'error': str(e),
'output': traceback.format_exc(),
'refresh_worker': True
}
# ---------------------------------------------------------------------------- #
# RunPod Handler #
# ---------------------------------------------------------------------------- #
if __name__ == '__main__':
logger.info('Starting RunPod Serverless...')
runpod.serverless.start(
{
'handler': handler
}
)