pixelnet / controlled_downscale.py
thomaseding's picture
Add some debug code
801759b
import argparse
import sys
from PIL import Image
from typing import List, Optional, Tuple
Pos = Tuple[int, int]
Dim = Tuple[int, int]
class Box:
def __init__(self, min: Pos, max: Pos) -> None:
self._min = min
self._max = max
# inclusive
def min(self) -> Tuple[int, int]:
return self._min
# inclusive
def max(self) -> Tuple[int, int]:
return self._max
def width(self) -> int:
return self._max[0] - self._min[0] + 1
def height(self) -> int:
return self._max[1] - self._min[1] + 1
def dimensions(self) -> Tuple[int, int]:
return (self.width(), self.height())
# (left, upper, right, lower)
def as_tuple(self) -> Tuple[int, int, int, int]:
return (self._min[0], self._min[1], self._max[0], self._max[1])
class DownBox(Box):
def __init__(self, min: Pos, max: Pos, down_pos: Pos) -> None:
super().__init__(min, max)
self._down_pos = down_pos
def down_pos(self) -> Tuple[int, int]:
return self._down_pos
class ExtractedBoxes:
def __init__(self, boxes: List[DownBox]) -> None:
self._boxes = boxes
def boxes(self) -> List[DownBox]:
return self._boxes
def down_dimensions(self) -> Dim:
if len(self._boxes) == 0:
return (0, 0)
back = self._boxes[-1]
down = back.down_pos()
return (down[0] + 1, down[1] + 1)
def full_dimensions(self) -> Dim:
if len(self._boxes) == 0:
return (0, 0)
back = self._boxes[-1]
max = back.max()
return (max[0] + 1, max[1] + 1)
def to_colored_checkers(self, *, full=True) -> Image.Image:
if full:
width, height = self.full_dimensions()
else:
width, height = self.down_dimensions()
if width == 0 or height == 0:
return Image.new("RGB", (0, 0))
image = Image.new("RGB", (width, height))
colors = [
(255, 255, 255),
(0, 0, 0),
(255, 0, 0),
(255, 127, 0),
(255, 255, 0),
(0, 255, 0),
(0, 0, 255),
(75, 0, 130),
(148, 0, 211),
(255, 0, 255),
]
colorsMax = len(colors)
currColor = 0
for box in self._boxes:
color = colors[currColor]
currColor = (currColor + 1) % colorsMax
if full:
dim = box.dimensions()
pos = box.min()
else:
dim = (1, 1)
pos = box.down_pos()
subImage = Image.new("RGB", dim, color)
image.paste(subImage, pos)
return image
def average_box_dimensions(boxes: List[DownBox]) -> Dim:
assert len(boxes) > 0
if len(boxes) == 1:
return boxes[0].dimensions()
if len(boxes) <= 16:
# mean
width = 0
height = 0
for box in boxes:
width += box.width()
height += box.height()
return (width // len(boxes), height // len(boxes))
# median
widths = [box.width() for box in boxes]
heights = [box.height() for box in boxes]
widths.sort()
heights.sort()
return (widths[len(widths) // 2], heights[len(heights) // 2])
def get_trimmed(boxes: List[DownBox]) -> Tuple[Box, Box]:
avg = average_box_dimensions(boxes)
outlier_dist = 1
# threshold = 8
# if avg[0] > threshold and avg[1] > threshold:
# outlier_dist = 2
# threshold = 32
# if avg[0] > threshold and avg[1] > threshold:
# outlier_dist = 3
def is_outlier(box: DownBox) -> bool:
dim = box.dimensions()
if abs(dim[0] - avg[0]) > outlier_dist:
return True
if abs(dim[1] - avg[1]) > outlier_dist:
return True
return False
assert len(boxes) > 0
front = boxes[0]
back = boxes[-1]
min_out = (0, 0)
max_out = back.max()
min_down = (0, 0)
max_down = back.down_pos()
if is_outlier(front):
for i in range(1, len(boxes)):
if not is_outlier(boxes[i]):
min_out = boxes[i].min()
min_down = boxes[i].down_pos()
break
if is_outlier(back):
for i in range(len(boxes) - 2, -1, -1):
if not is_outlier(boxes[i]):
max_out = boxes[i].max()
max_down = boxes[i].down_pos()
break
box_out = Box(min_out, max_out)
box_down = Box(min_down, max_down)
return (box_out, box_down)
def calc_face_box(control_image: Image.Image, min_pos: Pos) -> Box:
min_pixel = control_image.getpixel(min_pos)
width, height = control_image.size
x = 0
while min_pos[0] + x < width:
if control_image.getpixel((min_pos[0] + x, min_pos[1])) != min_pixel:
break
x += 1
y = 0
while min_pos[1] + y < height:
if control_image.getpixel((min_pos[0], min_pos[1] + y)) != min_pixel:
break
y += 1
x -= 1
y -= 1
assert x > 0
assert y > 0
return Box(min_pos, (x + min_pos[0], y + min_pos[1]))
def extract_boxes(control_image: Image.Image) -> ExtractedBoxes:
width, height = control_image.size
assert width > 0
assert height > 0
boxes: List[DownBox] = []
x = 0
y = 0
down_x = 0
down_y = 0
while y < height:
while x < width:
min_pos = (x, y)
box = calc_face_box(control_image, min_pos)
boxes.append(DownBox(box.min(), box.max(), (down_x, down_y)))
x += box.width()
down_x += 1
assert x == width
box = boxes[-1]
x = 0
y += box.height()
down_x = 0
down_y += 1
assert y == height
return ExtractedBoxes(boxes)
def downsample_one(input_image: Image.Image, box: Box, sample_radius: Optional[int], downsampler: Image.Resampling) -> Tuple[int, int, int]:
region = input_image.crop(box.as_tuple())
box_width = box.width()
box_height = box.height()
box_center_x = box.min()[0] + box_width // 2
box_center_y = box.min()[1] + box_height // 2
if sample_radius is not None:
radius_x = min(sample_radius, box_width // 2)
radius_y = min(sample_radius, box_height // 2)
else:
radius_x = box_width // 2
radius_y = box_height // 2
cropped_region = region.crop((
max(0, box_center_x - radius_x - box.min()[0]),
max(0, box_center_y - radius_y - box.min()[1]),
min(box_width, box_center_x + radius_x - box.min()[0]),
min(box_height, box_center_y + radius_y - box.min()[1])
))
assert cropped_region.size[0] >= radius_x and cropped_region.size[1] >= radius_y
sampled = cropped_region.resize((1, 1), downsampler)
rgb_value = sampled.getpixel((0, 0))
assert isinstance(rgb_value, tuple) and len(rgb_value) == 3
return rgb_value
class ImageRef:
def __init__(self, ref: Image.Image) -> None:
self.ref = ref
def downsample_all(*, input_image: Image.Image, output_image: Optional[ImageRef], down_image: Optional[ImageRef], boxes: List[DownBox], sample_radius: Optional[int], downsampler: Image.Resampling, trim_cropped_edges: bool) -> None:
assert output_image or down_image
for box in boxes:
rgb_value = downsample_one(input_image, box, sample_radius, downsampler)
solid_color_image = Image.new("RGB", box.dimensions(), rgb_value)
if output_image:
output_image.ref.paste(solid_color_image, box.min())
if down_image:
down_image.ref.paste(solid_color_image, box.down_pos())
if trim_cropped_edges:
o, d = get_trimmed(boxes)
if output_image:
output_image.ref = output_image.ref.crop(o.as_tuple())
if down_image:
down_image.ref = down_image.ref.crop(d.as_tuple())
def str2bool(value) -> bool:
if isinstance(value, bool):
return value
if value.lower() in ("true", "1"):
return True
elif value.lower() in ("false", "0"):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")
def controlled_downscale(*, control_path: str, input_path: str, output_downscaled_path: Optional[str], output_quantized_path: Optional[str], sample_radius: Optional[int], downsampler: Image.Resampling, trim_cropped_edges: bool, output_colorized_full_path: Optional[str], output_colorized_down_path: Optional[str]) -> None:
"""
Downsample and rescale an image.
:param control_path: Path to the control image.
:param input_path: Path to the input image.
:param output_downscaled_path: Path to save the output downscaled image.
:param output_quantized_path: Path to save the output quantized image (downscaled and then upscaled to the original size).
:param sample_radius: Radius for sampling (Manhattan distance).
:param downsampler: Downsampler to use.
:param trim_cropped_edges: Drop mapped checker grid elements that are cropped in the control image.
:param output_colorized_full_path: Colorize the full checker image to debug the checker parsing.
:param output_colorized_down_path: Colorize the downscaled checker image to debug the checker parsing.
"""
if not output_downscaled_path and not output_quantized_path:
raise ValueError("At least one of output_up and output_down must be specified.")
control_image = Image.open(control_path).convert("1")
input_image = Image.open(input_path)
if control_image.size != input_image.size:
raise ValueError("Control image and input image must have the same dimensions.")
downscaled_image: Optional[ImageRef] = None
quantized_image: Optional[ImageRef] = None
if output_quantized_path:
quantized_image = ImageRef(Image.new("RGB", input_image.size))
extracted_boxes = extract_boxes(control_image)
if output_colorized_full_path:
extracted_boxes.to_colored_checkers(full=True).save(output_colorized_full_path)
if output_colorized_down_path:
extracted_boxes.to_colored_checkers(full=False).save(output_colorized_down_path)
if output_downscaled_path:
downscaled_image = ImageRef(Image.new("RGB", extracted_boxes.down_dimensions()))
boxes = extracted_boxes.boxes()
downsample_all(input_image=input_image, output_image=quantized_image, down_image=downscaled_image, boxes=boxes, sample_radius=sample_radius, downsampler=downsampler, trim_cropped_edges=trim_cropped_edges)
if quantized_image:
assert output_quantized_path
quantized_image.ref.save(output_quantized_path)
if downscaled_image:
assert output_downscaled_path
downscaled_image.ref.save(output_downscaled_path)
def main(cli_args: List[str]) -> None:
parser = argparse.ArgumentParser(description="Downsample and rescale image.")
parser.add_argument("--control", type=str, required=True, help="Path to control image.")
parser.add_argument("--input", type=str, required=True, help="Path to input image.")
parser.add_argument("--output-downscaled", type=str, help="Path to save the output downscaled image.")
parser.add_argument("--output-quantized", type=str, help="Path to save the output quantized image (downscaled and then upscaled to the original size).")
parser.add_argument("--sample-radius", type=int, default=None, help="Radius for sampling (Manhattan distance).")
parser.add_argument("--downsampler", choices=["box", "bilinear", "bicubic", "hamming", "lanczos"], default="box", help="Downsampler to use.")
parser.add_argument("--trim-cropped-edges", type=str2bool, default=False, help="Drop mapped checker grid elements that are cropped in the control image.")
parser.add_argument("--output-colorized-full", type=str, help="Colorize the full checker image to debug the checker parsing.")
parser.add_argument("--output-colorized-down", type=str, help="Colorize the downscaled checker image to debug the checker parsing.")
args = parser.parse_args(cli_args)
downsampler = Image.Resampling[args.downsampler.upper()]
controlled_downscale(
control_path=args.control,
input_path=args.input,
output_downscaled_path=args.output_downscaled,
output_quantized_path=args.output_quantized,
sample_radius=args.sample_radius,
downsampler=downsampler,
trim_cropped_edges=args.trim_cropped_edges,
output_colorized_full_path=args.output_colorized_full,
output_colorized_down_path=args.output_colorized_down,
)
if __name__ == "__main__":
main(sys.argv[1:])