shivansarora's picture
Upload 14 files
3d18a82 verified
from parlai.core.opt import Opt
from parlai.utils.typing import TShared
from parlai.agents.transformer.transformer import TransformerGeneratorAgent
from .generation_methods import VocabTopKSampling, RerankedTopKSampling
from .generation_utils import Wordlist, Reranker, load_wordlist, cefr_to_int
class ControllableBlender(TransformerGeneratorAgent):
def __init__(self, opt: Opt, shared: TShared = None):
super().__init__(opt, shared)
if opt.get("inference", None) == "vocab":
wordlist_path = opt.get("wordlist_path", None)
assert wordlist_path, "Please provide path to vocab list, in order to use inference method 'vocab'"
allowed_words = load_wordlist(wordlist_path)
self.wordlist = Wordlist(allowed_words, self.dict)
elif opt.get("inference", None) == "rerank":
cefr = opt.get("rerank_cefr", None)
assert cefr, "Please provide CEFR level, in order to use inference method 'rerank'"
rerank_tokenizer = opt.get("rerank_tokenizer", None)
rerank_model = opt.get("rerank_model", None)
assert rerank_model, "Please provide path to directory containing model weights, in order to use inference method 'rerank'"
device = opt.get("complexity_model_device", None)
penalty_stddev = opt.get("penalty_stddev", None)
text_truncate = opt.get("text_truncate", None)
word_filter = None
filter_path = opt.get("filter_path", "")
if filter_path:
word_filter = load_wordlist(filter_path)
exempt_tokens = [self.dict.tok2ind.get(self.dict.null_token),
self.dict.tok2ind.get(self.dict.start_token),
self.dict.tok2ind.get(self.dict.end_token),
self.dict.tok2ind.get(self.dict.unk_token)]
if penalty_stddev < 0:
exempt_tokens = "all"
self.reranker = Reranker(cefr=cefr_to_int(cefr),
model=rerank_model,
tokenizer=rerank_tokenizer,
device=device,
text_truncate=text_truncate,
exempt_tokens=exempt_tokens,
penalty_stddev=penalty_stddev,
vocab_size=len(self.dict),
word_filter=word_filter)
else:
raise ValueError(f"Inference method {opt.get('inference', None)} does not exist. "
f"Please use 'vocab' or 'rerank'.")
def _treesearch_factory(self, device, verbose=False):
method = self.opt.get('inference', 'greedy')
beam_size = self.opt.get('beam_size', 1)
if method == 'vocab':
return VocabTopKSampling(
k=self.opt.get('topk', 40),
wordlist=self.wordlist,
beam_size=beam_size,
min_length=self.beam_min_length,
block_ngram=self.beam_block_ngram,
context_block_ngram=self.beam_context_block_ngram,
length_penalty=self.opt.get('beam_length_penalty', 0.65),
padding_token=self.NULL_IDX,
bos_token=self.START_IDX,
eos_token=self.END_IDX,
device=device,
verbose=verbose,
)
elif method == "rerank":
return RerankedTopKSampling(
k=self.opt.get('topk', 40),
reranker=self.reranker,
tokenids_to_text=self._v2t,
beam_size=beam_size,
min_length=self.beam_min_length,
block_ngram=self.beam_block_ngram,
context_block_ngram=self.beam_context_block_ngram,
length_penalty=self.opt.get('beam_length_penalty', 0.65),
padding_token=self.NULL_IDX,
bos_token=self.START_IDX,
eos_token=self.END_IDX,
device=device,
verbose=verbose,
)
else:
return super()._treesearch_factory(device, verbose=verbose)
def share(self):
"""
Share internal states between parent and child instances.
"""
shared = super().share()
if hasattr(self, 'wordlist'):
shared['wordlist'] = self.wordlist
if hasattr(self, 'reranker'):
shared['reranker'] = self.reranker
return shared