pllava-7b-demo / tasks /train /config_pllava_nframe.py
cathyxl
added
f239efc
raw
history blame contribute delete
No virus
3.91 kB
from tasks.train.instruction_data import *
# ========================= data ==========================
# train_corpus = "videochat2_instruction"
train_corpus = "videochat2_instruction_full"
train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation
test_file = dict()
test_types = []
num_workers = 8
save_steps=10000
ckpt_steps=1000
stop_key = None
deepspeed=False
# ========================= input ==========================
num_frames = 16
num_frames_test = 1
batch_size = 1
gradient_accumulation_steps=16
max_txt_l = 512
max_train_steps=None
pre_text = False
gradient_checkpointing=False
inputs = dict(
image_res=336,
video_input=dict(
num_frames="${num_frames}",
sample_type="rand",
num_frames_test="${num_frames_test}",
sample_type_test="middle",
random_aug=False,
),
max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"),
batch_size=dict(image="${batch_size}", video="${batch_size}"),
batch_size_test=dict(image="${batch_size}", video="${batch_size}"),
)
# ========================= model ==========================
model = dict(
repo_id="llava-hf/llava-v1.6-vicuna-7b-hf",
pretrained_path=None,
load_from_origin=False,
origin_vision="",
origin_llm="",
vision_encoder=dict(
name="vit_l14", # somehow need this to tell the dataset the mean std of pretrained model
),
torch_dtype='bfloat16',
freeze_projector=False,
freeze_lm=True,
freeze_vision_tower=True,
lora_target_modules=["q_proj", "v_proj"], # for llama/mistral/gemma
use_lora=True,
lora_r=128,
lora_alpha=32,
lora_dropout=0.05,
num_frames="${num_frames}",
pooling_method='avg',
use_pooling=True,
frame_shape=(24,24),
pooling_shape=(16,8,8),
)
preprocess = dict(
system="",
mm_alone=True,
random_shuffle=True,
add_second_msg=True,
roles=['USER:', 'ASSISTANT:'],
end_signal=(' ', '</s>'),
begin_signal='',
dataset_image_placeholder='<Image></Image>',
dataset_video_placeholder='<Video></Video>',
image_token_index=32000,
max_txt_l = "${max_txt_l}",
ignore_index=-100, # same as torch softmax ignore index
center_pad=False,
longest_edge=762,
shortest_edge=336,
clip_transform=False,
num_frames="${num_frames}",
)
optimizer = dict(
opt="adamW",
lr=2e-5,
opt_betas=[0.9, 0.999], # default
weight_decay=0.02,
max_grad_norm=-1, # requires a positive float, use -1 to disable
# use a different lr for some modules, e.g., larger lr for new modules
different_lr=dict(enable=False, module_names=[], lr=1e-3),
)
# scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.25, warmup_epochs=0.6)
# scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.25, warmup_epochs=0.6)
scheduler = dict(
is_videochat2_custom=False,
sched="cosine",
epochs=2,
warmup_ratio=0.2,
min_lr_multi=0.25)
evaluate = False
deep_fusion = False
evaluation = dict(
eval_frame_ensemble="concat", # [concat, max, mean, lse]
eval_x_only=False,
k_test=128,
eval_offload=True, # offload gpu tensors to cpu to save memory.
)
fp16 = True
gradient_checkpointing = True
# ========================= wandb ==========================
wandb = dict(
enable=False,
entity="user", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init
project="videochat2", # setup in your command line
)
dist_url = "env://"
device = "cuda"
mode = "it"
# ========================= others ==========================
output_dir = None # output dir
resume = False # if True, load optimizer and scheduler states as well
debug = False
log_freq = 5
metric_window_size=10 # window size for metric
seed = 42
report_to='tensorboard'
save_latest = True
auto_resume = True
pretrained_path = "" # path to pretrained model weights, for resume only?