danieldux commited on
Commit
d726519
1 Parent(s): 5af5762

Refactor ISCO_Hierarchical_Accuracy class to improve code readability and add input validation

Browse files
Files changed (1) hide show
  1. isco_hierarchical_accuracy.py +37 -24
isco_hierarchical_accuracy.py CHANGED
@@ -16,9 +16,7 @@
16
  from typing import List, Set, Dict, Tuple
17
  import evaluate
18
  import datasets
19
-
20
- # import ham
21
- # import isco
22
 
23
 
24
  # TODO: Add BibTeX citation
@@ -264,40 +262,55 @@ class ISCO_Hierarchical_Accuracy(evaluate.Metric):
264
  print("Weighted ISCO hierarchy dictionary created as isco_hierarchy")
265
  # print(self.isco_hierarchy)
266
 
267
- # Define the mapping from ISCO_CODE_TITLE to ISCO codes
268
- def _extract_isco_code(isco_code_title: str):
269
- # ISCO_CODE_TITLE is a string like "7412 Electrical Mechanics and Fitters" so we need to extract the first part for the evaluation.
270
- return isco_code_title.split()[0]
 
 
 
 
 
 
 
 
 
 
271
 
272
  def _compute(self, predictions, references):
273
- """Returns the accuracy scores."""
274
- # Convert the inputs to strings
275
- if len(predictions[0]) > 4:
276
- extracted_predictions = []
277
- extracted_references = []
278
- for p in predictions:
279
- extracted_predictions.append(self._extract_isco_code(p))
280
- for r in references:
281
- extracted_references.append(self._extract_isco_code(r))
282
- predictions = extracted_predictions
283
- references = extracted_references
 
284
  predictions = [str(p) for p in predictions]
285
  references = [str(r) for r in references]
 
 
 
 
 
 
 
 
 
286
 
287
  # Calculate accuracy
288
  accuracy = sum(i == j for i, j in zip(predictions, references)) / len(
289
  predictions
290
  )
291
-
292
  # Calculate hierarchical precision, recall and f-measure
293
- hierarchy = self.isco_hierarchy
294
  hP, hR = self.calculate_hierarchical_precision_recall(
295
- references, predictions, hierarchy
296
  )
297
  hF = self.hierarchical_f_measure(hP, hR)
298
- print(
299
- f"Accuracy: {accuracy}, Hierarchical Precision: {hP}, Hierarchical Recall: {hR}, Hierarchical F-measure: {hF}"
300
- )
301
 
302
  return {
303
  "accuracy": accuracy,
 
16
  from typing import List, Set, Dict, Tuple
17
  import evaluate
18
  import datasets
19
+ import re
 
 
20
 
21
 
22
  # TODO: Add BibTeX citation
 
262
  print("Weighted ISCO hierarchy dictionary created as isco_hierarchy")
263
  # print(self.isco_hierarchy)
264
 
265
+ # Function to check if a string matches the 4-digit code pattern
266
+ def _is_valid_code(self, code: str):
267
+ # Regular expression pattern for a 4-digit code
268
+ pattern = r"^\d{4}$"
269
+ if re.match(pattern, code):
270
+ return True
271
+ else:
272
+ return False
273
+
274
+ def _validate_codes(self, codes: list, code_type):
275
+ if not all(self._is_valid_code(code) for code in codes):
276
+ raise ValueError(
277
+ f"All {code_type} labels must start with a 4-digit ISCO-08 code string."
278
+ )
279
 
280
  def _compute(self, predictions, references):
281
+ """
282
+ Computes the accuracy scores, hierarchical precision, recall, and f-measure.
283
+
284
+ Args:
285
+ predictions (List[str]): A list of 4-digit ISCO-08 prediction label strings.
286
+ references (List[str]): A list of 4-digit ISCO-08 reference label strings.
287
+
288
+ Returns:
289
+ dict: A dictionary containing the accuracy, hierarchical precision, hierarchical recall,
290
+ and hierarchical f-measure scores.
291
+ """
292
+ # Cast all prediction labels as strings
293
  predictions = [str(p) for p in predictions]
294
  references = [str(r) for r in references]
295
+ # Check if the first prediction label is longer than 4 characters
296
+ if len(predictions[0]) > 4:
297
+ # Extract the first 4 characters from each prediction label
298
+ predictions = [str(p.split()[0]) for p in predictions]
299
+ # Check if all prediction labels are 4-digit strings
300
+ self._validate_codes(predictions, "prediction")
301
+ # Repeat for reference labels
302
+ references = [str(r.split()[0]) for r in references]
303
+ self._validate_codes(references, "reference")
304
 
305
  # Calculate accuracy
306
  accuracy = sum(i == j for i, j in zip(predictions, references)) / len(
307
  predictions
308
  )
 
309
  # Calculate hierarchical precision, recall and f-measure
 
310
  hP, hR = self.calculate_hierarchical_precision_recall(
311
+ references, predictions, self.isco_hierarchy
312
  )
313
  hF = self.hierarchical_f_measure(hP, hR)
 
 
 
314
 
315
  return {
316
  "accuracy": accuracy,