marc-match-ai / marcai /process.py
RvanB's picture
Fix CLI argument passing
d29e6b9
raw
history blame
No virus
7.19 kB
import argparse
import concurrent.futures
import csv
import time
from multiprocessing import get_context
import numpy as np
import pandas as pd
from more_itertools import chunked
import marcai.processing.comparisons as comps
import marcai.processing.normalizations as norms
from marcai.utils.parsing import load_records, record_dict
def multiprocess_pairs(
records_df,
pair_indices,
chunksize=50000,
processes=1,
):
# Create chunked iterator
pairs_chunked = chunked(pair_indices, chunksize)
# Create processing jobs
max_jobs = processes * 2
context = get_context("fork")
with concurrent.futures.ProcessPoolExecutor(
max_workers=processes, mp_context=context
) as executor:
futures = set()
done = set()
first_spawn = True
while futures or first_spawn:
if first_spawn:
spawn_count = max_jobs
first_spawn = False
else:
# Wait for a job to complete
done, futures = concurrent.futures.wait(
futures, return_when=concurrent.futures.FIRST_COMPLETED
)
spawn_count = max_jobs - len(futures)
for future in done:
# Get job's output
df = future.result()
# Yield output
yield df
# Spawn jobs
for _ in range(spawn_count):
pairs_chunk = next(pairs_chunked, None)
if pairs_chunk is None:
break
indices = np.array(pairs_chunk).astype(int)
left_indices = indices[:, 0]
right_indices = indices[:, 1]
left_records = records_df.iloc[left_indices].reset_index(drop=True)
right_records = records_df.iloc[right_indices].reset_index(drop=True)
futures.add(executor.submit(process, left_records, right_records))
def process(df0, df1):
normalize_fields = [
"author_names",
"corporate_names",
"meeting_names",
"publisher",
"title",
"title_a",
"title_b",
"title_c",
"title_p",
]
# Normalize text fields
for field in normalize_fields:
df0[field] = norms.lowercase(df0[field])
df1[field] = norms.lowercase(df1[field])
df0[field] = norms.remove_punctuation(df0[field])
df1[field] = norms.remove_punctuation(df1[field])
df0[field] = norms.remove_diacritics(df0[field])
df1[field] = norms.remove_diacritics(df1[field])
df0[field] = norms.normalize_whitespace(df0[field])
df1[field] = norms.normalize_whitespace(df1[field])
# Compare fields
result_df = pd.DataFrame()
result_df["id_0"] = df0["id"]
result_df["id_1"] = df1["id"]
result_df["raw_tokenset"] = comps.token_set_similarity(
df0["raw"], df1["raw"], null_value=0.5
)
# Token sort ratio
result_df["publisher"] = comps.token_sort_similarity(
df0["publisher"], df1["publisher"], null_value=0.5
)
author_names = comps.token_sort_similarity(
df0["author_names"], df1["author_names"], null_value=np.nan
)
corporate_names = comps.token_sort_similarity(
df0["corporate_names"], df1["corporate_names"], null_value=np.nan
)
meeting_names = comps.token_sort_similarity(
df0["meeting_names"], df1["meeting_names"], null_value=np.nan
)
authors = pd.concat([author_names, corporate_names, meeting_names], axis=1)
# Take max of author comparisons
result_df["author"] = comps.maximum(authors, null_value=0.5)
# Weighted title comparison
weights = {"title_a": 1, "raw": 0, "title_p": 1}
result_df["title_agg"] = comps.column_aggregate_similarity(
df0[weights.keys()], df1[weights.keys()], weights.values(), null_value=0
)
# Length difference
result_df["title_length"] = comps.length_similarity(
df0["title"], df1["title"], null_value=0.5
)
# Token set similarity
result_df["title_tokenset"] = comps.token_set_similarity(
df0["title"], df1["title"], null_value=0
)
# Token sort ratio
result_df["title_tokensort"] = comps.token_sort_similarity(
df0["title"], df1["title"], null_value=0
)
# Levenshtein
result_df["title_levenshtein"] = comps.levenshtein_similarity(
df0["title"], df1["title"], null_value=0
)
# Jaro
result_df["title_jaro"] = comps.jaro_similarity(
df0["title"], df1["title"], null_value=0
)
# Jaro Winkler
result_df["title_jaro_winkler"] = comps.jaro_winkler_similarity(
df0["title"], df1["title"], null_value=0
)
# Pagination
result_df["pagination"] = comps.pagination_match(
df0["pagination"], df1["pagination"], null_value=0.5
)
# Dates
result_df["pub_date"] = comps.year_similarity(
df0["pub_date"], df1["pub_date"], null_value=0.5, exp_coeff=0.15
)
# Pub place
result_df["pub_place"] = comps.equal(
df0["pub_place"], df1["pub_place"], null_value=0.5
)
# CID/Label
result_df["cid"] = comps.equal(df0["cid"], df1["cid"], null_value=0.5)
return result_df
def args_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
required = parser.add_argument_group("required arguments")
required.add_argument("-i", "--inputs", nargs="+", help="MARC files", required=True)
required.add_argument("-o", "--output", help="Output file", required=True)
parser.add_argument(
"-C",
"--chunksize",
type=int,
help="Number of comparisons per job",
default=50000,
)
parser.add_argument(
"-p", "--pair-indices", help="File containing indices of comparisons"
)
parser.add_argument(
"-P",
"--processes",
type=int,
help="Number of processes to run in parallel.",
default=1,
)
return parser
def main(args):
start = time.time()
# Load records
print("Loading records...")
records = []
for path in args.inputs:
records.extend([record_dict(r) for r in load_records(path)])
records_df = pd.DataFrame(records)
print(f"Loaded {len(records)} records.")
print("Processing records...")
# Process records
written = False
with open(args.pair_indices, "r") as indices_file:
reader = csv.reader(indices_file)
for df in multiprocess_pairs(
records_df, reader, args.chunksize, args.processes
):
if not written:
# Write header
df.to_csv(args.output, mode="w", header=True, index=False)
written = True
else:
# Write rows of df to output CSV
df.to_csv(args.output, mode="a", header=False, index=False)
end = time.time()
print(f"Processed {len(records)} records.")
print(f"Time elapsed: {end - start:.2f} seconds.")
if __name__ == "__main__":
args = args_parser().parse_args()
main(args)