unstructured-chipper-app / logits_ngrams.py
ajimeno's picture
Updated Chipper model
ec9e91a
raw
history blame
No virus
3.41 kB
import re
import torch
from transformers import DonutProcessor
from transformers.utils import add_start_docstrings
from transformers.generation.logits_process import LogitsProcessor, LOGITS_PROCESSOR_INPUTS_DOCSTRING
# Inspired on https://github.com/huggingface/transformers/blob/8e3980a290acc6d2f8ea76dba111b9ef0ef00309/src/transformers/generation/logits_process.py#L706
class NoRepeatNGramLogitsProcessor(LogitsProcessor):
def __init__(self, ngram_size: int, skip_tokens = None):
if not isinstance(ngram_size, int) or ngram_size <= 0:
raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
self.ngram_size = ngram_size
self.skip_tokens = skip_tokens
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
num_batch_hypotheses = scores.shape[0]
cur_len = input_ids.shape[-1]
return _no_repeat_ngram_logits(input_ids, cur_len, scores, batch_size = num_batch_hypotheses, no_repeat_ngram_size=self.ngram_size, skip_tokens = self.skip_tokens)
def _no_repeat_ngram_logits(input_ids, cur_len, logits, batch_size=1, no_repeat_ngram_size=0, skip_tokens=None):
if no_repeat_ngram_size > 0:
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
banned_tokens = _calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
for batch_idx in range(batch_size):
if skip_tokens is not None:
logits[batch_idx, [token for token in banned_tokens[batch_idx] if int(token) not in skip_tokens]] = -float("inf")
else:
logits[batch_idx, banned_tokens[batch_idx]] = -float("inf")
return logits
def _calc_banned_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len):
# Copied from fairseq for no_repeat_ngram in beam_search"""
if cur_len + 1 < no_repeat_ngram_size:
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
return [[] for _ in range(num_hypos)]
generated_ngrams = [{} for _ in range(num_hypos)]
for idx in range(num_hypos):
gen_tokens = prev_input_ids[idx].tolist()
generated_ngram = generated_ngrams[idx]
for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
prev_ngram_tuple = tuple(ngram[:-1])
generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [
ngram[-1]
]
def _get_generated_ngrams(hypo_idx):
# Before decoding the next token, prevent decoding of ngrams that have already appeared
start_idx = cur_len + 1 - no_repeat_ngram_size
ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist())
return generated_ngrams[hypo_idx].get(ngram_idx, [])
banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
return banned_tokens
def get_table_token_ids(processor):
return {token_id for token, token_id in processor.tokenizer.get_added_vocab().items() if token.startswith("<t") or token.startswith("</t") }