zipnerf / internal /raw_utils.py
Cr4yfish's picture
copy files from SuLvXiangXin
c165cd8
import glob
import json
import os
from internal import image as lib_image
from internal import math
from internal import utils
import numpy as np
import rawpy
def postprocess_raw(raw, camtorgb, exposure=None):
"""Converts demosaicked raw to sRGB with a minimal postprocessing pipeline.
Args:
raw: [H, W, 3], demosaicked raw camera image.
camtorgb: [3, 3], color correction transformation to apply to raw image.
exposure: color value to be scaled to pure white after color correction.
If None, "autoexposes" at the 97th percentile.
Returns:
srgb: [H, W, 3], color corrected + exposed + gamma mapped image.
"""
if raw.shape[-1] != 3:
raise ValueError(f'raw.shape[-1] is {raw.shape[-1]}, expected 3')
if camtorgb.shape != (3, 3):
raise ValueError(f'camtorgb.shape is {camtorgb.shape}, expected (3, 3)')
# Convert from camera color space to standard linear RGB color space.
rgb_linear = np.matmul(raw, camtorgb.T)
if exposure is None:
exposure = np.percentile(rgb_linear, 97)
# "Expose" image by mapping the input exposure level to white and clipping.
rgb_linear_scaled = np.clip(rgb_linear / exposure, 0, 1)
# Apply sRGB gamma curve to serve as a simple tonemap.
srgb = lib_image.linear_to_srgb_np(rgb_linear_scaled)
return srgb
def pixels_to_bayer_mask(pix_x, pix_y):
"""Computes binary RGB Bayer mask values from integer pixel coordinates."""
# Red is top left (0, 0).
r = (pix_x % 2 == 0) * (pix_y % 2 == 0)
# Green is top right (0, 1) and bottom left (1, 0).
g = (pix_x % 2 == 1) * (pix_y % 2 == 0) + (pix_x % 2 == 0) * (pix_y % 2 == 1)
# Blue is bottom right (1, 1).
b = (pix_x % 2 == 1) * (pix_y % 2 == 1)
return np.stack([r, g, b], -1).astype(np.float32)
def bilinear_demosaic(bayer):
"""Converts Bayer data into a full RGB image using bilinear demosaicking.
Input data should be ndarray of shape [height, width] with 2x2 mosaic pattern:
-------------
|red |green|
-------------
|green|blue |
-------------
Red and blue channels are bilinearly upsampled 2x, missing green channel
elements are the average of the neighboring 4 values in a cross pattern.
Args:
bayer: [H, W] array, Bayer mosaic pattern input image.
Returns:
rgb: [H, W, 3] array, full RGB image.
"""
def reshape_quads(*planes):
"""Reshape pixels from four input images to make tiled 2x2 quads."""
planes = np.stack(planes, -1)
shape = planes.shape[:-1]
# Create [2, 2] arrays out of 4 channels.
zup = planes.reshape(shape + (2, 2,))
# Transpose so that x-axis dimensions come before y-axis dimensions.
zup = np.transpose(zup, (0, 2, 1, 3))
# Reshape to 2D.
zup = zup.reshape((shape[0] * 2, shape[1] * 2))
return zup
def bilinear_upsample(z):
"""2x bilinear image upsample."""
# Using np.roll makes the right and bottom edges wrap around. The raw image
# data has a few garbage columns/rows at the edges that must be discarded
# anyway, so this does not matter in practice.
# Horizontally interpolated values.
zx = .5 * (z + np.roll(z, -1, axis=-1))
# Vertically interpolated values.
zy = .5 * (z + np.roll(z, -1, axis=-2))
# Diagonally interpolated values.
zxy = .5 * (zx + np.roll(zx, -1, axis=-2))
return reshape_quads(z, zx, zy, zxy)
def upsample_green(g1, g2):
"""Special 2x upsample from the two green channels."""
z = np.zeros_like(g1)
z = reshape_quads(z, g1, g2, z)
alt = 0
# Grab the 4 directly adjacent neighbors in a "cross" pattern.
for i in range(4):
axis = -1 - (i // 2)
roll = -1 + 2 * (i % 2)
alt = alt + .25 * np.roll(z, roll, axis=axis)
# For observed pixels, alt = 0, and for unobserved pixels, alt = avg(cross),
# so alt + z will have every pixel filled in.
return alt + z
r, g1, g2, b = [bayer[(i // 2)::2, (i % 2)::2] for i in range(4)]
r = bilinear_upsample(r)
# Flip in x and y before and after calling upsample, as bilinear_upsample
# assumes that the samples are at the top-left corner of the 2x2 sample.
b = bilinear_upsample(b[::-1, ::-1])[::-1, ::-1]
g = upsample_green(g1, g2)
rgb = np.stack([r, g, b], -1)
return rgb
def load_raw_images(image_dir, image_names=None):
"""Loads raw images and their metadata from disk.
Args:
image_dir: directory containing raw image and EXIF data.
image_names: files to load (ignores file extension), loads all DNGs if None.
Returns:
A tuple (images, exifs).
images: [N, height, width, 3] array of raw sensor data.
exifs: [N] list of dicts, one per image, containing the EXIF data.
Raises:
ValueError: The requested `image_dir` does not exist on disk.
"""
if not utils.file_exists(image_dir):
raise ValueError(f'Raw image folder {image_dir} does not exist.')
# Load raw images (dng files) and exif metadata (json files).
def load_raw_exif(image_name):
base = os.path.join(image_dir, os.path.splitext(image_name)[0])
with utils.open_file(base + '.dng', 'rb') as f:
raw = rawpy.imread(f).raw_image
with utils.open_file(base + '.json', 'rb') as f:
exif = json.load(f)[0]
return raw, exif
if image_names is None:
image_names = [
os.path.basename(f)
for f in sorted(glob.glob(os.path.join(image_dir, '*.dng')))
]
data = [load_raw_exif(x) for x in image_names]
raws, exifs = zip(*data)
raws = np.stack(raws, axis=0).astype(np.float32)
return raws, exifs
# Brightness percentiles to use for re-exposing and tonemapping raw images.
_PERCENTILE_LIST = (80, 90, 97, 99, 100)
# Relevant fields to extract from raw image EXIF metadata.
# For details regarding EXIF parameters, see:
# https://www.adobe.com/content/dam/acom/en/products/photoshop/pdfs/dng_spec_1.4.0.0.pdf.
_EXIF_KEYS = (
'BlackLevel', # Black level offset added to sensor measurements.
'WhiteLevel', # Maximum possible sensor measurement.
'AsShotNeutral', # RGB white balance coefficients.
'ColorMatrix2', # XYZ to camera color space conversion matrix.
'NoiseProfile', # Shot and read noise levels.
)
# Color conversion from reference illuminant XYZ to RGB color space.
# See http://www.brucelindbloom.com/index.html?Eqn_RGB_XYZ_Matrix.html.
_RGB2XYZ = np.array([[0.4124564, 0.3575761, 0.1804375],
[0.2126729, 0.7151522, 0.0721750],
[0.0193339, 0.1191920, 0.9503041]])
def process_exif(exifs):
"""Processes list of raw image EXIF data into useful metadata dict.
Input should be a list of dictionaries loaded from JSON files.
These JSON files are produced by running
$ exiftool -json IMAGE.dng > IMAGE.json
for each input raw file.
We extract only the parameters relevant to
1. Rescaling the raw data to [0, 1],
2. White balance and color correction, and
3. Noise level estimation.
Args:
exifs: a list of dicts containing EXIF data as loaded from JSON files.
Returns:
meta: a dict of the relevant metadata for running RawNeRF.
"""
meta = {}
exif = exifs[0]
# Convert from array of dicts (exifs) to dict of arrays (meta).
for key in _EXIF_KEYS:
exif_value = exif.get(key)
if exif_value is None:
continue
# Values can be a single int or float...
if isinstance(exif_value, int) or isinstance(exif_value, float):
vals = [x[key] for x in exifs]
# Or a string of numbers with ' ' between.
elif isinstance(exif_value, str):
vals = [[float(z) for z in x[key].split(' ')] for x in exifs]
meta[key] = np.squeeze(np.array(vals))
# Shutter speed is a special case, a string written like 1/N.
meta['ShutterSpeed'] = np.fromiter(
(1. / float(exif['ShutterSpeed'].split('/')[1]) for exif in exifs), float)
# Create raw-to-sRGB color transform matrices. Pipeline is:
# cam space -> white balanced cam space ("camwb") -> XYZ space -> RGB space.
# 'AsShotNeutral' is an RGB triplet representing how pure white would measure
# on the sensor, so dividing by these numbers corrects the white balance.
whitebalance = meta['AsShotNeutral'].reshape(-1, 3)
cam2camwb = np.array([np.diag(1. / x) for x in whitebalance])
# ColorMatrix2 converts from XYZ color space to "reference illuminant" (white
# balanced) camera space.
xyz2camwb = meta['ColorMatrix2'].reshape(-1, 3, 3)
rgb2camwb = xyz2camwb @ _RGB2XYZ
# We normalize the rows of the full color correction matrix, as is done in
# https://github.com/AbdoKamel/simple-camera-pipeline.
rgb2camwb /= rgb2camwb.sum(axis=-1, keepdims=True)
# Combining color correction with white balance gives the entire transform.
cam2rgb = np.linalg.inv(rgb2camwb) @ cam2camwb
meta['cam2rgb'] = cam2rgb
return meta
def load_raw_dataset(split, data_dir, image_names, exposure_percentile, n_downsample):
"""Loads and processes a set of RawNeRF input images.
Includes logic necessary for special "test" scenes that include a noiseless
ground truth frame, produced by HDR+ merge.
Args:
split: DataSplit.TRAIN or DataSplit.TEST, only used for test scene logic.
data_dir: base directory for scene data.
image_names: which images were successfully posed by COLMAP.
exposure_percentile: what brightness percentile to expose to white.
n_downsample: returned images are downsampled by a factor of n_downsample.
Returns:
A tuple (images, meta, testscene).
images: [N, height // n_downsample, width // n_downsample, 3] array of
demosaicked raw image data.
meta: EXIF metadata and other useful processing parameters. Includes per
image exposure information that can be passed into the NeRF model with
each ray: the set of unique exposure times is determined and each image
assigned a corresponding exposure index (mapping to an exposure value).
These are keys 'unique_shutters', 'exposure_idx', and 'exposure_value' in
the `meta` dictionary.
We rescale so the maximum `exposure_value` is 1 for convenience.
testscene: True when dataset includes ground truth test image, else False.
"""
image_dir = os.path.join(data_dir, 'raw')
testimg_file = os.path.join(data_dir, 'hdrplus_test/merged.dng')
testscene = utils.file_exists(testimg_file)
if testscene:
# Test scenes have train/ and test/ split subdirectories inside raw/.
image_dir = os.path.join(image_dir, split.value)
if split == utils.DataSplit.TEST:
# COLMAP image names not valid for test split of test scene.
image_names = None
else:
# Discard the first COLMAP image name as it is a copy of the test image.
image_names = image_names[1:]
raws, exifs = load_raw_images(image_dir, image_names)
meta = process_exif(exifs)
if testscene and split == utils.DataSplit.TEST:
# Test split for test scene must load the "ground truth" HDR+ merged image.
with utils.open_file(testimg_file, 'rb') as imgin:
testraw = rawpy.imread(imgin).raw_image
# HDR+ output has 2 extra bits of fixed precision, need to divide by 4.
testraw = testraw.astype(np.float32) / 4.
# Need to rescale long exposure test image by fast:slow shutter speed ratio.
fast_shutter = meta['ShutterSpeed'][0]
slow_shutter = meta['ShutterSpeed'][-1]
shutter_ratio = fast_shutter / slow_shutter
# Replace loaded raws with the "ground truth" test image.
raws = testraw[None]
# Test image shares metadata with the first loaded image (fast exposure).
meta = {k: meta[k][:1] for k in meta}
else:
shutter_ratio = 1.
# Next we determine an index for each unique shutter speed in the data.
shutter_speeds = meta['ShutterSpeed']
# Sort the shutter speeds from slowest (largest) to fastest (smallest).
# This way index 0 will always correspond to the brightest image.
unique_shutters = np.sort(np.unique(shutter_speeds))[::-1]
exposure_idx = np.zeros_like(shutter_speeds, dtype=np.int32)
for i, shutter in enumerate(unique_shutters):
# Assign index `i` to all images with shutter speed `shutter`.
exposure_idx[shutter_speeds == shutter] = i
meta['exposure_idx'] = exposure_idx
meta['unique_shutters'] = unique_shutters
# Rescale to use relative shutter speeds, where 1. is the brightest.
# This way the NeRF output with exposure=1 will always be reasonable.
meta['exposure_values'] = shutter_speeds / unique_shutters[0]
# Rescale raw sensor measurements to [0, 1] (plus noise).
blacklevel = meta['BlackLevel'].reshape(-1, 1, 1)
whitelevel = meta['WhiteLevel'].reshape(-1, 1, 1)
images = (raws - blacklevel) / (whitelevel - blacklevel) * shutter_ratio
# Calculate value for exposure level when gamma mapping, defaults to 97%.
# Always based on full resolution image 0 (for consistency).
image0_raw_demosaic = np.array(bilinear_demosaic(images[0]))
image0_rgb = image0_raw_demosaic @ meta['cam2rgb'][0].T
exposure = np.percentile(image0_rgb, exposure_percentile)
meta['exposure'] = exposure
# Sweep over various exposure percentiles to visualize in training logs.
exposure_levels = {p: np.percentile(image0_rgb, p) for p in _PERCENTILE_LIST}
meta['exposure_levels'] = exposure_levels
# Create postprocessing function mapping raw images to tonemapped sRGB space.
cam2rgb0 = meta['cam2rgb'][0]
meta['postprocess_fn'] = lambda z, x=exposure: postprocess_raw(z, cam2rgb0, x)
def processing_fn(x):
x_ = np.array(x)
x_demosaic = bilinear_demosaic(x_)
if n_downsample > 1:
x_demosaic = lib_image.downsample(x_demosaic, n_downsample)
return np.array(x_demosaic)
images = np.stack([processing_fn(im) for im in images], axis=0)
return images, meta, testscene
def best_fit_affine(x, y, axis):
"""Computes best fit a, b such that a * x + b = y, in a least square sense."""
x_m = x.mean(axis=axis)
y_m = y.mean(axis=axis)
xy_m = (x * y).mean(axis=axis)
xx_m = (x * x).mean(axis=axis)
# slope a = Cov(x, y) / Cov(x, x).
a = (xy_m - x_m * y_m) / (xx_m - x_m * x_m)
b = y_m - a * x_m
return a, b
def match_images_affine(est, gt, axis=(0, 1)):
"""Computes affine best fit of gt->est, then maps est back to match gt."""
# Mapping is computed gt->est to be robust since `est` may be very noisy.
a, b = best_fit_affine(gt, est, axis=axis)
# Inverse mapping back to gt ensures we use a consistent space for metrics.
est_matched = (est - b) / a
return est_matched