import csv import random import spacy import srsly import tqdm import yaml params = yaml.safe_load(open("params.yaml")) nlp = spacy.load("en_core_web_trf") INPUT_FILE = "data/processed/wellcome_grant_descriptions.csv" OUTPUT_FILE = "data/processed/entities.jsonl" INCLUDE_ENTS = {"GPE", "LOC"} EXCLUDE_ENTS = {"PERSON"} def process_documents(input_file: str, output_file: str): data = [] print(f"Reading data from {input_file}...") with open(input_file, "r") as f: reader = csv.reader(f) next(reader) for row in reader: data.append(row[0]) print(f"Processing {len(data)} documents...") entities = [] for doc_ in tqdm.tqdm(data): doc = nlp(doc_) # Get a list of found entities ents = [ { "text": ent.text, "label": ent.label_, "start": ent.start_char, "end": ent.end_char, } for ent in doc.ents ] if ents: found_ents = set([ent["label"] for ent in ents]) if found_ents.intersection(INCLUDE_ENTS) and not found_ents.intersection( EXCLUDE_ENTS ): entities.append( { "text": doc.text, "ents": ents, } ) print(f"Randomly selecting {params['max_docs']} documents...") random.shuffle(entities) entities = entities[: params["max_docs"]] print(f"Writing {len(entities)} documents to {output_file}...") srsly.write_jsonl(output_file, entities) if __name__ == "__main__": process_documents(INPUT_FILE, OUTPUT_FILE)