Spaces:
Runtime error
Runtime error
import csv | |
import json | |
import torch | |
from transformers import BertTokenizer | |
class CNerTokenizer(BertTokenizer): | |
def __init__(self, vocab_file, do_lower_case=True): | |
super().__init__(vocab_file=str(vocab_file), do_lower_case=do_lower_case) | |
self.vocab_file = str(vocab_file) | |
self.do_lower_case = do_lower_case | |
def tokenize(self, text): | |
_tokens = [] | |
for c in text: | |
if self.do_lower_case: | |
c = c.lower() | |
if c in self.vocab: | |
_tokens.append(c) | |
else: | |
_tokens.append('[UNK]') | |
return _tokens | |
class DataProcessor(object): | |
"""Base class for data converters for sequence classification data sets.""" | |
def get_train_examples(self, data_dir): | |
"""Gets a collection of `InputExample`s for the train set.""" | |
raise NotImplementedError() | |
def get_dev_examples(self, data_dir): | |
"""Gets a collection of `InputExample`s for the dev set.""" | |
raise NotImplementedError() | |
def get_labels(self): | |
"""Gets the list of labels for this data set.""" | |
raise NotImplementedError() | |
def _read_tsv(cls, input_file, quotechar=None): | |
"""Reads a tab separated value file.""" | |
with open(input_file, "r", encoding="utf-8-sig") as f: | |
reader = csv.reader(f, delimiter="\t", quotechar=quotechar) | |
lines = [] | |
for line in reader: | |
lines.append(line) | |
return lines | |
def _read_text(self, input_file): | |
lines = [] | |
with open(input_file, 'r') as f: | |
words = [] | |
labels = [] | |
for line in f: | |
if line.startswith("-DOCSTART-") or line == "" or line == "\n": | |
if words: | |
lines.append({"words": words, "labels": labels}) | |
words = [] | |
labels = [] | |
else: | |
splits = line.split(" ") | |
words.append(splits[0]) | |
if len(splits) > 1: | |
labels.append(splits[-1].replace("\n", "")) | |
else: | |
# Examples could have no label for mode = "test" | |
labels.append("O") | |
if words: | |
lines.append({"words": words, "labels": labels}) | |
return lines | |
def _read_json(self, input_file): | |
lines = [] | |
with open(input_file, 'r', encoding='utf8') as f: | |
for line in f: | |
line = json.loads(line.strip()) | |
text = line['text'] | |
label_entities = line.get('label', None) | |
words = list(text) | |
labels = ['O'] * len(words) | |
if label_entities is not None: | |
for key, value in label_entities.items(): | |
for sub_name, sub_index in value.items(): | |
for start_index, end_index in sub_index: | |
assert ''.join(words[start_index:end_index+1]) == sub_name | |
if start_index == end_index: | |
labels[start_index] = 'S-'+key | |
else: | |
if end_index - start_index == 1: | |
labels[start_index] = 'B-' + key | |
labels[end_index] = 'E-' + key | |
else: | |
labels[start_index] = 'B-' + key | |
labels[start_index + 1:end_index] = ['I-' + key] * (len(sub_name) - 2) | |
labels[end_index] = 'E-' + key | |
lines.append({"words": words, "labels": labels}) | |
return lines | |
def get_entity_bios(seq, id2label, middle_prefix='I-'): | |
"""Gets entities from sequence. | |
note: BIOS | |
Args: | |
seq (list): sequence of labels. | |
Returns: | |
list: list of (chunk_type, chunk_start, chunk_end). | |
Example: | |
# >>> seq = ['B-PER', 'I-PER', 'O', 'S-LOC'] | |
# >>> get_entity_bios(seq) | |
[['PER', 0,1], ['LOC', 3, 3]] | |
""" | |
chunks = [] | |
chunk = [-1, -1, -1] | |
for indx, tag in enumerate(seq): | |
if not isinstance(tag, str): | |
tag = id2label[tag] | |
if tag.startswith("S-"): | |
if chunk[2] != -1: | |
chunks.append(chunk) | |
chunk = [-1, -1, -1] | |
chunk[1] = indx | |
chunk[2] = indx | |
chunk[0] = tag.split('-')[1] | |
chunks.append(chunk) | |
chunk = (-1, -1, -1) | |
if tag.startswith("B-"): | |
if chunk[2] != -1: | |
chunks.append(chunk) | |
chunk = [-1, -1, -1] | |
chunk[1] = indx | |
chunk[0] = tag.split('-')[1] | |
elif tag.startswith(middle_prefix) and chunk[1] != -1: | |
_type = tag.split('-')[1] | |
if _type == chunk[0]: | |
chunk[2] = indx | |
if indx == len(seq) - 1: | |
chunks.append(chunk) | |
else: | |
if chunk[2] != -1: | |
chunks.append(chunk) | |
chunk = [-1, -1, -1] | |
return chunks | |
def get_entity_bio(seq, id2label, middle_prefix='I-'): | |
"""Gets entities from sequence. | |
note: BIO | |
Args: | |
seq (list): sequence of labels. | |
Returns: | |
list: list of (chunk_type, chunk_start, chunk_end). | |
Example: | |
seq = ['B-PER', 'I-PER', 'O', 'B-LOC'] | |
get_entity_bio(seq) | |
#output | |
[['PER', 0,1], ['LOC', 3, 3]] | |
""" | |
chunks = [] | |
chunk = [-1, -1, -1] | |
for indx, tag in enumerate(seq): | |
if not isinstance(tag, str): | |
tag = id2label[tag] | |
if tag.startswith("B-"): | |
if chunk[2] != -1: | |
chunks.append(chunk) | |
chunk = [-1, -1, -1] | |
chunk[1] = indx | |
chunk[0] = tag.split('-')[1] | |
chunk[2] = indx | |
if indx == len(seq) - 1: | |
chunks.append(chunk) | |
elif tag.startswith(middle_prefix) and chunk[1] != -1: | |
_type = tag.split('-')[1] | |
if _type == chunk[0]: | |
chunk[2] = indx | |
if indx == len(seq) - 1: | |
chunks.append(chunk) | |
else: | |
if chunk[2] != -1: | |
chunks.append(chunk) | |
chunk = [-1, -1, -1] | |
return chunks | |
def get_entity_bioes(seq, id2label, middle_prefix='I-'): | |
"""Gets entities from sequence. | |
note: BIOS | |
Args: | |
seq (list): sequence of labels. | |
Returns: | |
list: list of (chunk_type, chunk_start, chunk_end). | |
Example: | |
# >>> seq = ['B-PER', 'I-PER', 'O', 'S-LOC'] | |
# >>> get_entity_bios(seq) | |
[['PER', 0,1], ['LOC', 3, 3]] | |
""" | |
chunks = [] | |
chunk = [-1, -1, -1] | |
for indx, tag in enumerate(seq): | |
if not isinstance(tag, str): | |
tag = id2label[tag] | |
if tag.startswith("S-"): | |
if chunk[2] != -1: | |
chunks.append(chunk) | |
chunk = [-1, -1, -1] | |
chunk[1] = indx | |
chunk[2] = indx | |
chunk[0] = tag.split('-')[1] | |
chunks.append(chunk) | |
chunk = (-1, -1, -1) | |
if tag.startswith("B-"): | |
if chunk[2] != -1: | |
chunks.append(chunk) | |
chunk = [-1, -1, -1] | |
chunk[1] = indx | |
chunk[0] = tag.split('-')[1] | |
elif (tag.startswith(middle_prefix) or tag.startswith("E-")) and chunk[1] != -1: | |
_type = tag.split('-')[1] | |
if _type == chunk[0]: | |
chunk[2] = indx | |
if indx == len(seq) - 1: | |
chunks.append(chunk) | |
else: | |
if chunk[2] != -1: | |
chunks.append(chunk) | |
chunk = [-1, -1, -1] | |
return chunks | |
def get_entities(seq, id2label, markup='bio', middle_prefix='I-'): | |
''' | |
:param seq: | |
:param id2label: | |
:param markup: | |
:return: | |
''' | |
assert markup in ['bio', 'bios', 'bioes'] | |
if markup == 'bio': | |
return get_entity_bio(seq, id2label, middle_prefix) | |
elif markup == 'bios': | |
return get_entity_bios(seq, id2label, middle_prefix) | |
else: | |
return get_entity_bioes(seq, id2label, middle_prefix) | |
def bert_extract_item(start_logits, end_logits): | |
S = [] | |
start_pred = torch.argmax(start_logits, -1).cpu().numpy()[0][1:-1] | |
end_pred = torch.argmax(end_logits, -1).cpu().numpy()[0][1:-1] | |
for i, s_l in enumerate(start_pred): | |
if s_l == 0: | |
continue | |
for j, e_l in enumerate(end_pred[i:]): | |
if s_l == e_l: | |
S.append((s_l, i, i + j)) | |
break | |
return S | |