Reself commited on
Commit
0f2dd1c
1 Parent(s): 536f09b

Upload folder using huggingface_hub

Browse files
llm/pytorch_model-00001-of-00014.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f7f508240dba71e8272ef2de0331db16b0aad7b90b6de3a414f5251d2d514967
3
  size 1947779738
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f08ed91f9640a98601c3605499a586c3410914e611cf0ee85cc1f75d7061add2
3
  size 1947779738
llm/pytorch_model-00002-of-00014.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:098d87c2b14c05b788077baf97bf0c6d85873546ec09f94e29714802133d3aba
3
  size 1903236688
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ad177ce1be7bc0f3d41f56c762e7d2d246b4dd32ab8cf10a095fdf9e81e14b89
3
  size 1903236688
llm/pytorch_model-00003-of-00014.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:aed829f676a1c69618e42023a0d543f757076b507a51406a14b3b6a0c08e0738
3
  size 1903236688
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cc698a0828a2ca8d5c3783c938e9a78442db3d2bf649af0915c4bc914bc8c215
3
  size 1903236688
llm/pytorch_model-00004-of-00014.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:af90eb90d789cd5897aa9d6567a5fa4b6417ffc8d6d467a0c73df2d5bfa9fa79
3
  size 1903236688
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d3beb2415facca8bb85c0bcb45f94dda38afdc72bc4f8ea0f1ffeec729618e32
3
  size 1903236688
llm/pytorch_model-00005-of-00014.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:93c201b2cb75f6271c87e9dcae83dcf654c46fbd108698d01930174e173db4e6
3
  size 1903236752
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d107db54fe677de4296079bbcdf7e9b7a255cf2849d908e344bfe03a6dfcdd6a
3
  size 1903236752
llm/pytorch_model-00006-of-00014.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:8caa1e97eba055aff5bef97344f231bac6f3c1af7d59b5cd892c9559d897f857
3
  size 1903236752
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e273b5053a19ec8794f324272b19f659ad171a13df4ccbac4d69e80380c2a68e
3
  size 1903236752
llm/pytorch_model-00007-of-00014.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9968ee2634d9e1be745ad3094804938d73f52759babaa0fecba92dd6e429d7ea
3
  size 1903236752
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bd1fb9b44a08ce42901f385d774d7139eedd3df5cda692b6fd14a174d51e09e6
3
  size 1903236752
llm/pytorch_model-00008-of-00014.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:20d5aff66a5250c276b640b235ac1127d57ce0f99fae22f67a8ab3fca0b48b18
3
  size 1903236752
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:410b638a1d006ebc402a8b39f8b8a6befab94bfc9a015762659b97d7d9cc0f2d
3
  size 1903236752
llm/pytorch_model-00009-of-00014.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:70f37c7c0355b7f8946880f5a4d00827a9590f1cce334455ab84dcacc454c519
3
  size 1903236752
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0372d10309ef62b762a7f983e011162d897bb54fd4846836483befe04c1741f6
3
  size 1903236752
llm/pytorch_model-00010-of-00014.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:10263095ba406a983567fbeeb9c9eef9be441d71d231d230cbd6cd0544f9e7b9
3
  size 1903236752
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1348d915b61f874177bf4bc3c419d3e9790bf12d5cad827a8970cc2aea58c768
3
  size 1903236752
llm/pytorch_model-00011-of-00014.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:340a1bf3dc346cf46118cd9e3d4323bcf43ba36a4223d4bb2a54706c0c799795
3
  size 1903236752
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:150ec73a71f07471e2a2a2c7165febca406c81400e5ed8e9801c9241e8d0bda1
3
  size 1903236752
llm/pytorch_model-00012-of-00014.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a8936e2cd54590913d572390a199e24ef8e2545182848cf4488b1ec072fb4b62
3
  size 1903236752
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2e853bbc594dd3bfd4b7df0395719876478a7a86efa14e3f0e6ee8c4fa37db50
3
  size 1903236752
llm/pytorch_model-00013-of-00014.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7bf405e29337f52f4fb5858db29aa1a5e4b9c518bdda14f693ccfa7f4ba16b1b
3
  size 1903236752
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:216f2f90e253a2b172b183b5ed9104647a7f5bdd68598caba086e6617c1c4483
3
  size 1903236752
llm/pytorch_model-00014-of-00014.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:53aa958d86ba409b763bfec650e727e83b700bcc521e6fa937ff8ada5c1e2224
3
  size 1245241080
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9dd3358d15dac64ee9bac9ee571605e5b26b4ce891ed0aaf5e43b14e3d8d68a9
3
  size 1245241080
