aki-0421
F: add
a3a3ae4 unverified
raw
history blame
4.69 kB
# -*- coding: utf-8 -*-
"""Utils for monoDepth."""
import re
import sys
import cv2
import numpy as np
import torch
def read_pfm(path):
"""Read pfm file.
Args:
path (str): path to file
Returns:
tuple: (data, scale)
"""
with open(path, 'rb') as file:
color = None
width = None
height = None
scale = None
endian = None
header = file.readline().rstrip()
if header.decode('ascii') == 'PF':
color = True
elif header.decode('ascii') == 'Pf':
color = False
else:
raise Exception('Not a PFM file: ' + path)
dim_match = re.match(r'^(\d+)\s(\d+)\s$',
file.readline().decode('ascii'))
if dim_match:
width, height = list(map(int, dim_match.groups()))
else:
raise Exception('Malformed PFM header.')
scale = float(file.readline().decode('ascii').rstrip())
if scale < 0:
# little-endian
endian = '<'
scale = -scale
else:
# big-endian
endian = '>'
data = np.fromfile(file, endian + 'f')
shape = (height, width, 3) if color else (height, width)
data = np.reshape(data, shape)
data = np.flipud(data)
return data, scale
def write_pfm(path, image, scale=1):
"""Write pfm file.
Args:
path (str): pathto file
image (array): data
scale (int, optional): Scale. Defaults to 1.
"""
with open(path, 'wb') as file:
color = None
if image.dtype.name != 'float32':
raise Exception('Image dtype must be float32.')
image = np.flipud(image)
if len(image.shape) == 3 and image.shape[2] == 3: # color image
color = True
elif (len(image.shape) == 2
or len(image.shape) == 3 and image.shape[2] == 1): # greyscale
color = False
else:
raise Exception(
'Image must have H x W x 3, H x W x 1 or H x W dimensions.')
file.write('PF\n' if color else 'Pf\n'.encode())
file.write('%d %d\n'.encode() % (image.shape[1], image.shape[0]))
endian = image.dtype.byteorder
if endian == '<' or endian == '=' and sys.byteorder == 'little':
scale = -scale
file.write('%f\n'.encode() % scale)
image.tofile(file)
def read_image(path):
"""Read image and output RGB image (0-1).
Args:
path (str): path to file
Returns:
array: RGB image (0-1)
"""
img = cv2.imread(path)
if img.ndim == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
return img
def resize_image(img):
"""Resize image and make it fit for network.
Args:
img (array): image
Returns:
tensor: data ready for network
"""
height_orig = img.shape[0]
width_orig = img.shape[1]
if width_orig > height_orig:
scale = width_orig / 384
else:
scale = height_orig / 384
height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
img_resized = cv2.resize(img, (width, height),
interpolation=cv2.INTER_AREA)
img_resized = (torch.from_numpy(np.transpose(
img_resized, (2, 0, 1))).contiguous().float())
img_resized = img_resized.unsqueeze(0)
return img_resized
def resize_depth(depth, width, height):
"""Resize depth map and bring to CPU (numpy).
Args:
depth (tensor): depth
width (int): image width
height (int): image height
Returns:
array: processed depth
"""
depth = torch.squeeze(depth[0, :, :, :]).to('cpu')
depth_resized = cv2.resize(depth.numpy(), (width, height),
interpolation=cv2.INTER_CUBIC)
return depth_resized
def write_depth(path, depth, bits=1):
"""Write depth map to pfm and png file.
Args:
path (str): filepath without extension
depth (array): depth
"""
write_pfm(path + '.pfm', depth.astype(np.float32))
depth_min = depth.min()
depth_max = depth.max()
max_val = (2**(8 * bits)) - 1
if depth_max - depth_min > np.finfo('float').eps:
out = max_val * (depth - depth_min) / (depth_max - depth_min)
else:
out = np.zeros(depth.shape, dtype=depth.type)
if bits == 1:
cv2.imwrite(path + '.png', out.astype('uint8'))
elif bits == 2:
cv2.imwrite(path + '.png', out.astype('uint16'))
return