annabeth97c's picture
Initial commit
12f2e48 verified
import transformers
import logging
from multi_token.training import (
TrainingArguments,
ModelArguments,
train_for_modalities,
)
from multi_token.training_data import (
DataArguments,
TrainDataArguments,
EvaluationDataArguments,
)
from multi_token.model_utils import MultiTaskType
from multi_token.language_models import LANGUAGE_MODEL_NAME_TO_CLASS
from multi_token.modalities import MODALITY_BUILDERS
if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO)
parser = transformers.HfArgumentParser(
(TrainingArguments, ModelArguments, TrainDataArguments, EvaluationDataArguments)
)
training_args, model_args, train_data_args, evaluation_data_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
_train_data_args = DataArguments()
_evaluation_data_args = DataArguments()
_train_data_args.dataset_path = train_data_args.train_dataset_path
_evaluation_data_args.dataset_path = evaluation_data_args.evaluation_dataset_path
if MultiTaskType(model_args.use_multi_task) != MultiTaskType.NO_MULTI_TASK:
modalities = MODALITY_BUILDERS[model_args.modality_builder](use_multi_task = MultiTaskType(model_args.use_multi_task), tasks_config = model_args.tasks_config)
else:
modalities = MODALITY_BUILDERS[model_args.modality_builder]()
model_cls = LANGUAGE_MODEL_NAME_TO_CLASS[model_args.model_cls]
train_for_modalities(model_cls, training_args, model_args, _train_data_args, _evaluation_data_args, modalities)