danieldux commited on
Commit
cb3e43c
·
1 Parent(s): d726519

tests notebook

Browse files
Files changed (1) hide show
  1. tests.ipynb +275 -452
tests.ipynb CHANGED
@@ -166,6 +166,41 @@
166
  "execution_count": null,
167
  "metadata": {},
168
  "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  "source": [
170
  "import os\n",
171
  "from datasets import load_dataset\n",
@@ -178,34 +213,45 @@
178
  "if hf_token is None:\n",
179
  " raise ValueError(\"HF_TOKEN environment variable is not set.\")\n",
180
  "\n",
 
 
 
181
  "# Load the dataset\n",
182
  "test_data_subset = (\n",
183
- " load_dataset(\n",
184
- " \"ICILS/multilingual_parental_occupations\", split=\"test\", token=hf_token\n",
185
- " )\n",
186
- " .shuffle(seed=42)\n",
187
- " .select(range(100))\n",
188
- ")\n",
189
- "test_data = load_dataset(\n",
190
- " \"ICILS/multilingual_parental_occupations\", split=\"test\", token=hf_token\n",
191
- ")\n",
192
- "\n",
193
- "validation_data = load_dataset(\n",
194
- " \"ICILS/multilingual_parental_occupations\", split=\"validation\", token=hf_token\n",
195
  ")\n",
196
  "\n",
197
  "# Initialize the pipeline\n",
198
- "pipe = pipeline(\"text-classification\", model=\"ICILS/XLM-R-ISCO\", token=hf_token)\n",
199
- "\n",
200
- "# Define the mapping from ISCO_CODE_TITLE to ISCO codes\n",
201
- "def extract_isco_code(isco_code_title: str):\n",
202
- " # ISCO_CODE_TITLE is a string like \"7412 Electrical Mechanics and Fitters\" so we need to extract the first part for the evaluation.\n",
203
- " return isco_code_title.split()[0]\n",
204
  "\n",
205
  "# Initialize the hierarchical accuracy measure\n",
206
  "hierarchical_accuracy = evaluate.load(\"danieldux/isco_hierarchical_accuracy\")"
207
  ]
208
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  {
210
  "cell_type": "markdown",
211
  "metadata": {},
@@ -215,29 +261,35 @@
215
  },
216
  {
217
  "cell_type": "code",
218
- "execution_count": 2,
219
  "metadata": {},
220
  "outputs": [
221
  {
222
  "name": "stdout",
223
  "output_type": "stream",
224
  "text": [
225
- "Accuracy: 0.8611914401388086, Hierarchical Precision: 0.989010989010989, Hierarchical Recall: 0.9836065573770492, Hierarchical F-measure: 0.9863013698630136\n",
226
- "Evaluation results saved to isco_test_results.json\n"
227
  ]
228
  }
229
  ],
230
  "source": [
 
 
 
 
 
231
  "# Evaluate the model\n",
232
  "predictions = []\n",
233
  "references = []\n",
234
- "for example in test_data:\n",
235
  "\n",
236
  " # Predict\n",
237
  " prediction = pipe(\n",
238
  " example[\"JOB_DUTIES\"]\n",
239
  " ) # Use the key \"JOB_DUTIES\" for the text data\n",
240
- " predicted_label = extract_isco_code(prediction[0][\"label\"])\n",
 
241
  " predictions.append(predicted_label)\n",
242
  "\n",
243
  " # Reference\n",
@@ -248,10 +300,158 @@
248
  "test_results = hierarchical_accuracy.compute(predictions=predictions, references=references)\n",
249
  "\n",
250
  "# Save the results to a JSON file\n",
251
- "with open(\"isco_test_results.json\", \"w\") as f:\n",
252
  " json.dump(test_results, f)\n",
253
  "\n",
254
- "print(\"Evaluation results saved to isco_test_results.json\")"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  ]
256
  },
257
  {
@@ -309,9 +509,16 @@
309
  "# Inter rater agreement"
310
  ]
311
  },
 
 
 
 
 
 
 
312
  {
313
  "cell_type": "code",
314
- "execution_count": 70,
315
  "metadata": {},
316
  "outputs": [],
317
  "source": [
@@ -332,6 +539,13 @@
332
  "grouped_df = isco_rel_df.groupby('LANGUAGE')"
333
  ]
334
  },
 
 
 
 
 
 
 
335
  {
336
  "cell_type": "code",
337
  "execution_count": null,
@@ -371,19 +585,14 @@
371
  "results_df.loc[len(results_df)] = average_row\n",
372
  "\n",
373
  "\n",
374
- "results_df.to_csv('language_results.csv', index=False)"
375
  ]
376
  },
377
  {
378
- "cell_type": "code",
379
- "execution_count": null,
380
  "metadata": {},
381
- "outputs": [],
382
  "source": [
383
- "# create a dataframe with samples where ISCO and ISCO_REL the same\n",
384
- "isco_rel_df_same = isco_rel_df[isco_rel_df['ISCO'] == isco_rel_df['ISCO_REL']]\n",
385
- "\n",
386
- "isco_rel_df_same"
387
  ]
388
  },
389
  {
@@ -392,467 +601,81 @@
392
  "metadata": {},
393
  "outputs": [],
394
  "source": [
395
- "# create a dataframe with samples where ISCO and ISCO_REL are different\n",
396
- "isco_rel_df_diff = isco_rel_df[isco_rel_df['ISCO'] != isco_rel_df['ISCO_REL']]\n",
397
  "\n",
398
- "isco_rel_df_diff"
399
- ]
400
- },
401
- {
402
- "cell_type": "code",
403
- "execution_count": 64,
404
- "metadata": {},
405
- "outputs": [],
406
- "source": [
407
- "# Make a list of all values in ISCO and ISCO_REL columns\n",
408
- "coder1 = list(isco_rel_df['ISCO'])\n",
409
- "coder2 = list(isco_rel_df['ISCO_REL'])"
410
- ]
411
- },
412
- {
413
- "cell_type": "code",
414
- "execution_count": null,
415
- "metadata": {},
416
- "outputs": [],
417
- "source": [
418
- "# Compute the hierarchical accuracy\n",
419
- "reliability_results = hierarchical_accuracy.compute(predictions=coder2, references=coder1)\n",
420
  "\n",
421
- "# Save the results to a JSON file\n",
422
- "with open(\"isco_rel_results.json\", \"w\") as f:\n",
423
- " json.dump(reliability_results, f)\n",
424
  "\n",
425
- "print(\"Evaluation results saved to isco_rel_results.json\")"
 
426
  ]
427
  },
