#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Mon Sep 11 09:46:51 2023 @author: peter """ from allennlp.predictors.predictor import Predictor import pandas def clean(sentence): """ Ensure sentence ends with full stop Parameters ---------- sentence : str Sentence to be cleaned Returns ------- str Sentence with full stop at the end. """ return sentence if sentence.strip().endswith('.') else sentence+'.' class CoreferenceResolver(object): def __init__(self): """ Creates the Coreference resolver Returns ------- None. """ model_url = "https://storage.googleapis.com/allennlp-public-models/coref-spanbert-large-2020.02.27.tar.gz" self.predictor = Predictor.from_path(model_url) def __call__(self,group): """ Parameters ---------- group : pandas.Series Sentences on which to perform coreference resolution Returns ------- pandas.Series Sentences with coreferences resolved """ tokenized = group.apply(clean).str.split() line_breaks = tokenized.apply(len).cumsum() doc = [] for line in tokenized: doc.extend(line) clusters = self.predictor.predict_tokenized(doc) resolutions = {} for cluster in clusters['clusters']: starts = [] longest = -1 canonical = None for [start_pos,end_pos] in cluster: resolutions[start_pos]={'end':end_pos+1} starts.append(start_pos) length = end_pos - start_pos if length > longest: longest = length canonical = doc[start_pos:end_pos+1] for start in starts: resolutions[start]['canonical']=canonical doc_pos = 0 line = 0 results = [] current = [] while doc_pos < len(doc): if doc_pos in resolutions: current.extend(resolutions[doc_pos]['canonical']) doc_pos=resolutions[doc_pos]['end'] else: current.append(doc[doc_pos]) doc_pos+=1 if doc_pos>=line_breaks.iloc[line]: results.append(' '.join(current)) line+=1 current = [] return pandas.Series(results, index=group.index)