|
|
from typing import List, Optional |
|
|
|
|
|
import torch |
|
|
from data.dataAccessor import update_db |
|
|
from data.task import Task, TaskType |
|
|
from pipelines.commons import Img2Img, Text2Img |
|
|
from pipelines.controlnets import ControlNet |
|
|
from pipelines.prompt_modifier import PromptModifier |
|
|
from util.cache import auto_clear_cuda_and_gc, clear_cuda |
|
|
from util.commons import add_code_names, pickPoses, upload_images |
|
|
from util.lora_style import LoraStyle |
|
|
from util.slack import Slack |
|
|
|
|
|
torch.backends.cudnn.benchmark = True |
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
|
|
|
num_return_sequences = 4 |
|
|
auto_mode = False |
|
|
|
|
|
prompt_modifier = PromptModifier(num_of_sequences=num_return_sequences) |
|
|
controlnet = ControlNet() |
|
|
lora_style = LoraStyle() |
|
|
text2img_pipe = Text2Img() |
|
|
img2img_pipe = Img2Img() |
|
|
slack = Slack() |
|
|
|
|
|
|
|
|
def get_patched_prompt(task: Task): |
|
|
def add_style_and_character(prompt: List[str]): |
|
|
for i in range(len(prompt)): |
|
|
prompt[i] = add_code_names(prompt[i]) |
|
|
prompt[i] = lora_style.prepend_style_to_prompt(prompt[i], task.get_style()) |
|
|
|
|
|
prompt = task.get_prompt() |
|
|
|
|
|
if task.is_prompt_engineering(): |
|
|
prompt = prompt_modifier.modify(prompt) |
|
|
else: |
|
|
prompt = [prompt] * num_return_sequences |
|
|
|
|
|
ori_prompt = [task.get_prompt()] * num_return_sequences |
|
|
|
|
|
add_style_and_character(ori_prompt) |
|
|
add_style_and_character(prompt) |
|
|
|
|
|
print({"prompts": prompt}) |
|
|
|
|
|
return (prompt, ori_prompt) |
|
|
|
|
|
|
|
|
@update_db |
|
|
@auto_clear_cuda_and_gc(controlnet) |
|
|
@slack.auto_send_alert |
|
|
def canny(task: Task): |
|
|
prompt, _ = get_patched_prompt(task) |
|
|
|
|
|
controlnet.load_canny() |
|
|
|
|
|
lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style()) |
|
|
lora_patcher.patch() |
|
|
|
|
|
images = controlnet.process_canny( |
|
|
prompt=prompt, |
|
|
imageUrl=task.get_imageUrl(), |
|
|
seed=task.get_seed(), |
|
|
steps=task.get_steps(), |
|
|
width=task.get_width(), |
|
|
height=task.get_height(), |
|
|
negative_prompt=[ |
|
|
f"monochrome, neon, x-ray, negative image, oversaturated, {task.get_negative_prompt()}" |
|
|
] |
|
|
* num_return_sequences, |
|
|
**lora_patcher.kwargs(), |
|
|
) |
|
|
|
|
|
generated_image_urls = upload_images(images, "_canny", task.get_taskId()) |
|
|
|
|
|
lora_patcher.cleanup() |
|
|
controlnet.cleanup() |
|
|
|
|
|
return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls} |
|
|
|
|
|
|
|
|
@update_db |
|
|
@auto_clear_cuda_and_gc(controlnet) |
|
|
@slack.auto_send_alert |
|
|
def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None): |
|
|
prompt, _ = get_patched_prompt(task) |
|
|
|
|
|
controlnet.load_pose() |
|
|
|
|
|
lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style()) |
|
|
lora_patcher.patch() |
|
|
|
|
|
if poses is None: |
|
|
poses = [controlnet.detect_pose(task.get_imageUrl())] * num_return_sequences |
|
|
|
|
|
images = controlnet.process_pose( |
|
|
prompt=prompt, |
|
|
image=poses, |
|
|
seed=task.get_seed(), |
|
|
steps=task.get_steps(), |
|
|
negative_prompt=[task.get_negative_prompt()] * num_return_sequences, |
|
|
width=task.get_width(), |
|
|
height=task.get_height(), |
|
|
**lora_patcher.kwargs(), |
|
|
) |
|
|
|
|
|
generated_image_urls = upload_images(images, s3_outkey, task.get_taskId()) |
|
|
|
|
|
lora_patcher.cleanup() |
|
|
controlnet.cleanup() |
|
|
|
|
|
return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls} |
|
|
|
|
|
|
|
|
@update_db |
|
|
@auto_clear_cuda_and_gc(controlnet) |
|
|
@slack.auto_send_alert |
|
|
def text2img(task: Task): |
|
|
prompt, ori_prompt = get_patched_prompt(task) |
|
|
|
|
|
lora_patcher = lora_style.get_patcher(text2img_pipe.pipe, task.get_style()) |
|
|
lora_patcher.patch() |
|
|
|
|
|
torch.manual_seed(task.get_seed()) |
|
|
|
|
|
images = text2img_pipe.process( |
|
|
prompt=ori_prompt, |
|
|
modified_prompts=prompt, |
|
|
num_inference_steps=task.get_steps(), |
|
|
guidance_scale=7.5, |
|
|
height=task.get_height(), |
|
|
width=task.get_width(), |
|
|
negative_prompt=[task.get_negative_prompt()] * num_return_sequences, |
|
|
iteration=task.get_iteration(), |
|
|
**lora_patcher.kwargs(), |
|
|
) |
|
|
|
|
|
generated_image_urls = upload_images(images, "", task.get_taskId()) |
|
|
|
|
|
lora_patcher.cleanup() |
|
|
|
|
|
return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls} |
|
|
|
|
|
|
|
|
@update_db |
|
|
@auto_clear_cuda_and_gc(controlnet) |
|
|
@slack.auto_send_alert |
|
|
def img2img(task: Task): |
|
|
prompt, _ = get_patched_prompt(task) |
|
|
|
|
|
lora_patcher = lora_style.get_patcher(img2img_pipe.pipe, task.get_style()) |
|
|
lora_patcher.patch() |
|
|
|
|
|
torch.manual_seed(task.get_seed()) |
|
|
|
|
|
images = img2img_pipe.process( |
|
|
prompt=prompt, |
|
|
imageUrl=task.get_imageUrl(), |
|
|
negative_prompt=[task.get_negative_prompt()] * num_return_sequences, |
|
|
steps=task.get_steps(), |
|
|
**lora_patcher.kwargs(), |
|
|
) |
|
|
|
|
|
generated_image_urls = upload_images(images, "_imgtoimg", task.get_taskId()) |
|
|
|
|
|
lora_patcher.cleanup() |
|
|
|
|
|
return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls} |
|
|
|
|
|
|
|
|
def model_fn(model_dir): |
|
|
print("Logs: model loaded .... starts") |
|
|
|
|
|
prompt_modifier.load() |
|
|
|
|
|
lora_style.load(model_dir) |
|
|
controlnet.load(model_dir) |
|
|
|
|
|
text2img_pipe.load(model_dir) |
|
|
img2img_pipe.load(model_dir) |
|
|
|
|
|
print("Logs: model loaded ....") |
|
|
return |
|
|
|
|
|
|
|
|
def predict_fn(data, pipe): |
|
|
task = Task(data) |
|
|
print("task is ", data) |
|
|
|
|
|
try: |
|
|
task_type = task.get_type() |
|
|
|
|
|
if task_type == TaskType.TEXT_TO_IMAGE: |
|
|
|
|
|
if "character sheet" in task.get_prompt().lower(): |
|
|
return pose(task, s3_outkey="", poses=pickPoses()) |
|
|
else: |
|
|
return text2img(task) |
|
|
elif task_type == TaskType.IMAGE_TO_IMAGE: |
|
|
return img2img(task) |
|
|
elif task_type == TaskType.CANNY: |
|
|
return canny(task) |
|
|
elif task_type == TaskType.POSE: |
|
|
return pose(task) |
|
|
else: |
|
|
raise Exception("Invalid task type") |
|
|
except Exception as e: |
|
|
print(f"Error: {e}") |
|
|
slack.error_alert(task, e) |
|
|
return None |
|
|
|