llm_adapter/adapter_config.json CHANGED
@@ -19,13 +19,13 @@
19
  "rank_pattern": {},
20
  "revision": null,
21
  "target_modules": [
 
22
  "v_proj",
23
- "up_proj",
24
  "o_proj",
25
  "k_proj",
26
- "q_proj",
27
- "down_proj",
28
- "gate_proj"
29
  ],
30
  "task_type": "CAUSAL_LM"
31
  }
 
19
  "rank_pattern": {},
20
  "revision": null,
21
  "target_modules": [
22
+ "gate_proj",
23
  "v_proj",
24
+ "q_proj",
25
  "o_proj",
26
  "k_proj",
27
+ "up_proj",
28
+ "down_proj"
 
29
  ],
30
  "task_type": "CAUSAL_LM"
31
  }
llm_adapter/adapter_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a17317836e0db05dfee44bad2c4f2890207421c2b7151a695afd077ed05cc567
3
- size 4005637552
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6ce89af9bf87941d625643d64c8e807fcdd16f00149f458e82b695732b667d64
3
+ size 8011198024
projector/config.json CHANGED
@@ -12,6 +12,6 @@
12
  "llm_hidden_size": 5120,
13
  "model_type": "projector",
14
  "torch_dtype": "float32",
15
- "transformers_version": "4.37.2",
16
  "visual_hidden_size": 1280
17
  }
 
12
  "llm_hidden_size": 5120,
13
  "model_type": "projector",
14
  "torch_dtype": "float32",
15
+ "transformers_version": "4.36.0",
16
  "visual_hidden_size": 1280
17
  }
projector/configuration_projector.py CHANGED
@@ -3,15 +3,15 @@ from transformers import PretrainedConfig
3
 
4
 
5
  class ProjectorConfig(PretrainedConfig):
6
- model_type = "projector"
7
- _auto_class = "AutoConfig"
8
 
9
  def __init__(
10
  self,
11
  visual_hidden_size=4096,
12
  llm_hidden_size=4096,
13
  depth=2,
14
- hidden_act="gelu",
15
  bias=True,
16
  **kwargs,
17
  ):
 
3
 
4
 
5
  class ProjectorConfig(PretrainedConfig):
6
+ model_type = 'projector'
7
+ _auto_class = 'AutoConfig'
8
 
9
  def __init__(
10
  self,
11
  visual_hidden_size=4096,
12
  llm_hidden_size=4096,
13
  depth=2,
14
+ hidden_act='gelu',
15
  bias=True,
16
  **kwargs,
17
  ):
projector/model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0c2126a4ee2ac250fe6f7e67c5eb00167494e852dee344949e0f43b1c4dfe7b2
3
  size 131113328
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:53a362f3b6e223ec295c65441f7af338c113046c0367bd6111eab5b9c6d7c668
3
  size 131113328
projector/modeling_projector.py CHANGED
@@ -8,22 +8,32 @@ from .configuration_projector import ProjectorConfig
8
 
9
 
10
  class ProjectorModel(PreTrainedModel):
11
- _auto_class = "AutoModel"
12
  config_class = ProjectorConfig
13
- base_model_prefix = "model"
14
  supports_gradient_checkpointing = True
15
 
16
  def __init__(self, config: ProjectorConfig) -> None:
17
  super().__init__(config)
18
  self.gradient_checkpointing = False
19
 
20
- modules = [nn.Linear(config.visual_hidden_size, config.llm_hidden_size, bias=config.bias)]
 
 
 
 
 
21
  for _ in range(1, config.depth):
22
  modules.append(ACT2FN[config.hidden_act])
23
- modules.append(nn.Linear(config.llm_hidden_size, config.llm_hidden_size, bias=config.bias))
 
 
 
 
24
  self.model = nn.Sequential(*modules)
25
 
26
  def enable_input_require_grads(self):
 
27
  def make_inputs_require_grad(module, input, output):
28
  output.requires_grad_(True)
29
 
 
8
 
9
 
10
  class ProjectorModel(PreTrainedModel):
11
+ _auto_class = 'AutoModel'
12
  config_class = ProjectorConfig
13
+ base_model_prefix = 'model'
14
  supports_gradient_checkpointing = True
15
 
16
  def __init__(self, config: ProjectorConfig) -> None:
17
  super().__init__(config)
18
  self.gradient_checkpointing = False
19
 
