Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
from argparse import ArgumentParser | |
from io import BytesIO | |
from os import listdir, makedirs | |
from os.path import basename, isdir, join, splitext | |
from random import randint | |
from typing import Union | |
from cairosvg import svg2png | |
import numpy as np | |
from imageio.v3 import imread, imwrite | |
from skimage.transform import rescale | |
from svgpathtools import CubicBezier, Line, QuadraticBezier, disvg, wsvg | |
import onnx | |
import onnxruntime as ort | |
def raster_bezier_hard(all_points, image_width=128, image_height=128, stroke_width=2., colors=None, white_background=True, mark=None): | |
if colors is None: | |
colors = [[0., 0., 0., 1.]] * len(all_points) | |
elif colors is list and colors[0] is not list: | |
colors = [colors] * len(all_points) | |
else: | |
colors = np.array(colors) | |
colors[:, :3] *= 255 | |
colors = ["rgb(" + ",".join(map(str, color[:3])) + ")" for color in colors] | |
background_color = "white" if white_background else None | |
all_points = all_points + 0 | |
all_points[:, :, 0] *= image_width | |
all_points[:, :, 1] *= image_height | |
bezier_curves = [numpy_to_bezier(points) for points in all_points] | |
attributes = [{"stroke": colors[i], "stroke-width": str(stroke_width), "fill": "none"} for i in range(len(bezier_curves))] | |
if mark is not None: | |
mark = mark + 0 | |
mark[0] *= image_width | |
mark[1] *= image_height | |
mark_points = np.vstack([mark - stroke_width, mark + stroke_width]) | |
mark_path = numpy_to_bezier(mark_points) | |
mark_attr = {"stroke": "blue", "stroke-width": str(stroke_width * 2), "fill": "blue"} | |
bezier_curves.append(mark_path) | |
attributes.append(mark_attr) | |
svg_attributes = {"width": f"{image_width}px", "height": f"{image_height}px"} | |
svg_string = disvg(bezier_curves, attributes=attributes, svg_attributes=svg_attributes, paths2Drawing=True).tostring() | |
png_string = svg2png(bytestring=svg_string, background_color=background_color) | |
image = imread(BytesIO(png_string), extension=".png") | |
output = image.astype("float32") | |
output /= 255 | |
output = np.moveaxis(output, 2, 0) | |
return output, all_points | |
def diff_remaining_img(raster_img: np.ndarray, recons_img: np.ndarray): | |
remaining_img = raster_img.copy() | |
tmp_remaining_img = remaining_img.copy() | |
tmp_remaining_img[tmp_remaining_img < 1] = 0. | |
recons_img[recons_img < 1] = 0. | |
same_mask = (tmp_remaining_img == recons_img).copy() | |
remaining_img[same_mask] = 1 | |
return remaining_img | |
def place_point_on_img(image, point): | |
if np.any(point == point.astype(int)): | |
point_idx_start = point.astype(int) | |
point_idx_end = point.astype(int) + 1 | |
else: | |
point_idx_start = np.floor(point).astype(int) | |
point_idx_end = np.ceil(point).astype(int) | |
if image.shape[0] == 3: | |
image[0, point_idx_start[1]:point_idx_end[1], point_idx_start[0]:point_idx_end[0]] = 0 | |
image[1, point_idx_start[1]:point_idx_end[1], point_idx_start[0]:point_idx_end[0]] = 0 | |
image[2, point_idx_start[1]:point_idx_end[1], point_idx_start[0]:point_idx_end[0]] = 1 | |
else: | |
image[0, point_idx_start[1]:point_idx_end[1], point_idx_start[0]:point_idx_end[0]] = 0.5 | |
return image | |
def rgb_to_grayscale(image: np.ndarray): | |
image = image[0] * .2989 + image[1] *.587 + image[2] *.114 | |
return image | |
def sample_black_pixel(image: np.ndarray): | |
image = rgb_to_grayscale(image.copy()) | |
black_indices = np.argwhere(~np.isclose(image, np.ones_like(image, dtype="float32"), atol=0.5) != 0) | |
black_idx = black_indices[randint(0, len(black_indices) - 1)].astype("float32") | |
black_idx[0] /= image.shape[0] | |
black_idx[1] /= image.shape[1] | |
black_idx = black_idx[[1, 0]] | |
return black_idx | |
def numpy_to_bezier(points: np.ndarray): | |
if len(points) == 2: | |
return Line(*(complex(point[0], point[1]) for point in points)) | |
elif len(points) == 3: | |
return QuadraticBezier(*(complex(point[0], point[1]) for point in points)) | |
elif len(points) == 4: | |
return CubicBezier(*(complex(point[0], point[1]) for point in points)) | |
def center_on_point(image, point, new_width=None, new_height=None): | |
_, height, width = image.shape | |
if new_width is None: | |
new_width = width | |
if new_height is None: | |
new_height = height | |
half_width = round(width / 2) | |
half_height = round(height / 2) | |
point = point.copy() | |
point[0] *= width | |
point[1] *= height | |
point = point.round().astype(int) | |
top=half_height - (half_height - point[1]) | |
left=half_width - (half_width - point[0]) | |
padded = np.pad(image, ((0, 0), (half_height, half_height), (half_width, half_width)), constant_values=1) | |
cropped = padded[:, top:top+new_height, left:left+new_width] | |
return cropped | |
def reverse_center_on_point(paths, point): | |
for i in range(len(paths)): | |
paths[i, :, 0] -= 0.5 - point[i, 0] | |
paths[i, :, 1] -= 0.5 - point[i, 1] | |
def save_as_svg(curves: np.ndarray, filename, img_width, img_height, stroke_width=2.0): | |
svg_paths = [numpy_to_bezier(curve) for curve in curves] | |
output_attributes = [{"stroke": "black", "stroke-width": stroke_width, "stroke-linecap": "round", "fill": "none"}] * len(svg_paths) | |
svg_attributes = {"width": f"{img_width}px", "height": f"{img_height}px"} | |
wsvg(svg_paths, attributes=output_attributes, svg_attributes=svg_attributes, filename=filename) | |
def save_as_png(filename: str, image: np.ndarray): | |
image = np.moveaxis(image.copy(), 0, 2) | |
image *= 255 | |
imwrite(filename, image.round().astype("uint8")) | |
def setup_model(model_path): | |
model = onnx.load(model_path) | |
onnx.checker.check_model(model) | |
ort_sess = ort.InferenceSession(model_path, providers=["CUDAExecutionProvider"]) | |
return ort_sess | |
def vectorize_image(input_image_path, model: Union[str, ort.InferenceSession], output=None, threshold_ratio=0.1, stroke_width=0.512, width=512, height=512, binarization_threshold=0, force_grayscale=False): | |
if type(model) is str: | |
ort_sess = setup_model(model) | |
elif type(model) is ort.InferenceSession: | |
ort_sess = model | |
else: | |
raise ValueError("Invalid value for the model argument") | |
# Get dimensions expected by the model | |
_, channels, height, width = ort_sess.get_inputs()[0].shape | |
input_image = imread(input_image_path, pilmode="RGB") / 255 | |
original_height, original_width, _ = input_image.shape | |
# scale and white pad image to dimensions expected by the model | |
if original_height >= original_width: | |
scale = height / original_height | |
else: | |
scale = width / original_width | |
print(f"Rescale factor: {scale}") | |
input_image = rescale(input_image, scale, channel_axis=2, order=5) | |
scaled_height, scaled_width = input_image.shape[:2] | |
raster_img = np.ones((height, width, channels), dtype="float32") | |
raster_img[:input_image.shape[0], :input_image.shape[1]] = input_image | |
# convert CHW | |
raster_img = np.moveaxis(raster_img, 2, 0) | |
if binarization_threshold > 0: | |
raster_img[raster_img < binarization_threshold] = 0. | |
width = raster_img.shape[2] | |
height = raster_img.shape[1] | |
curve_pixels = (raster_img < .5).sum() | |
threshold = curve_pixels * threshold_ratio | |
print(f"Reconstruction candidate pixels: {curve_pixels}") | |
print(f"Reconstruction threshold: {threshold.astype(int)}") | |
recons_points = None | |
recons_img = np.ones_like(raster_img, dtype="float32") | |
remaining_img = raster_img.copy() | |
while (remaining_img < .5).sum() > threshold: | |
remaining_img = diff_remaining_img(raster_img, recons_img) | |
try: | |
mark = sample_black_pixel(remaining_img) | |
except ValueError: | |
break | |
centered_img = remaining_img.copy() | |
mark_real = mark.copy() | |
mark_real[0] *= width | |
mark_real[1] *= height | |
centered_img = place_point_on_img(centered_img, mark_real) | |
centered_img = center_on_point(centered_img, mark) | |
result = ort_sess.run(None, {"marked_raster_image": np.expand_dims(centered_img, 0)}) | |
points = result[0] | |
reverse_center_on_point(points, np.expand_dims(mark, 0)) | |
points = np.expand_dims(points, 1) | |
if recons_points is None: | |
recons_points = points | |
else: | |
recons_points = np.concatenate((recons_points, points), axis=1) | |
recons_img, _ = raster_bezier_hard(recons_points.squeeze(0), image_width=width, image_height=height, stroke_width=stroke_width) | |
yield np.moveaxis(recons_img, 0, 2) | |
output_filepath = splitext(basename(input_image_path))[0] + ".svg" | |
if output is not None: | |
if isdir(output): | |
makedirs(output, exist_ok=True) | |
output_filepath = join(output, output_filepath) | |
elif type(output) is str and output.endswith(".svg"): | |
output_filepath = output | |
recons_points = recons_points.squeeze(0) | |
recons_points[:, :, 0] *= width * (1 / scale) | |
recons_points[:, :, 1] *= height * (1 / scale) | |
save_as_svg(recons_points, output_filepath, original_width, original_height, stroke_width=stroke_width) | |
def main(): | |
parser = ArgumentParser(description="Inference script for the marked curve reconstruction model in ONNX format.") | |
parser.add_argument("model", metavar="FIlE", help="path to the *.onnx file") | |
parser.add_argument("-i", "--input_images", nargs="*", metavar="FILE", help="one or multiple paths to raster images that should be vectorized.") | |
parser.add_argument("-d", "--input_dir", metavar="DIR", help="path to a directory of raster images that should be vectorized.") | |
parser.add_argument("-o", "--output", help="optional output directory or file") | |
parser.add_argument("--threshold_ratio", "-t", default=0.1, type=float, help="The ratio of black pixels which need to be reconstructed before the algorithm terminates") | |
parser.add_argument("--stroke_width", "-r", default=0.512, type=float, help="stroke width if it should be different from the one specified in the model") | |
parser.add_argument("--seed", "-s", default=1234, help="Fixed random number generation seed. Set to negative number to deactivate") | |
parser.add_argument("-b", "--binarization_threshold", default=0., type=float, help="Set to a value in (0,1) to binarize the image.") | |
args = parser.parse_args() | |
if args.seed >= 0: | |
np.random.seed(args.seed) | |
if args.input_images is not None: | |
input_images = args.input_images | |
elif args.input_dir is not None and isdir(args.input_dir): | |
input_images = [join(args.input_dir, f) for f in listdir(args.input_dir)] | |
else: | |
print("-i or -d need to be passed") | |
exit(1) | |
for input_image in input_images: | |
vectorize_image(input_image, args.model, output=args.output, threshold_ratio=args.threshold_ratio, stroke_width=args.stroke_width, binarization_threshold=args.binarization_threshold, force_grayscale=False) | |
if __name__ == "__main__": | |
main() | |