qianhuiwu's picture
Initial commit.
cdba444
raw history blame
No virus
3.16 kB
import torch
from torch.utils.data import Dataset
import random, os
import numpy as np
import torch
import string
class TokenClfDataset(Dataset):
def __init__(
self,
texts,
max_len=512,
tokenizer=None,
model_name="bert-base-multilingual-cased",
):
self.len = len(texts)
self.texts = texts
self.tokenizer = tokenizer
self.max_len = max_len
self.model_name = model_name
if "bert-base-multilingual-cased" in model_name:
self.cls_token = "[CLS]"
self.sep_token = "[SEP]"
self.unk_token = "[UNK]"
self.pad_token = "[PAD]"
self.mask_token = "[MASK]"
elif "xlm-roberta-large" in model_name:
self.bos_token = "<s>"
self.eos_token = "</s>"
self.sep_token = "</s>"
self.cls_token = "<s>"
self.unk_token = "<unk>"
self.pad_token = "<pad>"
self.mask_token = "<mask>"
else:
raise NotImplementedError()
def __getitem__(self, index):
text = self.texts[index]
tokenized_text = self.tokenizer.tokenize(text)
tokenized_text = (
[self.cls_token] + tokenized_text + [self.sep_token]
) # add special tokens
if len(tokenized_text) > self.max_len:
tokenized_text = tokenized_text[: self.max_len]
else:
tokenized_text = tokenized_text + [
self.pad_token for _ in range(self.max_len - len(tokenized_text))
]
attn_mask = [1 if tok != self.pad_token else 0 for tok in tokenized_text]
ids = self.tokenizer.convert_tokens_to_ids(tokenized_text)
return {
"ids": torch.tensor(ids, dtype=torch.long),
"mask": torch.tensor(attn_mask, dtype=torch.long),
}
def __len__(self):
return self.len
def seed_everything(seed: int):
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def is_begin_of_new_word(token, model_name, force_tokens, token_map):
if "bert-base-multilingual-cased" in model_name:
if token.lstrip("##") in force_tokens or token.lstrip("##") in set(token_map.values()):
return True
return not token.startswith("##")
elif "xlm-roberta-large" in model_name:
if token in string.punctuation or token in force_tokens or token in set(token_map.values()):
return True
return token.startswith("▁")
else:
raise NotImplementedError()
def replace_added_token(token, token_map):
for ori_token, new_token in token_map.items():
token = token.replace(new_token, ori_token)
return token
def get_pure_token(token, model_name):
if "bert-base-multilingual-cased" in model_name:
return token.lstrip("##")
elif "xlm-roberta-large" in model_name:
return token.lstrip("▁")
else:
raise NotImplementedError()