danieldux commited on
Commit
e847a58
1 Parent(s): 45304f1

Add notebook with evaluation metric tests

Browse files
Files changed (2) hide show
  1. isco_results.json +1 -0
  2. tests.ipynb +265 -0
isco_results.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"accuracy": 0.8611914401388086, "hierarchical_precision": 0.989010989010989, "hierarchical_recall": 0.9836065573770492, "hierarchical_fmeasure": 0.9863013698630136}
tests.ipynb ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# ISCO-08 hierarchical accuracy measure"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": 18,
13
+ "metadata": {},
14
+ "outputs": [
15
+ {
16
+ "name": "stdout",
17
+ "output_type": "stream",
18
+ "text": [
19
+ "ISCO CSV file downloaded\n",
20
+ "Weighted ISCO hierarchy dictionary created as isco_hierarchy\n",
21
+ "\n",
22
+ "The ISCO-08 Hierarchical Accuracy Measure is an implementation of the measure described in [Functional Annotation of Genes Using Hierarchical Text Categorization](https://www.researchgate.net/publication/44046343_Functional_Annotation_of_Genes_Using_Hierarchical_Text_Categorization) (Kiritchenko, Svetlana and Famili, Fazel. 2005) and adapted for the ISCO-08 classification scheme by the International Labour Organization.\n",
23
+ "\n",
24
+ "The measure rewards more precise classifications that correctly identify an occupation's placement down to the specific Unit group level and applies penalties for misclassifications based on the hierarchical distance between the correct and assigned categories.\n",
25
+ "\n",
26
+ "\n"
27
+ ]
28
+ }
29
+ ],
30
+ "source": [
31
+ "import evaluate\n",
32
+ "\n",
33
+ "ham = evaluate.load(\"/home/dux/workspace/1-IEA_RnD/isco_hierarchical_accuracy\")\n",
34
+ "print(ham.description)"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": 3,
40
+ "metadata": {},
41
+ "outputs": [
42
+ {
43
+ "name": "stdout",
44
+ "output_type": "stream",
45
+ "text": [
46
+ "References: ['1111', '1112', '1113', '1114', '1120']\n",
47
+ "Predictions: ['1111', '1113', '1120', '1211', '2111']\n",
48
+ "Accuracy: 0.2, Hierarchical Precision: 0.5, Hierarchical Recall: 0.7777777777777778, Hierarchical F-measure: 0.6086956521739131\n",
49
+ "{'accuracy': 0.2, 'hierarchical_precision': 0.5, 'hierarchical_recall': 0.7777777777777778, 'hierarchical_fmeasure': 0.6086956521739131}\n"
50
+ ]
51
+ }
52
+ ],
53
+ "source": [
54
+ "references = [\"1111\", \"1112\", \"1113\", \"1114\", \"1120\"]\n",
55
+ "predictions = [\"1111\", \"1113\", \"1120\", \"1211\", \"2111\"]\n",
56
+ "\n",
57
+ "print(f\"References: {references}\")\n",
58
+ "print(f\"Predictions: {predictions}\")\n",
59
+ "print(ham.compute(references=references, predictions=predictions))"
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "code",
64
+ "execution_count": 16,
65
+ "metadata": {},
66
+ "outputs": [
67
+ {
68
+ "name": "stdout",
69
+ "output_type": "stream",
70
+ "text": [
71
+ "TEST CASE #1\n",
72
+ "References: ['1111', '1111', '1111', '1111', '1111', '1111', '1111', '1111', '1111', '1111']\n",
73
+ "Predictions: ['1111', '1112', '1120', '1211', '1311', '2111', '111', '11', '1', '9999']\n",
74
+ "Accuracy: 0.1, Hierarchical Precision: 0.2222222222222222, Hierarchical Recall: 1.0, Hierarchical F-measure: 0.3636363636363636\n",
75
+ "{'accuracy': 0.1, 'hierarchical_precision': 0.2222222222222222, 'hierarchical_recall': 1.0, 'hierarchical_fmeasure': 0.3636363636363636}\n",
76
+ "\n",
77
+ "TEST CASE #2\n",
78
+ "References: ['1111']\n",
79
+ "Predictions: ['1111']\n",
80
+ "Accuracy: 1.0, Hierarchical Precision: 1.0, Hierarchical Recall: 1.0, Hierarchical F-measure: 1.0\n",
81
+ "{'accuracy': 1.0, 'hierarchical_precision': 1.0, 'hierarchical_recall': 1.0, 'hierarchical_fmeasure': 1.0}\n",
82
+ "\n",
83
+ "TEST CASE #3\n",
84
+ "References: ['1111']\n",
85
+ "Predictions: ['1112']\n",
86
+ "Accuracy: 0.0, Hierarchical Precision: 0.75, Hierarchical Recall: 0.75, Hierarchical F-measure: 0.75\n",
87
+ "{'accuracy': 0.0, 'hierarchical_precision': 0.75, 'hierarchical_recall': 0.75, 'hierarchical_fmeasure': 0.75}\n",
88
+ "\n",
89
+ "TEST CASE #4\n",
90
+ "References: ['1111']\n",
91
+ "Predictions: ['1120']\n",
92
+ "Accuracy: 0.0, Hierarchical Precision: 0.5, Hierarchical Recall: 0.5, Hierarchical F-measure: 0.5\n",
93
+ "{'accuracy': 0.0, 'hierarchical_precision': 0.5, 'hierarchical_recall': 0.5, 'hierarchical_fmeasure': 0.5}\n",
94
+ "\n",
95
+ "TEST CASE #5\n",
96
+ "References: ['1111']\n",
97
+ "Predictions: ['1211']\n",
98
+ "Accuracy: 0.0, Hierarchical Precision: 0.25, Hierarchical Recall: 0.25, Hierarchical F-measure: 0.25\n",
99
+ "{'accuracy': 0.0, 'hierarchical_precision': 0.25, 'hierarchical_recall': 0.25, 'hierarchical_fmeasure': 0.25}\n",
100
+ "\n",
101
+ "TEST CASE #6\n",
102
+ "References: ['1111']\n",
103
+ "Predictions: ['1311']\n",
104
+ "Accuracy: 0.0, Hierarchical Precision: 0.25, Hierarchical Recall: 0.25, Hierarchical F-measure: 0.25\n",
105
+ "{'accuracy': 0.0, 'hierarchical_precision': 0.25, 'hierarchical_recall': 0.25, 'hierarchical_fmeasure': 0.25}\n",
106
+ "\n",
107
+ "TEST CASE #7\n",
108
+ "References: ['1111']\n",
109
+ "Predictions: ['2111']\n",
110
+ "Accuracy: 0.0, Hierarchical Precision: 0.0, Hierarchical Recall: 0.0, Hierarchical F-measure: 0\n",
111
+ "{'accuracy': 0.0, 'hierarchical_precision': 0.0, 'hierarchical_recall': 0.0, 'hierarchical_fmeasure': 0}\n",
112
+ "\n",
113
+ "TEST CASE #8\n",
114
+ "References: ['1111']\n",
115
+ "Predictions: ['111']\n",
116
+ "Accuracy: 0.0, Hierarchical Precision: 1.0, Hierarchical Recall: 0.25, Hierarchical F-measure: 0.4\n",
117
+ "{'accuracy': 0.0, 'hierarchical_precision': 1.0, 'hierarchical_recall': 0.25, 'hierarchical_fmeasure': 0.4}\n",
118
+ "\n",
119
+ "TEST CASE #9\n",
120
+ "References: ['1111']\n",
121
+ "Predictions: ['11']\n",
122
+ "Accuracy: 0.0, Hierarchical Precision: 1.0, Hierarchical Recall: 0.25, Hierarchical F-measure: 0.4\n",
123
+ "{'accuracy': 0.0, 'hierarchical_precision': 1.0, 'hierarchical_recall': 0.25, 'hierarchical_fmeasure': 0.4}\n",
124
+ "\n",
125
+ "TEST CASE #10\n",
126
+ "References: ['1111']\n",
127
+ "Predictions: ['1']\n",
128
+ "Accuracy: 0.0, Hierarchical Precision: 1.0, Hierarchical Recall: 0.25, Hierarchical F-measure: 0.4\n",
129
+ "{'accuracy': 0.0, 'hierarchical_precision': 1.0, 'hierarchical_recall': 0.25, 'hierarchical_fmeasure': 0.4}\n",
130
+ "\n",
131
+ "TEST CASE #11\n",
132
+ "References: ['1111']\n",
133
+ "Predictions: ['9999']\n",
134
+ "Accuracy: 0.0, Hierarchical Precision: 0.0, Hierarchical Recall: 0.0, Hierarchical F-measure: 0\n",
135
+ "{'accuracy': 0.0, 'hierarchical_precision': 0.0, 'hierarchical_recall': 0.0, 'hierarchical_fmeasure': 0}\n",
136
+ "\n"
137
+ ]
138
+ }
139
+ ],
140
+ "source": [
141
+ "# Compute all test cases and print the results\n",
142
+ "from tests import test_cases\n",
143
+ "\n",
144
+ "test_number = 1\n",
145
+ "\n",
146
+ "for test_case in test_cases:\n",
147
+ " references = test_case[\"references\"]\n",
148
+ " predictions = test_case[\"predictions\"]\n",
149
+ " print(f\"TEST CASE #{test_number}\")\n",
150
+ " print(f\"References: {references}\")\n",
151
+ " print(f\"Predictions: {predictions}\")\n",
152
+ " print(ham.compute(references=references, predictions=predictions))\n",
153
+ " print()\n",
154
+ " test_number += 1"
155
+ ]
156
+ },
157
+ {
158
+ "cell_type": "markdown",
159
+ "metadata": {},
160
+ "source": [
161
+ "# Model evaluation using the test split of the dataset"
162
+ ]
163
+ },
164
+ {
165
+ "cell_type": "code",
166
+ "execution_count": 17,
167
+ "metadata": {},
168
+ "outputs": [
169
+ {
170
+ "name": "stdout",
171
+ "output_type": "stream",
172
+ "text": [
173
+ "ISCO CSV file downloaded\n",
174
+ "Weighted ISCO hierarchy dictionary created\n",
175
+ "{'1111': {'111': 0.75, '11': 0.5, '1': 0.25}, '1112': {'111': 0.75, '11': 0.5, '1': 0.25}, '1113': {'111': 0.75, '11': 0.5, '1': 0.25}, '1114': {'111': 0.75, '11': 0.5, '1': 0.25}, '1120': {'112': 0.75, '11': 0.5, '1': 0.25}, '1211': {'121': 0.75, '12': 0.5, '1': 0.25}, '1212': {'121': 0.75, '12': 0.5, '1': 0.25}, '1213': {'121': 0.75, '12': 0.5, '1': 0.25}, '1219': {'121': 0.75, '12': 0.5, '1': 0.25}, '1221': {'122': 0.75, '12': 0.5, '1': 0.25}, '1222': {'122': 0.75, '12': 0.5, '1': 0.25}, '1223': {'122': 0.75, '12': 0.5, '1': 0.25}, '1311': {'131': 0.75, '13': 0.5, '1': 0.25}, '1312': {'131': 0.75, '13': 0.5, '1': 0.25}, '1321': {'132': 0.75, '13': 0.5, '1': 0.25}, '1322': {'132': 0.75, '13': 0.5, '1': 0.25}, '1323': {'132': 0.75, '13': 0.5, '1': 0.25}, '1324': {'132': 0.75, '13': 0.5, '1': 0.25}, '1330': {'133': 0.75, '13': 0.5, '1': 0.25}, '1341': {'134': 0.75, '13': 0.5, '1': 0.25}, '1342': {'134': 0.75, '13': 0.5, '1': 0.25}, '1343': {'134': 0.75, '13': 0.5, '1': 0.25}, '1344': {'134': 0.75, '13': 0.5, '1': 0.25}, '1345': {'134': 0.75, '13': 0.5, '1': 0.25}, '1346': {'134': 0.75, '13': 0.5, '1': 0.25}, '1349': {'134': 0.75, '13': 0.5, '1': 0.25}, '1411': {'141': 0.75, '14': 0.5, '1': 0.25}, '1412': {'141': 0.75, '14': 0.5, '1': 0.25}, '1420': {'142': 0.75, '14': 0.5, '1': 0.25}, '1431': {'143': 0.75, '14': 0.5, '1': 0.25}, '1439': {'143': 0.75, '14': 0.5, '1': 0.25}, '2111': {'211': 0.75, '21': 0.5, '2': 0.25}, '2112': {'211': 0.75, '21': 0.5, '2': 0.25}, '2113': {'211': 0.75, '21': 0.5, '2': 0.25}, '2114': {'211': 0.75, '21': 0.5, '2': 0.25}, '2120': {'212': 0.75, '21': 0.5, '2': 0.25}, '2131': {'213': 0.75, '21': 0.5, '2': 0.25}, '2132': {'213': 0.75, '21': 0.5, '2': 0.25}, '2133': {'213': 0.75, '21': 0.5, '2': 0.25}, '2141': {'214': 0.75, '21': 0.5, '2': 0.25}, '2142': {'214': 0.75, '21': 0.5, '2': 0.25}, '2143': {'214': 0.75, '21': 0.5, '2': 0.25}, '2144': {'214': 0.75, '21': 0.5, '2': 0.25}, '2145': {'214': 0.75, '21': 0.5, '2': 0.25}, '2146': {'214': 0.75, '21': 0.5, '2': 0.25}, '2149': {'214': 0.75, '21': 0.5, '2': 0.25}, '2151': {'215': 0.75, '21': 0.5, '2': 0.25}, '2152': {'215': 0.75, '21': 0.5, '2': 0.25}, '2153': {'215': 0.75, '21': 0.5, '2': 0.25}, '2161': {'216': 0.75, '21': 0.5, '2': 0.25}, '2162': {'216': 0.75, '21': 0.5, '2': 0.25}, '2163': {'216': 0.75, '21': 0.5, '2': 0.25}, '2164': {'216': 0.75, '21': 0.5, '2': 0.25}, '2165': {'216': 0.75, '21': 0.5, '2': 0.25}, '2166': {'216': 0.75, '21': 0.5, '2': 0.25}, '2211': {'221': 0.75, '22': 0.5, '2': 0.25}, '2212': {'221': 0.75, '22': 0.5, '2': 0.25}, '2221': {'222': 0.75, '22': 0.5, '2': 0.25}, '2222': {'222': 0.75, '22': 0.5, '2': 0.25}, '2230': {'223': 0.75, '22': 0.5, '2': 0.25}, '2240': {'224': 0.75, '22': 0.5, '2': 0.25}, '2250': {'225': 0.75, '22': 0.5, '2': 0.25}, '2261': {'226': 0.75, '22': 0.5, '2': 0.25}, '2262': {'226': 0.75, '22': 0.5, '2': 0.25}, '2263': {'226': 0.75, '22': 0.5, '2': 0.25}, '2264': {'226': 0.75, '22': 0.5, '2': 0.25}, '2265': {'226': 0.75, '22': 0.5, '2': 0.25}, '2266': {'226': 0.75, '22': 0.5, '2': 0.25}, '2267': {'226': 0.75, '22': 0.5, '2': 0.25}, '2269': {'226': 0.75, '22': 0.5, '2': 0.25}, '2310': {'231': 0.75, '23': 0.5, '2': 0.25}, '2320': {'232': 0.75, '23': 0.5, '2': 0.25}, '2330': {'233': 0.75, '23': 0.5, '2': 0.25}, '2341': {'234': 0.75, '23': 0.5, '2': 0.25}, '2342': {'234': 0.75, '23': 0.5, '2': 0.25}, '2351': {'235': 0.75, '23': 0.5, '2': 0.25}, '2352': {'235': 0.75, '23': 0.5, '2': 0.25}, '2353': {'235': 0.75, '23': 0.5, '2': 0.25}, '2354': {'235': 0.75, '23': 0.5, '2': 0.25}, '2355': {'235': 0.75, '23': 0.5, '2': 0.25}, '2356': {'235': 0.75, '23': 0.5, '2': 0.25}, '2359': {'235': 0.75, '23': 0.5, '2': 0.25}, '2411': {'241': 0.75, '24': 0.5, '2': 0.25}, '2412': {'241': 0.75, '24': 0.5, '2': 0.25}, '2413': {'241': 0.75, '24': 0.5, '2': 0.25}, '2421': {'242': 0.75, '24': 0.5, '2': 0.25}, '2422': {'242': 0.75, '24': 0.5, '2': 0.25}, '2423': {'242': 0.75, '24': 0.5, '2': 0.25}, '2424': {'242': 0.75, '24': 0.5, '2': 0.25}, '2431': {'243': 0.75, '24': 0.5, '2': 0.25}, '2432': {'243': 0.75, '24': 0.5, '2': 0.25}, '2433': {'243': 0.75, '24': 0.5, '2': 0.25}, '2434': {'243': 0.75, '24': 0.5, '2': 0.25}, '2511': {'251': 0.75, '25': 0.5, '2': 0.25}, '2512': {'251': 0.75, '25': 0.5, '2': 0.25}, '2513': {'251': 0.75, '25': 0.5, '2': 0.25}, '2514': {'251': 0.75, '25': 0.5, '2': 0.25}, '2519': {'251': 0.75, '25': 0.5, '2': 0.25}, '2521': {'252': 0.75, '25': 0.5, '2': 0.25}, '2522': {'252': 0.75, '25': 0.5, '2': 0.25}, '2523': {'252': 0.75, '25': 0.5, '2': 0.25}, '2529': {'252': 0.75, '25': 0.5, '2': 0.25}, '2611': {'261': 0.75, '26': 0.5, '2': 0.25}, '2612': {'261': 0.75, '26': 0.5, '2': 0.25}, '2619': {'261': 0.75, '26': 0.5, '2': 0.25}, '2621': {'262': 0.75, '26': 0.5, '2': 0.25}, '2622': {'262': 0.75, '26': 0.5, '2': 0.25}, '2631': {'263': 0.75, '26': 0.5, '2': 0.25}, '2632': {'263': 0.75, '26': 0.5, '2': 0.25}, '2633': {'263': 0.75, '26': 0.5, '2': 0.25}, '2634': {'263': 0.75, '26': 0.5, '2': 0.25}, '2635': {'263': 0.75, '26': 0.5, '2': 0.25}, '2636': {'263': 0.75, '26': 0.5, '2': 0.25}, '2641': {'264': 0.75, '26': 0.5, '2': 0.25}, '2642': {'264': 0.75, '26': 0.5, '2': 0.25}, '2643': {'264': 0.75, '26': 0.5, '2': 0.25}, '2651': {'265': 0.75, '26': 0.5, '2': 0.25}, '2652': {'265': 0.75, '26': 0.5, '2': 0.25}, '2653': {'265': 0.75, '26': 0.5, '2': 0.25}, '2654': {'265': 0.75, '26': 0.5, '2': 0.25}, '2655': {'265': 0.75, '26': 0.5, '2': 0.25}, '2656': {'265': 0.75, '26': 0.5, '2': 0.25}, '2659': {'265': 0.75, '26': 0.5, '2': 0.25}, '3111': {'311': 0.75, '31': 0.5, '3': 0.25}, '3112': {'311': 0.75, '31': 0.5, '3': 0.25}, '3113': {'311': 0.75, '31': 0.5, '3': 0.25}, '3114': {'311': 0.75, '31': 0.5, '3': 0.25}, '3115': {'311': 0.75, '31': 0.5, '3': 0.25}, '3116': {'311': 0.75, '31': 0.5, '3': 0.25}, '3117': {'311': 0.75, '31': 0.5, '3': 0.25}, '3118': {'311': 0.75, '31': 0.5, '3': 0.25}, '3119': {'311': 0.75, '31': 0.5, '3': 0.25}, '3121': {'312': 0.75, '31': 0.5, '3': 0.25}, '3122': {'312': 0.75, '31': 0.5, '3': 0.25}, '3123': {'312': 0.75, '31': 0.5, '3': 0.25}, '3131': {'313': 0.75, '31': 0.5, '3': 0.25}, '3132': {'313': 0.75, '31': 0.5, '3': 0.25}, '3133': {'313': 0.75, '31': 0.5, '3': 0.25}, '3134': {'313': 0.75, '31': 0.5, '3': 0.25}, '3135': {'313': 0.75, '31': 0.5, '3': 0.25}, '3139': {'313': 0.75, '31': 0.5, '3': 0.25}, '3141': {'314': 0.75, '31': 0.5, '3': 0.25}, '3142': {'314': 0.75, '31': 0.5, '3': 0.25}, '3143': {'314': 0.75, '31': 0.5, '3': 0.25}, '3151': {'315': 0.75, '31': 0.5, '3': 0.25}, '3152': {'315': 0.75, '31': 0.5, '3': 0.25}, '3153': {'315': 0.75, '31': 0.5, '3': 0.25}, '3154': {'315': 0.75, '31': 0.5, '3': 0.25}, '3155': {'315': 0.75, '31': 0.5, '3': 0.25}, '3211': {'321': 0.75, '32': 0.5, '3': 0.25}, '3212': {'321': 0.75, '32': 0.5, '3': 0.25}, '3213': {'321': 0.75, '32': 0.5, '3': 0.25}, '3214': {'321': 0.75, '32': 0.5, '3': 0.25}, '3221': {'322': 0.75, '32': 0.5, '3': 0.25}, '3222': {'322': 0.75, '32': 0.5, '3': 0.25}, '3230': {'323': 0.75, '32': 0.5, '3': 0.25}, '3240': {'324': 0.75, '32': 0.5, '3': 0.25}, '3251': {'325': 0.75, '32': 0.5, '3': 0.25}, '3252': {'325': 0.75, '32': 0.5, '3': 0.25}, '3253': {'325': 0.75, '32': 0.5, '3': 0.25}, '3254': {'325': 0.75, '32': 0.5, '3': 0.25}, '3255': {'325': 0.75, '32': 0.5, '3': 0.25}, '3256': {'325': 0.75, '32': 0.5, '3': 0.25}, '3257': {'325': 0.75, '32': 0.5, '3': 0.25}, '3258': {'325': 0.75, '32': 0.5, '3': 0.25}, '3259': {'325': 0.75, '32': 0.5, '3': 0.25}, '3311': {'331': 0.75, '33': 0.5, '3': 0.25}, '3312': {'331': 0.75, '33': 0.5, '3': 0.25}, '3313': {'331': 0.75, '33': 0.5, '3': 0.25}, '3314': {'331': 0.75, '33': 0.5, '3': 0.25}, '3315': {'331': 0.75, '33': 0.5, '3': 0.25}, '3321': {'332': 0.75, '33': 0.5, '3': 0.25}, '3322': {'332': 0.75, '33': 0.5, '3': 0.25}, '3323': {'332': 0.75, '33': 0.5, '3': 0.25}, '3324': {'332': 0.75, '33': 0.5, '3': 0.25}, '3331': {'333': 0.75, '33': 0.5, '3': 0.25}, '3332': {'333': 0.75, '33': 0.5, '3': 0.25}, '3333': {'333': 0.75, '33': 0.5, '3': 0.25}, '3334': {'333': 0.75, '33': 0.5, '3': 0.25}, '3339': {'333': 0.75, '33': 0.5, '3': 0.25}, '3341': {'334': 0.75, '33': 0.5, '3': 0.25}, '3342': {'334': 0.75, '33': 0.5, '3': 0.25}, '3343': {'334': 0.75, '33': 0.5, '3': 0.25}, '3344': {'334': 0.75, '33': 0.5, '3': 0.25}, '3351': {'335': 0.75, '33': 0.5, '3': 0.25}, '3352': {'335': 0.75, '33': 0.5, '3': 0.25}, '3353': {'335': 0.75, '33': 0.5, '3': 0.25}, '3354': {'335': 0.75, '33': 0.5, '3': 0.25}, '3355': {'335': 0.75, '33': 0.5, '3': 0.25}, '3359': {'335': 0.75, '33': 0.5, '3': 0.25}, '3411': {'341': 0.75, '34': 0.5, '3': 0.25}, '3412': {'341': 0.75, '34': 0.5, '3': 0.25}, '3413': {'341': 0.75, '34': 0.5, '3': 0.25}, '3421': {'342': 0.75, '34': 0.5, '3': 0.25}, '3422': {'342': 0.75, '34': 0.5, '3': 0.25}, '3423': {'342': 0.75, '34': 0.5, '3': 0.25}, '3431': {'343': 0.75, '34': 0.5, '3': 0.25}, '3432': {'343': 0.75, '34': 0.5, '3': 0.25}, '3433': {'343': 0.75, '34': 0.5, '3': 0.25}, '3434': {'343': 0.75, '34': 0.5, '3': 0.25}, '3435': {'343': 0.75, '34': 0.5, '3': 0.25}, '3511': {'351': 0.75, '35': 0.5, '3': 0.25}, '3512': {'351': 0.75, '35': 0.5, '3': 0.25}, '3513': {'351': 0.75, '35': 0.5, '3': 0.25}, '3514': {'351': 0.75, '35': 0.5, '3': 0.25}, '3521': {'352': 0.75, '35': 0.5, '3': 0.25}, '3522': {'352': 0.75, '35': 0.5, '3': 0.25}, '4110': {'411': 0.75, '41': 0.5, '4': 0.25}, '4120': {'412': 0.75, '41': 0.5, '4': 0.25}, '4131': {'413': 0.75, '41': 0.5, '4': 0.25}, '4132': {'413': 0.75, '41': 0.5, '4': 0.25}, '4211': {'421': 0.75, '42': 0.5, '4': 0.25}, '4212': {'421': 0.75, '42': 0.5, '4': 0.25}, '4213': {'421': 0.75, '42': 0.5, '4': 0.25}, '4214': {'421': 0.75, '42': 0.5, '4': 0.25}, '4221': {'422': 0.75, '42': 0.5, '4': 0.25}, '4222': {'422': 0.75, '42': 0.5, '4': 0.25}, '4223': {'422': 0.75, '42': 0.5, '4': 0.25}, '4224': {'422': 0.75, '42': 0.5, '4': 0.25}, '4225': {'422': 0.75, '42': 0.5, '4': 0.25}, '4226': {'422': 0.75, '42': 0.5, '4': 0.25}, '4227': {'422': 0.75, '42': 0.5, '4': 0.25}, '4229': {'422': 0.75, '42': 0.5, '4': 0.25}, '4311': {'431': 0.75, '43': 0.5, '4': 0.25}, '4312': {'431': 0.75, '43': 0.5, '4': 0.25}, '4313': {'431': 0.75, '43': 0.5, '4': 0.25}, '4321': {'432': 0.75, '43': 0.5, '4': 0.25}, '4322': {'432': 0.75, '43': 0.5, '4': 0.25}, '4323': {'432': 0.75, '43': 0.5, '4': 0.25}, '4411': {'441': 0.75, '44': 0.5, '4': 0.25}, '4412': {'441': 0.75, '44': 0.5, '4': 0.25}, '4413': {'441': 0.75, '44': 0.5, '4': 0.25}, '4414': {'441': 0.75, '44': 0.5, '4': 0.25}, '4415': {'441': 0.75, '44': 0.5, '4': 0.25}, '4416': {'441': 0.75, '44': 0.5, '4': 0.25}, '4419': {'441': 0.75, '44': 0.5, '4': 0.25}, '5111': {'511': 0.75, '51': 0.5, '5': 0.25}, '5112': {'511': 0.75, '51': 0.5, '5': 0.25}, '5113': {'511': 0.75, '51': 0.5, '5': 0.25}, '5120': {'512': 0.75, '51': 0.5, '5': 0.25}, '5131': {'513': 0.75, '51': 0.5, '5': 0.25}, '5132': {'513': 0.75, '51': 0.5, '5': 0.25}, '5141': {'514': 0.75, '51': 0.5, '5': 0.25}, '5142': {'514': 0.75, '51': 0.5, '5': 0.25}, '5151': {'515': 0.75, '51': 0.5, '5': 0.25}, '5152': {'515': 0.75, '51': 0.5, '5': 0.25}, '5153': {'515': 0.75, '51': 0.5, '5': 0.25}, '5161': {'516': 0.75, '51': 0.5, '5': 0.25}, '5162': {'516': 0.75, '51': 0.5, '5': 0.25}, '5163': {'516': 0.75, '51': 0.5, '5': 0.25}, '5164': {'516': 0.75, '51': 0.5, '5': 0.25}, '5165': {'516': 0.75, '51': 0.5, '5': 0.25}, '5169': {'516': 0.75, '51': 0.5, '5': 0.25}, '5211': {'521': 0.75, '52': 0.5, '5': 0.25}, '5212': {'521': 0.75, '52': 0.5, '5': 0.25}, '5221': {'522': 0.75, '52': 0.5, '5': 0.25}, '5222': {'522': 0.75, '52': 0.5, '5': 0.25}, '5223': {'522': 0.75, '52': 0.5, '5': 0.25}, '5230': {'523': 0.75, '52': 0.5, '5': 0.25}, '5241': {'524': 0.75, '52': 0.5, '5': 0.25}, '5242': {'524': 0.75, '52': 0.5, '5': 0.25}, '5243': {'524': 0.75, '52': 0.5, '5': 0.25}, '5244': {'524': 0.75, '52': 0.5, '5': 0.25}, '5245': {'524': 0.75, '52': 0.5, '5': 0.25}, '5246': {'524': 0.75, '52': 0.5, '5': 0.25}, '5249': {'524': 0.75, '52': 0.5, '5': 0.25}, '5311': {'531': 0.75, '53': 0.5, '5': 0.25}, '5312': {'531': 0.75, '53': 0.5, '5': 0.25}, '5321': {'532': 0.75, '53': 0.5, '5': 0.25}, '5322': {'532': 0.75, '53': 0.5, '5': 0.25}, '5329': {'532': 0.75, '53': 0.5, '5': 0.25}, '5411': {'541': 0.75, '54': 0.5, '5': 0.25}, '5412': {'541': 0.75, '54': 0.5, '5': 0.25}, '5413': {'541': 0.75, '54': 0.5, '5': 0.25}, '5414': {'541': 0.75, '54': 0.5, '5': 0.25}, '5419': {'541': 0.75, '54': 0.5, '5': 0.25}, '6111': {'611': 0.75, '61': 0.5, '6': 0.25}, '6112': {'611': 0.75, '61': 0.5, '6': 0.25}, '6113': {'611': 0.75, '61': 0.5, '6': 0.25}, '6114': {'611': 0.75, '61': 0.5, '6': 0.25}, '6121': {'612': 0.75, '61': 0.5, '6': 0.25}, '6122': {'612': 0.75, '61': 0.5, '6': 0.25}, '6123': {'612': 0.75, '61': 0.5, '6': 0.25}, '6129': {'612': 0.75, '61': 0.5, '6': 0.25}, '6130': {'613': 0.75, '61': 0.5, '6': 0.25}, '6210': {'621': 0.75, '62': 0.5, '6': 0.25}, '6221': {'622': 0.75, '62': 0.5, '6': 0.25}, '6222': {'622': 0.75, '62': 0.5, '6': 0.25}, '6223': {'622': 0.75, '62': 0.5, '6': 0.25}, '6224': {'622': 0.75, '62': 0.5, '6': 0.25}, '6310': {'631': 0.75, '63': 0.5, '6': 0.25}, '6320': {'632': 0.75, '63': 0.5, '6': 0.25}, '6330': {'633': 0.75, '63': 0.5, '6': 0.25}, '6340': {'634': 0.75, '63': 0.5, '6': 0.25}, '7111': {'711': 0.75, '71': 0.5, '7': 0.25}, '7112': {'711': 0.75, '71': 0.5, '7': 0.25}, '7113': {'711': 0.75, '71': 0.5, '7': 0.25}, '7114': {'711': 0.75, '71': 0.5, '7': 0.25}, '7115': {'711': 0.75, '71': 0.5, '7': 0.25}, '7119': {'711': 0.75, '71': 0.5, '7': 0.25}, '7121': {'712': 0.75, '71': 0.5, '7': 0.25}, '7122': {'712': 0.75, '71': 0.5, '7': 0.25}, '7123': {'712': 0.75, '71': 0.5, '7': 0.25}, '7124': {'712': 0.75, '71': 0.5, '7': 0.25}, '7125': {'712': 0.75, '71': 0.5, '7': 0.25}, '7126': {'712': 0.75, '71': 0.5, '7': 0.25}, '7127': {'712': 0.75, '71': 0.5, '7': 0.25}, '7131': {'713': 0.75, '71': 0.5, '7': 0.25}, '7132': {'713': 0.75, '71': 0.5, '7': 0.25}, '7133': {'713': 0.75, '71': 0.5, '7': 0.25}, '7211': {'721': 0.75, '72': 0.5, '7': 0.25}, '7212': {'721': 0.75, '72': 0.5, '7': 0.25}, '7213': {'721': 0.75, '72': 0.5, '7': 0.25}, '7214': {'721': 0.75, '72': 0.5, '7': 0.25}, '7215': {'721': 0.75, '72': 0.5, '7': 0.25}, '7221': {'722': 0.75, '72': 0.5, '7': 0.25}, '7222': {'722': 0.75, '72': 0.5, '7': 0.25}, '7223': {'722': 0.75, '72': 0.5, '7': 0.25}, '7224': {'722': 0.75, '72': 0.5, '7': 0.25}, '7231': {'723': 0.75, '72': 0.5, '7': 0.25}, '7232': {'723': 0.75, '72': 0.5, '7': 0.25}, '7233': {'723': 0.75, '72': 0.5, '7': 0.25}, '7234': {'723': 0.75, '72': 0.5, '7': 0.25}, '7311': {'731': 0.75, '73': 0.5, '7': 0.25}, '7312': {'731': 0.75, '73': 0.5, '7': 0.25}, '7313': {'731': 0.75, '73': 0.5, '7': 0.25}, '7314': {'731': 0.75, '73': 0.5, '7': 0.25}, '7315': {'731': 0.75, '73': 0.5, '7': 0.25}, '7316': {'731': 0.75, '73': 0.5, '7': 0.25}, '7317': {'731': 0.75, '73': 0.5, '7': 0.25}, '7318': {'731': 0.75, '73': 0.5, '7': 0.25}, '7319': {'731': 0.75, '73': 0.5, '7': 0.25}, '7321': {'732': 0.75, '73': 0.5, '7': 0.25}, '7322': {'732': 0.75, '73': 0.5, '7': 0.25}, '7323': {'732': 0.75, '73': 0.5, '7': 0.25}, '7411': {'741': 0.75, '74': 0.5, '7': 0.25}, '7412': {'741': 0.75, '74': 0.5, '7': 0.25}, '7413': {'741': 0.75, '74': 0.5, '7': 0.25}, '7421': {'742': 0.75, '74': 0.5, '7': 0.25}, '7422': {'742': 0.75, '74': 0.5, '7': 0.25}, '7511': {'751': 0.75, '75': 0.5, '7': 0.25}, '7512': {'751': 0.75, '75': 0.5, '7': 0.25}, '7513': {'751': 0.75, '75': 0.5, '7': 0.25}, '7514': {'751': 0.75, '75': 0.5, '7': 0.25}, '7515': {'751': 0.75, '75': 0.5, '7': 0.25}, '7516': {'751': 0.75, '75': 0.5, '7': 0.25}, '7521': {'752': 0.75, '75': 0.5, '7': 0.25}, '7522': {'752': 0.75, '75': 0.5, '7': 0.25}, '7523': {'752': 0.75, '75': 0.5, '7': 0.25}, '7531': {'753': 0.75, '75': 0.5, '7': 0.25}, '7532': {'753': 0.75, '75': 0.5, '7': 0.25}, '7533': {'753': 0.75, '75': 0.5, '7': 0.25}, '7534': {'753': 0.75, '75': 0.5, '7': 0.25}, '7535': {'753': 0.75, '75': 0.5, '7': 0.25}, '7536': {'753': 0.75, '75': 0.5, '7': 0.25}, '7541': {'754': 0.75, '75': 0.5, '7': 0.25}, '7542': {'754': 0.75, '75': 0.5, '7': 0.25}, '7543': {'754': 0.75, '75': 0.5, '7': 0.25}, '7544': {'754': 0.75, '75': 0.5, '7': 0.25}, '7549': {'754': 0.75, '75': 0.5, '7': 0.25}, '8111': {'811': 0.75, '81': 0.5, '8': 0.25}, '8112': {'811': 0.75, '81': 0.5, '8': 0.25}, '8113': {'811': 0.75, '81': 0.5, '8': 0.25}, '8114': {'811': 0.75, '81': 0.5, '8': 0.25}, '8121': {'812': 0.75, '81': 0.5, '8': 0.25}, '8122': {'812': 0.75, '81': 0.5, '8': 0.25}, '8131': {'813': 0.75, '81': 0.5, '8': 0.25}, '8132': {'813': 0.75, '81': 0.5, '8': 0.25}, '8141': {'814': 0.75, '81': 0.5, '8': 0.25}, '8142': {'814': 0.75, '81': 0.5, '8': 0.25}, '8143': {'814': 0.75, '81': 0.5, '8': 0.25}, '8151': {'815': 0.75, '81': 0.5, '8': 0.25}, '8152': {'815': 0.75, '81': 0.5, '8': 0.25}, '8153': {'815': 0.75, '81': 0.5, '8': 0.25}, '8154': {'815': 0.75, '81': 0.5, '8': 0.25}, '8155': {'815': 0.75, '81': 0.5, '8': 0.25}, '8156': {'815': 0.75, '81': 0.5, '8': 0.25}, '8157': {'815': 0.75, '81': 0.5, '8': 0.25}, '8159': {'815': 0.75, '81': 0.5, '8': 0.25}, '8160': {'816': 0.75, '81': 0.5, '8': 0.25}, '8171': {'817': 0.75, '81': 0.5, '8': 0.25}, '8172': {'817': 0.75, '81': 0.5, '8': 0.25}, '8181': {'818': 0.75, '81': 0.5, '8': 0.25}, '8182': {'818': 0.75, '81': 0.5, '8': 0.25}, '8183': {'818': 0.75, '81': 0.5, '8': 0.25}, '8189': {'818': 0.75, '81': 0.5, '8': 0.25}, '8211': {'821': 0.75, '82': 0.5, '8': 0.25}, '8212': {'821': 0.75, '82': 0.5, '8': 0.25}, '8219': {'821': 0.75, '82': 0.5, '8': 0.25}, '8311': {'831': 0.75, '83': 0.5, '8': 0.25}, '8312': {'831': 0.75, '83': 0.5, '8': 0.25}, '8321': {'832': 0.75, '83': 0.5, '8': 0.25}, '8322': {'832': 0.75, '83': 0.5, '8': 0.25}, '8331': {'833': 0.75, '83': 0.5, '8': 0.25}, '8332': {'833': 0.75, '83': 0.5, '8': 0.25}, '8341': {'834': 0.75, '83': 0.5, '8': 0.25}, '8342': {'834': 0.75, '83': 0.5, '8': 0.25}, '8343': {'834': 0.75, '83': 0.5, '8': 0.25}, '8344': {'834': 0.75, '83': 0.5, '8': 0.25}, '8350': {'835': 0.75, '83': 0.5, '8': 0.25}, '9111': {'911': 0.75, '91': 0.5, '9': 0.25}, '9112': {'911': 0.75, '91': 0.5, '9': 0.25}, '9121': {'912': 0.75, '91': 0.5, '9': 0.25}, '9122': {'912': 0.75, '91': 0.5, '9': 0.25}, '9123': {'912': 0.75, '91': 0.5, '9': 0.25}, '9129': {'912': 0.75, '91': 0.5, '9': 0.25}, '9211': {'921': 0.75, '92': 0.5, '9': 0.25}, '9212': {'921': 0.75, '92': 0.5, '9': 0.25}, '9213': {'921': 0.75, '92': 0.5, '9': 0.25}, '9214': {'921': 0.75, '92': 0.5, '9': 0.25}, '9215': {'921': 0.75, '92': 0.5, '9': 0.25}, '9216': {'921': 0.75, '92': 0.5, '9': 0.25}, '9311': {'931': 0.75, '93': 0.5, '9': 0.25}, '9312': {'931': 0.75, '93': 0.5, '9': 0.25}, '9313': {'931': 0.75, '93': 0.5, '9': 0.25}, '9321': {'932': 0.75, '93': 0.5, '9': 0.25}, '9329': {'932': 0.75, '93': 0.5, '9': 0.25}, '9331': {'933': 0.75, '93': 0.5, '9': 0.25}, '9332': {'933': 0.75, '93': 0.5, '9': 0.25}, '9333': {'933': 0.75, '93': 0.5, '9': 0.25}, '9334': {'933': 0.75, '93': 0.5, '9': 0.25}, '9411': {'941': 0.75, '94': 0.5, '9': 0.25}, '9412': {'941': 0.75, '94': 0.5, '9': 0.25}, '9510': {'951': 0.75, '95': 0.5, '9': 0.25}, '9520': {'952': 0.75, '95': 0.5, '9': 0.25}, '9611': {'961': 0.75, '96': 0.5, '9': 0.25}, '9612': {'961': 0.75, '96': 0.5, '9': 0.25}, '9613': {'961': 0.75, '96': 0.5, '9': 0.25}, '9621': {'962': 0.75, '96': 0.5, '9': 0.25}, '9622': {'962': 0.75, '96': 0.5, '9': 0.25}, '9623': {'962': 0.75, '96': 0.5, '9': 0.25}, '9624': {'962': 0.75, '96': 0.5, '9': 0.25}, '9629': {'962': 0.75, '96': 0.5, '9': 0.25}, '0110': {'011': 0.75, '01': 0.5, '0': 0.25}, '0210': {'021': 0.75, '02': 0.5, '0': 0.25}, '0310': {'031': 0.75, '03': 0.5, '0': 0.25}}\n",
176
+ "Accuracy: 0.8611914401388086\n",
177
+ "Hierarchical Precision: 0.989010989010989, Hierarchical Recall: 0.9836065573770492, Hierarchical F-measure: 0.9863013698630136\n",
178
+ "Evaluation results saved to isco_results.txt\n"
179
+ ]
180
+ }
181
+ ],
182
+ "source": [
183
+ "import os\n",
184
+ "from datasets import load_dataset\n",
185
+ "from transformers import pipeline\n",
186
+ "import evaluate\n",
187
+ "import json\n",
188
+ "\n",
189
+ "# Ensure that the HF_TOKEN environment variable is set\n",
190
+ "hf_token = os.getenv(\"HF_TOKEN\")\n",
191
+ "if hf_token is None:\n",
192
+ " raise ValueError(\"HF_TOKEN environment variable is not set.\")\n",
193
+ "\n",
194
+ "# Load the dataset\n",
195
+ "test_data_subset = (\n",
196
+ " load_dataset(\n",
197
+ " \"ICILS/multilingual_parental_occupations\", split=\"test\", token=hf_token\n",
198
+ " )\n",
199
+ " .shuffle(seed=42)\n",
200
+ " .select(range(100))\n",
201
+ ")\n",
202
+ "test_data = load_dataset(\n",
203
+ " \"ICILS/multilingual_parental_occupations\", split=\"test\", token=hf_token\n",
204
+ ")\n",
205
+ "\n",
206
+ "# Initialize the pipeline\n",
207
+ "pipe = pipeline(\"text-classification\", model=\"ICILS/XLM-R-ISCO\", token=hf_token)\n",
208
+ "\n",
209
+ "# Define the mapping from ISCO_CODE_TITLE to ISCO codes\n",
210
+ "def extract_isco_code(isco_code_title: str):\n",
211
+ " # ISCO_CODE_TITLE is a string like \"7412 Electrical Mechanics and Fitters\" so we need to extract the first part for the evaluation.\n",
212
+ " return isco_code_title.split()[0]\n",
213
+ "\n",
214
+ "# Evaluate the model\n",
215
+ "predictions = []\n",
216
+ "references = []\n",
217
+ "for example in test_data:\n",
218
+ "\n",
219
+ " # Predict\n",
220
+ " prediction = pipe(\n",
221
+ " example[\"JOB_DUTIES\"]\n",
222
+ " ) # Use the correct key \"JOB_DUTIES\" for the text data\n",
223
+ " predicted_label = extract_isco_code(prediction[0][\"label\"])\n",
224
+ " predictions.append(predicted_label)\n",
225
+ "\n",
226
+ " # Reference\n",
227
+ " reference_label = example[\"ISCO\"] # Use the correct key \"ISCO\" for the ISCO code\n",
228
+ " references.append(reference_label)\n",
229
+ "\n",
230
+ "# Initialize the hierarchical accuracy measure\n",
231
+ "hierarchical_accuracy = evaluate.load(\"danieldux/isco_hierarchical_accuracy\")\n",
232
+ "\n",
233
+ "# Compute the hierarchical accuracy\n",
234
+ "results = hierarchical_accuracy.compute(predictions=predictions, references=references)\n",
235
+ "\n",
236
+ "# Save the results to a JSON file\n",
237
+ "with open(\"isco_results.json\", \"w\") as f:\n",
238
+ " json.dump(results, f)\n",
239
+ "\n",
240
+ "print(\"Evaluation results saved to isco_results.json\")"
241
+ ]
242
+ }
243
+ ],
244
+ "metadata": {
245
+ "kernelspec": {
246
+ "display_name": "autogenstudio",
247
+ "language": "python",
248
+ "name": "python3"
249
+ },
250
+ "language_info": {
251
+ "codemirror_mode": {
252
+ "name": "ipython",
253
+ "version": 3
254
+ },
255
+ "file_extension": ".py",
256
+ "mimetype": "text/x-python",
257
+ "name": "python",
258
+ "nbconvert_exporter": "python",
259
+ "pygments_lexer": "ipython3",
260
+ "version": "3.11.7"
261
+ }
262
+ },
263
+ "nbformat": 4,
264
+ "nbformat_minor": 2
265
+ }