Spaces:
Running
on
Zero
Running
on
Zero
# -*- coding: utf-8 -*- | |
"""Utils for monoDepth.""" | |
import re | |
import sys | |
import cv2 | |
import numpy as np | |
import torch | |
def read_pfm(path): | |
"""Read pfm file. | |
Args: | |
path (str): path to file | |
Returns: | |
tuple: (data, scale) | |
""" | |
with open(path, 'rb') as file: | |
color = None | |
width = None | |
height = None | |
scale = None | |
endian = None | |
header = file.readline().rstrip() | |
if header.decode('ascii') == 'PF': | |
color = True | |
elif header.decode('ascii') == 'Pf': | |
color = False | |
else: | |
raise Exception('Not a PFM file: ' + path) | |
dim_match = re.match(r'^(\d+)\s(\d+)\s$', | |
file.readline().decode('ascii')) | |
if dim_match: | |
width, height = list(map(int, dim_match.groups())) | |
else: | |
raise Exception('Malformed PFM header.') | |
scale = float(file.readline().decode('ascii').rstrip()) | |
if scale < 0: | |
# little-endian | |
endian = '<' | |
scale = -scale | |
else: | |
# big-endian | |
endian = '>' | |
data = np.fromfile(file, endian + 'f') | |
shape = (height, width, 3) if color else (height, width) | |
data = np.reshape(data, shape) | |
data = np.flipud(data) | |
return data, scale | |
def write_pfm(path, image, scale=1): | |
"""Write pfm file. | |
Args: | |
path (str): pathto file | |
image (array): data | |
scale (int, optional): Scale. Defaults to 1. | |
""" | |
with open(path, 'wb') as file: | |
color = None | |
if image.dtype.name != 'float32': | |
raise Exception('Image dtype must be float32.') | |
image = np.flipud(image) | |
if len(image.shape) == 3 and image.shape[2] == 3: # color image | |
color = True | |
elif (len(image.shape) == 2 | |
or len(image.shape) == 3 and image.shape[2] == 1): # greyscale | |
color = False | |
else: | |
raise Exception( | |
'Image must have H x W x 3, H x W x 1 or H x W dimensions.') | |
file.write('PF\n' if color else 'Pf\n'.encode()) | |
file.write('%d %d\n'.encode() % (image.shape[1], image.shape[0])) | |
endian = image.dtype.byteorder | |
if endian == '<' or endian == '=' and sys.byteorder == 'little': | |
scale = -scale | |
file.write('%f\n'.encode() % scale) | |
image.tofile(file) | |
def read_image(path): | |
"""Read image and output RGB image (0-1). | |
Args: | |
path (str): path to file | |
Returns: | |
array: RGB image (0-1) | |
""" | |
img = cv2.imread(path) | |
if img.ndim == 2: | |
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 | |
return img | |
def resize_image(img): | |
"""Resize image and make it fit for network. | |
Args: | |
img (array): image | |
Returns: | |
tensor: data ready for network | |
""" | |
height_orig = img.shape[0] | |
width_orig = img.shape[1] | |
if width_orig > height_orig: | |
scale = width_orig / 384 | |
else: | |
scale = height_orig / 384 | |
height = (np.ceil(height_orig / scale / 32) * 32).astype(int) | |
width = (np.ceil(width_orig / scale / 32) * 32).astype(int) | |
img_resized = cv2.resize(img, (width, height), | |
interpolation=cv2.INTER_AREA) | |
img_resized = (torch.from_numpy(np.transpose( | |
img_resized, (2, 0, 1))).contiguous().float()) | |
img_resized = img_resized.unsqueeze(0) | |
return img_resized | |
def resize_depth(depth, width, height): | |
"""Resize depth map and bring to CPU (numpy). | |
Args: | |
depth (tensor): depth | |
width (int): image width | |
height (int): image height | |
Returns: | |
array: processed depth | |
""" | |
depth = torch.squeeze(depth[0, :, :, :]).to('cpu') | |
depth_resized = cv2.resize(depth.numpy(), (width, height), | |
interpolation=cv2.INTER_CUBIC) | |
return depth_resized | |
def write_depth(path, depth, bits=1): | |
"""Write depth map to pfm and png file. | |
Args: | |
path (str): filepath without extension | |
depth (array): depth | |
""" | |
write_pfm(path + '.pfm', depth.astype(np.float32)) | |
depth_min = depth.min() | |
depth_max = depth.max() | |
max_val = (2**(8 * bits)) - 1 | |
if depth_max - depth_min > np.finfo('float').eps: | |
out = max_val * (depth - depth_min) / (depth_max - depth_min) | |
else: | |
out = np.zeros(depth.shape, dtype=depth.type) | |
if bits == 1: | |
cv2.imwrite(path + '.png', out.astype('uint8')) | |
elif bits == 2: | |
cv2.imwrite(path + '.png', out.astype('uint16')) | |
return | |