udion commited on
Commit
d2d3b59
1 Parent(s): 269330d

fixed utils

Browse files
Files changed (1) hide show
  1. utils.py +64 -1
utils.py CHANGED
@@ -51,4 +51,67 @@ def ensure_checkpoint_exists(model_weights_filename):
51
  print(
52
  model_weights_filename,
53
  " not found, you may need to manually download the model weights."
54
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  print(
52
  model_weights_filename,
53
  " not found, you may need to manually download the model weights."
54
+ )
55
+
56
+ def normalize(image: np.ndarray) -> np.ndarray:
57
+ """Normalize the ``OpenCV.imread`` or ``skimage.io.imread`` data.
58
+ Args:
59
+ image (np.ndarray): The image data read by ``OpenCV.imread`` or ``skimage.io.imread``.
60
+ Returns:
61
+ Normalized image data. Data range [0, 1].
62
+ """
63
+ return image.astype(np.float64) / 255.0
64
+
65
+
66
+ def unnormalize(image: np.ndarray) -> np.ndarray:
67
+ """Un-normalize the ``OpenCV.imread`` or ``skimage.io.imread`` data.
68
+ Args:
69
+ image (np.ndarray): The image data read by ``OpenCV.imread`` or ``skimage.io.imread``.
70
+ Returns:
71
+ Denormalized image data. Data range [0, 255].
72
+ """
73
+ return image.astype(np.float64) * 255.0
74
+
75
+
76
+ def image2tensor(image: np.ndarray, range_norm: bool, half: bool) -> torch.Tensor:
77
+ """Convert ``PIL.Image`` to Tensor.
78
+ Args:
79
+ image (np.ndarray): The image data read by ``PIL.Image``
80
+ range_norm (bool): Scale [0, 1] data to between [-1, 1]
81
+ half (bool): Whether to convert torch.float32 similarly to torch.half type.
82
+ Returns:
83
+ Normalized image data
84
+ Examples:
85
+ >>> image = Image.open("image.bmp")
86
+ >>> tensor_image = image2tensor(image, range_norm=False, half=False)
87
+ """
88
+ tensor = F.to_tensor(image)
89
+
90
+ if range_norm:
91
+ tensor = tensor.mul_(2.0).sub_(1.0)
92
+ if half:
93
+ tensor = tensor.half()
94
+
95
+ return tensor
96
+
97
+
98
+ def tensor2image(tensor: torch.Tensor, range_norm: bool, half: bool) -> Any:
99
+ """Converts ``torch.Tensor`` to ``PIL.Image``.
100
+ Args:
101
+ tensor (torch.Tensor): The image that needs to be converted to ``PIL.Image``
102
+ range_norm (bool): Scale [-1, 1] data to between [0, 1]
103
+ half (bool): Whether to convert torch.float32 similarly to torch.half type.
104
+ Returns:
105
+ Convert image data to support PIL library
106
+ Examples:
107
+ >>> tensor = torch.randn([1, 3, 128, 128])
108
+ >>> image = tensor2image(tensor, range_norm=False, half=False)
109
+ """
110
+ if range_norm:
111
+ tensor = tensor.add_(1.0).div_(2.0)
112
+ if half:
113
+ tensor = tensor.half()
114
+
115
+ image = tensor.squeeze_(0).permute(1, 2, 0).mul_(255).clamp_(0, 255).cpu().numpy().astype("uint8")
116
+
117
+ return image