428
  {
429
  "cell_type": "markdown",
430
  "metadata": {},
431
  "source": [
432
- "## Giskard model testing"
433
  ]
434
  },
435
  {
436
- "cell_type": "code",
437
- "execution_count": 1,
438
  "metadata": {},
439
- "outputs": [],
440
  "source": [
441
- "import numpy as np\n",
442
- "import pandas as pd\n",
443
- "from scipy.special import softmax\n",
444
- "from datasets import load_dataset\n",
445
- "from transformers import AutoModelForSequenceClassification, AutoTokenizer\n",
446
- "\n",
447
- "from giskard import Dataset, Model, scan, testing, GiskardClient, Suite"
448
  ]
449
  },
450
  {
451
  "cell_type": "code",
452
- "execution_count": 3,
453
  "metadata": {},
454
- "outputs": [
455
- {
456
- "data": {
457
- "text/html": [
458
- "<div>\n",
459
- "<style scoped>\n",
460
- " .dataframe tbody tr th:only-of-type {\n",
461
- " vertical-align: middle;\n",
462
- " }\n",
463
- "\n",
464
- " .dataframe tbody tr th {\n",
465
- " vertical-align: top;\n",
466
- " }\n",
467
- "\n",
468
- " .dataframe thead th {\n",
469
- " text-align: right;\n",
470
- " }\n",
471
- "</style>\n",
472
- "<table border=\"1\" class=\"dataframe\">\n",
473
- " <thead>\n",
474
- " <tr style=\"text-align: right;\">\n",
475
- " <th></th>\n",
476
- " <th>IDSTUD</th>\n",
477
- " <th>JOB_DUTIES</th>\n",
478
- " <th>ISCO</th>\n",
479
- " <th>ISCO_REL</th>\n",
480
- " <th>ISCO_TITLE</th>\n",
481
- " <th>ISCO_CODE_TITLE</th>\n",
482
- " <th>COUNTRY</th>\n",
483
- " <th>LANGUAGE</th>\n",
484
- " </tr>\n",
485
- " </thead>\n",
486
- " <tbody>\n",
487
- " <tr>\n",
488
- " <th>0</th>\n",
489
- " <td>10670109</td>\n",
490
- " <td>forældre 1: Han arbejder som med-chef sammen...</td>\n",
491
- " <td>7412</td>\n",
492
- " <td>None</td>\n",
493
- " <td>Electrical Mechanics and Fitters</td>\n",
494
- " <td>7412 Electrical Mechanics and Fitters</td>\n",
495
- " <td>DNK</td>\n",
496
- " <td>da</td>\n",
497
- " </tr>\n",
498
- " <tr>\n",
499
- " <th>1</th>\n",
500
- " <td>10130106</td>\n",
501
- " <td>asistente de parbulo y basica. ayudaba en la e...</td>\n",
502
- " <td>5312</td>\n",
503
- " <td>5312</td>\n",
504
- " <td>Teachers' Aides</td>\n",
505
- " <td>5312 Teachers' Aides</td>\n",
506
- " <td>CHL</td>\n",
507
- " <td>es</td>\n",
508
- " </tr>\n",
509
- " <tr>\n",
510
- " <th>2</th>\n",
511
- " <td>10740120</td>\n",
512
- " <td>trabajaba en el campo como capatas. aveces cui...</td>\n",
513
- " <td>6121</td>\n",
514
- " <td>None</td>\n",
515
- " <td>Livestock and Dairy Producers</td>\n",
516
- " <td>6121 Livestock and Dairy Producers</td>\n",
517
- " <td>URY</td>\n",
518
- " <td>es</td>\n",
519
- " </tr>\n",
520
- " <tr>\n",
521
- " <th>3</th>\n",
522
- " <td>10170109</td>\n",
523
- " <td>gas abastible. vende gas abastible</td>\n",
524
- " <td>9621</td>\n",
525
- " <td>5243</td>\n",
526
- " <td>Messengers, Package Deliverers and Luggage Por...</td>\n",
527
- " <td>9621 Messengers, Package Deliverers and Luggag...</td>\n",
528
- " <td>CHL</td>\n",
529
- " <td>es</td>\n",
530
- " </tr>\n",
531
- " <tr>\n",
532
- " <th>4</th>\n",
533
- " <td>11480109</td>\n",
534
- " <td>jordbruk. sår potatis tar upp potatis plogar h...</td>\n",
535
- " <td>6111</td>\n",
536
- " <td>6111</td>\n",
537
- " <td>Field Crop and Vegetable Growers</td>\n",
538
- " <td>6111 Field Crop and Vegetable Growers</td>\n",
539
- " <td>FIN</td>\n",
540
- " <td>sv</td>\n",
541
- " </tr>\n",
542
- " <tr>\n",
543
- " <th>...</th>\n",
544
- " <td>...</td>\n",
545
- " <td>...</td>\n",
546
- " <td>...</td>\n",
547
- " <td>...</td>\n",
548
- " <td>...</td>\n",
549
- " <td>...</td>\n",
550
- " <td>...</td>\n",
551
- " <td>...</td>\n",
552
- " </tr>\n",
553
- " <tr>\n",
554
- " <th>495</th>\n",
555
- " <td>11780107</td>\n",
556
- " <td>acountent mannager|she mannages calls for jobs...</td>\n",
557
- " <td>1211</td>\n",
558
- " <td>9998</td>\n",
559
- " <td>Finance Managers</td>\n",
560
- " <td>1211 Finance Managers</td>\n",
561
- " <td>AUS</td>\n",
562
- " <td>en</td>\n",
563
- " </tr>\n",
564
- " <tr>\n",
565
- " <th>496</th>\n",
566
- " <td>10850104</td>\n",
567
- " <td>geometra/muratore. proggetta case e le restaura</td>\n",
568
- " <td>3112</td>\n",
569
- " <td>3112</td>\n",
570
- " <td>Civil Engineering Technicians</td>\n",
571
- " <td>3112 Civil Engineering Technicians</td>\n",
572
- " <td>ITA</td>\n",
573
- " <td>it</td>\n",
574
- " </tr>\n",
575
- " <tr>\n",
576
- " <th>497</th>\n",
577
- " <td>11460111</td>\n",
578
- " <td>fa parte della misericordia. Trasporta i malat...</td>\n",
579
- " <td>3258</td>\n",
580
- " <td>3258</td>\n",
581
- " <td>Ambulance Workers</td>\n",
582
- " <td>3258 Ambulance Workers</td>\n",
583
- " <td>ITA</td>\n",
584
- " <td>it</td>\n",
585
- " </tr>\n",
586
- " <tr>\n",
587
- " <th>498</th>\n",
588
- " <td>10340111</td>\n",
589
- " <td>사회복지사. 회사에서 복지원 관리</td>\n",
590
- " <td>2635</td>\n",
591
- " <td>2635</td>\n",
592
- " <td>Social Work and Counselling Professionals</td>\n",
593
- " <td>2635 Social Work and Counselling Professionals</td>\n",
594
- " <td>KOR</td>\n",
595
- " <td>ko</td>\n",
596
- " </tr>\n",
597
- " <tr>\n",
598
- " <th>499</th>\n",
599
- " <td>10370105</td>\n",
600
- " <td>자영업. 가게를 운영하신다.</td>\n",
601
- " <td>5221</td>\n",
602
- " <td>None</td>\n",
603
- " <td>Shopkeepers</td>\n",
604
- " <td>5221 Shopkeepers</td>\n",
605
- " <td>KOR</td>\n",
606
- " <td>ko</td>\n",
607
- " </tr>\n",
608
- " </tbody>\n",
609
- "</table>\n",
610
- "<p>500 rows × 8 columns</p>\n",
611
- "</div>"
612
- ],
613
- "text/plain": [
614
- " IDSTUD JOB_DUTIES ISCO \\\n",
615
- "0 10670109 forældre 1: Han arbejder som med-chef sammen... 7412 \n",
616
- "1 10130106 asistente de parbulo y basica. ayudaba en la e... 5312 \n",
617
- "2 10740120 trabajaba en el campo como capatas. aveces cui... 6121 \n",
618
- "3 10170109 gas abastible. vende gas abastible 9621 \n",
619
- "4 11480109 jordbruk. sår potatis tar upp potatis plogar h... 6111 \n",
620
- ".. ... ... ... \n",
621
- "495 11780107 acountent mannager|she mannages calls for jobs... 1211 \n",
622
- "496 10850104 geometra/muratore. proggetta case e le restaura 3112 \n",
623
- "497 11460111 fa parte della misericordia. Trasporta i malat... 3258 \n",
624
- "498 10340111 사회복지사. 회사에서 복지원 관리 2635 \n",
625
- "499 10370105 자영업. 가게를 운영하신다. 5221 \n",
626
- "\n",
627
- " ISCO_REL ISCO_TITLE \\\n",
628
- "0 None Electrical Mechanics and Fitters \n",
629
- "1 5312 Teachers' Aides \n",
630
- "2 None Livestock and Dairy Producers \n",
631
- "3 5243 Messengers, Package Deliverers and Luggage Por... \n",
632
- "4 6111 Field Crop and Vegetable Growers \n",
633
- ".. ... ... \n",
634
- "495 9998 Finance Managers \n",
635
- "496 3112 Civil Engineering Technicians \n",
636
- "497 3258 Ambulance Workers \n",
637
- "498 2635 Social Work and Counselling Professionals \n",
638
- "499 None Shopkeepers \n",
639
- "\n",
640
- " ISCO_CODE_TITLE COUNTRY LANGUAGE \n",
641
- "0 7412 Electrical Mechanics and Fitters DNK da \n",
642
- "1 5312 Teachers' Aides CHL es \n",
643
- "2 6121 Livestock and Dairy Producers URY es \n",
644
- "3 9621 Messengers, Package Deliverers and Luggag... CHL es \n",
645
- "4 6111 Field Crop and Vegetable Growers FIN sv \n",
646
- ".. ... ... ... \n",
647
- "495 1211 Finance Managers AUS en \n",
648
- "496 3112 Civil Engineering Technicians ITA it \n",
649
- "497 3258 Ambulance Workers ITA it \n",
650
- "498 2635 Social Work and Counselling Professionals KOR ko \n",
651
- "499 5221 Shopkeepers KOR ko \n",
652
- "\n",
653
- "[500 rows x 8 columns]"
654
- ]
655
- },
656
- "execution_count": 3,
657
- "metadata": {},
658
- "output_type": "execute_result"
659
- }
660
- ],
661
  "source": [
662
- "MODEL_NAME = \"ICILS/XLM-R-ISCO\"\n",
663
- "# DATASET_CONFIG = {\"path\": \"tweet_eval\", \"name\": \"sentiment\", \"split\": \"validation\"}\n",
664
- "TEXT_COLUMN = \"JOB_DUTIES\"\n",
665
- "TARGET_COLUMN = \"ISCO_CODE_TITLE\"\n",
666
- "\n",
667
- "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
668
- "model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)\n",
669
- "\n",
670
- "label2id: dict = model.config.label2id\n",
671
- "id2label: dict = model.config.id2label\n",
672
- "# LABEL_MAPPING = id2label.items()\n",
673
- "\n",
674
- "# raw_data = load_dataset(**DATASET_CONFIG).to_pandas().iloc[:500]\n",
675
- "raw_data = load_dataset(\"ICILS/multilingual_parental_occupations\", split=\"test\").to_pandas().iloc[:500]\n",
676
- "# raw_data = raw_data.replace({\"ISCO_CODE_TITLE\": LABEL_MAPPING})\n",
677
- "raw_data[\"ISCO\"] = raw_data[\"ISCO\"].astype(str)\n",
678
- "raw_data[\"ISCO_REL\"] = raw_data[\"ISCO_REL\"].astype(str)\n",
679
  "\n",
680
- "raw_data"
681
  ]
682
  },
