| import logging | |
| import sys | |
| import torch | |
| import wandb | |
| from .plot_color import Plot | |
| from .metrics import Metric | |
| class Logger: | |
| def __init__(self, args, mds): | |
| self.molecule = args.molecule | |
| self.save_dir = args.save_dir | |
| self.wandb = args.wandb | |
| self.plot = Plot(args, mds) | |
| self.metric = Metric(args, mds) | |
| self.rmsd = float("inf") | |
| def __call__(self, loss, rollout, policy): | |
| metrics = self.metric() | |
| if self.rmsd > metrics["rmsd"]: | |
| self.rmsd = metrics["rmsd"] | |
| torch.save(policy.state_dict(), f"{self.save_dir}/policy.pt") | |
| if self.wandb: | |
| if metrics["ets"] is not None: | |
| wandb.log({ | |
| "rmsd": metrics["rmsd"], | |
| "rmsd_std": metrics["rmsd_std"], | |
| "thp": metrics["thp"], | |
| "ets": metrics["ets"], | |
| "ets_std": metrics["ets_std"], | |
| "loss": loss | |
| }) | |
| else: | |
| wandb.log({ | |
| "rmsd": metrics["rmsd"], | |
| "rmsd_std": metrics["rmsd_std"], | |
| "thp": metrics["thp"], | |
| "loss": loss | |
| }) | |