import csv import json import torch from transformers import BertTokenizer class CNerTokenizer(BertTokenizer): def __init__(self, vocab_file, do_lower_case=True): super().__init__(vocab_file=str(vocab_file), do_lower_case=do_lower_case) self.vocab_file = str(vocab_file) self.do_lower_case = do_lower_case def tokenize(self, text): _tokens = [] for c in text: if self.do_lower_case: c = c.lower() if c in self.vocab: _tokens.append(c) else: _tokens.append('[UNK]') return _tokens class DataProcessor(object): """Base class for data converters for sequence classification data sets.""" def get_train_examples(self, data_dir): """Gets a collection of `InputExample`s for the train set.""" raise NotImplementedError() def get_dev_examples(self, data_dir): """Gets a collection of `InputExample`s for the dev set.""" raise NotImplementedError() def get_labels(self): """Gets the list of labels for this data set.""" raise NotImplementedError() @classmethod def _read_tsv(cls, input_file, quotechar=None): """Reads a tab separated value file.""" with open(input_file, "r", encoding="utf-8-sig") as f: reader = csv.reader(f, delimiter="\t", quotechar=quotechar) lines = [] for line in reader: lines.append(line) return lines @classmethod def _read_text(self, input_file): lines = [] with open(input_file, 'r') as f: words = [] labels = [] for line in f: if line.startswith("-DOCSTART-") or line == "" or line == "\n": if words: lines.append({"words": words, "labels": labels}) words = [] labels = [] else: splits = line.split(" ") words.append(splits[0]) if len(splits) > 1: labels.append(splits[-1].replace("\n", "")) else: # Examples could have no label for mode = "test" labels.append("O") if words: lines.append({"words": words, "labels": labels}) return lines @classmethod def _read_json(self, input_file): lines = [] with open(input_file, 'r', encoding='utf8') as f: for line in f: line = json.loads(line.strip()) text = line['text'] label_entities = line.get('label', None) words = list(text) labels = ['O'] * len(words) if label_entities is not None: for key, value in label_entities.items(): for sub_name, sub_index in value.items(): for start_index, end_index in sub_index: assert ''.join(words[start_index:end_index+1]) == sub_name if start_index == end_index: labels[start_index] = 'S-'+key else: if end_index - start_index == 1: labels[start_index] = 'B-' + key labels[end_index] = 'E-' + key else: labels[start_index] = 'B-' + key labels[start_index + 1:end_index] = ['I-' + key] * (len(sub_name) - 2) labels[end_index] = 'E-' + key lines.append({"words": words, "labels": labels}) return lines def get_entity_bios(seq, id2label, middle_prefix='I-'): """Gets entities from sequence. note: BIOS Args: seq (list): sequence of labels. Returns: list: list of (chunk_type, chunk_start, chunk_end). Example: # >>> seq = ['B-PER', 'I-PER', 'O', 'S-LOC'] # >>> get_entity_bios(seq) [['PER', 0,1], ['LOC', 3, 3]] """ chunks = [] chunk = [-1, -1, -1] for indx, tag in enumerate(seq): if not isinstance(tag, str): tag = id2label[tag] if tag.startswith("S-"): if chunk[2] != -1: chunks.append(chunk) chunk = [-1, -1, -1] chunk[1] = indx chunk[2] = indx chunk[0] = tag.split('-')[1] chunks.append(chunk) chunk = (-1, -1, -1) if tag.startswith("B-"): if chunk[2] != -1: chunks.append(chunk) chunk = [-1, -1, -1] chunk[1] = indx chunk[0] = tag.split('-')[1] elif tag.startswith(middle_prefix) and chunk[1] != -1: _type = tag.split('-')[1] if _type == chunk[0]: chunk[2] = indx if indx == len(seq) - 1: chunks.append(chunk) else: if chunk[2] != -1: chunks.append(chunk) chunk = [-1, -1, -1] return chunks def get_entity_bio(seq, id2label, middle_prefix='I-'): """Gets entities from sequence. note: BIO Args: seq (list): sequence of labels. Returns: list: list of (chunk_type, chunk_start, chunk_end). Example: seq = ['B-PER', 'I-PER', 'O', 'B-LOC'] get_entity_bio(seq) #output [['PER', 0,1], ['LOC', 3, 3]] """ chunks = [] chunk = [-1, -1, -1] for indx, tag in enumerate(seq): if not isinstance(tag, str): tag = id2label[tag] if tag.startswith("B-"): if chunk[2] != -1: chunks.append(chunk) chunk = [-1, -1, -1] chunk[1] = indx chunk[0] = tag.split('-')[1] chunk[2] = indx if indx == len(seq) - 1: chunks.append(chunk) elif tag.startswith(middle_prefix) and chunk[1] != -1: _type = tag.split('-')[1] if _type == chunk[0]: chunk[2] = indx if indx == len(seq) - 1: chunks.append(chunk) else: if chunk[2] != -1: chunks.append(chunk) chunk = [-1, -1, -1] return chunks def get_entity_bioes(seq, id2label, middle_prefix='I-'): """Gets entities from sequence. note: BIOS Args: seq (list): sequence of labels. Returns: list: list of (chunk_type, chunk_start, chunk_end). Example: # >>> seq = ['B-PER', 'I-PER', 'O', 'S-LOC'] # >>> get_entity_bios(seq) [['PER', 0,1], ['LOC', 3, 3]] """ chunks = [] chunk = [-1, -1, -1] for indx, tag in enumerate(seq): if not isinstance(tag, str): tag = id2label[tag] if tag.startswith("S-"): if chunk[2] != -1: chunks.append(chunk) chunk = [-1, -1, -1] chunk[1] = indx chunk[2] = indx chunk[0] = tag.split('-')[1] chunks.append(chunk) chunk = (-1, -1, -1) if tag.startswith("B-"): if chunk[2] != -1: chunks.append(chunk) chunk = [-1, -1, -1] chunk[1] = indx chunk[0] = tag.split('-')[1] elif (tag.startswith(middle_prefix) or tag.startswith("E-")) and chunk[1] != -1: _type = tag.split('-')[1] if _type == chunk[0]: chunk[2] = indx if indx == len(seq) - 1: chunks.append(chunk) else: if chunk[2] != -1: chunks.append(chunk) chunk = [-1, -1, -1] return chunks def get_entities(seq, id2label, markup='bio', middle_prefix='I-'): ''' :param seq: :param id2label: :param markup: :return: ''' assert markup in ['bio', 'bios', 'bioes'] if markup == 'bio': return get_entity_bio(seq, id2label, middle_prefix) elif markup == 'bios': return get_entity_bios(seq, id2label, middle_prefix) else: return get_entity_bioes(seq, id2label, middle_prefix) def bert_extract_item(start_logits, end_logits): S = [] start_pred = torch.argmax(start_logits, -1).cpu().numpy()[0][1:-1] end_pred = torch.argmax(end_logits, -1).cpu().numpy()[0][1:-1] for i, s_l in enumerate(start_pred): if s_l == 0: continue for j, e_l in enumerate(end_pred[i:]): if s_l == e_l: S.append((s_l, i, i + j)) break return S