683
  {
684
  "cell_type": "code",
685
- "execution_count": 4,
686
  "metadata": {},
687
- "outputs": [
688
- {
689
- "name": "stdout",
690
- "output_type": "stream",
691
- "text": [
692
- "2024-03-15 01:07:06,923 pid:166193 MainThread giskard.datasets.base INFO Your 'pandas.DataFrame' is successfully wrapped by Giskard's 'Dataset' wrapper class.\n",
693
- "2024-03-15 01:07:06,925 pid:166193 MainThread giskard.models.automodel INFO Your 'prediction_function' is successfully wrapped by Giskard's 'PredictionFunctionModel' wrapper class.\n"
694
- ]
695
- },
696
- {
697
- "name": "stderr",
698
- "output_type": "stream",
699
- "text": [
700
- "/home/dux/miniconda3/envs/autogenstudio/lib/python3.11/site-packages/giskard/datasets/base/__init__.py:466: UserWarning: The column ISCO is declared as numeric but has 'object' as data type. To avoid potential future issues, make sure to cast this column to the correct data type.\n",
701
- " warning(\n"
702
- ]
703
- }
704
- ],
705
  "source": [
706
- "giskard_dataset = Dataset(\n",
707
- " df=raw_data, # A pandas.DataFrame that contains the raw data (before all the pre-processing steps) and the actual ground truth variable (target).\n",
708
- " target=TARGET_COLUMN, # Ground truth variable.\n",
709
- " name=\"ISCO-08 Parental Occupation Corpus\", # Optional.\n",
710
- ")\n",
711
- "\n",
712
- "def prediction_function(df: pd.DataFrame) -> np.ndarray:\n",
713
- " encoded_input = tokenizer(list(df[TEXT_COLUMN]), padding=True, return_tensors=\"pt\")\n",
714
- " output = model(**encoded_input)\n",
715
- " return softmax(output[\"logits\"].detach().numpy(), axis=1)\n",
716
- "\n",
717
  "\n",
718
- "giskard_model = Model(\n",
719
- " model=prediction_function, # A prediction function that encapsulates all the data pre-processing steps and that\n",
720
- " model_type=\"classification\", # Either regression, classification or text_generation.\n",
721
- " name=\"XLM-R ISCO\", # Optional\n",
722
- " classification_labels=list(label2id.keys()), # Their order MUST be identical to the prediction_function's\n",
723
- " feature_names=[TEXT_COLUMN], # Default: all columns of your dataset\n",
724
- ")"
725
  ]
726
  },
