cole / src /task /task_factory.py
davebulaval's picture
v1
8fa3acc
import logging
from typing import Dict, List, Union
from src.task.task import Task, TaskType
from src.task.task_names import Tasks
def tasks_factory(task_names: Union[Dict, List[str], List[Tasks]]) -> List[Task]:
"""
Factory method to create a list of Task objects from a dictionary of task names and their predictions.
"""
tasks = []
if isinstance(task_names, Dict):
task_names = task_names.get("tasks")
task_names = [list(task.keys())[0] for task in task_names]
task_names = [
Tasks(task_name) if isinstance(task_name, str) else task_name
for task_name in task_names
]
for task in task_names:
match task:
case Tasks.ALLOCINE:
tasks.append(
Task(
task_name=task.value,
metric="accuracy",
task_type=TaskType.INFERENCE,
)
)
case Tasks.FQUAD:
tasks.append(
Task(
task_name=task.value,
metric="fquad",
task_type=TaskType.GENERATIVE,
)
)
case Tasks.GQNLI:
tasks.append(
Task(
task_name=task.value,
metric="accuracy",
task_type=TaskType.INFERENCE,
)
)
case Tasks.PAWS_X:
tasks.append(
Task(
task_name=task.value,
metric="accuracy",
task_type=TaskType.INFERENCE,
)
)
case Tasks.PIAF:
tasks.append(
Task(
task_name=task.value,
metric="fquad",
task_type=TaskType.GENERATIVE,
)
)
case Tasks.QFRBLIMP:
tasks.append(
Task(
task_name=task.value,
metric="accuracy",
task_type=TaskType.INFERENCE,
)
)
case Tasks.QFRCOLA:
tasks.append(
Task(
task_name=task.value,
metric="accuracy",
task_type=TaskType.INFERENCE,
)
)
case Tasks.SICKFR:
tasks.append(
Task(
task_name=task.value,
metric="accuracy",
task_type=TaskType.INFERENCE,
)
)
case Tasks.STS22:
tasks.append(
Task(
task_name=task.value,
metric="accuracy",
task_type=TaskType.INFERENCE,
)
)
case Tasks.XNLI:
tasks.append(
Task(
task_name=task.value,
metric="accuracy",
task_type=TaskType.INFERENCE,
)
)
case Tasks.QFRCORE:
tasks.append(
Task(
task_name=task.value,
metric="accuracy",
task_type=TaskType.INFERENCE,
)
)
case Tasks.QFRCORT:
tasks.append(
Task(
task_name=task.value,
metric="accuracy",
task_type=TaskType.INFERENCE,
)
)
case Tasks.DACCORD:
tasks.append(
Task(
task_name=task.value,
metric="accuracy",
task_type=TaskType.INFERENCE,
)
)
case Tasks.FRENCH_BOOLQ:
tasks.append(
Task(
task_name=task.value,
metric="accuracy",
task_type=TaskType.INFERENCE,
)
)
case Tasks.MNLI_NINEELEVEN_FR_MT:
tasks.append(
Task(
task_name=task.value,
metric="accuracy",
task_type=TaskType.INFERENCE,
)
)
case Tasks.RTE3_FRENCH:
tasks.append(
Task(
task_name=task.value,
metric="accuracy",
task_type=TaskType.INFERENCE,
)
)
case Tasks.WINO_X_LM:
tasks.append(
Task(
task_name=task.value,
metric="accuracy",
task_type=TaskType.INFERENCE,
)
)
case Tasks.WINO_X_MT:
tasks.append(
Task(
task_name=task.value,
metric="accuracy",
task_type=TaskType.INFERENCE,
)
)
case Tasks.MULTIBLIMP:
tasks.append(
Task(
task_name=task.value,
metric="accuracy",
task_type=TaskType.INFERENCE,
)
)
case Tasks.FRACAS:
tasks.append(
Task(
task_name=task.value,
metric="accuracy",
task_type=TaskType.INFERENCE,
)
)
case Tasks.MMS:
tasks.append(
Task(
task_name=task.value,
metric="accuracy",
task_type=TaskType.INFERENCE,
)
)
case Tasks.WSD:
tasks.append(
Task(
task_name=task.value,
metric="em",
task_type=TaskType.GENERATIVE,
)
)
case Tasks.LINGNLI:
tasks.append(
Task(
task_name=task.value,
metric="accuracy",
task_type=TaskType.INFERENCE,
)
)
case _:
error = f"Unknown task {task.value}."
logging.error(error)
raise ValueError(error)
return tasks