import json from tqdm import tqdm import numpy as np import matplotlib.pyplot as plt import xml.etree.ElementTree as ET from xml.dom import minidom import os from PIL import Image import matplotlib.animation as animation import copy from PIL import ImageEnhance import colorsys import matplotlib.colors as mcolors from matplotlib.collections import LineCollection from matplotlib.patheffects import withStroke import random import warnings from matplotlib.figure import Figure from io import BytesIO from matplotlib.animation import FuncAnimation, FFMpegWriter, PillowWriter import requests import zipfile import base64 warnings.filterwarnings("ignore") def get_svg_content(svg_path): with open(svg_path, "r") as file: return file.read() def download_file(url, filename): response = requests.get(url) with open(filename, "wb") as f: f.write(response.content) def unzip_file(filename, extract_to="."): with zipfile.ZipFile(filename, "r") as zip_ref: zip_ref.extractall(extract_to) def get_base64_encoded_gif(gif_path): with open(gif_path, "rb") as gif_file: return base64.b64encode(gif_file.read()).decode("utf-8") def load_and_pad_img_dir(file_dir): image_path = os.path.join(file_dir) image = Image.open(image_path) width, height = image.size ratio = min(224 / width, 224 / height) image = image.resize((int(width * ratio), int(height * ratio))) width, height = image.size if height < 224: # If width is shorter than height pad top and bottom. top_padding = (224 - height) // 2 bottom_padding = 224 - height - top_padding padded_image = Image.new("RGB", (width, 224), (255, 255, 255)) padded_image.paste(image, (0, top_padding)) else: # Otherwise pad left and right. left_padding = (224 - width) // 2 right_padding = 224 - width - left_padding padded_image = Image.new("RGB", (224, height), (255, 255, 255)) padded_image.paste(image, (left_padding, 0)) return padded_image def plot_ink(ink, ax, lw=1.8, input_image=None, with_path=True, path_color="white"): if input_image is not None: img = copy.deepcopy(input_image) enhancer = ImageEnhance.Brightness(img) img = enhancer.enhance(0.45) ax.imshow(img) base_colors = plt.cm.get_cmap("rainbow", len(ink.strokes)) for i, stroke in enumerate(ink.strokes): x, y = np.array(stroke.x), np.array(stroke.y) base_color = base_colors(len(ink.strokes) - 1 - i) hsv_color = colorsys.rgb_to_hsv(*base_color[:3]) darker_color = colorsys.hsv_to_rgb( hsv_color[0], hsv_color[1], max(0, hsv_color[2] * 0.65) ) colors = [ mcolors.to_rgba(darker_color, alpha=1 - (0.5 * j / len(x))) for j in range(len(x)) ] points = np.array([x, y]).T.reshape(-1, 1, 2) segments = np.concatenate([points[:-1], points[1:]], axis=1) lc = LineCollection(segments, colors=colors, linewidth=lw) if with_path: lc.set_path_effects( [withStroke(linewidth=lw * 1.25, foreground=path_color)] ) ax.add_collection(lc) ax.set_xlim(0, 224) ax.set_ylim(0, 224) ax.invert_yaxis() def plot_ink_to_video( ink, output_name, lw=1.8, input_image=None, path_color="white", fps=30 ): fig, ax = plt.subplots(figsize=(4, 4), dpi=150) if input_image is not None: img = copy.deepcopy(input_image) enhancer = ImageEnhance.Brightness(img) img = enhancer.enhance(0.45) ax.imshow(img) ax.set_xlim(0, 224) ax.set_ylim(0, 224) ax.invert_yaxis() ax.axis("off") base_colors = plt.cm.get_cmap("rainbow", len(ink.strokes)) all_points = sum([len(stroke.x) for stroke in ink.strokes], 0) def update(frame): ax.clear() if input_image is not None: ax.imshow(img) ax.set_xlim(0, 224) ax.set_ylim(0, 224) ax.invert_yaxis() ax.axis("off") points_drawn = 0 for stroke_index, stroke in enumerate(ink.strokes): x, y = np.array(stroke.x), np.array(stroke.y) points = np.array([x, y]).T.reshape(-1, 1, 2) segments = np.concatenate([points[:-1], points[1:]], axis=1) base_color = base_colors(len(ink.strokes) - 1 - stroke_index) hsv_color = colorsys.rgb_to_hsv(*base_color[:3]) darker_color = colorsys.hsv_to_rgb( hsv_color[0], hsv_color[1], max(0, hsv_color[2] * 0.65) ) visible_segments = ( segments[: frame - points_drawn] if frame - points_drawn < len(segments) else segments ) colors = [ mcolors.to_rgba( darker_color, alpha=1 - (0.5 * j / len(visible_segments)) ) for j in range(len(visible_segments)) ] if len(visible_segments) > 0: lc = LineCollection(visible_segments, colors=colors, linewidth=lw) lc.set_path_effects( [withStroke(linewidth=lw * 1.25, foreground=path_color)] ) ax.add_collection(lc) points_drawn += len(segments) if points_drawn >= frame: break ani = FuncAnimation(fig, update, frames=all_points + 1, blit=False) Writer = FFMpegWriter(fps=fps) plt.tight_layout() ani.save(output_name, writer=Writer) plt.close(fig) class Stroke: def __init__(self, list_of_coordinates=None) -> None: self.x = [] self.y = [] if list_of_coordinates: for point in list_of_coordinates: self.x.append(point[0]) self.y.append(point[1]) def __len__(self): return len(self.x) def __getitem__(self, index): return (self.x[index], self.y[index]) class Ink: def __init__(self, list_of_strokes=None) -> None: self.strokes = [] if list_of_strokes: self.strokes = list_of_strokes def __len__(self): return len(self.strokes) def __getitem__(self, index): return self.strokes[index] def inkml_to_ink(inkml_file): """Convert inkml file to Ink""" tree = ET.parse(inkml_file) root = tree.getroot() inkml_namespace = {"inkml": "http://www.w3.org/2003/InkML"} strokes = [] for trace in root.findall("inkml:trace", inkml_namespace): points = trace.text.strip().split() stroke_points = [] for point in points: x, y = point.split(",") stroke_points.append((float(x), float(y))) strokes.append(Stroke(stroke_points)) return Ink(strokes) def parse_inkml_annotations(inkml_file): tree = ET.parse(inkml_file) root = tree.getroot() annotations = root.findall(".//{http://www.w3.org/2003/InkML}annotation") annotation_dict = {} for annotation in annotations: annotation_type = annotation.get("type") annotation_text = annotation.text annotation_dict[annotation_type] = annotation_text return annotation_dict def pregenerate_videos(video_cache_dir): datasets = ["IAM", "IMGUR5K", "HierText"] models = ["Small-i", "Large-i", "Small-p"] query_modes = ["d+t", "r+d", "vanilla"] for Dataset in datasets: for Model in models: inkml_path_base = f"./derendering_supp/{Model.lower()}_{Dataset}_inkml" for mode in query_modes: path = f"./derendering_supp/{Dataset}/images_sample" if not os.path.exists(path): continue samples = os.listdir(path) for name in tqdm( samples, desc=f"Generating {Model}-{Dataset}-{mode} videos" ): example_id = name.strip(".png") inkml_file = os.path.join( inkml_path_base, mode, f"{example_id}.inkml" ) if not os.path.exists(inkml_file): continue video_filename = f"{Model}_{Dataset}_{mode}_{example_id}.mp4" video_filepath = video_cache_dir / video_filename if not video_filepath.exists(): img_path = os.path.join(path, name) img = load_and_pad_img_dir(img_path) ink = inkml_to_ink(inkml_file) plot_ink_to_video(ink, str(video_filepath), input_image=img)