from typing import TYPE_CHECKING, Optional, Callable, Dict, List, Union from ditk import logging from easydict import EasyDict from matplotlib import pyplot as plt from matplotlib import animation import os import numpy as np import torch import wandb import pickle import treetensor.numpy as tnp from ding.framework import task from ding.envs import BaseEnvManagerV2 from ding.utils import DistributedWriter from ding.torch_utils import to_ndarray from ding.utils.default_helper import one_time_warning if TYPE_CHECKING: from ding.framework import OnlineRLContext, OfflineRLContext def online_logger(record_train_iter: bool = False, train_show_freq: int = 100) -> Callable: """ Overview: Create an online RL tensorboard logger for recording training and evaluation metrics. Arguments: - record_train_iter (:obj:`bool`): Whether to record training iteration. Default is False. - train_show_freq (:obj:`int`): Frequency of showing training logs. Default is 100. Returns: - _logger (:obj:`Callable`): A logger function that takes an OnlineRLContext object as input. Raises: - RuntimeError: If writer is None. - NotImplementedError: If the key of train_output is not supported, such as "scalars". Examples: >>> task.use(online_logger(record_train_iter=False, train_show_freq=1000)) """ if task.router.is_active and not task.has_role(task.role.LEARNER): return task.void() writer = DistributedWriter.get_instance() if writer is None: raise RuntimeError("logger writer is None, you should call `ding_init(cfg)` at the beginning of training.") last_train_show_iter = -1 def _logger(ctx: "OnlineRLContext"): if task.finish: writer.close() nonlocal last_train_show_iter if not np.isinf(ctx.eval_value): if record_train_iter: writer.add_scalar('basic/eval_episode_return_mean-env_step', ctx.eval_value, ctx.env_step) writer.add_scalar('basic/eval_episode_return_mean-train_iter', ctx.eval_value, ctx.train_iter) else: writer.add_scalar('basic/eval_episode_return_mean', ctx.eval_value, ctx.env_step) if ctx.train_output is not None and ctx.train_iter - last_train_show_iter >= train_show_freq: last_train_show_iter = ctx.train_iter if isinstance(ctx.train_output, List): output = ctx.train_output.pop() # only use latest output for some algorithms, like PPO else: output = ctx.train_output for k, v in output.items(): if k in ['priority', 'td_error_priority']: continue if "[scalars]" in k: new_k = k.split(']')[-1] raise NotImplementedError elif "[histogram]" in k: new_k = k.split(']')[-1] writer.add_histogram(new_k, v, ctx.env_step) if record_train_iter: writer.add_histogram(new_k, v, ctx.train_iter) else: if record_train_iter: writer.add_scalar('basic/train_{}-train_iter'.format(k), v, ctx.train_iter) writer.add_scalar('basic/train_{}-env_step'.format(k), v, ctx.env_step) else: writer.add_scalar('basic/train_{}'.format(k), v, ctx.env_step) return _logger def offline_logger(train_show_freq: int = 100) -> Callable: """ Overview: Create an offline RL tensorboard logger for recording training and evaluation metrics. Arguments: - train_show_freq (:obj:`int`): Frequency of showing training logs. Defaults to 100. Returns: - _logger (:obj:`Callable`): A logger function that takes an OfflineRLContext object as input. Raises: - RuntimeError: If writer is None. - NotImplementedError: If the key of train_output is not supported, such as "scalars". Examples: >>> task.use(offline_logger(train_show_freq=1000)) """ if task.router.is_active and not task.has_role(task.role.LEARNER): return task.void() writer = DistributedWriter.get_instance() if writer is None: raise RuntimeError("logger writer is None, you should call `ding_init(cfg)` at the beginning of training.") last_train_show_iter = -1 def _logger(ctx: "OfflineRLContext"): nonlocal last_train_show_iter if task.finish: writer.close() if not np.isinf(ctx.eval_value): writer.add_scalar('basic/eval_episode_return_mean-train_iter', ctx.eval_value, ctx.train_iter) if ctx.train_output is not None and ctx.train_iter - last_train_show_iter >= train_show_freq: last_train_show_iter = ctx.train_iter output = ctx.train_output for k, v in output.items(): if k in ['priority']: continue if "[scalars]" in k: new_k = k.split(']')[-1] raise NotImplementedError elif "[histogram]" in k: new_k = k.split(']')[-1] writer.add_histogram(new_k, v, ctx.train_iter) else: writer.add_scalar('basic/train_{}-train_iter'.format(k), v, ctx.train_iter) return _logger # four utility functions for wandb logger def softmax(logit: np.ndarray) -> np.ndarray: v = np.exp(logit) return v / v.sum(axis=-1, keepdims=True) def action_prob(num, action_prob, ln): ax = plt.gca() ax.set_ylim([0, 1]) for rect, x in zip(ln, action_prob[num]): rect.set_height(x) return ln def return_prob(num, return_prob, ln): return ln def return_distribution(episode_return): num = len(episode_return) max_return = max(episode_return) min_return = min(episode_return) hist, bins = np.histogram(episode_return, bins=np.linspace(min_return - 50, max_return + 50, 6)) gap = (max_return - min_return + 100) / 5 x_dim = ['{:.1f}'.format(min_return - 50 + gap * x) for x in range(5)] return hist / num, x_dim def wandb_online_logger( record_path: str = None, cfg: Union[dict, EasyDict] = None, exp_config: Union[dict, EasyDict] = None, metric_list: Optional[List[str]] = None, env: Optional[BaseEnvManagerV2] = None, model: Optional[torch.nn.Module] = None, anonymous: bool = False, project_name: str = 'default-project', run_name: str = None, wandb_sweep: bool = False, ) -> Callable: """ Overview: Wandb visualizer to track the experiment. Arguments: - record_path (:obj:`str`): The path to save the replay of simulation. - cfg (:obj:`Union[dict, EasyDict]`): Config, a dict of following settings: - gradient_logger: boolean. Whether to track the gradient. - plot_logger: boolean. Whether to track the metrics like reward and loss. - video_logger: boolean. Whether to upload the rendering video replay. - action_logger: boolean. `q_value` or `action probability`. - return_logger: boolean. Whether to track the return value. - metric_list (:obj:`Optional[List[str]]`): Logged metric list, specialized by different policies. - env (:obj:`BaseEnvManagerV2`): Evaluator environment. - model (:obj:`nn.Module`): Policy neural network model. - anonymous (:obj:`bool`): Open the anonymous mode of wandb or not. The anonymous mode allows visualization \ of data without wandb count. - project_name (:obj:`str`): The name of wandb project. - run_name (:obj:`str`): The name of wandb run. - wandb_sweep (:obj:`bool`): Whether to use wandb sweep. ''' Returns: - _plot (:obj:`Callable`): A logger function that takes an OnlineRLContext object as input. """ if task.router.is_active and not task.has_role(task.role.LEARNER): return task.void() color_list = ["orange", "red", "blue", "purple", "green", "darkcyan"] if metric_list is None: metric_list = ["q_value", "target q_value", "loss", "lr", "entropy", "target_q_value", "td_error"] # Initialize wandb with default settings # Settings can be covered by calling wandb.init() at the top of the script if exp_config: if not wandb_sweep: if run_name is not None: if anonymous: wandb.init(project=project_name, config=exp_config, reinit=True, name=run_name, anonymous="must") else: wandb.init(project=project_name, config=exp_config, reinit=True, name=run_name) else: if anonymous: wandb.init(project=project_name, config=exp_config, reinit=True, anonymous="must") else: wandb.init(project=project_name, config=exp_config, reinit=True) else: if run_name is not None: if anonymous: wandb.init(project=project_name, config=exp_config, name=run_name, anonymous="must") else: wandb.init(project=project_name, config=exp_config, name=run_name) else: if anonymous: wandb.init(project=project_name, config=exp_config, anonymous="must") else: wandb.init(project=project_name, config=exp_config) else: if not wandb_sweep: if run_name is not None: if anonymous: wandb.init(project=project_name, reinit=True, name=run_name, anonymous="must") else: wandb.init(project=project_name, reinit=True, name=run_name) else: if anonymous: wandb.init(project=project_name, reinit=True, anonymous="must") else: wandb.init(project=project_name, reinit=True) else: if run_name is not None: if anonymous: wandb.init(project=project_name, name=run_name, anonymous="must") else: wandb.init(project=project_name, name=run_name) else: if anonymous: wandb.init(project=project_name, anonymous="must") else: wandb.init(project=project_name) plt.switch_backend('agg') if cfg is None: cfg = EasyDict( dict( gradient_logger=False, plot_logger=True, video_logger=False, action_logger=False, return_logger=False, ) ) else: if not isinstance(cfg, EasyDict): cfg = EasyDict(cfg) for key in ["gradient_logger", "plot_logger", "video_logger", "action_logger", "return_logger", "vis_dataset"]: if key not in cfg.keys(): cfg[key] = False # The visualizer is called to save the replay of the simulation # which will be uploaded to wandb later if env is not None and cfg.video_logger is True and record_path is not None: env.enable_save_replay(replay_path=record_path) if cfg.gradient_logger: wandb.watch(model, log="all", log_freq=100, log_graph=True) else: one_time_warning( "If you want to use wandb to visualize the gradient, please set gradient_logger = True in the config." ) first_plot = True def _plot(ctx: "OnlineRLContext"): nonlocal first_plot if first_plot: first_plot = False ctx.wandb_url = wandb.run.get_project_url() info_for_logging = {} if cfg.plot_logger: for metric in metric_list: if isinstance(ctx.train_output, Dict) and metric in ctx.train_output: if isinstance(ctx.train_output[metric], torch.Tensor): info_for_logging.update({metric: ctx.train_output[metric].cpu().detach().numpy()}) else: info_for_logging.update({metric: ctx.train_output[metric]}) elif isinstance(ctx.train_output, List) and len(ctx.train_output) > 0 and metric in ctx.train_output[0]: metric_value_list = [] for item in ctx.train_output: if isinstance(item[metric], torch.Tensor): metric_value_list.append(item[metric].cpu().detach().numpy()) else: metric_value_list.append(item[metric]) metric_value = np.mean(metric_value_list) info_for_logging.update({metric: metric_value}) else: one_time_warning( "If you want to use wandb to visualize the result, please set plot_logger = True in the config." ) if ctx.eval_value != -np.inf: if hasattr(ctx, "eval_value_min"): info_for_logging.update({ "episode return min": ctx.eval_value_min, }) if hasattr(ctx, "eval_value_max"): info_for_logging.update({ "episode return max": ctx.eval_value_max, }) if hasattr(ctx, "eval_value_std"): info_for_logging.update({ "episode return std": ctx.eval_value_std, }) if hasattr(ctx, "eval_value"): info_for_logging.update({ "episode return mean": ctx.eval_value, }) if hasattr(ctx, "train_iter"): info_for_logging.update({ "train iter": ctx.train_iter, }) if hasattr(ctx, "env_step"): info_for_logging.update({ "env step": ctx.env_step, }) eval_output = ctx.eval_output['output'] episode_return = ctx.eval_output['episode_return'] episode_return = np.array(episode_return) if len(episode_return.shape) == 2: episode_return = episode_return.squeeze(1) if cfg.video_logger: if 'replay_video' in ctx.eval_output: # save numpy array "images" of shape (N,1212,3,224,320) to N video files in mp4 format # The numpy tensor must be either 4 dimensional or 5 dimensional. # Channels should be (time, channel, height, width) or (batch, time, channel, height width) video_images = ctx.eval_output['replay_video'] video_images = video_images.astype(np.uint8) info_for_logging.update({"replay_video": wandb.Video(video_images, fps=60)}) elif record_path is not None: file_list = [] for p in os.listdir(record_path): if os.path.splitext(p)[-1] == ".mp4": file_list.append(p) file_list.sort(key=lambda fn: os.path.getmtime(os.path.join(record_path, fn))) video_path = os.path.join(record_path, file_list[-2]) info_for_logging.update({"video": wandb.Video(video_path, format="mp4")}) if cfg.action_logger: action_path = os.path.join(record_path, (str(ctx.env_step) + "_action.gif")) if all(['logit' in v for v in eval_output]) or hasattr(eval_output, "logit"): if isinstance(eval_output, tnp.ndarray): action_prob = softmax(eval_output.logit) else: action_prob = [softmax(to_ndarray(v['logit'])) for v in eval_output] fig, ax = plt.subplots() plt.ylim([-1, 1]) action_dim = len(action_prob[1]) x_range = [str(x + 1) for x in range(action_dim)] ln = ax.bar(x_range, [0 for x in range(action_dim)], color=color_list[:action_dim]) ani = animation.FuncAnimation( fig, action_prob, fargs=(action_prob, ln), blit=True, save_count=len(action_prob) ) ani.save(action_path, writer='pillow') info_for_logging.update({"action": wandb.Video(action_path, format="gif")}) elif all(['action' in v for v in eval_output[0]]): for i, action_trajectory in enumerate(eval_output): fig, ax = plt.subplots() fig_data = np.array([[i + 1, *v['action']] for i, v in enumerate(action_trajectory)]) steps = fig_data[:, 0] actions = fig_data[:, 1:] plt.ylim([-1, 1]) for j in range(actions.shape[1]): ax.scatter(steps, actions[:, j]) info_for_logging.update({"actions_of_trajectory_{}".format(i): fig}) if cfg.return_logger: return_path = os.path.join(record_path, (str(ctx.env_step) + "_return.gif")) fig, ax = plt.subplots() ax = plt.gca() ax.set_ylim([0, 1]) hist, x_dim = return_distribution(episode_return) assert len(hist) == len(x_dim) ln_return = ax.bar(x_dim, hist, width=1, color='r', linewidth=0.7) ani = animation.FuncAnimation(fig, return_prob, fargs=(hist, ln_return), blit=True, save_count=1) ani.save(return_path, writer='pillow') info_for_logging.update({"return distribution": wandb.Video(return_path, format="gif")}) if bool(info_for_logging): wandb.log(data=info_for_logging, step=ctx.env_step) plt.clf() return _plot def wandb_offline_logger( record_path: str = None, cfg: Union[dict, EasyDict] = None, exp_config: Union[dict, EasyDict] = None, metric_list: Optional[List[str]] = None, env: Optional[BaseEnvManagerV2] = None, model: Optional[torch.nn.Module] = None, anonymous: bool = False, project_name: str = 'default-project', run_name: str = None, wandb_sweep: bool = False, ) -> Callable: """ Overview: Wandb visualizer to track the experiment. Arguments: - record_path (:obj:`str`): The path to save the replay of simulation. - cfg (:obj:`Union[dict, EasyDict]`): Config, a dict of following settings: - gradient_logger: boolean. Whether to track the gradient. - plot_logger: boolean. Whether to track the metrics like reward and loss. - video_logger: boolean. Whether to upload the rendering video replay. - action_logger: boolean. `q_value` or `action probability`. - return_logger: boolean. Whether to track the return value. - vis_dataset: boolean. Whether to visualize the dataset. - metric_list (:obj:`Optional[List[str]]`): Logged metric list, specialized by different policies. - env (:obj:`BaseEnvManagerV2`): Evaluator environment. - model (:obj:`nn.Module`): Policy neural network model. - anonymous (:obj:`bool`): Open the anonymous mode of wandb or not. The anonymous mode allows visualization \ of data without wandb count. - project_name (:obj:`str`): The name of wandb project. - run_name (:obj:`str`): The name of wandb run. - wandb_sweep (:obj:`bool`): Whether to use wandb sweep. ''' Returns: - _plot (:obj:`Callable`): A logger function that takes an OfflineRLContext object as input. """ if task.router.is_active and not task.has_role(task.role.LEARNER): return task.void() color_list = ["orange", "red", "blue", "purple", "green", "darkcyan"] if metric_list is None: metric_list = ["q_value", "target q_value", "loss", "lr", "entropy", "target_q_value", "td_error"] # Initialize wandb with default settings # Settings can be covered by calling wandb.init() at the top of the script if exp_config: if not wandb_sweep: if run_name is not None: if anonymous: wandb.init(project=project_name, config=exp_config, reinit=True, name=run_name, anonymous="must") else: wandb.init(project=project_name, config=exp_config, reinit=True, name=run_name) else: if anonymous: wandb.init(project=project_name, config=exp_config, reinit=True, anonymous="must") else: wandb.init(project=project_name, config=exp_config, reinit=True) else: if run_name is not None: if anonymous: wandb.init(project=project_name, config=exp_config, name=run_name, anonymous="must") else: wandb.init(project=project_name, config=exp_config, name=run_name) else: if anonymous: wandb.init(project=project_name, config=exp_config, anonymous="must") else: wandb.init(project=project_name, config=exp_config) else: if not wandb_sweep: if run_name is not None: if anonymous: wandb.init(project=project_name, reinit=True, name=run_name, anonymous="must") else: wandb.init(project=project_name, reinit=True, name=run_name) else: if anonymous: wandb.init(project=project_name, reinit=True, anonymous="must") else: wandb.init(project=project_name, reinit=True) else: if run_name is not None: if anonymous: wandb.init(project=project_name, name=run_name, anonymous="must") else: wandb.init(project=project_name, name=run_name) else: if anonymous: wandb.init(project=project_name, anonymous="must") else: wandb.init(project=project_name) plt.switch_backend('agg') plt.switch_backend('agg') if cfg is None: cfg = EasyDict( dict( gradient_logger=False, plot_logger=True, video_logger=False, action_logger=False, return_logger=False, vis_dataset=True, ) ) else: if not isinstance(cfg, EasyDict): cfg = EasyDict(cfg) for key in ["gradient_logger", "plot_logger", "video_logger", "action_logger", "return_logger", "vis_dataset"]: if key not in cfg.keys(): cfg[key] = False # The visualizer is called to save the replay of the simulation # which will be uploaded to wandb later if env is not None and cfg.video_logger is True and record_path is not None: env.enable_save_replay(replay_path=record_path) if cfg.gradient_logger: wandb.watch(model, log="all", log_freq=100, log_graph=True) else: one_time_warning( "If you want to use wandb to visualize the gradient, please set gradient_logger = True in the config." ) first_plot = True def _vis_dataset(datasetpath: str): try: from sklearn.manifold import TSNE except ImportError: import sys logging.warning("Please install sklearn first, such as `pip3 install scikit-learn`.") sys.exit(1) try: import h5py except ImportError: import sys logging.warning("Please install h5py first, such as `pip3 install h5py`.") sys.exit(1) assert os.path.splitext(datasetpath)[-1] in ['.pkl', '.h5', '.hdf5'] if os.path.splitext(datasetpath)[-1] == '.pkl': with open(datasetpath, 'rb') as f: data = pickle.load(f) obs = [] action = [] reward = [] for i in range(len(data)): obs.extend(data[i]['observations']) action.extend(data[i]['actions']) reward.extend(data[i]['rewards']) elif os.path.splitext(datasetpath)[-1] in ['.h5', '.hdf5']: with h5py.File(datasetpath, 'r') as f: obs = f['obs'][()] action = f['action'][()] reward = f['reward'][()] cmap = plt.cm.hsv obs = np.array(obs) reward = np.array(reward) obs_action = np.hstack((obs, np.array(action))) reward = reward / (max(reward) - min(reward)) embedded_obs = TSNE(n_components=2).fit_transform(obs) embedded_obs_action = TSNE(n_components=2).fit_transform(obs_action) x_min, x_max = np.min(embedded_obs, 0), np.max(embedded_obs, 0) embedded_obs = embedded_obs / (x_max - x_min) x_min, x_max = np.min(embedded_obs_action, 0), np.max(embedded_obs_action, 0) embedded_obs_action = embedded_obs_action / (x_max - x_min) fig = plt.figure() f, axes = plt.subplots(nrows=1, ncols=3) axes[0].scatter(embedded_obs[:, 0], embedded_obs[:, 1], c=cmap(reward)) axes[1].scatter(embedded_obs[:, 0], embedded_obs[:, 1], c=cmap(action)) axes[2].scatter(embedded_obs_action[:, 0], embedded_obs_action[:, 1], c=cmap(reward)) axes[0].set_title('state-reward') axes[1].set_title('state-action') axes[2].set_title('stateAction-reward') plt.savefig('dataset.png') wandb.log({"dataset": wandb.Image("dataset.png")}) if cfg.vis_dataset is True: _vis_dataset(exp_config.dataset_path) def _plot(ctx: "OfflineRLContext"): nonlocal first_plot if first_plot: first_plot = False ctx.wandb_url = wandb.run.get_project_url() info_for_logging = {} if cfg.plot_logger: for metric in metric_list: if isinstance(ctx.train_output, Dict) and metric in ctx.train_output: if isinstance(ctx.train_output[metric], torch.Tensor): info_for_logging.update({metric: ctx.train_output[metric].cpu().detach().numpy()}) else: info_for_logging.update({metric: ctx.train_output[metric]}) elif isinstance(ctx.train_output, List) and len(ctx.train_output) > 0 and metric in ctx.train_output[0]: metric_value_list = [] for item in ctx.train_output: if isinstance(item[metric], torch.Tensor): metric_value_list.append(item[metric].cpu().detach().numpy()) else: metric_value_list.append(item[metric]) metric_value = np.mean(metric_value_list) info_for_logging.update({metric: metric_value}) else: one_time_warning( "If you want to use wandb to visualize the result, please set plot_logger = True in the config." ) if ctx.eval_value != -np.inf: if hasattr(ctx, "eval_value_min"): info_for_logging.update({ "episode return min": ctx.eval_value_min, }) if hasattr(ctx, "eval_value_max"): info_for_logging.update({ "episode return max": ctx.eval_value_max, }) if hasattr(ctx, "eval_value_std"): info_for_logging.update({ "episode return std": ctx.eval_value_std, }) if hasattr(ctx, "eval_value"): info_for_logging.update({ "episode return mean": ctx.eval_value, }) if hasattr(ctx, "train_iter"): info_for_logging.update({ "train iter": ctx.train_iter, }) if hasattr(ctx, "train_epoch"): info_for_logging.update({ "train_epoch": ctx.train_epoch, }) eval_output = ctx.eval_output['output'] episode_return = ctx.eval_output['episode_return'] episode_return = np.array(episode_return) if len(episode_return.shape) == 2: episode_return = episode_return.squeeze(1) if cfg.video_logger: if 'replay_video' in ctx.eval_output: # save numpy array "images" of shape (N,1212,3,224,320) to N video files in mp4 format # The numpy tensor must be either 4 dimensional or 5 dimensional. # Channels should be (time, channel, height, width) or (batch, time, channel, height width) video_images = ctx.eval_output['replay_video'] video_images = video_images.astype(np.uint8) info_for_logging.update({"replay_video": wandb.Video(video_images, fps=60)}) elif record_path is not None: file_list = [] for p in os.listdir(record_path): if os.path.splitext(p)[-1] == ".mp4": file_list.append(p) file_list.sort(key=lambda fn: os.path.getmtime(os.path.join(record_path, fn))) video_path = os.path.join(record_path, file_list[-2]) info_for_logging.update({"video": wandb.Video(video_path, format="mp4")}) if cfg.action_logger: action_path = os.path.join(record_path, (str(ctx.trained_env_step) + "_action.gif")) if all(['logit' in v for v in eval_output]) or hasattr(eval_output, "logit"): if isinstance(eval_output, tnp.ndarray): action_prob = softmax(eval_output.logit) else: action_prob = [softmax(to_ndarray(v['logit'])) for v in eval_output] fig, ax = plt.subplots() plt.ylim([-1, 1]) action_dim = len(action_prob[1]) x_range = [str(x + 1) for x in range(action_dim)] ln = ax.bar(x_range, [0 for x in range(action_dim)], color=color_list[:action_dim]) ani = animation.FuncAnimation( fig, action_prob, fargs=(action_prob, ln), blit=True, save_count=len(action_prob) ) ani.save(action_path, writer='pillow') info_for_logging.update({"action": wandb.Video(action_path, format="gif")}) elif all(['action' in v for v in eval_output[0]]): for i, action_trajectory in enumerate(eval_output): fig, ax = plt.subplots() fig_data = np.array([[i + 1, *v['action']] for i, v in enumerate(action_trajectory)]) steps = fig_data[:, 0] actions = fig_data[:, 1:] plt.ylim([-1, 1]) for j in range(actions.shape[1]): ax.scatter(steps, actions[:, j]) info_for_logging.update({"actions_of_trajectory_{}".format(i): fig}) if cfg.return_logger: return_path = os.path.join(record_path, (str(ctx.trained_env_step) + "_return.gif")) fig, ax = plt.subplots() ax = plt.gca() ax.set_ylim([0, 1]) hist, x_dim = return_distribution(episode_return) assert len(hist) == len(x_dim) ln_return = ax.bar(x_dim, hist, width=1, color='r', linewidth=0.7) ani = animation.FuncAnimation(fig, return_prob, fargs=(hist, ln_return), blit=True, save_count=1) ani.save(return_path, writer='pillow') info_for_logging.update({"return distribution": wandb.Video(return_path, format="gif")}) if bool(info_for_logging): wandb.log(data=info_for_logging, step=ctx.trained_env_step) plt.clf() return _plot