HaloMaster's picture
add fengshen
50f0fbb
import csv
import json
import torch
from transformers import BertTokenizer
class CNerTokenizer(BertTokenizer):
def __init__(self, vocab_file, do_lower_case=True):
super().__init__(vocab_file=str(vocab_file), do_lower_case=do_lower_case)
self.vocab_file = str(vocab_file)
self.do_lower_case = do_lower_case
def tokenize(self, text):
_tokens = []
for c in text:
if self.do_lower_case:
c = c.lower()
if c in self.vocab:
_tokens.append(c)
else:
_tokens.append('[UNK]')
return _tokens
class DataProcessor(object):
"""Base class for data converters for sequence classification data sets."""
def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set."""
raise NotImplementedError()
def get_dev_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the dev set."""
raise NotImplementedError()
def get_labels(self):
"""Gets the list of labels for this data set."""
raise NotImplementedError()
@classmethod
def _read_tsv(cls, input_file, quotechar=None):
"""Reads a tab separated value file."""
with open(input_file, "r", encoding="utf-8-sig") as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
lines = []
for line in reader:
lines.append(line)
return lines
@classmethod
def _read_text(self, input_file):
lines = []
with open(input_file, 'r') as f:
words = []
labels = []
for line in f:
if line.startswith("-DOCSTART-") or line == "" or line == "\n":
if words:
lines.append({"words": words, "labels": labels})
words = []
labels = []
else:
splits = line.split(" ")
words.append(splits[0])
if len(splits) > 1:
labels.append(splits[-1].replace("\n", ""))
else:
# Examples could have no label for mode = "test"
labels.append("O")
if words:
lines.append({"words": words, "labels": labels})
return lines
@classmethod
def _read_json(self, input_file):
lines = []
with open(input_file, 'r', encoding='utf8') as f:
for line in f:
line = json.loads(line.strip())
text = line['text']
label_entities = line.get('label', None)
words = list(text)
labels = ['O'] * len(words)
if label_entities is not None:
for key, value in label_entities.items():
for sub_name, sub_index in value.items():
for start_index, end_index in sub_index:
assert ''.join(words[start_index:end_index+1]) == sub_name
if start_index == end_index:
labels[start_index] = 'S-'+key
else:
if end_index - start_index == 1:
labels[start_index] = 'B-' + key
labels[end_index] = 'E-' + key
else:
labels[start_index] = 'B-' + key
labels[start_index + 1:end_index] = ['I-' + key] * (len(sub_name) - 2)
labels[end_index] = 'E-' + key
lines.append({"words": words, "labels": labels})
return lines
def get_entity_bios(seq, id2label, middle_prefix='I-'):
"""Gets entities from sequence.
note: BIOS
Args:
seq (list): sequence of labels.
Returns:
list: list of (chunk_type, chunk_start, chunk_end).
Example:
# >>> seq = ['B-PER', 'I-PER', 'O', 'S-LOC']
# >>> get_entity_bios(seq)
[['PER', 0,1], ['LOC', 3, 3]]
"""
chunks = []
chunk = [-1, -1, -1]
for indx, tag in enumerate(seq):
if not isinstance(tag, str):
tag = id2label[tag]
if tag.startswith("S-"):
if chunk[2] != -1:
chunks.append(chunk)
chunk = [-1, -1, -1]
chunk[1] = indx
chunk[2] = indx
chunk[0] = tag.split('-')[1]
chunks.append(chunk)
chunk = (-1, -1, -1)
if tag.startswith("B-"):
if chunk[2] != -1:
chunks.append(chunk)
chunk = [-1, -1, -1]
chunk[1] = indx
chunk[0] = tag.split('-')[1]
elif tag.startswith(middle_prefix) and chunk[1] != -1:
_type = tag.split('-')[1]
if _type == chunk[0]:
chunk[2] = indx
if indx == len(seq) - 1:
chunks.append(chunk)
else:
if chunk[2] != -1:
chunks.append(chunk)
chunk = [-1, -1, -1]
return chunks
def get_entity_bio(seq, id2label, middle_prefix='I-'):
"""Gets entities from sequence.
note: BIO
Args:
seq (list): sequence of labels.
Returns:
list: list of (chunk_type, chunk_start, chunk_end).
Example:
seq = ['B-PER', 'I-PER', 'O', 'B-LOC']
get_entity_bio(seq)
#output
[['PER', 0,1], ['LOC', 3, 3]]
"""
chunks = []
chunk = [-1, -1, -1]
for indx, tag in enumerate(seq):
if not isinstance(tag, str):
tag = id2label[tag]
if tag.startswith("B-"):
if chunk[2] != -1:
chunks.append(chunk)
chunk = [-1, -1, -1]
chunk[1] = indx
chunk[0] = tag.split('-')[1]
chunk[2] = indx
if indx == len(seq) - 1:
chunks.append(chunk)
elif tag.startswith(middle_prefix) and chunk[1] != -1:
_type = tag.split('-')[1]
if _type == chunk[0]:
chunk[2] = indx
if indx == len(seq) - 1:
chunks.append(chunk)
else:
if chunk[2] != -1:
chunks.append(chunk)
chunk = [-1, -1, -1]
return chunks
def get_entity_bioes(seq, id2label, middle_prefix='I-'):
"""Gets entities from sequence.
note: BIOS
Args:
seq (list): sequence of labels.
Returns:
list: list of (chunk_type, chunk_start, chunk_end).
Example:
# >>> seq = ['B-PER', 'I-PER', 'O', 'S-LOC']
# >>> get_entity_bios(seq)
[['PER', 0,1], ['LOC', 3, 3]]
"""
chunks = []
chunk = [-1, -1, -1]
for indx, tag in enumerate(seq):
if not isinstance(tag, str):
tag = id2label[tag]
if tag.startswith("S-"):
if chunk[2] != -1:
chunks.append(chunk)
chunk = [-1, -1, -1]
chunk[1] = indx
chunk[2] = indx
chunk[0] = tag.split('-')[1]
chunks.append(chunk)
chunk = (-1, -1, -1)
if tag.startswith("B-"):
if chunk[2] != -1:
chunks.append(chunk)
chunk = [-1, -1, -1]
chunk[1] = indx
chunk[0] = tag.split('-')[1]
elif (tag.startswith(middle_prefix) or tag.startswith("E-")) and chunk[1] != -1:
_type = tag.split('-')[1]
if _type == chunk[0]:
chunk[2] = indx
if indx == len(seq) - 1:
chunks.append(chunk)
else:
if chunk[2] != -1:
chunks.append(chunk)
chunk = [-1, -1, -1]
return chunks
def get_entities(seq, id2label, markup='bio', middle_prefix='I-'):
'''
:param seq:
:param id2label:
:param markup:
:return:
'''
assert markup in ['bio', 'bios', 'bioes']
if markup == 'bio':
return get_entity_bio(seq, id2label, middle_prefix)
elif markup == 'bios':
return get_entity_bios(seq, id2label, middle_prefix)
else:
return get_entity_bioes(seq, id2label, middle_prefix)
def bert_extract_item(start_logits, end_logits):
S = []
start_pred = torch.argmax(start_logits, -1).cpu().numpy()[0][1:-1]
end_pred = torch.argmax(end_logits, -1).cpu().numpy()[0][1:-1]
for i, s_l in enumerate(start_pred):
if s_l == 0:
continue
for j, e_l in enumerate(end_pred[i:]):
if s_l == e_l:
S.append((s_l, i, i + j))
break
return S