|
import ast |
|
import pandas as pd |
|
import joblib |
|
import nltk |
|
from nltk import pos_tag |
|
import string |
|
from nltk.stem import WordNetLemmatizer |
|
from nltk.stem import PorterStemmer |
|
|
|
|
|
|
|
nltk.download('wordnet') |
|
nltk.download('omw-1.4') |
|
nltk.download("averaged_perceptron_tagger") |
|
|
|
|
|
class getsentence(object): |
|
''' |
|
This class is used to get the sentences from the dataset. |
|
Converts from BIO format to sentences using their sentence numbers |
|
''' |
|
def __init__(self, data): |
|
self.n_sent = 1.0 |
|
self.data = data |
|
self.empty = False |
|
self.grouped = self.data.groupby("sentence_num").apply(self._agg_func) |
|
self.sentences = [s for s in self.grouped] |
|
|
|
def _agg_func(self, s): |
|
return [(w, p) for w, p in zip(s["token"].values.tolist(), |
|
s["pos_tag"].values.tolist())] |
|
|
|
|
|
def word2features(sent, i): |
|
''' |
|
This method is used to extract features from the words in the sentence. |
|
The main features extracted are: |
|
- word.lower(): The word in lowercase |
|
- word.isdigit(): If the word is a digit |
|
- word.punct(): If the word is a punctuation |
|
- postag: The pos tag of the word |
|
- word.lemma(): The lemma of the word |
|
- word.stem(): The stem of the word |
|
The features (not all) are also extracted for the 4 previous and 4 next words. |
|
''' |
|
global token_count |
|
wordnet_lemmatizer = WordNetLemmatizer() |
|
porter_stemmer = PorterStemmer() |
|
word = sent[i][0] |
|
postag = sent[i][1] |
|
|
|
features = { |
|
'bias': 1.0, |
|
'word.lower()': word.lower(), |
|
'word.isdigit()': word.isdigit(), |
|
|
|
'word.punct()': word in string.punctuation, |
|
'postag': postag, |
|
|
|
'word.lemma()': wordnet_lemmatizer.lemmatize(word), |
|
|
|
'word.stem()': porter_stemmer.stem(word) |
|
} |
|
if i > 0: |
|
word1 = sent[i-1][0] |
|
postag1 = sent[i-1][1] |
|
features.update({ |
|
'-1:word.lower()': word1.lower(), |
|
'-1:word.isdigit()': word1.isdigit(), |
|
'-1:word.punct()': word1 in string.punctuation, |
|
'-1:postag': postag1 |
|
}) |
|
if i - 2 >= 0: |
|
features.update({ |
|
'-2:word.lower()': sent[i-2][0].lower(), |
|
'-2:word.isdigit()': sent[i-2][0].isdigit(), |
|
'-2:word.punct()': sent[i-2][0] in string.punctuation, |
|
'-2:postag': sent[i-2][1] |
|
}) |
|
if i - 3 >= 0: |
|
features.update({ |
|
'-3:word.lower()': sent[i-3][0].lower(), |
|
'-3:word.isdigit()': sent[i-3][0].isdigit(), |
|
'-3:word.punct()': sent[i-3][0] in string.punctuation, |
|
'-3:postag': sent[i-3][1] |
|
}) |
|
if i - 4 >= 0: |
|
features.update({ |
|
'-4:word.lower()': sent[i-4][0].lower(), |
|
'-4:word.isdigit()': sent[i-4][0].isdigit(), |
|
'-4:word.punct()': sent[i-4][0] in string.punctuation, |
|
'-4:postag': sent[i-4][1] |
|
}) |
|
else: |
|
features['BOS'] = True |
|
|
|
if i < len(sent)-1: |
|
word1 = sent[i+1][0] |
|
postag1 = sent[i+1][1] |
|
features.update({ |
|
'+1:word.lower()': word1.lower(), |
|
'+1:word.isdigit()': word1.isdigit(), |
|
'+1:word.punct()': word1 in string.punctuation, |
|
'+1:postag': postag1 |
|
}) |
|
if i + 2 < len(sent): |
|
features.update({ |
|
'+2:word.lower()': sent[i+2][0].lower(), |
|
'+2:word.isdigit()': sent[i+2][0].isdigit(), |
|
'+2:word.punct()': sent[i+2][0] in string.punctuation, |
|
'+2:postag': sent[i+2][1] |
|
}) |
|
if i + 3 < len(sent): |
|
features.update({ |
|
'+3:word.lower()': sent[i+3][0].lower(), |
|
'+3:word.isdigit()': sent[i+3][0].isdigit(), |
|
'+3:word.punct()': sent[i+3][0] in string.punctuation, |
|
'+3:postag': sent[i+3][1] |
|
}) |
|
if i + 4 < len(sent): |
|
features.update({ |
|
'+4:word.lower()': sent[i+4][0].lower(), |
|
'+4:word.isdigit()': sent[i+4][0].isdigit(), |
|
'+4:word.punct()': sent[i+4][0] in string.punctuation, |
|
'+4:postag': sent[i+4][1] |
|
}) |
|
else: |
|
features['EOS'] = True |
|
|
|
return features |
|
|
|
|
|
def sent2features(sent): |
|
''' |
|
This method is used to extract features from the sentence. |
|
''' |
|
return [word2features(sent, i) for i in range(len(sent))] |
|
|
|
|
|
print("Evaluating the model...") |
|
|
|
df_eval = pd.read_excel("testset_NER_LegalLens.xlsx") |
|
print("Read the evaluation dataset.") |
|
df_eval["tokens"] = df_eval["tokens"].apply(ast.literal_eval) |
|
df_eval['pos_tags'] = df_eval['tokens'].apply(lambda x: [tag[1] |
|
for tag in pos_tag(x)]) |
|
data_eval = [] |
|
for i in range(len(df_eval)): |
|
for j in range(len(df_eval["tokens"][i])): |
|
data_eval.append( |
|
{ |
|
"sentence_num": i+1, |
|
"id": df_eval["id"][i], |
|
"token": df_eval["tokens"][i][j], |
|
"pos_tag": df_eval["pos_tags"][i][j], |
|
} |
|
) |
|
data_eval = pd.DataFrame(data_eval) |
|
print("Dataframe created.") |
|
getter = getsentence(data_eval) |
|
sentences_eval = getter.sentences |
|
X_eval = [sent2features(s) for s in sentences_eval] |
|
print("Predicting the NER tags...") |
|
|
|
crf = joblib.load("../models/crf.pkl") |
|
y_pred_eval = crf.predict(X_eval) |
|
print("NER tags predicted.") |
|
df_eval["ner_tags"] = y_pred_eval |
|
df_eval.drop(columns=["pos_tags"], inplace=True) |
|
print("Saving the predictions...") |
|
df_eval.to_csv("predictions_NERLens.csv", index=False) |
|
print("Predictions saved.") |
|
|