File size: 3,418 Bytes
a660631
 
 
 
 
f521e88
 
 
 
 
 
 
 
 
 
 
 
a660631
 
 
 
7a1ec93
a660631
 
 
f521e88
a660631
 
 
7a1ec93
f521e88
a660631
 
 
 
7a1ec93
 
 
 
f521e88
a660631
f521e88
a660631
f521e88
a660631
f521e88
a660631
f521e88
a660631
f521e88
a660631
f521e88
a660631
f521e88
a660631
f521e88
a660631
f521e88
a660631
f521e88
a660631
f521e88
a660631
7a1ec93
 
a660631
 
7a1ec93
 
 
a660631
7a1ec93
a660631
 
f521e88
 
 
a660631
 
 
 
 
f521e88
 
 
a660631
 
 
 
7a1ec93
 
a660631
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import gc

import numpy as np
import PIL.Image
import torch
from controlnet_aux import (
    CannyDetector,
    ContentShuffleDetector,
    HEDdetector,
    LineartAnimeDetector,
    LineartDetector,
    MidasDetector,
    MLSDdetector,
    NormalBaeDetector,
    OpenposeDetector,
    PidiNetDetector,
)
from controlnet_aux.util import HWC3

from cv_utils import resize_image
from depth_estimator import DepthEstimator
from image_segmentor import ImageSegmentor, ImageSegmentorOneFormer


class Preprocessor:
    MODEL_ID = "lllyasviel/Annotators"

    def __init__(self):
        self.model = None
        self.models = {}
        self.name = ""

    def load(self, name: str) -> None:
        if name == self.name:
            return
        if name in self.models:
            self.name = name
            self.model = self.models[name]
            return
        if name == "HED":
            self.model = HEDdetector.from_pretrained(self.MODEL_ID)
        elif name == "Midas":
            self.model = MidasDetector.from_pretrained(self.MODEL_ID)
        elif name == "MLSD":
            self.model = MLSDdetector.from_pretrained(self.MODEL_ID)
        elif name == "Openpose":
            self.model = OpenposeDetector.from_pretrained(self.MODEL_ID)
        elif name == "PidiNet":
            self.model = PidiNetDetector.from_pretrained(self.MODEL_ID)
        elif name == "NormalBae":
            self.model = NormalBaeDetector.from_pretrained(self.MODEL_ID)
        elif name == "Lineart":
            self.model = LineartDetector.from_pretrained(self.MODEL_ID)
        elif name == "LineartAnime":
            self.model = LineartAnimeDetector.from_pretrained(self.MODEL_ID)
        elif name == "Canny":
            self.model = CannyDetector()
        elif name == "ContentShuffle":
            self.model = ContentShuffleDetector()
        elif name == "DPT":
            self.model = DepthEstimator()
        elif name == "UPerNet":
            self.model = ImageSegmentor()
        elif name == "OneFormer":
            self.model = ImageSegmentorOneFormer()
        else:
            raise ValueError
        # if torch.cuda.is_available():
        #     torch.cuda.empty_cache()
        # gc.collect()
        self.name = name
        self.models[name] = self.model

    def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image:
        if self.name == "Canny":
            if "detect_resolution" in kwargs:
                detect_resolution = kwargs.pop("detect_resolution")
                image = np.array(image)
                image = HWC3(image)
                image = resize_image(image, resolution=detect_resolution)
            image = self.model(image, **kwargs)
            return PIL.Image.fromarray(image)
        elif self.name == "Midas":
            detect_resolution = kwargs.pop("detect_resolution", 512)
            image_resolution = kwargs.pop("image_resolution", 512)
            image = np.array(image)
            image = HWC3(image)
            image = resize_image(image, resolution=detect_resolution)
            image = self.model(image, **kwargs)
            if isinstance(image, tuple):
                image = image[-1][...,::-1] # normal old
            image = HWC3(image)
            image = resize_image(image, resolution=image_resolution)
            return PIL.Image.fromarray(image)
        else:
            return self.model(image, **kwargs)