Spaces:
Runtime error
Runtime error
import csv | |
import random | |
import spacy | |
import srsly | |
import tqdm | |
import yaml | |
params = yaml.safe_load(open("params.yaml")) | |
nlp = spacy.load("en_core_web_trf") | |
INPUT_FILE = "data/processed/wellcome_grant_descriptions.csv" | |
OUTPUT_FILE = "data/processed/entities.jsonl" | |
INCLUDE_ENTS = {"GPE", "LOC"} | |
EXCLUDE_ENTS = {"PERSON"} | |
def process_documents(input_file: str, output_file: str): | |
data = [] | |
print(f"Reading data from {input_file}...") | |
with open(input_file, "r") as f: | |
reader = csv.reader(f) | |
next(reader) | |
for row in reader: | |
data.append(row[0]) | |
print(f"Processing {len(data)} documents...") | |
entities = [] | |
for doc_ in tqdm.tqdm(data): | |
doc = nlp(doc_) | |
# Get a list of found entities | |
ents = [ | |
{ | |
"text": ent.text, | |
"label": ent.label_, | |
"start": ent.start_char, | |
"end": ent.end_char, | |
} | |
for ent in doc.ents | |
] | |
if ents: | |
found_ents = set([ent["label"] for ent in ents]) | |
if found_ents.intersection(INCLUDE_ENTS) and not found_ents.intersection( | |
EXCLUDE_ENTS | |
): | |
entities.append( | |
{ | |
"text": doc.text, | |
"ents": ents, | |
} | |
) | |
print(f"Randomly selecting {params['max_docs']} documents...") | |
random.shuffle(entities) | |
entities = entities[: params["max_docs"]] | |
print(f"Writing {len(entities)} documents to {output_file}...") | |
srsly.write_jsonl(output_file, entities) | |
if __name__ == "__main__": | |
process_documents(INPUT_FILE, OUTPUT_FILE) | |