File size: 3,218 Bytes
19b3da3 10230ea 19b3da3 22df957 19b3da3 10230ea 22df957 10230ea 22df957 10230ea 19b3da3 fd5252e 10230ea 19b3da3 fd5252e 10230ea 22df957 10230ea fd5252e 19b3da3 10230ea fd5252e 0daeeb0 10230ea 0daeeb0 10230ea 22df957 19b3da3 f70725b 19b3da3 f70725b 22df957 f70725b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
from typing import List, Union
import torch
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionXLInpaintPipeline
from internals.pipelines.commons import AbstractPipeline
from internals.util.cache import clear_cuda_and_gc
from internals.util.commons import disable_safety_checker, download_image
from internals.util.config import (
get_base_inpaint_model_variant,
get_hf_cache_dir,
get_hf_token,
get_inpaint_model_path,
get_is_sdxl,
get_model_dir,
)
class InPainter(AbstractPipeline):
__loaded = False
def init(self, pipeline: AbstractPipeline):
self.__base = pipeline
def load(self):
if self.__loaded:
return
if hasattr(self, "__base") and get_inpaint_model_path() == get_model_dir():
self.create(self.__base)
self.__loaded = True
return
if get_is_sdxl():
self.pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
get_inpaint_model_path(),
torch_dtype=torch.float16,
cache_dir=get_hf_cache_dir(),
use_auth_token=get_hf_token(),
variant=get_base_inpaint_model_variant(),
).to("cuda")
else:
self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
get_inpaint_model_path(),
torch_dtype=torch.float16,
cache_dir=get_hf_cache_dir(),
use_auth_token=get_hf_token(),
).to("cuda")
disable_safety_checker(self.pipe)
self.__patch()
self.__loaded = True
def create(self, pipeline: AbstractPipeline):
if get_is_sdxl():
self.pipe = StableDiffusionXLInpaintPipeline(**pipeline.pipe.components).to(
"cuda"
)
else:
self.pipe = StableDiffusionInpaintPipeline(**pipeline.pipe.components).to(
"cuda"
)
disable_safety_checker(self.pipe)
self.__patch()
def __patch(self):
if get_is_sdxl():
self.pipe.enable_vae_tiling()
self.pipe.enable_vae_slicing()
self.pipe.enable_xformers_memory_efficient_attention()
def unload(self):
self.__loaded = False
self.pipe = None
clear_cuda_and_gc()
@torch.inference_mode()
def process(
self,
image_url: str,
mask_image_url: str,
width: int,
height: int,
seed: int,
prompt: Union[str, List[str]],
negative_prompt: Union[str, List[str]],
num_inference_steps: int,
**kwargs,
):
torch.manual_seed(seed)
input_img = download_image(image_url).resize((width, height))
mask_img = download_image(mask_image_url).resize((width, height))
kwargs = {
"prompt": prompt,
"image": input_img,
"mask_image": mask_img,
"height": height,
"width": width,
"negative_prompt": negative_prompt,
"num_inference_steps": num_inference_steps,
"strength": 1.0,
**kwargs,
}
return self.pipe.__call__(**kwargs).images
|