727
  {
728
  "cell_type": "code",
729
- "execution_count": 5,
730
  "metadata": {},
731
- "outputs": [
732
- {
733
- "name": "stdout",
734
- "output_type": "stream",
735
- "text": [
736
- "2024-03-15 01:07:10,228 pid:166193 MainThread giskard.datasets.base INFO Casting dataframe columns from {'JOB_DUTIES': 'object'} to {'JOB_DUTIES': 'object'}\n"
737
- ]
738
- },
739
- {
740
- "name": "stdout",
741
- "output_type": "stream",
742
- "text": [
743
- "2024-03-15 01:07:12,838 pid:166193 MainThread giskard.utils.logging_utils INFO Predicted dataset with shape (10, 8) executed in 0:00:02.617399\n",
744
- "2024-03-15 01:07:12,848 pid:166193 MainThread giskard.datasets.base INFO Casting dataframe columns from {'JOB_DUTIES': 'object'} to {'JOB_DUTIES': 'object'}\n",
745
- "2024-03-15 01:07:13,007 pid:166193 MainThread giskard.utils.logging_utils INFO Predicted dataset with shape (1, 8) executed in 0:00:00.166843\n",
746
- "2024-03-15 01:07:13,015 pid:166193 MainThread giskard.datasets.base INFO Casting dataframe columns from {'JOB_DUTIES': 'object'} to {'JOB_DUTIES': 'object'}\n",
747
- "2024-03-15 01:07:13,017 pid:166193 MainThread giskard.utils.logging_utils INFO Predicted dataset with shape (10, 8) executed in 0:00:00.009517\n",
748
- "2024-03-15 01:07:13,029 pid:166193 MainThread giskard.datasets.base INFO Casting dataframe columns from {'JOB_DUTIES': 'object'} to {'JOB_DUTIES': 'object'}\n"
749
- ]
750
- },
751
- {
752
- "ename": "",
753
- "evalue": "",
754
- "output_type": "error",
755
- "traceback": [
756
- "\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n",
757
- "\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n",
758
- "\u001b[1;31mClick <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. \n",
759
- "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
760
- ]
761
- }
762
- ],
763
  "source": [
764
- "results = scan(giskard_model, giskard_dataset)"
 
 
765
  ]
766
  },
