File size: 6,039 Bytes
19b3da3 a3d6c18 19b3da3 a3d6c18 86248f3 19b3da3 a3d6c18 19b3da3 a3d6c18 19b3da3 a3d6c18 19b3da3 a3d6c18 19b3da3 a3d6c18 19b3da3 a3d6c18 19b3da3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
from io import BytesIO
import torch
from internals.data.dataAccessor import update_db
from internals.data.task import ModelType, Task, TaskType
from internals.pipelines.inpainter import InPainter
from internals.pipelines.object_remove import ObjectRemoval
from internals.pipelines.prompt_modifier import PromptModifier
from internals.pipelines.remove_background import RemoveBackground, RemoveBackgroundV2
from internals.pipelines.replace_background import ReplaceBackground
from internals.pipelines.safety_checker import SafetyChecker
from internals.pipelines.upscaler import Upscaler
from internals.util.avatar import Avatar
from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda
from internals.util.commons import construct_default_s3_url, upload_image, upload_images
from internals.util.config import (
num_return_sequences,
set_configs_from_task,
set_root_dir,
)
from internals.util.failure_hander import FailureHandler
from internals.util.slack import Slack
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
auto_mode = False
slack = Slack()
prompt_modifier = PromptModifier(num_of_sequences=num_return_sequences)
upscaler = Upscaler()
inpainter = InPainter()
safety_checker = SafetyChecker()
object_removal = ObjectRemoval()
remove_background_v2 = RemoveBackgroundV2()
avatar = Avatar()
replace_background = ReplaceBackground()
@update_db
@slack.auto_send_alert
def remove_bg(task: Task):
remove_background = RemoveBackground()
output_image = remove_background.remove(task.get_imageUrl())
output_key = "crecoAI/{}_rmbg.png".format(task.get_taskId())
upload_image(output_image, output_key)
return {"generated_image_url": construct_default_s3_url(output_key)}
@update_db
@slack.auto_send_alert
def inpaint(task: Task):
prompt = avatar.add_code_names(task.get_prompt())
if task.is_prompt_engineering():
prompt = prompt_modifier.modify(prompt)
else:
prompt = [prompt] * num_return_sequences
print({"prompts": prompt})
images = inpainter.process(
prompt=prompt,
image_url=task.get_imageUrl(),
mask_image_url=task.get_maskImageUrl(),
width=task.get_width(),
height=task.get_height(),
seed=task.get_seed(),
negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
)
generated_image_urls = upload_images(images, "_inpaint", task.get_taskId())
clear_cuda()
return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
@update_db
@slack.auto_send_alert
def remove_object(task: Task):
output_key = "crecoAI/{}_object_remove.png".format(task.get_taskId())
images = object_removal.process(
image_url=task.get_imageUrl(),
mask_image_url=task.get_maskImageUrl(),
seed=task.get_seed(),
width=task.get_width(),
height=task.get_height(),
)
generated_image_urls = upload_image(images[0], output_key)
clear_cuda()
return {"generated_image_urls": generated_image_urls}
@update_db
@slack.auto_send_alert
def replace_bg(task: Task):
prompt = task.get_prompt()
if task.is_prompt_engineering():
prompt = prompt_modifier.modify(prompt)
else:
prompt = [prompt] * num_return_sequences
images, has_nsfw = replace_background.replace(
image=task.get_imageUrl(),
prompt=prompt,
negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
seed=task.get_seed(),
width=task.get_width(),
height=task.get_height(),
steps=task.get_steps(),
resize_dimension=task.get_resize_dimension(),
product_scale_width=task.get_image_scale(),
)
generated_image_urls = upload_images(images, "_replace_bg", task.get_taskId())
return {
"modified_prompts": prompt,
"generated_image_urls": generated_image_urls,
"has_nsfw": has_nsfw,
}
@update_db
@slack.auto_send_alert
def upscale_image(task: Task):
output_key = "crecoAI/{}_upscale.png".format(task.get_taskId())
out_img = None
if task.get_modelType() == ModelType.ANIME:
print("Using Anime model")
out_img = upscaler.upscale_anime(
image=task.get_imageUrl(), resize_dimension=task.get_resize_dimension()
)
else:
print("Using Real model")
out_img = upscaler.upscale(
image=task.get_imageUrl(), resize_dimension=task.get_resize_dimension()
)
upload_image(BytesIO(out_img), output_key)
return {"generated_image_url": construct_default_s3_url(output_key)}
def model_fn(model_dir):
print("Logs: model loaded .... starts")
set_root_dir(__file__)
FailureHandler.register()
avatar.load_local(model_dir)
prompt_modifier.load()
safety_checker.load()
object_removal.load(model_dir)
upscaler.load()
inpainter.load()
replace_background.load(upscaler, remove_background_v2)
safety_checker.apply(inpainter)
print("Logs: model loaded ....")
return
@FailureHandler.clear
def predict_fn(data, pipe):
task = Task(data)
print("task is ", data)
FailureHandler.handle(task)
# Set set_environment
set_configs_from_task(task)
try:
# Set set_environment
set_configs_from_task(task)
# Fetch avatars
avatar.fetch_from_network(task.get_model_id())
task_type = task.get_type()
if task_type == TaskType.REMOVE_BG:
return remove_bg(task)
elif task_type == TaskType.INPAINT:
return inpaint(task)
elif task_type == TaskType.UPSCALE_IMAGE:
return upscale_image(task)
elif task_type == TaskType.OBJECT_REMOVAL:
return remove_object(task)
elif task_type == TaskType.REPLACE_BG:
return replace_bg(task)
else:
raise Exception("Invalid task type")
except Exception as e:
print(f"Error: {e}")
slack.error_alert(task, e)
return None
|