wondervictor commited on
Commit
67ce9c3
·
verified ·
1 Parent(s): d0638c6

Update preprocessor.py

Browse files
Files changed (1) hide show
  1. preprocessor.py +104 -0
preprocessor.py CHANGED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import cv2
3
+ import numpy as np
4
+ import PIL.Image
5
+ import torch
6
+ from controlnet_aux import (
7
+ CannyDetector,
8
+ ContentShuffleDetector,
9
+ HEDdetector,
10
+ LineartAnimeDetector,
11
+ LineartDetector,
12
+ MidasDetector,
13
+ MLSDdetector,
14
+ NormalBaeDetector,
15
+ OpenposeDetector,
16
+ PidiNetDetector,
17
+ )
18
+ from controlnet_aux.util import HWC3
19
+ from transformers import pipeline
20
+ # from cv_utils import resize_image
21
+ # from depth_estimator import DepthEstimator
22
+
23
+
24
+ class DepthEstimator:
25
+ def __init__(self):
26
+ self.model = pipeline("condition/ckpts/dpt_large")
27
+
28
+ def __call__(self, image: np.ndarray, **kwargs) -> PIL.Image.Image:
29
+ detect_resolution = kwargs.pop("detect_resolution", 512)
30
+ image_resolution = kwargs.pop("image_resolution", 512)
31
+ image = np.array(image)
32
+ image = HWC3(image)
33
+ image = resize_image(image, resolution=detect_resolution)
34
+ image = PIL.Image.fromarray(image)
35
+ image = self.model(image)
36
+ image = image["depth"]
37
+ image = np.array(image)
38
+ image = HWC3(image)
39
+ image = resize_image(image, resolution=image_resolution)
40
+ return PIL.Image.fromarray(image)
41
+
42
+ def resize_image(input_image, resolution, interpolation=None):
43
+ H, W, C = input_image.shape
44
+ H = float(H)
45
+ W = float(W)
46
+ k = float(resolution) / max(H, W)
47
+ H *= k
48
+ W *= k
49
+ H = int(np.round(H / 64.0)) * 64
50
+ W = int(np.round(W / 64.0)) * 64
51
+ if interpolation is None:
52
+ interpolation = cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA
53
+ img = cv2.resize(input_image, (W, H), interpolation=interpolation)
54
+ return img
55
+
56
+
57
+ class Preprocessor:
58
+ MODEL_ID = "condition/ckpts"
59
+
60
+ def __init__(self):
61
+ self.model = None
62
+ self.name = ""
63
+
64
+ def load(self, name: str) -> None:
65
+ if name == self.name:
66
+ return
67
+ if name == "HED":
68
+ self.model = HEDdetector.from_pretrained(self.MODEL_ID)
69
+ # elif name == "Midas":
70
+ # self.model = MidasDetector.from_pretrained(self.MODEL_ID)
71
+ elif name == "Lineart":
72
+ self.model = LineartDetector.from_pretrained(self.MODEL_ID)
73
+ elif name == "Canny":
74
+ self.model = CannyDetector()
75
+ elif name == "Depth":
76
+ # self.model = DepthEstimator()
77
+ self.model = MidasDetector.from_pretrained(self.MODEL_ID)
78
+ else:
79
+ raise ValueError
80
+ torch.cuda.empty_cache()
81
+ gc.collect()
82
+ self.name = name
83
+
84
+ def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image:
85
+ if self.name == "Canny":
86
+ if "detect_resolution" in kwargs:
87
+ detect_resolution = kwargs.pop("detect_resolution")
88
+ image = np.array(image)
89
+ image = HWC3(image)
90
+ image = resize_image(image, resolution=detect_resolution)
91
+ image = self.model(image, **kwargs)
92
+ return PIL.Image.fromarray(image)
93
+ elif self.name == "Midas":
94
+ detect_resolution = kwargs.pop("detect_resolution", 512)
95
+ image_resolution = kwargs.pop("image_resolution", 512)
96
+ image = np.array(image)
97
+ image = HWC3(image)
98
+ image = resize_image(image, resolution=detect_resolution)
99
+ image = self.model(image, **kwargs)
100
+ image = HWC3(image)
101
+ image = resize_image(image, resolution=image_resolution)
102
+ return PIL.Image.fromarray(image)
103
+ else:
104
+ return self.model(image, **kwargs)