from typing import TYPE_CHECKING, Optional, Callable, Dict, List, Union |
from ditk import logging |
from easydict import EasyDict |
from matplotlib import pyplot as plt |
from matplotlib import animation |
import os |
import numpy as np |
import torch |
import wandb |
import pickle |
import treetensor.numpy as tnp |
from ding.framework import task |
from ding.envs import BaseEnvManagerV2 |
from ding.utils import DistributedWriter |
from ding.torch_utils import to_ndarray |
from ding.utils.default_helper import one_time_warning |
from ding.framework import OnlineRLContext, OfflineRLContext |
def online_logger(record_train_iter: bool = False, train_show_freq: int = 100) -> Callable: |
""" |
Overview: |
Create an online RL tensorboard logger for recording training and evaluation metrics. |
Arguments: |
- record_train_iter (:obj:`bool`): Whether to record training iteration. Default is False. |
- train_show_freq (:obj:`int`): Frequency of showing training logs. Default is 100. |
Returns: |
- _logger (:obj:`Callable`): A logger function that takes an OnlineRLContext object as input. |
Raises: |
- RuntimeError: If writer is None. |
- NotImplementedError: If the key of train_output is not supported, such as "scalars". |
Examples: |
>>> task.use(online_logger(record_train_iter=False, train_show_freq=1000)) |
""" |
if task.router.is_active and not task.has_role(task.role.LEARNER): |
return task.void() |
writer = DistributedWriter.get_instance() |
if writer is None: |
raise RuntimeError("logger writer is None, you should call `ding_init(cfg)` at the beginning of training.") |
last_train_show_iter = -1 |
def _logger(ctx: "OnlineRLContext"): |
if task.finish: |
writer.close() |
nonlocal last_train_show_iter |
if not np.isinf(ctx.eval_value): |
if record_train_iter: |
writer.add_scalar('basic/eval_episode_return_mean-env_step', ctx.eval_value, ctx.env_step) |
writer.add_scalar('basic/eval_episode_return_mean-train_iter', ctx.eval_value, ctx.train_iter) |
else: |
writer.add_scalar('basic/eval_episode_return_mean', ctx.eval_value, ctx.env_step) |
if ctx.train_output is not None and ctx.train_iter - last_train_show_iter >= train_show_freq: |
last_train_show_iter = ctx.train_iter |
if isinstance(ctx.train_output, List): |
output = ctx.train_output.pop() |
else: |
output = ctx.train_output |
for k, v in output.items(): |
if k in ['priority', 'td_error_priority']: |
continue |
if "[scalars]" in k: |
new_k = k.split(']')[-1] |
raise NotImplementedError |
elif "[histogram]" in k: |
new_k = k.split(']')[-1] |
writer.add_histogram(new_k, v, ctx.env_step) |
if record_train_iter: |
writer.add_histogram(new_k, v, ctx.train_iter) |
else: |
if record_train_iter: |
writer.add_scalar('basic/train_{}-train_iter'.format(k), v, ctx.train_iter) |
writer.add_scalar('basic/train_{}-env_step'.format(k), v, ctx.env_step) |
else: |
writer.add_scalar('basic/train_{}'.format(k), v, ctx.env_step) |
return _logger |
def offline_logger(train_show_freq: int = 100) -> Callable: |
""" |
Overview: |
Create an offline RL tensorboard logger for recording training and evaluation metrics. |
Arguments: |
- train_show_freq (:obj:`int`): Frequency of showing training logs. Defaults to 100. |
Returns: |
- _logger (:obj:`Callable`): A logger function that takes an OfflineRLContext object as input. |
Raises: |
- RuntimeError: If writer is None. |
- NotImplementedError: If the key of train_output is not supported, such as "scalars". |
Examples: |
>>> task.use(offline_logger(train_show_freq=1000)) |
""" |
if task.router.is_active and not task.has_role(task.role.LEARNER): |
return task.void() |
writer = DistributedWriter.get_instance() |
if writer is None: |
raise RuntimeError("logger writer is None, you should call `ding_init(cfg)` at the beginning of training.") |
last_train_show_iter = -1 |
def _logger(ctx: "OfflineRLContext"): |
nonlocal last_train_show_iter |
if task.finish: |
writer.close() |
if not np.isinf(ctx.eval_value): |
writer.add_scalar('basic/eval_episode_return_mean-train_iter', ctx.eval_value, ctx.train_iter) |
if ctx.train_output is not None and ctx.train_iter - last_train_show_iter >= train_show_freq: |
last_train_show_iter = ctx.train_iter |
output = ctx.train_output |
for k, v in output.items(): |
if k in ['priority']: |
continue |
if "[scalars]" in k: |
new_k = k.split(']')[-1] |
raise NotImplementedError |
elif "[histogram]" in k: |
new_k = k.split(']')[-1] |
writer.add_histogram(new_k, v, ctx.train_iter) |
else: |
writer.add_scalar('basic/train_{}-train_iter'.format(k), v, ctx.train_iter) |
return _logger |
def softmax(logit: np.ndarray) -> np.ndarray: |
v = np.exp(logit) |
return v / v.sum(axis=-1, keepdims=True) |
def action_prob(num, action_prob, ln): |
ax = plt.gca() |
ax.set_ylim([0, 1]) |
for rect, x in zip(ln, action_prob[num]): |
rect.set_height(x) |
return ln |
def return_prob(num, return_prob, ln): |
return ln |
def return_distribution(episode_return): |
num = len(episode_return) |
max_return = max(episode_return) |
min_return = min(episode_return) |
hist, bins = np.histogram(episode_return, bins=np.linspace(min_return - 50, max_return + 50, 6)) |
gap = (max_return - min_return + 100) / 5 |
x_dim = ['{:.1f}'.format(min_return - 50 + gap * x) for x in range(5)] |
return hist / num, x_dim |
def wandb_online_logger( |
record_path: str = None, |
cfg: Union[dict, EasyDict] = None, |
exp_config: Union[dict, EasyDict] = None, |
metric_list: Optional[List[str]] = None, |
env: Optional[BaseEnvManagerV2] = None, |
model: Optional[torch.nn.Module] = None, |
anonymous: bool = False, |
project_name: str = 'default-project', |
run_name: str = None, |
wandb_sweep: bool = False, |
) -> Callable: |
""" |
Overview: |
Wandb visualizer to track the experiment. |
Arguments: |
- record_path (:obj:`str`): The path to save the replay of simulation. |
- cfg (:obj:`Union[dict, EasyDict]`): Config, a dict of following settings: |
- gradient_logger: boolean. Whether to track the gradient. |
- plot_logger: boolean. Whether to track the metrics like reward and loss. |
- video_logger: boolean. Whether to upload the rendering video replay. |
- action_logger: boolean. `q_value` or `action probability`. |
- return_logger: boolean. Whether to track the return value. |
- metric_list (:obj:`Optional[List[str]]`): Logged metric list, specialized by different policies. |
- env (:obj:`BaseEnvManagerV2`): Evaluator environment. |
- model (:obj:`nn.Module`): Policy neural network model. |
- anonymous (:obj:`bool`): Open the anonymous mode of wandb or not. The anonymous mode allows visualization \ |
of data without wandb count. |
- project_name (:obj:`str`): The name of wandb project. |
- run_name (:obj:`str`): The name of wandb run. |
- wandb_sweep (:obj:`bool`): Whether to use wandb sweep. |
''' |
Returns: |
- _plot (:obj:`Callable`): A logger function that takes an OnlineRLContext object as input. |
""" |
if task.router.is_active and not task.has_role(task.role.LEARNER): |
return task.void() |
color_list = ["orange", "red", "blue", "purple", "green", "darkcyan"] |
if metric_list is None: |
metric_list = ["q_value", "target q_value", "loss", "lr", "entropy", "target_q_value", "td_error"] |
if exp_config: |
if not wandb_sweep: |
if run_name is not None: |
if anonymous: |
wandb.init(project=project_name, config=exp_config, reinit=True, name=run_name, anonymous="must") |
else: |
wandb.init(project=project_name, config=exp_config, reinit=True, name=run_name) |
else: |
if anonymous: |
wandb.init(project=project_name, config=exp_config, reinit=True, anonymous="must") |
else: |
wandb.init(project=project_name, config=exp_config, reinit=True) |
else: |
if run_name is not None: |
if anonymous: |
wandb.init(project=project_name, config=exp_config, name=run_name, anonymous="must") |
else: |
wandb.init(project=project_name, config=exp_config, name=run_name) |
else: |
if anonymous: |
wandb.init(project=project_name, config=exp_config, anonymous="must") |
else: |
wandb.init(project=project_name, config=exp_config) |
else: |
if not wandb_sweep: |
if run_name is not None: |
if anonymous: |
wandb.init(project=project_name, reinit=True, name=run_name, anonymous="must") |
else: |
wandb.init(project=project_name, reinit=True, name=run_name) |
else: |
if anonymous: |
wandb.init(project=project_name, reinit=True, anonymous="must") |
else: |
wandb.init(project=project_name, reinit=True) |
else: |
if run_name is not None: |
if anonymous: |
wandb.init(project=project_name, name=run_name, anonymous="must") |
else: |
wandb.init(project=project_name, name=run_name) |
else: |
if anonymous: |
wandb.init(project=project_name, anonymous="must") |
else: |
wandb.init(project=project_name) |
plt.switch_backend('agg') |
if cfg is None: |
cfg = EasyDict( |
dict( |
gradient_logger=False, |
plot_logger=True, |
video_logger=False, |
action_logger=False, |
return_logger=False, |
) |
) |
else: |
if not isinstance(cfg, EasyDict): |
cfg = EasyDict(cfg) |
for key in ["gradient_logger", "plot_logger", "video_logger", "action_logger", "return_logger", "vis_dataset"]: |
if key not in cfg.keys(): |
cfg[key] = False |
if env is not None and cfg.video_logger is True and record_path is not None: |
env.enable_save_replay(replay_path=record_path) |
if cfg.gradient_logger: |
wandb.watch(model, log="all", log_freq=100, log_graph=True) |
else: |
one_time_warning( |
"If you want to use wandb to visualize the gradient, please set gradient_logger = True in the config." |
) |
first_plot = True |
def _plot(ctx: "OnlineRLContext"): |
nonlocal first_plot |
if first_plot: |
first_plot = False |
ctx.wandb_url = wandb.run.get_project_url() |
info_for_logging = {} |
if cfg.plot_logger: |
for metric in metric_list: |
if isinstance(ctx.train_output, Dict) and metric in ctx.train_output: |
if isinstance(ctx.train_output[metric], torch.Tensor): |
info_for_logging.update({metric: ctx.train_output[metric].cpu().detach().numpy()}) |
else: |
info_for_logging.update({metric: ctx.train_output[metric]}) |
elif isinstance(ctx.train_output, List) and len(ctx.train_output) > 0 and metric in ctx.train_output[0]: |
metric_value_list = [] |
for item in ctx.train_output: |
if isinstance(item[metric], torch.Tensor): |
metric_value_list.append(item[metric].cpu().detach().numpy()) |
else: |
metric_value_list.append(item[metric]) |
metric_value = np.mean(metric_value_list) |
info_for_logging.update({metric: metric_value}) |
else: |
one_time_warning( |
"If you want to use wandb to visualize the result, please set plot_logger = True in the config." |
) |
if ctx.eval_value != -np.inf: |
if hasattr(ctx, "eval_value_min"): |
info_for_logging.update({ |
"episode return min": ctx.eval_value_min, |
}) |
if hasattr(ctx, "eval_value_max"): |
info_for_logging.update({ |
"episode return max": ctx.eval_value_max, |
}) |
if hasattr(ctx, "eval_value_std"): |
info_for_logging.update({ |
"episode return std": ctx.eval_value_std, |
}) |
if hasattr(ctx, "eval_value"): |
info_for_logging.update({ |
"episode return mean": ctx.eval_value, |
}) |
if hasattr(ctx, "train_iter"): |
info_for_logging.update({ |
"train iter": ctx.train_iter, |
}) |
if hasattr(ctx, "env_step"): |
info_for_logging.update({ |
"env step": ctx.env_step, |
}) |
eval_output = ctx.eval_output['output'] |
episode_return = ctx.eval_output['episode_return'] |
episode_return = np.array(episode_return) |
if len(episode_return.shape) == 2: |
episode_return = episode_return.squeeze(1) |
if cfg.video_logger: |
if 'replay_video' in ctx.eval_output: |
video_images = ctx.eval_output['replay_video'] |
video_images = video_images.astype(np.uint8) |
info_for_logging.update({"replay_video": wandb.Video(video_images, fps=60)}) |
elif record_path is not None: |
file_list = [] |
for p in os.listdir(record_path): |
if os.path.splitext(p)[-1] == ".mp4": |
file_list.append(p) |
file_list.sort(key=lambda fn: os.path.getmtime(os.path.join(record_path, fn))) |
video_path = os.path.join(record_path, file_list[-2]) |
info_for_logging.update({"video": wandb.Video(video_path, format="mp4")}) |
if cfg.action_logger: |
action_path = os.path.join(record_path, (str(ctx.env_step) + "_action.gif")) |
if all(['logit' in v for v in eval_output]) or hasattr(eval_output, "logit"): |
if isinstance(eval_output, tnp.ndarray): |
action_prob = softmax(eval_output.logit) |
else: |
action_prob = [softmax(to_ndarray(v['logit'])) for v in eval_output] |
fig, ax = plt.subplots() |
plt.ylim([-1, 1]) |
action_dim = len(action_prob[1]) |
x_range = [str(x + 1) for x in range(action_dim)] |
ln = ax.bar(x_range, [0 for x in range(action_dim)], color=color_list[:action_dim]) |
ani = animation.FuncAnimation( |
fig, action_prob, fargs=(action_prob, ln), blit=True, save_count=len(action_prob) |
) |
ani.save(action_path, writer='pillow') |
info_for_logging.update({"action": wandb.Video(action_path, format="gif")}) |
elif all(['action' in v for v in eval_output[0]]): |
for i, action_trajectory in enumerate(eval_output): |
fig, ax = plt.subplots() |
fig_data = np.array([[i + 1, *v['action']] for i, v in enumerate(action_trajectory)]) |
steps = fig_data[:, 0] |
actions = fig_data[:, 1:] |
plt.ylim([-1, 1]) |
for j in range(actions.shape[1]): |
ax.scatter(steps, actions[:, j]) |
info_for_logging.update({"actions_of_trajectory_{}".format(i): fig}) |
if cfg.return_logger: |
return_path = os.path.join(record_path, (str(ctx.env_step) + "_return.gif")) |
fig, ax = plt.subplots() |
ax = plt.gca() |
ax.set_ylim([0, 1]) |
hist, x_dim = return_distribution(episode_return) |
assert len(hist) == len(x_dim) |
ln_return = ax.bar(x_dim, hist, width=1, color='r', linewidth=0.7) |
ani = animation.FuncAnimation(fig, return_prob, fargs=(hist, ln_return), blit=True, save_count=1) |
ani.save(return_path, writer='pillow') |
info_for_logging.update({"return distribution": wandb.Video(return_path, format="gif")}) |
if bool(info_for_logging): |
wandb.log(data=info_for_logging, step=ctx.env_step) |
plt.clf() |
return _plot |
def wandb_offline_logger( |
record_path: str = None, |
cfg: Union[dict, EasyDict] = None, |
exp_config: Union[dict, EasyDict] = None, |
metric_list: Optional[List[str]] = None, |
env: Optional[BaseEnvManagerV2] = None, |
model: Optional[torch.nn.Module] = None, |
anonymous: bool = False, |
project_name: str = 'default-project', |
run_name: str = None, |
wandb_sweep: bool = False, |
) -> Callable: |
""" |
Overview: |
Wandb visualizer to track the experiment. |
Arguments: |
- record_path (:obj:`str`): The path to save the replay of simulation. |
- cfg (:obj:`Union[dict, EasyDict]`): Config, a dict of following settings: |
- gradient_logger: boolean. Whether to track the gradient. |
- plot_logger: boolean. Whether to track the metrics like reward and loss. |
- video_logger: boolean. Whether to upload the rendering video replay. |
- action_logger: boolean. `q_value` or `action probability`. |
- return_logger: boolean. Whether to track the return value. |
- vis_dataset: boolean. Whether to visualize the dataset. |
- metric_list (:obj:`Optional[List[str]]`): Logged metric list, specialized by different policies. |
- env (:obj:`BaseEnvManagerV2`): Evaluator environment. |
- model (:obj:`nn.Module`): Policy neural network model. |
- anonymous (:obj:`bool`): Open the anonymous mode of wandb or not. The anonymous mode allows visualization \ |
of data without wandb count. |
- project_name (:obj:`str`): The name of wandb project. |
- run_name (:obj:`str`): The name of wandb run. |
- wandb_sweep (:obj:`bool`): Whether to use wandb sweep. |
''' |
Returns: |
- _plot (:obj:`Callable`): A logger function that takes an OfflineRLContext object as input. |
""" |
if task.router.is_active and not task.has_role(task.role.LEARNER): |
return task.void() |
color_list = ["orange", "red", "blue", "purple", "green", "darkcyan"] |
if metric_list is None: |
metric_list = ["q_value", "target q_value", "loss", "lr", "entropy", "target_q_value", "td_error"] |
if exp_config: |
if not wandb_sweep: |
if run_name is not None: |
if anonymous: |
wandb.init(project=project_name, config=exp_config, reinit=True, name=run_name, anonymous="must") |
else: |
wandb.init(project=project_name, config=exp_config, reinit=True, name=run_name) |
else: |
if anonymous: |
wandb.init(project=project_name, config=exp_config, reinit=True, anonymous="must") |
else: |
wandb.init(project=project_name, config=exp_config, reinit=True) |
else: |
if run_name is not None: |
if anonymous: |
wandb.init(project=project_name, config=exp_config, name=run_name, anonymous="must") |
else: |
wandb.init(project=project_name, config=exp_config, name=run_name) |
else: |
if anonymous: |
wandb.init(project=project_name, config=exp_config, anonymous="must") |
else: |
wandb.init(project=project_name, config=exp_config) |
else: |
if not wandb_sweep: |
if run_name is not None: |
if anonymous: |
wandb.init(project=project_name, reinit=True, name=run_name, anonymous="must") |
else: |
wandb.init(project=project_name, reinit=True, name=run_name) |
else: |
if anonymous: |
wandb.init(project=project_name, reinit=True, anonymous="must") |
else: |
wandb.init(project=project_name, reinit=True) |
else: |
if run_name is not None: |
if anonymous: |
wandb.init(project=project_name, name=run_name, anonymous="must") |
else: |
wandb.init(project=project_name, name=run_name) |
else: |
if anonymous: |
wandb.init(project=project_name, anonymous="must") |
else: |
wandb.init(project=project_name) |
plt.switch_backend('agg') |
plt.switch_backend('agg') |
if cfg is None: |
cfg = EasyDict( |
dict( |
gradient_logger=False, |
plot_logger=True, |
video_logger=False, |
action_logger=False, |
return_logger=False, |
vis_dataset=True, |
) |
) |
else: |
if not isinstance(cfg, EasyDict): |
cfg = EasyDict(cfg) |
for key in ["gradient_logger", "plot_logger", "video_logger", "action_logger", "return_logger", "vis_dataset"]: |
if key not in cfg.keys(): |
cfg[key] = False |
if env is not None and cfg.video_logger is True and record_path is not None: |
env.enable_save_replay(replay_path=record_path) |
if cfg.gradient_logger: |
wandb.watch(model, log="all", log_freq=100, log_graph=True) |
else: |
one_time_warning( |
"If you want to use wandb to visualize the gradient, please set gradient_logger = True in the config." |
) |
first_plot = True |
def _vis_dataset(datasetpath: str): |
try: |
from sklearn.manifold import TSNE |
except ImportError: |
import sys |
logging.warning("Please install sklearn first, such as `pip3 install scikit-learn`.") |
sys.exit(1) |
try: |
import h5py |
except ImportError: |
import sys |
logging.warning("Please install h5py first, such as `pip3 install h5py`.") |
sys.exit(1) |
assert os.path.splitext(datasetpath)[-1] in ['.pkl', '.h5', '.hdf5'] |
if os.path.splitext(datasetpath)[-1] == '.pkl': |
with open(datasetpath, 'rb') as f: |
data = pickle.load(f) |
obs = [] |
action = [] |
reward = [] |
for i in range(len(data)): |
obs.extend(data[i]['observations']) |
action.extend(data[i]['actions']) |
reward.extend(data[i]['rewards']) |
elif os.path.splitext(datasetpath)[-1] in ['.h5', '.hdf5']: |
with h5py.File(datasetpath, 'r') as f: |
obs = f['obs'][()] |
action = f['action'][()] |
reward = f['reward'][()] |
cmap = plt.cm.hsv |
obs = np.array(obs) |
reward = np.array(reward) |
obs_action = np.hstack((obs, np.array(action))) |
reward = reward / (max(reward) - min(reward)) |
embedded_obs = TSNE(n_components=2).fit_transform(obs) |
embedded_obs_action = TSNE(n_components=2).fit_transform(obs_action) |
x_min, x_max = np.min(embedded_obs, 0), np.max(embedded_obs, 0) |
embedded_obs = embedded_obs / (x_max - x_min) |
x_min, x_max = np.min(embedded_obs_action, 0), np.max(embedded_obs_action, 0) |
embedded_obs_action = embedded_obs_action / (x_max - x_min) |
fig = plt.figure() |
f, axes = plt.subplots(nrows=1, ncols=3) |
axes[0].scatter(embedded_obs[:, 0], embedded_obs[:, 1], c=cmap(reward)) |
axes[1].scatter(embedded_obs[:, 0], embedded_obs[:, 1], c=cmap(action)) |
axes[2].scatter(embedded_obs_action[:, 0], embedded_obs_action[:, 1], c=cmap(reward)) |
axes[0].set_title('state-reward') |
axes[1].set_title('state-action') |
axes[2].set_title('stateAction-reward') |
plt.savefig('dataset.png') |
wandb.log({"dataset": wandb.Image("dataset.png")}) |
if cfg.vis_dataset is True: |
_vis_dataset(exp_config.dataset_path) |
def _plot(ctx: "OfflineRLContext"): |
nonlocal first_plot |
if first_plot: |
first_plot = False |
ctx.wandb_url = wandb.run.get_project_url() |
info_for_logging = {} |
if cfg.plot_logger: |
for metric in metric_list: |
if isinstance(ctx.train_output, Dict) and metric in ctx.train_output: |
if isinstance(ctx.train_output[metric], torch.Tensor): |
info_for_logging.update({metric: ctx.train_output[metric].cpu().detach().numpy()}) |
else: |
info_for_logging.update({metric: ctx.train_output[metric]}) |
elif isinstance(ctx.train_output, List) and len(ctx.train_output) > 0 and metric in ctx.train_output[0]: |
metric_value_list = [] |
for item in ctx.train_output: |
if isinstance(item[metric], torch.Tensor): |
metric_value_list.append(item[metric].cpu().detach().numpy()) |
else: |
metric_value_list.append(item[metric]) |
metric_value = np.mean(metric_value_list) |
info_for_logging.update({metric: metric_value}) |
else: |
one_time_warning( |
"If you want to use wandb to visualize the result, please set plot_logger = True in the config." |
) |
if ctx.eval_value != -np.inf: |
if hasattr(ctx, "eval_value_min"): |
info_for_logging.update({ |
"episode return min": ctx.eval_value_min, |
}) |
if hasattr(ctx, "eval_value_max"): |
info_for_logging.update({ |
"episode return max": ctx.eval_value_max, |
}) |
if hasattr(ctx, "eval_value_std"): |
info_for_logging.update({ |
"episode return std": ctx.eval_value_std, |
}) |
if hasattr(ctx, "eval_value"): |
info_for_logging.update({ |
"episode return mean": ctx.eval_value, |
}) |
if hasattr(ctx, "train_iter"): |
info_for_logging.update({ |
"train iter": ctx.train_iter, |
}) |
if hasattr(ctx, "train_epoch"): |
info_for_logging.update({ |
"train_epoch": ctx.train_epoch, |
}) |
eval_output = ctx.eval_output['output'] |
episode_return = ctx.eval_output['episode_return'] |
episode_return = np.array(episode_return) |
if len(episode_return.shape) == 2: |
episode_return = episode_return.squeeze(1) |
if cfg.video_logger: |
if 'replay_video' in ctx.eval_output: |
video_images = ctx.eval_output['replay_video'] |
video_images = video_images.astype(np.uint8) |
info_for_logging.update({"replay_video": wandb.Video(video_images, fps=60)}) |
elif record_path is not None: |
file_list = [] |
for p in os.listdir(record_path): |
if os.path.splitext(p)[-1] == ".mp4": |
file_list.append(p) |
file_list.sort(key=lambda fn: os.path.getmtime(os.path.join(record_path, fn))) |
video_path = os.path.join(record_path, file_list[-2]) |
info_for_logging.update({"video": wandb.Video(video_path, format="mp4")}) |
if cfg.action_logger: |
action_path = os.path.join(record_path, (str(ctx.trained_env_step) + "_action.gif")) |
if all(['logit' in v for v in eval_output]) or hasattr(eval_output, "logit"): |
if isinstance(eval_output, tnp.ndarray): |
action_prob = softmax(eval_output.logit) |
else: |
action_prob = [softmax(to_ndarray(v['logit'])) for v in eval_output] |
fig, ax = plt.subplots() |
plt.ylim([-1, 1]) |
action_dim = len(action_prob[1]) |
x_range = [str(x + 1) for x in range(action_dim)] |
ln = ax.bar(x_range, [0 for x in range(action_dim)], color=color_list[:action_dim]) |
ani = animation.FuncAnimation( |
fig, action_prob, fargs=(action_prob, ln), blit=True, save_count=len(action_prob) |
) |
ani.save(action_path, writer='pillow') |
info_for_logging.update({"action": wandb.Video(action_path, format="gif")}) |
elif all(['action' in v for v in eval_output[0]]): |
for i, action_trajectory in enumerate(eval_output): |
fig, ax = plt.subplots() |
fig_data = np.array([[i + 1, *v['action']] for i, v in enumerate(action_trajectory)]) |
steps = fig_data[:, 0] |
actions = fig_data[:, 1:] |
plt.ylim([-1, 1]) |
for j in range(actions.shape[1]): |
ax.scatter(steps, actions[:, j]) |
info_for_logging.update({"actions_of_trajectory_{}".format(i): fig}) |
if cfg.return_logger: |
return_path = os.path.join(record_path, (str(ctx.trained_env_step) + "_return.gif")) |
fig, ax = plt.subplots() |
ax = plt.gca() |
ax.set_ylim([0, 1]) |
hist, x_dim = return_distribution(episode_return) |
assert len(hist) == len(x_dim) |
ln_return = ax.bar(x_dim, hist, width=1, color='r', linewidth=0.7) |
ani = animation.FuncAnimation(fig, return_prob, fargs=(hist, ln_return), blit=True, save_count=1) |
ani.save(return_path, writer='pillow') |
info_for_logging.update({"return distribution": wandb.Video(return_path, format="gif")}) |
if bool(info_for_logging): |
wandb.log(data=info_for_logging, step=ctx.trained_env_step) |
plt.clf() |
return _plot |