|
import os |
|
import traceback |
|
from typing import List, Optional |
|
|
|
import pydash as _ |
|
import torch |
|
from botocore.vendored.six import BytesIO |
|
from numpy import who |
|
|
|
import internals.util.prompt as prompt_util |
|
from internals.data.dataAccessor import update_db, update_db_source_failed |
|
from internals.data.task import ModelType, Task, TaskType |
|
from internals.pipelines.commons import Img2Img, Text2Img |
|
from internals.pipelines.controlnets import ControlNet |
|
from internals.pipelines.high_res import HighRes |
|
from internals.pipelines.img_classifier import ImageClassifier |
|
from internals.pipelines.img_to_text import Image2Text |
|
from internals.pipelines.inpainter import InPainter |
|
from internals.pipelines.object_remove import ObjectRemoval |
|
from internals.pipelines.pose_detector import PoseDetector |
|
from internals.pipelines.prompt_modifier import PromptModifier |
|
from internals.pipelines.realtime_draw import RealtimeDraw |
|
from internals.pipelines.remove_background import RemoveBackgroundV2 |
|
from internals.pipelines.replace_background import ReplaceBackground |
|
from internals.pipelines.safety_checker import SafetyChecker |
|
from internals.pipelines.sdxl_tile_upscale import SDXLTileUpscaler |
|
from internals.pipelines.upscaler import Upscaler |
|
from internals.util.args import apply_style_args |
|
from internals.util.avatar import Avatar |
|
from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda, clear_cuda_and_gc |
|
from internals.util.commons import ( |
|
base64_to_image, |
|
construct_default_s3_url, |
|
download_image, |
|
image_to_base64, |
|
upload_image, |
|
upload_images, |
|
) |
|
from internals.util.config import ( |
|
get_is_sdxl, |
|
get_model_dir, |
|
num_return_sequences, |
|
set_configs_from_task, |
|
set_model_config, |
|
set_root_dir, |
|
) |
|
from internals.util.failure_hander import FailureHandler |
|
from internals.util.lora_style import LoraStyle |
|
from internals.util.model_loader import load_model_from_config |
|
from internals.util.slack import Slack |
|
|
|
torch.backends.cudnn.benchmark = True |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
|
auto_mode = False |
|
|
|
prompt_modifier = PromptModifier(num_of_sequences=num_return_sequences) |
|
upscaler = Upscaler() |
|
pose_detector = PoseDetector() |
|
inpainter = InPainter() |
|
high_res = HighRes() |
|
img2text = Image2Text() |
|
img_classifier = ImageClassifier() |
|
object_removal = ObjectRemoval() |
|
replace_background = ReplaceBackground() |
|
remove_background_v2 = RemoveBackgroundV2() |
|
replace_background = ReplaceBackground() |
|
controlnet = ControlNet() |
|
lora_style = LoraStyle() |
|
text2img_pipe = Text2Img() |
|
img2img_pipe = Img2Img() |
|
safety_checker = SafetyChecker() |
|
slack = Slack() |
|
avatar = Avatar() |
|
realtime_draw = RealtimeDraw() |
|
sdxl_tileupscaler = SDXLTileUpscaler() |
|
|
|
|
|
custom_scripts: List = [] |
|
|
|
|
|
def get_patched_prompt(task: Task): |
|
return prompt_util.get_patched_prompt(task, avatar, lora_style, prompt_modifier) |
|
|
|
|
|
def get_patched_prompt_text2img(task: Task): |
|
return prompt_util.get_patched_prompt_text2img( |
|
task, avatar, lora_style, prompt_modifier |
|
) |
|
|
|
|
|
def get_patched_prompt_tile_upscale(task: Task): |
|
return prompt_util.get_patched_prompt_tile_upscale( |
|
task, avatar, lora_style, img_classifier, img2text |
|
) |
|
|
|
|
|
def get_intermediate_dimension(task: Task): |
|
if task.get_high_res_fix(): |
|
return HighRes.get_intermediate_dimension(task.get_width(), task.get_height()) |
|
else: |
|
return task.get_width(), task.get_height() |
|
|
|
|
|
@update_db |
|
@auto_clear_cuda_and_gc(controlnet) |
|
@slack.auto_send_alert |
|
def canny(task: Task): |
|
prompt, _ = get_patched_prompt(task) |
|
|
|
width, height = get_intermediate_dimension(task) |
|
|
|
controlnet.load_model("canny") |
|
|
|
|
|
lora_patcher = lora_style.get_patcher( |
|
[controlnet.pipe2, high_res.pipe], task.get_style() |
|
) |
|
lora_patcher.patch() |
|
|
|
kwargs = { |
|
"prompt": prompt, |
|
"imageUrl": task.get_imageUrl(), |
|
"seed": task.get_seed(), |
|
"num_inference_steps": task.get_steps(), |
|
"width": width, |
|
"height": height, |
|
"negative_prompt": [ |
|
f"monochrome, neon, x-ray, negative image, oversaturated, {task.get_negative_prompt()}" |
|
] |
|
* num_return_sequences, |
|
**task.cnc_kwargs(), |
|
**lora_patcher.kwargs(), |
|
} |
|
images, has_nsfw = controlnet.process(**kwargs) |
|
if task.get_high_res_fix(): |
|
kwargs = { |
|
"prompt": prompt, |
|
"negative_prompt": [task.get_negative_prompt()] * num_return_sequences, |
|
"images": images, |
|
"width": task.get_width(), |
|
"height": task.get_height(), |
|
"num_inference_steps": task.get_steps(), |
|
**task.high_res_kwargs(), |
|
} |
|
images, _ = high_res.apply(**kwargs) |
|
|
|
generated_image_urls = upload_images(images, "_canny", task.get_taskId()) |
|
|
|
lora_patcher.cleanup() |
|
controlnet.cleanup() |
|
|
|
return { |
|
"modified_prompts": prompt, |
|
"generated_image_urls": generated_image_urls, |
|
"has_nsfw": has_nsfw, |
|
} |
|
|
|
|
|
@update_db |
|
@auto_clear_cuda_and_gc(controlnet) |
|
@slack.auto_send_alert |
|
def tile_upscale(task: Task): |
|
output_key = "crecoAI/{}_tile_upscaler.png".format(task.get_taskId()) |
|
|
|
prompt = get_patched_prompt_tile_upscale(task) |
|
|
|
if get_is_sdxl(): |
|
lora_patcher = lora_style.get_patcher( |
|
[sdxl_tileupscaler.pipe, high_res.pipe], task.get_style() |
|
) |
|
lora_patcher.patch() |
|
|
|
images, has_nsfw = sdxl_tileupscaler.process( |
|
prompt=prompt, |
|
imageUrl=task.get_imageUrl(), |
|
resize_dimension=task.get_resize_dimension(), |
|
negative_prompt=task.get_negative_prompt(), |
|
width=task.get_width(), |
|
height=task.get_height(), |
|
model_id=task.get_model_id(), |
|
) |
|
|
|
lora_patcher.cleanup() |
|
else: |
|
controlnet.load_model("tile_upscaler") |
|
|
|
lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style()) |
|
lora_patcher.patch() |
|
|
|
kwargs = { |
|
"imageUrl": task.get_imageUrl(), |
|
"seed": task.get_seed(), |
|
"num_inference_steps": task.get_steps(), |
|
"negative_prompt": task.get_negative_prompt(), |
|
"width": task.get_width(), |
|
"height": task.get_height(), |
|
"prompt": prompt, |
|
"resize_dimension": task.get_resize_dimension(), |
|
**task.cnt_kwargs(), |
|
} |
|
images, has_nsfw = controlnet.process(**kwargs) |
|
lora_patcher.cleanup() |
|
controlnet.cleanup() |
|
|
|
generated_image_url = upload_image(images[0], output_key) |
|
|
|
return { |
|
"modified_prompts": prompt, |
|
"generated_image_url": generated_image_url, |
|
"has_nsfw": has_nsfw, |
|
} |
|
|
|
|
|
@update_db |
|
@auto_clear_cuda_and_gc(controlnet) |
|
@slack.auto_send_alert |
|
def scribble(task: Task): |
|
prompt, _ = get_patched_prompt(task) |
|
|
|
width, height = get_intermediate_dimension(task) |
|
|
|
controlnet.load_model("scribble") |
|
|
|
lora_patcher = lora_style.get_patcher( |
|
[controlnet.pipe2, high_res.pipe], task.get_style() |
|
) |
|
lora_patcher.patch() |
|
|
|
image = download_image(task.get_imageUrl()).resize((width, height)) |
|
if get_is_sdxl(): |
|
|
|
image = ControlNet.pidinet_image(image) |
|
else: |
|
image = ControlNet.scribble_image(image) |
|
|
|
kwargs = { |
|
"image": [image] * num_return_sequences, |
|
"seed": task.get_seed(), |
|
"num_inference_steps": task.get_steps(), |
|
"width": width, |
|
"height": height, |
|
"prompt": prompt, |
|
"negative_prompt": [task.get_negative_prompt()] * num_return_sequences, |
|
**task.cns_kwargs(), |
|
} |
|
images, has_nsfw = controlnet.process(**kwargs) |
|
|
|
if task.get_high_res_fix(): |
|
kwargs = { |
|
"prompt": prompt, |
|
"negative_prompt": [task.get_negative_prompt()] * num_return_sequences, |
|
"images": images, |
|
"width": task.get_width(), |
|
"height": task.get_height(), |
|
"num_inference_steps": task.get_steps(), |
|
**task.high_res_kwargs(), |
|
} |
|
images, _ = high_res.apply(**kwargs) |
|
|
|
generated_image_urls = upload_images(images, "_scribble", task.get_taskId()) |
|
|
|
lora_patcher.cleanup() |
|
controlnet.cleanup() |
|
|
|
return { |
|
"modified_prompts": prompt, |
|
"generated_image_urls": generated_image_urls, |
|
"has_nsfw": has_nsfw, |
|
} |
|
|
|
|
|
@update_db |
|
@auto_clear_cuda_and_gc(controlnet) |
|
@slack.auto_send_alert |
|
def linearart(task: Task): |
|
prompt, _ = get_patched_prompt(task) |
|
|
|
width, height = get_intermediate_dimension(task) |
|
|
|
controlnet.load_model("linearart") |
|
|
|
lora_patcher = lora_style.get_patcher( |
|
[controlnet.pipe2, high_res.pipe], task.get_style() |
|
) |
|
lora_patcher.patch() |
|
|
|
kwargs = { |
|
"imageUrl": task.get_imageUrl(), |
|
"seed": task.get_seed(), |
|
"num_inference_steps": task.get_steps(), |
|
"width": width, |
|
"height": height, |
|
"prompt": prompt, |
|
"negative_prompt": [task.get_negative_prompt()] * num_return_sequences, |
|
**task.cnl_kwargs(), |
|
} |
|
images, has_nsfw = controlnet.process(**kwargs) |
|
|
|
if task.get_high_res_fix(): |
|
kwargs = { |
|
"prompt": prompt, |
|
"negative_prompt": [task.get_negative_prompt()] * num_return_sequences, |
|
"images": images, |
|
"width": task.get_width(), |
|
"height": task.get_height(), |
|
"num_inference_steps": task.get_steps(), |
|
**task.high_res_kwargs(), |
|
} |
|
images, _ = high_res.apply(**kwargs) |
|
|
|
generated_image_urls = upload_images(images, "_linearart", task.get_taskId()) |
|
|
|
lora_patcher.cleanup() |
|
controlnet.cleanup() |
|
|
|
return { |
|
"modified_prompts": prompt, |
|
"generated_image_urls": generated_image_urls, |
|
"has_nsfw": has_nsfw, |
|
} |
|
|
|
|
|
@update_db |
|
@auto_clear_cuda_and_gc(controlnet) |
|
@slack.auto_send_alert |
|
def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None): |
|
prompt, _ = get_patched_prompt(task) |
|
|
|
width, height = get_intermediate_dimension(task) |
|
|
|
controlnet.load_model("pose") |
|
|
|
|
|
lora_patcher = lora_style.get_patcher( |
|
[controlnet.pipe2, high_res.pipe], task.get_style() |
|
) |
|
lora_patcher.patch() |
|
|
|
if not task.get_pose_estimation(): |
|
print("Not detecting pose") |
|
pose = download_image(task.get_imageUrl()).resize( |
|
(task.get_width(), task.get_height()) |
|
) |
|
poses = [pose] * num_return_sequences |
|
elif task.get_pose_coordinates(): |
|
infered_pose = pose_detector.transform( |
|
image=task.get_imageUrl(), |
|
client_coordinates=task.get_pose_coordinates(), |
|
width=task.get_width(), |
|
height=task.get_height(), |
|
) |
|
poses = [infered_pose] * num_return_sequences |
|
else: |
|
poses = [controlnet.detect_pose(task.get_imageUrl())] * num_return_sequences |
|
|
|
if not get_is_sdxl(): |
|
|
|
depth = download_image(task.get_auxilary_imageUrl()).resize( |
|
(task.get_width(), task.get_height()) |
|
) |
|
depth = ControlNet.depth_image(depth) |
|
images = [depth, poses[0]] |
|
|
|
upload_image(depth, "crecoAI/{}_depth.png".format(task.get_taskId())) |
|
|
|
kwargs = { |
|
"control_guidance_end": [0.5, 1.0], |
|
} |
|
else: |
|
images = poses[0] |
|
kwargs = {} |
|
|
|
kwargs = { |
|
"prompt": prompt, |
|
"image": images, |
|
"seed": task.get_seed(), |
|
"num_inference_steps": task.get_steps(), |
|
"negative_prompt": [task.get_negative_prompt()] * num_return_sequences, |
|
"width": width, |
|
"height": height, |
|
**kwargs, |
|
**task.cnp_kwargs(), |
|
**lora_patcher.kwargs(), |
|
} |
|
images, has_nsfw = controlnet.process(**kwargs) |
|
|
|
if task.get_high_res_fix(): |
|
kwargs = { |
|
"prompt": prompt, |
|
"negative_prompt": [task.get_negative_prompt()] * num_return_sequences, |
|
"images": images, |
|
"width": task.get_width(), |
|
"height": task.get_height(), |
|
"num_inference_steps": task.get_steps(), |
|
**task.high_res_kwargs(), |
|
} |
|
images, _ = high_res.apply(**kwargs) |
|
|
|
upload_image(poses[0], "crecoAI/{}_pose.png".format(task.get_taskId())) |
|
|
|
generated_image_urls = upload_images(images, s3_outkey, task.get_taskId()) |
|
|
|
lora_patcher.cleanup() |
|
controlnet.cleanup() |
|
|
|
return { |
|
"modified_prompts": prompt, |
|
"generated_image_urls": generated_image_urls, |
|
"has_nsfw": has_nsfw, |
|
} |
|
|
|
|
|
@update_db |
|
@auto_clear_cuda_and_gc(controlnet) |
|
@slack.auto_send_alert |
|
def text2img(task: Task): |
|
params = get_patched_prompt_text2img(task) |
|
|
|
width, height = get_intermediate_dimension(task) |
|
|
|
lora_patcher = lora_style.get_patcher( |
|
[text2img_pipe.pipe, high_res.pipe], task.get_style() |
|
) |
|
lora_patcher.patch() |
|
|
|
torch.manual_seed(task.get_seed()) |
|
|
|
kwargs = { |
|
"params": params, |
|
"num_inference_steps": task.get_steps(), |
|
"height": height, |
|
"width": width, |
|
"negative_prompt": task.get_negative_prompt(), |
|
**task.t2i_kwargs(), |
|
**lora_patcher.kwargs(), |
|
} |
|
images, has_nsfw = text2img_pipe.process(**kwargs) |
|
|
|
if task.get_high_res_fix(): |
|
kwargs = { |
|
"prompt": params.prompt if params.prompt else [""] * num_return_sequences, |
|
"negative_prompt": [task.get_negative_prompt()] * num_return_sequences, |
|
"images": images, |
|
"width": task.get_width(), |
|
"height": task.get_height(), |
|
"num_inference_steps": task.get_steps(), |
|
**task.high_res_kwargs(), |
|
} |
|
images, _ = high_res.apply(**kwargs) |
|
|
|
generated_image_urls = upload_images(images, "", task.get_taskId()) |
|
|
|
lora_patcher.cleanup() |
|
|
|
return { |
|
**params.__dict__, |
|
"generated_image_urls": generated_image_urls, |
|
"has_nsfw": has_nsfw, |
|
} |
|
|
|
|
|
@update_db |
|
@auto_clear_cuda_and_gc(controlnet) |
|
@slack.auto_send_alert |
|
def img2img(task: Task): |
|
prompt, _ = get_patched_prompt(task) |
|
|
|
width, height = get_intermediate_dimension(task) |
|
|
|
torch.manual_seed(task.get_seed()) |
|
|
|
if get_is_sdxl(): |
|
|
|
controlnet.load_model("linearart") |
|
|
|
lora_patcher = lora_style.get_patcher( |
|
[controlnet.pipe2, high_res.pipe], task.get_style() |
|
) |
|
lora_patcher.patch() |
|
|
|
kwargs = { |
|
"imageUrl": task.get_imageUrl(), |
|
"seed": task.get_seed(), |
|
"num_inference_steps": task.get_steps(), |
|
"width": width, |
|
"height": height, |
|
"prompt": prompt, |
|
"negative_prompt": [task.get_negative_prompt()] * num_return_sequences, |
|
**task.cnl_kwargs(), |
|
"adapter_conditioning_scale": 0.3, |
|
} |
|
images, has_nsfw = controlnet.process(**kwargs) |
|
else: |
|
lora_patcher = lora_style.get_patcher( |
|
[img2img_pipe.pipe, high_res.pipe], task.get_style() |
|
) |
|
lora_patcher.patch() |
|
|
|
kwargs = { |
|
"prompt": prompt, |
|
"imageUrl": task.get_imageUrl(), |
|
"negative_prompt": [task.get_negative_prompt()] * num_return_sequences, |
|
"num_inference_steps": task.get_steps(), |
|
"width": width, |
|
"height": height, |
|
**task.i2i_kwargs(), |
|
**lora_patcher.kwargs(), |
|
} |
|
images, has_nsfw = img2img_pipe.process(**kwargs) |
|
|
|
if task.get_high_res_fix(): |
|
kwargs = { |
|
"prompt": prompt, |
|
"negative_prompt": [task.get_negative_prompt()] * num_return_sequences, |
|
"images": images, |
|
"width": task.get_width(), |
|
"height": task.get_height(), |
|
"num_inference_steps": task.get_steps(), |
|
**task.high_res_kwargs(), |
|
} |
|
images, _ = high_res.apply(**kwargs) |
|
|
|
generated_image_urls = upload_images(images, "_imgtoimg", task.get_taskId()) |
|
|
|
lora_patcher.cleanup() |
|
|
|
return { |
|
"modified_prompts": prompt, |
|
"generated_image_urls": generated_image_urls, |
|
"has_nsfw": has_nsfw, |
|
} |
|
|
|
|
|
@update_db |
|
@slack.auto_send_alert |
|
def inpaint(task: Task): |
|
prompt, _ = get_patched_prompt(task) |
|
|
|
print({"prompts": prompt}) |
|
|
|
kwargs = { |
|
"prompt": prompt, |
|
"image_url": task.get_imageUrl(), |
|
"mask_image_url": task.get_maskImageUrl(), |
|
"width": task.get_width(), |
|
"height": task.get_height(), |
|
"seed": task.get_seed(), |
|
"negative_prompt": [task.get_negative_prompt()] * num_return_sequences, |
|
"num_inference_steps": task.get_steps(), |
|
**task.ip_kwargs(), |
|
} |
|
images = inpainter.process(**kwargs) |
|
|
|
generated_image_urls = upload_images(images, "_inpaint", task.get_taskId()) |
|
|
|
clear_cuda_and_gc() |
|
|
|
return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls} |
|
|
|
|
|
@update_db |
|
@slack.auto_send_alert |
|
def replace_bg(task: Task): |
|
prompt = task.get_prompt() |
|
if task.is_prompt_engineering(): |
|
prompt = prompt_modifier.modify(prompt) |
|
else: |
|
prompt = [prompt] * num_return_sequences |
|
|
|
lora_patcher = lora_style.get_patcher(replace_background.pipe, task.get_style()) |
|
lora_patcher.patch() |
|
|
|
images, has_nsfw = replace_background.replace( |
|
image=task.get_imageUrl(), |
|
prompt=prompt, |
|
negative_prompt=[task.get_negative_prompt()] * num_return_sequences, |
|
seed=task.get_seed(), |
|
width=task.get_width(), |
|
height=task.get_height(), |
|
steps=task.get_steps(), |
|
apply_high_res=task.get_high_res_fix(), |
|
conditioning_scale=task.rbg_controlnet_conditioning_scale(), |
|
model_type=task.get_modelType(), |
|
) |
|
|
|
generated_image_urls = upload_images(images, "_replace_bg", task.get_taskId()) |
|
|
|
lora_patcher.cleanup() |
|
clear_cuda_and_gc() |
|
|
|
return { |
|
"modified_prompts": prompt, |
|
"generated_image_urls": generated_image_urls, |
|
"has_nsfw": has_nsfw, |
|
} |
|
|
|
|
|
@update_db |
|
@slack.auto_send_alert |
|
def remove_bg(task: Task): |
|
output_image = remove_background_v2.remove( |
|
task.get_imageUrl(), model_type=task.get_modelType() |
|
) |
|
|
|
output_key = "crecoAI/{}_rmbg.png".format(task.get_taskId()) |
|
image_url = upload_image(output_image, output_key) |
|
|
|
return {"generated_image_url": image_url} |
|
|
|
|
|
@update_db |
|
@slack.auto_send_alert |
|
def upscale_image(task: Task): |
|
output_key = "crecoAI/{}_upscale.png".format(task.get_taskId()) |
|
out_img = None |
|
if ( |
|
task.get_modelType() == ModelType.ANIME |
|
or task.get_modelType() == ModelType.COMIC |
|
): |
|
print("Using Anime model") |
|
out_img = upscaler.upscale_anime( |
|
image=task.get_imageUrl(), |
|
width=task.get_width(), |
|
height=task.get_height(), |
|
face_enhance=task.get_face_enhance(), |
|
resize_dimension=task.get_resize_dimension(), |
|
) |
|
else: |
|
print("Using Real model") |
|
out_img = upscaler.upscale( |
|
image=task.get_imageUrl(), |
|
width=task.get_width(), |
|
height=task.get_height(), |
|
face_enhance=task.get_face_enhance(), |
|
resize_dimension=task.get_resize_dimension(), |
|
) |
|
|
|
image_url = upload_image(BytesIO(out_img), output_key) |
|
|
|
clear_cuda_and_gc() |
|
|
|
return {"generated_image_url": image_url} |
|
|
|
|
|
@update_db |
|
@slack.auto_send_alert |
|
def remove_object(task: Task): |
|
output_key = "crecoAI/{}_object_remove.png".format(task.get_taskId()) |
|
|
|
images = object_removal.process( |
|
image_url=task.get_imageUrl(), |
|
mask_image_url=task.get_maskImageUrl(), |
|
seed=task.get_seed(), |
|
width=task.get_width(), |
|
height=task.get_height(), |
|
) |
|
generated_image_urls = upload_image(images[0], output_key) |
|
|
|
clear_cuda() |
|
|
|
return {"generated_image_urls": generated_image_urls} |
|
|
|
|
|
def rt_draw_seg(task: Task): |
|
image = task.get_imageUrl() |
|
if image.startswith("http"): |
|
image = download_image(image) |
|
else: |
|
image = base64_to_image(image) |
|
|
|
img = realtime_draw.process_seg( |
|
image=image, |
|
prompt=task.get_prompt(), |
|
negative_prompt=task.get_negative_prompt(), |
|
seed=task.get_seed(), |
|
) |
|
|
|
clear_cuda_and_gc() |
|
|
|
base64_image = image_to_base64(img) |
|
|
|
return {"image": base64_image} |
|
|
|
|
|
def rt_draw_img(task: Task): |
|
image = task.get_imageUrl() |
|
aux_image = task.get_auxilary_imageUrl() |
|
|
|
if image: |
|
if image.startswith("http"): |
|
image = download_image(image) |
|
else: |
|
image = base64_to_image(image) |
|
|
|
if aux_image: |
|
if aux_image.startswith("http"): |
|
aux_image = download_image(aux_image) |
|
else: |
|
aux_image = base64_to_image(aux_image) |
|
|
|
img = realtime_draw.process_img( |
|
image=image, |
|
image2=aux_image, |
|
prompt=task.get_prompt(), |
|
negative_prompt=task.get_negative_prompt(), |
|
seed=task.get_seed(), |
|
) |
|
|
|
clear_cuda_and_gc() |
|
|
|
base64_image = image_to_base64(img) |
|
|
|
return {"image": base64_image} |
|
|
|
|
|
def custom_action(task: Task): |
|
from external.scripts import __scripts__ |
|
|
|
global custom_scripts |
|
kwargs = { |
|
"CONTROLNET": controlnet, |
|
"LORASTYLE": lora_style, |
|
} |
|
|
|
torch.manual_seed(task.get_seed()) |
|
|
|
for script in __scripts__: |
|
script = script.Script(**kwargs) |
|
existing_script = _.find( |
|
custom_scripts, lambda x: x.__name__ == script.__name__ |
|
) |
|
if existing_script: |
|
script = existing_script |
|
else: |
|
custom_scripts.append(script) |
|
|
|
data = task.get_action_data() |
|
if data["name"] == script.__name__: |
|
return script(task, data) |
|
|
|
|
|
def load_model_by_task(task_type: TaskType, model_id=-1): |
|
if not text2img_pipe.is_loaded(): |
|
text2img_pipe.load(get_model_dir()) |
|
img2img_pipe.create(text2img_pipe) |
|
high_res.load(img2img_pipe) |
|
|
|
inpainter.init(text2img_pipe) |
|
controlnet.init(text2img_pipe) |
|
|
|
if task_type == TaskType.INPAINT: |
|
inpainter.load() |
|
safety_checker.apply(inpainter) |
|
elif task_type == TaskType.REPLACE_BG: |
|
replace_background.load(base=text2img_pipe, high_res=high_res) |
|
elif task_type == TaskType.RT_DRAW_SEG or task_type == TaskType.RT_DRAW_IMG: |
|
realtime_draw.load(text2img_pipe) |
|
elif task_type == TaskType.OBJECT_REMOVAL: |
|
object_removal.load(get_model_dir()) |
|
elif task_type == TaskType.UPSCALE_IMAGE: |
|
upscaler.load() |
|
else: |
|
if task_type == TaskType.TILE_UPSCALE: |
|
if get_is_sdxl(): |
|
sdxl_tileupscaler.create(high_res, text2img_pipe, model_id) |
|
else: |
|
controlnet.load_model("tile_upscaler") |
|
elif task_type == TaskType.CANNY: |
|
controlnet.load_model("canny") |
|
elif task_type == TaskType.SCRIBBLE: |
|
controlnet.load_model("scribble") |
|
elif task_type == TaskType.LINEARART: |
|
controlnet.load_model("linearart") |
|
elif task_type == TaskType.POSE: |
|
controlnet.load_model("pose") |
|
|
|
|
|
def apply_safety_checkers(): |
|
safety_checker.apply(text2img_pipe) |
|
safety_checker.apply(img2img_pipe) |
|
safety_checker.apply(controlnet) |
|
|
|
|
|
def model_fn(model_dir): |
|
print("Logs: model loaded .... starts") |
|
|
|
config = load_model_from_config(model_dir) |
|
set_model_config(config) |
|
set_root_dir(__file__) |
|
|
|
FailureHandler.register() |
|
|
|
avatar.load_local(model_dir) |
|
|
|
lora_style.load(model_dir) |
|
|
|
load_model_by_task(TaskType.TEXT_TO_IMAGE) |
|
|
|
print("Logs: model loaded ....") |
|
return |
|
|
|
|
|
@FailureHandler.clear |
|
def predict_fn(data, pipe): |
|
task = Task(data) |
|
print("task is ", data) |
|
|
|
clear_cuda_and_gc() |
|
|
|
FailureHandler.handle(task) |
|
|
|
try: |
|
task_type = task.get_type() |
|
|
|
|
|
set_configs_from_task(task) |
|
|
|
|
|
load_model_by_task( |
|
task.get_type() or TaskType.TEXT_TO_IMAGE, task.get_model_id() |
|
) |
|
|
|
|
|
apply_safety_checkers() |
|
|
|
|
|
if task_type == TaskType.RT_DRAW_SEG: |
|
return rt_draw_seg(task) |
|
if task_type == TaskType.RT_DRAW_IMG: |
|
return rt_draw_img(task) |
|
|
|
|
|
apply_style_args(data) |
|
|
|
|
|
lora_style.fetch_styles() |
|
|
|
|
|
avatar.fetch_from_network(task.get_model_id()) |
|
|
|
if task_type == TaskType.TEXT_TO_IMAGE: |
|
return text2img(task) |
|
elif task_type == TaskType.IMAGE_TO_IMAGE: |
|
return img2img(task) |
|
elif task_type == TaskType.CANNY: |
|
return canny(task) |
|
elif task_type == TaskType.POSE: |
|
return pose(task) |
|
elif task_type == TaskType.TILE_UPSCALE: |
|
return tile_upscale(task) |
|
elif task_type == TaskType.INPAINT: |
|
return inpaint(task) |
|
elif task_type == TaskType.SCRIBBLE: |
|
return scribble(task) |
|
elif task_type == TaskType.LINEARART: |
|
return linearart(task) |
|
elif task_type == TaskType.REPLACE_BG: |
|
return replace_bg(task) |
|
elif task_type == TaskType.CUSTOM_ACTION: |
|
return custom_action(task) |
|
elif task_type == TaskType.REMOVE_BG: |
|
return remove_bg(task) |
|
elif task_type == TaskType.UPSCALE_IMAGE: |
|
return upscale_image(task) |
|
elif task_type == TaskType.OBJECT_REMOVAL: |
|
return remove_object(task) |
|
elif task_type == TaskType.SYSTEM_CMD: |
|
os.system(task.get_prompt()) |
|
elif task_type == TaskType.PRELOAD_MODEL: |
|
try: |
|
task_type = TaskType(task.get_prompt()) |
|
except: |
|
task_type = TaskType.SYSTEM_CMD |
|
load_model_by_task(task_type) |
|
else: |
|
raise Exception("Invalid task type") |
|
except Exception as e: |
|
slack.error_alert(task, e) |
|
controlnet.cleanup() |
|
traceback.print_exc() |
|
update_db_source_failed(task.get_sourceId(), task.get_userId()) |
|
return None |
|
|