Spaces:
Build error
Build error
# python3.7 | |
"""Utility functions for visualizing results.""" | |
import base64 | |
import os.path | |
import cv2 | |
import numpy as np | |
from bs4 import BeautifulSoup | |
__all__ = [ | |
'get_grid_shape', 'get_blank_image', 'load_image', 'save_image', | |
'resize_image', 'postprocess_image', 'add_text_to_image', | |
'parse_image_size', 'fuse_images', 'HtmlPageVisualizer', 'HtmlPageReader', | |
'VideoReader', 'VideoWriter' | |
] | |
def get_grid_shape(size, row=0, col=0, is_portrait=False): | |
"""Gets the shape of a grid based on the size. | |
This function makes greatest effort on making the output grid square if | |
neither `row` nor `col` is set. If `is_portrait` is set as `False`, the | |
height will always be equal to or smaller than the width. For example, if | |
input `size = 16`, output shape will be `(4, 4)`; if input `size = 15`, | |
output shape will be (3, 5). Otherwise, the height will always be equal to | |
or larger than the width. | |
Args: | |
size: Size (height * width) of the target grid. | |
is_portrait: Whether to return a portrait size of a landscape size. | |
(default: False) | |
Returns: | |
A two-element tuple, representing height and width respectively. | |
""" | |
assert isinstance(size, int) | |
assert isinstance(row, int) | |
assert isinstance(col, int) | |
if size == 0: | |
return (0, 0) | |
if row > 0 and col > 0 and row * col != size: | |
row = 0 | |
col = 0 | |
if row > 0 and size % row == 0: | |
return (row, size // row) | |
if col > 0 and size % col == 0: | |
return (size // col, col) | |
row = int(np.sqrt(size)) | |
while row > 0: | |
if size % row == 0: | |
col = size // row | |
break | |
row = row - 1 | |
return (col, row) if is_portrait else (row, col) | |
def get_blank_image(height, width, channels=3, is_black=True): | |
"""Gets a blank image, either white of black. | |
NOTE: This function will always return an image with `RGB` channel order for | |
color image and pixel range [0, 255]. | |
Args: | |
height: Height of the returned image. | |
width: Width of the returned image. | |
channels: Number of channels. (default: 3) | |
is_black: Whether to return a black image. (default: True) | |
""" | |
shape = (height, width, channels) | |
if is_black: | |
return np.zeros(shape, dtype=np.uint8) | |
return np.ones(shape, dtype=np.uint8) * 255 | |
def load_image(path, image_channels=3): | |
"""Loads an image from disk. | |
NOTE: This function will always return an image with `RGB` channel order for | |
color image and pixel range [0, 255]. | |
Args: | |
path: Path to load the image from. | |
image_channels: Number of image channels of returned image. This field | |
is employed since `cv2.imread()` will always return a 3-channel | |
image, even for grayscale image. | |
Returns: | |
An image with dtype `np.ndarray`, or `None` if `path` does not exist. | |
""" | |
if not os.path.isfile(path): | |
return None | |
assert image_channels in [1, 3] | |
image = cv2.imread(path) | |
assert image.ndim == 3 and image.shape[2] == 3 | |
if image_channels == 1: | |
return image[:, :, 0:1] | |
return image[:, :, ::-1] | |
def save_image(path, image): | |
"""Saves an image to disk. | |
NOTE: The input image (if colorful) is assumed to be with `RGB` channel | |
order and pixel range [0, 255]. | |
Args: | |
path: Path to save the image to. | |
image: Image to save. | |
""" | |
if image is None: | |
return | |
assert image.ndim == 3 and image.shape[2] in [1, 3] | |
cv2.imwrite(path, image[:, :, ::-1]) | |
def resize_image(image, *args, **kwargs): | |
"""Resizes image. | |
This is a wrap of `cv2.resize()`. | |
NOTE: THe channel order of the input image will not be changed. | |
Args: | |
image: Image to resize. | |
""" | |
if image is None: | |
return None | |
assert image.ndim == 3 and image.shape[2] in [1, 3] | |
image = cv2.resize(image, *args, **kwargs) | |
if image.ndim == 2: | |
return image[:, :, np.newaxis] | |
return image | |
def postprocess_image(image, min_val=-1.0, max_val=1.0, data_format='NCHW'): | |
"""Post-processes image to pixel range [0, 255] with dtype `uint8`. | |
NOTE: The returned image will always be with `HWC` format. | |
Args: | |
min_val: Minimum value of the input image. | |
max_val: Maximum value of the input image. | |
data_format: Data format of the input image. Supporting `NCHW`, `NHWC`, | |
`CHW`, `HWC`. | |
Returns: | |
The post-processed image. | |
Raises: | |
NotImplementedError: If the input `data_format` is not support. | |
""" | |
assert isinstance(image, np.ndarray) | |
image = image.astype(np.float64) | |
image = (image - min_val) * 255 / (max_val - min_val) | |
image = np.clip(image + 0.5, 0, 255).astype(np.uint8) | |
data_format = data_format.upper() | |
if data_format == 'NCHW': | |
assert image.ndim == 4 and image.shape[1] in [1, 3] | |
return image.transpose(0, 2, 3, 1) | |
if data_format == 'NHWC': | |
assert image.ndim == 4 and image.shape[3] in [1, 3] | |
return image | |
if data_format == 'CHW': | |
assert image.ndim == 3 and image.shape[0] in [1, 3] | |
return image.transpose(1, 2, 0) | |
if data_format == 'HWC': | |
assert image.ndim == 3 and image.shape[2] in [1, 3] | |
return image | |
raise NotImplementedError(f'Data format `{data_format}` is not supported!') | |
def add_text_to_image(image, | |
text='', | |
position=None, | |
font=cv2.FONT_HERSHEY_TRIPLEX, | |
font_size=1.0, | |
line_type=cv2.LINE_8, | |
line_width=1, | |
color=(255, 255, 255)): | |
"""Overlays text on given image. | |
NOTE: The input image is assumed to be with `RGB` channel order. | |
Args: | |
image: The image to overlay text on. | |
text: Text content to overlay on the image. (default: '') | |
position: Target position (bottom-left corner) to add text. If not set, | |
center of the image will be used by default. (default: None) | |
font: Font of the text added. (default: cv2.FONT_HERSHEY_TRIPLEX) | |
font_size: Font size of the text added. (default: 1.0) | |
line_type: Line type used to depict the text. (default: cv2.LINE_8) | |
line_width: Line width used to depict the text. (default: 1) | |
color: Color of the text added in `RGB` channel order. (default: | |
(255, 255, 255)) | |
Returns: | |
An image with target text overlayed on. | |
""" | |
if image is None or not text: | |
return image | |
cv2.putText(img=image, | |
text=text, | |
org=position, | |
fontFace=font, | |
fontScale=font_size, | |
color=color, | |
thickness=line_width, | |
lineType=line_type, | |
bottomLeftOrigin=False) | |
return image | |
def parse_image_size(obj): | |
"""Parses object to a pair of image size, i.e., (width, height). | |
Args: | |
obj: The input object to parse image size from. | |
Returns: | |
A two-element tuple, indicating image width and height respectively. | |
Raises: | |
If the input is invalid, i.e., neither a list or tuple, nor a string. | |
""" | |
if obj is None or obj == '': | |
width = height = 0 | |
elif isinstance(obj, int): | |
width = height = obj | |
elif isinstance(obj, (list, tuple, np.ndarray)): | |
numbers = tuple(obj) | |
if len(numbers) == 0: | |
width = height = 0 | |
elif len(numbers) == 1: | |
width = height = numbers[0] | |
elif len(numbers) == 2: | |
width = numbers[0] | |
height = numbers[1] | |
else: | |
raise ValueError(f'At most two elements for image size.') | |
elif isinstance(obj, str): | |
splits = obj.replace(' ', '').split(',') | |
numbers = tuple(map(int, splits)) | |
if len(numbers) == 0: | |
width = height = 0 | |
elif len(numbers) == 1: | |
width = height = numbers[0] | |
elif len(numbers) == 2: | |
width = numbers[0] | |
height = numbers[1] | |
else: | |
raise ValueError(f'At most two elements for image size.') | |
else: | |
raise ValueError(f'Invalid type of input: {type(obj)}!') | |
return (max(0, width), max(0, height)) | |
def fuse_images(images, | |
image_size=None, | |
row=0, | |
col=0, | |
is_row_major=True, | |
is_portrait=False, | |
row_spacing=0, | |
col_spacing=0, | |
border_left=0, | |
border_right=0, | |
border_top=0, | |
border_bottom=0, | |
black_background=True): | |
"""Fuses a collection of images into an entire image. | |
Args: | |
images: A collection of images to fuse. Should be with shape [num, | |
height, width, channels]. | |
image_size: This field is used to resize the image before fusion. `0` | |
disables resizing. (default: None) | |
row: Number of rows used for image fusion. If not set, this field will | |
be automatically assigned based on `col` and total number of images. | |
(default: None) | |
col: Number of columns used for image fusion. If not set, this field | |
will be automatically assigned based on `row` and total number of | |
images. (default: None) | |
is_row_major: Whether the input images should be arranged row-major or | |
column-major. (default: True) | |
is_portrait: Only active when both `row` and `col` should be assigned | |
automatically. (default: False) | |
row_spacing: Space between rows. (default: 0) | |
col_spacing: Space between columns. (default: 0) | |
border_left: Width of left border. (default: 0) | |
border_right: Width of right border. (default: 0) | |
border_top: Width of top border. (default: 0) | |
border_bottom: Width of bottom border. (default: 0) | |
Returns: | |
The fused image. | |
Raises: | |
ValueError: If the input `images` is not with shape [num, height, width, | |
width]. | |
""" | |
if images is None: | |
return images | |
if images.ndim != 4: | |
raise ValueError(f'Input `images` should be with shape [num, height, ' | |
f'width, channels], but {images.shape} is received!') | |
num, image_height, image_width, channels = images.shape | |
width, height = parse_image_size(image_size) | |
height = height or image_height | |
width = width or image_width | |
row, col = get_grid_shape(num, row=row, col=col, is_portrait=is_portrait) | |
fused_height = ( | |
height * row + row_spacing * (row - 1) + border_top + border_bottom) | |
fused_width = ( | |
width * col + col_spacing * (col - 1) + border_left + border_right) | |
fused_image = get_blank_image( | |
fused_height, fused_width, channels=channels, is_black=black_background) | |
images = images.reshape(row, col, image_height, image_width, channels) | |
if not is_row_major: | |
images = images.transpose(1, 0, 2, 3, 4) | |
for i in range(row): | |
y = border_top + i * (height + row_spacing) | |
for j in range(col): | |
x = border_left + j * (width + col_spacing) | |
if height != image_height or width != image_width: | |
image = cv2.resize(images[i, j], (width, height)) | |
else: | |
image = images[i, j] | |
fused_image[y:y + height, x:x + width] = image | |
return fused_image | |
def get_sortable_html_header(column_name_list, sort_by_ascending=False): | |
"""Gets header for sortable html page. | |
Basically, the html page contains a sortable table, where user can sort the | |
rows by a particular column by clicking the column head. | |
Example: | |
column_name_list = [name_1, name_2, name_3] | |
header = get_sortable_html_header(column_name_list) | |
footer = get_sortable_html_footer() | |
sortable_table = ... | |
html_page = header + sortable_table + footer | |
Args: | |
column_name_list: List of column header names. | |
sort_by_ascending: Default sorting order. If set as `True`, the html | |
page will be sorted by ascending order when the header is clicked | |
for the first time. | |
Returns: | |
A string, which represents for the header for a sortable html page. | |
""" | |
header = '\n'.join([ | |
'<script type="text/javascript">', | |
'var column_idx;', | |
'var sort_by_ascending = ' + str(sort_by_ascending).lower() + ';', | |
'', | |
'function sorting(tbody, column_idx){', | |
' this.column_idx = column_idx;', | |
' Array.from(tbody.rows)', | |
' .sort(compareCells)', | |
' .forEach(function(row) { tbody.appendChild(row); })', | |
' sort_by_ascending = !sort_by_ascending;', | |
'}', | |
'', | |
'function compareCells(row_a, row_b) {', | |
' var val_a = row_a.cells[column_idx].innerText;', | |
' var val_b = row_b.cells[column_idx].innerText;', | |
' var flag = sort_by_ascending ? 1 : -1;', | |
' return flag * (val_a > val_b ? 1 : -1);', | |
'}', | |
'</script>', | |
'', | |
'<html>', | |
'', | |
'<head>', | |
'<style>', | |
' table {', | |
' border-spacing: 0;', | |
' border: 1px solid black;', | |
' }', | |
' th {', | |
' cursor: pointer;', | |
' }', | |
' th, td {', | |
' text-align: left;', | |
' vertical-align: middle;', | |
' border-collapse: collapse;', | |
' border: 0.5px solid black;', | |
' padding: 8px;', | |
' }', | |
' tr:nth-child(even) {', | |
' background-color: #d2d2d2;', | |
' }', | |
'</style>', | |
'</head>', | |
'', | |
'<body>', | |
'', | |
'<table>', | |
'<thead>', | |
'<tr>', | |
'']) | |
for idx, name in enumerate(column_name_list): | |
header += f' <th onclick="sorting(tbody, {idx})">{name}</th>\n' | |
header += '</tr>\n' | |
header += '</thead>\n' | |
header += '<tbody id="tbody">\n' | |
return header | |
def get_sortable_html_footer(): | |
"""Gets footer for sortable html page. | |
Check function `get_sortable_html_header()` for more details. | |
""" | |
return '</tbody>\n</table>\n\n</body>\n</html>\n' | |
def encode_image_to_html_str(image, image_size=None): | |
"""Encodes an image to html language. | |
NOTE: Input image is always assumed to be with `RGB` channel order. | |
Args: | |
image: The input image to encode. Should be with `RGB` channel order. | |
image_size: This field is used to resize the image before encoding. `0` | |
disables resizing. (default: None) | |
Returns: | |
A string which represents the encoded image. | |
""" | |
if image is None: | |
return '' | |
assert image.ndim == 3 and image.shape[2] in [1, 3] | |
# Change channel order to `BGR`, which is opencv-friendly. | |
image = image[:, :, ::-1] | |
# Resize the image if needed. | |
width, height = parse_image_size(image_size) | |
if height or width: | |
height = height or image.shape[0] | |
width = width or image.shape[1] | |
image = cv2.resize(image, (width, height)) | |
# Encode the image to html-format string. | |
encoded_image = cv2.imencode('.jpg', image)[1].tostring() | |
encoded_image_base64 = base64.b64encode(encoded_image).decode('utf-8') | |
html_str = f'<img src="data:image/jpeg;base64, {encoded_image_base64}"/>' | |
return html_str | |
def decode_html_str_to_image(html_str, image_size=None): | |
"""Decodes image from html. | |
Args: | |
html_str: Image string parsed from html. | |
image_size: This field is used to resize the image after decoding. `0` | |
disables resizing. (default: None) | |
Returns: | |
An image with `RGB` channel order. | |
""" | |
if not html_str: | |
return None | |
assert isinstance(html_str, str) | |
image_str = html_str.split(',')[-1] | |
encoded_image = base64.b64decode(image_str) | |
encoded_image_numpy = np.frombuffer(encoded_image, dtype=np.uint8) | |
image = cv2.imdecode(encoded_image_numpy, flags=cv2.IMREAD_COLOR) | |
# Resize the image if needed. | |
width, height = parse_image_size(image_size) | |
if height or width: | |
height = height or image.shape[0] | |
width = width or image.shape[1] | |
image = cv2.resize(image, (width, height)) | |
return image[:, :, ::-1] | |
class HtmlPageVisualizer(object): | |
"""Defines the html page visualizer. | |
This class can be used to visualize image results as html page. Basically, | |
it is based on an html-format sorted table with helper functions | |
`get_sortable_html_header()`, `get_sortable_html_footer()`, and | |
`encode_image_to_html_str()`. To simplify the usage, specifying the | |
following fields are enough to create a visualization page: | |
(1) num_rows: Number of rows of the table (header-row exclusive). | |
(2) num_cols: Number of columns of the table. | |
(3) header contents (optional): Title of each column. | |
NOTE: `grid_size` can be used to assign `num_rows` and `num_cols` | |
automatically. | |
Example: | |
html = HtmlPageVisualizer(num_rows, num_cols) | |
html.set_headers([...]) | |
for i in range(num_rows): | |
for j in range(num_cols): | |
html.set_cell(i, j, text=..., image=..., highlight=False) | |
html.save('visualize.html') | |
""" | |
def __init__(self, | |
num_rows=0, | |
num_cols=0, | |
grid_size=0, | |
is_portrait=True, | |
viz_size=None): | |
if grid_size > 0: | |
num_rows, num_cols = get_grid_shape( | |
grid_size, row=num_rows, col=num_cols, is_portrait=is_portrait) | |
assert num_rows > 0 and num_cols > 0 | |
self.num_rows = num_rows | |
self.num_cols = num_cols | |
self.viz_size = parse_image_size(viz_size) | |
self.headers = ['' for _ in range(self.num_cols)] | |
self.cells = [[{ | |
'text': '', | |
'image': '', | |
'highlight': False, | |
} for _ in range(self.num_cols)] for _ in range(self.num_rows)] | |
def set_header(self, col_idx, content): | |
"""Sets the content of a particular header by column index.""" | |
self.headers[col_idx] = content | |
def set_headers(self, contents): | |
"""Sets the contents of all headers.""" | |
if isinstance(contents, str): | |
contents = [contents] | |
assert isinstance(contents, (list, tuple)) | |
assert len(contents) == self.num_cols | |
for col_idx, content in enumerate(contents): | |
self.set_header(col_idx, content) | |
def set_cell(self, row_idx, col_idx, text='', image=None, highlight=False): | |
"""Sets the content of a particular cell. | |
Basically, a cell contains some text as well as an image. Both text and | |
image can be empty. | |
Args: | |
row_idx: Row index of the cell to edit. | |
col_idx: Column index of the cell to edit. | |
text: Text to add into the target cell. (default: None) | |
image: Image to show in the target cell. Should be with `RGB` | |
channel order. (default: None) | |
highlight: Whether to highlight this cell. (default: False) | |
""" | |
self.cells[row_idx][col_idx]['text'] = text | |
self.cells[row_idx][col_idx]['image'] = encode_image_to_html_str( | |
image, self.viz_size) | |
self.cells[row_idx][col_idx]['highlight'] = bool(highlight) | |
def save(self, save_path): | |
"""Saves the html page.""" | |
html = '' | |
for i in range(self.num_rows): | |
html += f'<tr>\n' | |
for j in range(self.num_cols): | |
text = self.cells[i][j]['text'] | |
image = self.cells[i][j]['image'] | |
if self.cells[i][j]['highlight']: | |
color = ' bgcolor="#FF8888"' | |
else: | |
color = '' | |
if text: | |
html += f' <td{color}>{text}<br><br>{image}</td>\n' | |
else: | |
html += f' <td{color}>{image}</td>\n' | |
html += f'</tr>\n' | |
header = get_sortable_html_header(self.headers) | |
footer = get_sortable_html_footer() | |
with open(save_path, 'w') as f: | |
f.write(header + html + footer) | |
class HtmlPageReader(object): | |
"""Defines the html page reader. | |
This class can be used to parse results from the visualization page | |
generated by `HtmlPageVisualizer`. | |
Example: | |
html = HtmlPageReader(html_path) | |
for j in range(html.num_cols): | |
header = html.get_header(j) | |
for i in range(html.num_rows): | |
for j in range(html.num_cols): | |
text = html.get_text(i, j) | |
image = html.get_image(i, j, image_size=None) | |
""" | |
def __init__(self, html_path): | |
"""Initializes by loading the content from file.""" | |
self.html_path = html_path | |
if not os.path.isfile(html_path): | |
raise ValueError(f'File `{html_path}` does not exist!') | |
# Load content. | |
with open(html_path, 'r') as f: | |
self.html = BeautifulSoup(f, 'html.parser') | |
# Parse headers. | |
thead = self.html.find('thead') | |
headers = thead.findAll('th') | |
self.headers = [] | |
for header in headers: | |
self.headers.append(header.text) | |
self.num_cols = len(self.headers) | |
# Parse cells. | |
tbody = self.html.find('tbody') | |
rows = tbody.findAll('tr') | |
self.cells = [] | |
for row in rows: | |
cells = row.findAll('td') | |
self.cells.append([]) | |
for cell in cells: | |
self.cells[-1].append({ | |
'text': cell.text, | |
'image': cell.find('img')['src'], | |
}) | |
assert len(self.cells[-1]) == self.num_cols | |
self.num_rows = len(self.cells) | |
def get_header(self, j): | |
"""Gets header for a particular column.""" | |
return self.headers[j] | |
def get_text(self, i, j): | |
"""Gets text from a particular cell.""" | |
return self.cells[i][j]['text'] | |
def get_image(self, i, j, image_size=None): | |
"""Gets image from a particular cell.""" | |
return decode_html_str_to_image(self.cells[i][j]['image'], image_size) | |
class VideoReader(object): | |
"""Defines the video reader. | |
This class can be used to read frames from a given video. | |
""" | |
def __init__(self, path): | |
"""Initializes the video reader by loading the video from disk.""" | |
if not os.path.isfile(path): | |
raise ValueError(f'Video `{path}` does not exist!') | |
self.path = path | |
self.video = cv2.VideoCapture(path) | |
assert self.video.isOpened() | |
self.position = 0 | |
self.length = int(self.video.get(cv2.CAP_PROP_FRAME_COUNT)) | |
self.frame_height = int(self.video.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
self.frame_width = int(self.video.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
self.fps = self.video.get(cv2.CAP_PROP_FPS) | |
def __del__(self): | |
"""Releases the opened video.""" | |
self.video.release() | |
def read(self, position=None): | |
"""Reads a certain frame. | |
NOTE: The returned frame is assumed to be with `RGB` channel order. | |
Args: | |
position: Optional. If set, the reader will read frames from the | |
exact position. Otherwise, the reader will read next frames. | |
(default: None) | |
""" | |
if position is not None and position < self.length: | |
self.video.set(cv2.CAP_PROP_POS_FRAMES, position) | |
self.position = position | |
success, frame = self.video.read() | |
self.position = self.position + 1 | |
return frame[:, :, ::-1] if success else None | |
class VideoWriter(object): | |
"""Defines the video writer. | |
This class can be used to create a video. | |
NOTE: `.avi` and `DIVX` is the most recommended codec format since it does | |
not rely on other dependencies. | |
""" | |
def __init__(self, path, frame_height, frame_width, fps=24, codec='DIVX'): | |
"""Creates the video writer.""" | |
self.path = path | |
self.frame_height = frame_height | |
self.frame_width = frame_width | |
self.fps = fps | |
self.codec = codec | |
self.video = cv2.VideoWriter(filename=path, | |
fourcc=cv2.VideoWriter_fourcc(*codec), | |
fps=fps, | |
frameSize=(frame_width, frame_height)) | |
def __del__(self): | |
"""Releases the opened video.""" | |
self.video.release() | |
def write(self, frame): | |
"""Writes a target frame. | |
NOTE: The input frame is assumed to be with `RGB` channel order. | |
""" | |
self.video.write(frame[:, :, ::-1]) | |