Spaces:
Runtime error
Runtime error
import os | |
import random | |
import string | |
import numpy as np | |
import torch | |
from torch.utils.data import Dataset | |
class TokenClfDataset(Dataset): # Hàm tạo custom dataset | |
def __init__( | |
self, | |
texts, | |
max_len=512, # 256 (phobert) 512 (xlm-roberta) | |
tokenizer=None, | |
model_name="m_bert", | |
): | |
self.len = len(texts) | |
self.texts = texts | |
self.tokenizer = tokenizer | |
self.max_len = max_len | |
self.model_name = model_name | |
if "m_bert" in model_name: | |
self.cls_token = "[CLS]" | |
self.sep_token = "[SEP]" | |
self.unk_token = "[UNK]" | |
self.pad_token = "[PAD]" | |
self.mask_token = "[MASK]" | |
elif "xlm-roberta-large" in model_name: | |
self.bos_token = "<s>" | |
self.eos_token = "</s>" | |
self.sep_token = "</s>" | |
self.cls_token = "<s>" | |
self.unk_token = "<unk>" | |
self.pad_token = "<pad>" | |
self.mask_token = "<mask>" | |
elif "xlm-roberta" in model_name: | |
self.bos_token = "<s>" | |
self.eos_token = "</s>" | |
self.sep_token = "</s>" | |
self.cls_token = "<s>" | |
self.unk_token = "<unk>" | |
self.pad_token = "<pad>" | |
self.mask_token = "<mask>" | |
elif "phobert" in model_name: | |
self.bos_token = "<s>" | |
self.eos_token = "</s>" | |
self.sep_token = "</s>" | |
self.cls_token = "<s>" | |
self.unk_token = "<unk>" | |
self.pad_token = "<pad>" | |
self.mask_token = "<mask>" | |
#else: raise NotImplementedError() | |
def __getitem__(self, index): | |
text = self.texts[index] | |
tokenized_text = self.tokenizer.tokenize(text) | |
tokenized_text = ( | |
[self.cls_token] + tokenized_text + [self.sep_token] | |
) # add special tokens | |
if len(tokenized_text) > self.max_len: | |
tokenized_text = tokenized_text[: self.max_len] | |
else: | |
tokenized_text = tokenized_text + [ | |
self.pad_token for _ in range(self.max_len - len(tokenized_text)) | |
] | |
attn_mask = [1 if tok != self.pad_token else 0 for tok in tokenized_text] | |
ids = self.tokenizer.convert_tokens_to_ids(tokenized_text) | |
return { | |
"ids": torch.tensor(ids, dtype=torch.long), | |
"mask": torch.tensor(attn_mask, dtype=torch.long), | |
} | |
def __len__(self): | |
return self.len | |
def seed_everything(seed: int): | |
random.seed(seed) | |
os.environ["PYTHONHASHSEED"] = str(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
def is_begin_of_new_word(token, model_name, force_tokens, token_map): # Thêm kí tự bắt đầu vào từ mới | |
if "m_bert" in model_name: | |
if token.lstrip("##") in force_tokens or token.lstrip("##") in set( | |
token_map.values() | |
): | |
return True | |
return not token.startswith("##") | |
elif "xlm-roberta-large" in model_name: | |
#print("xlm-roberta-large") | |
if ( | |
token in string.punctuation | |
or token in force_tokens | |
or token in set(token_map.values()) | |
): | |
return True | |
return token.startswith("▁") # check xem token có bắt đầu bằng kí tự "_" hay ko -> Trả về False | |
elif "xlm-roberta" in model_name: | |
#print("xlm-roberta-large") | |
if ( | |
token in string.punctuation | |
or token in force_tokens | |
or token in set(token_map.values()) | |
): | |
return True | |
return token.startswith("▁") | |
elif "phobert" in model_name: | |
#print("minh phobert") | |
#print("xlm-roberta-large") | |
if ( | |
token in string.punctuation # điều kiện hoặc | |
or token in force_tokens | |
or token in set(token_map.values()) | |
): | |
return True | |
#return token.startswith("▁") # | |
#return not token.startswith("▁") | |
#return not token.startswith("@@") | |
return not token.endswith("@@") | |
#return token.startswith("@@") | |
#else: raise NotImplementedError() | |
def replace_added_token(token, token_map): | |
for ori_token, new_token in token_map.items(): | |
token = token.replace(new_token, ori_token) | |
return token | |
def get_pure_token(token, model_name): # hàm get pure token trả về token gốc (sau khi loại bỏ kí tự đặc biệt subword) | |
if "m_bert" in model_name: | |
return token.lstrip("##") | |
elif "xlm-roberta-large" in model_name: | |
return token.lstrip("▁") # bỏ kí tự "_" ở phía bên trái của từ | |
elif "xlm-roberta" in model_name: | |
return token.lstrip("▁") # bỏ kí tự "_" ở phía bên trái của từ | |
elif "phobert" in model_name: | |
#return token.lstrip("▁") | |
#return token.lstrip("@@") | |
return token.rstrip("@@") | |
# else: raise NotImplementedError() |