File size: 1,554 Bytes
12f2e48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import transformers
import logging

from multi_token.training import (
    TrainingArguments,
    ModelArguments,
    train_for_modalities,
)
from multi_token.training_data import (
    DataArguments,
    TrainDataArguments,
    EvaluationDataArguments,
)

from multi_token.model_utils import MultiTaskType
from multi_token.language_models import LANGUAGE_MODEL_NAME_TO_CLASS
from multi_token.modalities import MODALITY_BUILDERS

if __name__ == "__main__":
    logging.getLogger().setLevel(logging.INFO)

    parser = transformers.HfArgumentParser(
        (TrainingArguments, ModelArguments, TrainDataArguments, EvaluationDataArguments)
    )

    training_args, model_args, train_data_args, evaluation_data_args, _ = parser.parse_args_into_dataclasses(
        return_remaining_strings=True
    )

    _train_data_args = DataArguments()
    _evaluation_data_args = DataArguments()

    _train_data_args.dataset_path = train_data_args.train_dataset_path
    _evaluation_data_args.dataset_path = evaluation_data_args.evaluation_dataset_path

    if MultiTaskType(model_args.use_multi_task) != MultiTaskType.NO_MULTI_TASK:
        modalities = MODALITY_BUILDERS[model_args.modality_builder](use_multi_task = MultiTaskType(model_args.use_multi_task), tasks_config = model_args.tasks_config)
    else:
        modalities = MODALITY_BUILDERS[model_args.modality_builder]()

    model_cls = LANGUAGE_MODEL_NAME_TO_CLASS[model_args.model_cls]

    train_for_modalities(model_cls, training_args, model_args, _train_data_args, _evaluation_data_args, modalities)