Spaces:
Sleeping
Sleeping
import json | |
from tqdm import tqdm | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import xml.etree.ElementTree as ET | |
from xml.dom import minidom | |
import os | |
from PIL import Image | |
import matplotlib.animation as animation | |
import copy | |
from PIL import ImageEnhance | |
import colorsys | |
import matplotlib.colors as mcolors | |
from matplotlib.collections import LineCollection | |
from matplotlib.patheffects import withStroke | |
import random | |
import warnings | |
from matplotlib.figure import Figure | |
from io import BytesIO | |
from matplotlib.animation import FuncAnimation, FFMpegWriter, PillowWriter | |
import requests | |
import zipfile | |
import base64 | |
warnings.filterwarnings("ignore") | |
def get_svg_content(svg_path): | |
with open(svg_path, "r") as file: | |
return file.read() | |
def download_file(url, filename): | |
response = requests.get(url) | |
with open(filename, "wb") as f: | |
f.write(response.content) | |
def unzip_file(filename, extract_to="."): | |
with zipfile.ZipFile(filename, "r") as zip_ref: | |
zip_ref.extractall(extract_to) | |
def get_base64_encoded_gif(gif_path): | |
with open(gif_path, "rb") as gif_file: | |
return base64.b64encode(gif_file.read()).decode("utf-8") | |
def load_and_pad_img_dir(file_dir): | |
image_path = os.path.join(file_dir) | |
image = Image.open(image_path) | |
width, height = image.size | |
ratio = min(224 / width, 224 / height) | |
image = image.resize((int(width * ratio), int(height * ratio))) | |
width, height = image.size | |
if height < 224: | |
# If width is shorter than height pad top and bottom. | |
top_padding = (224 - height) // 2 | |
bottom_padding = 224 - height - top_padding | |
padded_image = Image.new("RGB", (width, 224), (255, 255, 255)) | |
padded_image.paste(image, (0, top_padding)) | |
else: | |
# Otherwise pad left and right. | |
left_padding = (224 - width) // 2 | |
right_padding = 224 - width - left_padding | |
padded_image = Image.new("RGB", (224, height), (255, 255, 255)) | |
padded_image.paste(image, (left_padding, 0)) | |
return padded_image | |
def plot_ink(ink, ax, lw=1.8, input_image=None, with_path=True, path_color="white"): | |
if input_image is not None: | |
img = copy.deepcopy(input_image) | |
enhancer = ImageEnhance.Brightness(img) | |
img = enhancer.enhance(0.45) | |
ax.imshow(img) | |
base_colors = plt.cm.get_cmap("rainbow", len(ink.strokes)) | |
for i, stroke in enumerate(ink.strokes): | |
x, y = np.array(stroke.x), np.array(stroke.y) | |
base_color = base_colors(len(ink.strokes) - 1 - i) | |
hsv_color = colorsys.rgb_to_hsv(*base_color[:3]) | |
darker_color = colorsys.hsv_to_rgb( | |
hsv_color[0], hsv_color[1], max(0, hsv_color[2] * 0.65) | |
) | |
colors = [ | |
mcolors.to_rgba(darker_color, alpha=1 - (0.5 * j / len(x))) | |
for j in range(len(x)) | |
] | |
points = np.array([x, y]).T.reshape(-1, 1, 2) | |
segments = np.concatenate([points[:-1], points[1:]], axis=1) | |
lc = LineCollection(segments, colors=colors, linewidth=lw) | |
if with_path: | |
lc.set_path_effects( | |
[withStroke(linewidth=lw * 1.25, foreground=path_color)] | |
) | |
ax.add_collection(lc) | |
ax.set_xlim(0, 224) | |
ax.set_ylim(0, 224) | |
ax.invert_yaxis() | |
def plot_ink_to_video( | |
ink, output_name, lw=1.8, input_image=None, path_color="white", fps=30 | |
): | |
fig, ax = plt.subplots(figsize=(4, 4), dpi=150) | |
if input_image is not None: | |
img = copy.deepcopy(input_image) | |
enhancer = ImageEnhance.Brightness(img) | |
img = enhancer.enhance(0.45) | |
ax.imshow(img) | |
ax.set_xlim(0, 224) | |
ax.set_ylim(0, 224) | |
ax.invert_yaxis() | |
ax.axis("off") | |
base_colors = plt.cm.get_cmap("rainbow", len(ink.strokes)) | |
all_points = sum([len(stroke.x) for stroke in ink.strokes], 0) | |
def update(frame): | |
ax.clear() | |
if input_image is not None: | |
ax.imshow(img) | |
ax.set_xlim(0, 224) | |
ax.set_ylim(0, 224) | |
ax.invert_yaxis() | |
ax.axis("off") | |
points_drawn = 0 | |
for stroke_index, stroke in enumerate(ink.strokes): | |
x, y = np.array(stroke.x), np.array(stroke.y) | |
points = np.array([x, y]).T.reshape(-1, 1, 2) | |
segments = np.concatenate([points[:-1], points[1:]], axis=1) | |
base_color = base_colors(len(ink.strokes) - 1 - stroke_index) | |
hsv_color = colorsys.rgb_to_hsv(*base_color[:3]) | |
darker_color = colorsys.hsv_to_rgb( | |
hsv_color[0], hsv_color[1], max(0, hsv_color[2] * 0.65) | |
) | |
visible_segments = ( | |
segments[: frame - points_drawn] | |
if frame - points_drawn < len(segments) | |
else segments | |
) | |
colors = [ | |
mcolors.to_rgba( | |
darker_color, alpha=1 - (0.5 * j / len(visible_segments)) | |
) | |
for j in range(len(visible_segments)) | |
] | |
if len(visible_segments) > 0: | |
lc = LineCollection(visible_segments, colors=colors, linewidth=lw) | |
lc.set_path_effects( | |
[withStroke(linewidth=lw * 1.25, foreground=path_color)] | |
) | |
ax.add_collection(lc) | |
points_drawn += len(segments) | |
if points_drawn >= frame: | |
break | |
ani = FuncAnimation(fig, update, frames=all_points + 1, blit=False) | |
Writer = FFMpegWriter(fps=fps) | |
plt.tight_layout() | |
ani.save(output_name, writer=Writer) | |
plt.close(fig) | |
class Stroke: | |
def __init__(self, list_of_coordinates=None) -> None: | |
self.x = [] | |
self.y = [] | |
if list_of_coordinates: | |
for point in list_of_coordinates: | |
self.x.append(point[0]) | |
self.y.append(point[1]) | |
def __len__(self): | |
return len(self.x) | |
def __getitem__(self, index): | |
return (self.x[index], self.y[index]) | |
class Ink: | |
def __init__(self, list_of_strokes=None) -> None: | |
self.strokes = [] | |
if list_of_strokes: | |
self.strokes = list_of_strokes | |
def __len__(self): | |
return len(self.strokes) | |
def __getitem__(self, index): | |
return self.strokes[index] | |
def inkml_to_ink(inkml_file): | |
"""Convert inkml file to Ink""" | |
tree = ET.parse(inkml_file) | |
root = tree.getroot() | |
inkml_namespace = {"inkml": "http://www.w3.org/2003/InkML"} | |
strokes = [] | |
for trace in root.findall("inkml:trace", inkml_namespace): | |
points = trace.text.strip().split() | |
stroke_points = [] | |
for point in points: | |
x, y = point.split(",") | |
stroke_points.append((float(x), float(y))) | |
strokes.append(Stroke(stroke_points)) | |
return Ink(strokes) | |
def parse_inkml_annotations(inkml_file): | |
tree = ET.parse(inkml_file) | |
root = tree.getroot() | |
annotations = root.findall(".//{http://www.w3.org/2003/InkML}annotation") | |
annotation_dict = {} | |
for annotation in annotations: | |
annotation_type = annotation.get("type") | |
annotation_text = annotation.text | |
annotation_dict[annotation_type] = annotation_text | |
return annotation_dict | |
def pregenerate_videos(video_cache_dir): | |
datasets = ["IAM", "IMGUR5K", "HierText"] | |
models = ["Small-i", "Large-i", "Small-p"] | |
query_modes = ["d+t", "r+d", "vanilla"] | |
for Dataset in datasets: | |
for Model in models: | |
inkml_path_base = f"./derendering_supp/{Model.lower()}_{Dataset}_inkml" | |
for mode in query_modes: | |
path = f"./derendering_supp/{Dataset}/images_sample" | |
if not os.path.exists(path): | |
continue | |
samples = os.listdir(path) | |
for name in tqdm( | |
samples, desc=f"Generating {Model}-{Dataset}-{mode} videos" | |
): | |
example_id = name.strip(".png") | |
inkml_file = os.path.join( | |
inkml_path_base, mode, f"{example_id}.inkml" | |
) | |
if not os.path.exists(inkml_file): | |
continue | |
video_filename = f"{Model}_{Dataset}_{mode}_{example_id}.mp4" | |
video_filepath = video_cache_dir / video_filename | |
if not video_filepath.exists(): | |
img_path = os.path.join(path, name) | |
img = load_and_pad_img_dir(img_path) | |
ink = inkml_to_ink(inkml_file) | |
plot_ink_to_video(ink, str(video_filepath), input_image=img) | |