20
+ modules = [
21
+ nn.Linear(
22
+ config.visual_hidden_size,
23
+ config.llm_hidden_size,
24
+ bias=config.bias)
25
+ ]
26
  for _ in range(1, config.depth):
27
  modules.append(ACT2FN[config.hidden_act])
28
+ modules.append(
29
+ nn.Linear(
30
+ config.llm_hidden_size,
31
+ config.llm_hidden_size,
32
+ bias=config.bias))
33
  self.model = nn.Sequential(*modules)
34
 
35
  def enable_input_require_grads(self):
36
+
37
  def make_inputs_require_grad(module, input, output):
38
  output.requires_grad_(True)
39
 
visual_encoder_adapter/adapter_config.json CHANGED
@@ -1,8 +1,8 @@
1
  {
2
  "alpha_pattern": {},
3
  "auto_mapping": {
4
- "base_model_class": "PikaVidEncoder",
5
- "parent_library": "xtuner.model.video_encoder"
6
  },
7
  "base_model_name_or_path": "apple/DFN5B-CLIP-ViT-H-14-378",
8
  "bias": "none",
@@ -22,12 +22,12 @@
22
  "rank_pattern": {},
23
  "revision": null,
24
  "target_modules": [
25
- "out_proj",
26
  "v_proj",
 
27
  "k_proj",
28
- "fc1",
29
  "fc2",
30
- "q_proj"
31
  ],
32
  "task_type": null
33
  }
 
1
  {
2
  "alpha_pattern": {},
3
  "auto_mapping": {
4
+ "base_model_class": "CLIPVisionModel",
5
+ "parent_library": "transformers.models.clip.modeling_clip"
6
  },
7
  "base_model_name_or_path": "apple/DFN5B-CLIP-ViT-H-14-378",
8
  "bias": "none",
 
22
  "rank_pattern": {},
23
  "revision": null,
24
  "target_modules": [
25
+ "fc1",
26
  "v_proj",
27
+ "q_proj",
28
  "k_proj",
 
29
  "fc2",
30
+ "out_proj"
31
  ],
32
  "task_type": null
33
  }
visual_encoder_adapter/adapter_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0915794d1b139877e627831e1d36d2ef378872141a4c3712c220d00f6e326a74
3
  size 188800496
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:43d1f05a283041e0ba93f16f0af1c20b3c3302f2a5305d3dba36aad45671f306
3
  size 188800496
vit/pytorch_model-00001-of-00002.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4611f601856556b5a54e0631251c9279a2610b42825a51519430efc37d1a64ce
3
  size 1994332295
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b52fb32d49c7403927a82093851cb3efbd2deb98885356c8d068b6843cc5a10
3
  size 1994332295
vit/pytorch_model-00002-of-00002.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d0a9038c66b6348e1b86889c0a96bcb942b95e53c950ec85ca51345cd9318e51
3
  size 531341514
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:85a66172a48b3dfd33504ba0a2e4ee44f5f063644a262c7e64d8e0ed4cb17b58
3
  size 531341514
xtuner_config.py CHANGED
@@ -7,47 +7,44 @@ from transformers import (AutoModelForCausalLM, AutoTokenizer,
7
  BitsAndBytesConfig,
8
  CLIPImageProcessor, CLIPVisionModel)
9
  from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
 
10
  from peft import LoraConfig
11
- from math import sqrt
12
  from torch.optim import AdamW
13
- from xtuner.dataset import VideoDataset, PikaDataset, ConcatDataset, ShareGPTVideoDataset
14
  from xtuner.dataset.collate_fns import default_collate_fn
15
- from xtuner.dataset.map_fns import llava_video_map_fn, llava_map_fn, pika_map_fn, template_map_fn_factory
16
  from xtuner.dataset.samplers import LengthGroupedSampler
17
  from xtuner.engine import DatasetInfoHook, EvaluateChatHook
18
  from xtuner.model import PikaModel, PikaVidEncoder
19
  from xtuner.utils import PROMPT_TEMPLATE
20
 
21
-
22
  #######################################################################
23
  # PART 1 Settings #
24
  #######################################################################
25
  # Model
26
  llm_name_or_path = 'lmsys/vicuna-13b-v1.5-16k'
27
  visual_encoder_name_or_path = 'apple/DFN5B-CLIP-ViT-H-14-378'
28
- # Specify the s3 pretrained pth
29
- pretrained_pth = 'work_dirs/13b_16k_s5/iter_400.pth'
30
- prompt_template = PROMPT_TEMPLATE.vicuna
31
 
 
 
32
  size = 378
