| | import os |
| | import random |
| | import json |
| | import torch |
| | import pprint |
| | import collections |
| | import numpy as np |
| | from torch import nn |
| | from tensorboardX import SummaryWriter |
| | from tqdm import trange |
| |
|
| | class Module(nn.Module): |
| |
|
| | def __init__(self, args, vocab): |
| | ''' |
| | Base Seq2Seq agent with common train and val loops |
| | ''' |
| | super().__init__() |
| |
|
| | |
| | self.pad = 0 |
| | self.seg = 1 |
| |
|
| | |
| | self.args = args |
| | self.vocab = vocab |
| |
|
| | |
| | self.emb_word = nn.Embedding(len(vocab['word']), args.demb) |
| | self.emb_action_low = nn.Embedding(len(vocab['action_low']), args.demb) |
| |
|
| | |
| | self.stop_token = self.vocab['action_low'].word2index("<<stop>>", train=False) |
| | self.seg_token = self.vocab['action_low'].word2index("<<seg>>", train=False) |
| |
|
| | |
| | random.seed(a=args.seed) |
| |
|
| | |
| | self.summary_writer = None |
| |
|
| | def run_train(self, splits, args=None, optimizer=None): |
| | ''' |
| | training loop |
| | ''' |
| |
|
| | |
| | args = args or self.args |
| |
|
| | |
| | train = splits['train'] |
| | valid_seen = splits['valid_seen'] |
| | valid_unseen = splits['valid_unseen'] |
| |
|
| | |
| | if self.args.dataset_fraction > 0: |
| | small_train_size = int(self.args.dataset_fraction * 0.7) |
| | small_valid_size = int((self.args.dataset_fraction * 0.3) / 2) |
| | train = train[:small_train_size] |
| | valid_seen = valid_seen[:small_valid_size] |
| | valid_unseen = valid_unseen[:small_valid_size] |
| |
|
| | |
| | if self.args.fast_epoch: |
| | train = train[:16] |
| | valid_seen = valid_seen[:16] |
| | valid_unseen = valid_unseen[:16] |
| |
|
| | |
| | self.summary_writer = SummaryWriter(log_dir=args.dout) |
| |
|
| | |
| | fconfig = os.path.join(args.dout, 'config.json') |
| | with open(fconfig, 'wt') as f: |
| | json.dump(vars(args), f, indent=2) |
| |
|
| | |
| | optimizer = optimizer or torch.optim.Adam(self.parameters(), lr=args.lr) |
| |
|
| | |
| | print("Saving to: %s" % self.args.dout) |
| | best_loss = {'train': 1e10, 'valid_seen': 1e10, 'valid_unseen': 1e10} |
| | train_iter, valid_seen_iter, valid_unseen_iter = 0, 0, 0 |
| | for epoch in trange(0, args.epoch, desc='epoch'): |
| | m_train = collections.defaultdict(list) |
| | self.train() |
| | self.adjust_lr(optimizer, args.lr, epoch, decay_epoch=args.decay_epoch) |
| | |
| | total_train_loss = list() |
| | random.shuffle(train) |
| | for batch, feat in self.iterate(train, args.batch): |
| | out = self.forward(feat) |
| | preds = self.extract_preds(out, batch, feat) |
| | |
| | loss = self.compute_loss(out, batch, feat) |
| | for k, v in loss.items(): |
| | ln = 'loss_' + k |
| | m_train[ln].append(v.item()) |
| | self.summary_writer.add_scalar('train/' + ln, v.item(), train_iter) |
| |
|
| | |
| | optimizer.zero_grad() |
| | sum_loss = sum(loss.values()) |
| | sum_loss.backward() |
| | optimizer.step() |
| |
|
| | self.summary_writer.add_scalar('train/loss', sum_loss, train_iter) |
| | sum_loss = sum_loss.detach().cpu() |
| | total_train_loss.append(float(sum_loss)) |
| | train_iter += self.args.batch |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | p_valid_seen, valid_seen_iter, total_valid_seen_loss, m_valid_seen = self.run_pred(valid_seen, args=args, name='valid_seen', iter=valid_seen_iter) |
| | m_valid_seen.update(self.compute_metric(p_valid_seen, valid_seen)) |
| | m_valid_seen['total_loss'] = float(total_valid_seen_loss) |
| | self.summary_writer.add_scalar('valid_seen/total_loss', m_valid_seen['total_loss'], valid_seen_iter) |
| |
|
| | |
| | p_valid_unseen, valid_unseen_iter, total_valid_unseen_loss, m_valid_unseen = self.run_pred(valid_unseen, args=args, name='valid_unseen', iter=valid_unseen_iter) |
| | m_valid_unseen.update(self.compute_metric(p_valid_unseen, valid_unseen)) |
| | m_valid_unseen['total_loss'] = float(total_valid_unseen_loss) |
| | self.summary_writer.add_scalar('valid_unseen/total_loss', m_valid_unseen['total_loss'], valid_unseen_iter) |
| |
|
| | stats = {'epoch': epoch, |
| | 'valid_seen': m_valid_seen, |
| | 'valid_unseen': m_valid_unseen} |
| |
|
| | |
| | if total_valid_seen_loss < best_loss['valid_seen']: |
| | print('\nFound new best valid_seen!! Saving...') |
| | fsave = os.path.join(args.dout, 'best_seen.pth') |
| | torch.save({ |
| | 'metric': stats, |
| | 'model': self.state_dict(), |
| | 'optim': optimizer.state_dict(), |
| | 'args': self.args, |
| | 'vocab': self.vocab, |
| | }, fsave) |
| | fbest = os.path.join(args.dout, 'best_seen.json') |
| | with open(fbest, 'wt') as f: |
| | json.dump(stats, f, indent=2) |
| |
|
| | fpred = os.path.join(args.dout, 'valid_seen.debug.preds.json') |
| | with open(fpred, 'wt') as f: |
| | json.dump(self.make_debug(p_valid_seen, valid_seen), f, indent=2) |
| | best_loss['valid_seen'] = total_valid_seen_loss |
| |
|
| | |
| | if total_valid_unseen_loss < best_loss['valid_unseen']: |
| | print('Found new best valid_unseen!! Saving...') |
| | fsave = os.path.join(args.dout, 'best_unseen.pth') |
| | torch.save({ |
| | 'metric': stats, |
| | 'model': self.state_dict(), |
| | 'optim': optimizer.state_dict(), |
| | 'args': self.args, |
| | 'vocab': self.vocab, |
| | }, fsave) |
| | fbest = os.path.join(args.dout, 'best_unseen.json') |
| | with open(fbest, 'wt') as f: |
| | json.dump(stats, f, indent=2) |
| |
|
| | fpred = os.path.join(args.dout, 'valid_unseen.debug.preds.json') |
| | with open(fpred, 'wt') as f: |
| | json.dump(self.make_debug(p_valid_unseen, valid_unseen), f, indent=2) |
| |
|
| | best_loss['valid_unseen'] = total_valid_unseen_loss |
| |
|
| | |
| | if args.save_every_epoch: |
| | fsave = os.path.join(args.dout, 'net_epoch_%d.pth' % epoch) |
| | else: |
| | fsave = os.path.join(args.dout, 'latest.pth') |
| | torch.save({ |
| | 'metric': stats, |
| | 'model': self.state_dict(), |
| | 'optim': optimizer.state_dict(), |
| | 'args': self.args, |
| | 'vocab': self.vocab, |
| | }, fsave) |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | for split in stats.keys(): |
| | if isinstance(stats[split], dict): |
| | for k, v in stats[split].items(): |
| | self.summary_writer.add_scalar(split + '/' + k, v, train_iter) |
| | pprint.pprint(stats) |
| |
|
| | def run_pred(self, dev, args=None, name='dev', iter=0): |
| | ''' |
| | validation loop |
| | ''' |
| | args = args or self.args |
| | m_dev = collections.defaultdict(list) |
| | p_dev = {} |
| | self.eval() |
| | total_loss = list() |
| | dev_iter = iter |
| | for batch, feat in self.iterate(dev, args.batch): |
| | out = self.forward(feat) |
| | preds = self.extract_preds(out, batch, feat) |
| | p_dev.update(preds) |
| | loss = self.compute_loss(out, batch, feat) |
| | for k, v in loss.items(): |
| | ln = 'loss_' + k |
| | m_dev[ln].append(v.item()) |
| | self.summary_writer.add_scalar("%s/%s" % (name, ln), v.item(), dev_iter) |
| | sum_loss = sum(loss.values()) |
| | self.summary_writer.add_scalar("%s/loss" % (name), sum_loss, dev_iter) |
| | total_loss.append(float(sum_loss.detach().cpu())) |
| | dev_iter += len(batch) |
| |
|
| | m_dev = {k: sum(v) / len(v) for k, v in m_dev.items()} |
| | total_loss = sum(total_loss) / len(total_loss) |
| | return p_dev, dev_iter, total_loss, m_dev |
| |
|
| | def featurize(self, batch): |
| | raise NotImplementedError() |
| |
|
| | def forward(self, feat, max_decode=100): |
| | raise NotImplementedError() |
| |
|
| | def extract_preds(self, out, batch, feat): |
| | raise NotImplementedError() |
| |
|
| | def compute_loss(self, out, batch, feat): |
| | raise NotImplementedError() |
| |
|
| | def compute_metric(self, preds, data): |
| | raise NotImplementedError() |
| |
|
| | def get_task_and_ann_id(self, ex): |
| | ''' |
| | single string for task_id and annotation repeat idx |
| | ''' |
| | return "%s_%s" % (ex['task_id'], str(ex['ann']['repeat_idx'])) |
| |
|
| | def make_debug(self, preds, data): |
| | ''' |
| | readable output generator for debugging |
| | ''' |
| | debug = {} |
| | for task in data: |
| | ex = self.load_task_json(task) |
| | i = self.get_task_and_ann_id(ex) |
| | debug[i] = { |
| | 'lang_goal': ex['turk_annotations']['anns'][ex['ann']['repeat_idx']]['task_desc'], |
| | 'action_low': [a['discrete_action']['action'] for a in ex['plan']['low_actions']], |
| | 'p_action_low': preds[i]['action_low'].split(), |
| | } |
| | return debug |
| |
|
| | def load_task_json(self, task): |
| | ''' |
| | load preprocessed json from disk |
| | ''' |
| | json_path = os.path.join(self.args.data, task['task'], '%s' % self.args.pp_folder, 'ann_%d.json' % task['repeat_idx']) |
| | with open(json_path) as f: |
| | data = json.load(f) |
| | return data |
| |
|
| | def get_task_root(self, ex): |
| | ''' |
| | returns the folder path of a trajectory |
| | ''' |
| | return os.path.join(self.args.data, ex['split'], *(ex['root'].split('/')[-2:])) |
| |
|
| | def iterate(self, data, batch_size): |
| | ''' |
| | breaks dataset into batch_size chunks for training |
| | ''' |
| | for i in trange(0, len(data), batch_size, desc='batch'): |
| | tasks = data[i:i+batch_size] |
| | batch = [self.load_task_json(task) for task in tasks] |
| | feat = self.featurize(batch) |
| | yield batch, feat |
| |
|
| | def zero_input(self, x, keep_end_token=True): |
| | ''' |
| | pad input with zeros (used for ablations) |
| | ''' |
| | end_token = [x[-1]] if keep_end_token else [self.pad] |
| | return list(np.full_like(x[:-1], self.pad)) + end_token |
| |
|
| | def zero_input_list(self, x, keep_end_token=True): |
| | ''' |
| | pad a list of input with zeros (used for ablations) |
| | ''' |
| | end_token = [x[-1]] if keep_end_token else [self.pad] |
| | lz = [list(np.full_like(i, self.pad)) for i in x[:-1]] + end_token |
| | return lz |
| |
|
| | @staticmethod |
| | def adjust_lr(optimizer, init_lr, epoch, decay_epoch=5): |
| | ''' |
| | decay learning rate every decay_epoch |
| | ''' |
| | lr = init_lr * (0.1 ** (epoch // decay_epoch)) |
| | for param_group in optimizer.param_groups: |
| | param_group['lr'] = lr |
| |
|
| | @classmethod |
| | def load(cls, fsave): |
| | ''' |
| | load pth model from disk |
| | ''' |
| | save = torch.load(fsave) |
| | model = cls(save['args'], save['vocab']) |
| | model.load_state_dict(save['model']) |
| | optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) |
| | optimizer.load_state_dict(save['optim']) |
| | return model, optimizer |
| |
|
| | @classmethod |
| | def has_interaction(cls, action): |
| | ''' |
| | check if low-level action is interactive |
| | ''' |
| | non_interact_actions = ['MoveAhead', 'Rotate', 'Look', '<<stop>>', '<<pad>>', '<<seg>>'] |
| | if any(a in action for a in non_interact_actions): |
| | return False |
| | else: |
| | return True |
| |
|