Upload folder using huggingface_hub
Browse files- .gitignore +1 -0
- inference.py +180 -15
- internals/data/task.py +5 -2
- internals/pipelines/object_remove.py +6 -0
- internals/pipelines/realtime_draw.py +97 -0
- internals/pipelines/sdxl_tile_upscale.py +27 -13
- internals/pipelines/upscaler.py +1 -1
- internals/util/commons.py +13 -0
.gitignore
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
*.pyc
|
2 |
.DS_Store
|
3 |
.ipynb_checkpoints █
|
|
|
4 |
env
|
5 |
test.py
|
6 |
*.jpeg
|
|
|
1 |
*.pyc
|
2 |
.DS_Store
|
3 |
.ipynb_checkpoints █
|
4 |
+
.vscode
|
5 |
env
|
6 |
test.py
|
7 |
*.jpeg
|
inference.py
CHANGED
@@ -4,26 +4,38 @@ from typing import List, Optional
|
|
4 |
|
5 |
import pydash as _
|
6 |
import torch
|
|
|
7 |
from numpy import who
|
8 |
|
9 |
import internals.util.prompt as prompt_util
|
10 |
from internals.data.dataAccessor import update_db, update_db_source_failed
|
11 |
-
from internals.data.task import Task, TaskType
|
12 |
from internals.pipelines.commons import Img2Img, Text2Img
|
13 |
from internals.pipelines.controlnets import ControlNet
|
14 |
from internals.pipelines.high_res import HighRes
|
15 |
from internals.pipelines.img_classifier import ImageClassifier
|
16 |
from internals.pipelines.img_to_text import Image2Text
|
17 |
from internals.pipelines.inpainter import InPainter
|
|
|
18 |
from internals.pipelines.pose_detector import PoseDetector
|
19 |
from internals.pipelines.prompt_modifier import PromptModifier
|
|
|
|
|
20 |
from internals.pipelines.replace_background import ReplaceBackground
|
21 |
from internals.pipelines.safety_checker import SafetyChecker
|
22 |
from internals.pipelines.sdxl_tile_upscale import SDXLTileUpscaler
|
|
|
23 |
from internals.util.args import apply_style_args
|
24 |
from internals.util.avatar import Avatar
|
25 |
from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda, clear_cuda_and_gc
|
26 |
-
from internals.util.commons import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
from internals.util.config import (
|
28 |
get_is_sdxl,
|
29 |
get_model_dir,
|
@@ -43,11 +55,15 @@ torch.backends.cuda.matmul.allow_tf32 = True
|
|
43 |
auto_mode = False
|
44 |
|
45 |
prompt_modifier = PromptModifier(num_of_sequences=num_return_sequences)
|
|
|
46 |
pose_detector = PoseDetector()
|
47 |
inpainter = InPainter()
|
48 |
high_res = HighRes()
|
49 |
img2text = Image2Text()
|
50 |
img_classifier = ImageClassifier()
|
|
|
|
|
|
|
51 |
replace_background = ReplaceBackground()
|
52 |
controlnet = ControlNet()
|
53 |
lora_style = LoraStyle()
|
@@ -56,6 +72,7 @@ img2img_pipe = Img2Img()
|
|
56 |
safety_checker = SafetyChecker()
|
57 |
slack = Slack()
|
58 |
avatar = Avatar()
|
|
|
59 |
sdxl_tileupscaler = SDXLTileUpscaler()
|
60 |
|
61 |
|
@@ -149,7 +166,9 @@ def tile_upscale(task: Task):
|
|
149 |
prompt = get_patched_prompt_tile_upscale(task)
|
150 |
|
151 |
if get_is_sdxl():
|
152 |
-
lora_patcher = lora_style.get_patcher(
|
|
|
|
|
153 |
lora_patcher.patch()
|
154 |
|
155 |
images, has_nsfw = sdxl_tileupscaler.process(
|
@@ -555,6 +574,124 @@ def replace_bg(task: Task):
|
|
555 |
}
|
556 |
|
557 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
558 |
def custom_action(task: Task):
|
559 |
from external.scripts import __scripts__
|
560 |
|
@@ -581,7 +718,7 @@ def custom_action(task: Task):
|
|
581 |
return script(task, data)
|
582 |
|
583 |
|
584 |
-
def load_model_by_task(
|
585 |
if not text2img_pipe.is_loaded():
|
586 |
text2img_pipe.load(get_model_dir())
|
587 |
img2img_pipe.create(text2img_pipe)
|
@@ -593,24 +730,30 @@ def load_model_by_task(task: Task):
|
|
593 |
safety_checker.apply(text2img_pipe)
|
594 |
safety_checker.apply(img2img_pipe)
|
595 |
|
596 |
-
if
|
597 |
inpainter.load()
|
598 |
safety_checker.apply(inpainter)
|
599 |
-
elif
|
600 |
replace_background.load(base=text2img_pipe, high_res=high_res)
|
|
|
|
|
|
|
|
|
|
|
|
|
601 |
else:
|
602 |
-
if
|
603 |
if get_is_sdxl():
|
604 |
-
sdxl_tileupscaler.create(text2img_pipe,
|
605 |
else:
|
606 |
controlnet.load_model("tile_upscaler")
|
607 |
-
elif
|
608 |
controlnet.load_model("canny")
|
609 |
-
elif
|
610 |
controlnet.load_model("scribble")
|
611 |
-
elif
|
612 |
controlnet.load_model("linearart")
|
613 |
-
elif
|
614 |
controlnet.load_model("pose")
|
615 |
|
616 |
safety_checker.apply(controlnet)
|
@@ -629,6 +772,8 @@ def model_fn(model_dir):
|
|
629 |
|
630 |
lora_style.load(model_dir)
|
631 |
|
|
|
|
|
632 |
print("Logs: model loaded ....")
|
633 |
return
|
634 |
|
@@ -643,11 +788,21 @@ def predict_fn(data, pipe):
|
|
643 |
FailureHandler.handle(task)
|
644 |
|
645 |
try:
|
|
|
|
|
646 |
# Set set_environment
|
647 |
set_configs_from_task(task)
|
648 |
|
649 |
# Load model based on task
|
650 |
-
load_model_by_task(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
651 |
|
652 |
# Apply arguments
|
653 |
apply_style_args(data)
|
@@ -658,8 +813,6 @@ def predict_fn(data, pipe):
|
|
658 |
# Fetch avatars
|
659 |
avatar.fetch_from_network(task.get_model_id())
|
660 |
|
661 |
-
task_type = task.get_type()
|
662 |
-
|
663 |
if task_type == TaskType.TEXT_TO_IMAGE:
|
664 |
# character sheet
|
665 |
# if "character sheet" in task.get_prompt().lower():
|
@@ -684,8 +837,20 @@ def predict_fn(data, pipe):
|
|
684 |
return replace_bg(task)
|
685 |
elif task_type == TaskType.CUSTOM_ACTION:
|
686 |
return custom_action(task)
|
|
|
|
|
|
|
|
|
|
|
|
|
687 |
elif task_type == TaskType.SYSTEM_CMD:
|
688 |
os.system(task.get_prompt())
|
|
|
|
|
|
|
|
|
|
|
|
|
689 |
else:
|
690 |
raise Exception("Invalid task type")
|
691 |
except Exception as e:
|
|
|
4 |
|
5 |
import pydash as _
|
6 |
import torch
|
7 |
+
from botocore.vendored.six import BytesIO
|
8 |
from numpy import who
|
9 |
|
10 |
import internals.util.prompt as prompt_util
|
11 |
from internals.data.dataAccessor import update_db, update_db_source_failed
|
12 |
+
from internals.data.task import ModelType, Task, TaskType
|
13 |
from internals.pipelines.commons import Img2Img, Text2Img
|
14 |
from internals.pipelines.controlnets import ControlNet
|
15 |
from internals.pipelines.high_res import HighRes
|
16 |
from internals.pipelines.img_classifier import ImageClassifier
|
17 |
from internals.pipelines.img_to_text import Image2Text
|
18 |
from internals.pipelines.inpainter import InPainter
|
19 |
+
from internals.pipelines.object_remove import ObjectRemoval
|
20 |
from internals.pipelines.pose_detector import PoseDetector
|
21 |
from internals.pipelines.prompt_modifier import PromptModifier
|
22 |
+
from internals.pipelines.realtime_draw import RealtimeDraw
|
23 |
+
from internals.pipelines.remove_background import RemoveBackgroundV2
|
24 |
from internals.pipelines.replace_background import ReplaceBackground
|
25 |
from internals.pipelines.safety_checker import SafetyChecker
|
26 |
from internals.pipelines.sdxl_tile_upscale import SDXLTileUpscaler
|
27 |
+
from internals.pipelines.upscaler import Upscaler
|
28 |
from internals.util.args import apply_style_args
|
29 |
from internals.util.avatar import Avatar
|
30 |
from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda, clear_cuda_and_gc
|
31 |
+
from internals.util.commons import (
|
32 |
+
base64_to_image,
|
33 |
+
construct_default_s3_url,
|
34 |
+
download_image,
|
35 |
+
image_to_base64,
|
36 |
+
upload_image,
|
37 |
+
upload_images,
|
38 |
+
)
|
39 |
from internals.util.config import (
|
40 |
get_is_sdxl,
|
41 |
get_model_dir,
|
|
|
55 |
auto_mode = False
|
56 |
|
57 |
prompt_modifier = PromptModifier(num_of_sequences=num_return_sequences)
|
58 |
+
upscaler = Upscaler()
|
59 |
pose_detector = PoseDetector()
|
60 |
inpainter = InPainter()
|
61 |
high_res = HighRes()
|
62 |
img2text = Image2Text()
|
63 |
img_classifier = ImageClassifier()
|
64 |
+
object_removal = ObjectRemoval()
|
65 |
+
replace_background = ReplaceBackground()
|
66 |
+
remove_background_v2 = RemoveBackgroundV2()
|
67 |
replace_background = ReplaceBackground()
|
68 |
controlnet = ControlNet()
|
69 |
lora_style = LoraStyle()
|
|
|
72 |
safety_checker = SafetyChecker()
|
73 |
slack = Slack()
|
74 |
avatar = Avatar()
|
75 |
+
realtime_draw = RealtimeDraw()
|
76 |
sdxl_tileupscaler = SDXLTileUpscaler()
|
77 |
|
78 |
|
|
|
166 |
prompt = get_patched_prompt_tile_upscale(task)
|
167 |
|
168 |
if get_is_sdxl():
|
169 |
+
lora_patcher = lora_style.get_patcher(
|
170 |
+
[sdxl_tileupscaler.pipe, high_res.pipe], task.get_style()
|
171 |
+
)
|
172 |
lora_patcher.patch()
|
173 |
|
174 |
images, has_nsfw = sdxl_tileupscaler.process(
|
|
|
574 |
}
|
575 |
|
576 |
|
577 |
+
@update_db
|
578 |
+
@slack.auto_send_alert
|
579 |
+
def remove_bg(task: Task):
|
580 |
+
output_image = remove_background_v2.remove(
|
581 |
+
task.get_imageUrl(), model_type=task.get_modelType()
|
582 |
+
)
|
583 |
+
|
584 |
+
output_key = "crecoAI/{}_rmbg.png".format(task.get_taskId())
|
585 |
+
image_url = upload_image(output_image, output_key)
|
586 |
+
|
587 |
+
return {"generated_image_url": image_url}
|
588 |
+
|
589 |
+
|
590 |
+
@update_db
|
591 |
+
@slack.auto_send_alert
|
592 |
+
def upscale_image(task: Task):
|
593 |
+
output_key = "crecoAI/{}_upscale.png".format(task.get_taskId())
|
594 |
+
out_img = None
|
595 |
+
if (
|
596 |
+
task.get_modelType() == ModelType.ANIME
|
597 |
+
or task.get_modelType() == ModelType.COMIC
|
598 |
+
):
|
599 |
+
print("Using Anime model")
|
600 |
+
out_img = upscaler.upscale_anime(
|
601 |
+
image=task.get_imageUrl(),
|
602 |
+
width=task.get_width(),
|
603 |
+
height=task.get_height(),
|
604 |
+
face_enhance=task.get_face_enhance(),
|
605 |
+
resize_dimension=task.get_resize_dimension(),
|
606 |
+
)
|
607 |
+
else:
|
608 |
+
print("Using Real model")
|
609 |
+
out_img = upscaler.upscale(
|
610 |
+
image=task.get_imageUrl(),
|
611 |
+
width=task.get_width(),
|
612 |
+
height=task.get_height(),
|
613 |
+
face_enhance=task.get_face_enhance(),
|
614 |
+
resize_dimension=task.get_resize_dimension(),
|
615 |
+
)
|
616 |
+
|
617 |
+
image_url = upload_image(BytesIO(out_img), output_key)
|
618 |
+
|
619 |
+
clear_cuda_and_gc()
|
620 |
+
|
621 |
+
return {"generated_image_url": image_url}
|
622 |
+
|
623 |
+
|
624 |
+
@update_db
|
625 |
+
@slack.auto_send_alert
|
626 |
+
def remove_object(task: Task):
|
627 |
+
output_key = "crecoAI/{}_object_remove.png".format(task.get_taskId())
|
628 |
+
|
629 |
+
images = object_removal.process(
|
630 |
+
image_url=task.get_imageUrl(),
|
631 |
+
mask_image_url=task.get_maskImageUrl(),
|
632 |
+
seed=task.get_seed(),
|
633 |
+
width=task.get_width(),
|
634 |
+
height=task.get_height(),
|
635 |
+
)
|
636 |
+
generated_image_urls = upload_image(images[0], output_key)
|
637 |
+
|
638 |
+
clear_cuda()
|
639 |
+
|
640 |
+
return {"generated_image_urls": generated_image_urls}
|
641 |
+
|
642 |
+
|
643 |
+
def rt_draw_seg(task: Task):
|
644 |
+
image = task.get_imageUrl()
|
645 |
+
if image.startswith("http"):
|
646 |
+
image = download_image(image)
|
647 |
+
else: # consider image as base64
|
648 |
+
image = base64_to_image(image)
|
649 |
+
|
650 |
+
img = realtime_draw.process_seg(
|
651 |
+
image=image,
|
652 |
+
prompt=task.get_prompt(),
|
653 |
+
negative_prompt=task.get_negative_prompt(),
|
654 |
+
seed=task.get_seed(),
|
655 |
+
)
|
656 |
+
|
657 |
+
clear_cuda_and_gc()
|
658 |
+
|
659 |
+
base64_image = image_to_base64(img)
|
660 |
+
|
661 |
+
return {"image": base64_image}
|
662 |
+
|
663 |
+
|
664 |
+
def rt_draw_img(task: Task):
|
665 |
+
image = task.get_imageUrl()
|
666 |
+
aux_image = task.get_auxilary_imageUrl()
|
667 |
+
|
668 |
+
if image:
|
669 |
+
if image.startswith("http"):
|
670 |
+
image = download_image(image)
|
671 |
+
else: # consider image as base64
|
672 |
+
image = base64_to_image(image)
|
673 |
+
|
674 |
+
if aux_image:
|
675 |
+
if aux_image.startswith("http"):
|
676 |
+
aux_image = download_image(aux_image)
|
677 |
+
else: # consider image as base64
|
678 |
+
aux_image = base64_to_image(aux_image)
|
679 |
+
|
680 |
+
img = realtime_draw.process_img(
|
681 |
+
image=image, # pyright: ignore
|
682 |
+
image2=aux_image, # pyright: ignore
|
683 |
+
prompt=task.get_prompt(),
|
684 |
+
negative_prompt=task.get_negative_prompt(),
|
685 |
+
seed=task.get_seed(),
|
686 |
+
)
|
687 |
+
|
688 |
+
clear_cuda_and_gc()
|
689 |
+
|
690 |
+
base64_image = image_to_base64(img)
|
691 |
+
|
692 |
+
return {"image": base64_image}
|
693 |
+
|
694 |
+
|
695 |
def custom_action(task: Task):
|
696 |
from external.scripts import __scripts__
|
697 |
|
|
|
718 |
return script(task, data)
|
719 |
|
720 |
|
721 |
+
def load_model_by_task(task_type: TaskType, model_id=-1):
|
722 |
if not text2img_pipe.is_loaded():
|
723 |
text2img_pipe.load(get_model_dir())
|
724 |
img2img_pipe.create(text2img_pipe)
|
|
|
730 |
safety_checker.apply(text2img_pipe)
|
731 |
safety_checker.apply(img2img_pipe)
|
732 |
|
733 |
+
if task_type == TaskType.INPAINT:
|
734 |
inpainter.load()
|
735 |
safety_checker.apply(inpainter)
|
736 |
+
elif task_type == TaskType.REPLACE_BG:
|
737 |
replace_background.load(base=text2img_pipe, high_res=high_res)
|
738 |
+
elif task_type == TaskType.RT_DRAW_SEG or task_type == TaskType.RT_DRAW_IMG:
|
739 |
+
realtime_draw.load(text2img_pipe)
|
740 |
+
elif task_type == TaskType.OBJECT_REMOVAL:
|
741 |
+
object_removal.load(get_model_dir())
|
742 |
+
elif task_type == TaskType.UPSCALE_IMAGE:
|
743 |
+
upscaler.load()
|
744 |
else:
|
745 |
+
if task_type == TaskType.TILE_UPSCALE:
|
746 |
if get_is_sdxl():
|
747 |
+
sdxl_tileupscaler.create(high_res, text2img_pipe, model_id)
|
748 |
else:
|
749 |
controlnet.load_model("tile_upscaler")
|
750 |
+
elif task_type == TaskType.CANNY:
|
751 |
controlnet.load_model("canny")
|
752 |
+
elif task_type == TaskType.SCRIBBLE:
|
753 |
controlnet.load_model("scribble")
|
754 |
+
elif task_type == TaskType.LINEARART:
|
755 |
controlnet.load_model("linearart")
|
756 |
+
elif task_type == TaskType.POSE:
|
757 |
controlnet.load_model("pose")
|
758 |
|
759 |
safety_checker.apply(controlnet)
|
|
|
772 |
|
773 |
lora_style.load(model_dir)
|
774 |
|
775 |
+
load_model_by_task(TaskType.TEXT_TO_IMAGE)
|
776 |
+
|
777 |
print("Logs: model loaded ....")
|
778 |
return
|
779 |
|
|
|
788 |
FailureHandler.handle(task)
|
789 |
|
790 |
try:
|
791 |
+
task_type = task.get_type()
|
792 |
+
|
793 |
# Set set_environment
|
794 |
set_configs_from_task(task)
|
795 |
|
796 |
# Load model based on task
|
797 |
+
load_model_by_task(
|
798 |
+
task.get_type() or TaskType.TEXT_TO_IMAGE, task.get_model_id()
|
799 |
+
)
|
800 |
+
|
801 |
+
# Realtime generation apis
|
802 |
+
if task_type == TaskType.RT_DRAW_SEG:
|
803 |
+
return rt_draw_seg(task)
|
804 |
+
if task_type == TaskType.RT_DRAW_IMG:
|
805 |
+
return rt_draw_img(task)
|
806 |
|
807 |
# Apply arguments
|
808 |
apply_style_args(data)
|
|
|
813 |
# Fetch avatars
|
814 |
avatar.fetch_from_network(task.get_model_id())
|
815 |
|
|
|
|
|
816 |
if task_type == TaskType.TEXT_TO_IMAGE:
|
817 |
# character sheet
|
818 |
# if "character sheet" in task.get_prompt().lower():
|
|
|
837 |
return replace_bg(task)
|
838 |
elif task_type == TaskType.CUSTOM_ACTION:
|
839 |
return custom_action(task)
|
840 |
+
elif task_type == TaskType.REMOVE_BG:
|
841 |
+
return remove_bg(task)
|
842 |
+
elif task_type == TaskType.UPSCALE_IMAGE:
|
843 |
+
return upscale_image(task)
|
844 |
+
elif task_type == TaskType.OBJECT_REMOVAL:
|
845 |
+
return remove_object(task)
|
846 |
elif task_type == TaskType.SYSTEM_CMD:
|
847 |
os.system(task.get_prompt())
|
848 |
+
elif task_type == TaskType.PRELOAD_MODEL:
|
849 |
+
try:
|
850 |
+
task_type = TaskType(task.get_prompt())
|
851 |
+
except:
|
852 |
+
task_type = TaskType.SYSTEM_CMD
|
853 |
+
load_model_by_task(task_type)
|
854 |
else:
|
855 |
raise Exception("Invalid task type")
|
856 |
except Exception as e:
|
internals/data/task.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
from enum import Enum
|
2 |
from functools import lru_cache
|
3 |
-
from typing import Union
|
4 |
|
5 |
import numpy as np
|
6 |
|
@@ -18,6 +18,9 @@ class TaskType(Enum):
|
|
18 |
SCRIBBLE = "SCRIBBLE"
|
19 |
LINEARART = "LINEARART"
|
20 |
REPLACE_BG = "REPLACE_BG"
|
|
|
|
|
|
|
21 |
CUSTOM_ACTION = "CUSTOM_ACTION"
|
22 |
SYSTEM_CMD = "SYSTEM_CMD"
|
23 |
|
@@ -52,7 +55,7 @@ class Task:
|
|
52 |
def get_imageUrl(self) -> str:
|
53 |
return self.__data.get("imageUrl", None)
|
54 |
|
55 |
-
def get_auxilary_imageUrl(self) -> str:
|
56 |
return self.__data.get("aux_imageUrl", None)
|
57 |
|
58 |
def get_prompt(self) -> str:
|
|
|
1 |
from enum import Enum
|
2 |
from functools import lru_cache
|
3 |
+
from typing import Optional, Union
|
4 |
|
5 |
import numpy as np
|
6 |
|
|
|
18 |
SCRIBBLE = "SCRIBBLE"
|
19 |
LINEARART = "LINEARART"
|
20 |
REPLACE_BG = "REPLACE_BG"
|
21 |
+
RT_DRAW_SEG = "RT_DRAW_SEG"
|
22 |
+
RT_DRAW_IMG = "RT_DRAW_IMG"
|
23 |
+
PRELOAD_MODEL = "PRELOAD_MODEL"
|
24 |
CUSTOM_ACTION = "CUSTOM_ACTION"
|
25 |
SYSTEM_CMD = "SYSTEM_CMD"
|
26 |
|
|
|
55 |
def get_imageUrl(self) -> str:
|
56 |
return self.__data.get("imageUrl", None)
|
57 |
|
58 |
+
def get_auxilary_imageUrl(self) -> Optional[str]:
|
59 |
return self.__data.get("aux_imageUrl", None)
|
60 |
|
61 |
def get_prompt(self) -> str:
|
internals/pipelines/object_remove.py
CHANGED
@@ -18,7 +18,11 @@ from saicinpainting.training.trainers import load_checkpoint
|
|
18 |
|
19 |
|
20 |
class ObjectRemoval:
|
|
|
|
|
21 |
def load(self, model_dir):
|
|
|
|
|
22 |
print("Downloading LAMA model...")
|
23 |
|
24 |
self.lama_path = Path.home() / ".cache" / "lama"
|
@@ -36,6 +40,8 @@ class ObjectRemoval:
|
|
36 |
self.model.freeze()
|
37 |
self.model.to("cuda")
|
38 |
|
|
|
|
|
39 |
@torch.no_grad()
|
40 |
def process(
|
41 |
self,
|
|
|
18 |
|
19 |
|
20 |
class ObjectRemoval:
|
21 |
+
__loaded = False
|
22 |
+
|
23 |
def load(self, model_dir):
|
24 |
+
if self.__loaded:
|
25 |
+
return
|
26 |
print("Downloading LAMA model...")
|
27 |
|
28 |
self.lama_path = Path.home() / ".cache" / "lama"
|
|
|
40 |
self.model.freeze()
|
41 |
self.model.to("cuda")
|
42 |
|
43 |
+
self.__loaded = True
|
44 |
+
|
45 |
@torch.no_grad()
|
46 |
def process(
|
47 |
self,
|
internals/pipelines/realtime_draw.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from diffusers import ControlNetModel, StableDiffusionControlNetImg2ImgPipeline
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
import internals.util.image as ImageUtil
|
8 |
+
from internals.pipelines.commons import AbstractPipeline
|
9 |
+
from internals.pipelines.controlnets import ControlNet
|
10 |
+
from internals.util.config import get_hf_cache_dir
|
11 |
+
|
12 |
+
|
13 |
+
class RealtimeDraw(AbstractPipeline):
|
14 |
+
def load(self, pipeline: AbstractPipeline):
|
15 |
+
if hasattr(self, "pipe"):
|
16 |
+
return
|
17 |
+
|
18 |
+
self.__controlnet_scribble = ControlNetModel.from_pretrained(
|
19 |
+
"lllyasviel/control_v11p_sd15_scribble",
|
20 |
+
torch_dtype=torch.float16,
|
21 |
+
cache_dir=get_hf_cache_dir(),
|
22 |
+
)
|
23 |
+
|
24 |
+
self.__controlnet_seg = ControlNetModel.from_pretrained(
|
25 |
+
"lllyasviel/control_v11p_sd15_seg",
|
26 |
+
torch_dtype=torch.float16,
|
27 |
+
cache_dir=get_hf_cache_dir(),
|
28 |
+
)
|
29 |
+
|
30 |
+
kwargs = {**pipeline.pipe.components} # pyright: ignore
|
31 |
+
kwargs.pop("image_encoder", None)
|
32 |
+
self.pipe = StableDiffusionControlNetImg2ImgPipeline(
|
33 |
+
**kwargs, controlnet=self.__controlnet_seg
|
34 |
+
).to("cuda")
|
35 |
+
self.pipe.safety_checker = None
|
36 |
+
self.pipe2 = StableDiffusionControlNetImg2ImgPipeline(
|
37 |
+
**kwargs, controlnet=[self.__controlnet_scribble, self.__controlnet_seg]
|
38 |
+
).to("cuda")
|
39 |
+
self.pipe2.safety_checker = None
|
40 |
+
|
41 |
+
def process_seg(
|
42 |
+
self,
|
43 |
+
image: Image.Image,
|
44 |
+
prompt: str,
|
45 |
+
negative_prompt: str,
|
46 |
+
seed: int,
|
47 |
+
):
|
48 |
+
torch.manual_seed(seed)
|
49 |
+
|
50 |
+
image = ImageUtil.resize_image(image, 512)
|
51 |
+
|
52 |
+
img = self.pipe.__call__(
|
53 |
+
image=image,
|
54 |
+
control_image=image,
|
55 |
+
prompt=prompt,
|
56 |
+
num_inference_steps=15,
|
57 |
+
negative_prompt=negative_prompt,
|
58 |
+
guidance_scale=10,
|
59 |
+
strength=0.8,
|
60 |
+
).images[0]
|
61 |
+
|
62 |
+
return img
|
63 |
+
|
64 |
+
def process_img(
|
65 |
+
self,
|
66 |
+
prompt: str,
|
67 |
+
negative_prompt: str,
|
68 |
+
seed: int,
|
69 |
+
image: Optional[Image.Image] = None,
|
70 |
+
image2: Optional[Image.Image] = None,
|
71 |
+
):
|
72 |
+
torch.manual_seed(seed)
|
73 |
+
|
74 |
+
if not image:
|
75 |
+
image = Image.new("RGB", (512, 512), color=0)
|
76 |
+
|
77 |
+
if not image2:
|
78 |
+
image2 = Image.new("RGB", image.size, color=0)
|
79 |
+
|
80 |
+
image = ImageUtil.resize_image(image, 512)
|
81 |
+
|
82 |
+
scribble = ControlNet.scribble_image(image)
|
83 |
+
|
84 |
+
image2 = ImageUtil.resize_image(image2, 512)
|
85 |
+
|
86 |
+
img = self.pipe2.__call__(
|
87 |
+
image=image,
|
88 |
+
control_image=[scribble, image2],
|
89 |
+
prompt=prompt,
|
90 |
+
num_inference_steps=15,
|
91 |
+
negative_prompt=negative_prompt,
|
92 |
+
guidance_scale=10,
|
93 |
+
strength=0.9,
|
94 |
+
controlnet_conditioning_scale=[1.0, 0.8],
|
95 |
+
).images[0]
|
96 |
+
|
97 |
+
return img
|
internals/pipelines/sdxl_tile_upscale.py
CHANGED
@@ -4,10 +4,12 @@ from PIL import Image
|
|
4 |
from torchvision import transforms
|
5 |
|
6 |
import internals.util.image as ImageUtils
|
|
|
7 |
from internals.data.result import Result
|
8 |
from internals.pipelines.commons import AbstractPipeline, Text2Img
|
9 |
from internals.pipelines.controlnets import ControlNet
|
10 |
from internals.pipelines.demofusion_sdxl import DemoFusionSDXLControlNetPipeline
|
|
|
11 |
from internals.util.commons import download_image
|
12 |
from internals.util.config import get_base_dimension
|
13 |
|
@@ -15,7 +17,7 @@ controlnet = ControlNet()
|
|
15 |
|
16 |
|
17 |
class SDXLTileUpscaler(AbstractPipeline):
|
18 |
-
def create(self, pipeline: Text2Img, model_id: int):
|
19 |
# temporal hack for upscale model till multicontrolnet support is added
|
20 |
model = (
|
21 |
"thibaud/controlnet-openpose-sdxl-1.0"
|
@@ -32,6 +34,8 @@ class SDXLTileUpscaler(AbstractPipeline):
|
|
32 |
pipe.enable_vae_slicing()
|
33 |
pipe.enable_xformers_memory_efficient_attention()
|
34 |
|
|
|
|
|
35 |
self.pipe = pipe
|
36 |
|
37 |
def process(
|
@@ -58,18 +62,28 @@ class SDXLTileUpscaler(AbstractPipeline):
|
|
58 |
|
59 |
image_lr = self.load_and_process_image(img)
|
60 |
print("img", img2.size, img.size)
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
return images, False
|
74 |
|
75 |
def load_and_process_image(self, pil_image):
|
|
|
4 |
from torchvision import transforms
|
5 |
|
6 |
import internals.util.image as ImageUtils
|
7 |
+
from carvekit.api import high
|
8 |
from internals.data.result import Result
|
9 |
from internals.pipelines.commons import AbstractPipeline, Text2Img
|
10 |
from internals.pipelines.controlnets import ControlNet
|
11 |
from internals.pipelines.demofusion_sdxl import DemoFusionSDXLControlNetPipeline
|
12 |
+
from internals.pipelines.high_res import HighRes
|
13 |
from internals.util.commons import download_image
|
14 |
from internals.util.config import get_base_dimension
|
15 |
|
|
|
17 |
|
18 |
|
19 |
class SDXLTileUpscaler(AbstractPipeline):
|
20 |
+
def create(self, high_res: HighRes, pipeline: Text2Img, model_id: int):
|
21 |
# temporal hack for upscale model till multicontrolnet support is added
|
22 |
model = (
|
23 |
"thibaud/controlnet-openpose-sdxl-1.0"
|
|
|
34 |
pipe.enable_vae_slicing()
|
35 |
pipe.enable_xformers_memory_efficient_attention()
|
36 |
|
37 |
+
self.high_res = high_res
|
38 |
+
|
39 |
self.pipe = pipe
|
40 |
|
41 |
def process(
|
|
|
62 |
|
63 |
image_lr = self.load_and_process_image(img)
|
64 |
print("img", img2.size, img.size)
|
65 |
+
if int(model_id) == 2000173:
|
66 |
+
kwargs = {
|
67 |
+
"prompt": prompt,
|
68 |
+
"negative_prompt": negative_prompt,
|
69 |
+
"image": img2,
|
70 |
+
"strength": 0.3,
|
71 |
+
"num_inference_steps": 30,
|
72 |
+
}
|
73 |
+
images = self.high_res.pipe.__call__(**kwargs).images
|
74 |
+
else:
|
75 |
+
images = self.pipe.__call__(
|
76 |
+
image_lr=image_lr,
|
77 |
+
prompt=prompt,
|
78 |
+
condition_image=condition_image,
|
79 |
+
negative_prompt="blurry, ugly, duplicate, poorly drawn, deformed, mosaic",
|
80 |
+
guidance_scale=11,
|
81 |
+
sigma=0.8,
|
82 |
+
num_inference_steps=24,
|
83 |
+
width=img2.size[0],
|
84 |
+
height=img2.size[1],
|
85 |
+
)
|
86 |
+
images = images[::-1]
|
87 |
return images, False
|
88 |
|
89 |
def load_and_process_image(self, pil_image):
|
internals/pipelines/upscaler.py
CHANGED
@@ -148,7 +148,7 @@ class Upscaler:
|
|
148 |
model=model,
|
149 |
half=False,
|
150 |
gpu_id="0",
|
151 |
-
tile=
|
152 |
tile_pad=10,
|
153 |
pre_pad=0,
|
154 |
)
|
|
|
148 |
model=model,
|
149 |
half=False,
|
150 |
gpu_id="0",
|
151 |
+
tile=320,
|
152 |
tile_pad=10,
|
153 |
pre_pad=0,
|
154 |
)
|
internals/util/commons.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import json
|
2 |
import os
|
3 |
import pprint
|
@@ -163,6 +164,18 @@ def download_file(url, out_path: Path):
|
|
163 |
f.write(chunk)
|
164 |
|
165 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
def pickPoses():
|
167 |
random_images = random.sample(characterSheets, 4)
|
168 |
poses = []
|
|
|
1 |
+
import base64
|
2 |
import json
|
3 |
import os
|
4 |
import pprint
|
|
|
164 |
f.write(chunk)
|
165 |
|
166 |
|
167 |
+
def base64_to_image(base64_string):
|
168 |
+
imgdata = base64.b64decode(base64_string)
|
169 |
+
return Image.open(io.BytesIO(imgdata)).convert("RGB")
|
170 |
+
|
171 |
+
|
172 |
+
def image_to_base64(image):
|
173 |
+
buffered = BytesIO()
|
174 |
+
image.save(buffered, format="PNG")
|
175 |
+
img_str = base64.b64encode(buffered.getvalue())
|
176 |
+
return img_str.decode("ascii")
|
177 |
+
|
178 |
+
|
179 |
def pickPoses():
|
180 |
random_images = random.sample(characterSheets, 4)
|
181 |
poses = []
|