33
- # None for sampling all the video frames
34
- n_sample_frames = 32
35
- visual_token_merge_ratio = 0.1
36
- accumulative_counts = 32
37
- lr = 1e-4
38
- batch_size = 1 # per_device can only be set to 1 to support image and video mix training
39
 
40
- max_length = 4096
 
 
41
  dataloader_num_workers = 0
42
  max_epochs = 1
43
  optim_type = AdamW
44
  betas = (0.9, 0.999)
45
- weight_decay = 0.1
46
  max_norm = 1 # grad clip
47
  warmup_ratio = 0.03
48
 
49
  # Save
50
- save_steps = 500
51
  save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
52
 
53
  #######################################################################
@@ -93,21 +90,21 @@ model = dict(
93
  bias='none',
94
  task_type='CAUSAL_LM'),
95
  visual_encoder=dict(
96
- # type=CLIPVisionModel.from_pretrained,
97
- type=PikaVidEncoder.from_pretrained,
98
- pretrained_model_name_or_path=visual_encoder_name_or_path,
99
- visual_token_merge_ratio=visual_token_merge_ratio),
100
  visual_encoder_lora=dict(
101
- type=LoraConfig, r=64, lora_alpha=16, lora_dropout=0.05, bias='none'),
102
- )
103
-
 
 
104
 
105
  #######################################################################
106
  # PART 3 Dataset & Dataloader #
107
  #######################################################################
108
- allava_image_caption_dataset = dict(
109
  type=PikaDataset,
110
- data_path='./data/image_finetune/ALLaVA-Caption-LAION-4V',
111
  image_folder='./data/image_data',
112
  tokenizer=tokenizer,
113
  image_processor=image_processor,
@@ -115,40 +112,35 @@ allava_image_caption_dataset = dict(
115
  template_map_fn=dict(
116
  type=template_map_fn_factory, template=prompt_template),
117
  max_length=max_length,
118
- pad_image_to_square=False,
119
- keep_aspect_ratio=True,)
120
 
121
- sharegpt4v_video_caption_dataset = dict(
122
- type=ShareGPTVideoDataset,
123
- data_path='./data/video_finetune/sharegptvideo_caption_full_frame',
124
- image_folder='./data/video_data/sharegptvideo_900k',
125
  tokenizer=tokenizer,
126
  image_processor=image_processor,
127
- dataset_map_fn=llava_video_map_fn,
128
  template_map_fn=dict(
129
  type=template_map_fn_factory, template=prompt_template),
130
  max_length=max_length,
131
- pad_image_to_square=False,
132
- frame_number=n_sample_frames,
133
- keep_aspect_ratio=True,)
134
 
135
- # mix video and image
136
  train_dataset = dict(
137
  type=ConcatDataset,
138
  datasets=[
139
- allava_image_caption_dataset,
140
- sharegpt4v_video_caption_dataset,
141
- ])
142
 
143
  train_dataloader = dict(
144
  batch_size=batch_size,
145
  num_workers=dataloader_num_workers,
146
  dataset=train_dataset,
 
147
  # sampler=dict(
148
  # type=LengthGroupedSampler,
149
  # length_property='modality_length',
150
  # per_device_batch_size=batch_size * accumulative_counts),
151
- sampler=dict(type=DefaultSampler, shuffle=True),
152
  collate_fn=dict(type=default_collate_fn))
153
 
154
  #######################################################################
@@ -190,7 +182,7 @@ train_cfg = dict(by_epoch=True, max_epochs=max_epochs, val_interval=1)
190
  # PART 5 Runtime #
191
  #######################################################################
192
  # Evaluate the generation performance during the training
193
- evaluation_freq = 500
194
  SYSTEM = ''
195
  evaluation_images = 'https://llava-vl.github.io/static/images/view.jpg'
196
  evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture']
@@ -219,12 +211,7 @@ default_hooks = dict(
219
  # enable the parameter scheduler.
220
  param_scheduler=dict(type=ParamSchedulerHook),
221
  # save checkpoint per epoch.
222
- # checkpoint=dict(type=CheckpointHook, interval=1),
223
- checkpoint=dict(
224
- type=CheckpointHook,
225
- by_epoch=False,
226
- interval=save_steps,
227
- max_keep_ckpts=save_total_limit),
228
  # set sampler seed in distributed evrionment.
229
  sampler_seed=dict(type=DistSamplerSeedHook),
230
  )
 
7
  BitsAndBytesConfig,
8
  CLIPImageProcessor, CLIPVisionModel)
9
  from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
10
+
11
  from peft import LoraConfig
 
12
  from torch.optim import AdamW
