Spaces:
Runtime error
Runtime error
Update annotator/hed/__init__.py
Browse files
annotator/hed/__init__.py
CHANGED
@@ -100,13 +100,20 @@ class HEDdetector:
|
|
100 |
if not os.path.exists(modelpath):
|
101 |
from basicsr.utils.download_util import load_file_from_url
|
102 |
load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
|
103 |
-
|
|
|
|
|
|
|
|
|
104 |
|
105 |
def __call__(self, input_image):
|
106 |
assert input_image.ndim == 3
|
107 |
input_image = input_image[:, :, ::-1].copy()
|
108 |
with torch.no_grad():
|
109 |
-
|
|
|
|
|
|
|
110 |
image_hed = image_hed / 255.0
|
111 |
image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
|
112 |
edge = self.netNetwork(image_hed)[0]
|
|
|
100 |
if not os.path.exists(modelpath):
|
101 |
from basicsr.utils.download_util import load_file_from_url
|
102 |
load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
|
103 |
+
if torch.cuda.is_available():
|
104 |
+
self.netNetwork = Network(modelpath).cuda().eval()
|
105 |
+
else:
|
106 |
+
self.netNetwork = Network(modelpath).eval()
|
107 |
+
|
108 |
|
109 |
def __call__(self, input_image):
|
110 |
assert input_image.ndim == 3
|
111 |
input_image = input_image[:, :, ::-1].copy()
|
112 |
with torch.no_grad():
|
113 |
+
if torch.cuda.is_available():
|
114 |
+
image_hed = torch.from_numpy(input_image).float().cuda()
|
115 |
+
else:
|
116 |
+
image_hed = torch.from_numpy(input_image).float()
|
117 |
image_hed = image_hed / 255.0
|
118 |
image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
|
119 |
edge = self.netNetwork(image_hed)[0]
|