|
from parlai.core.opt import Opt |
|
from parlai.utils.typing import TShared |
|
from parlai.agents.transformer.transformer import TransformerGeneratorAgent |
|
|
|
from .generation_methods import VocabTopKSampling, RerankedTopKSampling |
|
from .generation_utils import Wordlist, Reranker, load_wordlist, cefr_to_int |
|
|
|
class ControllableBlender(TransformerGeneratorAgent): |
|
def __init__(self, opt: Opt, shared: TShared = None): |
|
super().__init__(opt, shared) |
|
|
|
if opt.get("inference", None) == "vocab": |
|
wordlist_path = opt.get("wordlist_path", None) |
|
assert wordlist_path, "Please provide path to vocab list, in order to use inference method 'vocab'" |
|
|
|
allowed_words = load_wordlist(wordlist_path) |
|
self.wordlist = Wordlist(allowed_words, self.dict) |
|
|
|
elif opt.get("inference", None) == "rerank": |
|
cefr = opt.get("rerank_cefr", None) |
|
assert cefr, "Please provide CEFR level, in order to use inference method 'rerank'" |
|
|
|
rerank_tokenizer = opt.get("rerank_tokenizer", None) |
|
rerank_model = opt.get("rerank_model", None) |
|
assert rerank_model, "Please provide path to directory containing model weights, in order to use inference method 'rerank'" |
|
|
|
device = opt.get("complexity_model_device", None) |
|
penalty_stddev = opt.get("penalty_stddev", None) |
|
text_truncate = opt.get("text_truncate", None) |
|
|
|
word_filter = None |
|
filter_path = opt.get("filter_path", "") |
|
if filter_path: |
|
word_filter = load_wordlist(filter_path) |
|
|
|
exempt_tokens = [self.dict.tok2ind.get(self.dict.null_token), |
|
self.dict.tok2ind.get(self.dict.start_token), |
|
self.dict.tok2ind.get(self.dict.end_token), |
|
self.dict.tok2ind.get(self.dict.unk_token)] |
|
|
|
if penalty_stddev < 0: |
|
exempt_tokens = "all" |
|
|
|
self.reranker = Reranker(cefr=cefr_to_int(cefr), |
|
model=rerank_model, |
|
tokenizer=rerank_tokenizer, |
|
device=device, |
|
text_truncate=text_truncate, |
|
exempt_tokens=exempt_tokens, |
|
penalty_stddev=penalty_stddev, |
|
vocab_size=len(self.dict), |
|
word_filter=word_filter) |
|
|
|
else: |
|
raise ValueError(f"Inference method {opt.get('inference', None)} does not exist. " |
|
f"Please use 'vocab' or 'rerank'.") |
|
|
|
|
|
def _treesearch_factory(self, device, verbose=False): |
|
method = self.opt.get('inference', 'greedy') |
|
beam_size = self.opt.get('beam_size', 1) |
|
if method == 'vocab': |
|
return VocabTopKSampling( |
|
k=self.opt.get('topk', 40), |
|
wordlist=self.wordlist, |
|
beam_size=beam_size, |
|
min_length=self.beam_min_length, |
|
block_ngram=self.beam_block_ngram, |
|
context_block_ngram=self.beam_context_block_ngram, |
|
length_penalty=self.opt.get('beam_length_penalty', 0.65), |
|
padding_token=self.NULL_IDX, |
|
bos_token=self.START_IDX, |
|
eos_token=self.END_IDX, |
|
device=device, |
|
verbose=verbose, |
|
) |
|
elif method == "rerank": |
|
return RerankedTopKSampling( |
|
k=self.opt.get('topk', 40), |
|
reranker=self.reranker, |
|
tokenids_to_text=self._v2t, |
|
beam_size=beam_size, |
|
min_length=self.beam_min_length, |
|
block_ngram=self.beam_block_ngram, |
|
context_block_ngram=self.beam_context_block_ngram, |
|
length_penalty=self.opt.get('beam_length_penalty', 0.65), |
|
padding_token=self.NULL_IDX, |
|
bos_token=self.START_IDX, |
|
eos_token=self.END_IDX, |
|
device=device, |
|
verbose=verbose, |
|
) |
|
else: |
|
return super()._treesearch_factory(device, verbose=verbose) |
|
|
|
def share(self): |
|
""" |
|
Share internal states between parent and child instances. |
|
""" |
|
shared = super().share() |
|
if hasattr(self, 'wordlist'): |
|
shared['wordlist'] = self.wordlist |
|
if hasattr(self, 'reranker'): |
|
shared['reranker'] = self.reranker |
|
return shared |
|
|
|
|
|
|