Spaces:
Sleeping
Sleeping
""" | |
Copyright (c) 2022, salesforce.com, inc. | |
All rights reserved. | |
SPDX-License-Identifier: BSD-3-Clause | |
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
""" | |
import datetime | |
import logging | |
import os | |
import time | |
import torch | |
import torch.distributed as dist | |
import webdataset as wds | |
from lavis.common.dist_utils import download_cached_file, is_main_process, main_process | |
from lavis.common.registry import registry | |
from lavis.common.utils import is_url | |
from lavis.datasets.data_utils import concat_datasets, reorg_datasets_by_split | |
from lavis.runners.runner_base import RunnerBase | |
from torch.utils.data.dataset import ChainDataset | |
class RunnerIter(RunnerBase): | |
""" | |
Run training based on the number of iterations. This is common when | |
the training dataset size is large. Underhood logic is similar to | |
epoch-based training by considering every #iters_per_inner_epoch as an | |
inner epoch. | |
In iter-based runner, after every #iters_per_inner_epoch steps, we | |
1) do a validation epoch; | |
2) schedule the learning rate; | |
3) save the checkpoint. | |
We refer every #iters_per_inner_epoch steps as an inner epoch. | |
""" | |
def __init__(self, cfg, task, model, datasets, job_id): | |
super().__init__(cfg, task, model, datasets, job_id) | |
self.start_iters = 0 | |
self.max_iters = int(self.config.run_cfg.get("max_iters", -1)) | |
assert self.max_iters > 0, "max_iters must be greater than 0." | |
self.iters_per_inner_epoch = int( | |
self.config.run_cfg.get("iters_per_inner_epoch", -1) | |
) | |
assert ( | |
self.iters_per_inner_epoch > 0 | |
), "iters_per_inner_epoch must be greater than 0." | |
def max_epoch(self): | |
return int(self.max_iters / self.iters_per_inner_epoch) | |
def cur_epoch(self): | |
try: | |
return self.train_loader.epoch | |
except AttributeError: | |
# pipeline data (e.g. LAION) is streaming, have no concept of epoch | |
return 0 | |
def _progress(self, cur_iters): | |
return "{}_iters={}".format(self.cur_epoch, cur_iters) | |
def train(self): | |
start_time = time.time() | |
best_agg_metric = 0 | |
best_iters = 0 | |
self.log_config() | |
# resume from checkpoint if specified | |
if not self.evaluate_only and self.resume_ckpt_path is not None: | |
self._load_checkpoint(self.resume_ckpt_path) | |
for start_iters in range( | |
self.start_iters, self.max_iters, self.iters_per_inner_epoch | |
): | |
end_iters = start_iters + self.iters_per_inner_epoch | |
# training phase | |
if not self.evaluate_only: | |
logging.info( | |
"Start training, max_iters={}, in total {} inner epochs.".format( | |
self.max_iters, int(self.max_iters / self.iters_per_inner_epoch) | |
) | |
) | |
train_stats = self.train_iters(self.cur_epoch, start_iters) | |
self.log_stats(split_name="train", stats=train_stats) | |
# evaluation phase | |
if len(self.valid_splits) > 0: | |
for split_name in self.valid_splits: | |
logging.info("Evaluating on {}.".format(split_name)) | |
val_log = self.eval_epoch( | |
split_name=split_name, cur_epoch=self._progress(end_iters) | |
) | |
if val_log is not None: | |
if is_main_process(): | |
assert ( | |
"agg_metrics" in val_log | |
), "No agg_metrics found in validation log." | |
agg_metrics = val_log["agg_metrics"] | |
if agg_metrics > best_agg_metric and split_name == "val": | |
best_iters, best_agg_metric = end_iters, agg_metrics | |
self._save_checkpoint(end_iters, is_best=True) | |
val_log.update({"best_iters": best_iters}) | |
self.log_stats(val_log, split_name) | |
else: | |
# if no validation split is provided, we just save the checkpoint at the end of each inner epoch. | |
if not self.evaluate_only: | |
self._save_checkpoint(end_iters, is_best=False) | |
if self.evaluate_only: | |
break | |
dist.barrier() | |
# testing phase | |
self.evaluate(cur_epoch=self.cur_epoch) | |
total_time = time.time() - start_time | |
total_time_str = str(datetime.timedelta(seconds=int(total_time))) | |
logging.info("Training time {}".format(total_time_str)) | |
def train_iters(self, epoch, start_iters): | |
# train by iterations | |
self.model.train() | |
return self.task.train_iters( | |
epoch=epoch, | |
start_iters=start_iters, | |
iters_per_inner_epoch=self.iters_per_inner_epoch, | |
model=self.model, | |
data_loader=self.train_loader, | |
optimizer=self.optimizer, | |
scaler=self.scaler, | |
lr_scheduler=self.lr_scheduler, | |
cuda_enabled=self.cuda_enabled, | |
log_freq=self.log_freq, | |
accum_grad_iters=self.accum_grad_iters, | |
) | |
def _save_checkpoint(self, cur_iters, is_best=False): | |
save_obj = { | |
"model": self.unwrap_dist_model(self.model).state_dict(), | |
"optimizer": self.optimizer.state_dict(), | |
"config": self.config.to_dict(), | |
"scaler": self.scaler.state_dict() if self.scaler else None, | |
"iters": cur_iters, | |
} | |
save_to = os.path.join( | |
self.output_dir, | |
"checkpoint_{}.pth".format("best" if is_best else cur_iters), | |
) | |
logging.info("Saving checkpoint at iters {} to {}.".format(cur_iters, save_to)) | |
torch.save(save_obj, save_to) | |
def _load_checkpoint(self, url_or_filename): | |
""" | |
Resume from a checkpoint. | |
""" | |
if is_url(url_or_filename): | |
cached_file = download_cached_file( | |
url_or_filename, check_hash=False, progress=True | |
) | |
checkpoint = torch.load(cached_file, map_location=self.device) | |
elif os.path.isfile(url_or_filename): | |
checkpoint = torch.load(url_or_filename, map_location=self.device) | |
else: | |
raise RuntimeError("checkpoint url or path is invalid") | |
state_dict = checkpoint["model"] | |
self.unwrap_dist_model(self.model).load_state_dict(state_dict) | |
self.optimizer.load_state_dict(checkpoint["optimizer"]) | |
if self.scaler and "scaler" in checkpoint: | |
self.scaler.load_state_dict(checkpoint["scaler"]) | |
self.start_iters = checkpoint["iters"] + 1 | |
logging.info("Resume checkpoint from {}".format(url_or_filename)) | |
def dataloaders(self) -> dict: | |
""" | |
A property to get and create dataloaders by split just in need. | |
If no train_dataset_ratio is provided, concatenate map-style datasets and | |
chain wds.DataPipe datasets separately. Training set becomes a tuple | |
(ConcatDataset, ChainDataset), both are optional but at least one of them is | |
required. The resultant ConcatDataset and ChainDataset will be sampled evenly. | |
If train_dataset_ratio is provided, create a MultiIterLoader to sample | |
each dataset by ratios during training. | |
Currently do not support multiple datasets for validation and test. | |
Returns: | |
dict: {split_name: (tuples of) dataloader} | |
""" | |
if self._dataloaders is None: | |
# reoganize datasets by split and concatenate/chain if necessary | |
dataset_ratios = self.config.run_cfg.get("train_dataset_ratios", None) | |
if dataset_ratios is None: | |
# concatenate map-style datasets and chain wds.DataPipe datasets separately | |
# training set becomes a tuple (ConcatDataset, ChainDataset), both are | |
# optional but at least one of them is required. The resultant ConcatDataset | |
# and ChainDataset will be sampled evenly. | |
logging.info( | |
"dataset_ratios not specified, datasets will be concatenated (map-style datasets) or chained (webdataset.DataPipeline)." | |
) | |
datasets = reorg_datasets_by_split(self.datasets) | |
self.datasets = concat_datasets(datasets) | |
else: | |
# create multi-loader with the provided ratios, without concatenating or chaining | |
missing_keys = [k for k in dataset_ratios if k not in self.datasets] | |
if len(missing_keys) > 0: | |
raise ValueError( | |
"Datasets with the following split names are not found: {}".format( | |
missing_keys | |
) | |
) | |
unexpected_keys = [k for k in self.datasets if k not in dataset_ratios] | |
if len(unexpected_keys) > 0: | |
raise ValueError( | |
"Datasets with the following split names are not expected: {}".format( | |
unexpected_keys | |
) | |
) | |
dataset_ratios = [float(dataset_ratios[k]) for k in self.datasets] | |
self.datasets = reorg_datasets_by_split(self.datasets) | |
# to keep the same structure as return value of concat_datasets | |
self.datasets = { | |
k: v[0] if len(v) == 1 else v for k, v in datasets.items() | |
} | |
# print dataset statistics after concatenation/chaining | |
for split_name in self.datasets: | |
if isinstance(self.datasets[split_name], tuple) or isinstance( | |
self.datasets[split_name], list | |
): | |
# mixed wds.DataPipeline and torch.utils.data.Dataset | |
num_records = sum( | |
[ | |
len(d) | |
if not type(d) in [wds.DataPipeline, ChainDataset] | |
else 0 | |
for d in self.datasets[split_name] | |
] | |
) | |
else: | |
try: | |
# a single map-style dataset | |
num_records = len(self.datasets[split_name]) | |
except TypeError: | |
# a single wds.DataPipeline or ChainDataset | |
num_records = -1 | |
logging.info( | |
"Only a single wds.DataPipeline dataset, no __len__ attribute." | |
) | |
if num_records >= 0: | |
logging.info( | |
"Loaded {} records for {} split from the dataset.".format( | |
num_records, split_name | |
) | |
) | |
# create dataloaders | |
split_names = sorted(self.datasets.keys()) | |
datasets = [self.datasets[split] for split in split_names] | |
is_trains = [split in self.train_splits for split in split_names] | |
batch_sizes = [ | |
self.config.run_cfg.batch_size_train | |
if split == "train" | |
else self.config.run_cfg.batch_size_eval | |
for split in split_names | |
] | |
collate_fns = [] | |
for dataset in datasets: | |
if isinstance(dataset, tuple) or isinstance(dataset, list): | |
collate_fns.append([getattr(d, "collater", None) for d in dataset]) | |
else: | |
collate_fns.append(getattr(dataset, "collater", None)) | |
dataloaders = self.create_loaders( | |
datasets=datasets, | |
num_workers=self.config.run_cfg.num_workers, | |
batch_sizes=batch_sizes, | |
is_trains=is_trains, | |
collate_fns=collate_fns, | |
dataset_ratios=dataset_ratios, | |
) | |
self._dataloaders = {k: v for k, v in zip(split_names, dataloaders)} | |
return self._dataloaders | |