# This file is the implementation for ensemble evaluation. from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import * from .CaptionModel import CaptionModel from .AttModel import pack_wrapper, AttModel class AttEnsemble(AttModel): def __init__(self, models, weights=None): CaptionModel.__init__(self) # super(AttEnsemble, self).__init__() self.models = nn.ModuleList(models) self.vocab_size = models[0].vocab_size self.seq_length = models[0].seq_length self.bad_endings_ix = models[0].bad_endings_ix self.ss_prob = 0 weights = weights or [1.0] * len(self.models) self.register_buffer('weights', torch.tensor(weights)) def init_hidden(self, batch_size): state = [m.init_hidden(batch_size) for m in self.models] return self.pack_state(state) def pack_state(self, state): self.state_lengths = [len(_) for _ in state] return sum([list(_) for _ in state], []) def unpack_state(self, state): out = [] for l in self.state_lengths: out.append(state[:l]) state = state[l:] return out def embed(self, it): return [m.embed(it) for m in self.models] def core(self, *args): return zip(*[m.core(*_) for m, _ in zip(self.models, zip(*args))]) def get_logprobs_state(self, it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state, output_logsoftmax=1): # 'it' contains a word index xt = self.embed(it) state = self.unpack_state(state) output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, state, tmp_att_masks) logprobs = torch.stack([F.softmax(m.logit(output[i]), dim=1) for i,m in enumerate(self.models)], 2).mul(self.weights).div(self.weights.sum()).sum(-1).log() return logprobs, self.pack_state(state) def _prepare_feature(self, *args): return tuple(zip(*[m._prepare_feature(*args) for m in self.models])) def _old_sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): beam_size = opt.get('beam_size', 10) batch_size = fc_feats.size(0) fc_feats, att_feats, p_att_feats, att_masks = self._prepare_feature(fc_feats, att_feats, att_masks) assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' seq = torch.LongTensor(self.seq_length, batch_size).zero_() seqLogprobs = torch.FloatTensor(self.seq_length, batch_size, self.vocab_size + 1) # lets process every image independently for now, for simplicity self.done_beams = [[] for _ in range(batch_size)] for k in range(batch_size): state = self.init_hidden(beam_size) tmp_fc_feats = [fc_feats[i][k:k+1].expand(beam_size, fc_feats[i].size(1)) for i,m in enumerate(self.models)] tmp_att_feats = [att_feats[i][k:k+1].expand(*((beam_size,)+att_feats[i].size()[1:])).contiguous() for i,m in enumerate(self.models)] tmp_p_att_feats = [p_att_feats[i][k:k+1].expand(*((beam_size,)+p_att_feats[i].size()[1:])).contiguous() for i,m in enumerate(self.models)] tmp_att_masks = [att_masks[i][k:k+1].expand(*((beam_size,)+att_masks[i].size()[1:])).contiguous() if att_masks[i] is not None else att_masks[i] for i,m in enumerate(self.models)] it = fc_feats[0].data.new(beam_size).long().zero_() logprobs, state = self.get_logprobs_state(it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state) self.done_beams[k] = self.old_beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, opt=opt) seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score seqLogprobs[:, k] = self.done_beams[k][0]['logps'] # return the samples and their log likelihoods return seq.transpose(0, 1), seqLogprobs.transpose(0, 1) # return the samples and their log likelihoods