import torch from .topk import TopK class BeamNode(object): def __init__(self, seq, state, score): self.seq = seq self.state = state self.score = score self.avg_score = score / len(seq) def __cmp__(self, other): if self.avg_score == other.avg_score: return 0 elif self.avg_score < other.avg_score: return -1 else: return 1 def __lt__(self, other): return self.avg_score < other.avg_score def __eq__(self, other): return self.avg_score == other.avg_score class BeamSearch(object): """Class to generate sequences from an image-to-text model.""" def __init__(self, decode_step, eos, beam_size=2, max_seq_len=32): self.decode_step = decode_step self.eos = eos self.beam_size = beam_size self.max_seq_len = max_seq_len def beam_search(self, init_inputs, init_states): # self.beam_size = 1 batch_size = len(init_inputs) part_seqs = [TopK(self.beam_size) for _ in range(batch_size)] comp_seqs = [TopK(self.beam_size) for _ in range(batch_size)] # print(init_inputs.shape, init_states.shape) words, scores, states = self.decode_step(init_inputs, init_states, k=self.beam_size) for batch_id in range(batch_size): for i in range(self.beam_size): node = BeamNode([words[batch_id][i]], states[:, :, batch_id, :], scores[batch_id][i]) part_seqs[batch_id].push(node) for t in range(self.max_seq_len - 1): part_seq_list = [] for p in part_seqs: part_seq_list.append(p.extract()) p.reset() inputs, states = [], [] for seq_list in part_seq_list: for node in seq_list: inputs.append(node.seq[-1]) states.append(node.state) if len(inputs) == 0: break inputs = torch.stack(inputs) states = torch.stack(states, dim=2) words, scores, states = self.decode_step(inputs, states, k=self.beam_size + 1) idx = 0 for batch_id in range(batch_size): for node in part_seq_list[batch_id]: tmp_state = states[:, :, idx, :] k = 0 num_hyp = 0 while num_hyp < self.beam_size: word = words[idx][k] tmp_seq = node.seq + [word] tmp_score = node.score + scores[idx][k] tmp_node = BeamNode(tmp_seq, tmp_state, tmp_score) k += 1 num_hyp += 1 if word == self.eos: comp_seqs[batch_id].push(tmp_node) num_hyp -= 1 else: part_seqs[batch_id].push(tmp_node) idx += 1 for batch_id in range(batch_size): if not comp_seqs[batch_id].size(): comp_seqs[batch_id] = part_seqs[batch_id] seqs = [seq_list.extract(sort=True)[0].seq for seq_list in comp_seqs] seq_scores = [seq_list.extract(sort=True)[0].avg_score for seq_list in comp_seqs] return seqs, seq_scores