File size: 3,218 Bytes
19b3da3
 
 
10230ea
19b3da3
 
22df957
19b3da3
10230ea
22df957
10230ea
 
 
22df957
10230ea
 
19b3da3
 
 
fd5252e
 
10230ea
 
 
19b3da3
fd5252e
 
 
10230ea
 
 
 
 
 
 
 
 
 
 
22df957
10230ea
 
 
 
 
 
 
 
fd5252e
19b3da3
 
10230ea
 
fd5252e
 
0daeeb0
10230ea
 
 
 
 
 
 
 
0daeeb0
 
10230ea
 
 
 
 
 
 
 
22df957
 
 
 
 
19b3da3
 
 
 
 
 
 
 
 
 
f70725b
 
19b3da3
 
 
 
 
 
f70725b
 
 
 
 
 
 
 
22df957
f70725b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from typing import List, Union

import torch
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionXLInpaintPipeline

from internals.pipelines.commons import AbstractPipeline
from internals.util.cache import clear_cuda_and_gc
from internals.util.commons import disable_safety_checker, download_image
from internals.util.config import (
    get_base_inpaint_model_variant,
    get_hf_cache_dir,
    get_hf_token,
    get_inpaint_model_path,
    get_is_sdxl,
    get_model_dir,
)


class InPainter(AbstractPipeline):
    __loaded = False

    def init(self, pipeline: AbstractPipeline):
        self.__base = pipeline

    def load(self):
        if self.__loaded:
            return

        if hasattr(self, "__base") and get_inpaint_model_path() == get_model_dir():
            self.create(self.__base)
            self.__loaded = True
            return

        if get_is_sdxl():
            self.pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
                get_inpaint_model_path(),
                torch_dtype=torch.float16,
                cache_dir=get_hf_cache_dir(),
                use_auth_token=get_hf_token(),
                variant=get_base_inpaint_model_variant(),
            ).to("cuda")
        else:
            self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
                get_inpaint_model_path(),
                torch_dtype=torch.float16,
                cache_dir=get_hf_cache_dir(),
                use_auth_token=get_hf_token(),
            ).to("cuda")

        disable_safety_checker(self.pipe)

        self.__patch()

        self.__loaded = True

    def create(self, pipeline: AbstractPipeline):
        if get_is_sdxl():
            self.pipe = StableDiffusionXLInpaintPipeline(**pipeline.pipe.components).to(
                "cuda"
            )
        else:
            self.pipe = StableDiffusionInpaintPipeline(**pipeline.pipe.components).to(
                "cuda"
            )
        disable_safety_checker(self.pipe)

        self.__patch()

    def __patch(self):
        if get_is_sdxl():
            self.pipe.enable_vae_tiling()
            self.pipe.enable_vae_slicing()
        self.pipe.enable_xformers_memory_efficient_attention()

    def unload(self):
        self.__loaded = False
        self.pipe = None
        clear_cuda_and_gc()

    @torch.inference_mode()
    def process(
        self,
        image_url: str,
        mask_image_url: str,
        width: int,
        height: int,
        seed: int,
        prompt: Union[str, List[str]],
        negative_prompt: Union[str, List[str]],
        num_inference_steps: int,
        **kwargs,
    ):
        torch.manual_seed(seed)

        input_img = download_image(image_url).resize((width, height))
        mask_img = download_image(mask_image_url).resize((width, height))

        kwargs = {
            "prompt": prompt,
            "image": input_img,
            "mask_image": mask_img,
            "height": height,
            "width": width,
            "negative_prompt": negative_prompt,
            "num_inference_steps": num_inference_steps,
            "strength": 1.0,
            **kwargs,
        }
        return self.pipe.__call__(**kwargs).images