File size: 3,016 Bytes
68916cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import gc

import numpy as np
import PIL.Image
import torch
from controlnet_aux import (CannyDetector, ContentShuffleDetector, HEDdetector,
                            LineartAnimeDetector, LineartDetector,
                            MidasDetector, MLSDdetector, NormalBaeDetector,
                            OpenposeDetector, PidiNetDetector)
from controlnet_aux.util import HWC3

from cv_utils import resize_image
from depth_estimator import DepthEstimator
from image_segmentor import ImageSegmentor


class Preprocessor:
    MODEL_ID = 'lllyasviel/Annotators'

    def __init__(self):
        self.model = None
        self.name = ''

    def load(self, name: str) -> None:
        if name == self.name:
            return
        if name == 'HED':
            self.model = HEDdetector.from_pretrained(self.MODEL_ID)
        elif name == 'Midas':
            self.model = MidasDetector.from_pretrained(self.MODEL_ID)
        elif name == 'MLSD':
            self.model = MLSDdetector.from_pretrained(self.MODEL_ID)
        elif name == 'Openpose':
            self.model = OpenposeDetector.from_pretrained(self.MODEL_ID)
        elif name == 'PidiNet':
            self.model = PidiNetDetector.from_pretrained(self.MODEL_ID)
        elif name == 'NormalBae':
            self.model = NormalBaeDetector.from_pretrained(self.MODEL_ID)
        elif name == 'Lineart':
            self.model = LineartDetector.from_pretrained(self.MODEL_ID)
        elif name == 'LineartAnime':
            self.model = LineartAnimeDetector.from_pretrained(self.MODEL_ID)
        elif name == 'Canny':
            self.model = CannyDetector()
        elif name == 'ContentShuffle':
            self.model = ContentShuffleDetector()
        elif name == 'DPT':
            self.model = DepthEstimator()
        elif name == 'UPerNet':
            self.model = ImageSegmentor()
        else:
            raise ValueError
        torch.cuda.empty_cache()
        gc.collect()
        self.name = name

    def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image:
        if self.name == 'Canny':
            if 'detect_resolution' in kwargs:
                detect_resolution = kwargs.pop('detect_resolution')
                image = np.array(image)
                image = HWC3(image)
                image = resize_image(image, resolution=detect_resolution)
            image = self.model(image, **kwargs)
            return PIL.Image.fromarray(image)
        elif self.name == 'Midas':
            detect_resolution = kwargs.pop('detect_resolution', 512)
            image_resolution = kwargs.pop('image_resolution', 512)
            image = np.array(image)
            image = HWC3(image)
            image = resize_image(image, resolution=detect_resolution)
            image = self.model(image, **kwargs)
            image = HWC3(image)
            image = resize_image(image, resolution=image_resolution)
            return PIL.Image.fromarray(image)
        else:
            return self.model(image, **kwargs)