|
"""
|
|
Copyright (c) 2022, salesforce.com, inc.
|
|
All rights reserved.
|
|
SPDX-License-Identifier: BSD-3-Clause
|
|
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
import random
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.backends.cudnn as cudnn
|
|
|
|
import lavis.tasks as tasks
|
|
from lavis.common.config import Config
|
|
from lavis.common.dist_utils import get_rank, init_distributed_mode
|
|
from lavis.common.logger import setup_logger
|
|
from lavis.common.optims import (
|
|
LinearWarmupCosineLRScheduler,
|
|
LinearWarmupStepLRScheduler,
|
|
)
|
|
from lavis.common.registry import registry
|
|
from lavis.common.utils import now
|
|
|
|
|
|
from lavis.datasets.builders import *
|
|
from lavis.models import *
|
|
from lavis.processors import *
|
|
from lavis.runners import *
|
|
from lavis.tasks import *
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="Training")
|
|
|
|
parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
|
|
parser.add_argument(
|
|
"--options",
|
|
nargs="+",
|
|
help="override some settings in the used config, the key-value pair "
|
|
"in xxx=yyy format will be merged into config file (deprecate), "
|
|
"change to --cfg-options instead.",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
return args
|
|
|
|
|
|
def setup_seeds(config):
|
|
seed = config.run_cfg.seed + get_rank()
|
|
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
|
|
cudnn.benchmark = False
|
|
cudnn.deterministic = True
|
|
|
|
|
|
def get_runner_class(cfg):
|
|
"""
|
|
Get runner class from config. Default to epoch-based runner.
|
|
"""
|
|
runner_cls = registry.get_runner_class(cfg.run_cfg.get("runner", "runner_base"))
|
|
|
|
return runner_cls
|
|
|
|
|
|
def main():
|
|
|
|
|
|
|
|
|
|
job_id = now()
|
|
|
|
cfg = Config(parse_args())
|
|
|
|
init_distributed_mode(cfg.run_cfg)
|
|
|
|
setup_seeds(cfg)
|
|
|
|
|
|
setup_logger()
|
|
|
|
cfg.pretty_print()
|
|
|
|
task = tasks.setup_task(cfg)
|
|
datasets = task.build_datasets(cfg)
|
|
model = task.build_model(cfg)
|
|
|
|
runner = get_runner_class(cfg)(
|
|
cfg=cfg, job_id=job_id, task=task, model=model, datasets=datasets
|
|
)
|
|
runner.train()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |