Spaces:
Running
Running
| import logging | |
| from typing import Dict, List, Union | |
| from src.task.task import Task, TaskType | |
| from src.task.task_names import Tasks | |
| def tasks_factory(task_names: Union[Dict, List[str], List[Tasks]]) -> List[Task]: | |
| """ | |
| Factory method to create a list of Task objects from a dictionary of task names and their predictions. | |
| """ | |
| tasks = [] | |
| if isinstance(task_names, Dict): | |
| task_names = task_names.get("tasks") | |
| task_names = [list(task.keys())[0] for task in task_names] | |
| task_names = [ | |
| Tasks(task_name) if isinstance(task_name, str) else task_name | |
| for task_name in task_names | |
| ] | |
| for task in task_names: | |
| match task: | |
| case Tasks.ALLOCINE: | |
| tasks.append( | |
| Task( | |
| task_name=task.value, | |
| metric="accuracy", | |
| task_type=TaskType.INFERENCE, | |
| ) | |
| ) | |
| case Tasks.FQUAD: | |
| tasks.append( | |
| Task( | |
| task_name=task.value, | |
| metric="fquad", | |
| task_type=TaskType.GENERATIVE, | |
| ) | |
| ) | |
| case Tasks.GQNLI: | |
| tasks.append( | |
| Task( | |
| task_name=task.value, | |
| metric="accuracy", | |
| task_type=TaskType.INFERENCE, | |
| ) | |
| ) | |
| case Tasks.PAWS_X: | |
| tasks.append( | |
| Task( | |
| task_name=task.value, | |
| metric="accuracy", | |
| task_type=TaskType.INFERENCE, | |
| ) | |
| ) | |
| case Tasks.PIAF: | |
| tasks.append( | |
| Task( | |
| task_name=task.value, | |
| metric="fquad", | |
| task_type=TaskType.GENERATIVE, | |
| ) | |
| ) | |
| case Tasks.QFRBLIMP: | |
| tasks.append( | |
| Task( | |
| task_name=task.value, | |
| metric="accuracy", | |
| task_type=TaskType.INFERENCE, | |
| ) | |
| ) | |
| case Tasks.QFRCOLA: | |
| tasks.append( | |
| Task( | |
| task_name=task.value, | |
| metric="accuracy", | |
| task_type=TaskType.INFERENCE, | |
| ) | |
| ) | |
| case Tasks.SICKFR: | |
| tasks.append( | |
| Task( | |
| task_name=task.value, | |
| metric="accuracy", | |
| task_type=TaskType.INFERENCE, | |
| ) | |
| ) | |
| case Tasks.STS22: | |
| tasks.append( | |
| Task( | |
| task_name=task.value, | |
| metric="accuracy", | |
| task_type=TaskType.INFERENCE, | |
| ) | |
| ) | |
| case Tasks.XNLI: | |
| tasks.append( | |
| Task( | |
| task_name=task.value, | |
| metric="accuracy", | |
| task_type=TaskType.INFERENCE, | |
| ) | |
| ) | |
| case Tasks.QFRCORE: | |
| tasks.append( | |
| Task( | |
| task_name=task.value, | |
| metric="accuracy", | |
| task_type=TaskType.INFERENCE, | |
| ) | |
| ) | |
| case Tasks.QFRCORT: | |
| tasks.append( | |
| Task( | |
| task_name=task.value, | |
| metric="accuracy", | |
| task_type=TaskType.INFERENCE, | |
| ) | |
| ) | |
| case Tasks.DACCORD: | |
| tasks.append( | |
| Task( | |
| task_name=task.value, | |
| metric="accuracy", | |
| task_type=TaskType.INFERENCE, | |
| ) | |
| ) | |
| case Tasks.FRENCH_BOOLQ: | |
| tasks.append( | |
| Task( | |
| task_name=task.value, | |
| metric="accuracy", | |
| task_type=TaskType.INFERENCE, | |
| ) | |
| ) | |
| case Tasks.MNLI_NINEELEVEN_FR_MT: | |
| tasks.append( | |
| Task( | |
| task_name=task.value, | |
| metric="accuracy", | |
| task_type=TaskType.INFERENCE, | |
| ) | |
| ) | |
| case Tasks.RTE3_FRENCH: | |
| tasks.append( | |
| Task( | |
| task_name=task.value, | |
| metric="accuracy", | |
| task_type=TaskType.INFERENCE, | |
| ) | |
| ) | |
| case Tasks.WINO_X_LM: | |
| tasks.append( | |
| Task( | |
| task_name=task.value, | |
| metric="accuracy", | |
| task_type=TaskType.INFERENCE, | |
| ) | |
| ) | |
| case Tasks.WINO_X_MT: | |
| tasks.append( | |
| Task( | |
| task_name=task.value, | |
| metric="accuracy", | |
| task_type=TaskType.INFERENCE, | |
| ) | |
| ) | |
| case Tasks.MULTIBLIMP: | |
| tasks.append( | |
| Task( | |
| task_name=task.value, | |
| metric="accuracy", | |
| task_type=TaskType.INFERENCE, | |
| ) | |
| ) | |
| case Tasks.FRACAS: | |
| tasks.append( | |
| Task( | |
| task_name=task.value, | |
| metric="accuracy", | |
| task_type=TaskType.INFERENCE, | |
| ) | |
| ) | |
| case Tasks.MMS: | |
| tasks.append( | |
| Task( | |
| task_name=task.value, | |
| metric="accuracy", | |
| task_type=TaskType.INFERENCE, | |
| ) | |
| ) | |
| case Tasks.WSD: | |
| tasks.append( | |
| Task( | |
| task_name=task.value, | |
| metric="em", | |
| task_type=TaskType.GENERATIVE, | |
| ) | |
| ) | |
| case Tasks.LINGNLI: | |
| tasks.append( | |
| Task( | |
| task_name=task.value, | |
| metric="accuracy", | |
| task_type=TaskType.INFERENCE, | |
| ) | |
| ) | |
| case _: | |
| error = f"Unknown task {task.value}." | |
| logging.error(error) | |
| raise ValueError(error) | |
| return tasks | |