jayparmr commited on
Commit
99a0484
·
1 Parent(s): 1cd09a3

Upload folder using huggingface_hub

Browse files
.gitignore CHANGED
@@ -1,6 +1,7 @@
1
  *.pyc
2
  .DS_Store
3
  .ipynb_checkpoints █
 
4
  env
5
  test.py
6
  *.jpeg
 
1
  *.pyc
2
  .DS_Store
3
  .ipynb_checkpoints █
4
+ .vscode
5
  env
6
  test.py
7
  *.jpeg
inference.py CHANGED
@@ -4,26 +4,38 @@ from typing import List, Optional
4
 
5
  import pydash as _
6
  import torch
 
7
  from numpy import who
8
 
9
  import internals.util.prompt as prompt_util
10
  from internals.data.dataAccessor import update_db, update_db_source_failed
11
- from internals.data.task import Task, TaskType
12
  from internals.pipelines.commons import Img2Img, Text2Img
13
  from internals.pipelines.controlnets import ControlNet
14
  from internals.pipelines.high_res import HighRes
15
  from internals.pipelines.img_classifier import ImageClassifier
16
  from internals.pipelines.img_to_text import Image2Text
17
  from internals.pipelines.inpainter import InPainter
 
18
  from internals.pipelines.pose_detector import PoseDetector
19
  from internals.pipelines.prompt_modifier import PromptModifier
 
 
20
  from internals.pipelines.replace_background import ReplaceBackground
21
  from internals.pipelines.safety_checker import SafetyChecker
22
  from internals.pipelines.sdxl_tile_upscale import SDXLTileUpscaler
 
23
  from internals.util.args import apply_style_args
24
  from internals.util.avatar import Avatar
25
  from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda, clear_cuda_and_gc
26
- from internals.util.commons import download_image, upload_image, upload_images
 
 
 
 
 
 
 
