jiangyzy commited on
Commit
48758c6
1 Parent(s): 28b27d8

Update ldm/util.py

Browse files
Files changed (1) hide show
  1. ldm/util.py +30 -1
ldm/util.py CHANGED
@@ -286,4 +286,33 @@ def load_state_dict(ckpt_path, location='cpu'):
286
  state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
287
  state_dict = get_state_dict(state_dict)
288
  print(f'Loaded state_dict from [{ckpt_path}]')
289
- return state_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
  state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
287
  state_dict = get_state_dict(state_dict)
288
  print(f'Loaded state_dict from [{ckpt_path}]')
289
+ return state_dict
290
+
291
+
292
+ def img2tensor(imgs, bgr2rgb=True, float32=True):
293
+ """Numpy array to tensor.
294
+
295
+ Args:
296
+ imgs (list[ndarray] | ndarray): Input images.
297
+ bgr2rgb (bool): Whether to change bgr to rgb.
298
+ float32 (bool): Whether to change to float32.
299
+
300
+ Returns:
301
+ list[tensor] | tensor: Tensor images. If returned results only have
302
+ one element, just return tensor.
303
+ """
304
+
305
+ def _totensor(img, bgr2rgb, float32):
306
+ if img.shape[2] == 3 and bgr2rgb:
307
+ if img.dtype == 'float64':
308
+ img = img.astype('float32')
309
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
310
+ img = torch.from_numpy(img.transpose(2, 0, 1))
311
+ if float32:
312
+ img = img.float()
313
+ return img
314
+
315
+ if isinstance(imgs, list):
316
+ return [_totensor(img, bgr2rgb, float32) for img in imgs]
317
+ else:
318
+ return _totensor(imgs, bgr2rgb, float32)