Spaces:
Running
on
Zero
Running
on
Zero
| from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, | |
| LoggerHook, ParamSchedulerHook) | |
| from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR | |
| from torch.optim import AdamW | |
| from transformers import AutoTokenizer | |
| from xtuner.dataset import ConcatDataset | |
| from xtuner.dataset.samplers import LengthGroupedSampler | |
| from xtuner.engine.hooks import DatasetInfoHook | |
| from xtuner.engine.runner import TrainLoop | |
| from xtuner.utils import PROMPT_TEMPLATE | |
| from xtuner.dataset.map_fns import template_map_fn_factory | |
| from third_parts.mmdet.models.losses import DiceLoss, CrossEntropyLoss | |
| from peft import LoraConfig | |
| from projects.llava_sam2.models.internvl import InternVL_Slowfast | |
| from projects.llava_sam2.models import VideoLLaVASAMModel, SAM2TrainRunner, VideoLLaVASAMModel_zero3 | |
| from projects.llava_sam2.datasets import VideoReVOSDataset, VideoMeVISDataset, VideoRefYoutubeVOSDataset, video_lisa_collate_fn, VideoSAM2Dataset | |
| from projects.llava_sam2.datasets import VideoChatUniViDataset | |
| from projects.llava_sam2.datasets import RefCOCOgGCGDataset, OpenPsgGCGDataset, FlickrGCGDataset, GranDfGCGDataset, OspreyDataset, OspreyDescriptionDataset, OspreyShortDescriptionDataset | |
| from projects.llava_sam2.datasets import LLaVADataset | |
| from projects.llava_sam2.datasets import ReferSegmDataset | |
| from projects.llava_sam2.models.preprocess.image_resize import DirectResize | |
| ####################################################################### | |
| # PART 1 Settings # | |
| ####################################################################### | |
| # Model | |
| path = './pretrained/InternVL2_5-4B' | |
| pretrained_pth = None | |
| # Data | |
| prompt_template = PROMPT_TEMPLATE.phi3_chat | |
| max_length = 8192 | |
| # Scheduler & Optimizer | |
| batch_size = 2 # per_device | |
| accumulative_counts = 4 | |
| dataloader_num_workers = 4 | |
| max_epochs = 1 | |
| optim_type = AdamW | |
| # official 1024 -> 4e-5 | |
| # lr = 1e-6 | |
| lr = 4e-5 | |
| betas = (0.9, 0.999) | |
| weight_decay = 0.05 | |
| max_norm = 1 # grad clip | |
| warmup_ratio = 0.05 | |
| # Save | |
| save_steps = 1000 | |
| save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited) | |
| special_tokens = ['[SEG]', '<p>', '</p>', '<vp>', '</vp>'] | |
| tokenizer = dict( | |
| type=AutoTokenizer.from_pretrained, | |
| pretrained_model_name_or_path=path, | |
| trust_remote_code=True, | |
| padding_side='right') | |
| extra_image_processor = dict( | |
| type=DirectResize, | |
| target_length=1024, | |
| ) | |
| ####################################################################### | |
| # PART 2 Model & Tokenizer & Image Processor # | |
| ####################################################################### | |
| model = dict( | |
| type=VideoLLaVASAMModel_zero3, | |
| special_tokens=special_tokens, | |
| frozen_sam2_decoder=False, | |
| mllm=dict( | |
| type=InternVL_Slowfast, | |
| model_path=path, | |
| freeze_llm=True, | |
| freeze_visual_encoder=True, | |
| llm_lora=dict( | |
| type=LoraConfig, | |
| r=128, | |
| lora_alpha=256, | |
| lora_dropout=0.05, | |
| bias='none', | |
| task_type='CAUSAL_LM'), | |
| special_tokens=special_tokens, | |
| ), | |
| tokenizer=tokenizer, | |
| grounding_encoder=dict( | |
| type=SAM2TrainRunner, | |
| ), | |
| loss_mask=dict( | |
| type=CrossEntropyLoss, | |
| use_sigmoid=True, | |
| reduction='mean', | |
| loss_weight=2.0), | |
| loss_dice=dict( | |
| type=DiceLoss, | |
| use_sigmoid=True, | |
| activate=True, | |
| reduction='mean', | |
| naive_dice=True, | |
| eps=1.0, | |
| loss_weight=0.5), | |
| pretrained_pth=pretrained_pth, | |
| loss_sample_points=True, | |
| # loss_sample_points=False, | |
| bs=batch_size, | |
| ) | |
| ####################################################################### | |
| # PART 3 Dataset & Dataloader # | |
| ####################################################################### | |
| VIDEO_DATAS = './data/video_datas/' | |
| IMG_DATAS = './data/image_datas/' | |
| ############### video res | |
| data_root_revos = './data/video_datas/revos/' | |
| video_revos_image_folder = data_root_revos | |
| video_revos_expression_file = data_root_revos + 'meta_expressions_train_.json' | |
| video_revos_mask_file = data_root_revos + 'mask_dict.json' | |
| data_root_mevis = './data/video_datas/mevis/train/' | |
| video_mevis_image_folder = data_root_mevis + 'JPEGImages' | |
| video_mevis_expression_file = data_root_mevis + 'meta_expressions.json' | |
| video_mevis_mask_file = data_root_mevis + 'mask_dict.json' | |
| data_root_refytvos = './data/video_datas/rvos/' | |
| video_refytvos_image_folder = data_root_refytvos + 'train/JPEGImages/' | |
| video_refytvos_expression_file = data_root_refytvos + 'meta_expressions/train/meta_expressions.json' | |
| video_refytvos_mask_file = data_root_refytvos + 'mask_dict.pkl' | |
| video_revos_dataset = dict( | |
| type=VideoReVOSDataset, | |
| image_folder=video_revos_image_folder, | |
| expression_file=video_revos_expression_file, | |
| mask_file=video_revos_mask_file, | |
| tokenizer=tokenizer, | |
| template_map_fn=dict( | |
| type=template_map_fn_factory, template=prompt_template), | |
| max_length=max_length, | |
| lazy=True, | |
| repeats=10, | |
| special_tokens=special_tokens, | |
| extra_image_processor=extra_image_processor, | |
| sampled_frames=5, | |
| ) | |
| video_mevis_dataset = dict( | |
| type=VideoMeVISDataset, | |
| image_folder=video_mevis_image_folder, | |
| expression_file=video_mevis_expression_file, | |
| mask_file=video_mevis_mask_file, | |
| tokenizer=tokenizer, | |
| template_map_fn=dict( | |
| type=template_map_fn_factory, template=prompt_template), | |
| max_length=max_length, | |
| lazy=True, | |
| repeats=4, | |
| special_tokens=special_tokens, | |
| extra_image_processor=extra_image_processor, | |
| sampled_frames=5, | |
| ) | |
| video_refytvos_dataset = dict( | |
| type=VideoRefYoutubeVOSDataset, | |
| image_folder=video_refytvos_image_folder, | |
| expression_file=video_refytvos_expression_file, | |
| mask_file=video_refytvos_mask_file, | |
| tokenizer=tokenizer, | |
| template_map_fn=dict( | |
| type=template_map_fn_factory, template=prompt_template), | |
| max_length=max_length, | |
| lazy=True, | |
| repeats=4, | |
| special_tokens=special_tokens, | |
| extra_image_processor=extra_image_processor, | |
| sampled_frames=5, | |
| ) | |
| ################### Video chat | |
| data_root_video_chatunivi = VIDEO_DATAS + 'video_vlm/video_chat/' | |
| video_chatunivi_image_folder = data_root_video_chatunivi + 'Activity_Videos/' | |
| video_chatunivi_json_file = data_root_video_chatunivi+ 'video_chat.json' | |
| video_qa_dataset = dict( | |
| type=VideoChatUniViDataset, | |
| image_folder=video_chatunivi_image_folder, | |
| json_file=video_chatunivi_json_file, | |
| tokenizer=tokenizer, | |
| template_map_fn=dict( | |
| type=template_map_fn_factory, template=prompt_template), | |
| max_length=max_length, | |
| lazy=True, | |
| repeats=1, | |
| special_tokens=special_tokens, | |
| extra_image_processor=extra_image_processor, | |
| sampled_frames=5, | |
| ) | |
| ################## image chat | |
| llava_vqa_dataset = dict( | |
| type=LLaVADataset, | |
| tokenizer=tokenizer, | |
| data_path='data/llava_data/LLaVA-Instruct-150K/llava_v1_5_mix665k.json', | |
| prompt_template=prompt_template, | |
| special_tokens=special_tokens, | |
| image_folder='data/llava_data/llava_images/', | |
| ) | |
| ################## image res | |
| refcoco_segm_dataset=dict( | |
| type=ReferSegmDataset, | |
| tokenizer=tokenizer, | |
| special_tokens=special_tokens, | |
| extra_image_processor=extra_image_processor, | |
| data_root='data/ref_seg/refcoco', | |
| data_prefix=dict(img_path='coco2014/train2014/'), | |
| ann_file='instances.json', | |
| split_file='refs(unc).p', | |
| prompt_template=prompt_template, | |
| num_classes_per_sample=5, | |
| max_length=max_length, | |
| ) | |
| refcoco_plus_segm_dataset=dict( | |
| type=ReferSegmDataset, | |
| tokenizer=tokenizer, | |
| special_tokens=special_tokens, | |
| extra_image_processor=extra_image_processor, | |
| data_root='data/ref_seg/refcoco+', | |
| data_prefix=dict(img_path='coco2014/train2014/'), | |
| ann_file='instances.json', | |
| split_file='refs(unc).p', | |
| prompt_template=prompt_template, | |
| num_classes_per_sample=5, | |
| max_length=max_length, | |
| ) | |
| refcocog_segm_dataset=dict( | |
| type=ReferSegmDataset, | |
| tokenizer=tokenizer, | |
| special_tokens=special_tokens, | |
| extra_image_processor=extra_image_processor, | |
| data_root='data/ref_seg/refcocog', | |
| data_prefix=dict(img_path='coco2014/train2014/'), | |
| ann_file='instances.json', | |
| split_file='refs(umd).p', | |
| prompt_template=prompt_template, | |
| num_classes_per_sample=5, | |
| max_length=max_length, | |
| ) | |
| # image gcg datas | |
| glamm_data_root = './data/glamm_data/' | |
| refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/' | |
| refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json' | |
| grandf_image_path = glamm_data_root + 'images/grandf/train/' | |
| grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json' | |
| flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/' | |
| flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json' | |
| psg_image_path = glamm_data_root + 'images/coco2017/' | |
| psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json' | |
| glamm_refcocog_dataset = dict( | |
| type=RefCOCOgGCGDataset, | |
| image_folder=refcocog_image_path, | |
| data_path=refcocog_ann_file, | |
| tokenizer=tokenizer, | |
| max_length=max_length, | |
| special_tokens=special_tokens, | |
| template_map_fn=dict(type=template_map_fn_factory, template=prompt_template), | |
| extra_image_processor=extra_image_processor, | |
| lazy=True, | |
| repeats=1, | |
| ) | |
| glamm_grandf_dataset = dict( | |
| type=GranDfGCGDataset, | |
| data_path=grandf_ann_file, | |
| image_folder=grandf_image_path, | |
| tokenizer=tokenizer, | |
| max_length=max_length, | |
| special_tokens=special_tokens, | |
| template_map_fn=dict(type=template_map_fn_factory, template=prompt_template), | |
| extra_image_processor=extra_image_processor, | |
| lazy=True, | |
| repeats=10, | |
| ) | |
| glamm_psg_dataset = dict( | |
| type=OpenPsgGCGDataset, | |
| data_path=psg_ann_file, | |
| image_folder=psg_image_path, | |
| tokenizer=tokenizer, | |
| max_length=max_length, | |
| special_tokens=special_tokens, | |
| template_map_fn=dict(type=template_map_fn_factory, template=prompt_template), | |
| extra_image_processor=extra_image_processor, | |
| lazy=True, | |
| repeats=1, | |
| ) | |
| glamm_flickr_dataset = dict( | |
| type=FlickrGCGDataset, | |
| data_path=flickr_ann_file, | |
| image_folder=flickr_image_path, | |
| tokenizer=tokenizer, | |
| max_length=max_length, | |
| special_tokens=special_tokens, | |
| template_map_fn=dict(type=template_map_fn_factory, template=prompt_template), | |
| extra_image_processor=extra_image_processor, | |
| lazy=True, | |
| repeats=1, | |
| ) | |
| # sam2 data | |
| data_sam2_folder = VIDEO_DATAS + 'segmentation_datasets/sam_v_full/' | |
| data_sam2_expression_file = './whole_pesudo_cap_v3/sam_v_final_v3.json' | |
| video_sam2_dataset = dict( | |
| type=VideoSAM2Dataset, | |
| sam2_folder=data_sam2_folder, | |
| expression_file=data_sam2_expression_file, | |
| tokenizer=tokenizer, | |
| template_map_fn=dict( | |
| type=template_map_fn_factory, template=prompt_template), | |
| max_length=max_length, | |
| lazy=True, | |
| repeats=4, | |
| special_tokens=special_tokens, | |
| extra_image_processor=extra_image_processor, | |
| sampled_frames=5, | |
| select_number=5, | |
| ) | |
| # osprey | |
| data_osprey_file = VIDEO_DATAS + 'osprey-724k/Osprey-724K/osprey_conversation.json' | |
| data_osprey_image_folders = [ | |
| IMG_DATAS+ 'coco/train2014/', | |
| IMG_DATAS + 'coco/val2014/', | |
| IMG_DATAS + 'coco/train2017/', | |
| IMG_DATAS + 'coco/val2017/', | |
| ] | |
| image_osprey_dataset = dict( | |
| type=OspreyDataset, | |
| image_folder=data_osprey_image_folders, | |
| data_path=data_osprey_file, | |
| tokenizer=tokenizer, | |
| template_map_fn=dict( | |
| type=template_map_fn_factory, template=prompt_template), | |
| max_length=max_length, | |
| lazy=True, | |
| repeats=1, | |
| special_tokens=special_tokens, | |
| ) | |
| data_osprey_detail_description_file = VIDEO_DATAS + 'osprey-724k/Osprey-724K/osprey_detail_description.json' | |
| image_osprey_description_dataset = dict( | |
| type=OspreyDescriptionDataset, | |
| image_folder=data_osprey_image_folders, | |
| data_path=data_osprey_detail_description_file, | |
| tokenizer=tokenizer, | |
| template_map_fn=dict( | |
| type=template_map_fn_factory, template=prompt_template), | |
| max_length=max_length, | |
| lazy=True, | |
| repeats=1, | |
| special_tokens=special_tokens, | |
| ) | |
| data_osprey_short_file = VIDEO_DATAS + 'osprey-724k/Osprey-724K/osprey_short_form.json' | |
| image_osprey_short_dataset = dict( | |
| type=OspreyShortDescriptionDataset, | |
| image_folder=data_osprey_image_folders, | |
| data_path=data_osprey_short_file, | |
| tokenizer=tokenizer, | |
| template_map_fn=dict( | |
| type=template_map_fn_factory, template=prompt_template), | |
| max_length=max_length, | |
| lazy=True, | |
| repeats=1, | |
| special_tokens=special_tokens, | |
| ) | |
| data_osprey_part_file = VIDEO_DATAS + 'osprey-724k/Osprey-724K/osprey_part_level.json' | |
| image_osprey_part_dataset = dict( | |
| type=OspreyDataset, | |
| image_folder=data_osprey_image_folders, | |
| data_path=data_osprey_part_file, | |
| tokenizer=tokenizer, | |
| template_map_fn=dict( | |
| type=template_map_fn_factory, template=prompt_template), | |
| max_length=max_length, | |
| lazy=True, | |
| repeats=1, | |
| special_tokens=special_tokens, | |
| ) | |
| data_osprey_positive_neg_file = VIDEO_DATAS + 'osprey-724k/Osprey-724K/osprey_lvis_positive_negative.json' | |
| image_osprey_positive_neg_dataset = dict( | |
| type=OspreyDataset, | |
| image_folder=data_osprey_image_folders, | |
| data_path=data_osprey_positive_neg_file, | |
| tokenizer=tokenizer, | |
| template_map_fn=dict( | |
| type=template_map_fn_factory, template=prompt_template), | |
| max_length=max_length, | |
| lazy=True, | |
| repeats=1, | |
| special_tokens=special_tokens, | |
| ) | |
| train_dataset = dict( | |
| type=ConcatDataset, datasets=[ | |
| # sem seg | |
| # semantic_seg_ade20k_dataset, | |
| # ref seg | |
| refcoco_segm_dataset, refcoco_plus_segm_dataset, refcocog_segm_dataset, | |
| refcoco_segm_dataset, refcoco_plus_segm_dataset, refcocog_segm_dataset, | |
| refcoco_segm_dataset, refcoco_plus_segm_dataset, refcocog_segm_dataset, | |
| refcoco_segm_dataset, refcoco_plus_segm_dataset, refcocog_segm_dataset, | |
| # image qa | |
| llava_vqa_dataset, | |
| # video res | |
| video_mevis_dataset, video_revos_dataset, video_refytvos_dataset, | |
| # video chat | |
| video_qa_dataset, | |
| # sam2 pesudo | |
| video_sam2_dataset, | |
| # gcg data | |
| glamm_psg_dataset, | |
| glamm_grandf_dataset, | |
| glamm_flickr_dataset, | |
| glamm_refcocog_dataset, | |
| # visual prompt | |
| image_osprey_dataset, image_osprey_description_dataset, | |
| image_osprey_part_dataset, image_osprey_short_dataset, | |
| image_osprey_positive_neg_dataset, | |
| ] | |
| ) | |
| train_dataloader = dict( | |
| batch_size=batch_size, | |
| num_workers=dataloader_num_workers, | |
| dataset=train_dataset, | |
| sampler=dict( | |
| type=LengthGroupedSampler, | |
| length_property='modality_length', | |
| per_device_batch_size=batch_size * accumulative_counts), | |
| collate_fn=dict(type=video_lisa_collate_fn) | |
| ) | |
| ####################################################################### | |
| # PART 4 Scheduler & Optimizer # | |
| ####################################################################### | |
| # optimizer | |
| optim_wrapper = dict( | |
| type=AmpOptimWrapper, | |
| optimizer=dict( | |
| type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), | |
| clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), | |
| accumulative_counts=accumulative_counts, | |
| loss_scale='dynamic', | |
| dtype='bfloat16' | |
| ) | |
| # learning policy | |
| # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 | |
| param_scheduler = [ | |
| dict( | |
| type=LinearLR, | |
| start_factor=1e-5, | |
| by_epoch=True, | |
| begin=0, | |
| end=warmup_ratio * max_epochs, | |
| convert_to_iter_based=True), | |
| dict( | |
| type=CosineAnnealingLR, | |
| eta_min=0.0, | |
| by_epoch=True, | |
| begin=warmup_ratio * max_epochs, | |
| end=max_epochs, | |
| convert_to_iter_based=True) | |
| ] | |
| # train, val, test setting | |
| train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) | |
| ####################################################################### | |
| # PART 5 Runtime # | |
| ####################################################################### | |
| # Log the dialogue periodically during the training process, optional | |
| custom_hooks = [ | |
| # dict(type=DatasetInfoHook, tokenizer=tokenizer), | |
| ] | |
| # configure default hooks | |
| default_hooks = dict( | |
| # record the time of every iteration. | |
| timer=dict(type=IterTimerHook), | |
| # print log every 10 iterations. | |
| logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), | |
| # enable the parameter scheduler. | |
| param_scheduler=dict(type=ParamSchedulerHook), | |
| # save checkpoint per `save_steps`. | |
| checkpoint=dict( | |
| type=CheckpointHook, | |
| save_optimizer=False, | |
| by_epoch=False, | |
| interval=save_steps, | |
| max_keep_ckpts=save_total_limit), | |
| # set sampler seed in distributed evrionment. | |
| sampler_seed=dict(type=DistSamplerSeedHook), | |
| ) | |
| # configure environment | |
| env_cfg = dict( | |
| # whether to enable cudnn benchmark | |
| cudnn_benchmark=False, | |
| # set multi process parameters | |
| mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), | |
| # set distributed parameters | |
| dist_cfg=dict(backend='nccl'), | |
| ) | |
| # set visualizer | |
| visualizer = None | |
| # set log level | |
| log_level = 'INFO' | |
| # load from which checkpoint | |
| load_from = None | |
| # whether to resume training from the loaded checkpoint | |
| resume = False | |
| # Defaults to use random seed and disable `deterministic` | |
| randomness = dict(seed=None, deterministic=False) | |
| # set log processor | |
| log_processor = dict(by_epoch=False) | |