27
  from internals.util.config import (
28
  get_is_sdxl,
29
  get_model_dir,
@@ -43,11 +55,15 @@ torch.backends.cuda.matmul.allow_tf32 = True
43
  auto_mode = False
44
 
45
  prompt_modifier = PromptModifier(num_of_sequences=num_return_sequences)
 
46
  pose_detector = PoseDetector()
47
  inpainter = InPainter()
48
  high_res = HighRes()
49
  img2text = Image2Text()
50
  img_classifier = ImageClassifier()
 
 
 
51
  replace_background = ReplaceBackground()
52
  controlnet = ControlNet()
53
  lora_style = LoraStyle()
@@ -56,6 +72,7 @@ img2img_pipe = Img2Img()
56
  safety_checker = SafetyChecker()
57
  slack = Slack()
58
  avatar = Avatar()
 
59
  sdxl_tileupscaler = SDXLTileUpscaler()
60
 
61
 
@@ -149,7 +166,9 @@ def tile_upscale(task: Task):
149
  prompt = get_patched_prompt_tile_upscale(task)
150
 
151
  if get_is_sdxl():
152
- lora_patcher = lora_style.get_patcher(sdxl_tileupscaler.pipe, task.get_style())
 
 
153
  lora_patcher.patch()
154
 
155
  images, has_nsfw = sdxl_tileupscaler.process(
@@ -555,6 +574,124 @@ def replace_bg(task: Task):
555
  }
556
 
557
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
558
  def custom_action(task: Task):
559
  from external.scripts import __scripts__
560
 
@@ -581,7 +718,7 @@ def custom_action(task: Task):
581
  return script(task, data)
582
 
583
 
584
- def load_model_by_task(task: Task):
585
  if not text2img_pipe.is_loaded():
586
  text2img_pipe.load(get_model_dir())
587
  img2img_pipe.create(text2img_pipe)
@@ -593,24 +730,30 @@ def load_model_by_task(task: Task):
593
  safety_checker.apply(text2img_pipe)
594
  safety_checker.apply(img2img_pipe)
595
 
596
- if task.get_type() == TaskType.INPAINT:
597
  inpainter.load()
598
  safety_checker.apply(inpainter)
599
- elif task.get_type() == TaskType.REPLACE_BG:
600
  replace_background.load(base=text2img_pipe, high_res=high_res)
 
 
 
 
 
 
601
  else:
602
- if task.get_type() == TaskType.TILE_UPSCALE:
603
  if get_is_sdxl():
604
- sdxl_tileupscaler.create(text2img_pipe, task.get_model_id())
605
  else:
606
  controlnet.load_model("tile_upscaler")
607
- elif task.get_type() == TaskType.CANNY:
608
  controlnet.load_model("canny")
609
- elif task.get_type() == TaskType.SCRIBBLE:
610
  controlnet.load_model("scribble")
611
- elif task.get_type() == TaskType.LINEARART:
612
  controlnet.load_model("linearart")
613
- elif task.get_type() == TaskType.POSE:
614
  controlnet.load_model("pose")
615
 
616
  safety_checker.apply(controlnet)
@@ -629,6 +772,8 @@ def model_fn(model_dir):
629
 
630
  lora_style.load(model_dir)
631
 
 
 
632
  print("Logs: model loaded ....")
633
  return
634
 
@@ -643,11 +788,21 @@ def predict_fn(data, pipe):
643
  FailureHandler.handle(task)
644
 
645
  try:
 
 
646
  # Set set_environment
647
  set_configs_from_task(task)
648
 
649
  # Load model based on task
650
- load_model_by_task(task)
 
 
 
 
 
 
 
 
651
 
652
  # Apply arguments
653
  apply_style_args(data)
@@ -658,8 +813,6 @@ def predict_fn(data, pipe):
658
  # Fetch avatars
659
  avatar.fetch_from_network(task.get_model_id())
660
 
661
- task_type = task.get_type()
662
-
663
  if task_type == TaskType.TEXT_TO_IMAGE:
664
  # character sheet
665
  # if "character sheet" in task.get_prompt().lower():
@@ -684,8 +837,20 @@ def predict_fn(data, pipe):
684
  return replace_bg(task)
685
  elif task_type == TaskType.CUSTOM_ACTION:
686
  return custom_action(task)
 
 
 
 
 
 
687
  elif task_type == TaskType.SYSTEM_CMD:
688
  os.system(task.get_prompt())
 
 
 
 
 
 
689
  else:
690
  raise Exception("Invalid task type")
691
  except Exception as e:
 
4
 
5
  import pydash as _
6
  import torch
7
+ from botocore.vendored.six import BytesIO
8
  from numpy import who
9
 
10
  import internals.util.prompt as prompt_util
11
  from internals.data.dataAccessor import update_db, update_db_source_failed
12
+ from internals.data.task import ModelType, Task, TaskType
13
  from internals.pipelines.commons import Img2Img, Text2Img
14
  from internals.pipelines.controlnets import ControlNet
15
  from internals.pipelines.high_res import HighRes
16
  from internals.pipelines.img_classifier import ImageClassifier
17
  from internals.pipelines.img_to_text import Image2Text
18
  from internals.pipelines.inpainter import InPainter
19
+ from internals.pipelines.object_remove import ObjectRemoval
20
  from internals.pipelines.pose_detector import PoseDetector
21
  from internals.pipelines.prompt_modifier import PromptModifier
22
+ from internals.pipelines.realtime_draw import RealtimeDraw
23
+ from internals.pipelines.remove_background import RemoveBackgroundV2
24
  from internals.pipelines.replace_background import ReplaceBackground
25
  from internals.pipelines.safety_checker import SafetyChecker
26
  from internals.pipelines.sdxl_tile_upscale import SDXLTileUpscaler
27
+ from internals.pipelines.upscaler import Upscaler
28
  from internals.util.args import apply_style_args
29
  from internals.util.avatar import Avatar
30
  from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda, clear_cuda_and_gc
31
+ from internals.util.commons import (
32
+ base64_to_image,
33
+ construct_default_s3_url,
34
+ download_image,
35
+ image_to_base64,
36
+ upload_image,
37
+ upload_images,
38
+ )
39
  from internals.util.config import (
40
  get_is_sdxl,
41
  get_model_dir,
 
55
  auto_mode = False
56
 
57
  prompt_modifier = PromptModifier(num_of_sequences=num_return_sequences)
58
+ upscaler = Upscaler()
59
  pose_detector = PoseDetector()
60
  inpainter = InPainter()
61
  high_res = HighRes()
62
  img2text = Image2Text()
63
  img_classifier = ImageClassifier()
64
+ object_removal = ObjectRemoval()
65
+ replace_background = ReplaceBackground()
66
+ remove_background_v2 = RemoveBackgroundV2()
67
  replace_background = ReplaceBackground()
68
  controlnet = ControlNet()
69
  lora_style = LoraStyle()
 
72
  safety_checker = SafetyChecker()
73
  slack = Slack()
74
  avatar = Avatar()
75
+ realtime_draw = RealtimeDraw()
76
  sdxl_tileupscaler = SDXLTileUpscaler()
77
 
78
 
 
166
  prompt = get_patched_prompt_tile_upscale(task)
167
 
168
  if get_is_sdxl():
169
+ lora_patcher = lora_style.get_patcher(
170
+ [sdxl_tileupscaler.pipe, high_res.pipe], task.get_style()
171
+ )
172
  lora_patcher.patch()
173
 
174
  images, has_nsfw = sdxl_tileupscaler.process(
 
574
  }
575
 
576
 
577
+ @update_db
578
+ @slack.auto_send_alert
579
+ def remove_bg(task: Task):
580
+ output_image = remove_background_v2.remove(
581
+ task.get_imageUrl(), model_type=task.get_modelType()
582
+ )
583
+
584
+ output_key = "crecoAI/{}_rmbg.png".format(task.get_taskId())
585
+ image_url = upload_image(output_image, output_key)
586
+
587
+ return {"generated_image_url": image_url}
588
+
589
+
590
+ @update_db
591
+ @slack.auto_send_alert
592
+ def upscale_image(task: Task):
593
+ output_key = "crecoAI/{}_upscale.png".format(task.get_taskId())
594
+ out_img = None
595
+ if (
596
+ task.get_modelType() == ModelType.ANIME
597
+ or task.get_modelType() == ModelType.COMIC
598
+ ):
599
+ print("Using Anime model")
600
+ out_img = upscaler.upscale_anime(
601
+ image=task.get_imageUrl(),
602
+ width=task.get_width(),
603
+ height=task.get_height(),
604
+ face_enhance=task.get_face_enhance(),
605
+ resize_dimension=task.get_resize_dimension(),
606
+ )
607
+ else:
608
+ print("Using Real model")
609
+ out_img = upscaler.upscale(
610
+ image=task.get_imageUrl(),
611
+ width=task.get_width(),
612
+ height=task.get_height(),
613
+ face_enhance=task.get_face_enhance(),
614
+ resize_dimension=task.get_resize_dimension(),
615
+ )
616
+
617
+ image_url = upload_image(BytesIO(out_img), output_key)
618
+
619
+ clear_cuda_and_gc()
620
+
621
+ return {"generated_image_url": image_url}
622
+
623
+
624
+ @update_db
625
+ @slack.auto_send_alert
626
+ def remove_object(task: Task):
627
+ output_key = "crecoAI/{}_object_remove.png".format(task.get_taskId())
628
+
629
+ images = object_removal.process(
630
+ image_url=task.get_imageUrl(),
631
+ mask_image_url=task.get_maskImageUrl(),
632
+ seed=task.get_seed(),
633
+ width=task.get_width(),
634
+ height=task.get_height(),
635
+ )
636
+ generated_image_urls = upload_image(images[0], output_key)
637
+
638
+ clear_cuda()
639
+
640
+ return {"generated_image_urls": generated_image_urls}
641
+
642
+
643
+ def rt_draw_seg(task: Task):
644
+ image = task.get_imageUrl()
645
+ if image.startswith("http"):
646
+ image = download_image(image)
647
+ else: # consider image as base64
648
+ image = base64_to_image(image)
649
+
650
+ img = realtime_draw.process_seg(
651
+ image=image,
652
+ prompt=task.get_prompt(),
653
+ negative_prompt=task.get_negative_prompt(),
654
+ seed=task.get_seed(),
655
+ )
656
+
657
+ clear_cuda_and_gc()
658
+
659
+ base64_image = image_to_base64(img)
660
+
661
+ return {"image": base64_image}
662
+
663
+
664
+ def rt_draw_img(task: Task):
665
+ image = task.get_imageUrl()
666
+ aux_image = task.get_auxilary_imageUrl()
667
+
668
+ if image:
669
+ if image.startswith("http"):
670
+ image = download_image(image)
671
+ else: # consider image as base64
672
+ image = base64_to_image(image)
673
+
674
+ if aux_image:
675
+ if aux_image.startswith("http"):
676
+ aux_image = download_image(aux_image)
677
+ else: # consider image as base64
678
+ aux_image = base64_to_image(aux_image)
679
+
680
+ img = realtime_draw.process_img(
681
+ image=image, # pyright: ignore
682
+ image2=aux_image, # pyright: ignore
683
+ prompt=task.get_prompt(),
684
+ negative_prompt=task.get_negative_prompt(),
685
+ seed=task.get_seed(),
686
+ )
687
+
688
+ clear_cuda_and_gc()
689
+
690
+ base64_image = image_to_base64(img)
691
+
692
+ return {"image": base64_image}
693
+
694
+
695
  def custom_action(task: Task):
696
  from external.scripts import __scripts__
697
 
 
718
  return script(task, data)
719
 
720
 
721
+ def load_model_by_task(task_type: TaskType, model_id=-1):
722
  if not text2img_pipe.is_loaded():
723
  text2img_pipe.load(get_model_dir())
724
  img2img_pipe.create(text2img_pipe)
 
730
  safety_checker.apply(text2img_pipe)
731
  safety_checker.apply(img2img_pipe)
732
 
733
+ if task_type == TaskType.INPAINT:
734
  inpainter.load()
735
  safety_checker.apply(inpainter)
736
+ elif task_type == TaskType.REPLACE_BG:
737
  replace_background.load(base=text2img_pipe, high_res=high_res)
738
+ elif task_type == TaskType.RT_DRAW_SEG or task_type == TaskType.RT_DRAW_IMG:
739
+ realtime_draw.load(text2img_pipe)
740
+ elif task_type == TaskType.OBJECT_REMOVAL:
741
+ object_removal.load(get_model_dir())
742
+ elif task_type == TaskType.UPSCALE_IMAGE:
743
+ upscaler.load()
744
  else:
745
+ if task_type == TaskType.TILE_UPSCALE:
746
  if get_is_sdxl():
747
+ sdxl_tileupscaler.create(high_res, text2img_pipe, model_id)
748
  else:
749
  controlnet.load_model("tile_upscaler")
750
+ elif task_type == TaskType.CANNY:
751
  controlnet.load_model("canny")
752
+ elif task_type == TaskType.SCRIBBLE:
753
  controlnet.load_model("scribble")
754
+ elif task_type == TaskType.LINEARART:
755
  controlnet.load_model("linearart")
756
+ elif task_type == TaskType.POSE:
757
  controlnet.load_model("pose")
758
 
759
  safety_checker.apply(controlnet)
 
772
 
773
  lora_style.load(model_dir)
774
 
775
+ load_model_by_task(TaskType.TEXT_TO_IMAGE)
776
+
777
  print("Logs: model loaded ....")
778
  return
779
 
 
788
  FailureHandler.handle(task)
789
 
790
  try:
791
+ task_type = task.get_type()
792
+
793
  # Set set_environment
794
  set_configs_from_task(task)
795
 
796
  # Load model based on task
797
+ load_model_by_task(
798
+ task.get_type() or TaskType.TEXT_TO_IMAGE, task.get_model_id()
799
+ )
800
+
801
+ # Realtime generation apis
802
+ if task_type == TaskType.RT_DRAW_SEG:
803
+ return rt_draw_seg(task)
804
+ if task_type == TaskType.RT_DRAW_IMG:
805
+ return rt_draw_img(task)
806
 
807
  # Apply arguments
808
  apply_style_args(data)
 
813
  # Fetch avatars
814
  avatar.fetch_from_network(task.get_model_id())
815
 
 
 
816
  if task_type == TaskType.TEXT_TO_IMAGE:
817
  # character sheet
818
  # if "character sheet" in task.get_prompt().lower():
 
837
  return replace_bg(task)
838
  elif task_type == TaskType.CUSTOM_ACTION:
839
  return custom_action(task)
840
+ elif task_type == TaskType.REMOVE_BG:
841
+ return remove_bg(task)
842
+ elif task_type == TaskType.UPSCALE_IMAGE:
843
+ return upscale_image(task)
844
+ elif task_type == TaskType.OBJECT_REMOVAL:
845
+ return remove_object(task)
846
  elif task_type == TaskType.SYSTEM_CMD:
847
  os.system(task.get_prompt())
848
+ elif task_type == TaskType.PRELOAD_MODEL:
849
+ try:
850
+ task_type = TaskType(task.get_prompt())
851
+ except:
852
+ task_type = TaskType.SYSTEM_CMD
853
+ load_model_by_task(task_type)
854
  else:
855
  raise Exception("Invalid task type")
856
  except Exception as e:
internals/data/task.py CHANGED
@@ -1,6 +1,6 @@
1
  from enum import Enum
2
  from functools import lru_cache
3
- from typing import Union
4
 
5
  import numpy as np
6
 
@@ -18,6 +18,9 @@ class TaskType(Enum):
18
  SCRIBBLE = "SCRIBBLE"
19
  LINEARART = "LINEARART"
20
  REPLACE_BG = "REPLACE_BG"
 
 
 
21
  CUSTOM_ACTION = "CUSTOM_ACTION"
22
  SYSTEM_CMD = "SYSTEM_CMD"
23
 
@@ -52,7 +55,7 @@ class Task:
52
  def get_imageUrl(self) -> str:
53
  return self.__data.get("imageUrl", None)
54
 
55
- def get_auxilary_imageUrl(self) -> str:
56
  return self.__data.get("aux_imageUrl", None)
57
 
58
  def get_prompt(self) -> str:
 
1
  from enum import Enum
2
  from functools import lru_cache
3
+ from typing import Optional, Union
4
 
5
  import numpy as np
6
 
 
18
  SCRIBBLE = "SCRIBBLE"
19
  LINEARART = "LINEARART"
20
  REPLACE_BG = "REPLACE_BG"
21
+ RT_DRAW_SEG = "RT_DRAW_SEG"
22
+ RT_DRAW_IMG = "RT_DRAW_IMG"
23
+ PRELOAD_MODEL = "PRELOAD_MODEL"
24
  CUSTOM_ACTION = "CUSTOM_ACTION"
25
  SYSTEM_CMD = "SYSTEM_CMD"
26
 
 
55
  def get_imageUrl(self) -> str:
56
  return self.__data.get("imageUrl", None)
57
 
58
+ def get_auxilary_imageUrl(self) -> Optional[str]:
59
  return self.__data.get("aux_imageUrl", None)
60
 
61
  def get_prompt(self) -> str:
internals/pipelines/object_remove.py CHANGED
@@ -18,7 +18,11 @@ from saicinpainting.training.trainers import load_checkpoint
18
 
19
 
20
  class ObjectRemoval:
 
 
21
  def load(self, model_dir):
 
 
22
  print("Downloading LAMA model...")
23
 
24
  self.lama_path = Path.home() / ".cache" / "lama"
@@ -36,6 +40,8 @@ class ObjectRemoval:
36
  self.model.freeze()
37
  self.model.to("cuda")
38
 
 
 
39
  @torch.no_grad()
40
  def process(
41
  self,
 
18
 
19
 
20
  class ObjectRemoval:
21
+ __loaded = False
22
+
23
  def load(self, model_dir):
24
+ if self.__loaded:
25
+ return
26
  print("Downloading LAMA model...")
27
 
28
  self.lama_path = Path.home() / ".cache" / "lama"
 
40
  self.model.freeze()
41
  self.model.to("cuda")
42
 
43
+ self.__loaded = True
44
+
45
  @torch.no_grad()
46
  def process(
47
  self,
internals/pipelines/realtime_draw.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from diffusers import ControlNetModel, StableDiffusionControlNetImg2ImgPipeline
5
+ from PIL import Image
6
+
7
+ import internals.util.image as ImageUtil
8
+ from internals.pipelines.commons import AbstractPipeline
9
+ from internals.pipelines.controlnets import ControlNet
10
+ from internals.util.config import get_hf_cache_dir
11
+
12
+
13
+ class RealtimeDraw(AbstractPipeline):
14
+ def load(self, pipeline: AbstractPipeline):
15
+ if hasattr(self, "pipe"):
16
+ return
17
+
18
+ self.__controlnet_scribble = ControlNetModel.from_pretrained(
19
+ "lllyasviel/control_v11p_sd15_scribble",
20
+ torch_dtype=torch.float16,
21
+ cache_dir=get_hf_cache_dir(),
22
+ )
23
+
24
+ self.__controlnet_seg = ControlNetModel.from_pretrained(
25
+ "lllyasviel/control_v11p_sd15_seg",
26
+ torch_dtype=torch.float16,
27
+ cache_dir=get_hf_cache_dir(),
28
+ )
29
+
30
+ kwargs = {**pipeline.pipe.components} # pyright: ignore
31
+ kwargs.pop("image_encoder", None)
32
+ self.pipe = StableDiffusionControlNetImg2ImgPipeline(
33
+ **kwargs, controlnet=self.__controlnet_seg
34
+ ).to("cuda")
35
+ self.pipe.safety_checker = None
36
+ self.pipe2 = StableDiffusionControlNetImg2ImgPipeline(
37
+ **kwargs, controlnet=[self.__controlnet_scribble, self.__controlnet_seg]
38
+ ).to("cuda")
39
+ self.pipe2.safety_checker = None
40
+
41
+ def process_seg(
42
+ self,
43
+ image: Image.Image,
44
+ prompt: str,
45
+ negative_prompt: str,
46
+ seed: int,
47
+ ):
48
+ torch.manual_seed(seed)
49
+
50
+ image = ImageUtil.resize_image(image, 512)
51
+
52
+ img = self.pipe.__call__(
53
+ image=image,
54
+ control_image=image,
55
+ prompt=prompt,
56
+ num_inference_steps=15,
57
+ negative_prompt=negative_prompt,
58
+ guidance_scale=10,
59
+ strength=0.8,
60
+ ).images[0]
61
+
62
+ return img
63
+
64
+ def process_img(
65
+ self,
66
+ prompt: str,
67
+ negative_prompt: str,
68
+ seed: int,
69
+ image: Optional[Image.Image] = None,
70
+ image2: Optional[Image.Image] = None,
71
+ ):
72
+ torch.manual_seed(seed)
73
+
74
+ if not image:
75
+ image = Image.new("RGB", (512, 512), color=0)
76
+
77
+ if not image2:
78
+ image2 = Image.new("RGB", image.size, color=0)
79
+
80
+ image = ImageUtil.resize_image(image, 512)
81
+
82
+ scribble = ControlNet.scribble_image(image)
83
+
84
+ image2 = ImageUtil.resize_image(image2, 512)
85
+
86
+ img = self.pipe2.__call__(
87
+ image=image,
88
+ control_image=[scribble, image2],
89
+ prompt=prompt,
90
+ num_inference_steps=15,
91
+ negative_prompt=negative_prompt,
92
+ guidance_scale=10,
93
+ strength=0.9,
94
+ controlnet_conditioning_scale=[1.0, 0.8],
95
+ ).images[0]
96
+
97
+ return img
internals/pipelines/sdxl_tile_upscale.py CHANGED
@@ -4,10 +4,12 @@ from PIL import Image
4
  from torchvision import transforms
5
 
6
  import internals.util.image as ImageUtils
 
7
  from internals.data.result import Result
8
  from internals.pipelines.commons import AbstractPipeline, Text2Img
9
  from internals.pipelines.controlnets import ControlNet
10
  from internals.pipelines.demofusion_sdxl import DemoFusionSDXLControlNetPipeline
 
11
  from internals.util.commons import download_image
12
  from internals.util.config import get_base_dimension
13
 
@@ -15,7 +17,7 @@ controlnet = ControlNet()
15
 
16
 
17
  class SDXLTileUpscaler(AbstractPipeline):
18
- def create(self, pipeline: Text2Img, model_id: int):
19
  # temporal hack for upscale model till multicontrolnet support is added
20
  model = (
21
  "thibaud/controlnet-openpose-sdxl-1.0"
@@ -32,6 +34,8 @@ class SDXLTileUpscaler(AbstractPipeline):
32
  pipe.enable_vae_slicing()
33
  pipe.enable_xformers_memory_efficient_attention()
34
 
 
 
35
  self.pipe = pipe
36
 
37
  def process(
@@ -58,18 +62,28 @@ class SDXLTileUpscaler(AbstractPipeline):
58
 
59
  image_lr = self.load_and_process_image(img)
60
  print("img", img2.size, img.size)
61
- images = self.pipe.__call__(
62
- image_lr=image_lr,
63
- prompt=prompt,
64
- condition_image=condition_image,
65
- negative_prompt="blurry, ugly, duplicate, poorly drawn, deformed, mosaic",
66
- guidance_scale=11,
67
- sigma=0.8,
68
- num_inference_steps=24,
69
- width=img2.size[0],
70
- height=img2.size[1],
71
- )
72
- images = images[::-1]
 
 
 
 
 
 
 
 
 
 
73
  return images, False
74
 
75
  def load_and_process_image(self, pil_image):
 
4
  from torchvision import transforms
5
 
6
  import internals.util.image as ImageUtils
7
+ from carvekit.api import high
8
  from internals.data.result import Result
9
  from internals.pipelines.commons import AbstractPipeline, Text2Img
10
  from internals.pipelines.controlnets import ControlNet
11
  from internals.pipelines.demofusion_sdxl import DemoFusionSDXLControlNetPipeline
12
+ from internals.pipelines.high_res import HighRes
13
  from internals.util.commons import download_image
14
  from internals.util.config import get_base_dimension
15
 
 
17
 
18
 
19
  class SDXLTileUpscaler(AbstractPipeline):
20
+ def create(self, high_res: HighRes, pipeline: Text2Img, model_id: int):
21
  # temporal hack for upscale model till multicontrolnet support is added
22
  model = (
23
  "thibaud/controlnet-openpose-sdxl-1.0"
 
34
  pipe.enable_vae_slicing()
35
  pipe.enable_xformers_memory_efficient_attention()
36
 
37
+ self.high_res = high_res
38
+
39
  self.pipe = pipe
40
 
41
  def process(
 
62
 
63
  image_lr = self.load_and_process_image(img)
64
  print("img", img2.size, img.size)
65
+ if int(model_id) == 2000173:
66
+ kwargs = {
67
+ "prompt": prompt,
68
+ "negative_prompt": negative_prompt,
69
+ "image": img2,
70
+ "strength": 0.3,
71
+ "num_inference_steps": 30,
72
+ }
73
+ images = self.high_res.pipe.__call__(**kwargs).images
74
+ else:
75
+ images = self.pipe.__call__(
76
+ image_lr=image_lr,
77
+ prompt=prompt,
78
+ condition_image=condition_image,
79
+ negative_prompt="blurry, ugly, duplicate, poorly drawn, deformed, mosaic",
80
+ guidance_scale=11,
81
+ sigma=0.8,
82
+ num_inference_steps=24,
83
+ width=img2.size[0],
84
+ height=img2.size[1],
85
+ )
86
+ images = images[::-1]
87
  return images, False
88
 
89
  def load_and_process_image(self, pil_image):
internals/pipelines/upscaler.py CHANGED
@@ -148,7 +148,7 @@ class Upscaler:
148
  model=model,
149
  half=False,
150
  gpu_id="0",
151
- tile=128,
152
  tile_pad=10,
153
  pre_pad=0,
154
  )
 
148
  model=model,
149
  half=False,
150
  gpu_id="0",
151
+ tile=320,
152
  tile_pad=10,
153
  pre_pad=0,
154
  )
internals/util/commons.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import json
2
  import os
3
  import pprint
@@ -163,6 +164,18 @@ def download_file(url, out_path: Path):
163
  f.write(chunk)
164
 
165
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  def pickPoses():
167
  random_images = random.sample(characterSheets, 4)
168
  poses = []
 
1
+ import base64
2
  import json
3
  import os
4
  import pprint
 
164
  f.write(chunk)
165
 
166
 
167
+ def base64_to_image(base64_string):
168
+ imgdata = base64.b64decode(base64_string)
169
+ return Image.open(io.BytesIO(imgdata)).convert("RGB")
170
+
171
+
172
+ def image_to_base64(image):
173
+ buffered = BytesIO()
174
+ image.save(buffered, format="PNG")
175
+ img_str = base64.b64encode(buffered.getvalue())
176
+ return img_str.decode("ascii")
177
+
178
+
179
  def pickPoses():
180
  random_images = random.sample(characterSheets, 4)
181
  poses = []