13
+ from xtuner.dataset import PikaDataset, ConcatDataset
14
  from xtuner.dataset.collate_fns import default_collate_fn
15
+ from xtuner.dataset.map_fns import llava_map_fn, m3it_map_fn, template_map_fn_factory
16
  from xtuner.dataset.samplers import LengthGroupedSampler
17
  from xtuner.engine import DatasetInfoHook, EvaluateChatHook
18
  from xtuner.model import PikaModel, PikaVidEncoder
19
  from xtuner.utils import PROMPT_TEMPLATE
20
 
21
+
22
  #######################################################################
23
  # PART 1 Settings #
24
  #######################################################################
25
  # Model
26
  llm_name_or_path = 'lmsys/vicuna-13b-v1.5-16k'
27
  visual_encoder_name_or_path = 'apple/DFN5B-CLIP-ViT-H-14-378'
28
+ # Specify the s2 pretrained pth
29
+ pretrained_pth = 'work_dirs/13b_16k_s2/epoch_1.pth'
 
30
 
31
+ prompt_template = PROMPT_TEMPLATE.vicuna
32
+ max_length = 4096
33
  size = 378
 
 
 
 
 
 
34
 
35
+ batch_size = 16 # per_device
36
+ accumulative_counts = 1
37
+ lr = 2e-4
38
  dataloader_num_workers = 0
39
  max_epochs = 1
40
  optim_type = AdamW
41
  betas = (0.9, 0.999)
42
+ weight_decay = 0
43
  max_norm = 1 # grad clip
44
  warmup_ratio = 0.03
45
 
46
  # Save
47
+ save_steps = 100
48
  save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
49
 
50
  #######################################################################
 
90
  bias='none',
91
  task_type='CAUSAL_LM'),
92
  visual_encoder=dict(
93
+ type=CLIPVisionModel.from_pretrained,
94
+ pretrained_model_name_or_path=visual_encoder_name_or_path),
 
 
95
  visual_encoder_lora=dict(
96
+ type=LoraConfig,
97
+ r=64,
98
+ lora_alpha=16,
99
+ lora_dropout=0.05,
100
+ bias='none'))
101
 
102
  #######################################################################
103
  # PART 3 Dataset & Dataloader #
104
  #######################################################################
105
+ llava_dataset = dict(
106
  type=PikaDataset,
107
+ data_path='./data/image_finetune/llava_v1_5_mix665k',
108
  image_folder='./data/image_data',
109
  tokenizer=tokenizer,
110
  image_processor=image_processor,
 
112
  template_map_fn=dict(
113
  type=template_map_fn_factory, template=prompt_template),
114
  max_length=max_length,
115
+ pad_image_to_square=True)
 
116
 
117
+ train_dataset = dict(
118
+ type=PikaDataset,
119
+ data_path='./data/stage_3_part2',
120
+ image_folder='./data/image_data',
121
  tokenizer=tokenizer,
122
  image_processor=image_processor,
123
+ dataset_map_fn=llava_map_fn,
124
  template_map_fn=dict(
125
  type=template_map_fn_factory, template=prompt_template),
126
  max_length=max_length,
127
+ pad_image_to_square=True)
 
 
128
 
 
129
  train_dataset = dict(
130
  type=ConcatDataset,
131
  datasets=[
132
+ llava_dataset,
133
+ train_dataset])
 
134
 
135
  train_dataloader = dict(
136
  batch_size=batch_size,
137
  num_workers=dataloader_num_workers,
138
  dataset=train_dataset,
139
+ sampler=dict(type=DefaultSampler, shuffle=True),
140
  # sampler=dict(
141
  # type=LengthGroupedSampler,
142
  # length_property='modality_length',
143
  # per_device_batch_size=batch_size * accumulative_counts),
 
144
  collate_fn=dict(type=default_collate_fn))
145
 
146
  #######################################################################
 
182
  # PART 5 Runtime #
183
  #######################################################################
184
  # Evaluate the generation performance during the training
185
+ evaluation_freq = 100
186
  SYSTEM = ''
187
  evaluation_images = 'https://llava-vl.github.io/static/images/view.jpg'
188
  evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture']
 
211
  # enable the parameter scheduler.
212
  param_scheduler=dict(type=ParamSchedulerHook),
213
  # save checkpoint per epoch.
214
+ checkpoint=dict(type=CheckpointHook, interval=1),
 
 
 
 
 
215
  # set sampler seed in distributed evrionment.
216
  sampler_seed=dict(type=DistSamplerSeedHook),
217
  )