model-sd-multi / inference.py
jayparmr's picture
Upload 18 files
4adca93
from typing import List, Optional
import torch
from data.dataAccessor import update_db
from data.task import Task, TaskType
from pipelines.commons import Img2Img, Text2Img
from pipelines.controlnets import ControlNet
from pipelines.prompt_modifier import PromptModifier
from util.cache import auto_clear_cuda_and_gc, clear_cuda
from util.commons import add_code_names, pickPoses, upload_images
from util.lora_style import LoraStyle
from util.slack import Slack
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
num_return_sequences = 4 # the number of results to generate
auto_mode = False
prompt_modifier = PromptModifier(num_of_sequences=num_return_sequences)
controlnet = ControlNet()
lora_style = LoraStyle()
text2img_pipe = Text2Img()
img2img_pipe = Img2Img()
slack = Slack()
def get_patched_prompt(task: Task):
def add_style_and_character(prompt: List[str]):
for i in range(len(prompt)):
prompt[i] = add_code_names(prompt[i])
prompt[i] = lora_style.prepend_style_to_prompt(prompt[i], task.get_style())
prompt = task.get_prompt()
if task.is_prompt_engineering():
prompt = prompt_modifier.modify(prompt)
else:
prompt = [prompt] * num_return_sequences
ori_prompt = [task.get_prompt()] * num_return_sequences
add_style_and_character(ori_prompt)
add_style_and_character(prompt)
print({"prompts": prompt})
return (prompt, ori_prompt)
@update_db
@auto_clear_cuda_and_gc(controlnet)
@slack.auto_send_alert
def canny(task: Task):
prompt, _ = get_patched_prompt(task)
controlnet.load_canny()
lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style())
lora_patcher.patch()
images = controlnet.process_canny(
prompt=prompt,
imageUrl=task.get_imageUrl(),
seed=task.get_seed(),
steps=task.get_steps(),
width=task.get_width(),
height=task.get_height(),
negative_prompt=[
f"monochrome, neon, x-ray, negative image, oversaturated, {task.get_negative_prompt()}"
]
* num_return_sequences,
**lora_patcher.kwargs(),
)
generated_image_urls = upload_images(images, "_canny", task.get_taskId())
lora_patcher.cleanup()
controlnet.cleanup()
return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
@update_db
@auto_clear_cuda_and_gc(controlnet)
@slack.auto_send_alert
def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None):
prompt, _ = get_patched_prompt(task)
controlnet.load_pose()
lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style())
lora_patcher.patch()
if poses is None:
poses = [controlnet.detect_pose(task.get_imageUrl())] * num_return_sequences
images = controlnet.process_pose(
prompt=prompt,
image=poses,
seed=task.get_seed(),
steps=task.get_steps(),
negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
width=task.get_width(),
height=task.get_height(),
**lora_patcher.kwargs(),
)
generated_image_urls = upload_images(images, s3_outkey, task.get_taskId())
lora_patcher.cleanup()
controlnet.cleanup()
return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
@update_db
@auto_clear_cuda_and_gc(controlnet)
@slack.auto_send_alert
def text2img(task: Task):
prompt, ori_prompt = get_patched_prompt(task)
lora_patcher = lora_style.get_patcher(text2img_pipe.pipe, task.get_style())
lora_patcher.patch()
torch.manual_seed(task.get_seed())
images = text2img_pipe.process(
prompt=ori_prompt,
modified_prompts=prompt,
num_inference_steps=task.get_steps(),
guidance_scale=7.5,
height=task.get_height(),
width=task.get_width(),
negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
iteration=task.get_iteration(),
**lora_patcher.kwargs(),
)
generated_image_urls = upload_images(images, "", task.get_taskId())
lora_patcher.cleanup()
return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
@update_db
@auto_clear_cuda_and_gc(controlnet)
@slack.auto_send_alert
def img2img(task: Task):
prompt, _ = get_patched_prompt(task)
lora_patcher = lora_style.get_patcher(img2img_pipe.pipe, task.get_style())
lora_patcher.patch()
torch.manual_seed(task.get_seed())
images = img2img_pipe.process(
prompt=prompt,
imageUrl=task.get_imageUrl(),
negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
steps=task.get_steps(),
**lora_patcher.kwargs(),
)
generated_image_urls = upload_images(images, "_imgtoimg", task.get_taskId())
lora_patcher.cleanup()
return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
def model_fn(model_dir):
print("Logs: model loaded .... starts")
prompt_modifier.load()
lora_style.load(model_dir)
controlnet.load(model_dir)
text2img_pipe.load(model_dir)
img2img_pipe.load(model_dir)
print("Logs: model loaded ....")
return
def predict_fn(data, pipe):
task = Task(data)
print("task is ", data)
try:
task_type = task.get_type()
if task_type == TaskType.TEXT_TO_IMAGE:
# character sheet
if "character sheet" in task.get_prompt().lower():
return pose(task, s3_outkey="", poses=pickPoses())
else:
return text2img(task)
elif task_type == TaskType.IMAGE_TO_IMAGE:
return img2img(task)
elif task_type == TaskType.CANNY:
return canny(task)
elif task_type == TaskType.POSE:
return pose(task)
else:
raise Exception("Invalid task type")
except Exception as e:
print(f"Error: {e}")
slack.error_alert(task, e)
return None