wondervictor commited on
Commit
8369c1d
·
verified ·
1 Parent(s): 67ce9c3

Update preprocessor.py

Browse files
Files changed (1) hide show
  1. preprocessor.py +5 -4
preprocessor.py CHANGED
@@ -23,7 +23,7 @@ from transformers import pipeline
23
 
24
  class DepthEstimator:
25
  def __init__(self):
26
- self.model = pipeline("condition/ckpts/dpt_large")
27
 
28
  def __call__(self, image: np.ndarray, **kwargs) -> PIL.Image.Image:
29
  detect_resolution = kwargs.pop("detect_resolution", 512)
@@ -55,7 +55,8 @@ def resize_image(input_image, resolution, interpolation=None):
55
 
56
 
57
  class Preprocessor:
58
- MODEL_ID = "condition/ckpts"
 
59
 
60
  def __init__(self):
61
  self.model = None
@@ -73,8 +74,8 @@ class Preprocessor:
73
  elif name == "Canny":
74
  self.model = CannyDetector()
75
  elif name == "Depth":
76
- # self.model = DepthEstimator()
77
- self.model = MidasDetector.from_pretrained(self.MODEL_ID)
78
  else:
79
  raise ValueError
80
  torch.cuda.empty_cache()
 
23
 
24
  class DepthEstimator:
25
  def __init__(self):
26
+ self.model = pipeline("depth-estimation")
27
 
28
  def __call__(self, image: np.ndarray, **kwargs) -> PIL.Image.Image:
29
  detect_resolution = kwargs.pop("detect_resolution", 512)
 
55
 
56
 
57
  class Preprocessor:
58
+ # MODEL_ID = "condition/ckpts"
59
+ MODEL_ID = "lllyasviel/Annotators"
60
 
61
  def __init__(self):
62
  self.model = None
 
74
  elif name == "Canny":
75
  self.model = CannyDetector()
76
  elif name == "Depth":
77
+ self.model = DepthEstimator()
78
+ # self.model = MidasDetector.from_pretrained(self.MODEL_ID)
79
  else:
80
  raise ValueError
81
  torch.cuda.empty_cache()