Spaces:
Runtime error
Runtime error
add application files
Browse files- .gitignore +5 -0
- app.py +100 -0
- data/merged_dataset/dataset_dict.json +1 -0
- data/merged_dataset/orig_test/cache-eeafde0b6770e328.arrow +3 -0
- data/merged_dataset/orig_test/data-00000-of-00001.arrow +3 -0
- data/merged_dataset/orig_test/dataset_info.json +34 -0
- data/merged_dataset/orig_test/state.json +13 -0
- data/merged_dataset/orig_train/cache-45d1543dc33c36be.arrow +3 -0
- data/merged_dataset/orig_train/data-00000-of-00001.arrow +3 -0
- data/merged_dataset/orig_train/dataset_info.json +34 -0
- data/merged_dataset/orig_train/state.json +13 -0
- data/merged_dataset/orig_validation/cache-afff9bbc07b5bee3.arrow +3 -0
- data/merged_dataset/orig_validation/data-00000-of-00001.arrow +3 -0
- data/merged_dataset/orig_validation/dataset_info.json +34 -0
- data/merged_dataset/orig_validation/state.json +13 -0
- data/merged_dataset/test/cache-3a6709085dd0f520.arrow +3 -0
- data/merged_dataset/test/cache-50fbc051d6b536f8.arrow +3 -0
- data/merged_dataset/test/cache-7344e423192cdf30.arrow +3 -0
- data/merged_dataset/test/cache-861a0fd50d74bfe1.arrow +3 -0
- data/merged_dataset/test/data-00000-of-00001.arrow +3 -0
- data/merged_dataset/test/dataset_info.json +34 -0
- data/merged_dataset/test/state.json +13 -0
- data/merged_dataset/train/cache-f8f6a910898e33f3.arrow +3 -0
- data/merged_dataset/train/data-00000-of-00001.arrow +3 -0
- data/merged_dataset/train/dataset_info.json +34 -0
- data/merged_dataset/train/state.json +13 -0
- data/merged_dataset/validation/cache-a70cdc1f600f2440.arrow +3 -0
- data/merged_dataset/validation/cache-c442280565074102.arrow +3 -0
- data/merged_dataset/validation/data-00000-of-00001.arrow +3 -0
- data/merged_dataset/validation/dataset_info.json +34 -0
- data/merged_dataset/validation/state.json +13 -0
- data/ner_feature.pickle +3 -0
- data/sample_data.json +0 -0
- evaluate_model.py +62 -0
- metrics.py +78 -0
- requirements.txt +7 -0
- utils.py +83 -0
.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.ipynb_checkpoints/
|
2 |
+
Untitled.ipynb
|
3 |
+
__pycache__/
|
4 |
+
evaluate_trf.ipynb
|
5 |
+
test.json
|
app.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from metrics import calc_metrics
|
2 |
+
import gradio as gr
|
3 |
+
from openai import OpenAI
|
4 |
+
import os
|
5 |
+
|
6 |
+
from transformers import pipeline
|
7 |
+
# from dotenv import load_dotenv, find_dotenv
|
8 |
+
import huggingface_hub
|
9 |
+
import json
|
10 |
+
# from simcse import SimCSE # use for gpt
|
11 |
+
from evaluate_data import store_sample_data, get_metrics_trf
|
12 |
+
|
13 |
+
store_sample_data()
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
with open('./data/sample_data.json', 'r') as f:
|
18 |
+
# sample_data = [
|
19 |
+
# {'id': "", 'text': "", 'orgs': ["", ""]}
|
20 |
+
# ]
|
21 |
+
sample_data = json.load(f)
|
22 |
+
|
23 |
+
# _ = load_dotenv(find_dotenv()) # read local .env file
|
24 |
+
hf_token= os.environ['HF_TOKEN']
|
25 |
+
huggingface_hub.login(hf_token)
|
26 |
+
|
27 |
+
pipe = pipeline("token-classification", model="elshehawy/finer-ord-transformers", aggregation_strategy="first")
|
28 |
+
|
29 |
+
|
30 |
+
llm_model = 'gpt-3.5-turbo-0125'
|
31 |
+
# openai.api_key = os.environ['OPENAI_API_KEY']
|
32 |
+
|
33 |
+
client = OpenAI(
|
34 |
+
api_key=os.environ.get("OPENAI_API_KEY"),
|
35 |
+
)
|
36 |
+
|
37 |
+
|
38 |
+
def get_completion(prompt, model=llm_model):
|
39 |
+
messages = [{"role": "user", "content": prompt}]
|
40 |
+
response = client.chat.completions.create(
|
41 |
+
messages=messages,
|
42 |
+
model=model,
|
43 |
+
temperature=0,
|
44 |
+
)
|
45 |
+
return response.choices[0].message.content
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
def find_orgs_gpt(sentence):
|
50 |
+
prompt = f"""
|
51 |
+
In context of named entity recognition (NER), find all organizations in the text delimited by triple backticks.
|
52 |
+
|
53 |
+
text:
|
54 |
+
```
|
55 |
+
{sentence}
|
56 |
+
```
|
57 |
+
You should output only a list of organizations and follow this output format exactly: ["org_1", "org_2", "org_3"]
|
58 |
+
"""
|
59 |
+
|
60 |
+
sent_orgs_str = get_completion(prompt)
|
61 |
+
sent_orgs = json.loads(sent_orgs_str)
|
62 |
+
|
63 |
+
return sent_orgs
|
64 |
+
|
65 |
+
|
66 |
+
|
67 |
+
# def find_orgs_trf(sentence):
|
68 |
+
# org_list = []
|
69 |
+
# for ent in pipe(sentence):
|
70 |
+
# if ent['entity_group'] == 'ORG':
|
71 |
+
# # message += f'\n- {ent["word"]} \t- score: {ent["score"]}'
|
72 |
+
# # message += f'\n- {ent["word"]}'# \t- score: {ent["score"]}'
|
73 |
+
# org_list.append(ent['word'])
|
74 |
+
# return list(set(org_list))
|
75 |
+
|
76 |
+
|
77 |
+
true_orgs = [sent['orgs'] for sent in sample_data]
|
78 |
+
|
79 |
+
predicted_orgs_gpt = [find_orgs_gpt(sent['text']) for sent in sample_data]
|
80 |
+
# predicted_orgs_trf = [find_orgs_trf(sent['text']) for sent in sample_data]
|
81 |
+
|
82 |
+
all_metrics = {}
|
83 |
+
|
84 |
+
# sim_model = SimCSE('sentence-transformers/all-MiniLM-L6-v2')
|
85 |
+
# all_metrics['gpt'] = calc_metrics(true_orgs, predicted_orgs_gpt, sim_model)
|
86 |
+
all_metrics['trf'] = get_metrics_trf()
|
87 |
+
|
88 |
+
|
89 |
+
|
90 |
+
# example = """
|
91 |
+
# My latest exclusive for The Hill : Conservative frustration over Republican efforts to force a House vote on reauthorizing the Export - Import Bank boiled over Wednesday during a contentious GOP meeting.
|
92 |
+
|
93 |
+
# """
|
94 |
+
def find_orgs(sentence, choice):
|
95 |
+
return all_metrics
|
96 |
+
radio_btn = gr.Radio(choices=['GPT', 'iSemantics'], value='iSemantics', label='Available models', show_label=True)
|
97 |
+
textbox = gr.Textbox(label="Enter your text", placeholder=str(all_metrics), lines=8)
|
98 |
+
|
99 |
+
iface = gr.Interface(fn=find_orgs, inputs=[textbox, radio_btn], outputs="text", examples=[[example]])
|
100 |
+
iface.launch(share=True)
|
data/merged_dataset/dataset_dict.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"splits": ["train", "validation", "test", "orig_train", "orig_validation", "orig_test"]}
|
data/merged_dataset/orig_test/cache-eeafde0b6770e328.arrow
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e4b20ac141827d2e067e67afe6bb6efe6fdabf3d227c33b0764aff545c15ee6c
|
3 |
+
size 953224
|
data/merged_dataset/orig_test/data-00000-of-00001.arrow
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fc88467edb4babd0a7fc480903eed43b359e3755b5eecc87780fb33864530237
|
3 |
+
size 437856
|
data/merged_dataset/orig_test/dataset_info.json
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"citation": "",
|
3 |
+
"description": "",
|
4 |
+
"features": {
|
5 |
+
"id": {
|
6 |
+
"dtype": "string",
|
7 |
+
"_type": "Value"
|
8 |
+
},
|
9 |
+
"tokens": {
|
10 |
+
"feature": {
|
11 |
+
"dtype": "string",
|
12 |
+
"_type": "Value"
|
13 |
+
},
|
14 |
+
"_type": "Sequence"
|
15 |
+
},
|
16 |
+
"ner_tags": {
|
17 |
+
"feature": {
|
18 |
+
"names": [
|
19 |
+
"O",
|
20 |
+
"B-PER",
|
21 |
+
"I-PER",
|
22 |
+
"B-LOC",
|
23 |
+
"I-LOC",
|
24 |
+
"B-ORG",
|
25 |
+
"I-ORG"
|
26 |
+
],
|
27 |
+
"_type": "ClassLabel"
|
28 |
+
},
|
29 |
+
"_type": "Sequence"
|
30 |
+
}
|
31 |
+
},
|
32 |
+
"homepage": "",
|
33 |
+
"license": ""
|
34 |
+
}
|
data/merged_dataset/orig_test/state.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_data_files": [
|
3 |
+
{
|
4 |
+
"filename": "data-00000-of-00001.arrow"
|
5 |
+
}
|
6 |
+
],
|
7 |
+
"_fingerprint": "0b33bf3dd398a19a",
|
8 |
+
"_format_columns": null,
|
9 |
+
"_format_kwargs": {},
|
10 |
+
"_format_type": null,
|
11 |
+
"_output_all_columns": false,
|
12 |
+
"_split": null
|
13 |
+
}
|
data/merged_dataset/orig_train/cache-45d1543dc33c36be.arrow
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:83332abe5b5e7d05e3ae4376018429896530b916ab3ff74eb8ca7aef94497961
|
3 |
+
size 3009552
|
data/merged_dataset/orig_train/data-00000-of-00001.arrow
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ac320be565428a08ad7a3c43d03ad14810775cb0620b47659321228b17a22148
|
3 |
+
size 1371040
|
data/merged_dataset/orig_train/dataset_info.json
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"citation": "",
|
3 |
+
"description": "",
|
4 |
+
"features": {
|
5 |
+
"id": {
|
6 |
+
"dtype": "string",
|
7 |
+
"_type": "Value"
|
8 |
+
},
|
9 |
+
"tokens": {
|
10 |
+
"feature": {
|
11 |
+
"dtype": "string",
|
12 |
+
"_type": "Value"
|
13 |
+
},
|
14 |
+
"_type": "Sequence"
|
15 |
+
},
|
16 |
+
"ner_tags": {
|
17 |
+
"feature": {
|
18 |
+
"names": [
|
19 |
+
"O",
|
20 |
+
"B-PER",
|
21 |
+
"I-PER",
|
22 |
+
"B-LOC",
|
23 |
+
"I-LOC",
|
24 |
+
"B-ORG",
|
25 |
+
"I-ORG"
|
26 |
+
],
|
27 |
+
"_type": "ClassLabel"
|
28 |
+
},
|
29 |
+
"_type": "Sequence"
|
30 |
+
}
|
31 |
+
},
|
32 |
+
"homepage": "",
|
33 |
+
"license": ""
|
34 |
+
}
|
data/merged_dataset/orig_train/state.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_data_files": [
|
3 |
+
{
|
4 |
+
"filename": "data-00000-of-00001.arrow"
|
5 |
+
}
|
6 |
+
],
|
7 |
+
"_fingerprint": "74ec65c2b682826d",
|
8 |
+
"_format_columns": null,
|
9 |
+
"_format_kwargs": {},
|
10 |
+
"_format_type": null,
|
11 |
+
"_output_all_columns": false,
|
12 |
+
"_split": null
|
13 |
+
}
|
data/merged_dataset/orig_validation/cache-afff9bbc07b5bee3.arrow
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:94d972da8072255a5df65632859af45a7ce025dd587dc066106ea8e7224b0a1f
|
3 |
+
size 387592
|
data/merged_dataset/orig_validation/data-00000-of-00001.arrow
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:758e0d9bfceefd51b3cef856c8e15786ce0493da10bdf231f27e067b6b66caec
|
3 |
+
size 174712
|
data/merged_dataset/orig_validation/dataset_info.json
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"citation": "",
|
3 |
+
"description": "",
|
4 |
+
"features": {
|
5 |
+
"id": {
|
6 |
+
"dtype": "string",
|
7 |
+
"_type": "Value"
|
8 |
+
},
|
9 |
+
"tokens": {
|
10 |
+
"feature": {
|
11 |
+
"dtype": "string",
|
12 |
+
"_type": "Value"
|
13 |
+
},
|
14 |
+
"_type": "Sequence"
|
15 |
+
},
|
16 |
+
"ner_tags": {
|
17 |
+
"feature": {
|
18 |
+
"names": [
|
19 |
+
"O",
|
20 |
+
"B-PER",
|
21 |
+
"I-PER",
|
22 |
+
"B-LOC",
|
23 |
+
"I-LOC",
|
24 |
+
"B-ORG",
|
25 |
+
"I-ORG"
|
26 |
+
],
|
27 |
+
"_type": "ClassLabel"
|
28 |
+
},
|
29 |
+
"_type": "Sequence"
|
30 |
+
}
|
31 |
+
},
|
32 |
+
"homepage": "",
|
33 |
+
"license": ""
|
34 |
+
}
|
data/merged_dataset/orig_validation/state.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_data_files": [
|
3 |
+
{
|
4 |
+
"filename": "data-00000-of-00001.arrow"
|
5 |
+
}
|
6 |
+
],
|
7 |
+
"_fingerprint": "2b90f959ed79ba44",
|
8 |
+
"_format_columns": null,
|
9 |
+
"_format_kwargs": {},
|
10 |
+
"_format_type": null,
|
11 |
+
"_output_all_columns": false,
|
12 |
+
"_split": null
|
13 |
+
}
|
data/merged_dataset/test/cache-3a6709085dd0f520.arrow
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:28be4e971c5d73f16b208f4b15d5965cfc81fb8936ce1c711fedc6fff5b3479a
|
3 |
+
size 953224
|
data/merged_dataset/test/cache-50fbc051d6b536f8.arrow
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:141c98d5d03e91f39e198897f147e0c2c6fa2c7a4c55174993392ec512599b34
|
3 |
+
size 953224
|
data/merged_dataset/test/cache-7344e423192cdf30.arrow
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:855fbd5d5477353be5930ee9ed4435238d847ef0971abe8106056e8d93639cd8
|
3 |
+
size 953240
|
data/merged_dataset/test/cache-861a0fd50d74bfe1.arrow
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:004f623f011a702e9eff454113818978e90f497b8ad806a8f86fa011868a0831
|
3 |
+
size 12304024
|
data/merged_dataset/test/data-00000-of-00001.arrow
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bc107fe614a3ee59fd5b302dc0a56896e63f2a3106fd88b5c52d4fd88b77a0fe
|
3 |
+
size 437856
|
data/merged_dataset/test/dataset_info.json
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"citation": "",
|
3 |
+
"description": "",
|
4 |
+
"features": {
|
5 |
+
"id": {
|
6 |
+
"dtype": "string",
|
7 |
+
"_type": "Value"
|
8 |
+
},
|
9 |
+
"tokens": {
|
10 |
+
"feature": {
|
11 |
+
"dtype": "string",
|
12 |
+
"_type": "Value"
|
13 |
+
},
|
14 |
+
"_type": "Sequence"
|
15 |
+
},
|
16 |
+
"ner_tags": {
|
17 |
+
"feature": {
|
18 |
+
"names": [
|
19 |
+
"O",
|
20 |
+
"B-PER",
|
21 |
+
"I-PER",
|
22 |
+
"B-LOC",
|
23 |
+
"I-LOC",
|
24 |
+
"B-ORG",
|
25 |
+
"I-ORG"
|
26 |
+
],
|
27 |
+
"_type": "ClassLabel"
|
28 |
+
},
|
29 |
+
"_type": "Sequence"
|
30 |
+
}
|
31 |
+
},
|
32 |
+
"homepage": "",
|
33 |
+
"license": ""
|
34 |
+
}
|
data/merged_dataset/test/state.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_data_files": [
|
3 |
+
{
|
4 |
+
"filename": "data-00000-of-00001.arrow"
|
5 |
+
}
|
6 |
+
],
|
7 |
+
"_fingerprint": "538471187ad5b763",
|
8 |
+
"_format_columns": null,
|
9 |
+
"_format_kwargs": {},
|
10 |
+
"_format_type": null,
|
11 |
+
"_output_all_columns": false,
|
12 |
+
"_split": null
|
13 |
+
}
|
data/merged_dataset/train/cache-f8f6a910898e33f3.arrow
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1060a3d1000579dc92f65efb200632efdafa80f5d750f0c2298d82193e648f3e
|
3 |
+
size 3009552
|
data/merged_dataset/train/data-00000-of-00001.arrow
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1546be0dd9960d920988ce2bb6883fc567db03c2c80d0f8678d4bf95001a1a5f
|
3 |
+
size 1371040
|
data/merged_dataset/train/dataset_info.json
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"citation": "",
|
3 |
+
"description": "",
|
4 |
+
"features": {
|
5 |
+
"id": {
|
6 |
+
"dtype": "string",
|
7 |
+
"_type": "Value"
|
8 |
+
},
|
9 |
+
"tokens": {
|
10 |
+
"feature": {
|
11 |
+
"dtype": "string",
|
12 |
+
"_type": "Value"
|
13 |
+
},
|
14 |
+
"_type": "Sequence"
|
15 |
+
},
|
16 |
+
"ner_tags": {
|
17 |
+
"feature": {
|
18 |
+
"names": [
|
19 |
+
"O",
|
20 |
+
"B-PER",
|
21 |
+
"I-PER",
|
22 |
+
"B-LOC",
|
23 |
+
"I-LOC",
|
24 |
+
"B-ORG",
|
25 |
+
"I-ORG"
|
26 |
+
],
|
27 |
+
"_type": "ClassLabel"
|
28 |
+
},
|
29 |
+
"_type": "Sequence"
|
30 |
+
}
|
31 |
+
},
|
32 |
+
"homepage": "",
|
33 |
+
"license": ""
|
34 |
+
}
|
data/merged_dataset/train/state.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_data_files": [
|
3 |
+
{
|
4 |
+
"filename": "data-00000-of-00001.arrow"
|
5 |
+
}
|
6 |
+
],
|
7 |
+
"_fingerprint": "13b20c4adf67dcf4",
|
8 |
+
"_format_columns": null,
|
9 |
+
"_format_kwargs": {},
|
10 |
+
"_format_type": null,
|
11 |
+
"_output_all_columns": false,
|
12 |
+
"_split": null
|
13 |
+
}
|
data/merged_dataset/validation/cache-a70cdc1f600f2440.arrow
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a0f9fa1d8fcd428b779d15f9f386d8b463c9f542dadb0056a16e3eb6b817cb5a
|
3 |
+
size 387592
|
data/merged_dataset/validation/cache-c442280565074102.arrow
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ca8b2d2e4610afe0cf6daf0ca37d02414e1df7e6a486c80e0ea2b25bf7808807
|
3 |
+
size 387592
|
data/merged_dataset/validation/data-00000-of-00001.arrow
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0fb30ccc26dc3d3172ffd54077d11217436ed169738251bf51fdb82908497868
|
3 |
+
size 174712
|
data/merged_dataset/validation/dataset_info.json
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"citation": "",
|
3 |
+
"description": "",
|
4 |
+
"features": {
|
5 |
+
"id": {
|
6 |
+
"dtype": "string",
|
7 |
+
"_type": "Value"
|
8 |
+
},
|
9 |
+
"tokens": {
|
10 |
+
"feature": {
|
11 |
+
"dtype": "string",
|
12 |
+
"_type": "Value"
|
13 |
+
},
|
14 |
+
"_type": "Sequence"
|
15 |
+
},
|
16 |
+
"ner_tags": {
|
17 |
+
"feature": {
|
18 |
+
"names": [
|
19 |
+
"O",
|
20 |
+
"B-PER",
|
21 |
+
"I-PER",
|
22 |
+
"B-LOC",
|
23 |
+
"I-LOC",
|
24 |
+
"B-ORG",
|
25 |
+
"I-ORG"
|
26 |
+
],
|
27 |
+
"_type": "ClassLabel"
|
28 |
+
},
|
29 |
+
"_type": "Sequence"
|
30 |
+
}
|
31 |
+
},
|
32 |
+
"homepage": "",
|
33 |
+
"license": ""
|
34 |
+
}
|
data/merged_dataset/validation/state.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_data_files": [
|
3 |
+
{
|
4 |
+
"filename": "data-00000-of-00001.arrow"
|
5 |
+
}
|
6 |
+
],
|
7 |
+
"_fingerprint": "f95fe8e7a800be97",
|
8 |
+
"_format_columns": null,
|
9 |
+
"_format_kwargs": {},
|
10 |
+
"_format_type": null,
|
11 |
+
"_output_all_columns": false,
|
12 |
+
"_split": null
|
13 |
+
}
|
data/ner_feature.pickle
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8ecfd61f261845f22d0b83a72263f7326514d78d71d3c52534ede75671dacc70
|
3 |
+
size 286
|
data/sample_data.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
evaluate_model.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import evaluate
|
2 |
+
import numpy as np
|
3 |
+
import pickle
|
4 |
+
|
5 |
+
metric = evaluate.load("seqeval")
|
6 |
+
with open('./data/ner_feature.pickle', 'rb') as f:
|
7 |
+
ner_feature = pickle.load(f)
|
8 |
+
|
9 |
+
label_names = ner_feature.feature.names
|
10 |
+
# label2id = {label: ner_feature.feature.str2int(label) for label in label_names}
|
11 |
+
# id2label = {v: k for k, v in label2id.items()}
|
12 |
+
|
13 |
+
def compute_metrics(eval_preds):
|
14 |
+
"""
|
15 |
+
This compute_metrics() function first takes the argmax of the logits to convert them to predictions
|
16 |
+
(as usual, the logits and the probabilities are in the same order,
|
17 |
+
so we donβt need to apply the softmax).
|
18 |
+
Then we have to convert both labels and predictions from integers to strings.
|
19 |
+
We remove all the values where the label is -100, then pass the results to the metric.compute() method:
|
20 |
+
"""
|
21 |
+
|
22 |
+
logits, labels = eval_preds
|
23 |
+
predictions = np.argmax(logits, axis=-1)
|
24 |
+
|
25 |
+
# Remove ignored index (special tokens) and convert to labels
|
26 |
+
true_labels = [[label_names[l] for l in label if l != -100] for label in labels]
|
27 |
+
true_predictions = [
|
28 |
+
[label_names[p] for (p, l) in zip(prediction, label) if l != -100]
|
29 |
+
for prediction, label in zip(predictions, labels)
|
30 |
+
]
|
31 |
+
all_metrics = metric.compute(predictions=true_predictions, references=true_labels)
|
32 |
+
|
33 |
+
# return all_metrics
|
34 |
+
# return {
|
35 |
+
# "precision": all_metrics["overall_precision"],
|
36 |
+
# "recall": all_metrics["overall_recall"],
|
37 |
+
# "f1": all_metrics["overall_f1"],
|
38 |
+
# "accuracy": all_metrics["overall_accuracy"],
|
39 |
+
# }
|
40 |
+
|
41 |
+
return {
|
42 |
+
# organization metrics
|
43 |
+
'org_precision': all_metrics['ORG']['precision'],
|
44 |
+
'org_recall': all_metrics['ORG']['recall'],
|
45 |
+
'org_f1': all_metrics['ORG']['f1'],
|
46 |
+
|
47 |
+
# person metrics
|
48 |
+
'per_precision': all_metrics['PER']['precision'],
|
49 |
+
'per_recall': all_metrics['PER']['recall'],
|
50 |
+
'per_f1': all_metrics['PER']['f1'],
|
51 |
+
|
52 |
+
# location metrics
|
53 |
+
'loc_precision': all_metrics['LOC']['precision'],
|
54 |
+
'loc_recall': all_metrics['LOC']['recall'],
|
55 |
+
'loc_f1': all_metrics['LOC']['f1'],
|
56 |
+
|
57 |
+
# over all metrics
|
58 |
+
'precision': all_metrics['overall_precision'],
|
59 |
+
'recall': all_metrics['overall_recall'],
|
60 |
+
'f1': all_metrics['overall_f1'],
|
61 |
+
'accuracy': all_metrics['overall_accuracy']
|
62 |
+
}
|
metrics.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def calc_recall(true_pos, false_neg, eps=1e-8):
|
2 |
+
return true_pos / (true_pos + false_neg + eps)
|
3 |
+
|
4 |
+
|
5 |
+
|
6 |
+
def calc_precision(true_pos, false_pos, eps=1e-8):
|
7 |
+
return true_pos / (true_pos + false_pos + eps)
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
def calc_f1_score(precision, recall, eps=1e-8):
|
12 |
+
return (2*precision*recall) / (precision + recall + eps)
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
def calc_metrics(true, predicted, model, threshold=0.95, eps=1e-8):
|
17 |
+
true_pos = 0
|
18 |
+
false_pos = 0
|
19 |
+
false_neg = 0
|
20 |
+
|
21 |
+
false_pos_ids = []
|
22 |
+
false_neg_ids = []
|
23 |
+
|
24 |
+
i = 0
|
25 |
+
total = len(true)
|
26 |
+
for j, (true_ents, pred_ents) in enumerate(zip(true, predicted)):
|
27 |
+
i += 1
|
28 |
+
# print(f'{i}/{total}')
|
29 |
+
# print('----------------------------')
|
30 |
+
|
31 |
+
if len(true_ents) == 0:
|
32 |
+
false_pos += len(pred_ents)
|
33 |
+
|
34 |
+
if len(pred_ents) > 0:
|
35 |
+
false_pos_ids.append(j)
|
36 |
+
|
37 |
+
continue
|
38 |
+
|
39 |
+
if len(pred_ents) == 0:
|
40 |
+
false_neg += len(true_ents)
|
41 |
+
|
42 |
+
if len(true_ents) > 0:
|
43 |
+
# print('False Negative')
|
44 |
+
false_neg_ids.append(j)
|
45 |
+
|
46 |
+
continue
|
47 |
+
|
48 |
+
similarities = model.similarity(true_ents, pred_ents, device='cuda')
|
49 |
+
|
50 |
+
for row in similarities:
|
51 |
+
if (row >= threshold).any():
|
52 |
+
true_pos += 1
|
53 |
+
else:
|
54 |
+
false_neg += 1
|
55 |
+
# print('False Negative 2222222')
|
56 |
+
false_neg_ids.append(j)
|
57 |
+
|
58 |
+
for row in similarities.T:
|
59 |
+
if (row >= threshold).any():
|
60 |
+
continue
|
61 |
+
else:
|
62 |
+
false_pos += 1
|
63 |
+
false_pos_ids.append(j)
|
64 |
+
|
65 |
+
recall = calc_recall(true_pos, false_neg)
|
66 |
+
precision = calc_precision(true_pos, false_pos)
|
67 |
+
f1_score = calc_f1_score(precision, recall, eps=eps)
|
68 |
+
|
69 |
+
return {
|
70 |
+
# 'true_pos': true_pos,
|
71 |
+
# 'false_pos': false_pos,
|
72 |
+
# 'false_neg': false_neg,
|
73 |
+
'recall': recall,
|
74 |
+
'precision': precision,
|
75 |
+
'f1': f1_score,
|
76 |
+
# 'false_pos_ids': list(set(false_pos_ids)),
|
77 |
+
# 'false_neg_ids': list(set(false_neg_ids))
|
78 |
+
}
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
openai
|
2 |
+
transformers[torch]
|
3 |
+
tqdm==4.66.1
|
4 |
+
datasets==2.18.0
|
5 |
+
evaluate
|
6 |
+
seqeval
|
7 |
+
rich
|
utils.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def find_broken_examples(data):
|
2 |
+
splits = list(data.keys())
|
3 |
+
broken = []
|
4 |
+
|
5 |
+
for s in splits:
|
6 |
+
for i, tokens in enumerate(data[s]['tokens']):
|
7 |
+
for token in tokens:
|
8 |
+
if not token.isprintable():
|
9 |
+
broken.append(s + '-' + str(i))
|
10 |
+
|
11 |
+
return broken
|
12 |
+
|
13 |
+
|
14 |
+
def update_data(examples, split, broken_ids):
|
15 |
+
new_tags = []
|
16 |
+
new_tokens = []
|
17 |
+
for id_ in examples['id']:
|
18 |
+
sent_id = split + '-' + id_
|
19 |
+
if sent_id in broken_ids:
|
20 |
+
continue
|
21 |
+
|
22 |
+
new_tokens.append(examples['tokens'][int(id_)])
|
23 |
+
new_tags.append(examples['ner_tags'][int(id_)])
|
24 |
+
|
25 |
+
assert len(new_tokens) == len(new_tags)
|
26 |
+
assert len(new_tokens[-1]) == len(new_tags[-1])
|
27 |
+
|
28 |
+
return {
|
29 |
+
'id': [str(i) for i in range(len(new_tokens))],
|
30 |
+
'tokens': new_tokens,
|
31 |
+
'ner_tags': new_tags
|
32 |
+
}
|
33 |
+
|
34 |
+
|
35 |
+
def align_labels_with_tokens(labels, word_ids):
|
36 |
+
new_labels = []
|
37 |
+
current_word = None
|
38 |
+
for word_id in word_ids:
|
39 |
+
if word_id != current_word:
|
40 |
+
# Start of a new word!
|
41 |
+
current_word = word_id
|
42 |
+
label = -100 if word_id is None else labels[word_id]
|
43 |
+
new_labels.append(label)
|
44 |
+
elif word_id is None:
|
45 |
+
# Special token
|
46 |
+
new_labels.append(-100)
|
47 |
+
else:
|
48 |
+
# Same word as previous token
|
49 |
+
# label = labels[word_id]
|
50 |
+
# If the label is B-XXX we change it to I-XXX
|
51 |
+
# if label % 2 == 1:
|
52 |
+
# label += 1
|
53 |
+
label = -100
|
54 |
+
new_labels.append(label)
|
55 |
+
|
56 |
+
return new_labels
|
57 |
+
|
58 |
+
|
59 |
+
def tokenize_and_align_labels(examples, tokenizer):
|
60 |
+
tokenized_inputs = tokenizer(
|
61 |
+
examples["tokens"], truncation=True, is_split_into_words=True, padding='max_length'
|
62 |
+
)
|
63 |
+
all_labels = examples["ner_tags"]
|
64 |
+
new_labels = []
|
65 |
+
word_ids = []
|
66 |
+
for i, labels in enumerate(all_labels):
|
67 |
+
word_ids.append(tokenized_inputs.word_ids(i))
|
68 |
+
new_labels.append(align_labels_with_tokens(labels, word_ids[i]))
|
69 |
+
|
70 |
+
tokenized_inputs["labels"] = new_labels
|
71 |
+
tokenized_inputs['word_ids'] = word_ids
|
72 |
+
|
73 |
+
return tokenized_inputs
|
74 |
+
|
75 |
+
|
76 |
+
# def model_init(checkpoint, id2label, label2id):
|
77 |
+
# model = AutoModelForTokenClassification.from_pretrained(
|
78 |
+
# checkpoint,
|
79 |
+
# id2label=id2label,
|
80 |
+
# label2id=label2id
|
81 |
+
# )
|
82 |
+
|
83 |
+
# return model
|