import torch import torch.nn.functional as F import numpy from torch_ac.utils import DictList # dictionary that defines what head is required for each extra info used for auxiliary supervision required_heads = {'seen_state': 'binary', 'see_door': 'binary', 'see_obj': 'binary', 'obj_in_instr': 'binary', 'in_front_of_what': 'multiclass9', # multi class classifier with 9 possible classes 'visit_proportion': 'continuous01', # continous regressor with outputs in [0, 1] 'bot_action': 'binary' } class ExtraInfoCollector: ''' This class, used in rl.algos.base, allows connecting the extra information from the environment, and the corresponding predictions using the specific heads in the model. It transforms them so that they are easy to use to evaluate losses ''' def __init__(self, aux_info, shape, device): self.aux_info = aux_info self.shape = shape self.device = device self.collected_info = dict() self.extra_predictions = dict() for info in self.aux_info: self.collected_info[info] = torch.zeros(*shape, device=self.device) if required_heads[info] == 'binary' or required_heads[info].startswith('continuous'): # we predict one number only self.extra_predictions[info] = torch.zeros(*shape, 1, device=self.device) elif required_heads[info].startswith('multiclass'): # means that this is a multi-class classification and we need to predict the whole proba distr n_classes = int(required_heads[info].replace('multiclass', '')) self.extra_predictions[info] = torch.zeros(*shape, n_classes, device=self.device) else: raise ValueError("{} not supported".format(required_heads[info])) def process(self, env_info): # env_info is now a tuple of dicts env_info = [{k: v for k, v in dic.items() if k in self.aux_info} for dic in env_info] env_info = {k: [env_info[_][k] for _ in range(len(env_info))] for k in env_info[0].keys()} # env_info is now a dict of lists return env_info def fill_dictionaries(self, index, env_info, extra_predictions): for info in self.aux_info: dtype = torch.long if required_heads[info].startswith('multiclass') else torch.float self.collected_info[info][index] = torch.tensor(env_info[info], dtype=dtype, device=self.device) self.extra_predictions[info][index] = extra_predictions[info] def end_collection(self, exps): collected_info = dict() extra_predictions = dict() for info in self.aux_info: # T x P -> P x T -> P * T collected_info[info] = self.collected_info[info].transpose(0, 1).reshape(-1) if required_heads[info] == 'binary' or required_heads[info].startswith('continuous'): # T x P x 1 -> P x T x 1 -> P * T extra_predictions[info] = self.extra_predictions[info].transpose(0, 1).reshape(-1) elif type(required_heads[info]) == int: # T x P x k -> P x T x k -> (P * T) x k k = required_heads[info] # number of classes extra_predictions[info] = self.extra_predictions[info].transpose(0, 1).reshape(-1, k) # convert the dicts to DictLists, and add them to the exps DictList. exps.collected_info = DictList(collected_info) exps.extra_predictions = DictList(extra_predictions) return exps class SupervisedLossUpdater: ''' This class, used by PPO, allows the evaluation of the supervised loss when using extra information from the environment. It also handles logging accuracies/L2 distances/etc... ''' def __init__(self, aux_info, supervised_loss_coef, recurrence, device): self.aux_info = aux_info self.supervised_loss_coef = supervised_loss_coef self.recurrence = recurrence self.device = device self.log_supervised_losses = [] self.log_supervised_accuracies = [] self.log_supervised_L2_losses = [] self.log_supervised_prevalences = [] self.batch_supervised_loss = 0 self.batch_supervised_accuracy = 0 self.batch_supervised_L2_loss = 0 self.batch_supervised_prevalence = 0 def init_epoch(self): self.log_supervised_losses = [] self.log_supervised_accuracies = [] self.log_supervised_L2_losses = [] self.log_supervised_prevalences = [] def init_batch(self): self.batch_supervised_loss = 0 self.batch_supervised_accuracy = 0 self.batch_supervised_L2_loss = 0 self.batch_supervised_prevalence = 0 def eval_subbatch(self, extra_predictions, sb): supervised_loss = torch.tensor(0., device=self.device) supervised_accuracy = torch.tensor(0., device=self.device) supervised_L2_loss = torch.tensor(0., device=self.device) supervised_prevalence = torch.tensor(0., device=self.device) binary_classification_tasks = 0 classification_tasks = 0 regression_tasks = 0 for pos, info in enumerate(self.aux_info): coef = self.supervised_loss_coef[pos] pred = extra_predictions[info] target = dict.__getitem__(sb.collected_info, info) if required_heads[info] == 'binary': binary_classification_tasks += 1 classification_tasks += 1 supervised_loss += coef * F.binary_cross_entropy_with_logits(pred.reshape(-1), target) supervised_accuracy += ((pred.reshape(-1) > 0).float() == target).float().mean() supervised_prevalence += target.mean() elif required_heads[info].startswith('continuous'): regression_tasks += 1 mse = F.mse_loss(pred.reshape(-1), target) supervised_loss += coef * mse supervised_L2_loss += mse elif required_heads[info].startswith('multiclass'): classification_tasks += 1 supervised_accuracy += (pred.argmax(1).float() == target).float().mean() supervised_loss += coef * F.cross_entropy(pred, target.long()) else: raise ValueError("{} not supported".format(required_heads[info])) if binary_classification_tasks > 0: supervised_prevalence /= binary_classification_tasks else: supervised_prevalence = torch.tensor(-1) if classification_tasks > 0: supervised_accuracy /= classification_tasks else: supervised_accuracy = torch.tensor(-1) if regression_tasks > 0: supervised_L2_loss /= regression_tasks else: supervised_L2_loss = torch.tensor(-1) self.batch_supervised_loss += supervised_loss.item() self.batch_supervised_accuracy += supervised_accuracy.item() self.batch_supervised_L2_loss += supervised_L2_loss.item() self.batch_supervised_prevalence += supervised_prevalence.item() return supervised_loss def update_batch_values(self): self.batch_supervised_loss /= self.recurrence self.batch_supervised_accuracy /= self.recurrence self.batch_supervised_L2_loss /= self.recurrence self.batch_supervised_prevalence /= self.recurrence def update_epoch_logs(self): self.log_supervised_losses.append(self.batch_supervised_loss) self.log_supervised_accuracies.append(self.batch_supervised_accuracy) self.log_supervised_L2_losses.append(self.batch_supervised_L2_loss) self.log_supervised_prevalences.append(self.batch_supervised_prevalence) def end_training(self, logs): logs["supervised_loss"] = numpy.mean(self.log_supervised_losses) logs["supervised_accuracy"] = numpy.mean(self.log_supervised_accuracies) logs["supervised_L2_loss"] = numpy.mean(self.log_supervised_L2_losses) logs["supervised_prevalence"] = numpy.mean(self.log_supervised_prevalences) return logs