import logging import os from typing import List, TextIO, Union from conllu import parse_incr from utils_ner import InputExample, Split, TokenClassificationTask logger = logging.getLogger(__name__) class NER(TokenClassificationTask): def __init__(self, label_idx=-1): # in NER datasets, the last column is usually reserved for NER label self.label_idx = label_idx def read_examples_from_file(self, data_dir, mode: Union[Split, str]) -> List[InputExample]: if isinstance(mode, Split): mode = mode.value file_path = os.path.join(data_dir, f"{mode}.txt") guid_index = 1 examples = [] with open(file_path, encoding="utf-8") as f: words = [] labels = [] for line in f: if line.startswith("-DOCSTART-") or line == "" or line == "\n": if words: examples.append(InputExample(guid=f"{mode}-{guid_index}", words=words, labels=labels)) guid_index += 1 words = [] labels = [] else: splits = line.split(" ") words.append(splits[0]) if len(splits) > 1: labels.append(splits[self.label_idx].replace("\n", "")) else: # Examples could have no label for mode = "test" labels.append("O") if words: examples.append(InputExample(guid=f"{mode}-{guid_index}", words=words, labels=labels)) return examples def write_predictions_to_file(self, writer: TextIO, test_input_reader: TextIO, preds_list: List): example_id = 0 for line in test_input_reader: if line.startswith("-DOCSTART-") or line == "" or line == "\n": writer.write(line) if not preds_list[example_id]: example_id += 1 elif preds_list[example_id]: output_line = line.split()[0] + " " + preds_list[example_id].pop(0) + "\n" writer.write(output_line) else: logger.warning("Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0]) def get_labels(self, path: str) -> List[str]: if path: with open(path, "r") as f: labels = f.read().splitlines() if "O" not in labels: labels = ["O"] + labels return labels else: return ["O", "B-MISC", "I-MISC", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC"] class Chunk(NER): def __init__(self): # in CONLL2003 dataset chunk column is second-to-last super().__init__(label_idx=-2) def get_labels(self, path: str) -> List[str]: if path: with open(path, "r") as f: labels = f.read().splitlines() if "O" not in labels: labels = ["O"] + labels return labels else: return [ "O", "B-ADVP", "B-INTJ", "B-LST", "B-PRT", "B-NP", "B-SBAR", "B-VP", "B-ADJP", "B-CONJP", "B-PP", "I-ADVP", "I-INTJ", "I-LST", "I-PRT", "I-NP", "I-SBAR", "I-VP", "I-ADJP", "I-CONJP", "I-PP", ] class POS(TokenClassificationTask): def read_examples_from_file(self, data_dir, mode: Union[Split, str]) -> List[InputExample]: if isinstance(mode, Split): mode = mode.value file_path = os.path.join(data_dir, f"{mode}.txt") guid_index = 1 examples = [] with open(file_path, encoding="utf-8") as f: for sentence in parse_incr(f): words = [] labels = [] for token in sentence: words.append(token["form"]) labels.append(token["upos"]) assert len(words) == len(labels) if words: examples.append(InputExample(guid=f"{mode}-{guid_index}", words=words, labels=labels)) guid_index += 1 return examples def write_predictions_to_file(self, writer: TextIO, test_input_reader: TextIO, preds_list: List): example_id = 0 for sentence in parse_incr(test_input_reader): s_p = preds_list[example_id] out = "" for token in sentence: out += f'{token["form"]} ({token["upos"]}|{s_p.pop(0)}) ' out += "\n" writer.write(out) example_id += 1 def get_labels(self, path: str) -> List[str]: if path: with open(path, "r") as f: return f.read().splitlines() else: return [ "ADJ", "ADP", "ADV", "AUX", "CCONJ", "DET", "INTJ", "NOUN", "NUM", "PART", "PRON", "PROPN", "PUNCT", "SCONJ", "SYM", "VERB", "X", ]