Spaces:
Sleeping
Sleeping
Audio-Deepfake-Detection
/
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1
/fairseq
/ngram_repeat_block.py
# Originally from Microsoft Corporation. | |
# Licensed under the MIT License. | |
""" Wrapper for ngram_repeat_block cuda extension """ | |
import math | |
import warnings | |
from typing import Dict, List, Optional | |
import torch | |
from torch import nn | |
try: | |
from fairseq import ngram_repeat_block_cuda | |
EXTENSION_BUILT = True | |
except ImportError: | |
EXTENSION_BUILT = False | |
def is_cuda_extension_usable() -> bool: | |
"""Check whether ngram_repeat_block_cuda is built properly""" | |
if not EXTENSION_BUILT or not torch.cuda.is_available(): | |
return False | |
bsz = 2 | |
tokens = torch.tensor([[4, 4, 3, 2], [1, 2, 3, 4]], dtype=torch.long, device="cuda") | |
lprobs = torch.rand((8, 12), device="cuda") | |
try: | |
outputs = ngram_repeat_block_cuda.forward(tokens, lprobs, bsz, 3, 4, 3) | |
outputs = outputs + 4 # This line breaks if the extension is built incorrectly. | |
return True | |
except RuntimeError: | |
warnings.warn( | |
"NGramRepeatBlock extension must be rebuilt." | |
'Run TORCH_CUDA_ARCH_LIST="6.0;6.1;7.0" python setup.py build_ext --inplace' | |
) | |
return False | |
class NGramRepeatBlock(nn.Module): | |
"""Wrapper class for calling ngram_repeat_block cuda extension""" | |
def __init__(self, no_repeat_ngram_size: int, use_extension: bool = True): | |
super().__init__() | |
self.use_extension = is_cuda_extension_usable() if use_extension else False | |
self.no_repeat_ngram_size = no_repeat_ngram_size | |
def reset_parameters(self): | |
pass | |
def call_cuda_extension( | |
self, | |
tokens, | |
lprobs, | |
bsz: int, | |
beam_size: int, | |
step: int, | |
): | |
return ngram_repeat_block_cuda.forward( | |
tokens, lprobs, bsz, step, beam_size, self.no_repeat_ngram_size | |
) | |
def forward( | |
self, | |
tokens, | |
lprobs, | |
bsz: int, | |
beam_size: int, | |
step: int, | |
): | |
""" | |
Args: | |
tokens(Tensor): Input tokens(Bsz*beam, seq_len) | |
lprobs(Tensor): likelihood probability, | |
Expected to be updated in place.(Bsz*beam, vocab_size) | |
bsz(int): batch size | |
step(int): current step | |
beam_size(int): beam size | |
no_repeat_ngram_size(int): Ngram size | |
""" | |
msg = f"expected {bsz *beam_size} got" | |
assert tokens.size(0) == bsz * beam_size, f"{msg} {tokens.size(0)}" | |
assert lprobs.size(0) == bsz * beam_size, f"{msg} {lprobs.size(0)}" | |
if self.use_extension: | |
return self.call_cuda_extension(tokens, lprobs, bsz, beam_size, step) | |
else: | |
return self._no_repeat_ngram( | |
tokens, | |
lprobs, | |
bsz, | |
beam_size, | |
step, | |
) | |
def _no_repeat_ngram(self, tokens, lprobs, bsz: int, beam_size: int, step: int): | |
"""For each hypothesis generate a list of previous ngrams and set associated lprobs to -inf""" | |
gen_ngrams: List[Dict[str, List[int]]] = [ | |
torch.jit.annotate(Dict[str, List[int]], {}) | |
for bbsz_idx in range(bsz * beam_size) | |
] | |
cpu_tokens = tokens.cpu() | |
for bbsz_idx in range(bsz * beam_size): | |
gen_tokens: List[int] = cpu_tokens[bbsz_idx].tolist() | |
for ngram in self.transpose_list( | |
[gen_tokens[i:] for i in range(self.no_repeat_ngram_size)] | |
): | |
key = ",".join([str(x) for x in ngram[:-1]]) | |
gen_ngrams[bbsz_idx][key] = gen_ngrams[bbsz_idx].get( | |
key, torch.jit.annotate(List[int], []) | |
) + [ngram[-1]] | |
if step + 2 - self.no_repeat_ngram_size >= 0: | |
# no banned tokens if we haven't generated no_repeat_ngram_size tokens yet | |
banned_tokens = [ | |
self.calculate_banned_tokens( | |
tokens, step, gen_ngrams, self.no_repeat_ngram_size, bbsz_idx | |
) | |
for bbsz_idx in range(bsz * beam_size) | |
] | |
else: | |
banned_tokens = [ | |
torch.jit.annotate(List[int], []) for bbsz_idx in range(bsz * beam_size) | |
] | |
for bbsz_idx in range(bsz * beam_size): | |
lprobs[bbsz_idx][ | |
torch.tensor(banned_tokens[bbsz_idx], dtype=torch.int64) | |
] = torch.tensor(-math.inf).to(lprobs) | |
return lprobs | |
def calculate_banned_tokens( | |
tokens, | |
step: int, | |
gen_ngrams: List[Dict[str, List[int]]], | |
no_repeat_ngram_size: int, | |
bbsz_idx: int, | |
): | |
tokens_list: List[int] = tokens[ | |
bbsz_idx, step + 2 - no_repeat_ngram_size : step + 1 | |
].tolist() | |
# before decoding the next token, prevent decoding of ngrams that have already appeared | |
ngram_index = ",".join([str(x) for x in tokens_list]) | |
return gen_ngrams[bbsz_idx].get(ngram_index, torch.jit.annotate(List[int], [])) | |
def transpose_list(l: List[List[int]]): | |
# GeneratorExp aren't supported in TS so ignoring the lint | |
min_len = min([len(x) for x in l]) # noqa | |
l2 = [[row[i] for row in l] for i in range(min_len)] | |
return l2 | |