mazpie's picture
Initial commit
2d9a728
raw
history blame
9.75 kB
import copy
import logging
import os
import os.path as osp
import io
try:
import deepspeed
except Exception as e:
print(e)
print("deepspeed is not installed!!!")
from os.path import join
try:
from petrel_client.client import Client
except:
Client = None
import torch
from torch.utils.data import ConcatDataset, DataLoader
from dataset.resample_concat_dataset import ResampleConcatDataset
from models.backbones.internvideo2.pos_embed import interpolate_pos_embed_internvideo2_new
from models.backbones.bert.tokenization_bert import BertTokenizer
from utils.optimizer import create_optimizer
from utils.scheduler import create_scheduler
from utils.distributed import get_rank
logger = logging.getLogger(__name__)
def get_media_types(datasources):
"""get the media types for for all the dataloaders.
Args:
datasources (List): List of dataloaders or datasets.
Returns: List. The media_types.
"""
if isinstance(datasources[0], DataLoader):
datasets = [dataloader.dataset for dataloader in datasources]
else:
datasets = datasources
media_types = [
dataset.datasets[0].media_type
if isinstance(dataset, ConcatDataset) or isinstance(dataset, ResampleConcatDataset)
else dataset.media_type
for dataset in datasets
]
return media_types
def setup_model(
config, model_cls, add_decoder=False, pretrain=False, find_unused_parameters=False
):
logger.info("Creating model")
config = copy.deepcopy(config)
if "bert" in config.model.text_encoder.name:
logger.info(f"Using BertTokenizer: {config.model.text_encoder.pretrained}!")
tokenizer = BertTokenizer.from_pretrained(config.model.text_encoder.pretrained, local_files_only=True)
model = model_cls(config=config, tokenizer=tokenizer, is_pretrain=pretrain)
else:
model = model_cls(config=config, is_pretrain=pretrain)
tokenizer = model.tokenizer
logger.info(f"Using model.tokenizer: {tokenizer}!")
if config.get('compile_model', False):
torch.set_float32_matmul_precision('high')
model = torch.compile(model)
model = model.to(torch.device(config.device))
model_without_ddp = model
if hasattr(config, "deepspeed") and config.deepspeed.enable:
# We move this to the back
optimizer_params = create_optimizer(config.optimizer, model, return_group=True)
scheduler = None
scaler = None
else:
if config.distributed:
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[config.gpu],
find_unused_parameters=find_unused_parameters, # `False` for image-only task
)
optimizer = create_optimizer(config.optimizer, model)
scaler = torch.cuda.amp.GradScaler(enabled=config.use_half_precision) # This is never used actually if we fixed bf16
scheduler = create_scheduler(config.scheduler, optimizer)
start_epoch = 0
global_step = 0
# auto resume the latest checkpoint
if config.get("auto_resume", False):
logger.info("Auto resuming")
model_latest = join(config.output_dir, "ckpt_latest.pth")
model_best = join(config.output_dir, "ckpt_best.pth")
large_num = -1
for p in os.listdir(config.output_dir):
if 'ckpt' in p:
num = p.split('_')[1].split('.')[0]
if str.isnumeric(num):
if int(num) > large_num:
large_num = int(num)
if large_num != -1:
model_latest = join(config.output_dir, f"ckpt_{large_num:02d}.pth")
if osp.isfile(model_latest):
config.pretrained_path = model_latest
config.resume = True
elif osp.isfile(model_best):
config.pretrained_path = model_best
config.resume = True
else:
logger.info(f"Not found checkpoint in {config.output_dir}")
if (config.pretrained_path.strip() and (osp.isfile(config.pretrained_path)) or "s3://" in config.pretrained_path):
if Client is not None:
client = Client()
with io.BytesIO(client.get(config.pretrained_path)) as buffer:
checkpoint = torch.load(buffer, map_location="cpu")
else:
checkpoint = torch.load(config.pretrained_path, map_location="cpu")
logger.info(f"Loading checkpoint from {config.pretrained_path}")
try:
if "model" in checkpoint.keys():
state_dict = checkpoint["model"]
else:
state_dict = checkpoint["module"] # This is a deepspeed stage 1 model
except:
state_dict = checkpoint
if config.get('origin_num_frames', None) is not None:
logger.info(f"interpolate_pos_embed_internvideo2 (origin_num_frames={config.origin_num_frames})!!!")
a = len(state_dict)
interpolate_pos_embed_internvideo2_new(state_dict, model_without_ddp.vision_encoder, orig_t_size=config.origin_num_frames)
assert a == len(state_dict), state_dict.keys()
if config.resume:
assert not (hasattr(config, "deepspeed") and config.deepspeed.enable), "Deepspeed should run here!!!"
optimizer.load_state_dict(checkpoint["optimizer"])
scheduler.load_state_dict(checkpoint["scheduler"])
scaler.load_state_dict(checkpoint["scaler"])
if 'local_step' in checkpoint.keys():
start_epoch = checkpoint['epoch']
else:
start_epoch = checkpoint["epoch"] + 1
global_step = checkpoint["global_step"]
elif not pretrain: # downstream init from pretrained ckpt
if not config.evaluate or config.get("zero_shot", False): # finetuning from a pretrained weights.
if add_decoder:
logger.info("Init new decoder with encoder!!!")
for key in list(state_dict.keys()):
if "text_encoder.bert" in key:
encoder_key = key.replace("bert.", "")
state_dict[encoder_key] = state_dict[key]
if not add_decoder:
del state_dict[key]
# init text decoder as multimodal encoder (last 6 layers of model.text_encoder)
# only for generation tasks like VQA
if add_decoder and "text_encoder.bert" in key:
if "layer" in key:
encoder_keys = key.split(".")
layer_num = int(encoder_keys[4])
if layer_num < config.model.text_encoder.fusion_layer:
del state_dict[key]
continue
else:
decoder_layer_num = layer_num - config.model.text_encoder.fusion_layer
encoder_keys[4] = str(decoder_layer_num)
encoder_key = ".".join(encoder_keys)
else:
encoder_key = key
decoder_key = encoder_key.replace("text_encoder", "text_decoder")
state_dict[decoder_key] = state_dict[key]
del state_dict[key]
msg = model_without_ddp.load_state_dict(state_dict, strict=False)
logger.info(msg)
logger.info(f"Loaded checkpoint from {config.pretrained_path}")
else:
if not config.resume:
assert not config.evaluate, "No available pretrained checkpoint provided!!!"
assert config.pretrained_path == "", config.pretrained_path
logger.warning("No available pretrained checkpoint provided, training from scratch.")
if hasattr(config, "deepspeed") and config.deepspeed.enable:
logger.info(f'Use deepspeed to initialize model (resume={config.resume}) !!!')
model = model_without_ddp
model, optimizer, _, _ = deepspeed.initialize(
args=config, model=model, model_parameters=optimizer_params, dist_init_required=not config.distributed,
lr_scheduler=lambda opt: create_scheduler(config.scheduler, opt)
)
if config.resume:
logger.info(f'Resume deepspeed ckpt from {config.output_dir}, tag={config.pretrained_path}, load_module_strict={config.get("load_module_strict", True)}, load_lr_scheduler_states={config.get("load_lr_scheduler_states", True)}!!!')
_, client_states = model.load_checkpoint(config.output_dir, tag=config.pretrained_path, load_module_strict=config.get("load_module_strict", True), load_lr_scheduler_states=config.get("load_lr_scheduler_states", True))
logger.info(client_states)
if 'local_step' in client_states.keys():
start_epoch = client_states['epoch']
else:
start_epoch = client_states['epoch'] + 1
global_step = client_states['global_step']
logger.info(f"Cuda memory after create model: {torch.cuda.memory_allocated() // 1024**2}M, Max mem: {torch.cuda.max_memory_allocated() // 1024**2}M start_epoch={start_epoch}, global_step={global_step}")
print(f"\033[31m Cuda memory after create model: {torch.cuda.memory_allocated() // 1024**2}M, Max mem: {torch.cuda.max_memory_allocated() // 1024**2}M start_epoch={start_epoch}, global_step={global_step}\033[0m")
return (
model,
model_without_ddp,
optimizer,
scheduler,
scaler,
tokenizer,
start_epoch,
global_step,
)