import os import random import string import numpy as np import torch from torch.utils.data import Dataset class TokenClfDataset(Dataset): # Hàm tạo custom dataset def __init__( self, texts, max_len=512, # 256 (phobert) 512 (xlm-roberta) tokenizer=None, model_name="m_bert", ): self.len = len(texts) self.texts = texts self.tokenizer = tokenizer self.max_len = max_len self.model_name = model_name if "m_bert" in model_name: self.cls_token = "[CLS]" self.sep_token = "[SEP]" self.unk_token = "[UNK]" self.pad_token = "[PAD]" self.mask_token = "[MASK]" elif "xlm-roberta-large" in model_name: self.bos_token = "" self.eos_token = "" self.sep_token = "" self.cls_token = "" self.unk_token = "" self.pad_token = "" self.mask_token = "" elif "xlm-roberta" in model_name: self.bos_token = "" self.eos_token = "" self.sep_token = "" self.cls_token = "" self.unk_token = "" self.pad_token = "" self.mask_token = "" elif "phobert" in model_name: self.bos_token = "" self.eos_token = "" self.sep_token = "" self.cls_token = "" self.unk_token = "" self.pad_token = "" self.mask_token = "" #else: raise NotImplementedError() def __getitem__(self, index): text = self.texts[index] tokenized_text = self.tokenizer.tokenize(text) tokenized_text = ( [self.cls_token] + tokenized_text + [self.sep_token] ) # add special tokens if len(tokenized_text) > self.max_len: tokenized_text = tokenized_text[: self.max_len] else: tokenized_text = tokenized_text + [ self.pad_token for _ in range(self.max_len - len(tokenized_text)) ] attn_mask = [1 if tok != self.pad_token else 0 for tok in tokenized_text] ids = self.tokenizer.convert_tokens_to_ids(tokenized_text) return { "ids": torch.tensor(ids, dtype=torch.long), "mask": torch.tensor(attn_mask, dtype=torch.long), } def __len__(self): return self.len def seed_everything(seed: int): random.seed(seed) os.environ["PYTHONHASHSEED"] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def is_begin_of_new_word(token, model_name, force_tokens, token_map): # Thêm kí tự bắt đầu vào từ mới if "m_bert" in model_name: if token.lstrip("##") in force_tokens or token.lstrip("##") in set( token_map.values() ): return True return not token.startswith("##") elif "xlm-roberta-large" in model_name: #print("xlm-roberta-large") if ( token in string.punctuation or token in force_tokens or token in set(token_map.values()) ): return True return token.startswith("▁") # check xem token có bắt đầu bằng kí tự "_" hay ko -> Trả về False elif "xlm-roberta" in model_name: #print("xlm-roberta-large") if ( token in string.punctuation or token in force_tokens or token in set(token_map.values()) ): return True return token.startswith("▁") elif "phobert" in model_name: #print("minh phobert") #print("xlm-roberta-large") if ( token in string.punctuation # điều kiện hoặc or token in force_tokens or token in set(token_map.values()) ): return True #return token.startswith("▁") # #return not token.startswith("▁") #return not token.startswith("@@") return not token.endswith("@@") #return token.startswith("@@") #else: raise NotImplementedError() def replace_added_token(token, token_map): for ori_token, new_token in token_map.items(): token = token.replace(new_token, ori_token) return token def get_pure_token(token, model_name): # hàm get pure token trả về token gốc (sau khi loại bỏ kí tự đặc biệt subword) if "m_bert" in model_name: return token.lstrip("##") elif "xlm-roberta-large" in model_name: return token.lstrip("▁") # bỏ kí tự "_" ở phía bên trái của từ elif "xlm-roberta" in model_name: return token.lstrip("▁") # bỏ kí tự "_" ở phía bên trái của từ elif "phobert" in model_name: #return token.lstrip("▁") #return token.lstrip("@@") return token.rstrip("@@") # else: raise NotImplementedError()