Spaces:
Runtime error
Runtime error
Update ldm/util.py
Browse files- ldm/util.py +30 -1
ldm/util.py
CHANGED
|
@@ -286,4 +286,33 @@ def load_state_dict(ckpt_path, location='cpu'):
|
|
| 286 |
state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
|
| 287 |
state_dict = get_state_dict(state_dict)
|
| 288 |
print(f'Loaded state_dict from [{ckpt_path}]')
|
| 289 |
-
return state_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
|
| 287 |
state_dict = get_state_dict(state_dict)
|
| 288 |
print(f'Loaded state_dict from [{ckpt_path}]')
|
| 289 |
+
return state_dict
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def img2tensor(imgs, bgr2rgb=True, float32=True):
|
| 293 |
+
"""Numpy array to tensor.
|
| 294 |
+
|
| 295 |
+
Args:
|
| 296 |
+
imgs (list[ndarray] | ndarray): Input images.
|
| 297 |
+
bgr2rgb (bool): Whether to change bgr to rgb.
|
| 298 |
+
float32 (bool): Whether to change to float32.
|
| 299 |
+
|
| 300 |
+
Returns:
|
| 301 |
+
list[tensor] | tensor: Tensor images. If returned results only have
|
| 302 |
+
one element, just return tensor.
|
| 303 |
+
"""
|
| 304 |
+
|
| 305 |
+
def _totensor(img, bgr2rgb, float32):
|
| 306 |
+
if img.shape[2] == 3 and bgr2rgb:
|
| 307 |
+
if img.dtype == 'float64':
|
| 308 |
+
img = img.astype('float32')
|
| 309 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 310 |
+
img = torch.from_numpy(img.transpose(2, 0, 1))
|
| 311 |
+
if float32:
|
| 312 |
+
img = img.float()
|
| 313 |
+
return img
|
| 314 |
+
|
| 315 |
+
if isinstance(imgs, list):
|
| 316 |
+
return [_totensor(img, bgr2rgb, float32) for img in imgs]
|
| 317 |
+
else:
|
| 318 |
+
return _totensor(imgs, bgr2rgb, float32)
|