''' Author: Qiguang Chen Date: 2023-01-11 10:39:26 LastEditors: Qiguang Chen LastEditTime: 2023-02-19 15:39:48 Description: all class for load data. ''' import os import torch import json from datasets import load_dataset, Dataset from torch.utils.data import DataLoader from common.utils import InputData ABS_PATH=os.path.join(os.path.abspath(os.path.dirname(__file__)), "../") class DataFactory(object): def __init__(self, tokenizer,use_multi_intent=False, to_lower_case=True): """_summary_ Args: tokenizer (Tokenizer): _description_ use_multi_intent (bool, optional): _description_. Defaults to False. """ self.tokenizer = tokenizer self.slot_label_list = [] self.intent_label_list = [] self.use_multi = use_multi_intent self.to_lower_case = to_lower_case self.slot_label_dict = None self.intent_label_dict = None def __is_supported_datasets(self, dataset_name:str)->bool: return dataset_name.lower() in ["atis", "snips", "mix-atis", "mix-atis"] def load_dataset(self, dataset_config, split="train"): dataset_name = None if split not in dataset_config: dataset_name = dataset_config.get("dataset_name") elif self.__is_supported_datasets(dataset_config[split]): dataset_name = dataset_config[split].lower() if dataset_name is not None: return load_dataset("LightChen2333/OpenSLU", dataset_name, split=split) else: data_file = dataset_config[split] data_dict = {"text": [], "slot": [], "intent":[]} with open(data_file, encoding="utf-8") as f: for line in f: row = json.loads(line) data_dict["text"].append(row["text"]) data_dict["slot"].append(row["slot"]) data_dict["intent"].append(row["intent"]) return Dataset.from_dict(data_dict) def update_label_names(self, dataset): for intent_labels in dataset["intent"]: if self.use_multi: intent_label = intent_labels.split("#") else: intent_label = [intent_labels] for x in intent_label: if x not in self.intent_label_list: self.intent_label_list.append(x) for slot_label in dataset["slot"]: for x in slot_label: if x not in self.slot_label_list: self.slot_label_list.append(x) self.intent_label_dict = {key: index for index, key in enumerate(self.intent_label_list)} self.slot_label_dict = {key: index for index, key in enumerate(self.slot_label_list)} def update_vocabulary(self, dataset): if self.tokenizer.name_or_path in ["word_tokenizer"]: for data in dataset: self.tokenizer.add_instance(data["text"]) @staticmethod def fast_align_data(text, padding_side="right"): for i in range(len(text.input_ids)): desired_output = [] for word_id in text.word_ids(i): if word_id is not None: start, end = text.word_to_tokens( i, word_id, sequence_index=0 if padding_side == "right" else 1) if start == end - 1: tokens = [start] else: tokens = [start, end - 1] if len(desired_output) == 0 or desired_output[-1] != tokens: desired_output.append(tokens) yield desired_output def fast_align(self, batch, ignore_index=-100, device="cuda", config=None, enable_label=True, label2tensor=True): if self.to_lower_case: input_list = [[t.lower() for t in x["text"]] for x in batch] else: input_list = [x["text"] for x in batch] text = self.tokenizer(input_list, return_tensors="pt", padding=True, is_split_into_words=True, truncation=True, **config).to(device) if enable_label: if label2tensor: slot_mask = torch.ones_like(text.input_ids) * ignore_index for i, offsets in enumerate( DataFactory.fast_align_data(text, padding_side=self.tokenizer.padding_side)): num = 0 assert len(offsets) == len(batch[i]["text"]) assert len(offsets) == len(batch[i]["slot"]) for off in offsets: slot_mask[i][off[0] ] = self.slot_label_dict[batch[i]["slot"][num]] num += 1 slot = slot_mask.clone() attentin_id = 0 if self.tokenizer.padding_side == "right" else 1 for i, slot_batch in enumerate(slot): for j, x in enumerate(slot_batch): if x == ignore_index and text.attention_mask[i][j] == attentin_id and (text.input_ids[i][ j] not in self.tokenizer.all_special_ids or text.input_ids[i][j] == self.tokenizer.unk_token_id): slot[i][j] = slot[i][j - 1] slot = slot.to(device) if not self.use_multi: intent = torch.tensor( [self.intent_label_dict[x["intent"]] for x in batch]).to(device) else: one_hot = torch.zeros( (len(batch), len(self.intent_label_list)), dtype=torch.float) for index, b in enumerate(batch): for x in b["intent"].split("#"): one_hot[index][self.intent_label_dict[x]] = 1. intent = one_hot.to(device) else: slot_mask = None slot = [['#' for _ in range(text.input_ids.shape[1])] for _ in range(text.input_ids.shape[0])] for i, offsets in enumerate(DataFactory.fast_align_data(text)): num = 0 for off in offsets: slot[i][off[0]] = batch[i]["slot"][num] num += 1 if not self.use_multi: intent = [x["intent"] for x in batch] else: intent = [ [x for x in b["intent"].split("#")] for b in batch] return InputData((text, slot, intent)) else: return InputData((text, None, None)) def general_align_data(self, split_text_list, raw_text_list, encoded_text): for i in range(len(split_text_list)): desired_output = [] jdx = 0 offset = encoded_text.offset_mapping[i].tolist() split_texts = split_text_list[i] raw_text = raw_text_list[i] last = 0 temp_offset = [] for off in offset: s, e = off if len(temp_offset) > 0 and (e != 0 and last == s): len_1 = off[1] - off[0] len_2 = temp_offset[-1][1] - temp_offset[-1][0] if len_1 > len_2: temp_offset.pop(-1) temp_offset.append([0, 0]) temp_offset.append(off) continue temp_offset.append(off) last = s offset = temp_offset for split_text in split_texts: while jdx < len(offset) and offset[jdx][0] == 0 and offset[jdx][1] == 0: jdx += 1 if jdx == len(offset): continue start_, end_ = offset[jdx] tokens = None if split_text == raw_text[start_:end_].strip(): tokens = [jdx] else: # Compute "xxx" -> "xx" "#x" temp_jdx = jdx last_str = raw_text[start_:end_].strip() while last_str != split_text and temp_jdx < len(offset) - 1: temp_jdx += 1 last_str += raw_text[offset[temp_jdx] [0]:offset[temp_jdx][1]].strip() if temp_jdx == jdx: raise ValueError("Illegal Input data") elif last_str == split_text: tokens = [jdx, temp_jdx] jdx = temp_jdx else: jdx -= 1 jdx += 1 if tokens is not None: desired_output.append(tokens) yield desired_output def general_align(self, batch, ignore_index=-100, device="cuda", config=None, enable_label=True, label2tensor=True, locale="en-US"): if self.to_lower_case: raw_data = [" ".join(x["text"]).lower() if locale not in ['ja-JP', 'zh-CN', 'zh-TW'] else "".join(x["text"]) for x in batch] input_list = [[t.lower() for t in x["text"]] for x in batch] else: input_list = [x["text"] for x in batch] raw_data = [" ".join(x["text"]) if locale not in ['ja-JP', 'zh-CN', 'zh-TW'] else "".join(x["text"]) for x in batch] text = self.tokenizer(raw_data, return_tensors="pt", padding=True, truncation=True, return_offsets_mapping=True, **config).to(device) if enable_label: if label2tensor: slot_mask = torch.ones_like(text.input_ids) * ignore_index for i, offsets in enumerate( self.general_align_data(input_list, raw_data, encoded_text=text)): num = 0 # if len(offsets) != len(batch[i]["text"]) or len(offsets) != len(batch[i]["slot"]): # if for off in offsets: slot_mask[i][off[0] ] = self.slot_label_dict[batch[i]["slot"][num]] num += 1 # slot = slot_mask.clone() # attentin_id = 0 if self.tokenizer.padding_side == "right" else 1 # for i, slot_batch in enumerate(slot): # for j, x in enumerate(slot_batch): # if x == ignore_index and text.attention_mask[i][j] == attentin_id and text.input_ids[i][ # j] not in self.tokenizer.all_special_ids: # slot[i][j] = slot[i][j - 1] slot = slot_mask.to(device) if not self.use_multi: intent = torch.tensor( [self.intent_label_dict[x["intent"]] for x in batch]).to(device) else: one_hot = torch.zeros( (len(batch), len(self.intent_label_list)), dtype=torch.float) for index, b in enumerate(batch): for x in b["intent"].split("#"): one_hot[index][self.intent_label_dict[x]] = 1. intent = one_hot.to(device) else: slot_mask = None slot = [['#' for _ in range(text.input_ids.shape[1])] for _ in range(text.input_ids.shape[0])] for i, offsets in enumerate(self.general_align_data(input_list, raw_data, encoded_text=text)): num = 0 for off in offsets: slot[i][off[0]] = batch[i]["slot"][num] num += 1 if not self.use_multi: intent = [x["intent"] for x in batch] else: intent = [ [x for x in b["intent"].split("#")] for b in batch] return InputData((text, slot, intent)) else: return InputData((text, None, None)) def batch_fn(self, batch, ignore_index=-100, device="cuda", config=None, align_mode="fast", enable_label=True, label2tensor=True): if align_mode == "fast": # try: return self.fast_align(batch, ignore_index=ignore_index, device=device, config=config, enable_label=enable_label, label2tensor=label2tensor) # except: # return self.general_align(batch, # ignore_index=ignore_index, # device=device, # config=config, # enable_label=enable_label, # label2tensor=label2tensor) else: return self.general_align(batch, ignore_index=ignore_index, device=device, config=config, enable_label=enable_label, label2tensor=label2tensor) def get_data_loader(self, dataset, batch_size, shuffle=False, device="cuda", enable_label=True, align_mode="fast", label2tensor=True, **config): data_loader = DataLoader(dataset, shuffle=shuffle, batch_size=batch_size, collate_fn=lambda x: self.batch_fn(x, device=device, config=config, enable_label=enable_label, align_mode=align_mode, label2tensor=label2tensor)) return data_loader