| |
|
|
| from typing import Dict |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from .utils import mean_with_lens, repeat_tensor |
|
|
|
|
| class CaptionModel(nn.Module): |
| """ |
| Encoder-decoder captioning model. |
| """ |
|
|
| pad_idx = 0 |
| start_idx = 1 |
| end_idx = 2 |
| max_length = 20 |
|
|
| def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs): |
| super().__init__() |
| self.encoder = encoder |
| self.decoder = decoder |
| self.vocab_size = decoder.vocab_size |
| self.train_forward_keys = ["cap", "cap_len", "ss_ratio"] |
| self.inference_forward_keys = ["sample_method", "max_length", "temp"] |
| freeze_encoder = kwargs.get("freeze_encoder", False) |
| if freeze_encoder: |
| for param in self.encoder.parameters(): |
| param.requires_grad = False |
| self.check_decoder_compatibility() |
|
|
| def check_decoder_compatibility(self): |
| compatible_decoders = [x.__class__.__name__ for x in self.compatible_decoders] |
| assert isinstance(self.decoder, self.compatible_decoders), \ |
| f"{self.decoder.__class__.__name__} is incompatible with " \ |
| f"{self.__class__.__name__}, please use decoder in {compatible_decoders} " |
|
|
| @classmethod |
| def set_index(cls, start_idx, end_idx): |
| cls.start_idx = start_idx |
| cls.end_idx = end_idx |
|
|
| def forward(self, input_dict: Dict): |
| """ |
| input_dict: { |
| (required) |
| mode: train/inference, |
| spec, |
| spec_len, |
| fc, |
| attn, |
| attn_len, |
| [sample_method: greedy], |
| [temp: 1.0] (in case of no teacher forcing) |
| |
| (optional, mode=train) |
| cap, |
| cap_len, |
| ss_ratio, |
| |
| (optional, mode=inference) |
| sample_method: greedy/beam, |
| max_length, |
| temp, |
| beam_size (optional, sample_method=beam), |
| n_best (optional, sample_method=beam), |
| } |
| """ |
| |
| |
| encoder_output_dict = self.encoder(input_dict) |
| if input_dict["mode"] == "train": |
| forward_dict = { |
| "mode": "train", "sample_method": "greedy", "temp": 1.0 |
| } |
| for key in self.train_forward_keys: |
| forward_dict[key] = input_dict[key] |
| forward_dict.update(encoder_output_dict) |
| output = self.train_forward(forward_dict) |
| elif input_dict["mode"] == "inference": |
| forward_dict = {"mode": "inference"} |
| default_args = { "sample_method": "greedy", "max_length": self.max_length, "temp": 1.0 } |
| for key in self.inference_forward_keys: |
| if key in input_dict: |
| forward_dict[key] = input_dict[key] |
| else: |
| forward_dict[key] = default_args[key] |
|
|
| if forward_dict["sample_method"] == "beam": |
| forward_dict["beam_size"] = input_dict.get("beam_size", 3) |
| forward_dict["n_best"] = input_dict.get("n_best", False) |
| forward_dict["n_best_size"] = input_dict.get("n_best_size", forward_dict["beam_size"]) |
| elif forward_dict["sample_method"] == "dbs": |
| forward_dict["beam_size"] = input_dict.get("beam_size", 6) |
| forward_dict["group_size"] = input_dict.get("group_size", 3) |
| forward_dict["diversity_lambda"] = input_dict.get("diversity_lambda", 0.5) |
| forward_dict["group_nbest"] = input_dict.get("group_nbest", True) |
|
|
| forward_dict.update(encoder_output_dict) |
| output = self.inference_forward(forward_dict) |
| else: |
| raise Exception("mode should be either 'train' or 'inference'") |
|
|
| return output |
|
|
| def prepare_output(self, input_dict): |
| output = {} |
| batch_size = input_dict["fc_emb"].size(0) |
| if input_dict["mode"] == "train": |
| max_length = input_dict["cap"].size(1) - 1 |
| elif input_dict["mode"] == "inference": |
| max_length = input_dict["max_length"] |
| else: |
| raise Exception("mode should be either 'train' or 'inference'") |
| device = input_dict["fc_emb"].device |
| output["seq"] = torch.full((batch_size, max_length), self.end_idx, |
| dtype=torch.long) |
| output["logit"] = torch.empty(batch_size, max_length, |
| self.vocab_size).to(device) |
| output["sampled_logprob"] = torch.zeros(batch_size, max_length) |
| output["embed"] = torch.empty(batch_size, max_length, |
| self.decoder.d_model).to(device) |
| return output |
|
|
| def train_forward(self, input_dict): |
| if input_dict["ss_ratio"] != 1: |
| input_dict["mode"] = "train" |
| return self.stepwise_forward(input_dict) |
| output = self.seq_forward(input_dict) |
| self.train_process(output, input_dict) |
| return output |
|
|
| def seq_forward(self, input_dict): |
| raise NotImplementedError |
|
|
| def train_process(self, output, input_dict): |
| pass |
|
|
| def inference_forward(self, input_dict): |
| if input_dict["sample_method"] == "beam": |
| return self.beam_search(input_dict) |
| elif input_dict["sample_method"] == "dbs": |
| return self.diverse_beam_search(input_dict) |
| return self.stepwise_forward(input_dict) |
|
|
| def stepwise_forward(self, input_dict): |
| """Step-by-step decoding""" |
| output = self.prepare_output(input_dict) |
| max_length = output["seq"].size(1) |
| |
| for t in range(max_length): |
| input_dict["t"] = t |
| self.decode_step(input_dict, output) |
| if input_dict["mode"] == "inference": |
| unfinished_t = output["seq"][:, t] != self.end_idx |
| if t == 0: |
| unfinished = unfinished_t |
| else: |
| unfinished *= unfinished_t |
| output["seq"][:, t][~unfinished] = self.end_idx |
| if unfinished.sum() == 0: |
| break |
| self.stepwise_process(output) |
| return output |
|
|
| def decode_step(self, input_dict, output): |
| """Decoding operation of timestep t""" |
| decoder_input = self.prepare_decoder_input(input_dict, output) |
| |
| output_t = self.decoder(decoder_input) |
| logit_t = output_t["logit"] |
| |
| if logit_t.size(1) == 1: |
| logit_t = logit_t.squeeze(1) |
| embed_t = output_t["embed"].squeeze(1) |
| elif logit_t.size(1) > 1: |
| logit_t = logit_t[:, -1, :] |
| embed_t = output_t["embed"][:, -1, :] |
| else: |
| raise Exception("no logit output") |
| |
| sampled = self.sample_next_word(logit_t, |
| method=input_dict["sample_method"], |
| temp=input_dict["temp"]) |
|
|
| output_t.update(sampled) |
| output_t["t"] = input_dict["t"] |
| output_t["logit"] = logit_t |
| output_t["embed"] = embed_t |
| self.stepwise_process_step(output, output_t) |
|
|
| def prepare_decoder_input(self, input_dict, output): |
| """Prepare the inp ut dict for the decoder""" |
| raise NotImplementedError |
| |
| def stepwise_process_step(self, output, output_t): |
| """Postprocessing (save output values) after each timestep t""" |
| t = output_t["t"] |
| output["logit"][:, t, :] = output_t["logit"] |
| output["seq"][:, t] = output_t["word"] |
| output["sampled_logprob"][:, t] = output_t["probs"] |
| output["embed"][:, t, :] = output_t["embed"] |
|
|
| def stepwise_process(self, output): |
| """Postprocessing after the whole step-by-step autoregressive decoding""" |
| pass |
|
|
| def sample_next_word(self, logit, method, temp): |
| """Sample the next word, given probs output by the decoder""" |
| logprob = torch.log_softmax(logit, dim=1) |
| if method == "greedy": |
| sampled_logprob, word = torch.max(logprob.detach(), 1) |
| elif method == "gumbel": |
| def sample_gumbel(shape, eps=1e-20): |
| U = torch.rand(shape).to(logprob.device) |
| return -torch.log(-torch.log(U + eps) + eps) |
| def gumbel_softmax_sample(logit, temperature): |
| y = logit + sample_gumbel(logit.size()) |
| return torch.log_softmax(y / temperature, dim=-1) |
| _logprob = gumbel_softmax_sample(logprob, temp) |
| _, word = torch.max(_logprob.data, 1) |
| sampled_logprob = logprob.gather(1, word.unsqueeze(-1)) |
| else: |
| logprob = logprob / temp |
| if method.startswith("top"): |
| top_num = float(method[3:]) |
| if 0 < top_num < 1: |
| probs = torch.softmax(logit, dim=1) |
| sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=1) |
| _cumsum = sorted_probs.cumsum(1) |
| mask = _cumsum < top_num |
| mask = torch.cat([torch.ones_like(mask[:,:1]), mask[:,:-1]], 1) |
| sorted_probs = sorted_probs * mask.to(sorted_probs) |
| sorted_probs = sorted_probs / sorted_probs.sum(1, keepdim=True) |
| logprob.scatter_(1, sorted_indices, sorted_probs.log()) |
| else: |
| k = int(top_num) |
| tmp = torch.empty_like(logprob).fill_(float('-inf')) |
| topk, indices = torch.topk(logprob, k, dim=1) |
| tmp = tmp.scatter(1, indices, topk) |
| logprob = tmp |
| word = torch.distributions.Categorical(logits=logprob.detach()).sample() |
| sampled_logprob = logprob.gather(1, word.unsqueeze(-1)).squeeze(1) |
| word = word.detach().long() |
| |
| return {"word": word, "probs": sampled_logprob} |
|
|
| def beam_search(self, input_dict): |
| output = self.prepare_output(input_dict) |
| max_length = input_dict["max_length"] |
| beam_size = input_dict["beam_size"] |
| if input_dict["n_best"]: |
| n_best_size = input_dict["n_best_size"] |
| batch_size, max_length = output["seq"].size() |
| output["seq"] = torch.full((batch_size, n_best_size, max_length), |
| self.end_idx, dtype=torch.long) |
| |
| temp = input_dict["temp"] |
| |
| for i in range(output["seq"].size(0)): |
| output_i = self.prepare_beamsearch_output(input_dict) |
| input_dict["sample_idx"] = i |
| for t in range(max_length): |
| input_dict["t"] = t |
| output_t = self.beamsearch_step(input_dict, output_i) |
| |
| |
| |
| logit_t = output_t["logit"] |
| if logit_t.size(1) == 1: |
| logit_t = logit_t.squeeze(1) |
| elif logit_t.size(1) > 1: |
| logit_t = logit_t[:, -1, :] |
| else: |
| raise Exception("no logit output") |
| logprob_t = torch.log_softmax(logit_t, dim=1) |
| logprob_t = torch.log_softmax(logprob_t / temp, dim=1) |
| logprob_t = output_i["topk_logprob"].unsqueeze(1) + logprob_t |
| if t == 0: |
| topk_logprob, topk_words = logprob_t[0].topk( |
| beam_size, 0, True, True) |
| else: |
| topk_logprob, topk_words = logprob_t.view(-1).topk( |
| beam_size, 0, True, True) |
| topk_words = topk_words.cpu() |
| output_i["topk_logprob"] = topk_logprob |
| |
| output_i["prev_words_beam"] = torch.div(topk_words, self.vocab_size, |
| rounding_mode='trunc') |
| output_i["next_word"] = topk_words % self.vocab_size |
| if t == 0: |
| output_i["seq"] = output_i["next_word"].unsqueeze(1) |
| else: |
| output_i["seq"] = torch.cat([ |
| output_i["seq"][output_i["prev_words_beam"]], |
| output_i["next_word"].unsqueeze(1)], dim=1) |
|
|
| |
| is_end = output_i["next_word"] == self.end_idx |
| if t == max_length - 1: |
| is_end.fill_(1) |
| |
| for beam_idx in range(beam_size): |
| if is_end[beam_idx]: |
| final_beam = { |
| "seq": output_i["seq"][beam_idx].clone(), |
| "score": output_i["topk_logprob"][beam_idx].item() |
| } |
| final_beam["score"] = final_beam["score"] / (t + 1) |
| output_i["done_beams"].append(final_beam) |
| output_i["topk_logprob"][is_end] -= 1000 |
|
|
| self.beamsearch_process_step(output_i, output_t) |
|
|
| self.beamsearch_process(output, output_i, input_dict) |
| return output |
|
|
| def prepare_beamsearch_output(self, input_dict): |
| beam_size = input_dict["beam_size"] |
| device = input_dict["fc_emb"].device |
| output = { |
| "topk_logprob": torch.zeros(beam_size).to(device), |
| "seq": None, |
| "prev_words_beam": None, |
| "next_word": None, |
| "done_beams": [], |
| } |
| return output |
|
|
| def beamsearch_step(self, input_dict, output_i): |
| decoder_input = self.prepare_beamsearch_decoder_input(input_dict, output_i) |
| output_t = self.decoder(decoder_input) |
| output_t["t"] = input_dict["t"] |
| return output_t |
|
|
| def prepare_beamsearch_decoder_input(self, input_dict, output_i): |
| raise NotImplementedError |
| |
| def beamsearch_process_step(self, output_i, output_t): |
| pass |
|
|
| def beamsearch_process(self, output, output_i, input_dict): |
| i = input_dict["sample_idx"] |
| done_beams = sorted(output_i["done_beams"], key=lambda x: -x["score"]) |
| if input_dict["n_best"]: |
| done_beams = done_beams[:input_dict["n_best_size"]] |
| for out_idx, done_beam in enumerate(done_beams): |
| seq = done_beam["seq"] |
| output["seq"][i][out_idx, :len(seq)] = seq |
| else: |
| seq = done_beams[0]["seq"] |
| output["seq"][i][:len(seq)] = seq |
| |
| def diverse_beam_search(self, input_dict): |
| |
| def add_diversity(seq_table, logprob, t, divm, diversity_lambda, bdash): |
| local_time = t - divm |
| unaug_logprob = logprob.clone() |
|
|
| if divm > 0: |
| change = torch.zeros(logprob.size(-1)) |
| for prev_choice in range(divm): |
| prev_decisions = seq_table[prev_choice][..., local_time] |
| for prev_labels in range(bdash): |
| change.scatter_add_(0, prev_decisions[prev_labels], change.new_ones(1)) |
|
|
| change = change.to(logprob.device) |
| logprob = logprob - repeat_tensor(change, bdash) * diversity_lambda |
|
|
| return logprob, unaug_logprob |
|
|
| output = self.prepare_output(input_dict) |
| group_size = input_dict["group_size"] |
| batch_size = output["seq"].size(0) |
| beam_size = input_dict["beam_size"] |
| bdash = beam_size // group_size |
| input_dict["bdash"] = bdash |
| diversity_lambda = input_dict["diversity_lambda"] |
| device = input_dict["fc_emb"].device |
| max_length = input_dict["max_length"] |
| temp = input_dict["temp"] |
| group_nbest = input_dict["group_nbest"] |
| batch_size, max_length = output["seq"].size() |
| if group_nbest: |
| output["seq"] = torch.full((batch_size, beam_size, max_length), |
| self.end_idx, dtype=torch.long) |
| else: |
| output["seq"] = torch.full((batch_size, group_size, max_length), |
| self.end_idx, dtype=torch.long) |
|
|
|
|
| for i in range(batch_size): |
| input_dict["sample_idx"] = i |
| seq_table = [torch.LongTensor(bdash, 0) for _ in range(group_size)] |
| logprob_table = [torch.zeros(bdash).to(device) for _ in range(group_size)] |
| done_beams_table = [[] for _ in range(group_size)] |
|
|
| output_i = { |
| "prev_words_beam": [None for _ in range(group_size)], |
| "next_word": [None for _ in range(group_size)], |
| "state": [None for _ in range(group_size)] |
| } |
|
|
| for t in range(max_length + group_size - 1): |
| input_dict["t"] = t |
| for divm in range(group_size): |
| input_dict["divm"] = divm |
| if t >= divm and t <= max_length + divm - 1: |
| local_time = t - divm |
| decoder_input = self.prepare_dbs_decoder_input(input_dict, output_i) |
| output_t = self.decoder(decoder_input) |
| output_t["divm"] = divm |
| logit_t = output_t["logit"] |
| if logit_t.size(1) == 1: |
| logit_t = logit_t.squeeze(1) |
| elif logit_t.size(1) > 1: |
| logit_t = logit_t[:, -1, :] |
| else: |
| raise Exception("no logit output") |
| logprob_t = torch.log_softmax(logit_t, dim=1) |
| logprob_t = torch.log_softmax(logprob_t / temp, dim=1) |
| logprob_t, unaug_logprob_t = add_diversity(seq_table, logprob_t, t, divm, diversity_lambda, bdash) |
| logprob_t = logprob_table[divm].unsqueeze(-1) + logprob_t |
| if local_time == 0: |
| topk_logprob, topk_words = logprob_t[0].topk( |
| bdash, 0, True, True) |
| else: |
| topk_logprob, topk_words = logprob_t.view(-1).topk( |
| bdash, 0, True, True) |
| topk_words = topk_words.cpu() |
| logprob_table[divm] = topk_logprob |
| output_i["prev_words_beam"][divm] = topk_words // self.vocab_size |
| output_i["next_word"][divm] = topk_words % self.vocab_size |
| if local_time > 0: |
| seq_table[divm] = seq_table[divm][output_i["prev_words_beam"][divm]] |
| seq_table[divm] = torch.cat([ |
| seq_table[divm], |
| output_i["next_word"][divm].unsqueeze(-1)], -1) |
|
|
| is_end = seq_table[divm][:, t-divm] == self.end_idx |
| assert seq_table[divm].shape[-1] == t - divm + 1 |
| if t == max_length + divm - 1: |
| is_end.fill_(1) |
| for beam_idx in range(bdash): |
| if is_end[beam_idx]: |
| final_beam = { |
| "seq": seq_table[divm][beam_idx].clone(), |
| "score": logprob_table[divm][beam_idx].item() |
| } |
| final_beam["score"] = final_beam["score"] / (t - divm + 1) |
| done_beams_table[divm].append(final_beam) |
| logprob_table[divm][is_end] -= 1000 |
| self.dbs_process_step(output_i, output_t) |
| done_beams_table = [sorted(done_beams_table[divm], key=lambda x: -x["score"])[:bdash] for divm in range(group_size)] |
| if group_nbest: |
| done_beams = sum(done_beams_table, []) |
| else: |
| done_beams = [group_beam[0] for group_beam in done_beams_table] |
| for _, done_beam in enumerate(done_beams): |
| output["seq"][i, _, :len(done_beam["seq"])] = done_beam["seq"] |
|
|
| return output |
| |
| def prepare_dbs_decoder_input(self, input_dict, output_i): |
| raise NotImplementedError |
|
|
| def dbs_process_step(self, output_i, output_t): |
| pass |
|
|
|
|
| class CaptionSequenceModel(nn.Module): |
|
|
| def __init__(self, model, seq_output_size): |
| super().__init__() |
| self.model = model |
| if model.decoder.d_model != seq_output_size: |
| self.output_transform = nn.Linear(model.decoder.d_model, seq_output_size) |
| else: |
| self.output_transform = lambda x: x |
|
|
| def forward(self, input_dict): |
| output = self.model(input_dict) |
|
|
| if input_dict["mode"] == "train": |
| lens = input_dict["cap_len"] - 1 |
| |
| elif input_dict["mode"] == "inference": |
| if "sample_method" in input_dict and input_dict["sample_method"] == "beam": |
| return output |
| seq = output["seq"] |
| lens = torch.where(seq == self.model.end_idx, torch.zeros_like(seq), torch.ones_like(seq)).sum(dim=1) |
| else: |
| raise Exception("mode should be either 'train' or 'inference'") |
| seq_output = mean_with_lens(output["embed"], lens) |
| seq_output = self.output_transform(seq_output) |
| output["seq_output"] = seq_output |
| return output |
|
|
|
|