767
  {
768
  "cell_type": "code",
769
  "execution_count": null,
770
  "metadata": {},
771
- "outputs": [
772
- {
773
- "ename": "NameError",
774
- "evalue": "name 'results' is not defined",
775
- "output_type": "error",
776
- "traceback": [
777
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
778
- "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
779
- "Cell \u001b[0;32mIn[7], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m display(\u001b[43mresults\u001b[49m)\n\u001b[1;32m 3\u001b[0m \u001b[38;5;66;03m# Save it to a file\u001b[39;00m\n\u001b[1;32m 4\u001b[0m results\u001b[38;5;241m.\u001b[39mto_html(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mscan_report.html\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
780
- "\u001b[0;31mNameError\u001b[0m: name 'results' is not defined"
781
- ]
782
- }
783
- ],
784
- "source": [
785
- "display(results)\n",
786
- "\n",
787
- "# Save it to a file\n",
788
- "results.to_html(\"scan_report.html\")"
789
- ]
790
- },
791
- {
792
- "cell_type": "code",
793
- "execution_count": 2,
794
- "metadata": {},
795
- "outputs": [
796
- {
797
- "ename": "GiskardError",
798
- "evalue": "No details or messages available.",
799
- "output_type": "error",
800
- "traceback": [
801
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
802
- "\u001b[0;31mGiskardError\u001b[0m Traceback (most recent call last)",
803
- "Cell \u001b[0;32mIn[2], line 10\u001b[0m\n\u001b[1;32m 7\u001b[0m project_name \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mxlmr_isco\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 9\u001b[0m \u001b[38;5;66;03m# Create a giskard client to communicate with Giskard\u001b[39;00m\n\u001b[0;32m---> 10\u001b[0m client \u001b[38;5;241m=\u001b[39m \u001b[43mGiskardClient\u001b[49m\u001b[43m(\u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[43m)\u001b[49m\n",
804
- "File \u001b[0;32m~/miniconda3/envs/autogenstudio/lib/python3.11/site-packages/giskard/client/giskard_client.py:153\u001b[0m, in \u001b[0;36mGiskardClient.__init__\u001b[0;34m(self, url, key, hf_token)\u001b[0m\n\u001b[1;32m 150\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m hf_token:\n\u001b[1;32m 151\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_session\u001b[38;5;241m.\u001b[39mcookies[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mspaces-jwt\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m hf_token\n\u001b[0;32m--> 153\u001b[0m server_settings: ServerInfo \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_server_info\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 155\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m server_settings\u001b[38;5;241m.\u001b[39mserverVersion \u001b[38;5;241m!=\u001b[39m giskard\u001b[38;5;241m.\u001b[39m__version__:\n\u001b[1;32m 156\u001b[0m warning(\n\u001b[1;32m 157\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYour giskard client version (\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mgiskard\u001b[38;5;241m.\u001b[39m__version__\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m) does not match the hub version \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 158\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m(\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mserver_settings\u001b[38;5;241m.\u001b[39mserverVersion\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m). \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 159\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPlease upgrade your client to the latest version. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 160\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mpip install \u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgiskard[hub]>=2.0.0b\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m -U\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 161\u001b[0m )\n",
805
- "File \u001b[0;32m~/miniconda3/envs/autogenstudio/lib/python3.11/site-packages/giskard/client/giskard_client.py:417\u001b[0m, in \u001b[0;36mGiskardClient.get_server_info\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 416\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_server_info\u001b[39m(\u001b[38;5;28mself\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m ServerInfo:\n\u001b[0;32m--> 417\u001b[0m resp \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_session\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m/public-api/ml-worker-connect\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 418\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 419\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m ServerInfo\u001b[38;5;241m.\u001b[39mparse_obj(resp\u001b[38;5;241m.\u001b[39mjson())\n",
806
- "File \u001b[0;32m~/miniconda3/envs/autogenstudio/lib/python3.11/site-packages/requests/sessions.py:602\u001b[0m, in \u001b[0;36mSession.get\u001b[0;34m(self, url, **kwargs)\u001b[0m\n\u001b[1;32m 594\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"Sends a GET request. Returns :class:`Response` object.\u001b[39;00m\n\u001b[1;32m 595\u001b[0m \n\u001b[1;32m 596\u001b[0m \u001b[38;5;124;03m:param url: URL for the new :class:`Request` object.\u001b[39;00m\n\u001b[1;32m 597\u001b[0m \u001b[38;5;124;03m:param \\*\\*kwargs: Optional arguments that ``request`` takes.\u001b[39;00m\n\u001b[1;32m 598\u001b[0m \u001b[38;5;124;03m:rtype: requests.Response\u001b[39;00m\n\u001b[1;32m 599\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 601\u001b[0m kwargs\u001b[38;5;241m.\u001b[39msetdefault(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mallow_redirects\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[0;32m--> 602\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrequest\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mGET\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
807
- "File \u001b[0;32m~/miniconda3/envs/autogenstudio/lib/python3.11/site-packages/requests_toolbelt/sessions.py:76\u001b[0m, in \u001b[0;36mBaseUrlSession.request\u001b[0;34m(self, method, url, *args, **kwargs)\u001b[0m\n\u001b[1;32m 74\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Send the request after generating the complete URL.\"\"\"\u001b[39;00m\n\u001b[1;32m 75\u001b[0m url \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcreate_url(url)\n\u001b[0;32m---> 76\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mBaseUrlSession\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrequest\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 77\u001b[0m \u001b[43m \u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[1;32m 78\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
808
- "File \u001b[0;32m~/miniconda3/envs/autogenstudio/lib/python3.11/site-packages/requests/sessions.py:589\u001b[0m, in \u001b[0;36mSession.request\u001b[0;34m(self, method, url, params, data, headers, cookies, files, auth, timeout, allow_redirects, proxies, hooks, stream, verify, cert, json)\u001b[0m\n\u001b[1;32m 584\u001b[0m send_kwargs \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 585\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtimeout\u001b[39m\u001b[38;5;124m\"\u001b[39m: timeout,\n\u001b[1;32m 586\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mallow_redirects\u001b[39m\u001b[38;5;124m\"\u001b[39m: allow_redirects,\n\u001b[1;32m 587\u001b[0m }\n\u001b[1;32m 588\u001b[0m send_kwargs\u001b[38;5;241m.\u001b[39mupdate(settings)\n\u001b[0;32m--> 589\u001b[0m resp \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msend\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprep\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43msend_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 591\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m resp\n",
809
- "File \u001b[0;32m~/miniconda3/envs/autogenstudio/lib/python3.11/site-packages/requests/sessions.py:703\u001b[0m, in \u001b[0;36mSession.send\u001b[0;34m(self, request, **kwargs)\u001b[0m\n\u001b[1;32m 700\u001b[0m start \u001b[38;5;241m=\u001b[39m preferred_clock()\n\u001b[1;32m 702\u001b[0m \u001b[38;5;66;03m# Send the request\u001b[39;00m\n\u001b[0;32m--> 703\u001b[0m r \u001b[38;5;241m=\u001b[39m \u001b[43madapter\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msend\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrequest\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 705\u001b[0m \u001b[38;5;66;03m# Total elapsed time of the request (approximately)\u001b[39;00m\n\u001b[1;32m 706\u001b[0m elapsed \u001b[38;5;241m=\u001b[39m preferred_clock() \u001b[38;5;241m-\u001b[39m start\n",
810
- "File \u001b[0;32m~/miniconda3/envs/autogenstudio/lib/python3.11/site-packages/requests/adapters.py:538\u001b[0m, in \u001b[0;36mHTTPAdapter.send\u001b[0;34m(self, request, stream, timeout, verify, cert, proxies)\u001b[0m\n\u001b[1;32m 535\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 536\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m\n\u001b[0;32m--> 538\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbuild_response\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrequest\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mresp\u001b[49m\u001b[43m)\u001b[49m\n",
811
- "File \u001b[0;32m~/miniconda3/envs/autogenstudio/lib/python3.11/site-packages/giskard/client/giskard_client.py:107\u001b[0m, in \u001b[0;36mErrorHandlingAdapter.build_response\u001b[0;34m(self, req, resp)\u001b[0m\n\u001b[1;32m 105\u001b[0m resp \u001b[38;5;241m=\u001b[39m \u001b[38;5;28msuper\u001b[39m(ErrorHandlingAdapter, \u001b[38;5;28mself\u001b[39m)\u001b[38;5;241m.\u001b[39mbuild_response(req, resp)\n\u001b[1;32m 106\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _get_status(resp) \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m400\u001b[39m:\n\u001b[0;32m--> 107\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m explain_error(resp)\n\u001b[1;32m 109\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m resp\n",
812
- "\u001b[0;31mGiskardError\u001b[0m: No details or messages available."
813
- ]
814
- }
815
- ],
816
  "source": [
817
- "import giskard\n",
818
- "from datasets import load_dataset\n",
819
- "\n",
820
- "dataset = load_dataset(\"ICILS/multilingual_parental_occupations\", split=\"test\")\n",
821
- "\n",
822
- "# Replace this with your own data & model creation.\n",
823
- "# df = giskard.demo.titanic_df()\n",
824
- "df = dataset\n",
825
- "demo_data_preprocessing_function, demo_sklearn_model = giskard.demo.titanic_pipeline()\n",
826
- "\n",
827
- "# Wrap your Pandas DataFrame\n",
828
- "giskard_dataset = giskard.Dataset(df=df,\n",
829
- " target=\"ISCO_CODE_TITLE\",\n",
830
- " name=\"ISCO-08 Parental Occupation Corpus\",\n",
831
- " cat_columns=['LANGUAGE', 'COUNTRY'])\n",
832
- "\n",
833
- "# Wrap your model\n",
834
- "def prediction_function(df):\n",
835
- " preprocessed_df = demo_data_preprocessing_function(df)\n",
836
- " return demo_sklearn_model.predict_proba(preprocessed_df)\n",
837
- "\n",
838
- "giskard_model = giskard.Model(model=prediction_function,\n",
839
- " model_type=\"classification\",\n",
840
- " name=\"Titanic model\",\n",
841
- " classification_labels=demo_sklearn_model.classes_,\n",
842
- " feature_names=['PassengerId', 'Pclass', 'Name', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare', 'Embarked'])\n",
843
- "\n",
844
- "# Then apply the scan\n",
845
- "results = giskard.scan(giskard_model, giskard_dataset)\n",
846
- "\n",
847
- "\n",
848
- "# Create a Giskard client\n",
849
- "client = giskard.GiskardClient(\n",
850
- " url=\"https://danieldux-giskard.hf.space\", # URL of your Giskard instance\n",
851
- " key=\"<Generate your API Key on the Giskard Hub settings page first>\")\n",
852
  "\n",
 
 
 
853
  "\n",
854
- "# Upload an automatically created test suite to the current project ✉️\n",
855
- "results.generate_test_suite(\"Test suite created by scan\").upload(client, \"xlmr_isco\")\n"
856
  ]
857
  }
858
  ],
 
166
  "execution_count": null,
167
  "metadata": {},
168
  "outputs": [],
169
+ "source": [
170
+ "from datasets import load_dataset, get_dataset_config_names, get_dataset_infos, get_dataset_split_names\n",
171
+ "\n",
172
+ "dataset = load_dataset(\"ICILS/multilingual_parental_occupations\", \"ilo\")\n",
173
+ "dataset"
174
+ ]
175
+ },
176
+ {
177
+ "cell_type": "code",
178
+ "execution_count": 2,
179
+ "metadata": {},
180
+ "outputs": [
181
+ {
182
+ "data": {
183
+ "application/vnd.jupyter.widget-view+json": {
184
+ "model_id": "4634a4a344384ef28d182adeea1f5afc",
185
+ "version_major": 2,
186
+ "version_minor": 0
187
+ },
188
+ "text/plain": [
189
+ "Downloading builder script: 0%| | 0.00/13.4k [00:00<?, ?B/s]"
190
+ ]
191
+ },
192
+ "metadata": {},
193
+ "output_type": "display_data"
194
+ },
195
+ {
196
+ "name": "stdout",
197
+ "output_type": "stream",
198
+ "text": [
199
+ "ISCO CSV file downloaded\n",
200
+ "Weighted ISCO hierarchy dictionary created as isco_hierarchy\n"
201
+ ]
202
+ }
203
+ ],
204
  "source": [
205
  "import os\n",
206
  "from datasets import load_dataset\n",
 
213
  "if hf_token is None:\n",
214
  " raise ValueError(\"HF_TOKEN environment variable is not set.\")\n",
215
  "\n",
216
+ "test_split = load_dataset(\"ICILS/multilingual_parental_occupations\", \"icils\", split=\"test\", token=hf_token)\n",
217
+ "validation_split = load_dataset(\"ICILS/multilingual_parental_occupations\", \"icils\", split=\"validation\", token=hf_token)\n",
218
+ "\n",
219
  "# Load the dataset\n",
220
  "test_data_subset = (\n",
221
+ " test_split.shuffle(seed=42).select(range(100))\n",
 
 
 
 
 
 
 
 
 
 
 
222
  ")\n",
223
  "\n",
224
  "# Initialize the pipeline\n",
225
+ "model = \"danieldux/XLM-R-ISCO-v2\" # ICILS/XLM-R-ISCO\n",
226
+ "pipe = pipeline(\"text-classification\", model=model, token=hf_token)\n",
 
 
 
 
227
  "\n",
228
  "# Initialize the hierarchical accuracy measure\n",
229
  "hierarchical_accuracy = evaluate.load(\"danieldux/isco_hierarchical_accuracy\")"
230
  ]
231
  },
232
+ {
233
+ "cell_type": "code",
234
+ "execution_count": 3,
235
+ "metadata": {},
236
+ "outputs": [
237
+ {
238
+ "data": {
239
+ "text/plain": [
240
+ "Dataset({\n",
241
+ " features: ['IDSTUD', 'JOB_DUTIES', 'ISCO', 'ISCO_REL', 'ISCO_TITLE', 'ISCO_CODE_TITLE', 'COUNTRY', 'LANGUAGE'],\n",
242
+ " num_rows: 100\n",
243
+ "})"
244
+ ]
245
+ },
246
+ "execution_count": 3,
247
+ "metadata": {},
248
+ "output_type": "execute_result"
249
+ }
250
+ ],
251
+ "source": [
252
+ "test_data_subset"
253
+ ]
254
+ },
255
  {
256
  "cell_type": "markdown",
257
  "metadata": {},
 
261
  },
262
  {
263
  "cell_type": "code",
264
+ "execution_count": 4,
265
  "metadata": {},
266
  "outputs": [
267
  {
268
  "name": "stdout",
269
  "output_type": "stream",
270
  "text": [
271
+ "2024-03-31--01:29\n",
272
+ "Evaluation results saved to test_split_results-2024-03-31--01:29.json\n"
273
  ]
274
  }
275
  ],
276
  "source": [
277
+ "import datetime\n",
278
+ "\n",
279
+ "stamp = datetime.datetime.now().strftime(\"%Y-%m-%d--%H:%M\")\n",
280
+ "print(stamp)\n",
281
+ "\n",
282
  "# Evaluate the model\n",
283
  "predictions = []\n",
284
  "references = []\n",
285
+ "for example in test_data_subset:\n",
286
  "\n",
287
  " # Predict\n",
288
  " prediction = pipe(\n",
289
  " example[\"JOB_DUTIES\"]\n",
290
  " ) # Use the key \"JOB_DUTIES\" for the text data\n",
291
+ " # predicted_label = extract_isco_code(prediction[0][\"label\"])\n",
292
+ " predicted_label = prediction[0][\"label\"]\n",
293
  " predictions.append(predicted_label)\n",
294
  "\n",
295
  " # Reference\n",
 
300
  "test_results = hierarchical_accuracy.compute(predictions=predictions, references=references)\n",
301
  "\n",
302
  "# Save the results to a JSON file\n",
303
+ "with open(f\"test_split_results-{stamp}.json\", \"w\") as f:\n",
304
  " json.dump(test_results, f)\n",
305
  "\n",
306
+ "print(f\"Evaluation results saved to test_split_results-{stamp}.json\")\n",
307
+ "\n"
308
+ ]
309
+ },
310
+ {
311
+ "cell_type": "code",
312
+ "execution_count": 5,
313
+ "metadata": {},
314
+ "outputs": [
315
+ {
316
+ "data": {
317
+ "text/plain": [
318
+ "{'accuracy': 0.82,\n",
319
+ " 'hierarchical_precision': 0.9090909090909091,\n",
320
+ " 'hierarchical_recall': 0.8839779005524862,\n",
321
+ " 'hierarchical_fmeasure': 0.8963585434173669}"
322
+ ]
323
+ },
324
+ "execution_count": 5,
325
+ "metadata": {},
326
+ "output_type": "execute_result"
327
+ }
328
+ ],
329
+ "source": [
330
+ "test_results"
331
+ ]
332
+ },
333
+ {
334
+ "cell_type": "code",
335
+ "execution_count": 6,
336
+ "metadata": {},
337
+ "outputs": [
338
+ {
339
+ "name": "stdout",
340
+ "output_type": "stream",
341
+ "text": [
342
+ "Accuracy: 0.8523316062176166, Hierarchical Precision: 0.9711751662971175, Hierarchical Recall: 0.9733333333333334, Hierarchical F-measure: 0.9722530521642619\n"
343
+ ]
344
+ },
345
+ {
346
+ "name": "stderr",
347
+ "output_type": "stream",
348
+ "text": [
349
+ "/tmp/ipykernel_376175/1380879571.py:30: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.\n",
350
+ " results_df = pd.concat(\n"
351
+ ]
352
+ },
353
+ {
354
+ "name": "stdout",
355
+ "output_type": "stream",
356
+ "text": [
357
+ "Accuracy: 0.8549323017408124, Hierarchical Precision: 0.9425981873111783, Hierarchical Recall: 0.96, Hierarchical F-measure: 0.9512195121951218\n",
358
+ "Accuracy: 0.817351598173516, Hierarchical Precision: 0.9076305220883534, Hierarchical Recall: 0.9377593360995851, Hierarchical F-measure: 0.9224489795918367\n",
359
+ "Accuracy: 0.8160919540229885, Hierarchical Precision: 0.9140893470790378, Hierarchical Recall: 0.9204152249134948, Hierarchical F-measure: 0.9172413793103448\n",
360
+ "Accuracy: 0.7801724137931034, Hierarchical Precision: 0.8776978417266187, Hierarchical Recall: 0.9207547169811321, Hierarchical F-measure: 0.8987108655616942\n",
361
+ "Accuracy: 0.8200836820083682, Hierarchical Precision: 0.9007352941176471, Hierarchical Recall: 0.9176029962546817, Hierarchical F-measure: 0.9090909090909092\n",
362
+ "Accuracy: 0.5149253731343284, Hierarchical Precision: 0.7487684729064039, Hierarchical Recall: 0.8, Hierarchical F-measure: 0.7735368956743003\n",
363
+ "Accuracy: 0.9, Hierarchical Precision: 0.9244444444444444, Hierarchical Recall: 0.9285714285714286, Hierarchical F-measure: 0.9265033407572383\n",
364
+ "Accuracy: 0.9030612244897959, Hierarchical Precision: 0.9509803921568627, Hierarchical Recall: 0.9603960396039604, Hierarchical F-measure: 0.9556650246305418\n",
365
+ "Accuracy: 0.7836538461538461, Hierarchical Precision: 0.9047619047619048, Hierarchical Recall: 0.8916967509025271, Hierarchical F-measure: 0.8981818181818182\n",
366
+ "Accuracy: 0.8707865168539326, Hierarchical Precision: 0.9269406392694064, Hierarchical Recall: 0.9441860465116279, Hierarchical F-measure: 0.9354838709677419\n",
367
+ "Accuracy: 0.9230769230769231, Hierarchical Precision: 0.9, Hierarchical Recall: 0.9473684210526315, Hierarchical F-measure: 0.9230769230769231\n",
368
+ " Language Accuracy Hierarchical Precision Hierarchical Recall \\\n",
369
+ "0 sv 0.923077 0.900000 0.947368 \n",
370
+ "1 ko 0.870787 0.926941 0.944186 \n",
371
+ "2 pt 0.783654 0.904762 0.891697 \n",
372
+ "3 kk 0.903061 0.950980 0.960396 \n",
373
+ "4 ru 0.900000 0.924444 0.928571 \n",
374
+ "5 de 0.514925 0.748768 0.800000 \n",
375
+ "6 fi 0.820084 0.900735 0.917603 \n",
376
+ "7 da 0.780172 0.877698 0.920755 \n",
377
+ "8 fr 0.816092 0.914089 0.920415 \n",
378
+ "9 it 0.817352 0.907631 0.937759 \n",
379
+ "10 es 0.854932 0.942598 0.960000 \n",
380
+ "11 en 0.852332 0.971175 0.973333 \n",
381
+ "\n",
382
+ " Hierarchical F1 \n",
383
+ "0 0.923077 \n",
384
+ "1 0.935484 \n",
385
+ "2 0.898182 \n",
386
+ "3 0.955665 \n",
387
+ "4 0.926503 \n",
388
+ "5 0.773537 \n",
389
+ "6 0.909091 \n",
390
+ "7 0.898711 \n",
391
+ "8 0.917241 \n",
392
+ "9 0.922449 \n",
393
+ "10 0.951220 \n",
394
+ "11 0.972253 \n"
395
+ ]
396
+ }
397
+ ],
398
+ "source": [
399
+ "import pandas as pd\n",
400
+ "\n",
401
+ "test_data_df = test_data.to_pandas()\n",
402
+ "results_df = pd.DataFrame(columns=['Language', 'Accuracy', 'Hierarchical Precision', 'Hierarchical Recall', 'Hierarchical F1'])\n",
403
+ "\n",
404
+ "# Iterate over unique languages\n",
405
+ "for language in test_data_df[\"LANGUAGE\"].unique():\n",
406
+ " # Filter test data for the current language\n",
407
+ " test_data_subset = test_data_df[test_data_df[\"LANGUAGE\"] == language]\n",
408
+ "\n",
409
+ " # Evaluate the model for the current language\n",
410
+ " predictions = []\n",
411
+ " references = []\n",
412
+ " for example in test_data_subset.to_dict(\"records\"):\n",
413
+ " # Predict\n",
414
+ " prediction = pipe(example[\"JOB_DUTIES\"])\n",
415
+ " predicted_label = extract_isco_code(prediction[0][\"label\"])\n",
416
+ " predictions.append(predicted_label)\n",
417
+ "\n",
418
+ " # Reference\n",
419
+ " reference_label = example[\"ISCO\"]\n",
420
+ " references.append(reference_label)\n",
421
+ "\n",
422
+ " # Compute the hierarchical accuracy for the current language\n",
423
+ " test_results = hierarchical_accuracy.compute(\n",
424
+ " predictions=predictions, references=references\n",
425
+ " )\n",
426
+ "\n",
427
+ " # Save the results to a JSON file\n",
428
+ " results_df = pd.concat(\n",
429
+ " [\n",
430
+ " pd.DataFrame(\n",
431
+ " {\n",
432
+ " \"Language\": [language],\n",
433
+ " \"Accuracy\": [test_results[\"accuracy\"]],\n",
434
+ " \"Hierarchical Precision\": [test_results[\"hierarchical_precision\"]],\n",
435
+ " \"Hierarchical Recall\": [test_results[\"hierarchical_recall\"]],\n",
436
+ " \"Hierarchical F1\": [test_results[\"hierarchical_fmeasure\"]],\n",
437
+ " }\n",
438
+ " ),\n",
439
+ " results_df,\n",
440
+ " ],\n",
441
+ " ignore_index=True\n",
442
+ " )\n",
443
+ "\n",
444
+ "# Print the evaluation results\n",
445
+ "print(results_df)"
446
+ ]
447
+ },
448
+ {
449
+ "cell_type": "code",
450
+ "execution_count": 7,
451
+ "metadata": {},
452
+ "outputs": [],
453
+ "source": [
454
+ "results_df.to_csv('model_language_results.csv', index=False)"
455
  ]
456
  },
457
  {
 
509
  "# Inter rater agreement"
510
  ]
511
  },
512
+ {
513
+ "cell_type": "markdown",
514
+ "metadata": {},
515
+ "source": [
516
+ "## All ICILS 2018 data"
517
+ ]
518
+ },
519
  {
520
  "cell_type": "code",
521
+ "execution_count": 8,
522
  "metadata": {},
523
  "outputs": [],
524
  "source": [
 
539
  "grouped_df = isco_rel_df.groupby('LANGUAGE')"
540
  ]
541
  },
542
+ {
543
+ "cell_type": "markdown",
544
+ "metadata": {},
545
+ "source": [
546
+ "### By language"
547
+ ]
548
+ },
549
  {
550
  "cell_type": "code",
551
  "execution_count": null,
 
585
  "results_df.loc[len(results_df)] = average_row\n",
586
  "\n",
587
  "\n",
588
+ "results_df.to_csv('inter-rater_language_results.csv', index=False)"
589
  ]
590
  },
591
  {
592
+ "cell_type": "markdown",
 
593
  "metadata": {},
 
594
  "source": [
595
+ "## Training data"
 
 
 
596
  ]
597
  },
598
  {
 
601
  "metadata": {},
602
  "outputs": [],
603
  "source": [
604
+ "import pandas as pd\n",
 
605
  "\n",
606
+ "test_data_df = test_data.to_pandas()\n",
607
+ "unknown_reliability_samples = test_data_df[test_data_df['ISCO_REL'].isna() | test_data_df['ISCO_REL'].isin([\"9998\", \"9999\"])]\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
608
  "\n",
609
+ "# Exclude unknown reliability samples from test_data_df\n",
610
+ "test_split_rel_df = test_data_df[~test_data_df['ISCO_REL'].isna() & ~test_data_df['ISCO_REL'].isin([\"9998\", \"9999\"])]\n",
 
611
  "\n",
612
+ "# Group the DataFrame by LANGUAGE column\n",
613
+ "test_split_rel_grouped_df = test_split_rel_df.groupby('LANGUAGE')"
614
  ]
615
  },
616
  {
617
  "cell_type": "markdown",
618
  "metadata": {},
619
  "source": [
620
+ "## Validation data"
621
  ]
622
  },
623
  {
624
+ "cell_type": "markdown",
 
625
  "metadata": {},
 
626
  "source": [
627
+ "## Test data"
 
 
 
 
 
 
628
  ]
629
  },
630
  {
631
  "cell_type": "code",
632
+ "execution_count": null,
633
  "metadata": {},
634
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
635
  "source": [
636
+ "# create a dataframe with samples where ISCO and ISCO_REL the same\n",
637
+ "isco_rel_df_same = isco_rel_df[isco_rel_df['ISCO'] == isco_rel_df['ISCO_REL']]\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
638
  "\n",
639
+ "isco_rel_df_same"
640
  ]
641
  },
642
  {
643
  "cell_type": "code",
644
+ "execution_count": null,
645
  "metadata": {},
646
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
647
  "source": [
648
+ "# create a dataframe with samples where ISCO and ISCO_REL are different\n",
649
+ "isco_rel_df_diff = isco_rel_df[isco_rel_df['ISCO'] != isco_rel_df['ISCO_REL']]\n",
 
 
 
 
 
 
 
 
 
650
  "\n",
651
+ "isco_rel_df_diff"
 
 
 
 
 
 
652
  ]
653
  },
654
  {
655
  "cell_type": "code",
656
+ "execution_count": 64,
657
  "metadata": {},
658
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
659
  "source": [
660
+ "# Make a list of all values in ISCO and ISCO_REL columns\n",
661
+ "coder1 = list(isco_rel_df['ISCO'])\n",
662
+ "coder2 = list(isco_rel_df['ISCO_REL'])"
663
  ]
664
  },
665
  {
666
  "cell_type": "code",
667
  "execution_count": null,
668
  "metadata": {},
669
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
670
  "source": [
671
+ "# Compute the hierarchical accuracy\n",
672
+ "reliability_results = hierarchical_accuracy.compute(predictions=coder2, references=coder1)\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
673
  "\n",
674
+ "# Save the results to a JSON file\n",
675
+ "with open(\"isco_rel_results.json\", \"w\") as f:\n",
676
+ " json.dump(reliability_results, f)\n",
677
  "\n",
678
+ "print(\"Evaluation results saved to isco_rel_results.json\")"
 
679
  ]
680
  }
681
  ],