|
from transformers import PretrainedConfig, PreTrainedModel, AutoModel, AutoConfig |
|
|
|
class EasyDict(dict): |
|
def __init__(self, d=None, **kwargs): |
|
if d is None: |
|
d = {} |
|
if kwargs: |
|
d.update(**kwargs) |
|
for k, v in d.items(): |
|
setattr(self, k, v) |
|
|
|
for k in self.__class__.__dict__.keys(): |
|
if not (k.startswith("__") and k.endswith("__")) and not k in ("update", "pop"): |
|
setattr(self, k, getattr(self, k)) |
|
|
|
def __setattr__(self, name, value): |
|
if isinstance(value, (list, tuple)): |
|
value = [self.__class__(x) if isinstance(x, dict) else x for x in value] |
|
elif isinstance(value, dict) and not isinstance(value, self.__class__): |
|
value = self.__class__(value) |
|
super(EasyDict, self).__setattr__(name, value) |
|
super(EasyDict, self).__setitem__(name, value) |
|
|
|
__setitem__ = __setattr__ |
|
|
|
def update(self, e=None, **f): |
|
d = e or dict() |
|
d.update(f) |
|
for k in d: |
|
setattr(self, k, d[k]) |
|
|
|
def pop(self, k, d=None): |
|
if hasattr(self, k): |
|
delattr(self, k) |
|
return super(EasyDict, self).pop(k, d) |
|
|
|
class InternVideo2Config(PretrainedConfig): |
|
model_type = "internvideo2" |
|
|
|
def __init__(self, |
|
tokenizer=None, |
|
train_file=None, |
|
test_file=None, |
|
test_types=None, |
|
num_workers=6, |
|
best_key=None, |
|
num_frames=8, |
|
num_frames_test=8, |
|
batch_size=64, |
|
batch_size_test=4, |
|
max_txt_l=32, |
|
inputs=None, |
|
text_enc="bert_large", |
|
model=None, |
|
criterion=None, |
|
optimizer=None, |
|
scheduler=None, |
|
evaluate=False, |
|
deep_fusion=False, |
|
evaluation=None, |
|
use_half_precision=False, |
|
use_bf16=True, |
|
gradient_checkpointing=True, |
|
use_flash_sdp=False, |
|
use_mem_efficient_sdp=False, |
|
compile_model=False, |
|
wandb=None, |
|
dist_url="env://", |
|
device="cuda", |
|
mode="pt", |
|
output_dir=None, |
|
resume=False, |
|
debug=False, |
|
log_freq=100, |
|
seed=42, |
|
save_latest=True, |
|
auto_resume=False, |
|
jump_evaluate=False, |
|
pretrained_path="", |
|
save_ckpt_iter=None, |
|
delete_ds_optim_states=True, |
|
deepspeed=None, |
|
**kwargs): |
|
super().__init__(**kwargs) |
|
|
|
self.tokenizer = tokenizer |
|
|
|
|
|
self.train_file = train_file or "available_corpus[\"pretrain_example_data_1B\"]" |
|
self.test_file = EasyDict(test_file or { |
|
"msrvtt_1k_test": "available_corpus[\"msrvtt_1k_test\"]", |
|
"didemo_ret_test": "available_corpus[\"didemo_ret_test\"]" |
|
}) |
|
self.test_types = test_types or ["msrvtt_1k_test", "didemo_ret_test"] |
|
self.num_workers = num_workers |
|
self.best_key = best_key or ["msrvtt_1k_test_match", "t2v_r1"] |
|
|
|
|
|
self.num_frames = num_frames |
|
self.num_frames_test = num_frames_test |
|
self.batch_size = batch_size |
|
self.batch_size_test = batch_size_test |
|
self.max_txt_l = max_txt_l |
|
self.inputs = EasyDict(inputs or { |
|
"image_res": 224, |
|
"video_input": EasyDict({ |
|
"num_frames": num_frames, |
|
"sample_type": "rand", |
|
"num_frames_test": num_frames_test, |
|
"sample_type_test": "middle", |
|
"random_aug": False |
|
}), |
|
"max_txt_l": EasyDict({"image": max_txt_l, "video": max_txt_l}), |
|
"batch_size": EasyDict({"image": batch_size, "video": batch_size}), |
|
"batch_size_test": EasyDict({"image": batch_size_test, "video": batch_size_test}) |
|
}) |
|
|
|
|
|
self.text_enc = text_enc |
|
self.model = EasyDict(model or { |
|
"model_cls": "InternVideo2_Stage2", |
|
"vision_encoder": EasyDict({ |
|
"name": "pretrain_internvideo2_1b_patch14_224", |
|
"img_size": 224, |
|
"num_frames": num_frames, |
|
"tubelet_size": 1, |
|
"patch_size": 14, |
|
"d_model": 1408, |
|
"clip_embed_dim": 768, |
|
"clip_teacher_embed_dim": 3200, |
|
"clip_teacher_final_dim": 768, |
|
"clip_norm_type": "l2", |
|
"clip_return_layer": 6, |
|
"clip_student_return_interval": 1, |
|
"pretrained": None, |
|
"use_checkpoint": False, |
|
"checkpoint_num": 40, |
|
"use_flash_attn": True, |
|
"use_fused_rmsnorm": True, |
|
"use_fused_mlp": True, |
|
"clip_teacher": None, |
|
"clip_input_resolution": 224, |
|
"clip_teacher_return_interval": 1, |
|
"video_mask_type": "random", |
|
"video_mask_ratio": 0.8, |
|
"image_mask_type": "random", |
|
"image_mask_ratio": 0.5, |
|
"sep_image_video_pos_embed": True, |
|
"keep_temporal": False, |
|
"only_mask": True |
|
}), |
|
"text_encoder": text_enc, |
|
"multimodal": EasyDict({"enable": True}), |
|
"embed_dim": 512, |
|
"temp": 0.07, |
|
"find_unused_parameters": False |
|
}) |
|
|
|
|
|
self.criterion = EasyDict(criterion or { |
|
"loss_weight": EasyDict({ |
|
"vtc": 1.0, |
|
"mlm": 1.0, |
|
"vtm": 1.0, |
|
"mvm": 0.0, |
|
"uta": 0.0 |
|
}), |
|
"vtm_hard_neg": True, |
|
"mlm_masking_prob": 0.5, |
|
"distill_final_features": True, |
|
"clip_loss_ratio": [1.0, 1.0] |
|
}) |
|
|
|
|
|
self.optimizer = EasyDict(optimizer or { |
|
"opt": "adamW", |
|
"lr": 5e-5, |
|
"opt_betas": [0.9, 0.98], |
|
"weight_decay": 0.05, |
|
"max_grad_norm": 3.0, |
|
"different_lr": EasyDict({"enable": False, "module_names": [], "lr": 1e-3}) |
|
}) |
|
|
|
|
|
self.scheduler = EasyDict(scheduler or { |
|
"sched": "cosine", |
|
"epochs": 10, |
|
"min_lr_multi": 0.01, |
|
"warmup_epochs": 1 |
|
}) |
|
|
|
|
|
self.evaluate = evaluate |
|
self.deep_fusion = deep_fusion |
|
self.evaluation = EasyDict(evaluation or { |
|
"eval_frame_ensemble": "concat", |
|
"eval_x_only": False, |
|
"k_test": 128, |
|
"eval_offload": True |
|
}) |
|
|
|
|
|
self.use_half_precision = use_half_precision |
|
self.use_bf16 = use_bf16 |
|
self.gradient_checkpointing = gradient_checkpointing |
|
self.use_flash_sdp = use_flash_sdp |
|
self.use_mem_efficient_sdp = use_mem_efficient_sdp |
|
self.compile_model = compile_model |
|
|
|
self.wandb = EasyDict(wandb or { |
|
"enable": False, |
|
"entity": "opengvlab", |
|
"project": "InternVideo2-Stage2" |
|
}) |
|
|
|
self.dist_url = dist_url |
|
self.device = device |
|
self.mode = mode |
|
self.output_dir = output_dir |
|
self.resume = resume |
|
self.debug = debug |
|
self.log_freq = log_freq |
|
self.seed = seed |
|
|
|
self.save_latest = save_latest |
|
self.auto_resume = auto_resume |
|
self.jump_evaluate = jump_evaluate |
|
self.pretrained_path = pretrained_path |
|
self.save_ckpt_iter = save_ckpt_iter |
|
self.delete_ds_optim_states = delete_ds_optim_states |
|
|
|
self.deepspeed = EasyDict(deepspeed or { |
|
"enable": True, |
|
"stage": 1 |
|
}) |
|
def set_num_frames(self, num_frames): |
|
|
|
self.num_frames = num_frames |
|
self.inputs.video_input.num_frames = num_frames |
|
self.model.vision_encoder.num_frames = num_frames |