EdwardoSunny's picture
finished
85ab89d
raw
history blame
24.5 kB
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import datetime
import json
import logging
import os
import time
from pathlib import Path
import numpy as np
import torch
import torch.distributed as dist
import webdataset as wds
from minigpt4.common.dist_utils import (
download_cached_file,
get_rank,
get_world_size,
is_main_process,
main_process,
)
from minigpt4.common.registry import registry
from minigpt4.common.utils import is_url
from minigpt4.datasets.data_utils import concat_datasets, reorg_datasets_by_split, ChainDataset
from minigpt4.datasets.datasets.dataloader_utils import (
IterLoader,
MultiIterLoader,
PrefetchLoader,
)
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler, default_collate
@registry.register_runner("runner_base")
class RunnerBase:
"""
A runner class to train and evaluate a model given a task and datasets.
The runner uses pytorch distributed data parallel by default. Future release
will support other distributed frameworks.
"""
def __init__(self, cfg, task, model, datasets, job_id):
self.config = cfg
self.job_id = job_id
self.task = task
self.datasets = datasets
self._model = model
self._wrapped_model = None
self._device = None
self._optimizer = None
self._scaler = None
self._dataloaders = None
self._lr_sched = None
self.start_epoch = 0
# self.setup_seeds()
self.setup_output_dir()
@property
def device(self):
if self._device is None:
self._device = torch.device(self.config.run_cfg.device)
return self._device
@property
def use_distributed(self):
return self.config.run_cfg.distributed
@property
def model(self):
"""
A property to get the DDP-wrapped model on the device.
"""
# move model to device
if self._model.device != self.device:
self._model = self._model.to(self.device)
# distributed training wrapper
if self.use_distributed:
if self._wrapped_model is None:
self._wrapped_model = DDP(
self._model, device_ids=[self.config.run_cfg.gpu]
)
else:
self._wrapped_model = self._model
return self._wrapped_model
@property
def optimizer(self):
# TODO make optimizer class and configurations
if self._optimizer is None:
num_parameters = 0
p_wd, p_non_wd = [], []
for n, p in self.model.named_parameters():
if not p.requires_grad:
continue # frozen weights
print(n)
if p.ndim < 2 or "bias" in n or "ln" in n or "bn" in n:
p_non_wd.append(p)
else:
p_wd.append(p)
num_parameters += p.data.nelement()
logging.info("number of trainable parameters: %d" % num_parameters)
optim_params = [
{
"params": p_wd,
"weight_decay": float(self.config.run_cfg.weight_decay),
},
{"params": p_non_wd, "weight_decay": 0},
]
beta2 = self.config.run_cfg.get("beta2", 0.999)
self._optimizer = torch.optim.AdamW(
optim_params,
lr=float(self.config.run_cfg.init_lr),
weight_decay=float(self.config.run_cfg.weight_decay),
betas=(0.9, beta2),
)
return self._optimizer
@property
def scaler(self):
amp = self.config.run_cfg.get("amp", False)
if amp:
if self._scaler is None:
self._scaler = torch.cuda.amp.GradScaler()
return self._scaler
@property
def lr_scheduler(self):
"""
A property to get and create learning rate scheduler by split just in need.
"""
if self._lr_sched is None:
lr_sched_cls = registry.get_lr_scheduler_class(self.config.run_cfg.lr_sched)
# max_epoch = self.config.run_cfg.max_epoch
max_epoch = self.max_epoch
# min_lr = self.config.run_cfg.min_lr
min_lr = self.min_lr
# init_lr = self.config.run_cfg.init_lr
init_lr = self.init_lr
# optional parameters
decay_rate = self.config.run_cfg.get("lr_decay_rate", None)
warmup_start_lr = self.config.run_cfg.get("warmup_lr", -1)
warmup_steps = self.config.run_cfg.get("warmup_steps", 0)
iters_per_epoch = self.config.run_cfg.get("iters_per_epoch", None)
if iters_per_epoch is None:
try:
iters_per_epoch = len(self.dataloaders['train'].loaders[0])
except (AttributeError, TypeError):
iters_per_epoch = 10000
self._lr_sched = lr_sched_cls(
optimizer=self.optimizer,
max_epoch=max_epoch,
iters_per_epoch=iters_per_epoch,
min_lr=min_lr,
init_lr=init_lr,
decay_rate=decay_rate,
warmup_start_lr=warmup_start_lr,
warmup_steps=warmup_steps,
)
return self._lr_sched
@property
def dataloaders(self) -> dict:
"""
A property to get and create dataloaders by split just in need.
If no train_dataset_ratio is provided, concatenate map-style datasets and
chain wds.DataPipe datasets separately. Training set becomes a tuple
(ConcatDataset, ChainDataset), both are optional but at least one of them is
required. The resultant ConcatDataset and ChainDataset will be sampled evenly.
If train_dataset_ratio is provided, create a MultiIterLoader to sample
each dataset by ratios during training.
Currently do not support multiple datasets for validation and test.
Returns:
dict: {split_name: (tuples of) dataloader}
"""
if self._dataloaders is None:
# concatenate map-style datasets and chain wds.DataPipe datasets separately
# training set becomes a tuple (ConcatDataset, ChainDataset), both are
# optional but at least one of them is required. The resultant ConcatDataset
# and ChainDataset will be sampled evenly.
logging.info(
"dataset_ratios not specified, datasets will be concatenated (map-style datasets) or chained (webdataset.DataPipeline)."
)
datasets = reorg_datasets_by_split(self.datasets)
self.datasets = datasets
# self.datasets = concat_datasets(datasets)
# print dataset statistics after concatenation/chaining
for split_name in self.datasets:
if isinstance(self.datasets[split_name], tuple) or isinstance(
self.datasets[split_name], list
):
# mixed wds.DataPipeline and torch.utils.data.Dataset
num_records = sum(
[
len(d)
if not type(d) in [wds.DataPipeline, ChainDataset]
else 0
for d in self.datasets[split_name]
]
)
else:
if hasattr(self.datasets[split_name], "__len__"):
# a single map-style dataset
num_records = len(self.datasets[split_name])
else:
# a single wds.DataPipeline
num_records = -1
logging.info(
"Only a single wds.DataPipeline dataset, no __len__ attribute."
)
if num_records >= 0:
logging.info(
"Loaded {} records for {} split from the dataset.".format(
num_records, split_name
)
)
# create dataloaders
split_names = sorted(self.datasets.keys())
print("split_names")
print(f"split_names: {split_names}")
# exit()
datasets = [self.datasets[split] for split in split_names]
is_trains = [split in self.train_splits for split in split_names]
batch_sizes = [
self.config.run_cfg.batch_size_train
if split == "train"
else self.config.run_cfg.batch_size_eval
for split in split_names
]
collate_fns = []
for dataset in datasets:
if isinstance(dataset, tuple) or isinstance(dataset, list):
collate_fns.append([getattr(d, "collater", None) for d in dataset])
else:
collate_fns.append(getattr(dataset, "collater", None))
# def custom_collate(original_batch):
# max_len_protein_dim0 = -1
# for pdb_json in original_batch:
# pdb_coords = pdb_json["pdb_coords"]
# shape_dim0 = pdb_coords.shape[0]
# max_len_protein_dim0 = max(max_len_protein_dim0, shape_dim0)
# # print(f"max_len_protein_dim0: {max_len_protein_dim0}")
# for pdb_json in original_batch:
# pdb_coords = pdb_json["pdb_coords"]
# shape_dim0 = pdb_coords.shape[0]
# pad1 = ((0, max_len_protein_dim0 - shape_dim0), (0, 0), (0, 0))
# arr1_padded = np.pad(pdb_coords, pad1, mode='constant', )
# pdb_json["pdb_coords"] = arr1_padded
#
# # for pdb_json in original_batch:
# # print(f"id: {pdb_json['pdb_id']}, shape: {pdb_json['pdb_coords'].shape}")
#
#
# return default_collate(original_batch)
# # print(collate_fns)
# collate_fns[0][0] = custom_collate
# print(collate_fns)
dataloaders = self.create_loaders(
datasets=datasets,
num_workers=self.config.run_cfg.num_workers,
batch_sizes=batch_sizes,
is_trains=is_trains,
collate_fns=collate_fns,
)
# print(dataloaders[0].loaders)
# print(collate_fns)
#
# print(datasets)
# # exit()
self._dataloaders = {k: v for k, v in zip(split_names, dataloaders)}
return self._dataloaders
@property
def cuda_enabled(self):
return self.device.type == "cuda"
@property
def max_epoch(self):
return int(self.config.run_cfg.max_epoch)
@property
def log_freq(self):
log_freq = self.config.run_cfg.get("log_freq", 50)
return int(log_freq)
@property
def init_lr(self):
return float(self.config.run_cfg.init_lr)
@property
def min_lr(self):
return float(self.config.run_cfg.min_lr)
@property
def accum_grad_iters(self):
return int(self.config.run_cfg.get("accum_grad_iters", 1))
@property
def valid_splits(self):
valid_splits = self.config.run_cfg.get("valid_splits", [])
if len(valid_splits) == 0:
logging.info("No validation splits found.")
return valid_splits
@property
def test_splits(self):
test_splits = self.config.run_cfg.get("test_splits", [])
return test_splits
@property
def train_splits(self):
train_splits = self.config.run_cfg.get("train_splits", [])
if len(train_splits) == 0:
logging.info("Empty train splits.")
return train_splits
@property
def evaluate_only(self):
"""
Set to True to skip training.
"""
return self.config.run_cfg.evaluate
@property
def use_dist_eval_sampler(self):
return self.config.run_cfg.get("use_dist_eval_sampler", True)
@property
def resume_ckpt_path(self):
return self.config.run_cfg.get("resume_ckpt_path", None)
@property
def train_loader(self):
train_dataloader = self.dataloaders["train"]
return train_dataloader
def setup_output_dir(self):
lib_root = Path(registry.get_path("library_root"))
output_dir = lib_root / self.config.run_cfg.output_dir / self.job_id
result_dir = output_dir / "result"
output_dir.mkdir(parents=True, exist_ok=True)
result_dir.mkdir(parents=True, exist_ok=True)
registry.register_path("result_dir", str(result_dir))
registry.register_path("output_dir", str(output_dir))
self.result_dir = result_dir
self.output_dir = output_dir
def train(self):
start_time = time.time()
best_agg_metric = 0
best_epoch = 0
self.log_config()
# resume from checkpoint if specified
if not self.evaluate_only and self.resume_ckpt_path is not None:
self._load_checkpoint(self.resume_ckpt_path)
for cur_epoch in range(self.start_epoch, self.max_epoch):
# training phase
if not self.evaluate_only:
logging.info("Start training")
train_stats = self.train_epoch(cur_epoch)
self.log_stats(split_name="train", stats=train_stats)
# evaluation phase
if len(self.valid_splits) > 0:
for split_name in self.valid_splits:
logging.info("Evaluating on {}.".format(split_name))
val_log = self.eval_epoch(
split_name=split_name, cur_epoch=cur_epoch
)
if val_log is not None:
if is_main_process():
assert (
"agg_metrics" in val_log
), "No agg_metrics found in validation log."
agg_metrics = val_log["agg_metrics"]
if agg_metrics > best_agg_metric and split_name == "val":
best_epoch, best_agg_metric = cur_epoch, agg_metrics
self._save_checkpoint(cur_epoch, is_best=True)
val_log.update({"best_epoch": best_epoch})
self.log_stats(val_log, split_name)
else:
# if no validation split is provided, we just save the checkpoint at the end of each epoch.
if not self.evaluate_only:
self._save_checkpoint(cur_epoch, is_best=False)
if self.evaluate_only:
break
if self.config.run_cfg.distributed:
dist.barrier()
# testing phase
test_epoch = "best" if len(self.valid_splits) > 0 else cur_epoch
self.evaluate(cur_epoch=test_epoch, skip_reload=self.evaluate_only)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
logging.info("Training time {}".format(total_time_str))
def evaluate(self, cur_epoch="best", skip_reload=False):
test_logs = dict()
if len(self.test_splits) > 0:
for split_name in self.test_splits:
test_logs[split_name] = self.eval_epoch(
split_name=split_name, cur_epoch=cur_epoch, skip_reload=skip_reload
)
return test_logs
def train_epoch(self, epoch):
# train
self.model.train()
return self.task.train_epoch(
epoch=epoch,
model=self.model,
data_loader=self.train_loader,
optimizer=self.optimizer,
scaler=self.scaler,
lr_scheduler=self.lr_scheduler,
cuda_enabled=self.cuda_enabled,
log_freq=self.log_freq,
accum_grad_iters=self.accum_grad_iters,
)
@torch.no_grad()
def eval_epoch(self, split_name, cur_epoch, skip_reload=False):
"""
Evaluate the model on a given split.
Args:
split_name (str): name of the split to evaluate on.
cur_epoch (int): current epoch.
skip_reload_best (bool): whether to skip reloading the best checkpoint.
During training, we will reload the best checkpoint for validation.
During testing, we will use provided weights and skip reloading the best checkpoint .
"""
data_loader = self.dataloaders.get(split_name, None)
assert data_loader, "data_loader for split {} is None.".format(split_name)
# TODO In validation, you need to compute loss as well as metrics
# TODO consider moving to model.before_evaluation()
model = self.unwrap_dist_model(self.model)
if not skip_reload and cur_epoch == "best":
model = self._reload_best_model(model)
model.eval()
self.task.before_evaluation(
model=model,
dataset=self.datasets[split_name],
)
results = self.task.evaluation(model, data_loader)
if results is not None:
return self.task.after_evaluation(
val_result=results,
split_name=split_name,
epoch=cur_epoch,
)
def unwrap_dist_model(self, model):
if self.use_distributed:
return model.module
else:
return model
def create_loaders(
self,
datasets,
num_workers,
batch_sizes,
is_trains,
collate_fns,
dataset_ratios=None,
):
"""
Create dataloaders for training and validation.
"""
def _create_loader(dataset, num_workers, bsz, is_train, collate_fn):
# create a single dataloader for each split
if isinstance(dataset, ChainDataset) or isinstance(
dataset, wds.DataPipeline
):
# wds.WebdDataset instance are chained together
# webdataset.DataPipeline has its own sampler and collate_fn
loader = iter(
DataLoader(
dataset,
batch_size=bsz,
num_workers=num_workers,
pin_memory=True,
)
)
else:
# map-style dataset are concatenated together
# setup distributed sampler
if self.use_distributed:
sampler = DistributedSampler(
dataset,
shuffle=is_train,
num_replicas=get_world_size(),
rank=get_rank(),
)
if not self.use_dist_eval_sampler:
# e.g. retrieval evaluation
sampler = sampler if is_train else None
else:
sampler = None
loader = DataLoader(
dataset,
batch_size=bsz,
num_workers=num_workers,
pin_memory=True,
sampler=sampler,
shuffle=sampler is None and is_train,
collate_fn=collate_fn,
drop_last=True if is_train else False,
)
loader = PrefetchLoader(loader)
if is_train:
loader = IterLoader(loader, use_distributed=self.use_distributed)
return loader
loaders = []
for dataset, bsz, is_train, collate_fn in zip(
datasets, batch_sizes, is_trains, collate_fns
):
if isinstance(dataset, list) or isinstance(dataset, tuple):
if hasattr(dataset[0], 'sample_ratio') and dataset_ratios is None:
dataset_ratios = [d.sample_ratio for d in dataset]
loader = MultiIterLoader(
loaders=[
_create_loader(d, num_workers, bsz, is_train, collate_fn[i])
for i, d in enumerate(dataset)
],
ratios=dataset_ratios,
)
else:
loader = _create_loader(dataset, num_workers, bsz, is_train, collate_fn)
loaders.append(loader)
return loaders
@main_process
def _save_checkpoint(self, cur_epoch, is_best=False):
"""
Save the checkpoint at the current epoch.
"""
model_no_ddp = self.unwrap_dist_model(self.model)
param_grad_dic = {
k: v.requires_grad for (k, v) in model_no_ddp.named_parameters()
}
state_dict = model_no_ddp.state_dict()
for k in list(state_dict.keys()):
if k in param_grad_dic.keys() and not param_grad_dic[k]:
# delete parameters that do not require gradient
del state_dict[k]
save_obj = {
"model": state_dict,
"optimizer": self.optimizer.state_dict(),
"config": self.config.to_dict(),
"scaler": self.scaler.state_dict() if self.scaler else None,
"epoch": cur_epoch,
}
save_to = os.path.join(
self.output_dir,
"checkpoint_{}.pth".format("best" if is_best else cur_epoch),
)
logging.info("Saving checkpoint at epoch {} to {}.".format(cur_epoch, save_to))
torch.save(save_obj, save_to)
def _reload_best_model(self, model):
"""
Load the best checkpoint for evaluation.
"""
checkpoint_path = os.path.join(self.output_dir, "checkpoint_best.pth")
logging.info("Loading checkpoint from {}.".format(checkpoint_path))
checkpoint = torch.load(checkpoint_path, map_location="cpu")
try:
model.load_state_dict(checkpoint["model"])
except RuntimeError as e:
logging.warning(
"""
Key mismatch when loading checkpoint. This is expected if only part of the model is saved.
Trying to load the model with strict=False.
"""
)
model.load_state_dict(checkpoint["model"], strict=False)
return model
def _load_checkpoint(self, url_or_filename):
"""
Resume from a checkpoint.
"""
if is_url(url_or_filename):
cached_file = download_cached_file(
url_or_filename, check_hash=False, progress=True
)
checkpoint = torch.load(cached_file, map_location=self.device)
elif os.path.isfile(url_or_filename):
checkpoint = torch.load(url_or_filename, map_location=self.device)
else:
raise RuntimeError("checkpoint url or path is invalid")
state_dict = checkpoint["model"]
self.unwrap_dist_model(self.model).load_state_dict(state_dict,strict=False)
self.optimizer.load_state_dict(checkpoint["optimizer"])
if self.scaler and "scaler" in checkpoint:
self.scaler.load_state_dict(checkpoint["scaler"])
self.start_epoch = checkpoint["epoch"] + 1
logging.info("Resume checkpoint from {}".format(url_or_filename))
@main_process
def log_stats(self, stats, split_name):
if isinstance(stats, dict):
log_stats = {**{f"{split_name}_{k}": v for k, v in stats.items()}}
with open(os.path.join(self.output_dir, "log.txt"), "a") as f:
f.write(json.dumps(log_stats) + "\n")
elif isinstance(stats, list):
pass
@main_process
def log_config(self):
with open(os.path.join(self.output_dir, "log.txt"), "a") as f:
f.write(json.dumps(self.config.to_dict(), indent=4) + "\n")