PromptNet / modules /caption_model.py
fenglinliu's picture
Upload 55 files
6e32a75 verified
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import modules.utils as utils
class CaptionModel(nn.Module):
def __init__(self):
super(CaptionModel, self).__init__()
# implements beam search
# calls beam_step and returns the final set of beams
# augments log-probabilities with diversity terms when number of groups > 1
def forward(self, *args, **kwargs):
mode = kwargs.get('mode', 'forward')
if 'mode' in kwargs:
del kwargs['mode']
return getattr(self, '_' + mode)(*args, **kwargs)
def beam_search(self, init_state, init_logprobs, *args, **kwargs):
# function computes the similarity score to be augmented
def add_diversity(beam_seq_table, logprobs, t, divm, diversity_lambda, bdash):
local_time = t - divm
unaug_logprobs = logprobs.clone()
batch_size = beam_seq_table[0].shape[0]
if divm > 0:
change = logprobs.new_zeros(batch_size, logprobs.shape[-1])
for prev_choice in range(divm):
prev_decisions = beam_seq_table[prev_choice][:, :, local_time] # Nxb
for prev_labels in range(bdash):
change.scatter_add_(1, prev_decisions[:, prev_labels].unsqueeze(-1),
change.new_ones(batch_size, 1))
if local_time == 0:
logprobs = logprobs - change * diversity_lambda
else:
logprobs = logprobs - self.repeat_tensor(bdash, change) * diversity_lambda
return logprobs, unaug_logprobs
# does one step of classical beam search
def beam_step(logprobs, unaug_logprobs, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state):
# INPUTS:
# logprobs: probabilities augmented after diversity N*bxV
# beam_size: obvious
# t : time instant
# beam_seq : tensor contanining the beams
# beam_seq_logprobs: tensor contanining the beam logprobs
# beam_logprobs_sum: tensor contanining joint logprobs
# OUPUTS:
# beam_seq : tensor containing the word indices of the decoded captions Nxbxl
# beam_seq_logprobs : log-probability of each decision made, NxbxlxV
# beam_logprobs_sum : joint log-probability of each beam Nxb
batch_size = beam_logprobs_sum.shape[0]
vocab_size = logprobs.shape[-1]
logprobs = logprobs.reshape(batch_size, -1, vocab_size) # NxbxV
if t == 0:
assert logprobs.shape[1] == 1
beam_logprobs_sum = beam_logprobs_sum[:, :1]
candidate_logprobs = beam_logprobs_sum.unsqueeze(-1) + logprobs # beam_logprobs_sum Nxb logprobs is NxbxV
ys, ix = torch.sort(candidate_logprobs.reshape(candidate_logprobs.shape[0], -1), -1, True)
ys, ix = ys[:, :beam_size], ix[:, :beam_size]
beam_ix = ix // vocab_size # Nxb which beam
selected_ix = ix % vocab_size # Nxb # which world
state_ix = (beam_ix + torch.arange(batch_size).type_as(beam_ix).unsqueeze(-1) * logprobs.shape[1]).reshape(
-1) # N*b which in Nxb beams
if t > 0:
# gather according to beam_ix
assert (beam_seq.gather(1, beam_ix.unsqueeze(-1).expand_as(beam_seq)) ==
beam_seq.reshape(-1, beam_seq.shape[-1])[state_ix].view_as(beam_seq)).all()
beam_seq = beam_seq.gather(1, beam_ix.unsqueeze(-1).expand_as(beam_seq))
beam_seq_logprobs = beam_seq_logprobs.gather(1, beam_ix.unsqueeze(-1).unsqueeze(-1).expand_as(
beam_seq_logprobs))
beam_seq = torch.cat([beam_seq, selected_ix.unsqueeze(-1)], -1) # beam_seq Nxbxl
beam_logprobs_sum = beam_logprobs_sum.gather(1, beam_ix) + \
logprobs.reshape(batch_size, -1).gather(1, ix)
assert (beam_logprobs_sum == ys).all()
_tmp_beam_logprobs = unaug_logprobs[state_ix].reshape(batch_size, -1, vocab_size)
beam_logprobs = unaug_logprobs.reshape(batch_size, -1, vocab_size).gather(1,
beam_ix.unsqueeze(-1).expand(-1,
-1,
vocab_size)) # NxbxV
assert (_tmp_beam_logprobs == beam_logprobs).all()
beam_seq_logprobs = torch.cat([
beam_seq_logprobs,
beam_logprobs.reshape(batch_size, -1, 1, vocab_size)], 2)
new_state = [None for _ in state]
for _ix in range(len(new_state)):
# copy over state in previous beam q to new beam at vix
new_state[_ix] = state[_ix][:, state_ix]
state = new_state
return beam_seq, beam_seq_logprobs, beam_logprobs_sum, state
# Start diverse_beam_search
opt = kwargs['opt']
temperature = opt.get('temperature', 1) # This should not affect beam search, but will affect dbs
beam_size = opt.get('beam_size', 10)
group_size = opt.get('group_size', 1)
diversity_lambda = opt.get('diversity_lambda', 0.5)
decoding_constraint = opt.get('decoding_constraint', 0)
suppress_UNK = opt.get('suppress_UNK', 0)
length_penalty = utils.penalty_builder(opt.get('length_penalty', ''))
bdash = beam_size // group_size # beam per group
batch_size = init_logprobs.shape[0]
device = init_logprobs.device
# INITIALIZATIONS
beam_seq_table = [torch.LongTensor(batch_size, bdash, 0).to(device) for _ in range(group_size)]
beam_seq_logprobs_table = [torch.FloatTensor(batch_size, bdash, 0, self.vocab_size + 1).to(device) for _ in
range(group_size)]
beam_logprobs_sum_table = [torch.zeros(batch_size, bdash).to(device) for _ in range(group_size)]
# logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1)
done_beams_table = [[[] for __ in range(group_size)] for _ in range(batch_size)]
state_table = [[_.clone() for _ in init_state] for _ in range(group_size)]
logprobs_table = [init_logprobs.clone() for _ in range(group_size)]
# END INIT
# Chunk elements in the args
args = list(args)
args = utils.split_tensors(group_size, args) # For each arg, turn (Bbg)x... to (Bb)x(g)x...
if self.__class__.__name__ == 'AttEnsemble':
args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in
range(group_size)] # group_name, arg_name, model_name
else:
args = [[args[i][j] for i in range(len(args))] for j in range(group_size)]
for t in range(self.max_seq_length + group_size - 1):
for divm in range(group_size):
if t >= divm and t <= self.max_seq_length + divm - 1:
# add diversity
logprobs = logprobs_table[divm]
# suppress previous word
if decoding_constraint and t - divm > 0:
logprobs.scatter_(1, beam_seq_table[divm][:, :, t - divm - 1].reshape(-1, 1).to(device),
float('-inf'))
# suppress UNK tokens in the decoding
if suppress_UNK and hasattr(self, 'vocab') and self.vocab[str(logprobs.size(1) - 1)] == 'UNK':
logprobs[:, logprobs.size(1) - 1] = logprobs[:, logprobs.size(1) - 1] - 1000
# diversity is added here
# the function directly modifies the logprobs values and hence, we need to return
# the unaugmented ones for sorting the candidates in the end. # for historical
# reasons :-)
logprobs, unaug_logprobs = add_diversity(beam_seq_table, logprobs, t, divm, diversity_lambda, bdash)
# infer new beams
beam_seq_table[divm], \
beam_seq_logprobs_table[divm], \
beam_logprobs_sum_table[divm], \
state_table[divm] = beam_step(logprobs,
unaug_logprobs,
bdash,
t - divm,
beam_seq_table[divm],
beam_seq_logprobs_table[divm],
beam_logprobs_sum_table[divm],
state_table[divm])
# if time's up... or if end token is reached then copy beams
for b in range(batch_size):
is_end = beam_seq_table[divm][b, :, t - divm] == self.eos_idx
assert beam_seq_table[divm].shape[-1] == t - divm + 1
if t == self.max_seq_length + divm - 1:
is_end.fill_(1)
for vix in range(bdash):
if is_end[vix]:
final_beam = {
'seq': beam_seq_table[divm][b, vix].clone(),
'logps': beam_seq_logprobs_table[divm][b, vix].clone(),
'unaug_p': beam_seq_logprobs_table[divm][b, vix].sum().item(),
'p': beam_logprobs_sum_table[divm][b, vix].item()
}
final_beam['p'] = length_penalty(t - divm + 1, final_beam['p'])
done_beams_table[b][divm].append(final_beam)
beam_logprobs_sum_table[divm][b, is_end] -= 1000
# move the current group one step forward in time
it = beam_seq_table[divm][:, :, t - divm].reshape(-1)
logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it.cuda(), *(
args[divm] + [state_table[divm]]))
logprobs_table[divm] = F.log_softmax(logprobs_table[divm] / temperature, dim=-1)
# all beams are sorted by their log-probabilities
done_beams_table = [[sorted(done_beams_table[b][i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)]
for b in range(batch_size)]
done_beams = [sum(_, []) for _ in done_beams_table]
return done_beams
def old_beam_search(self, init_state, init_logprobs, *args, **kwargs):
# function computes the similarity score to be augmented
def add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda, bdash):
local_time = t - divm
unaug_logprobsf = logprobsf.clone()
for prev_choice in range(divm):
prev_decisions = beam_seq_table[prev_choice][local_time]
for sub_beam in range(bdash):
for prev_labels in range(bdash):
logprobsf[sub_beam][prev_decisions[prev_labels]] = logprobsf[sub_beam][prev_decisions[
prev_labels]] - diversity_lambda
return unaug_logprobsf
# does one step of classical beam search
def beam_step(logprobsf, unaug_logprobsf, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state):
# INPUTS:
# logprobsf: probabilities augmented after diversity
# beam_size: obvious
# t : time instant
# beam_seq : tensor contanining the beams
# beam_seq_logprobs: tensor contanining the beam logprobs
# beam_logprobs_sum: tensor contanining joint logprobs
# OUPUTS:
# beam_seq : tensor containing the word indices of the decoded captions
# beam_seq_logprobs : log-probability of each decision made, same size as beam_seq
# beam_logprobs_sum : joint log-probability of each beam
ys, ix = torch.sort(logprobsf, 1, True)
candidates = []
cols = min(beam_size, ys.size(1))
rows = beam_size
if t == 0:
rows = 1
for c in range(cols): # for each column (word, essentially)
for q in range(rows): # for each beam expansion
# compute logprob of expanding beam q with word in (sorted) position c
local_logprob = ys[q, c].item()
candidate_logprob = beam_logprobs_sum[q] + local_logprob
# local_unaug_logprob = unaug_logprobsf[q,ix[q,c]]
candidates.append({'c': ix[q, c], 'q': q, 'p': candidate_logprob, 'r': unaug_logprobsf[q]})
candidates = sorted(candidates, key=lambda x: -x['p'])
new_state = [_.clone() for _ in state]
# beam_seq_prev, beam_seq_logprobs_prev
if t >= 1:
# we''ll need these as reference when we fork beams around
beam_seq_prev = beam_seq[:t].clone()
beam_seq_logprobs_prev = beam_seq_logprobs[:t].clone()
for vix in range(beam_size):
v = candidates[vix]
# fork beam index q into index vix
if t >= 1:
beam_seq[:t, vix] = beam_seq_prev[:, v['q']]
beam_seq_logprobs[:t, vix] = beam_seq_logprobs_prev[:, v['q']]
# rearrange recurrent states
for state_ix in range(len(new_state)):
# copy over state in previous beam q to new beam at vix
new_state[state_ix][:, vix] = state[state_ix][:, v['q']] # dimension one is time step
# append new end terminal at the end of this beam
beam_seq[t, vix] = v['c'] # c'th word is the continuation
beam_seq_logprobs[t, vix] = v['r'] # the raw logprob here
beam_logprobs_sum[vix] = v['p'] # the new (sum) logprob along this beam
state = new_state
return beam_seq, beam_seq_logprobs, beam_logprobs_sum, state, candidates
# Start diverse_beam_search
opt = kwargs['opt']
temperature = opt.get('temperature', 1) # This should not affect beam search, but will affect dbs
beam_size = opt.get('beam_size', 10)
group_size = opt.get('group_size', 1)
diversity_lambda = opt.get('diversity_lambda', 0.5)
decoding_constraint = opt.get('decoding_constraint', 0)
suppress_UNK = opt.get('suppress_UNK', 0)
length_penalty = utils.penalty_builder(opt.get('length_penalty', ''))
bdash = beam_size // group_size # beam per group
# INITIALIZATIONS
beam_seq_table = [torch.LongTensor(self.max_seq_length, bdash).zero_() for _ in range(group_size)]
beam_seq_logprobs_table = [torch.FloatTensor(self.max_seq_length, bdash, self.vocab_size + 1).zero_() for _ in
range(group_size)]
beam_logprobs_sum_table = [torch.zeros(bdash) for _ in range(group_size)]
# logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1)
done_beams_table = [[] for _ in range(group_size)]
# state_table = [list(torch.unbind(_)) for _ in torch.stack(init_state).chunk(group_size, 2)]
state_table = list(zip(*[_.chunk(group_size, 1) for _ in init_state]))
logprobs_table = list(init_logprobs.chunk(group_size, 0))
# END INIT
# Chunk elements in the args
args = list(args)
if self.__class__.__name__ == 'AttEnsemble':
args = [[_.chunk(group_size) if _ is not None else [None] * group_size for _ in args_] for args_ in
args] # arg_name, model_name, group_name
args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in
range(group_size)] # group_name, arg_name, model_name
else:
args = [_.chunk(group_size) if _ is not None else [None] * group_size for _ in args]
args = [[args[i][j] for i in range(len(args))] for j in range(group_size)]
for t in range(self.max_seq_length + group_size - 1):
for divm in range(group_size):
if t >= divm and t <= self.max_seq_length + divm - 1:
# add diversity
logprobsf = logprobs_table[divm].float()
# suppress previous word
if decoding_constraint and t - divm > 0:
logprobsf.scatter_(1, beam_seq_table[divm][t - divm - 1].unsqueeze(1).cuda(), float('-inf'))
# suppress UNK tokens in the decoding
if suppress_UNK and hasattr(self, 'vocab') and self.vocab[str(logprobsf.size(1) - 1)] == 'UNK':
logprobsf[:, logprobsf.size(1) - 1] = logprobsf[:, logprobsf.size(1) - 1] - 1000
# diversity is added here
# the function directly modifies the logprobsf values and hence, we need to return
# the unaugmented ones for sorting the candidates in the end. # for historical
# reasons :-)
unaug_logprobsf = add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda, bdash)
# infer new beams
beam_seq_table[divm], \
beam_seq_logprobs_table[divm], \
beam_logprobs_sum_table[divm], \
state_table[divm], \
candidates_divm = beam_step(logprobsf,
unaug_logprobsf,
bdash,
t - divm,
beam_seq_table[divm],
beam_seq_logprobs_table[divm],
beam_logprobs_sum_table[divm],
state_table[divm])
# if time's up... or if end token is reached then copy beams
for vix in range(bdash):
if beam_seq_table[divm][t - divm, vix] == self.eos_idx or t == self.max_seq_length + divm - 1:
final_beam = {
'seq': beam_seq_table[divm][:, vix].clone(),
'logps': beam_seq_logprobs_table[divm][:, vix].clone(),
'unaug_p': beam_seq_logprobs_table[divm][:, vix].sum().item(),
'p': beam_logprobs_sum_table[divm][vix].item()
}
final_beam['p'] = length_penalty(t - divm + 1, final_beam['p'])
done_beams_table[divm].append(final_beam)
# don't continue beams from finished sequences
beam_logprobs_sum_table[divm][vix] = -1000
# move the current group one step forward in time
it = beam_seq_table[divm][t - divm]
logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it.cuda(), *(
args[divm] + [state_table[divm]]))
logprobs_table[divm] = F.log_softmax(logprobs_table[divm] / temperature, dim=-1)
# all beams are sorted by their log-probabilities
done_beams_table = [sorted(done_beams_table[i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)]
done_beams = sum(done_beams_table, [])
return done_beams
def sample_next_word(self, logprobs, sample_method, temperature):
if sample_method == 'greedy':
sampleLogprobs, it = torch.max(logprobs.data, 1)
it = it.view(-1).long()
elif sample_method == 'gumbel': # gumbel softmax
def sample_gumbel(shape, eps=1e-20):
U = torch.rand(shape).cuda()
return -torch.log(-torch.log(U + eps) + eps)
def gumbel_softmax_sample(logits, temperature):
y = logits + sample_gumbel(logits.size())
return F.log_softmax(y / temperature, dim=-1)
_logprobs = gumbel_softmax_sample(logprobs, temperature)
_, it = torch.max(_logprobs.data, 1)
sampleLogprobs = logprobs.gather(1, it.unsqueeze(1)) # gather the logprobs at sampled positions
else:
logprobs = logprobs / temperature
if sample_method.startswith('top'): # topk sampling
top_num = float(sample_method[3:])
if 0 < top_num < 1:
# nucleus sampling from # The Curious Case of Neural Text Degeneration
probs = F.softmax(logprobs, dim=1)
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=1)
_cumsum = sorted_probs.cumsum(1)
mask = _cumsum < top_num
mask = torch.cat([torch.ones_like(mask[:, :1]), mask[:, :-1]], 1)
sorted_probs = sorted_probs * mask.float()
sorted_probs = sorted_probs / sorted_probs.sum(1, keepdim=True)
logprobs.scatter_(1, sorted_indices, sorted_probs.log())
else:
the_k = int(top_num)
tmp = torch.empty_like(logprobs).fill_(float('-inf'))
topk, indices = torch.topk(logprobs, the_k, dim=1)
tmp = tmp.scatter(1, indices, topk)
logprobs = tmp
it = torch.distributions.Categorical(logits=logprobs.detach()).sample()
sampleLogprobs = logprobs.gather(1, it.unsqueeze(1)) # gather the logprobs at sampled positions
return it, sampleLogprobs