|
import io |
|
import os |
|
from datetime import datetime |
|
from typing import Optional |
|
|
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import tensorflow as tf |
|
from ganime.visualization.images import display_true_pred, unnormalize_if_necessary |
|
|
|
|
|
def get_logdir(parent_folder: str, experiment_name: Optional[str] = None) -> str: |
|
"""Get the logdir used for logging in tensorboard. The logdir will be the parent folder with the experiment name and the current date and time. |
|
|
|
Args: |
|
parent_folder (str): The parent folder of the logdir |
|
experiment_name (str, optional): Optinal name of the experiment. Defaults to "". |
|
|
|
Returns: |
|
str: The path of the logdir that can be used by Tensorboard |
|
""" |
|
current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") |
|
sub_folder = ( |
|
f"{experiment_name}_{current_time}" if experiment_name else current_time |
|
) |
|
logdir = os.path.join(parent_folder, sub_folder) |
|
return logdir |
|
|
|
|
|
def plot_to_image(figure): |
|
"""Converts the matplotlib plot specified by 'figure' to a PNG image and |
|
returns it. The supplied figure is closed and inaccessible after this call.""" |
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format="png") |
|
|
|
|
|
plt.close(figure) |
|
buf.seek(0) |
|
|
|
image = tf.image.decode_png(buf.getvalue(), channels=4) |
|
|
|
image = tf.expand_dims(image, 0) |
|
return image |
|
|
|
|
|
class TensorboardImage(tf.keras.callbacks.Callback): |
|
def __init__( |
|
self, |
|
logdir: str, |
|
train: np.array, |
|
validation: np.array = None, |
|
n_images: int = 8, |
|
): |
|
super().__init__() |
|
self.logdir = logdir |
|
self.train = train |
|
self.validation = validation |
|
self.file_writer = tf.summary.create_file_writer(logdir) |
|
self.n_images = n_images |
|
|
|
def on_epoch_end(self, epoch, logs): |
|
train_X, train_y = self.train |
|
train_X, train_y = self.truncate_X_y(train_X, train_y, self.n_images) |
|
train_pred = self.model.predict(train_X) |
|
self.write_to_tensorboard(train_y, train_pred, "Training data", epoch) |
|
|
|
if self.validation is not None: |
|
validation_X, validation_y = self.validation |
|
validation_X, validation_y = self.truncate_X_y( |
|
validation_X, validation_y, self.n_images |
|
) |
|
validation_pred = self.model.predict(validation_X) |
|
self.write_to_tensorboard( |
|
validation_y, validation_pred, "Validation data", epoch |
|
) |
|
|
|
def truncate_X_y(self, X, y, n_images): |
|
"""Truncate the X and y arrays to the first n_images.""" |
|
X = X[:n_images] |
|
y = y[:n_images] |
|
return X, y |
|
|
|
def write_to_tensorboard(self, y_true, y_pred, tag, step): |
|
with self.file_writer.as_default(): |
|
tf.summary.image( |
|
tag, |
|
plot_to_image(display_true_pred(y_true, y_pred, n_cols=len(y_true))), |
|
step=step, |
|
) |
|
|
|
|
|
class TensorboardVideo(tf.keras.callbacks.Callback): |
|
def __init__( |
|
self, |
|
logdir: str, |
|
train: np.array, |
|
validation: np.array = None, |
|
n_videos: int = 3, |
|
): |
|
super().__init__() |
|
self.logdir = logdir |
|
self.train = train |
|
self.validation = validation |
|
self.file_writer = tf.summary.create_file_writer(logdir) |
|
self.n_videos = n_videos |
|
|
|
def on_epoch_end(self, epoch, logs): |
|
|
|
|
|
|
|
train_pred = self.model.predict(self.train) |
|
self.write_to_tensorboard( |
|
unnormalize_if_necessary(self.train["y"]), |
|
train_pred, |
|
"Training data", |
|
epoch, |
|
) |
|
|
|
if self.validation is not None: |
|
|
|
|
|
|
|
|
|
validation_pred = self.model.predict(self.validation) |
|
self.write_to_tensorboard( |
|
unnormalize_if_necessary(self.validation["y"]), |
|
validation_pred, |
|
"Validation data", |
|
epoch, |
|
) |
|
|
|
def truncate_X_y(self, X, y, n_videos): |
|
"""Truncate the X and y arrays to the first n_videos.""" |
|
X = X[:n_videos] |
|
y = y[:n_videos] |
|
return X, y |
|
|
|
def write_to_tensorboard(self, y_true, y_pred, tag, step): |
|
stacked = tf.concat([y_pred, y_true], axis=2) |
|
self.video_summary(tag, stacked, step) |
|
self.image_summary(tag + "/images", y_true, y_pred, step) |
|
|
|
def image_summary(self, tag, y_true, y_pred, step): |
|
batch, n_frames, height, width, channels = y_true.shape |
|
images = np.empty( |
|
(batch * 2, n_frames, height, width, channels), dtype=np.float32 |
|
) |
|
|
|
images[0::2] = y_pred |
|
images[1::2] = y_true |
|
images = tf.transpose(images, (0, 2, 1, 3, 4)) |
|
images = tf.reshape(images, (height * batch * 2, width * n_frames, channels)) |
|
|
|
with self.file_writer.as_default(): |
|
tf.summary.image(tag, [images], step=step) |
|
|
|
def add_red_border(self, image_batch): |
|
image_batch = image_batch.copy() |
|
dtype = image_batch.dtype |
|
min_value = 0 |
|
max_value = 1 if dtype in [np.float16, np.float32, np.float64] else 255 |
|
|
|
image_batch[:, 0:2, :, 0] = max_value |
|
image_batch[:, 0:2, :, 1] = min_value |
|
image_batch[:, 0:2, :, 2] = min_value |
|
|
|
image_batch[:, -2:, :, 0] = max_value |
|
image_batch[:, -2:, :, 1] = min_value |
|
image_batch[:, -2:, :, 2] = min_value |
|
|
|
image_batch[:, :, 0:2, 0] = max_value |
|
image_batch[:, :, 0:2, 1] = min_value |
|
image_batch[:, :, 0:2, 2] = min_value |
|
|
|
image_batch[:, :, -2:, 0] = max_value |
|
image_batch[:, :, -2:, 1] = min_value |
|
image_batch[:, :, -2:, 2] = min_value |
|
return image_batch |
|
|
|
def video_summary(self, name, video, step=None, fps=10): |
|
name = tf.constant(name).numpy().decode("utf-8") |
|
video = np.array(video) |
|
if video.dtype in (np.float32, np.float64): |
|
video = np.clip(255 * video, 0, 255).astype(np.uint8) |
|
B, T, H, W, C = video.shape |
|
|
|
|
|
with self.file_writer.as_default(): |
|
try: |
|
frames = video.transpose((1, 2, 0, 3, 4)).reshape((T, H, B * W, C)) |
|
summary = tf.compat.v1.Summary() |
|
image = tf.compat.v1.Summary.Image( |
|
height=B * H, width=T * W, colorspace=C |
|
) |
|
image.encoded_image_string = self.encode_gif(frames, fps) |
|
summary.value.add(tag=name + "/gif", image=image) |
|
tf.summary.experimental.write_raw_pb(summary.SerializeToString(), step) |
|
except (IOError, OSError) as e: |
|
print("GIF summaries require ffmpeg in $PATH.", e) |
|
frames = video.transpose((0, 2, 1, 3, 4)).reshape((1, B * H, T * W, C)) |
|
tf.summary.image(name + "/grid", frames, step) |
|
|
|
def encode_gif(self, frames, fps): |
|
from subprocess import PIPE, Popen |
|
|
|
h, w, c = frames[0].shape |
|
pxfmt = {1: "gray", 3: "rgb24"}[c] |
|
cmd = " ".join( |
|
[ |
|
f"ffmpeg -y -f rawvideo -vcodec rawvideo", |
|
f"-r {fps:.02f} -s {w}x{h} -pix_fmt {pxfmt} -i - -filter_complex", |
|
f"[0:v]split[x][z];[z]palettegen[y];[x]fifo[x];[x][y]paletteuse", |
|
f"-r {fps:.02f} -f gif -", |
|
] |
|
) |
|
proc = Popen(cmd.split(" "), stdin=PIPE, stdout=PIPE, stderr=PIPE) |
|
for image in frames: |
|
proc.stdin.write(image.tostring()) |
|
out, err = proc.communicate() |
|
if proc.returncode: |
|
raise IOError("\n".join([" ".join(cmd), err.decode("utf8")])) |
|
del proc |
|
return out |
|
|