|
import math |
|
import os |
|
import json |
|
from typing import List, Set, Union, Optional |
|
|
|
import numpy as np |
|
import torch |
|
from parlai.core.dict import DictionaryAgent |
|
from transformers import RobertaForSequenceClassification, RobertaTokenizer |
|
|
|
def cefr_to_int(cefr: str) -> int: |
|
mapping = { |
|
"A1": 0, |
|
"A2": 1, |
|
"B1": 2, |
|
"B2": 3, |
|
"C1": 4, |
|
"C2": 5, |
|
} |
|
clean_cefr = cefr.upper().strip() |
|
assert clean_cefr in mapping, f"CEFR must be one of {list(mapping.keys())}, not {cefr}" |
|
|
|
return mapping[clean_cefr] |
|
|
|
|
|
def load_wordlist(path: str) -> List[str]: |
|
""" |
|
Load a list of words from a text file containing one word per line |
|
""" |
|
vocab = [] |
|
|
|
if not path: |
|
return vocab |
|
|
|
assert os.path.isfile(path) |
|
|
|
with open(path, 'r', encoding="utf-8") as vocab_file: |
|
for row in vocab_file: |
|
token = row.strip() |
|
vocab.append(token) |
|
|
|
return vocab |
|
|
|
|
|
class Wordlist(): |
|
def __init__(self, allowed_words: List[str], dict_agent: DictionaryAgent): |
|
self.dict_agent = dict_agent |
|
|
|
|
|
self.boundary_ids = [] |
|
self.non_boundary_ids = [] |
|
|
|
for idx, subtoken in dict_agent.ind2tok.items(): |
|
if subtoken[0] == "\u0120" or not subtoken.isalpha(): |
|
self.boundary_ids.append(idx) |
|
else: |
|
self.non_boundary_ids.append(idx) |
|
|
|
|
|
|
|
self.allowed_sequences = [] |
|
self.allowed_continuations = {} |
|
for word in allowed_words: |
|
for word_variant in self._get_word_variants(word): |
|
token_ids = dict_agent.txt2vec(word_variant) |
|
self.allowed_sequences.append(repr(token_ids)) |
|
|
|
for i, idx in enumerate(token_ids[1:]): |
|
prefix = repr(token_ids[:i + 1]) |
|
if prefix not in self.allowed_continuations: |
|
self.allowed_continuations[prefix] = [] |
|
self.allowed_continuations[prefix].append(idx) |
|
|
|
self.allowed_sequences = set(self.allowed_sequences) |
|
|
|
|
|
def get_allowed_ids(self, token_ids: List[int]) -> List[int]: |
|
last_word = self._get_last_word(token_ids) |
|
continuation_ids = self._get_continuation_ids(last_word) |
|
|
|
return continuation_ids |
|
|
|
|
|
def _is_word(self, token_ids: List[int]) -> bool: |
|
""" |
|
For a given sequence of token IDs, determine whether that sequence is a complete word |
|
""" |
|
return (token_ids == [] or repr(token_ids) in self.allowed_sequences) |
|
|
|
|
|
def _get_continuation_ids(self, token_ids: List[int]) -> List[int]: |
|
""" |
|
For a given sequence of last word token IDs, determine which token IDs the word can continue with |
|
""" |
|
continuation_ids = [] |
|
if repr(token_ids) in self.allowed_continuations: |
|
continuation_ids.extend(self.allowed_continuations[repr(token_ids)]) |
|
|
|
if self._is_word(token_ids) or token_ids == []: |
|
continuation_ids.extend(self.boundary_ids) |
|
|
|
return continuation_ids |
|
|
|
|
|
def _get_last_word(self, token_ids: List[int]) -> List[int]: |
|
""" |
|
Get the sequence of token IDs after the last word boundary. |
|
Assumes that a word boundary is denoted by punctuation or whitespace (Ġ). |
|
""" |
|
for i in range(-1, -len(token_ids), -1): |
|
last_word = token_ids[i:] |
|
check_token = self.dict_agent[last_word[0]] |
|
|
|
if not check_token.isalpha(): |
|
return last_word[1:] |
|
|
|
if check_token[0] == "Ġ": |
|
return last_word |
|
|
|
raise ValueError("Boundary token not found") |
|
|
|
|
|
def _get_word_variants(self, word: str) -> Set[str]: |
|
return {word, word.lower(), word.capitalize()} |
|
|
|
|
|
|
|
class Reranker(): |
|
def __init__(self, |
|
cefr: int, |
|
model: str, |
|
tokenizer: str = "distilroberta-base", |
|
device: Optional[str] = "cuda", |
|
text_truncate: int = 128, |
|
exempt_tokens: Union[str, List[int]] = "all", |
|
penalty_stddev: int = 2, |
|
vocab_size: int = 8008, |
|
word_filter: Optional[List[str]] = None): |
|
|
|
self.tokenizer = RobertaTokenizer.from_pretrained(tokenizer) |
|
self.model = RobertaForSequenceClassification.from_pretrained(model) |
|
self.model.to(device) |
|
self.device = device |
|
|
|
self.target_cefr = cefr |
|
self.text_truncate = text_truncate |
|
self.word_filter = word_filter |
|
|
|
cefr_filepath = os.path.join(os.path.dirname(__file__), 'tokens_by_cefr.json') |
|
with open(cefr_filepath, 'r') as cefr_file: |
|
token_cefrs = json.load(cefr_file) |
|
|
|
if exempt_tokens == "all" or penalty_stddev < 0: |
|
self.token_penalties = torch.tensor([[1] * vocab_size]) |
|
else: |
|
|
|
normal_dist = torch.distributions.normal.Normal(0, penalty_stddev) |
|
cefr_penalties = [math.exp(normal_dist.log_prob(torch.tensor(i))) for i in range(6)] |
|
|
|
token_penalties = [] |
|
for i in range(vocab_size): |
|
if i in exempt_tokens: |
|
token_penalties.append(cefr_penalties[0]) |
|
|
|
elif str(i) in token_cefrs: |
|
token_str, token_cefr = token_cefrs[str(i)] |
|
penalty = cefr_penalties[int(token_cefr - self.target_cefr)] |
|
|
|
if token_cefr <= self.target_cefr or not token_str.isalpha(): |
|
penalty = cefr_penalties[0] |
|
|
|
token_penalties.append(penalty) |
|
|
|
else: |
|
token_penalties.append(cefr_penalties[int(5 - self.target_cefr)]) |
|
|
|
self.token_penalties = torch.tensor([token_penalties]) |
|
|
|
def get_complexity_scores(self, hyps: List[str]) -> np.ndarray: |
|
model_inputs = self.tokenizer(hyps, |
|
padding='max_length', |
|
truncation=True, |
|
max_length=self.text_truncate, |
|
return_tensors='pt', |
|
return_token_type_ids=True, |
|
return_attention_mask=True) |
|
|
|
model_output = self.model(input_ids=model_inputs["input_ids"].to(self.device), |
|
attention_mask=model_inputs["attention_mask"].to(self.device), |
|
token_type_ids=model_inputs["token_type_ids"].to(self.device)) |
|
|
|
complexity_scores = model_output.logits.cpu().numpy().flatten() |
|
complexity_diffs = 5 - np.absolute(complexity_scores - self.target_cefr) |
|
|
|
return complexity_diffs |
|
|
|
|