Add files using upload-large-folder tool
Browse files- ChatUniVi/eval/questions/video_qa/msvd_qa.json +0 -0
- ChatUniVi/eval/questions/video_qa/temporal_qa.json +0 -0
- ChatUniVi/eval/questions/video_qa/tgif_a_list.json +1309 -0
- ChatUniVi/eval/questions/video_qa/tgif_qa.json +0 -0
- ChatUniVi/eval/table/caps_boxes_coco2014_val_80.jsonl +80 -0
- ChatUniVi/eval/table/model.jsonl +5 -0
- ChatUniVi/eval/table/question.jsonl +80 -0
- ChatUniVi/eval/table/reviewer.jsonl +4 -0
- ChatUniVi/eval/table/rule.json +11 -0
- ChatUniVi/model/__init__.py +1 -0
- ChatUniVi/model/apply_delta.py +44 -0
- ChatUniVi/model/arch.py +652 -0
- ChatUniVi/model/builder.py +118 -0
- ChatUniVi/model/cluster.py +287 -0
- ChatUniVi/model/consolidate.py +29 -0
- ChatUniVi/model/dataloader.py +67 -0
- ChatUniVi/model/language_model/language_model/configuration_phi.py +62 -0
- ChatUniVi/model/language_model/language_model/modeling_phi.py +984 -0
- ChatUniVi/model/language_model/llama.py +136 -0
- ChatUniVi/model/language_model/phi.py +142 -0
- ChatUniVi/model/make_delta.py +52 -0
- ChatUniVi/model/multimodal_encoder/builder.py +14 -0
- ChatUniVi/model/multimodal_encoder/clip_encoder.py +83 -0
- ChatUniVi/model/multimodal_encoder/eva_encoder.py +81 -0
- ChatUniVi/model/multimodal_encoder/eva_vit.py +448 -0
- ChatUniVi/model/multimodal_encoder/processor.py +68 -0
- ChatUniVi/model/multimodal_encoder/utils.py +137 -0
- ChatUniVi/model/multimodal_projector/builder.py +52 -0
- ChatUniVi/train/llama_flash_attn_monkey_patch.py +124 -0
- ChatUniVi/train/train.py +1232 -0
- ChatUniVi/train/train_mem.py +13 -0
- ChatUniVi/train/trainer.py +53 -0
- configs/__init__.py +1 -0
- configs/config.py +84 -0
- data/metadata.csv +0 -0
ChatUniVi/eval/questions/video_qa/msvd_qa.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ChatUniVi/eval/questions/video_qa/temporal_qa.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ChatUniVi/eval/questions/video_qa/tgif_a_list.json
ADDED
|
@@ -0,0 +1,1309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
"cookie",
|
| 3 |
+
"? machine",
|
| 4 |
+
"two",
|
| 5 |
+
"glasses",
|
| 6 |
+
"black",
|
| 7 |
+
"tail",
|
| 8 |
+
"red",
|
| 9 |
+
"flowers",
|
| 10 |
+
"laptop",
|
| 11 |
+
"three",
|
| 12 |
+
"white",
|
| 13 |
+
"green",
|
| 14 |
+
"? boat",
|
| 15 |
+
"blue",
|
| 16 |
+
"? room",
|
| 17 |
+
"brown",
|
| 18 |
+
"cat",
|
| 19 |
+
"picture",
|
| 20 |
+
"drink",
|
| 21 |
+
"cigarette",
|
| 22 |
+
"clock",
|
| 23 |
+
"car",
|
| 24 |
+
"monkey",
|
| 25 |
+
"guitar",
|
| 26 |
+
"purple",
|
| 27 |
+
"? kitchen",
|
| 28 |
+
"? mirror",
|
| 29 |
+
"meal",
|
| 30 |
+
"four",
|
| 31 |
+
"? tank",
|
| 32 |
+
"? classroom",
|
| 33 |
+
"dog",
|
| 34 |
+
"pipe",
|
| 35 |
+
"leaf",
|
| 36 |
+
"shirt",
|
| 37 |
+
"champagne",
|
| 38 |
+
"string",
|
| 39 |
+
"sweater",
|
| 40 |
+
"? studio",
|
| 41 |
+
"tortoise",
|
| 42 |
+
"and one of them is holding ? dog",
|
| 43 |
+
"rings",
|
| 44 |
+
"vehicles",
|
| 45 |
+
"lollipop",
|
| 46 |
+
"candy",
|
| 47 |
+
"bottle",
|
| 48 |
+
"then a man is shown sitting . ? locker",
|
| 49 |
+
"parakeets",
|
| 50 |
+
"hole",
|
| 51 |
+
"tie",
|
| 52 |
+
"boat",
|
| 53 |
+
"ball",
|
| 54 |
+
"cash",
|
| 55 |
+
"chicken",
|
| 56 |
+
"? street",
|
| 57 |
+
"bird",
|
| 58 |
+
"six",
|
| 59 |
+
"? pool",
|
| 60 |
+
"window",
|
| 61 |
+
"round",
|
| 62 |
+
"instrument",
|
| 63 |
+
"puppy",
|
| 64 |
+
"doorway",
|
| 65 |
+
"juice",
|
| 66 |
+
"flamethrower",
|
| 67 |
+
"gray",
|
| 68 |
+
"dress",
|
| 69 |
+
"hat",
|
| 70 |
+
"kitten",
|
| 71 |
+
"gun",
|
| 72 |
+
"cars",
|
| 73 |
+
"paws",
|
| 74 |
+
"elephant",
|
| 75 |
+
"beam",
|
| 76 |
+
"? chair",
|
| 77 |
+
"chimp",
|
| 78 |
+
"one",
|
| 79 |
+
"butt",
|
| 80 |
+
"mascara",
|
| 81 |
+
"dogs",
|
| 82 |
+
"puppet",
|
| 83 |
+
"hamster",
|
| 84 |
+
"? bedroom",
|
| 85 |
+
"who pretends to slap him in return ? crack",
|
| 86 |
+
"machine",
|
| 87 |
+
"drops",
|
| 88 |
+
"then he removes and throws it to the ground ? hat",
|
| 89 |
+
"when two of the cyclist crash ? bicycles",
|
| 90 |
+
"cannabis",
|
| 91 |
+
"? trap",
|
| 92 |
+
"helmet",
|
| 93 |
+
"motorcycle",
|
| 94 |
+
"purses",
|
| 95 |
+
"bank",
|
| 96 |
+
"orange",
|
| 97 |
+
"guitars",
|
| 98 |
+
"? crib",
|
| 99 |
+
"hedgehog",
|
| 100 |
+
"? hallway",
|
| 101 |
+
"? car",
|
| 102 |
+
"steps",
|
| 103 |
+
"horse",
|
| 104 |
+
"? bath",
|
| 105 |
+
"drawer",
|
| 106 |
+
"cats",
|
| 107 |
+
"duck",
|
| 108 |
+
"wearing , reads a piece of paper on a desk and then raises his head ? glasses",
|
| 109 |
+
"phone",
|
| 110 |
+
"pillow",
|
| 111 |
+
"cup",
|
| 112 |
+
"he has food in front of him . ? chair",
|
| 113 |
+
"surfboard",
|
| 114 |
+
"before one of them climbs from the ring ? two",
|
| 115 |
+
"dancing , and clapping ? four",
|
| 116 |
+
"pool",
|
| 117 |
+
"motorcycles",
|
| 118 |
+
"pictures",
|
| 119 |
+
"? star",
|
| 120 |
+
"clipboard",
|
| 121 |
+
"paw",
|
| 122 |
+
"kiss ? two",
|
| 123 |
+
"turtle",
|
| 124 |
+
"when one touches the other on the shoulder ? two",
|
| 125 |
+
"? house",
|
| 126 |
+
"five",
|
| 127 |
+
"locker",
|
| 128 |
+
"tree",
|
| 129 |
+
"bat",
|
| 130 |
+
"popcorn",
|
| 131 |
+
"broom",
|
| 132 |
+
"guns",
|
| 133 |
+
"paint",
|
| 134 |
+
"seat",
|
| 135 |
+
"and then they run away ? heels",
|
| 136 |
+
"flags",
|
| 137 |
+
"dice",
|
| 138 |
+
"? library",
|
| 139 |
+
"yellow",
|
| 140 |
+
"chair",
|
| 141 |
+
"door",
|
| 142 |
+
"? warehouse",
|
| 143 |
+
"kick it and fall over ? tire",
|
| 144 |
+
"jacket",
|
| 145 |
+
"wire",
|
| 146 |
+
"crow",
|
| 147 |
+
"motions",
|
| 148 |
+
"bubbles",
|
| 149 |
+
"vehicle",
|
| 150 |
+
"wearing and speaking ? necklace",
|
| 151 |
+
"one is dressed funny , look at each other ? two",
|
| 152 |
+
"mice",
|
| 153 |
+
"clothing",
|
| 154 |
+
"bread",
|
| 155 |
+
"fireworks",
|
| 156 |
+
"microphone",
|
| 157 |
+
"mascot",
|
| 158 |
+
"? booth",
|
| 159 |
+
"wolf",
|
| 160 |
+
"? foyer",
|
| 161 |
+
"driver",
|
| 162 |
+
"cylinder",
|
| 163 |
+
"on top of his food bowl ? dog",
|
| 164 |
+
"rabbit",
|
| 165 |
+
"? office",
|
| 166 |
+
"treadmill",
|
| 167 |
+
"cap",
|
| 168 |
+
"tire",
|
| 169 |
+
"stick",
|
| 170 |
+
"is laying and opening her eyes . ? bed",
|
| 171 |
+
"stairs",
|
| 172 |
+
"drums",
|
| 173 |
+
"bar",
|
| 174 |
+
"? bed",
|
| 175 |
+
"spoons",
|
| 176 |
+
"? lab",
|
| 177 |
+
"headphones",
|
| 178 |
+
"one is . ? basket",
|
| 179 |
+
"makeup",
|
| 180 |
+
"frogs",
|
| 181 |
+
"wine",
|
| 182 |
+
"two men sit on a sofa and a man dances along a red carpet ? rectangle",
|
| 183 |
+
"sauce",
|
| 184 |
+
"airplane",
|
| 185 |
+
"and he is playing ? guitar",
|
| 186 |
+
"fox",
|
| 187 |
+
"costume",
|
| 188 |
+
"slide",
|
| 189 |
+
"stamp",
|
| 190 |
+
"butts",
|
| 191 |
+
"? window",
|
| 192 |
+
"rope",
|
| 193 |
+
"receiver",
|
| 194 |
+
"then the dog turns around crazy ? butt",
|
| 195 |
+
"and one talks to someone . ? room",
|
| 196 |
+
"? aisle",
|
| 197 |
+
"headset",
|
| 198 |
+
"horses",
|
| 199 |
+
"handgun",
|
| 200 |
+
"bear",
|
| 201 |
+
"napkin",
|
| 202 |
+
"? bottle",
|
| 203 |
+
"frog",
|
| 204 |
+
"wearing , animal print pants and pink shoes is dancing on a sidewalk ? shirt",
|
| 205 |
+
"bicycle",
|
| 206 |
+
"button",
|
| 207 |
+
"panda",
|
| 208 |
+
"turtles",
|
| 209 |
+
"but keeps flying ? airplane",
|
| 210 |
+
"? headset",
|
| 211 |
+
"lobby",
|
| 212 |
+
"pelican",
|
| 213 |
+
"dive",
|
| 214 |
+
"? cage",
|
| 215 |
+
"dishes",
|
| 216 |
+
"wagon",
|
| 217 |
+
"seven",
|
| 218 |
+
"? bag",
|
| 219 |
+
"butterfly",
|
| 220 |
+
"flask",
|
| 221 |
+
"banana",
|
| 222 |
+
"flasks",
|
| 223 |
+
"bus",
|
| 224 |
+
"device",
|
| 225 |
+
"is riding through the house ? bicycle",
|
| 226 |
+
"bright lightning ? sky",
|
| 227 |
+
"umbrellas",
|
| 228 |
+
"yawns then puts out its paw and pushes a jar off onto the floor ? cat",
|
| 229 |
+
"skateboard",
|
| 230 |
+
"cupcakes",
|
| 231 |
+
"shoe",
|
| 232 |
+
"cloak",
|
| 233 |
+
"apple",
|
| 234 |
+
"wall",
|
| 235 |
+
"horns",
|
| 236 |
+
"trick",
|
| 237 |
+
"date",
|
| 238 |
+
"he is talking to a woman ? beer",
|
| 239 |
+
"hill",
|
| 240 |
+
"? bar",
|
| 241 |
+
"pieces",
|
| 242 |
+
"stars",
|
| 243 |
+
"and the bowl disappears ? dog",
|
| 244 |
+
"bridge",
|
| 245 |
+
"box",
|
| 246 |
+
"with one of them embracing the other from behind ? two",
|
| 247 |
+
"piano",
|
| 248 |
+
"? hall",
|
| 249 |
+
"coffee",
|
| 250 |
+
"peel",
|
| 251 |
+
"cutter",
|
| 252 |
+
"circle",
|
| 253 |
+
"sunglasses",
|
| 254 |
+
"star",
|
| 255 |
+
"? pen",
|
| 256 |
+
"they move slowly ? stairs",
|
| 257 |
+
"kitty",
|
| 258 |
+
"pen",
|
| 259 |
+
"owl",
|
| 260 |
+
"puppies",
|
| 261 |
+
"fish",
|
| 262 |
+
"keyboard",
|
| 263 |
+
"underwear",
|
| 264 |
+
"? gym",
|
| 265 |
+
"pigeon",
|
| 266 |
+
"retriever",
|
| 267 |
+
"masks",
|
| 268 |
+
"kangaroo",
|
| 269 |
+
"close",
|
| 270 |
+
"shorts",
|
| 271 |
+
"band",
|
| 272 |
+
"swimming",
|
| 273 |
+
"? plate",
|
| 274 |
+
"then another man reaches for it ? gun",
|
| 275 |
+
"face",
|
| 276 |
+
"ferret",
|
| 277 |
+
"drug",
|
| 278 |
+
"clothes",
|
| 279 |
+
"spoon",
|
| 280 |
+
"hurdle",
|
| 281 |
+
"grass",
|
| 282 |
+
"? paint",
|
| 283 |
+
"airplanes",
|
| 284 |
+
"talks",
|
| 285 |
+
"whose lights flash on ? flower",
|
| 286 |
+
"with one drumming ? instruments",
|
| 287 |
+
"? bowl",
|
| 288 |
+
"burger",
|
| 289 |
+
"llama",
|
| 290 |
+
"it licks its lips ? horse",
|
| 291 |
+
"? holder",
|
| 292 |
+
"camel",
|
| 293 |
+
"dancing",
|
| 294 |
+
"umbrella",
|
| 295 |
+
"pants",
|
| 296 |
+
"ducklings",
|
| 297 |
+
"mug",
|
| 298 |
+
"necklace",
|
| 299 |
+
"track",
|
| 300 |
+
"smoking and turning her head ? cigarette",
|
| 301 |
+
"ladder",
|
| 302 |
+
"cliff",
|
| 303 |
+
"shirts",
|
| 304 |
+
"shark",
|
| 305 |
+
"is playing ? ukulele",
|
| 306 |
+
"turns",
|
| 307 |
+
"? ball",
|
| 308 |
+
"scooter",
|
| 309 |
+
"? box",
|
| 310 |
+
"? road",
|
| 311 |
+
"cover",
|
| 312 |
+
". ? cage",
|
| 313 |
+
"backhoe",
|
| 314 |
+
"bed",
|
| 315 |
+
"and she is holding up ? puppet",
|
| 316 |
+
"? two",
|
| 317 |
+
"goblet",
|
| 318 |
+
"is using and smoking a cigarette ? phone",
|
| 319 |
+
"wearing coats , is hugging . ? hallway",
|
| 320 |
+
"but he misses ? ball",
|
| 321 |
+
"diver",
|
| 322 |
+
"? nightclub",
|
| 323 |
+
"they both smile ? round",
|
| 324 |
+
"medic",
|
| 325 |
+
"? stick",
|
| 326 |
+
"train",
|
| 327 |
+
"? microphone",
|
| 328 |
+
"cigar",
|
| 329 |
+
"wearing , comes through a door held open by another man ? suit",
|
| 330 |
+
"wheel",
|
| 331 |
+
"lions",
|
| 332 |
+
"tights",
|
| 333 |
+
"racetrack",
|
| 334 |
+
"one picks up the other and carries him ? two",
|
| 335 |
+
"sun",
|
| 336 |
+
"? floor",
|
| 337 |
+
"beer",
|
| 338 |
+
"berries",
|
| 339 |
+
"mask",
|
| 340 |
+
"heels",
|
| 341 |
+
"decorator",
|
| 342 |
+
"cub",
|
| 343 |
+
"breakfast",
|
| 344 |
+
". ? chair",
|
| 345 |
+
"then looks away ? monkey",
|
| 346 |
+
"? bucket",
|
| 347 |
+
"snack",
|
| 348 |
+
"girl",
|
| 349 |
+
"suspenders",
|
| 350 |
+
"toy",
|
| 351 |
+
"elephants",
|
| 352 |
+
"boar",
|
| 353 |
+
"bubble",
|
| 354 |
+
"falls off and he grabs it ? hat",
|
| 355 |
+
"trunk",
|
| 356 |
+
"and one of them climbs from one to the other ? frogs",
|
| 357 |
+
"floor",
|
| 358 |
+
"belt",
|
| 359 |
+
"octopus",
|
| 360 |
+
"? dish",
|
| 361 |
+
"truck",
|
| 362 |
+
"snowmobile",
|
| 363 |
+
"standing in the dark , wears ? dress",
|
| 364 |
+
"? bathtub",
|
| 365 |
+
"trees",
|
| 366 |
+
"? mall",
|
| 367 |
+
"bow",
|
| 368 |
+
"beat to the rhythm ? sticks",
|
| 369 |
+
"? store",
|
| 370 |
+
"but stops him ? rope",
|
| 371 |
+
"pug",
|
| 372 |
+
"headgear",
|
| 373 |
+
"tubes",
|
| 374 |
+
"dance",
|
| 375 |
+
"pandas",
|
| 376 |
+
"iguana",
|
| 377 |
+
"concert",
|
| 378 |
+
"dandelion",
|
| 379 |
+
"? garden",
|
| 380 |
+
"queen",
|
| 381 |
+
"instruments",
|
| 382 |
+
"tricycle",
|
| 383 |
+
"racing",
|
| 384 |
+
"? garage",
|
| 385 |
+
"horn",
|
| 386 |
+
"entrance",
|
| 387 |
+
"can",
|
| 388 |
+
"chimpanzee",
|
| 389 |
+
"but the bear cub does ? bear",
|
| 390 |
+
"glass",
|
| 391 |
+
"birds",
|
| 392 |
+
"screaming and pointing ? two",
|
| 393 |
+
"robot",
|
| 394 |
+
"sky",
|
| 395 |
+
"egg",
|
| 396 |
+
"moth",
|
| 397 |
+
"backpack",
|
| 398 |
+
"beverages",
|
| 399 |
+
"bouquet",
|
| 400 |
+
"trumpet",
|
| 401 |
+
"carpet",
|
| 402 |
+
"? apartment",
|
| 403 |
+
"pony",
|
| 404 |
+
"goat",
|
| 405 |
+
"headdress",
|
| 406 |
+
"and he is removing ? hat",
|
| 407 |
+
"house",
|
| 408 |
+
"suit",
|
| 409 |
+
"gum",
|
| 410 |
+
"curb",
|
| 411 |
+
"and then leaves it ? car",
|
| 412 |
+
"snake",
|
| 413 |
+
"he looks at his passenger who is sleeping ? car",
|
| 414 |
+
"? bow-tie",
|
| 415 |
+
"wig",
|
| 416 |
+
"raising a cloud of dust ? car",
|
| 417 |
+
"freezer",
|
| 418 |
+
"delivering , and signing ? flowers",
|
| 419 |
+
"skis",
|
| 420 |
+
"road",
|
| 421 |
+
"deal",
|
| 422 |
+
"ship",
|
| 423 |
+
"? bathroom",
|
| 424 |
+
"bills",
|
| 425 |
+
"piece",
|
| 426 |
+
"items fall out and she makes a face ? door",
|
| 427 |
+
"drinks",
|
| 428 |
+
"dives , . ? cafeteria",
|
| 429 |
+
"goggles",
|
| 430 |
+
"? wagon",
|
| 431 |
+
"man",
|
| 432 |
+
"cups",
|
| 433 |
+
"dolphin",
|
| 434 |
+
"card",
|
| 435 |
+
"building",
|
| 436 |
+
"trunks",
|
| 437 |
+
"liquor",
|
| 438 |
+
"scarf",
|
| 439 |
+
"squash",
|
| 440 |
+
"cheese",
|
| 441 |
+
"then the snake kisses her ? snake",
|
| 442 |
+
"dances seductively ? dress",
|
| 443 |
+
"sword",
|
| 444 |
+
"kiss",
|
| 445 |
+
"possum",
|
| 446 |
+
"stockings",
|
| 447 |
+
"? tray",
|
| 448 |
+
"the one man yells ? two",
|
| 449 |
+
"and she is playing ? guitar",
|
| 450 |
+
"? alley",
|
| 451 |
+
"also wearing ? helmet",
|
| 452 |
+
"beverage",
|
| 453 |
+
"weapon",
|
| 454 |
+
"rodent",
|
| 455 |
+
"beach",
|
| 456 |
+
"? cereals",
|
| 457 |
+
"bench",
|
| 458 |
+
"with two holding glass bottles with colored liquid ? five",
|
| 459 |
+
"holding , jumps in the air and then moves to the back of stage ? guitar",
|
| 460 |
+
"transportation",
|
| 461 |
+
"shampoo",
|
| 462 |
+
"caps",
|
| 463 |
+
"hook",
|
| 464 |
+
"squirrel",
|
| 465 |
+
"scenery",
|
| 466 |
+
"playing",
|
| 467 |
+
"? wheelchair",
|
| 468 |
+
"performer",
|
| 469 |
+
"cake",
|
| 470 |
+
"dancing and playing ? instruments",
|
| 471 |
+
"boxes",
|
| 472 |
+
"leash",
|
| 473 |
+
"? bouquet",
|
| 474 |
+
"but only one arm is . ? sleeve",
|
| 475 |
+
"rifles",
|
| 476 |
+
"lenses",
|
| 477 |
+
"the girl watches him . ? building",
|
| 478 |
+
"almonds",
|
| 479 |
+
"tank",
|
| 480 |
+
"pot",
|
| 481 |
+
"bracelet",
|
| 482 |
+
"knife",
|
| 483 |
+
"mouse",
|
| 484 |
+
"who then catches it ? bottle",
|
| 485 |
+
"exercise",
|
| 486 |
+
"and he is turning around ? wand",
|
| 487 |
+
"purse",
|
| 488 |
+
"stones",
|
| 489 |
+
"show",
|
| 490 |
+
"bag",
|
| 491 |
+
"stocking",
|
| 492 |
+
"balloon",
|
| 493 |
+
"stops , and its tongue remains stuck out ? cat",
|
| 494 |
+
"scythe",
|
| 495 |
+
"creature",
|
| 496 |
+
"cello",
|
| 497 |
+
"and ends up on its back ? bird",
|
| 498 |
+
"pup",
|
| 499 |
+
"? container",
|
| 500 |
+
"and one blows a kiss ? two",
|
| 501 |
+
"animal",
|
| 502 |
+
"trampoline",
|
| 503 |
+
"before they turn and walk away ? two",
|
| 504 |
+
"cloaks",
|
| 505 |
+
"blackjack",
|
| 506 |
+
"as they hit fist to fist ? two",
|
| 507 |
+
"bicycles",
|
| 508 |
+
"watch",
|
| 509 |
+
"corgi",
|
| 510 |
+
"spider",
|
| 511 |
+
"earring",
|
| 512 |
+
"bull",
|
| 513 |
+
"? wheel",
|
| 514 |
+
"? stadium",
|
| 515 |
+
"looking at each other ? two",
|
| 516 |
+
"foxes",
|
| 517 |
+
"mammal",
|
| 518 |
+
"sheep",
|
| 519 |
+
"chases",
|
| 520 |
+
"? armchair",
|
| 521 |
+
". ? room",
|
| 522 |
+
"dancing , and playing ? instruments",
|
| 523 |
+
"which then falls backwards ? cat",
|
| 524 |
+
"dancer",
|
| 525 |
+
"boots",
|
| 526 |
+
"rotors",
|
| 527 |
+
"? ranch",
|
| 528 |
+
"? shower",
|
| 529 |
+
"paper , scissors as they stand by the door ? two",
|
| 530 |
+
"laying and crying on her pillow . ? bed",
|
| 531 |
+
"pencil",
|
| 532 |
+
"when one side scores a goal ? two",
|
| 533 |
+
"food",
|
| 534 |
+
"one with an arm on the other ? two",
|
| 535 |
+
"sheets",
|
| 536 |
+
"rabbits",
|
| 537 |
+
"pizza",
|
| 538 |
+
"? glove",
|
| 539 |
+
"table",
|
| 540 |
+
"scratched",
|
| 541 |
+
"syrup",
|
| 542 |
+
"cone",
|
| 543 |
+
"while the larger man breaks up the fight ? two",
|
| 544 |
+
"drives",
|
| 545 |
+
"luggage",
|
| 546 |
+
"? vehicle",
|
| 547 |
+
"lift",
|
| 548 |
+
"frame",
|
| 549 |
+
"shoes",
|
| 550 |
+
"opens the door , and the cat and four dogs enter through the door ? building",
|
| 551 |
+
"blinks",
|
| 552 |
+
"crotch",
|
| 553 |
+
"dishwasher",
|
| 554 |
+
"skills",
|
| 555 |
+
"sleeves",
|
| 556 |
+
"model",
|
| 557 |
+
"ties",
|
| 558 |
+
"modeling",
|
| 559 |
+
"bath",
|
| 560 |
+
"jet",
|
| 561 |
+
"tortillas",
|
| 562 |
+
"teapot",
|
| 563 |
+
"barbel",
|
| 564 |
+
"cartwheel",
|
| 565 |
+
"musician",
|
| 566 |
+
"rhino",
|
| 567 |
+
"exits",
|
| 568 |
+
"pole",
|
| 569 |
+
"ski",
|
| 570 |
+
"pajama",
|
| 571 |
+
"woodchucks",
|
| 572 |
+
"lanes",
|
| 573 |
+
"candle",
|
| 574 |
+
"tag",
|
| 575 |
+
"gloves",
|
| 576 |
+
"dinosaur",
|
| 577 |
+
"surface",
|
| 578 |
+
"? tub",
|
| 579 |
+
"snowboard",
|
| 580 |
+
"wearing , hops around her couch while pointing at her face ? glasses",
|
| 581 |
+
"donut",
|
| 582 |
+
"mustard",
|
| 583 |
+
"? tunnel",
|
| 584 |
+
"? theater",
|
| 585 |
+
"wheels",
|
| 586 |
+
"rat",
|
| 587 |
+
"and one talks to someone ? two",
|
| 588 |
+
"bungee",
|
| 589 |
+
"but then suddenly takes off again ? jet",
|
| 590 |
+
"? rink",
|
| 591 |
+
"face shown . ? mirror",
|
| 592 |
+
"shell",
|
| 593 |
+
"costumes",
|
| 594 |
+
"? shield",
|
| 595 |
+
"confetti",
|
| 596 |
+
"flower",
|
| 597 |
+
"gesture",
|
| 598 |
+
"portfolio",
|
| 599 |
+
"and moves from under him ? ball",
|
| 600 |
+
"violin",
|
| 601 |
+
"photographs",
|
| 602 |
+
"uniforms",
|
| 603 |
+
"money",
|
| 604 |
+
"bomb",
|
| 605 |
+
"? rv",
|
| 606 |
+
"claws",
|
| 607 |
+
"lands",
|
| 608 |
+
"turnstile",
|
| 609 |
+
"bot",
|
| 610 |
+
"hose",
|
| 611 |
+
"suitcase",
|
| 612 |
+
"sitting on a table , reaches out and pushes a glass off the table ? paw",
|
| 613 |
+
"mountain",
|
| 614 |
+
"tools",
|
| 615 |
+
"headsets",
|
| 616 |
+
"the streets crumble below it ? airplane",
|
| 617 |
+
"t-shirt",
|
| 618 |
+
"doors",
|
| 619 |
+
"wearing , hugs another person and smiles ? glasses",
|
| 620 |
+
"one of them is shaking his head . ? car",
|
| 621 |
+
"octopuses",
|
| 622 |
+
"performs",
|
| 623 |
+
"cases",
|
| 624 |
+
"deer",
|
| 625 |
+
"? wall",
|
| 626 |
+
"and holding a lighter underneath , it explodes in flames ? balloon",
|
| 627 |
+
"blanket",
|
| 628 |
+
"coat",
|
| 629 |
+
"knives",
|
| 630 |
+
"? frame",
|
| 631 |
+
"trolley",
|
| 632 |
+
"noodles",
|
| 633 |
+
"one cries and holds a handkerchief to his nose , the other tries to comfort him ? two",
|
| 634 |
+
"wrap",
|
| 635 |
+
"? cart",
|
| 636 |
+
"inside of the car get scared ? two",
|
| 637 |
+
"animals",
|
| 638 |
+
"tails",
|
| 639 |
+
"? drawer",
|
| 640 |
+
"? cigarette",
|
| 641 |
+
"? barbel",
|
| 642 |
+
"room",
|
| 643 |
+
"? building",
|
| 644 |
+
"using as a weapon , hits a zombie in the head ? bat",
|
| 645 |
+
"trucks",
|
| 646 |
+
"boxers",
|
| 647 |
+
"drum",
|
| 648 |
+
"challenge",
|
| 649 |
+
"? toilet",
|
| 650 |
+
"llamas",
|
| 651 |
+
"then watches the smoke rise ? cat",
|
| 652 |
+
"mouths from across a room ? two",
|
| 653 |
+
"and it is pushed by a cat ? box",
|
| 654 |
+
"but the bear cub does ? bird",
|
| 655 |
+
"? skateboard",
|
| 656 |
+
"lifts up to her mouth , ? microphone",
|
| 657 |
+
"wearing , talks and bends his head forward ? cap",
|
| 658 |
+
"? doorway",
|
| 659 |
+
"which causes that cat to attack another cat ? cat",
|
| 660 |
+
"giraffe",
|
| 661 |
+
"cam",
|
| 662 |
+
"microphones",
|
| 663 |
+
"losing balance as it tries to walk forward ? cat",
|
| 664 |
+
"groove",
|
| 665 |
+
"tricks",
|
| 666 |
+
"spins , and lands on another ramp ? car",
|
| 667 |
+
"dumbbell",
|
| 668 |
+
"with their arms out , while laughing ? three",
|
| 669 |
+
"sea",
|
| 670 |
+
"carrot",
|
| 671 |
+
"chips",
|
| 672 |
+
"gift",
|
| 673 |
+
"ropes",
|
| 674 |
+
"singer",
|
| 675 |
+
"rocket",
|
| 676 |
+
"? net",
|
| 677 |
+
"blows",
|
| 678 |
+
"? zipper",
|
| 679 |
+
"sticks",
|
| 680 |
+
"tambourine",
|
| 681 |
+
"and he is laughing at a puppet talking ? cookie",
|
| 682 |
+
"? train",
|
| 683 |
+
"boats",
|
| 684 |
+
"across a road , and into the path of a car before being hit ? bicycle",
|
| 685 |
+
"penguins",
|
| 686 |
+
"song",
|
| 687 |
+
"antlers",
|
| 688 |
+
"feather",
|
| 689 |
+
"handcuffs",
|
| 690 |
+
"insect",
|
| 691 |
+
"gratings",
|
| 692 |
+
"milk",
|
| 693 |
+
"blackbird",
|
| 694 |
+
"scaffolding",
|
| 695 |
+
"sheet",
|
| 696 |
+
"seal",
|
| 697 |
+
"which bursts as the car approaches it ? car",
|
| 698 |
+
"? locker",
|
| 699 |
+
"towels",
|
| 700 |
+
"? highway",
|
| 701 |
+
"? lane",
|
| 702 |
+
"? rope",
|
| 703 |
+
"wearing , is singing with a microphone ? dress",
|
| 704 |
+
"vegetables",
|
| 705 |
+
"rag",
|
| 706 |
+
"? hoop",
|
| 707 |
+
"? hospital",
|
| 708 |
+
"keys",
|
| 709 |
+
"and he is raising his arm ? crotch",
|
| 710 |
+
"otter",
|
| 711 |
+
"? corridor",
|
| 712 |
+
"tires",
|
| 713 |
+
"they see it from looking up ? window",
|
| 714 |
+
"trainer",
|
| 715 |
+
"groundhog",
|
| 716 |
+
"gorilla",
|
| 717 |
+
"is sitting on the steps and eating ? shirt",
|
| 718 |
+
"oar",
|
| 719 |
+
"nugget",
|
| 720 |
+
"? cellphone",
|
| 721 |
+
"hamsters",
|
| 722 |
+
"walls",
|
| 723 |
+
"? cup",
|
| 724 |
+
"and then starts wracking it FRAMEQAeatedly ? wand",
|
| 725 |
+
"concoction",
|
| 726 |
+
"computer",
|
| 727 |
+
"hall",
|
| 728 |
+
"one is licking the other ones ear ? cats",
|
| 729 |
+
"earphone",
|
| 730 |
+
"hallway",
|
| 731 |
+
"trailer",
|
| 732 |
+
"magazine",
|
| 733 |
+
"and pointing at it ? laptop",
|
| 734 |
+
"elevator",
|
| 735 |
+
"river",
|
| 736 |
+
"pig",
|
| 737 |
+
"is also using ? earring",
|
| 738 |
+
"case",
|
| 739 |
+
"cape",
|
| 740 |
+
"? tablet",
|
| 741 |
+
"beanie",
|
| 742 |
+
"penguin",
|
| 743 |
+
"race",
|
| 744 |
+
"? excitedly",
|
| 745 |
+
"groomed each other ? cats",
|
| 746 |
+
"carriage",
|
| 747 |
+
"with long hair , open her mouth . ? room",
|
| 748 |
+
"parakeet",
|
| 749 |
+
"call",
|
| 750 |
+
"? tire",
|
| 751 |
+
"windshield",
|
| 752 |
+
"nose",
|
| 753 |
+
"? capsule",
|
| 754 |
+
"woman",
|
| 755 |
+
"snowball",
|
| 756 |
+
"look at one another , and fall to the ground laughing ? three",
|
| 757 |
+
"wing",
|
| 758 |
+
"bowl",
|
| 759 |
+
"lipstick",
|
| 760 |
+
"who is looking upset ? one",
|
| 761 |
+
"balls",
|
| 762 |
+
"cage",
|
| 763 |
+
"sunroof",
|
| 764 |
+
"? shop",
|
| 765 |
+
"shining and wearing a yellow outfit ? microphone",
|
| 766 |
+
"then two of them wave goodbye ? three",
|
| 767 |
+
"? sunglasses",
|
| 768 |
+
"kittens",
|
| 769 |
+
"? lingerie",
|
| 770 |
+
"colors",
|
| 771 |
+
"crying and eating a sandwich . ? bed",
|
| 772 |
+
"? lapel",
|
| 773 |
+
"corn",
|
| 774 |
+
"twirl",
|
| 775 |
+
"dough",
|
| 776 |
+
"dock",
|
| 777 |
+
"taxi",
|
| 778 |
+
"singing",
|
| 779 |
+
"stares",
|
| 780 |
+
"skate",
|
| 781 |
+
"chick",
|
| 782 |
+
"is visiting another guy . ? hospital",
|
| 783 |
+
"comb",
|
| 784 |
+
"roll",
|
| 785 |
+
"runway",
|
| 786 |
+
"statue",
|
| 787 |
+
"rides a skateboard up and launches himself through the air ? ramp",
|
| 788 |
+
"bleachers",
|
| 789 |
+
"? pot",
|
| 790 |
+
"butter",
|
| 791 |
+
"and it bounces off of a wall onto a table ? cat",
|
| 792 |
+
"? basement",
|
| 793 |
+
"eyeliner",
|
| 794 |
+
"wearing , is waving his hand ? shirt",
|
| 795 |
+
"opens the door , and the cat and four dogs enter the building through the door ? cat",
|
| 796 |
+
"right",
|
| 797 |
+
"flashlights",
|
| 798 |
+
"pet",
|
| 799 |
+
"pastry",
|
| 800 |
+
"but then the trailing car is shown a weapon and the car falls back ? car",
|
| 801 |
+
"tuxedo",
|
| 802 |
+
"begins to flip over and over ? car",
|
| 803 |
+
"curtain",
|
| 804 |
+
"fork",
|
| 805 |
+
"he looks away ? guitar",
|
| 806 |
+
"roof",
|
| 807 |
+
"? restroom",
|
| 808 |
+
"who jumps away . ? box",
|
| 809 |
+
"? rag",
|
| 810 |
+
"wearing , talks and raises on eyebrow ? headband",
|
| 811 |
+
"? cloak",
|
| 812 |
+
"then the rider lands on top ? motorcycle",
|
| 813 |
+
"toys",
|
| 814 |
+
"are talking to each other ? two",
|
| 815 |
+
"rats",
|
| 816 |
+
"telephone",
|
| 817 |
+
"bananas",
|
| 818 |
+
"user",
|
| 819 |
+
"stops and gets in ? taxi",
|
| 820 |
+
"cane",
|
| 821 |
+
"bucket",
|
| 822 |
+
"popsicle",
|
| 823 |
+
"? tent",
|
| 824 |
+
"? oven",
|
| 825 |
+
"and the fired a shot ? flower",
|
| 826 |
+
"? broom",
|
| 827 |
+
"? pan",
|
| 828 |
+
"design",
|
| 829 |
+
"hippopotamus",
|
| 830 |
+
"they move to the left ? sky",
|
| 831 |
+
"trying not to laugh ? two",
|
| 832 |
+
"torch",
|
| 833 |
+
"they look at one another , and the woman exits the car . ? car",
|
| 834 |
+
"his head nods to the left . ? chair",
|
| 835 |
+
"and he had a bandage on his head . ? car",
|
| 836 |
+
"vegetable",
|
| 837 |
+
"and everyone celebrates ? star",
|
| 838 |
+
"balloons",
|
| 839 |
+
"men",
|
| 840 |
+
"circles",
|
| 841 |
+
"graffiti",
|
| 842 |
+
"racer",
|
| 843 |
+
"jump",
|
| 844 |
+
"kissing , and spinning around ? two",
|
| 845 |
+
"works",
|
| 846 |
+
"castle",
|
| 847 |
+
"while they are sitting down ? two",
|
| 848 |
+
"sandwich",
|
| 849 |
+
"earpiece",
|
| 850 |
+
"then lift ? shirt",
|
| 851 |
+
"motors",
|
| 852 |
+
"burrito",
|
| 853 |
+
"? singlet",
|
| 854 |
+
"180",
|
| 855 |
+
"? dryer",
|
| 856 |
+
"torches",
|
| 857 |
+
"? pullover",
|
| 858 |
+
"wearing , slides open a door and dances through while carrying a walking tick and radio ? glasses",
|
| 859 |
+
"straw",
|
| 860 |
+
"wearing , pushes a melting ice cream into his mouth as some drops from his hand ? cap",
|
| 861 |
+
"clown",
|
| 862 |
+
"smiles , and turns away . ? classroom",
|
| 863 |
+
"figure",
|
| 864 |
+
"white doll ? two",
|
| 865 |
+
"signs",
|
| 866 |
+
"? airplane",
|
| 867 |
+
"cannon",
|
| 868 |
+
"cloth",
|
| 869 |
+
"serviette",
|
| 870 |
+
"toast",
|
| 871 |
+
"? kit",
|
| 872 |
+
"bats",
|
| 873 |
+
"bobcat",
|
| 874 |
+
"griddle",
|
| 875 |
+
"leaves",
|
| 876 |
+
"pass",
|
| 877 |
+
"? door",
|
| 878 |
+
"ramp",
|
| 879 |
+
"porpoise",
|
| 880 |
+
"scissors",
|
| 881 |
+
"fighter",
|
| 882 |
+
"bandannas",
|
| 883 |
+
"bases",
|
| 884 |
+
"hug each other ? two",
|
| 885 |
+
"duckling",
|
| 886 |
+
"but grabs on and takes a drink ? monkey",
|
| 887 |
+
"winks",
|
| 888 |
+
"? jeep",
|
| 889 |
+
"twirls",
|
| 890 |
+
"harp",
|
| 891 |
+
"one points and talks and the other laughs ? two",
|
| 892 |
+
"then a redhead grabs ? hat",
|
| 893 |
+
"? zoo",
|
| 894 |
+
"tender",
|
| 895 |
+
"disc",
|
| 896 |
+
"fly",
|
| 897 |
+
"wash",
|
| 898 |
+
"harness",
|
| 899 |
+
"opening",
|
| 900 |
+
"brick",
|
| 901 |
+
"watermelon",
|
| 902 |
+
"plate",
|
| 903 |
+
"they bring it closer to their body ? stick",
|
| 904 |
+
"lake",
|
| 905 |
+
"sledgehammer",
|
| 906 |
+
"leaning backward , and waving their arms back and forth ? two",
|
| 907 |
+
"ocean",
|
| 908 |
+
"while spectators watch ? two",
|
| 909 |
+
"shuttle",
|
| 910 |
+
"loop",
|
| 911 |
+
"balcony",
|
| 912 |
+
"? closet",
|
| 913 |
+
"but falls off a table ? cat",
|
| 914 |
+
"anchor",
|
| 915 |
+
"? plaid",
|
| 916 |
+
"terrapins",
|
| 917 |
+
"pop",
|
| 918 |
+
"tool",
|
| 919 |
+
"hay",
|
| 920 |
+
"panther",
|
| 921 |
+
"smiling and laughing ? three",
|
| 922 |
+
"and it lands on his head ? hat",
|
| 923 |
+
"? fountain",
|
| 924 |
+
"photograph",
|
| 925 |
+
"it has a double yolk ? egg",
|
| 926 |
+
"one is in a basket ? dogs",
|
| 927 |
+
"but does ? cub",
|
| 928 |
+
"strips",
|
| 929 |
+
"jeep",
|
| 930 |
+
"when the toaster pops out toast the cat gets scared and jumps off ? cat",
|
| 931 |
+
"then turns around crazy ? dog",
|
| 932 |
+
"goldfish",
|
| 933 |
+
"? elevator",
|
| 934 |
+
"sedan",
|
| 935 |
+
"? pocket",
|
| 936 |
+
"planet",
|
| 937 |
+
"drill",
|
| 938 |
+
"two of them spinning around ? cars",
|
| 939 |
+
"baboon",
|
| 940 |
+
"mirror",
|
| 941 |
+
"? flowers",
|
| 942 |
+
"chairs",
|
| 943 |
+
"make in the air with a wand ? float",
|
| 944 |
+
"jewelry",
|
| 945 |
+
"fabric",
|
| 946 |
+
"coins",
|
| 947 |
+
"handset",
|
| 948 |
+
"jets",
|
| 949 |
+
"bulldog",
|
| 950 |
+
"black hair wearing and raising their hand up to their mouth ? shirt",
|
| 951 |
+
"sweatshirt",
|
| 952 |
+
"workout",
|
| 953 |
+
"rounds",
|
| 954 |
+
"? bench",
|
| 955 |
+
"? piece",
|
| 956 |
+
"sparklers",
|
| 957 |
+
"waterfall",
|
| 958 |
+
"lettuce",
|
| 959 |
+
"crashes",
|
| 960 |
+
"tomato",
|
| 961 |
+
"cheeseburger",
|
| 962 |
+
"strawberry",
|
| 963 |
+
"and another one appears to be . ? garden",
|
| 964 |
+
"flag",
|
| 965 |
+
"eight",
|
| 966 |
+
"toothpick",
|
| 967 |
+
"and disappears ? bowl",
|
| 968 |
+
"? lipstick",
|
| 969 |
+
"and she is smiling ? cat",
|
| 970 |
+
"? alleyway",
|
| 971 |
+
"shield",
|
| 972 |
+
"tuxedos",
|
| 973 |
+
"talking , smiling and waving his hand . ? chair",
|
| 974 |
+
"cheetah",
|
| 975 |
+
"and one player kicks into the goal ? ball",
|
| 976 |
+
"letters",
|
| 977 |
+
"? basket",
|
| 978 |
+
"pill",
|
| 979 |
+
"which trips another man who does a flip and lands on a recycle bin ? peel",
|
| 980 |
+
"human",
|
| 981 |
+
"fence",
|
| 982 |
+
"? sink",
|
| 983 |
+
"black leather trench coat ? star",
|
| 984 |
+
"divers",
|
| 985 |
+
"couch",
|
| 986 |
+
"buttons",
|
| 987 |
+
"shot",
|
| 988 |
+
"rodents",
|
| 989 |
+
"swords",
|
| 990 |
+
"gown",
|
| 991 |
+
"both speeding down the road ? car",
|
| 992 |
+
"people watch them . ? house",
|
| 993 |
+
"belts",
|
| 994 |
+
"catapult",
|
| 995 |
+
"ammunition",
|
| 996 |
+
"potatoes",
|
| 997 |
+
"lemur",
|
| 998 |
+
"while a third moves forward and dances ? two",
|
| 999 |
+
"then their hand and a slogan appears ? towel",
|
| 1000 |
+
"firecrackers",
|
| 1001 |
+
"ribs",
|
| 1002 |
+
"briefcase",
|
| 1003 |
+
"the man spills milk over his face . ? car",
|
| 1004 |
+
"? workshop",
|
| 1005 |
+
"is sitting down and smoking ? cigarette",
|
| 1006 |
+
"dressed in a suit and carrying ? cane",
|
| 1007 |
+
"and she is dancing in a field . ? mirror",
|
| 1008 |
+
"? ashtray",
|
| 1009 |
+
"looking sad . ? hallway",
|
| 1010 |
+
"noodle",
|
| 1011 |
+
"missiles",
|
| 1012 |
+
"? helicopter",
|
| 1013 |
+
"catfish",
|
| 1014 |
+
"toothbrush",
|
| 1015 |
+
"have taken ? pictures",
|
| 1016 |
+
"pane",
|
| 1017 |
+
"he dances on the stage ? headset",
|
| 1018 |
+
"scooters",
|
| 1019 |
+
"then he does the splits . ? hallway",
|
| 1020 |
+
"and it is pushed by a cat ? mouse",
|
| 1021 |
+
"desks",
|
| 1022 |
+
"hills",
|
| 1023 |
+
"stairway",
|
| 1024 |
+
"whisk",
|
| 1025 |
+
"with",
|
| 1026 |
+
"while one of them sings into a microphone ? two",
|
| 1027 |
+
"bottles",
|
| 1028 |
+
"but grabs her leg ? panda",
|
| 1029 |
+
"sled",
|
| 1030 |
+
"nut",
|
| 1031 |
+
"feathers",
|
| 1032 |
+
"dresses",
|
| 1033 |
+
"sink",
|
| 1034 |
+
"wristband",
|
| 1035 |
+
"then jumps up to celebrate ? pool",
|
| 1036 |
+
"drumsticks",
|
| 1037 |
+
"opens her mouth and smiles ? one",
|
| 1038 |
+
"suits",
|
| 1039 |
+
"sculpture",
|
| 1040 |
+
"are fighting for control of the soccer ball ? two",
|
| 1041 |
+
"and he is throwing ? napkin",
|
| 1042 |
+
"pets",
|
| 1043 |
+
"bin",
|
| 1044 |
+
"jockey",
|
| 1045 |
+
"backwards",
|
| 1046 |
+
"spiky , walk across the pavement ? heels",
|
| 1047 |
+
"chainsaw",
|
| 1048 |
+
"? guitar",
|
| 1049 |
+
"with just head and tail exposed ? cat",
|
| 1050 |
+
"when one pins the other one down for a three count ? two",
|
| 1051 |
+
"shore",
|
| 1052 |
+
"chicks",
|
| 1053 |
+
"dancing and laughing ? two",
|
| 1054 |
+
"looking sideways and singing ? guitar",
|
| 1055 |
+
"? turns",
|
| 1056 |
+
"lamp",
|
| 1057 |
+
"paper , scissors ? two",
|
| 1058 |
+
"chocolate",
|
| 1059 |
+
"bra",
|
| 1060 |
+
"blonde woman wearing a back top and matching ? piece",
|
| 1061 |
+
"holding hands ? two",
|
| 1062 |
+
"while the man next to him talks and moves his hands around ? one",
|
| 1063 |
+
"cubs",
|
| 1064 |
+
"having cake . ? restaurant",
|
| 1065 |
+
"figurine",
|
| 1066 |
+
"hood",
|
| 1067 |
+
"lens",
|
| 1068 |
+
"groomed each other ? two",
|
| 1069 |
+
"sabers",
|
| 1070 |
+
"before jumping in the pool ? dog",
|
| 1071 |
+
"mattress",
|
| 1072 |
+
"sidewalk",
|
| 1073 |
+
"landing",
|
| 1074 |
+
"rocks",
|
| 1075 |
+
"avocado",
|
| 1076 |
+
"? bear",
|
| 1077 |
+
"and a man spills , crouches , and cowers ? coffee",
|
| 1078 |
+
"disks",
|
| 1079 |
+
"mountainside",
|
| 1080 |
+
"lips",
|
| 1081 |
+
"chest",
|
| 1082 |
+
"wan",
|
| 1083 |
+
"glove",
|
| 1084 |
+
"? beer",
|
| 1085 |
+
"tortilla",
|
| 1086 |
+
"? stable",
|
| 1087 |
+
"meteor",
|
| 1088 |
+
"expression",
|
| 1089 |
+
"? kayak",
|
| 1090 |
+
"biscuit",
|
| 1091 |
+
"ukulele",
|
| 1092 |
+
"at something ? two",
|
| 1093 |
+
"convertible",
|
| 1094 |
+
"climber",
|
| 1095 |
+
"is using the pay phone and smoking ? cigarette",
|
| 1096 |
+
"wearing , looks mad ? jacket",
|
| 1097 |
+
"mike",
|
| 1098 |
+
"sleeping and stretching on the person 's stomach ? cat",
|
| 1099 |
+
"denim",
|
| 1100 |
+
"lantern",
|
| 1101 |
+
"breaks the branch its sitting on in the tree , and falls to the ground ? panda",
|
| 1102 |
+
"so that she 's almost laying down . ? car",
|
| 1103 |
+
"smears",
|
| 1104 |
+
"hair",
|
| 1105 |
+
"bones",
|
| 1106 |
+
"blade",
|
| 1107 |
+
"unicycle",
|
| 1108 |
+
"? cone",
|
| 1109 |
+
"wallet",
|
| 1110 |
+
"blouse",
|
| 1111 |
+
"trousers",
|
| 1112 |
+
"buds",
|
| 1113 |
+
"spill",
|
| 1114 |
+
"rib",
|
| 1115 |
+
"porcupine",
|
| 1116 |
+
"tray",
|
| 1117 |
+
"map",
|
| 1118 |
+
"sad ? dog",
|
| 1119 |
+
"socks",
|
| 1120 |
+
"automobile",
|
| 1121 |
+
"parallel",
|
| 1122 |
+
"skyscraper",
|
| 1123 |
+
"classroom",
|
| 1124 |
+
"catwalk",
|
| 1125 |
+
"the bike crashes ? bicycle",
|
| 1126 |
+
"stare , and look shocked ? four",
|
| 1127 |
+
"towel",
|
| 1128 |
+
"whilst another one is sitting down ? guitar",
|
| 1129 |
+
"lion",
|
| 1130 |
+
"cargo",
|
| 1131 |
+
"grabs",
|
| 1132 |
+
"and then starts wracking it FRAMEQAeatedly ? cat",
|
| 1133 |
+
"vest",
|
| 1134 |
+
"spits",
|
| 1135 |
+
"wearing is walking and waving ? dress",
|
| 1136 |
+
"poker",
|
| 1137 |
+
"robe",
|
| 1138 |
+
"bandanna",
|
| 1139 |
+
"little fingers ? two",
|
| 1140 |
+
"person",
|
| 1141 |
+
"doves",
|
| 1142 |
+
"container",
|
| 1143 |
+
"wearing , uses gymnastic rings to lift herself to a seated position then into a handstand ? clothes",
|
| 1144 |
+
"forklift",
|
| 1145 |
+
"buildings",
|
| 1146 |
+
"wearing ? blouse",
|
| 1147 |
+
"making a crack big enough for the rest to get in ? cat",
|
| 1148 |
+
"carrots",
|
| 1149 |
+
"lizard",
|
| 1150 |
+
"beakers",
|
| 1151 |
+
"blower",
|
| 1152 |
+
"and another woman is running in black shorts ? pants",
|
| 1153 |
+
"marks",
|
| 1154 |
+
"spaceship",
|
| 1155 |
+
"when one man lays the other man down ? two",
|
| 1156 |
+
"are dancing on a stage while the crowd cheers ? two",
|
| 1157 |
+
"they start to head bang . ? car",
|
| 1158 |
+
"then one blows confetti into the air ? two",
|
| 1159 |
+
"sitting down , when someone else steps up and spins the chair around . ? chair",
|
| 1160 |
+
"puppets",
|
| 1161 |
+
"garage",
|
| 1162 |
+
"lemon",
|
| 1163 |
+
"wearing , is sitting and doing something with her foot ? clothes",
|
| 1164 |
+
"and two men with lighting swords want to fight with him ? door",
|
| 1165 |
+
"treat",
|
| 1166 |
+
"lamb",
|
| 1167 |
+
"ways",
|
| 1168 |
+
"and one man throws ? hat",
|
| 1169 |
+
"pick",
|
| 1170 |
+
"product",
|
| 1171 |
+
"is throwing around the room ? clothes",
|
| 1172 |
+
"the clothes of the people catch on fire ? horses",
|
| 1173 |
+
"all , have the same type of hair style ? three",
|
| 1174 |
+
"whip",
|
| 1175 |
+
"mop",
|
| 1176 |
+
"pointing his fingers and nodding ? bow",
|
| 1177 |
+
"bags",
|
| 1178 |
+
"machines",
|
| 1179 |
+
"seeds",
|
| 1180 |
+
"symbol",
|
| 1181 |
+
"layer",
|
| 1182 |
+
"opens ? door",
|
| 1183 |
+
"dark sunglasses , and cigar ? two",
|
| 1184 |
+
"the man smashes the head of a zombie ? bat",
|
| 1185 |
+
"extinguisher",
|
| 1186 |
+
"candles",
|
| 1187 |
+
", looking out ? window",
|
| 1188 |
+
"group",
|
| 1189 |
+
"drop",
|
| 1190 |
+
"is riding , into the swimming pool ? bicycle",
|
| 1191 |
+
"stake",
|
| 1192 |
+
"block",
|
| 1193 |
+
"and he is singing into a microphone ? guitar",
|
| 1194 |
+
"ornament",
|
| 1195 |
+
"spins as he bends over . ? chair",
|
| 1196 |
+
"? shirts",
|
| 1197 |
+
"? colors",
|
| 1198 |
+
"hookah",
|
| 1199 |
+
"? courtyard",
|
| 1200 |
+
"cactus",
|
| 1201 |
+
"are having taken while on stage ? picture",
|
| 1202 |
+
"an orange ? shell",
|
| 1203 |
+
"and he is talking ? sunglasses",
|
| 1204 |
+
"veil",
|
| 1205 |
+
"then rolling around in the mud ? horse",
|
| 1206 |
+
"? pillow",
|
| 1207 |
+
"drugs",
|
| 1208 |
+
"? couch",
|
| 1209 |
+
"bun",
|
| 1210 |
+
"koala",
|
| 1211 |
+
"one wearing brown shoes and the other has no footwear ? two",
|
| 1212 |
+
"and he is falling in the water ? dog",
|
| 1213 |
+
"is smoking ? cigarette",
|
| 1214 |
+
"rooster",
|
| 1215 |
+
"submarine",
|
| 1216 |
+
"wand",
|
| 1217 |
+
"helicopter",
|
| 1218 |
+
"wearing , smiles as her hair blows in the wind ? hat",
|
| 1219 |
+
"and fails , to jump into the window ? cat",
|
| 1220 |
+
"tram",
|
| 1221 |
+
"and then is knocked down when it hits him in the head ? bag",
|
| 1222 |
+
"curve",
|
| 1223 |
+
"handrail",
|
| 1224 |
+
"bulldozer",
|
| 1225 |
+
"stops a taxi . ? street",
|
| 1226 |
+
"speedometer",
|
| 1227 |
+
"? necklace",
|
| 1228 |
+
"curbs",
|
| 1229 |
+
"over multiple vehicles and lands on another ramp ? bicycle",
|
| 1230 |
+
"wolves",
|
| 1231 |
+
"laundry",
|
| 1232 |
+
"holding , laughs into a microphone and then puts her fingers up to her lips ? guitar",
|
| 1233 |
+
"peeking . ? room",
|
| 1234 |
+
"cigarettes",
|
| 1235 |
+
"bells",
|
| 1236 |
+
"sill",
|
| 1237 |
+
"raspberry",
|
| 1238 |
+
"suited",
|
| 1239 |
+
"shawl",
|
| 1240 |
+
"wakes",
|
| 1241 |
+
"applying the brake , and applying the gas as needed . ? car",
|
| 1242 |
+
"poodle",
|
| 1243 |
+
"and he 's ? candles",
|
| 1244 |
+
"then skids on the ground ? motorcycle",
|
| 1245 |
+
"office",
|
| 1246 |
+
"outdoors",
|
| 1247 |
+
"it stops at the edge ? car",
|
| 1248 |
+
"as she puts it all on top of her head ? two",
|
| 1249 |
+
"but his reflection is doing something different . ? mirror",
|
| 1250 |
+
"holding , are walking together ? bear",
|
| 1251 |
+
"hats",
|
| 1252 |
+
"mat",
|
| 1253 |
+
"then the team mate scores a goal ? ball",
|
| 1254 |
+
"one with a guitar are behind him ? one",
|
| 1255 |
+
"? looks",
|
| 1256 |
+
"grenade",
|
| 1257 |
+
"coin",
|
| 1258 |
+
"toasting each other with their liquor bottles ? two",
|
| 1259 |
+
"saxophone",
|
| 1260 |
+
"capes",
|
| 1261 |
+
"lounges",
|
| 1262 |
+
"? scissors",
|
| 1263 |
+
"hoop",
|
| 1264 |
+
"rack",
|
| 1265 |
+
"frisbee",
|
| 1266 |
+
"then jumps in the air and runs away ? cat",
|
| 1267 |
+
"wearing , is hugging in the hallway ? coats",
|
| 1268 |
+
"? lobby",
|
| 1269 |
+
"corridor",
|
| 1270 |
+
"who they push to the ground ? two",
|
| 1271 |
+
"worms",
|
| 1272 |
+
"tablet",
|
| 1273 |
+
"who turns and causes the kitten to raise its paw ? kitten",
|
| 1274 |
+
"chariot",
|
| 1275 |
+
"lock",
|
| 1276 |
+
"tongs",
|
| 1277 |
+
"game",
|
| 1278 |
+
"s head while he is trying to eat ? cat",
|
| 1279 |
+
"pie",
|
| 1280 |
+
"feline",
|
| 1281 |
+
"and then are shown ? pictures",
|
| 1282 |
+
"parasol",
|
| 1283 |
+
"pumpkins",
|
| 1284 |
+
"notebook",
|
| 1285 |
+
"the horse leans its head around her ? horse",
|
| 1286 |
+
"spaghetti",
|
| 1287 |
+
"outside",
|
| 1288 |
+
"? bib",
|
| 1289 |
+
"gold",
|
| 1290 |
+
"cart",
|
| 1291 |
+
"the trees are being passed by , and the clouds are above ? sun",
|
| 1292 |
+
"the other elephant pulls it closer ? elephant",
|
| 1293 |
+
"most of them wearing ? sunglasses",
|
| 1294 |
+
"and are falling down on top of him ? balloons",
|
| 1295 |
+
"nods his head and blinks ? one",
|
| 1296 |
+
"with long brown hair , wink and raises to her face ? two",
|
| 1297 |
+
"uncontrollably",
|
| 1298 |
+
"wearing , raises two fingers to her face ? cap",
|
| 1299 |
+
"swinging its hips from side to side ? turtle",
|
| 1300 |
+
"skates",
|
| 1301 |
+
"they look at one another , and the woman exits ? car",
|
| 1302 |
+
"his friends join in the background . ? chair",
|
| 1303 |
+
"store",
|
| 1304 |
+
"donuts",
|
| 1305 |
+
"then sticks its tongue out ? dog",
|
| 1306 |
+
"and then a massive explosion occurs ? container",
|
| 1307 |
+
"then kisses her ? snake",
|
| 1308 |
+
"brakes"
|
| 1309 |
+
]
|
ChatUniVi/eval/questions/video_qa/tgif_qa.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ChatUniVi/eval/table/caps_boxes_coco2014_val_80.jsonl
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"id": "000000296284", "image": "COCO_val2014_000000296284.jpg", "captions": ["A donut shop is full of different flavors of donuts.", "Fruit flavored donuts lined up in a glass fronted cabinet", "A rack with some doughnuts in a glass case.", "A display case in a bakery filled with donuts.", "An assortment of doughnuts are arranged in a display case."], "instances": [{"category": "donut", "bbox": [0.37, 0.584, 0.504, 0.709]}, {"category": "donut", "bbox": [0.369, 0.22, 0.492, 0.317]}, {"category": "donut", "bbox": [0.471, 0.587, 0.639, 0.706]}, {"category": "donut", "bbox": [0.544, 0.213, 0.679, 0.316]}, {"category": "donut", "bbox": [0.035, 0.22, 0.196, 0.328]}, {"category": "donut", "bbox": [0.054, 0.608, 0.221, 0.711]}, {"category": "donut", "bbox": [0.283, 0.586, 0.429, 0.708]}, {"category": "donut", "bbox": [0.466, 0.226, 0.585, 0.32]}, {"category": "donut", "bbox": [0.28, 0.232, 0.393, 0.322]}, {"category": "donut", "bbox": [0.0, 0.609, 0.097, 0.722]}]}
|
| 2 |
+
{"id": "000000151358", "image": "COCO_val2014_000000151358.jpg", "captions": ["A newspaper that has sunglasses on top of it sitting in front of books.", "an apple sunglasses books and a teddy bear", "A folded newspaper and sunglasses are on a table with an apple, books, and teddy bear behind.", "An apple sitting on a table next to sunglasses and a news paper.", "There are sunglasses laying on the folded newspaper."], "instances": [{"category": "tie", "bbox": [0.258, 0.074, 0.527, 0.589]}, {"category": "apple", "bbox": [0.621, 0.482, 0.853, 0.645]}, {"category": "book", "bbox": [0.154, 0.107, 0.275, 0.59]}, {"category": "book", "bbox": [0.535, 0.09, 0.735, 0.583]}, {"category": "book", "bbox": [0.051, 0.112, 0.159, 0.6]}, {"category": "teddy bear", "bbox": [0.753, 0.084, 1.0, 0.517]}, {"category": "book", "bbox": [0.681, 0.097, 0.796, 0.483]}, {"category": "book", "bbox": [0.443, 0.099, 0.574, 0.588]}, {"category": "book", "bbox": [0.267, 0.337, 0.386, 0.579]}]}
|
| 3 |
+
{"id": "000000052312", "image": "COCO_val2014_000000052312.jpg", "captions": ["The old man literally has a toothbrush mustache.", "An old man with a tooth brush head under his nose, mimicking Hitler", "A man wearing a toothbrush for a moustache.", "A man with the head of a toothbrush under his nose like a mustache", "An elderly man wearing the head of a toothbrush as a moustache."], "instances": [{"category": "toothbrush", "bbox": [0.345, 0.59, 0.594, 0.679]}, {"category": "person", "bbox": [0.0, 0.03, 1.0, 0.99]}]}
|
| 4 |
+
{"id": "000000473210", "image": "COCO_val2014_000000473210.jpg", "captions": ["two people taking apart their wii controllers to replace batteries", "People taking apart video game remote controls on a table", "People handling a couple of remotes taking them apart.", "two sets of hands a wooden table and two controllers", "Two people who are taking apart a video game controller."], "instances": [{"category": "person", "bbox": [0.002, 0.334, 0.453, 0.986]}, {"category": "remote", "bbox": [0.407, 0.207, 0.727, 0.604]}, {"category": "remote", "bbox": [0.088, 0.344, 0.313, 0.547]}, {"category": "laptop", "bbox": [0.001, 0.049, 0.1, 0.197]}, {"category": "person", "bbox": [0.484, 0.254, 0.998, 0.985]}, {"category": "dining table", "bbox": [0.0, 0.003, 1.0, 0.956]}]}
|
| 5 |
+
{"id": "000000097131", "image": "COCO_val2014_000000097131.jpg", "captions": ["A car parked by a parking meter in front of a building.", "A car is sitting parked at a curb in front of a parking meter.", "A black car on the street next to a parking meter.", "A gray car parked in front of two parking meters.", "A black car parked on the side of the road."], "instances": [{"category": "car", "bbox": [0.227, 0.362, 0.946, 0.761]}, {"category": "car", "bbox": [0.793, 0.322, 0.88, 0.4]}, {"category": "car", "bbox": [0.0, 0.447, 0.028, 0.726]}, {"category": "parking meter", "bbox": [0.156, 0.35, 0.186, 0.453]}, {"category": "truck", "bbox": [0.907, 0.331, 1.0, 0.408]}, {"category": "parking meter", "bbox": [0.188, 0.349, 0.218, 0.448]}]}
|
| 6 |
+
{"id": "000000543364", "image": "COCO_val2014_000000543364.jpg", "captions": ["There is a table in the middle of the room.", "A room with a couch, table, lamp and a chaise.", "A living room with couch, chaise, track lighting, and a large window.", "A room with large windows, a couch and a table.", "A living room with lots of furniture and a large window."], "instances": [{"category": "dining table", "bbox": [0.388, 0.644, 0.636, 0.879]}, {"category": "couch", "bbox": [0.194, 0.531, 0.552, 0.777]}, {"category": "couch", "bbox": [0.568, 0.488, 0.907, 0.783]}, {"category": "remote", "bbox": [0.524, 0.651, 0.556, 0.675]}, {"category": "chair", "bbox": [0.661, 0.478, 0.802, 0.604]}]}
|
| 7 |
+
{"id": "000000217181", "image": "COCO_val2014_000000217181.jpg", "captions": ["They are standing next to some stylish motorcycles.", "Three men are standing around looking at sports motorcycles.", "A small group of men are standing around a motorcycle.", "Two men surrounding a blue motorcycle and others", "A few blue motorcycles are parked in a lot."], "instances": [{"category": "car", "bbox": [0.011, 0.177, 0.2, 0.336]}, {"category": "motorcycle", "bbox": [0.032, 0.139, 0.907, 0.982]}, {"category": "motorcycle", "bbox": [0.0, 0.239, 0.148, 0.613]}, {"category": "motorcycle", "bbox": [0.0, 0.301, 0.106, 0.45]}, {"category": "person", "bbox": [0.775, 0.043, 0.93, 0.463]}, {"category": "person", "bbox": [0.717, 0.116, 0.81, 0.509]}, {"category": "person", "bbox": [0.296, 0.008, 0.472, 0.325]}, {"category": "person", "bbox": [0.115, 0.19, 0.164, 0.269]}, {"category": "truck", "bbox": [0.63, 0.227, 0.731, 0.335]}]}
|
| 8 |
+
{"id": "000000140289", "image": "COCO_val2014_000000140289.jpg", "captions": ["Two born bears walking though a forest surrounded by trees.", "Two full grown brown bears in a habitat.", "Two bears are roaming around in the woods.", "Two bears around logs in front of a large rock.", "Two big bears wandering through the woods together"], "instances": [{"category": "bear", "bbox": [0.131, 0.269, 0.375, 0.65]}, {"category": "bear", "bbox": [0.568, 0.193, 0.809, 0.827]}]}
|
| 9 |
+
{"id": "000000460149", "image": "COCO_val2014_000000460149.jpg", "captions": ["A clock hosted on a pole on a pavement next to a building", "Street clock on quiet street with trees and bicycles.", "A tall clock stands on an empty sidewalk.", "A pole that has a clock on the top of it.", "a clock on a short tower and potted plants along the sidewalk"], "instances": [{"category": "potted plant", "bbox": [0.14, 0.71, 0.338, 0.856]}, {"category": "bicycle", "bbox": [0.65, 0.671, 0.766, 0.733]}, {"category": "car", "bbox": [0.38, 0.608, 0.488, 0.656]}, {"category": "clock", "bbox": [0.468, 0.048, 0.699, 0.216]}, {"category": "bicycle", "bbox": [0.669, 0.662, 0.719, 0.67]}, {"category": "car", "bbox": [0.786, 0.625, 0.86, 0.668]}, {"category": "potted plant", "bbox": [0.756, 0.637, 0.819, 0.682]}, {"category": "person", "bbox": [0.942, 0.615, 0.954, 0.641]}, {"category": "bicycle", "bbox": [0.648, 0.68, 0.714, 0.747]}, {"category": "car", "bbox": [0.837, 0.619, 0.88, 0.659]}, {"category": "potted plant", "bbox": [0.017, 0.197, 0.443, 0.686]}]}
|
| 10 |
+
{"id": "000000225738", "image": "COCO_val2014_000000225738.jpg", "captions": ["A group of giraffes standing up in their natural habitat.", "A group of giraffe standing in a grass field.", "A group of four giraffes near the same tree.", "there are four giraffes standing among some dry brush", "A herd of giraffe standing on top of a grass field."], "instances": [{"category": "giraffe", "bbox": [0.648, 0.231, 0.855, 0.915]}, {"category": "giraffe", "bbox": [0.33, 0.136, 0.521, 0.93]}, {"category": "giraffe", "bbox": [0.406, 0.261, 0.515, 1.0]}, {"category": "giraffe", "bbox": [0.347, 0.194, 0.583, 0.922]}]}
|
| 11 |
+
{"id": "000000109532", "image": "COCO_val2014_000000109532.jpg", "captions": ["An adorable husky dog sleeping in a dog bed next to a fan.", "A dark room with a dog sleeping on a dog bed.", "A dog is sleeping in a dark room.", "a large dog laying in a dog bed in a living room", "A dog sleeping on a dog bed in a room."], "instances": [{"category": "dog", "bbox": [0.426, 0.661, 0.582, 0.925]}, {"category": "potted plant", "bbox": [0.603, 0.261, 0.781, 0.613]}, {"category": "chair", "bbox": [0.67, 0.515, 0.899, 0.801]}, {"category": "potted plant", "bbox": [0.671, 0.439, 0.763, 0.612]}, {"category": "chair", "bbox": [0.852, 0.653, 0.948, 0.818]}]}
|
| 12 |
+
{"id": "000000118606", "image": "COCO_val2014_000000118606.jpg", "captions": ["A man riding skis on top of a rail.", "a person riding a pair of skis on a rail", "Someone on a pair of skis on a ramp at the ski slope", "Person with skis in the air above the snow.", "A man performing a trick on a rail while skiing."], "instances": [{"category": "person", "bbox": [0.444, 0.361, 0.537, 0.633]}, {"category": "skis", "bbox": [0.413, 0.554, 0.539, 0.664]}, {"category": "person", "bbox": [0.342, 0.585, 0.352, 0.62]}, {"category": "person", "bbox": [0.439, 0.565, 0.446, 0.58]}]}
|
| 13 |
+
{"id": "000000385873", "image": "COCO_val2014_000000385873.jpg", "captions": ["Three pizzas sitting next to each other in boxes.", "Two smaller pizzas sit beside a large pizza topped with tortilla chips.", "Three pizzas inside their delivery boxes, one with two side orders of sauce.", "One pizza is larger than two other pizzas.", "Three pizza boxes with pizza in them are open."], "instances": [{"category": "bowl", "bbox": [0.634, 0.624, 0.736, 0.752]}, {"category": "pizza", "bbox": [0.3, 0.382, 0.615, 0.733]}, {"category": "pizza", "bbox": [0.0, 0.4, 0.287, 0.745]}, {"category": "pizza", "bbox": [0.624, 0.279, 0.999, 0.753]}, {"category": "bowl", "bbox": [0.94, 0.247, 1.0, 0.352]}]}
|
| 14 |
+
{"id": "000000092109", "image": "COCO_val2014_000000092109.jpg", "captions": ["A giraffe's head is pictured in this clear, colorful photo.", "A giraffe is standing tall in the middle of several bright green trees", "The face of a giraffe looking to the side.", "the close up head shot of a giraffe", "this is a giraffe chewing on some leaves"], "instances": [{"category": "giraffe", "bbox": [0.236, 0.122, 1.0, 0.987]}]}
|
| 15 |
+
{"id": "000000163076", "image": "COCO_val2014_000000163076.jpg", "captions": ["There's an outdoor dining area featuring a fountain.", "A table sitting next to a water fountain covered by an umbrella.", "An empty restaurant patio with tables and umbrellas.", "An outdoor restaurant with a fountain at night", "A fountain bubbles in the plaza of an outdoor cafe."], "instances": [{"category": "umbrella", "bbox": [0.064, 0.069, 0.95, 0.844]}, {"category": "chair", "bbox": [0.198, 0.574, 0.355, 0.704]}, {"category": "chair", "bbox": [0.42, 0.571, 0.55, 0.738]}, {"category": "dining table", "bbox": [0.066, 0.741, 0.766, 0.925]}, {"category": "dining table", "bbox": [0.059, 0.584, 0.27, 0.659]}, {"category": "chair", "bbox": [0.432, 0.567, 0.52, 0.624]}, {"category": "chair", "bbox": [0.433, 0.555, 0.504, 0.6]}, {"category": "chair", "bbox": [0.109, 0.673, 0.374, 0.796]}]}
|
| 16 |
+
{"id": "000000560371", "image": "COCO_val2014_000000560371.jpg", "captions": ["Street signs from the corner of 8th ave. and 22 3/4 st.", "A two way street sign with one sign that changes from one name to another.", "A street sign is pointing towards 8th avenue and the other is pointing towards 22 3/4 street in the middle of the forest.", "A street sign standing in front of some trees.", "Peculiar street sign showing intersection of 23 3/4 St and 8th Ave/CTH D."], "instances": []}
|
| 17 |
+
{"id": "000000367571", "image": "COCO_val2014_000000367571.jpg", "captions": ["A couple of different doughnuts in a box", "There are four donuts in a box, and some are cake donuts and a doughnut with nuts and coconut on top.", "A box of glazed doughnuts on a table.", "Three donuts with toppings on them sitting inside a box.", "A box that is filled with different kinds of doughnuts."], "instances": [{"category": "donut", "bbox": [0.412, 0.335, 0.711, 0.681]}, {"category": "donut", "bbox": [0.093, 0.493, 0.486, 0.922]}, {"category": "donut", "bbox": [0.713, 0.423, 0.957, 0.874]}, {"category": "donut", "bbox": [0.13, 0.331, 0.397, 0.55]}]}
|
| 18 |
+
{"id": "000000580197", "image": "COCO_val2014_000000580197.jpg", "captions": ["Two men in bow ties standing next to steel rafter.", "Several men in suits talking together in a room.", "An older man in a tuxedo standing next to a younger man in a tuxedo wearing glasses.", "Two men wearing tuxedos glance at each other.", "Older man in tuxedo sitting next to another younger man in tuxedo."], "instances": [{"category": "tie", "bbox": [0.914, 0.46, 0.984, 0.512]}, {"category": "person", "bbox": [0.297, 0.638, 0.71, 0.989]}, {"category": "person", "bbox": [0.77, 0.177, 1.0, 0.971]}, {"category": "tie", "bbox": [0.281, 0.481, 0.368, 0.519]}, {"category": "person", "bbox": [0.103, 0.204, 0.497, 1.0]}]}
|
| 19 |
+
{"id": "000000506095", "image": "COCO_val2014_000000506095.jpg", "captions": ["A cat is staring at a laptop computer.", "a cat on a desk with a laptop and a mouse", "A cat that is sitting at a desk next to a laptop.", "A kitten sitting on a laptop computer sitting on top of a wooden desk.", "A kitten sits facing an open black laptop."], "instances": [{"category": "cat", "bbox": [0.658, 0.207, 1.0, 0.754]}, {"category": "laptop", "bbox": [0.108, 0.135, 0.766, 0.69]}, {"category": "book", "bbox": [0.836, 0.239, 0.954, 0.273]}, {"category": "book", "bbox": [0.0, 0.556, 0.128, 0.685]}, {"category": "book", "bbox": [0.039, 0.574, 0.257, 0.691]}, {"category": "book", "bbox": [0.825, 0.214, 0.962, 0.254]}, {"category": "book", "bbox": [0.892, 0.275, 0.958, 0.308]}, {"category": "book", "bbox": [0.922, 0.318, 0.986, 0.353]}, {"category": "book", "bbox": [0.87, 0.267, 0.951, 0.291]}, {"category": "book", "bbox": [0.949, 0.102, 0.976, 0.114]}, {"category": "book", "bbox": [0.936, 0.161, 0.958, 0.168]}]}
|
| 20 |
+
{"id": "000000024996", "image": "COCO_val2014_000000024996.jpg", "captions": ["A bathroom with a glass door and a sink.", "A blue lined bathroom with an open glass door.", "A nice bathroom with a sink, toilet, and tiled shower.", "A bathroom that is clean and shiny in the day.", "a bathroom with a sink and a mirror and a window"], "instances": [{"category": "toilet", "bbox": [0.842, 0.934, 0.95, 1.0]}, {"category": "sink", "bbox": [0.506, 0.724, 0.683, 0.834]}]}
|
| 21 |
+
{"id": "000000457882", "image": "COCO_val2014_000000457882.jpg", "captions": ["a girl in a bikini and a brown and white dog and a few other people", "A woman with a swimsuit on sitting with a dog.", "A woman is sitting with a dog on her lap.", "A dog sitting next to a woman in her swimsuit.", "WOMAN SITTING WITH HER DOG, AND OTHER WOMEN ARE AROUND"], "instances": [{"category": "dog", "bbox": [0.202, 0.409, 0.54, 0.81]}, {"category": "dog", "bbox": [0.61, 0.428, 0.729, 0.723]}, {"category": "boat", "bbox": [0.003, 0.705, 0.939, 0.974]}, {"category": "person", "bbox": [0.236, 0.001, 0.558, 0.784]}, {"category": "person", "bbox": [0.681, 0.001, 0.957, 0.798]}, {"category": "person", "bbox": [0.849, 0.478, 1.0, 0.946]}, {"category": "person", "bbox": [0.345, 0.187, 0.634, 0.828]}, {"category": "person", "bbox": [0.033, 0.345, 0.109, 0.434]}]}
|
| 22 |
+
{"id": "000000081552", "image": "COCO_val2014_000000081552.jpg", "captions": ["A cat sitting and curled up on a red couch", "A cat laying on a red couch sleeping.", "a tan and black cat curled up asleep on a red velvet seat", "A cat is curled up on a red sofa.", "Cat curled up, sleeping on a red plush couch."], "instances": [{"category": "cat", "bbox": [0.412, 0.237, 0.634, 0.482]}, {"category": "couch", "bbox": [0.003, 0.005, 1.0, 0.99]}]}
|
| 23 |
+
{"id": "000000273450", "image": "COCO_val2014_000000273450.jpg", "captions": ["A person flipping of a parking meter on the side of a road.", "A man holds up his middle finger to a parking meter.", "Person giving the middle finger to a parking meter.", "a black silver white blue red an orange parking meter and a hand flipping it off", "A person is flipping off a parking meter."], "instances": [{"category": "person", "bbox": [0.0, 0.475, 0.565, 0.987]}, {"category": "car", "bbox": [0.0, 0.0, 0.531, 0.734]}, {"category": "parking meter", "bbox": [0.0, 0.0, 1.0, 0.987]}]}
|
| 24 |
+
{"id": "000000203879", "image": "COCO_val2014_000000203879.jpg", "captions": ["There is a small cellphone displayed between a set of ear buds and two paper weights.", "a cell phone lays next to some diamonds", "a close up of a cell phone on a table near earbuds", "A cell phone sits on a table next to some jewels.", "A cell phone, ear buds, and two jewels laying near each other."], "instances": [{"category": "cell phone", "bbox": [0.322, 0.233, 0.62, 0.79]}]}
|
| 25 |
+
{"id": "000000346875", "image": "COCO_val2014_000000346875.jpg", "captions": ["two zebras in a field near one another", "A couple of zebra walking across a green field.", "Two zebra are walking near a gravel road.", "two zebras in a green field of grass and some trees", "A zebra follows another zebra through a park."], "instances": [{"category": "zebra", "bbox": [0.591, 0.263, 0.82, 0.466]}, {"category": "zebra", "bbox": [0.293, 0.243, 0.561, 0.45]}]}
|
| 26 |
+
{"id": "000000525439", "image": "COCO_val2014_000000525439.jpg", "captions": ["a man stands in front of a flipped skate boarder", "A man standing next to a skateboard that is laying on the ground wheels pointed up.", "Skateboard laying upside down on cement with someone standing next to it.", "A boy in camo shorts stands before an overturned skateboard.", "a person with an upside down skate board"], "instances": [{"category": "person", "bbox": [0.307, 0.001, 0.63, 0.739]}, {"category": "skateboard", "bbox": [0.0, 0.592, 0.626, 0.969]}]}
|
| 27 |
+
{"id": "000000304749", "image": "COCO_val2014_000000304749.jpg", "captions": ["The woman is taking a picture in the bathroom mirror.", "A picture of a woman in a mirror.", "A woman's midsection reflected in a round mirror.", "A circular mirror reflecting a woman's stomach in turquoise shirt.", "A selfie taken of a person from the neck down."], "instances": [{"category": "person", "bbox": [0.092, 0.001, 0.646, 0.496]}]}
|
| 28 |
+
{"id": "000000323760", "image": "COCO_val2014_000000323760.jpg", "captions": ["A toilet is shown in a bare room.", "A ugly bathroom with a section of the wall missing.", "A toilet in a stripped bathroom with studs, bricks and plaster showing", "A bathroom with no walls and a toilet bowl", "A white toilet next to some torn out walls."], "instances": [{"category": "toilet", "bbox": [0.167, 0.585, 0.714, 1.0]}]}
|
| 29 |
+
{"id": "000000066144", "image": "COCO_val2014_000000066144.jpg", "captions": ["A woman standing in front of window next to a bug and a stop sign.", "A car parked on the street next to a tree and stop sign.", "A lone Volkswagen is parked by a stop sign.", "A window view of a small car near a street stop sign.", "An old VW Bug standing at a stop sign."], "instances": [{"category": "stop sign", "bbox": [0.501, 0.328, 0.569, 0.428]}, {"category": "car", "bbox": [0.242, 0.488, 0.56, 0.726]}, {"category": "car", "bbox": [0.279, 0.325, 0.33, 0.363]}, {"category": "car", "bbox": [0.153, 0.333, 0.29, 0.405]}, {"category": "car", "bbox": [0.11, 0.339, 0.177, 0.373]}, {"category": "car", "bbox": [0.0, 0.654, 0.082, 0.826]}, {"category": "car", "bbox": [0.0, 0.322, 0.064, 0.364]}, {"category": "car", "bbox": [0.451, 0.333, 0.51, 0.392]}]}
|
| 30 |
+
{"id": "000000455772", "image": "COCO_val2014_000000455772.jpg", "captions": ["A person in a field jumping to catch a Frisbee.", "A guy jumping to catch a frisbee in mid-air.", "A person that is trying to get a frisbee.", "Nice reach, but the Frisbee flies on, victorious.", "A man playing frisbee in a grassy yard."], "instances": [{"category": "car", "bbox": [0.148, 0.339, 0.201, 0.476]}, {"category": "car", "bbox": [0.376, 0.396, 0.424, 0.476]}, {"category": "person", "bbox": [0.547, 0.122, 0.698, 0.904]}, {"category": "frisbee", "bbox": [0.479, 0.154, 0.555, 0.231]}, {"category": "car", "bbox": [0.001, 0.299, 0.085, 0.394]}]}
|
| 31 |
+
{"id": "000000511117", "image": "COCO_val2014_000000511117.jpg", "captions": ["A couple of kids standing on top of a grass covered field.", "A little boy wearing a baseball uniform stands by a little girl.", "A young boy in a baseball uniform and a young girl are standing in front of a chain link fence.", "A little boy and girl standing on a baseball field. The boy has a uniform on.", "A young baseball player is standing next to a young girl."], "instances": [{"category": "person", "bbox": [0.514, 0.178, 0.776, 0.774]}, {"category": "baseball glove", "bbox": [0.468, 0.462, 0.593, 0.609]}, {"category": "person", "bbox": [0.174, 0.051, 0.598, 0.839]}, {"category": "bench", "bbox": [0.558, 0.125, 1.0, 0.315]}]}
|
| 32 |
+
{"id": "000000207151", "image": "COCO_val2014_000000207151.jpg", "captions": ["A vegetarian pizza is half eaten on a pizza holder.", "A couple of pieces of pizza with vegetable slices on them.", "A wooden pan serving tray with a pizza on it.", "A pizza on a cutting board is half gone.", "A Pizza is nearly finished with only three pieces left."], "instances": [{"category": "bottle", "bbox": [0.001, 0.001, 0.121, 0.231]}, {"category": "cup", "bbox": [0.0, 0.002, 0.121, 0.238]}, {"category": "pizza", "bbox": [0.17, 0.472, 0.526, 0.82]}, {"category": "pizza", "bbox": [0.398, 0.106, 0.962, 0.679]}, {"category": "dining table", "bbox": [0.0, 0.001, 1.0, 0.988]}]}
|
| 33 |
+
{"id": "000000431165", "image": "COCO_val2014_000000431165.jpg", "captions": ["A baby elephant standing in front of a brick building.", "An elephant is standing near a dirt mount in an exhibit.", "Grey elephant standing next to a large sand dune in a pen.", "An elephant standing alone inside of an enclosure.", "The baby elephant is alone in the pen."], "instances": [{"category": "elephant", "bbox": [0.303, 0.399, 0.638, 0.78]}]}
|
| 34 |
+
{"id": "000000378545", "image": "COCO_val2014_000000378545.jpg", "captions": ["A pole that has a clock on top of it.", "A clock mounted on an outdoor post with Roman numerals.", "a clock on a pole saying it is 12:45", "An ornamental standing clock is at the foreground of a row of houses.", "A black and gold clock on a pole in front of a building."], "instances": [{"category": "clock", "bbox": [0.216, 0.249, 0.749, 0.658]}]}
|
| 35 |
+
{"id": "000000555904", "image": "COCO_val2014_000000555904.jpg", "captions": ["A man sitting at a bar filled with liquor.", "People sitting a a take near several bottles of wine on shelves.", "Several people are sitting at a table drinking.", "Several people in a bar sitting at a long table.", "People eating in a restaurant near wine bottles."], "instances": [{"category": "dining table", "bbox": [0.123, 0.663, 0.317, 0.811]}, {"category": "person", "bbox": [0.715, 0.239, 1.0, 0.998]}, {"category": "person", "bbox": [0.142, 0.528, 0.281, 0.742]}, {"category": "person", "bbox": [0.529, 0.53, 0.606, 0.69]}, {"category": "person", "bbox": [0.705, 0.518, 0.796, 0.673]}, {"category": "wine glass", "bbox": [0.247, 0.669, 0.27, 0.718]}, {"category": "person", "bbox": [0.281, 0.524, 0.534, 1.0]}, {"category": "bottle", "bbox": [0.168, 0.346, 0.189, 0.425]}, {"category": "bottle", "bbox": [0.379, 0.264, 0.431, 0.433]}, {"category": "bottle", "bbox": [0.252, 0.313, 0.277, 0.429]}, {"category": "bottle", "bbox": [0.294, 0.295, 0.326, 0.43]}, {"category": "bottle", "bbox": [0.589, 0.35, 0.613, 0.444]}, {"category": "bottle", "bbox": [0.433, 0.281, 0.473, 0.437]}, {"category": "bottle", "bbox": [0.478, 0.289, 0.513, 0.44]}, {"category": "wine glass", "bbox": [0.688, 0.615, 0.709, 0.69]}, {"category": "cup", "bbox": [0.589, 0.647, 0.612, 0.693]}, {"category": "person", "bbox": [0.732, 0.356, 0.953, 0.806]}, {"category": "bottle", "bbox": [0.555, 0.337, 0.585, 0.438]}, {"category": "bottle", "bbox": [0.337, 0.29, 0.378, 0.432]}, {"category": "bottle", "bbox": [0.21, 0.333, 0.232, 0.426]}, {"category": "bottle", "bbox": [0.134, 0.36, 0.148, 0.422]}, {"category": "bottle", "bbox": [0.516, 0.312, 0.557, 0.439]}, {"category": "cup", "bbox": [0.231, 0.718, 0.26, 0.763]}, {"category": "chair", "bbox": [0.517, 0.828, 0.65, 0.999]}, {"category": "chair", "bbox": [0.643, 0.804, 0.738, 0.841]}, {"category": "chair", "bbox": [0.347, 0.908, 0.519, 1.0]}, {"category": "chair", "bbox": [0.64, 0.806, 0.74, 0.998]}, {"category": "cup", "bbox": [0.205, 0.692, 0.232, 0.767]}, {"category": "dining table", "bbox": [0.536, 0.676, 0.743, 0.838]}, {"category": "person", "bbox": [0.002, 0.501, 0.263, 0.987]}, {"category": "bottle", "bbox": [0.531, 0.461, 0.542, 0.526]}, {"category": "bottle", "bbox": [0.237, 0.354, 0.702, 0.629]}]}
|
| 36 |
+
{"id": "000000415393", "image": "COCO_val2014_000000415393.jpg", "captions": ["a man on a skate board looks like he is falling", "A man does a skateboard trick on a skateboard ramp", "Guy falling off a skateboard in a room.", "A man riding a skateboard on top of a table.", "a man skating on part of a ramp with his skateboard"], "instances": [{"category": "person", "bbox": [0.361, 0.016, 0.809, 0.888]}, {"category": "skateboard", "bbox": [0.606, 0.809, 0.889, 0.901]}, {"category": "person", "bbox": [0.479, 0.091, 0.576, 0.386]}, {"category": "person", "bbox": [0.047, 0.441, 0.197, 0.759]}, {"category": "person", "bbox": [0.038, 0.453, 0.076, 0.545]}, {"category": "person", "bbox": [0.249, 0.307, 0.311, 0.591]}]}
|
| 37 |
+
{"id": "000000161011", "image": "COCO_val2014_000000161011.jpg", "captions": ["Three skiers posing for a picture on the slope.", "Three skiers pause for a photo at the top of a mountain.", "Three people standing on a mountain taking a picture as they ski.", "A woman and two men on skis on a snowy hillside surrounded by trees", "Three skiers have stopped to pose for a picture."], "instances": [{"category": "person", "bbox": [0.36, 0.321, 0.509, 0.82]}, {"category": "person", "bbox": [0.179, 0.281, 0.349, 0.795]}, {"category": "person", "bbox": [0.611, 0.292, 0.751, 0.809]}, {"category": "skis", "bbox": [0.595, 0.743, 0.732, 0.961]}, {"category": "skis", "bbox": [0.341, 0.724, 0.621, 0.907]}, {"category": "skis", "bbox": [0.212, 0.705, 0.398, 0.905]}]}
|
| 38 |
+
{"id": "000000284296", "image": "COCO_val2014_000000284296.jpg", "captions": ["Three giraffe's leaning over to get a sip of water.", "an image of a herd of giraffes in the water", "three giraffes banding down to drink water with trees in the background", "Three giraffe drinking from a pond with brush in back.", "Giraffes leaning down to drink at a watering hole"], "instances": [{"category": "giraffe", "bbox": [0.624, 0.387, 0.822, 0.635]}, {"category": "giraffe", "bbox": [0.4, 0.326, 0.561, 0.58]}, {"category": "giraffe", "bbox": [0.152, 0.291, 0.343, 0.551]}]}
|
| 39 |
+
{"id": "000000056013", "image": "COCO_val2014_000000056013.jpg", "captions": ["a number of luggage bags on a cart in a lobby", "Wheeled cart with luggage at lobby of commercial business.", "Trolley used for transporting personal luggage to guests rooms.", "A luggage cart topped with lots of luggage.", "a cart filled with suitcases and bags"], "instances": [{"category": "backpack", "bbox": [0.276, 0.52, 0.456, 0.678]}, {"category": "suitcase", "bbox": [0.41, 0.58, 0.597, 0.827]}, {"category": "suitcase", "bbox": [0.173, 0.645, 0.363, 0.836]}, {"category": "person", "bbox": [0.959, 0.297, 1.0, 0.478]}, {"category": "suitcase", "bbox": [0.526, 0.519, 0.712, 0.706]}, {"category": "person", "bbox": [0.762, 0.253, 0.871, 0.46]}, {"category": "backpack", "bbox": [0.517, 0.514, 0.694, 0.698]}, {"category": "handbag", "bbox": [0.316, 0.181, 0.431, 0.426]}, {"category": "suitcase", "bbox": [0.747, 0.453, 0.858, 0.557]}]}
|
| 40 |
+
{"id": "000000293505", "image": "COCO_val2014_000000293505.jpg", "captions": ["A person on a motor bike next to a cow.", "A woman riding a motorcycle down a dirt road.", "there is a woman riding a scooter down a dirt road", "A woman on a moped, two men and animals walking down the road.", "A woman on a motorcycle is next to a man walking a dog along with other people going down a dirt road."], "instances": [{"category": "cow", "bbox": [0.602, 0.472, 0.721, 0.816]}, {"category": "motorcycle", "bbox": [0.402, 0.512, 0.516, 0.788]}, {"category": "person", "bbox": [0.408, 0.4, 0.514, 0.639]}, {"category": "person", "bbox": [0.754, 0.301, 1.0, 1.0]}, {"category": "person", "bbox": [0.705, 0.415, 0.789, 0.714]}, {"category": "cow", "bbox": [0.347, 0.44, 0.373, 0.509]}, {"category": "cow", "bbox": [0.361, 0.436, 0.381, 0.501]}]}
|
| 41 |
+
{"id": "000000305873", "image": "COCO_val2014_000000305873.jpg", "captions": ["A little girl holding a red black dotted umbrella.", "A little girl with rain boots and a rain jacket on and an open umbrella to match her jacket.", "a little girl holding onto a lady bug pattern umbrella", "The child wears a labybug rain coat with a matching umbrella.", "A little girl wearing a ladybug raincoat and green rubber boots holding a ladybug umbrella"], "instances": [{"category": "umbrella", "bbox": [0.246, 0.002, 0.992, 0.415]}, {"category": "person", "bbox": [0.35, 0.132, 0.699, 0.791]}, {"category": "car", "bbox": [0.614, 0.0, 1.0, 0.465]}]}
|
| 42 |
+
{"id": "000000034096", "image": "COCO_val2014_000000034096.jpg", "captions": ["A house being built with lots of wood.", "A big pile of building material is placed on the floor in the wooden structure.", "A partially-built house with wooden studs and staircase in view.", "A house full of wood getting built at the moment.", "The beginning stages of a home still being made."], "instances": [{"category": "bed", "bbox": [0.505, 0.42, 0.721, 0.59]}, {"category": "tv", "bbox": [0.192, 0.441, 0.335, 0.606]}]}
|
| 43 |
+
{"id": "000000165257", "image": "COCO_val2014_000000165257.jpg", "captions": ["A large black counter top sitting next to a sink.", "a clean kitchen counter with a clean sink", "A kitchen with a sink, dishwasher and some boxes on the counter.", "A kitchen with a sink, dishwasher and boxes on the counter.", "a black counter on a wood cabinet in a kitchen", "a new kitchen cabinet with a sink being installed"], "instances": [{"category": "sink", "bbox": [0.513, 0.243, 0.718, 0.314]}]}
|
| 44 |
+
{"id": "000000431026", "image": "COCO_val2014_000000431026.jpg", "captions": ["a street sign on a city street near some tall bushes", "street signs on a metal pole lining a sidewalk lined with shrubbery.", "a large hedge of bushes on a corner near a street sign.", "Two street signs on sidewalk next to bushes and trees.", "Street signs along a well manicured street with large houses."], "instances": []}
|
| 45 |
+
{"id": "000000524575", "image": "COCO_val2014_000000524575.jpg", "captions": ["Three giraffe and a wildebeest in a field.", "A moose and several giraffes are grazing in the field.", "Zebras in the wild with a wildebeest behind them", "Two giraffe and a ox standing in a field eating grass.", "Giraffes and other safari animals graze in a sunlit field."], "instances": [{"category": "cow", "bbox": [0.46, 0.716, 0.643, 0.999]}, {"category": "giraffe", "bbox": [0.285, 0.5, 0.401, 0.826]}, {"category": "giraffe", "bbox": [0.083, 0.554, 0.179, 0.821]}, {"category": "giraffe", "bbox": [0.887, 0.481, 0.968, 0.715]}]}
|
| 46 |
+
{"id": "000000326550", "image": "COCO_val2014_000000326550.jpg", "captions": ["Black and white photograph of a person holding a surfboard by water.", "A person with a surfboard standing next to the water.", "A surfer stands on the rocks watching a wave crash.", "A man standing on a beach holding a surfboard.", "a person looking at the waves ready to surf"], "instances": [{"category": "person", "bbox": [0.327, 0.461, 0.492, 0.897]}, {"category": "surfboard", "bbox": [0.282, 0.56, 0.606, 0.741]}, {"category": "person", "bbox": [0.924, 0.352, 0.933, 0.362]}, {"category": "person", "bbox": [0.912, 0.348, 0.919, 0.36]}]}
|
| 47 |
+
{"id": "000000018476", "image": "COCO_val2014_000000018476.jpg", "captions": ["A tie that is sitting on top of a shirt.", "This photograph appears to be looking truly wonderful.", "a uniform complete with shoes laying on a bed", "Suit laid out with a red tie, white shirt and black shoes.", "a white shirt a red tie and some black shoes"], "instances": [{"category": "tie", "bbox": [0.457, 0.09, 0.853, 0.984]}, {"category": "bed", "bbox": [0.005, 0.005, 1.0, 0.379]}]}
|
| 48 |
+
{"id": "000000480652", "image": "COCO_val2014_000000480652.jpg", "captions": ["These suitcases are sitting next to a chair.", "An assortment of luggage bags stacked by a kitchen chair.", "A stack of luggage by a chair and table.", "a table and chair with several pieces of luggage nearby", "A pile of luggage sitting on the floor."], "instances": [{"category": "chair", "bbox": [0.483, 0.192, 1.0, 0.769]}, {"category": "backpack", "bbox": [0.433, 0.429, 0.742, 0.856]}, {"category": "suitcase", "bbox": [0.059, 0.414, 0.453, 0.841]}, {"category": "handbag", "bbox": [0.19, 0.184, 0.779, 0.475]}, {"category": "suitcase", "bbox": [0.175, 0.204, 0.583, 0.462]}]}
|
| 49 |
+
{"id": "000000012748", "image": "COCO_val2014_000000012748.jpg", "captions": ["A man and child next to a horse.", "a little boy touching the nose of a brown horse", "A man holding a baby whose petting a horse.", "a man letting his baby pet a horse", "man holding a baby and petting a horse"], "instances": [{"category": "horse", "bbox": [0.003, 0.079, 0.504, 0.868]}, {"category": "person", "bbox": [0.452, 0.294, 1.0, 0.989]}, {"category": "person", "bbox": [0.46, 0.217, 1.0, 0.988]}]}
|
| 50 |
+
{"id": "000000247840", "image": "COCO_val2014_000000247840.jpg", "captions": ["Large group of people standing outside a restaurant together.", "A dairy queen has people standing outside waiting", "an image of people standing outside and ice cream store", "Several people are lined up outside of a store.", "The front of a Dairy Queen restaurant with people entering the side."], "instances": [{"category": "fire hydrant", "bbox": [0.774, 0.674, 0.83, 0.807]}, {"category": "person", "bbox": [0.741, 0.465, 0.824, 0.755]}, {"category": "person", "bbox": [0.806, 0.471, 0.839, 0.722]}, {"category": "person", "bbox": [0.831, 0.499, 0.866, 0.726]}, {"category": "bench", "bbox": [0.061, 0.69, 0.219, 0.768]}, {"category": "handbag", "bbox": [0.859, 0.558, 0.877, 0.603]}, {"category": "person", "bbox": [0.719, 0.504, 0.75, 0.626]}, {"category": "potted plant", "bbox": [0.7, 0.648, 0.764, 0.743]}, {"category": "handbag", "bbox": [0.827, 0.548, 0.837, 0.577]}, {"category": "sandwich", "bbox": [0.359, 0.618, 0.417, 0.694]}]}
|
| 51 |
+
{"id": "000000399452", "image": "COCO_val2014_000000399452.jpg", "captions": ["a sandwhich sitting on a plate next to a glass of tea, bowl of soup", "a sandwich on a white plate a drink on a brown table", "A sandwich and chips sit on a white plate.", "a large plate of food with a glass of soda by it", "A sandwich sitting on top of a white plate next to a cup of coffee."], "instances": [{"category": "sandwich", "bbox": [0.175, 0.326, 0.605, 0.71]}, {"category": "cup", "bbox": [0.504, 0.024, 0.687, 0.419]}, {"category": "knife", "bbox": [0.742, 0.283, 0.857, 0.376]}, {"category": "spoon", "bbox": [0.618, 0.46, 0.797, 0.809]}, {"category": "fork", "bbox": [0.684, 0.254, 0.805, 0.395]}, {"category": "bowl", "bbox": [0.782, 0.366, 1.0, 0.62]}, {"category": "chair", "bbox": [0.202, 0.0, 0.671, 0.148]}, {"category": "dining table", "bbox": [0.002, 0.126, 0.996, 0.987]}]}
|
| 52 |
+
{"id": "000000515716", "image": "COCO_val2014_000000515716.jpg", "captions": ["A couple of women standing on either side of a man wearing glasses.", "Two women and a man are holding glasses up at a wine tasting.", "Three young adults holding wine glasses while standing at a bar.", "A group of people sit holding glasses and smiling at a table with several bottles.", "A group of people at a celebration having a taste of wine."], "instances": [{"category": "bottle", "bbox": [0.529, 0.604, 0.637, 0.908]}, {"category": "bottle", "bbox": [0.379, 0.398, 0.481, 0.892]}, {"category": "bottle", "bbox": [0.942, 0.464, 0.988, 0.653]}, {"category": "person", "bbox": [0.0, 0.126, 0.136, 0.811]}, {"category": "person", "bbox": [0.05, 0.093, 0.211, 0.471]}, {"category": "person", "bbox": [0.401, 0.031, 0.678, 0.683]}, {"category": "person", "bbox": [0.617, 0.191, 0.94, 0.858]}, {"category": "person", "bbox": [0.723, 0.098, 0.947, 0.564]}, {"category": "wine glass", "bbox": [0.634, 0.434, 0.697, 0.628]}, {"category": "wine glass", "bbox": [0.285, 0.346, 0.372, 0.558]}, {"category": "wine glass", "bbox": [0.522, 0.422, 0.583, 0.544]}, {"category": "handbag", "bbox": [0.704, 0.601, 1.0, 0.916]}, {"category": "person", "bbox": [0.944, 0.319, 0.999, 0.604]}, {"category": "bottle", "bbox": [0.921, 0.46, 0.953, 0.636]}, {"category": "person", "bbox": [0.116, 0.171, 0.41, 0.829]}]}
|
| 53 |
+
{"id": "000000116173", "image": "COCO_val2014_000000116173.jpg", "captions": ["The boy is on his surfboard in the water riding it.", "a young boy riding a boogie board in the water", "A boy riding surf board in the ocean.", "A young boy is riding a surfboard on a small wave.", "A young boy is surfing in the ocean."], "instances": [{"category": "person", "bbox": [0.485, 0.238, 0.702, 0.821]}, {"category": "person", "bbox": [0.866, 0.223, 0.921, 0.29]}, {"category": "person", "bbox": [0.752, 0.146, 0.775, 0.188]}, {"category": "surfboard", "bbox": [0.239, 0.758, 0.782, 0.846]}, {"category": "surfboard", "bbox": [0.853, 0.277, 0.981, 0.29]}, {"category": "surfboard", "bbox": [0.727, 0.169, 0.801, 0.198]}, {"category": "person", "bbox": [0.637, 0.194, 0.677, 0.261]}]}
|
| 54 |
+
{"id": "000000186013", "image": "COCO_val2014_000000186013.jpg", "captions": ["A beach scene includes many different kites flying in a cloudy sky.", "Kites being flown at the beach at twilight.", "A beach with flags in the ground and kites overhead in the sky.", "A beach with rows of flags in the sand and kites flying overhead.", "A beach filled with kites and wind sails next to the ocean."], "instances": [{"category": "kite", "bbox": [0.174, 0.4, 0.351, 0.483]}, {"category": "kite", "bbox": [0.144, 0.13, 0.273, 0.17]}, {"category": "kite", "bbox": [0.236, 0.269, 0.268, 0.294]}, {"category": "kite", "bbox": [0.464, 0.204, 0.598, 0.271]}, {"category": "kite", "bbox": [0.61, 0.304, 0.659, 0.342]}, {"category": "kite", "bbox": [0.545, 0.435, 0.565, 0.452]}, {"category": "kite", "bbox": [0.027, 0.558, 0.151, 0.59]}, {"category": "kite", "bbox": [0.93, 0.429, 0.973, 0.536]}, {"category": "kite", "bbox": [0.684, 0.36, 0.697, 0.374]}, {"category": "surfboard", "bbox": [0.393, 0.627, 0.446, 0.934]}, {"category": "person", "bbox": [0.959, 0.685, 0.984, 0.713]}, {"category": "person", "bbox": [0.919, 0.681, 0.94, 0.725]}, {"category": "person", "bbox": [0.8, 0.597, 0.805, 0.61]}, {"category": "person", "bbox": [0.079, 0.928, 0.116, 0.975]}, {"category": "kite", "bbox": [0.743, 0.307, 0.755, 0.319]}, {"category": "kite", "bbox": [0.78, 0.322, 0.795, 0.335]}, {"category": "kite", "bbox": [0.536, 0.526, 0.597, 0.617]}, {"category": "person", "bbox": [0.941, 0.694, 0.961, 0.726]}, {"category": "kite", "bbox": [0.575, 0.446, 0.594, 0.471]}]}
|
| 55 |
+
{"id": "000000015029", "image": "COCO_val2014_000000015029.jpg", "captions": ["A man holding a white frisbee standing on top of a field.", "A man is playing frisbee next to a tent.", "Guy at the park holding a frisbee with people in the back under a tent", "A man is holding a Frisbee standing in the grass.", "Young adult male holding a frisbee at an event."], "instances": [{"category": "frisbee", "bbox": [0.138, 0.359, 0.215, 0.587]}, {"category": "person", "bbox": [0.16, 0.002, 0.726, 0.995]}, {"category": "person", "bbox": [0.81, 0.73, 0.852, 0.825]}, {"category": "person", "bbox": [0.786, 0.749, 0.833, 0.814]}, {"category": "person", "bbox": [0.847, 0.743, 0.89, 0.804]}, {"category": "person", "bbox": [0.614, 0.749, 0.706, 0.936]}]}
|
| 56 |
+
{"id": "000000500565", "image": "COCO_val2014_000000500565.jpg", "captions": ["A woman holding a child wrapped in a towel brushing her teeth.", "A woman is holding a baby who is wrapped in a towel and holding a toothbrush", "A woman holding a little boy who is brushing his teeth.", "A baby with a toothbrush in his mouth while being held by a woman", "a close up of an adult holding a child brushing their teeth"], "instances": [{"category": "toothbrush", "bbox": [0.586, 0.66, 0.754, 0.821]}, {"category": "person", "bbox": [0.002, 0.007, 0.637, 0.991]}, {"category": "person", "bbox": [0.357, 0.196, 0.998, 0.984]}]}
|
| 57 |
+
{"id": "000000297323", "image": "COCO_val2014_000000297323.jpg", "captions": ["Two buses are parked against a curb in front of a building.", "Two automobiles parked on the side of a building.", "two tourist buses parked on street in front of old industrial building", "Two unique city buses stopped at a stop sign.", "Buses parked outside by a building and stop sign."], "instances": [{"category": "bus", "bbox": [0.7, 0.711, 0.92, 0.881]}, {"category": "person", "bbox": [0.936, 0.771, 0.972, 0.833]}, {"category": "stop sign", "bbox": [0.237, 0.666, 0.285, 0.728]}, {"category": "bus", "bbox": [0.334, 0.71, 0.678, 0.935]}, {"category": "truck", "bbox": [0.335, 0.72, 0.683, 0.934]}, {"category": "person", "bbox": [0.34, 0.791, 0.367, 0.834]}]}
|
| 58 |
+
{"id": "000000441147", "image": "COCO_val2014_000000441147.jpg", "captions": ["Two antique suitcases sit stacked one on top of the other.", "Two suitcases are stacked on each other and one is black while the other is brown and yellow.", "a close up of two luggage suit cases stacked on each other", "A stack of antique luggage is displayed with price tags.", "two suitcases made of leather and stacked on top of each other"], "instances": [{"category": "suitcase", "bbox": [0.167, 0.025, 0.989, 0.445]}, {"category": "suitcase", "bbox": [0.002, 0.31, 0.994, 0.996]}]}
|
| 59 |
+
{"id": "000000353536", "image": "COCO_val2014_000000353536.jpg", "captions": ["A table topped with plates and glasses with eating utensils..", "a fork is laying on a small white plate", "dirty dishes on a table, and a bottle of something.", "a table top with some dishes on top of it", "A table full of dirty dishes is pictured in this image."], "instances": [{"category": "dining table", "bbox": [0.0, 0.007, 0.998, 0.988]}, {"category": "bottle", "bbox": [0.554, 0.002, 0.768, 0.411]}, {"category": "cup", "bbox": [0.372, 0.011, 0.544, 0.427]}, {"category": "fork", "bbox": [0.442, 0.464, 0.818, 0.572]}, {"category": "fork", "bbox": [0.089, 0.233, 0.272, 0.456]}, {"category": "spoon", "bbox": [0.144, 0.218, 0.326, 0.413]}, {"category": "cup", "bbox": [0.688, 0.056, 0.812, 0.361]}]}
|
| 60 |
+
{"id": "000000416256", "image": "COCO_val2014_000000416256.jpg", "captions": ["A cat laying on the floor next to a keyboard.", "an orange and white cat is laying next to a keyboard and some wires", "A cat is laying next to a computer keyboard.", "a cat laying on a floor next to a keyboard", "A CAT LAYING ON THE FLOOR AMIDST A COMPUTER,SPEAKERS,CORDS"], "instances": [{"category": "cat", "bbox": [0.235, 0.23, 0.737, 0.639]}, {"category": "keyboard", "bbox": [0.243, 0.562, 0.631, 0.836]}, {"category": "keyboard", "bbox": [0.058, 0.33, 0.277, 0.608]}]}
|
| 61 |
+
{"id": "000000214367", "image": "COCO_val2014_000000214367.jpg", "captions": ["Wood shading on the side of a window with brick siding.", "A tree filled with lots of red fruit near a building.", "By the window outside is a apple tree, where the apples are ready to be picked.", "Some very nice looking red fruity by a window,", "A shuttered window has a fruit tree outside it."], "instances": [{"category": "apple", "bbox": [0.214, 0.112, 0.408, 0.266]}, {"category": "apple", "bbox": [0.472, 0.166, 0.618, 0.293]}, {"category": "apple", "bbox": [0.055, 0.592, 0.172, 0.686]}, {"category": "apple", "bbox": [0.126, 0.661, 0.236, 0.739]}, {"category": "apple", "bbox": [0.52, 0.09, 0.609, 0.143]}, {"category": "apple", "bbox": [0.226, 0.354, 0.285, 0.409]}, {"category": "apple", "bbox": [0.0, 0.698, 0.096, 0.771]}, {"category": "apple", "bbox": [0.001, 0.646, 0.042, 0.713]}, {"category": "apple", "bbox": [0.258, 0.719, 0.329, 0.778]}]}
|
| 62 |
+
{"id": "000000210299", "image": "COCO_val2014_000000210299.jpg", "captions": ["A little boy riding his bike and wearing a helmet", "A little boy raveling down a road on a bike, with a yellow helmet on.", "The boy wears a helmet while riding his bicycle.", "a small child wearing a helmet and riding a bike", "A little boy wearing a helmet and riding a bike."], "instances": [{"category": "person", "bbox": [0.198, 0.259, 0.399, 0.679]}, {"category": "bicycle", "bbox": [0.213, 0.383, 0.408, 0.835]}]}
|
| 63 |
+
{"id": "000000088218", "image": "COCO_val2014_000000088218.jpg", "captions": ["Signs proclaim the famous Haight Ashbury intersection and district.", "a pole with street lights, signs and wires attached to it", "A traffic light at the intersection of Haight and Ashbury", "A traffic sign is shown with traffic signs above it.", "The street signs and traffic signal are below wires attached to the pole."], "instances": [{"category": "traffic light", "bbox": [0.443, 0.435, 0.658, 0.721]}]}
|
| 64 |
+
{"id": "000000020650", "image": "COCO_val2014_000000020650.jpg", "captions": ["Burger with broccoli, pickle, and fork on orange plate", "On a plate is kept a burger and a bowl of broccoli and a fork.", "There is half a sandwich on an orange plate with a pickle and a bowl of broccoli", "A A bowl and a sandwich on an orange plate on a table.", "A plate has a sandwich, broccoli, and a pickle."], "instances": [{"category": "sandwich", "bbox": [0.436, 0.155, 0.805, 0.859]}, {"category": "sandwich", "bbox": [0.311, 0.006, 0.748, 0.293]}, {"category": "fork", "bbox": [0.0, 0.665, 0.578, 0.876]}, {"category": "bowl", "bbox": [0.002, 0.263, 0.487, 0.744]}, {"category": "bowl", "bbox": [0.708, 0.003, 0.828, 0.03]}, {"category": "broccoli", "bbox": [0.185, 0.288, 0.366, 0.546]}, {"category": "broccoli", "bbox": [0.017, 0.344, 0.384, 0.654]}, {"category": "broccoli", "bbox": [0.31, 0.191, 0.466, 0.463]}, {"category": "broccoli", "bbox": [0.104, 0.107, 0.285, 0.342]}, {"category": "broccoli", "bbox": [0.092, 0.276, 0.242, 0.442]}, {"category": "dining table", "bbox": [0.002, 0.0, 0.999, 0.987]}]}
|
| 65 |
+
{"id": "000000514915", "image": "COCO_val2014_000000514915.jpg", "captions": ["A large black dog laying on a kitchen floor.", "A dog is laying down on the floor in the home.", "Black dog laying down on the kitchen floor next to it's bowls and toy", "A black dog with a red collar laying on a tiled floor.", "A black dog that is laying on the floor."], "instances": [{"category": "dog", "bbox": [0.087, 0.276, 0.812, 0.792]}, {"category": "bowl", "bbox": [0.437, 0.09, 0.533, 0.213]}, {"category": "bowl", "bbox": [0.537, 0.035, 0.665, 0.141]}]}
|
| 66 |
+
{"id": "000000205183", "image": "COCO_val2014_000000205183.jpg", "captions": ["A duck walking along a paved road next to a patch of grass.", "A close up of a duck walking on a path.", "a duck walks along a cement patch while looking down", "A white duck out of water, walking on the ground.", "A goose standing in the road, looking at the ground."], "instances": [{"category": "bird", "bbox": [0.291, 0.235, 0.859, 0.889]}]}
|
| 67 |
+
{"id": "000000534270", "image": "COCO_val2014_000000534270.jpg", "captions": ["Man and woman with umbrella hats sitting on top of a bridge.", "A couple equipped with umbrella hats taking a break from walking their dog on a bridge on a rainy day.", "Two people in ridiculous looking umbrella hats.", "two people with umbrella hats near one another", "A couple of people wearing umbrella hats next to the ocean."], "instances": [{"category": "dog", "bbox": [0.456, 0.832, 0.6, 0.983]}, {"category": "person", "bbox": [0.433, 0.464, 0.636, 0.975]}, {"category": "person", "bbox": [0.263, 0.321, 0.459, 0.978]}, {"category": "boat", "bbox": [0.912, 0.4, 0.978, 0.433]}, {"category": "boat", "bbox": [0.211, 0.236, 0.478, 0.304]}, {"category": "boat", "bbox": [0.144, 0.328, 0.189, 0.361]}, {"category": "umbrella", "bbox": [0.443, 0.402, 0.607, 0.473]}, {"category": "umbrella", "bbox": [0.325, 0.311, 0.483, 0.432]}, {"category": "umbrella", "bbox": [0.207, 0.738, 0.284, 0.778]}, {"category": "umbrella", "bbox": [0.489, 0.713, 0.649, 0.83]}]}
|
| 68 |
+
{"id": "000000408439", "image": "COCO_val2014_000000408439.jpg", "captions": ["Cliffs rise on the edge of a placid lake.", "A scenic view of a river with a train on the edge of it in the distance.", "A large lake surrounded by beautiful tree covered mountains.", "a landscape scene with water, mountains and trees", "A train on a waterfront track surrounded by mountains."], "instances": [{"category": "train", "bbox": [0.008, 0.591, 0.562, 0.644]}]}
|
| 69 |
+
{"id": "000000474253", "image": "COCO_val2014_000000474253.jpg", "captions": ["A man riding on the back of a horse through a river.", "A person is riding a horse through water.", "Horse and rider crossing waterway during competitive event.", "A woman riding a horse splashes through a large puddle.", "A young man riding a horse through some water."], "instances": [{"category": "horse", "bbox": [0.385, 0.235, 0.651, 0.814]}, {"category": "person", "bbox": [0.396, 0.06, 0.576, 0.675]}, {"category": "person", "bbox": [0.29, 0.148, 0.355, 0.333]}, {"category": "person", "bbox": [0.129, 0.163, 0.212, 0.349]}, {"category": "person", "bbox": [0.005, 0.014, 0.038, 0.165]}, {"category": "person", "bbox": [0.144, 0.011, 0.193, 0.155]}, {"category": "person", "bbox": [0.089, 0.007, 0.133, 0.162]}]}
|
| 70 |
+
{"id": "000000098029", "image": "COCO_val2014_000000098029.jpg", "captions": ["a table with many plates on it with a bread basket", "A table set for four has many foods and fruits on it.", "Several objects displayed on a kitchen table including bread, oranges and plating.", "Several dishes and food items sit on a table.", "An assortment of foods sitting on a round brown table."], "instances": [{"category": "refrigerator", "bbox": [0.013, 0.004, 0.37, 0.317]}, {"category": "bottle", "bbox": [0.467, 0.517, 0.555, 0.638]}, {"category": "bottle", "bbox": [0.602, 0.536, 0.658, 0.609]}, {"category": "chair", "bbox": [0.747, 0.367, 1.0, 0.592]}, {"category": "chair", "bbox": [0.044, 0.368, 0.358, 0.544]}, {"category": "cup", "bbox": [0.296, 0.465, 0.359, 0.54]}, {"category": "cup", "bbox": [0.709, 0.67, 0.782, 0.736]}, {"category": "cup", "bbox": [0.213, 0.684, 0.294, 0.753]}, {"category": "knife", "bbox": [0.787, 0.699, 0.922, 0.797]}, {"category": "knife", "bbox": [0.161, 0.539, 0.265, 0.584]}, {"category": "spoon", "bbox": [0.813, 0.674, 0.922, 0.759]}, {"category": "spoon", "bbox": [0.156, 0.555, 0.233, 0.587]}, {"category": "spoon", "bbox": [0.596, 0.467, 0.613, 0.509]}, {"category": "bowl", "bbox": [0.241, 0.753, 0.505, 0.935]}, {"category": "banana", "bbox": [0.632, 0.138, 0.718, 0.161]}, {"category": "apple", "bbox": [0.701, 0.152, 0.758, 0.191]}, {"category": "orange", "bbox": [0.607, 0.66, 0.692, 0.716]}, {"category": "orange", "bbox": [0.565, 0.636, 0.611, 0.667]}, {"category": "orange", "bbox": [0.526, 0.624, 0.572, 0.652]}, {"category": "orange", "bbox": [0.61, 0.628, 0.656, 0.657]}, {"category": "orange", "bbox": [0.599, 0.649, 0.643, 0.677]}, {"category": "dining table", "bbox": [0.013, 0.439, 0.964, 0.986]}, {"category": "cup", "bbox": [0.612, 0.489, 0.669, 0.548]}, {"category": "knife", "bbox": [0.605, 0.457, 0.638, 0.53]}, {"category": "apple", "bbox": [0.502, 0.137, 0.537, 0.159]}, {"category": "orange", "bbox": [0.54, 0.135, 0.563, 0.151]}, {"category": "orange", "bbox": [0.527, 0.129, 0.554, 0.142]}, {"category": "orange", "bbox": [0.611, 0.155, 0.641, 0.171]}, {"category": "chair", "bbox": [0.0, 0.843, 0.29, 0.989]}, {"category": "cup", "bbox": [0.353, 0.469, 0.411, 0.511]}, {"category": "cup", "bbox": [0.609, 0.716, 0.682, 0.786]}, {"category": "orange", "bbox": [0.638, 0.158, 0.679, 0.177]}, {"category": "cake", "bbox": [0.38, 0.821, 0.481, 0.895]}, {"category": "chair", "bbox": [0.79, 0.747, 1.0, 1.0]}, {"category": "bottle", "bbox": [0.719, 0.55, 0.769, 0.616]}, {"category": "bottle", "bbox": [0.795, 0.546, 0.873, 0.613]}, {"category": "knife", "bbox": [0.17, 0.799, 0.264, 0.88]}, {"category": "cup", "bbox": [0.317, 0.695, 0.391, 0.752]}]}
|
| 71 |
+
{"id": "000000294073", "image": "COCO_val2014_000000294073.jpg", "captions": ["A woman and a man standing between two brown horses.", "A COUPLE WEARING YELLOW DRESS STANDING NEAR TWO HORSES.", "An older couple stands between two horses.", "A man and a woman standing with two horses", "A man and a woman stand in between two horses."], "instances": [{"category": "horse", "bbox": [0.0, 0.052, 0.49, 0.989]}, {"category": "horse", "bbox": [0.632, 0.23, 1.0, 0.989]}, {"category": "person", "bbox": [0.425, 0.326, 0.696, 0.987]}, {"category": "person", "bbox": [0.627, 0.203, 0.828, 0.986]}, {"category": "book", "bbox": [0.525, 0.597, 0.644, 0.833]}]}
|
| 72 |
+
{"id": "000000203629", "image": "COCO_val2014_000000203629.jpg", "captions": ["A man on a cell phone in a public area holding his thumb up.", "A group of people gathered inside of a room.", "A man on his cellphone posing for a picture.", "A man giving a thumbs up while on a cell phone.", "The man is giving a thumbs up while on his phone."], "instances": [{"category": "cell phone", "bbox": [0.43, 0.459, 0.449, 0.503]}, {"category": "cup", "bbox": [0.756, 0.838, 0.865, 0.98]}, {"category": "person", "bbox": [0.232, 0.317, 0.603, 0.98]}, {"category": "person", "bbox": [0.602, 0.405, 1.0, 0.999]}, {"category": "person", "bbox": [0.003, 0.339, 0.313, 0.987]}, {"category": "person", "bbox": [0.164, 0.379, 0.258, 0.733]}, {"category": "person", "bbox": [0.564, 0.36, 0.673, 0.645]}, {"category": "person", "bbox": [0.241, 0.379, 0.336, 0.512]}, {"category": "person", "bbox": [0.682, 0.372, 0.736, 0.502]}, {"category": "person", "bbox": [0.654, 0.428, 0.734, 0.536]}, {"category": "person", "bbox": [0.718, 0.368, 0.787, 0.508]}, {"category": "person", "bbox": [0.148, 0.362, 0.205, 0.529]}, {"category": "person", "bbox": [0.001, 0.431, 0.044, 0.564]}, {"category": "cup", "bbox": [0.901, 0.808, 0.995, 0.982]}]}
|
| 73 |
+
{"id": "000000119876", "image": "COCO_val2014_000000119876.jpg", "captions": ["A man dressed loudly is using his cell phone.", "A man talking on the phone while he walks down the street.", "A man with pink hair talking on a cell phone.", "A man in a purple shirt and tie and purple hair.", "a man colored his hair in purple walking on the road"], "instances": [{"category": "bicycle", "bbox": [0.525, 0.222, 0.924, 0.608]}, {"category": "bicycle", "bbox": [0.895, 0.249, 1.0, 0.642]}, {"category": "person", "bbox": [0.0, 0.0, 0.738, 1.0]}, {"category": "tie", "bbox": [0.319, 0.255, 0.423, 0.638]}, {"category": "cell phone", "bbox": [0.411, 0.13, 0.426, 0.161]}, {"category": "handbag", "bbox": [0.369, 0.205, 0.575, 0.839]}]}
|
| 74 |
+
{"id": "000000164255", "image": "COCO_val2014_000000164255.jpg", "captions": ["An umbrella that is standing in the sand.", "An umbrella is stuck in the sand on the beach.", "a colorful striped umbrella on the beach near the ocean", "A colorful umbrella is set up at the beach.", "The colorful umbrella is sitting by the beach,"], "instances": [{"category": "umbrella", "bbox": [0.0, 0.101, 0.567, 0.575]}]}
|
| 75 |
+
{"id": "000000192817", "image": "COCO_val2014_000000192817.jpg", "captions": ["A view from a window high up in the sky.", "A bunch of mountains seen from a plane window.", "The window from a plane overlooking the ground.", "The view of a mountain area from an airplane window.", "An aerial view of mountains and lakes from an airplane window."], "instances": []}
|
| 76 |
+
{"id": "000000258285", "image": "COCO_val2014_000000258285.jpg", "captions": ["Two large passenger jets flying over a beach filled with birds.", "A plane is flying over a bird filed lake", "Two airplanes are in the sky over blue water.", "An airplane landing over an airplane on the ground.", "A photo of two plans with water and birds surrounding it , one plane in the air one one the ground."], "instances": [{"category": "bird", "bbox": [0.507, 0.941, 0.536, 0.973]}, {"category": "bird", "bbox": [0.304, 0.933, 0.315, 0.95]}, {"category": "bird", "bbox": [0.129, 0.885, 0.143, 0.912]}, {"category": "bird", "bbox": [0.158, 0.851, 0.165, 0.87]}, {"category": "bird", "bbox": [0.404, 0.839, 0.429, 0.864]}, {"category": "bird", "bbox": [0.498, 0.833, 0.513, 0.861]}, {"category": "airplane", "bbox": [0.276, 0.085, 0.825, 0.316]}, {"category": "airplane", "bbox": [0.478, 0.252, 0.983, 0.495]}, {"category": "bird", "bbox": [0.552, 0.828, 0.564, 0.844]}, {"category": "bird", "bbox": [0.789, 0.812, 0.798, 0.836]}, {"category": "bird", "bbox": [0.927, 0.82, 0.936, 0.838]}, {"category": "bird", "bbox": [0.65, 0.828, 0.664, 0.849]}, {"category": "bird", "bbox": [0.752, 0.81, 0.763, 0.83]}, {"category": "bird", "bbox": [0.841, 0.817, 0.852, 0.828]}, {"category": "bird", "bbox": [0.292, 0.849, 0.311, 0.868]}, {"category": "bird", "bbox": [0.005, 0.727, 0.981, 0.998]}]}
|
| 77 |
+
{"id": "000000506483", "image": "COCO_val2014_000000506483.jpg", "captions": ["An art installation is placed by a street.", "People sit near a display of large artworks including an oversize bench and painted feline heads.", "Looking down on a giant rocking bench and large animal heads.", "An over sized wooden bench next to two massive animal art sculptures.", "artistic sculptures and images on a city street"], "instances": [{"category": "car", "bbox": [0.656, 0.939, 0.933, 1.0]}, {"category": "person", "bbox": [0.08, 0.664, 0.147, 0.805]}, {"category": "person", "bbox": [0.154, 0.646, 0.217, 0.821]}, {"category": "bench", "bbox": [0.316, 0.124, 0.951, 0.635]}, {"category": "backpack", "bbox": [0.062, 0.701, 0.097, 0.769]}, {"category": "person", "bbox": [0.0, 0.132, 0.031, 0.197]}]}
|
| 78 |
+
{"id": "000000502168", "image": "COCO_val2014_000000502168.jpg", "captions": ["a fleet of naval ships in the ocean", "A group of men on aircraft carrier with other boats in the distance.", "A large ship floating in the ocean next to other ships.", "Several men on a boat looking over the side.", "The men wear hardhats as they work on the aircraft carrier."], "instances": [{"category": "boat", "bbox": [0.634, 0.292, 1.0, 0.982]}, {"category": "person", "bbox": [0.675, 0.507, 0.736, 0.731]}, {"category": "person", "bbox": [0.684, 0.737, 0.817, 1.0]}, {"category": "person", "bbox": [0.803, 0.691, 0.883, 0.932]}, {"category": "person", "bbox": [0.741, 0.56, 0.798, 0.767]}, {"category": "person", "bbox": [0.924, 0.269, 0.951, 0.367]}, {"category": "boat", "bbox": [0.079, 0.171, 0.172, 0.231]}, {"category": "boat", "bbox": [0.863, 0.131, 0.961, 0.239]}, {"category": "boat", "bbox": [0.435, 0.288, 0.46, 0.313]}, {"category": "boat", "bbox": [0.591, 0.186, 0.605, 0.222]}, {"category": "person", "bbox": [0.451, 0.289, 0.455, 0.296]}, {"category": "person", "bbox": [0.446, 0.29, 0.451, 0.296]}, {"category": "person", "bbox": [0.872, 0.627, 0.957, 0.966]}, {"category": "person", "bbox": [0.44, 0.288, 0.446, 0.3]}]}
|
| 79 |
+
{"id": "000000319432", "image": "COCO_val2014_000000319432.jpg", "captions": ["Man holding two shirts with luggage and window", "A man holding clothes on a hanger with a suitcase in front of him.", "A man show a red and a white clothing hangers.", "A man holding his garment bags in both hands", "A man holding up some clothes in some hanger bags."], "instances": [{"category": "person", "bbox": [0.0, 0.092, 0.776, 0.852]}, {"category": "suitcase", "bbox": [0.153, 0.798, 0.587, 1.0]}]}
|
| 80 |
+
{"id": "000000131019", "image": "COCO_val2014_000000131019.jpg", "captions": ["Two zebras and two monkeys walking on the grass.", "Two giraffes and another animal are on green grass.", "A baboon and two zebras grazing on the savannah.", "A baboon and its baby eat by two zebras in the grass", "Monkey standing behind two zebras as they graze."], "instances": [{"category": "zebra", "bbox": [0.367, 0.258, 0.834, 0.646]}, {"category": "zebra", "bbox": [0.161, 0.13, 0.396, 0.375]}, {"category": "bird", "bbox": [0.309, 0.138, 0.34, 0.163]}]}
|
ChatUniVi/eval/table/model.jsonl
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"model_id": "vicuna-13b:20230322-clean-lang", "model_name": "vicuna-13b", "model_version": "20230322-clean-lang", "model_metadata": "vicuna-13b-20230322-clean-lang"}
|
| 2 |
+
{"model_id": "alpaca-13b:v1", "model_name": "alpaca-13b", "model_version": "v1", "model_metadata": "alpaca-13b"}
|
| 3 |
+
{"model_id": "llama-13b:v1", "model_name": "llama-13b", "model_version": "v1", "model_metadata": "hf-llama-13b"}
|
| 4 |
+
{"model_id": "bard:20230327", "model_name": "bard", "model_version": "20230327", "model_metadata": "Google Bard 20230327"}
|
| 5 |
+
{"model_id": "gpt-3.5-turbo:20230327", "model_name": "gpt-3.5-turbo", "model_version": "20230327", "model_metadata": "OpenAI ChatGPT gpt-3.5-turbo Chat Completion"}
|
ChatUniVi/eval/table/question.jsonl
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"question_id": 1, "text": "How can I improve my time management skills?", "category": "generic"}
|
| 2 |
+
{"question_id": 2, "text": "What are the most effective ways to deal with stress?", "category": "generic"}
|
| 3 |
+
{"question_id": 3, "text": "What are the main differences between Python and JavaScript programming languages?", "category": "generic"}
|
| 4 |
+
{"question_id": 4, "text": "How can I increase my productivity while working from home?", "category": "generic"}
|
| 5 |
+
{"question_id": 5, "text": "Can you explain the basics of quantum computing?", "category": "generic"}
|
| 6 |
+
{"question_id": 6, "text": "What are the differences between plant-based and animal-based protein sources?", "category": "generic"}
|
| 7 |
+
{"question_id": 7, "text": "How can I develop my critical thinking skills?", "category": "generic"}
|
| 8 |
+
{"question_id": 8, "text": "What are the major challenges faced by the education sector today?", "category": "generic"}
|
| 9 |
+
{"question_id": 9, "text": "What are the primary factors that influence consumer behavior?", "category": "generic"}
|
| 10 |
+
{"question_id": 10, "text": "What are the most effective strategies for conflict resolution in the workplace?", "category": "generic"}
|
| 11 |
+
{"question_id": 11, "text": "What are some potential implications of using a single-use plastic bottle versus a reusable bottle on both the environment and human health?", "category": "knowledge"}
|
| 12 |
+
{"question_id": 12, "text": "What factors would you consider when designing an inclusive and accessible public transportation system?", "category": "knowledge"}
|
| 13 |
+
{"question_id": 13, "text": "How can governments utilize fiscal and monetary policies to combat economic recessions?", "category": "knowledge"}
|
| 14 |
+
{"question_id": 14, "text": "How do language and cultural barriers affect the way people communicate and form relationships in multicultural societies?", "category": "knowledge"}
|
| 15 |
+
{"question_id": 15, "text": "Describe a scenario where artificial intelligence could be used to improve the quality and efficiency of healthcare delivery.", "category": "knowledge"}
|
| 16 |
+
{"question_id": 16, "text": "Explain the process of gene editing using CRISPR-Cas9 technology, and discuss its potential applications and ethical implications.", "category": "knowledge"}
|
| 17 |
+
{"question_id": 17, "text": "How do vaccinations work to protect individuals and communities from infectious diseases, and what is herd immunity?", "category": "knowledge"}
|
| 18 |
+
{"question_id": 18, "text": "How do social media platforms influence the way people consume and share news, and what are the potential implications for the spread of misinformation?", "category": "knowledge"}
|
| 19 |
+
{"question_id": 19, "text": "How do cultural, social, and economic factors influence people's food choices, and how can this knowledge be used to promote healthier diets?", "category": "knowledge"}
|
| 20 |
+
{"question_id": 20, "text": "Explain the process of natural selection and how it contributes to the evolution and adaptation of species.", "category": "knowledge"}
|
| 21 |
+
{"question_id": 21, "text": "How would you introduce yourself as a medieval knight at a royal banquet?", "category": "roleplay"}
|
| 22 |
+
{"question_id": 22, "text": "As a pirate captain, what would you say to your crew to motivate them to search for hidden treasure?", "category": "roleplay"}
|
| 23 |
+
{"question_id": 23, "text": "If you were a Shakespearean character, how would you declare your love for someone in a soliloquy?", "category": "roleplay"}
|
| 24 |
+
{"question_id": 24, "text": "As a superhero, how would you explain your origin story to a curious child?", "category": "roleplay"}
|
| 25 |
+
{"question_id": 25, "text": "Imagine you are a time traveler from the year 3000. What technological advancements would you tell people about?", "category": "roleplay"}
|
| 26 |
+
{"question_id": 26, "text": "As a sports commentator, describe the winning play in the final seconds of a championship game.", "category": "roleplay"}
|
| 27 |
+
{"question_id": 27, "text": "Pretend to be a world-famous chef. How would you describe your signature dish to a panel of judges?", "category": "roleplay"}
|
| 28 |
+
{"question_id": 28, "text": "You are a mountain climber reaching the summit of Mount Everest. Describe your emotions and the view from the top.", "category": "roleplay"}
|
| 29 |
+
{"question_id": 29, "text": "As a space colonist on Mars, describe your daily life and the challenges you face living on another planet.", "category": "roleplay"}
|
| 30 |
+
{"question_id": 30, "text": "Pretend to be a character in a post-apocalyptic world. Describe how you survive and the allies you encounter.", "category": "roleplay"}
|
| 31 |
+
{"question_id": 31, "text": "How can you determine if a restaurant is popular among locals or mainly attracts tourists, and why might this information be useful?", "category": "common-sense"}
|
| 32 |
+
{"question_id": 32, "text": "What are some subtle clues that suggest someone is pretending to understand a topic or conversation when they are actually confused or uninformed?", "category": "common-sense"}
|
| 33 |
+
{"question_id": 33, "text": "Why might someone choose to use a paper map or ask for directions instead of relying on a GPS device or smartphone app?", "category": "common-sense"}
|
| 34 |
+
{"question_id": 34, "text": "How can you determine if a person is genuinely interested in a conversation or simply being polite?", "category": "common-sense"}
|
| 35 |
+
{"question_id": 35, "text": "Why might someone prefer to shop at a small, locally-owned business instead of a large chain store, even if the prices are higher?", "category": "common-sense"}
|
| 36 |
+
{"question_id": 36, "text": "How can you assess the credibility of a source of information, such as a news article or blog post, without relying solely on the reputation of the author or publisher?", "category": "common-sense"}
|
| 37 |
+
{"question_id": 37, "text": "Why do some people enjoy the sensation of being scared, such as by watching horror movies or going on roller coasters, while others avoid these experiences?", "category": "common-sense"}
|
| 38 |
+
{"question_id": 38, "text": "How can observing the behavior of other people in a social situation provide clues about cultural norms and expectations?", "category": "common-sense"}
|
| 39 |
+
{"question_id": 39, "text": "Do we have a moral obligation to explore space, or should we focus on solving Earth's problems first?", "category": "common-sense"}
|
| 40 |
+
{"question_id": 40, "text": "In a world where automation is becoming increasingly prevalent, is it more important to prioritize job creation or technological progress?", "category": "common-sense"}
|
| 41 |
+
{"question_id": 41, "text": "How many times does the average human blink in a lifetime? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"}
|
| 42 |
+
{"question_id": 42, "text": "How many atoms are in a grain of salt? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"}
|
| 43 |
+
{"question_id": 43, "text": "How many lightning strikes occur on Earth each day? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"}
|
| 44 |
+
{"question_id": 44, "text": "How many balloons would it take to lift a house like in the movie \"Up\"? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"}
|
| 45 |
+
{"question_id": 45, "text": "How many text messages are sent globally in a minute? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"}
|
| 46 |
+
{"question_id": 46, "text": "How many words are spoken daily on Earth? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"}
|
| 47 |
+
{"question_id": 47, "text": "How many snowflakes fall during a typical winter? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"}
|
| 48 |
+
{"question_id": 48, "text": "How many pages are in all the books ever written? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"}
|
| 49 |
+
{"question_id": 49, "text": "How many times has the Earth orbited the Sun since the beginning of life? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"}
|
| 50 |
+
{"question_id": 50, "text": "How many songs have been recorded throughout history? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"}
|
| 51 |
+
{"question_id": 51, "text": "What if the Internet had been invented during the Renaissance period?", "category": "counterfactual"}
|
| 52 |
+
{"question_id": 52, "text": "What if the Aztecs had successfully repelled the Spanish conquistadors?", "category": "counterfactual"}
|
| 53 |
+
{"question_id": 53, "text": "What if the Black Death had not occurred in the 14th century?", "category": "counterfactual"}
|
| 54 |
+
{"question_id": 54, "text": "What if Isaac Newton had focused on biology instead of physics?", "category": "counterfactual"}
|
| 55 |
+
{"question_id": 55, "text": "What if the Beatles had never formed as a band?", "category": "counterfactual"}
|
| 56 |
+
{"question_id": 56, "text": "What if Alan Turing had not cracked the Enigma code during World War II?", "category": "counterfactual"}
|
| 57 |
+
{"question_id": 57, "text": "What if the Suez Canal had never been constructed?", "category": "counterfactual"}
|
| 58 |
+
{"question_id": 58, "text": "What if the Maya civilization had never mysteriously collapsed?", "category": "counterfactual"}
|
| 59 |
+
{"question_id": 59, "text": "What if Christopher Columbus had not discovered the Americas?", "category": "counterfactual"}
|
| 60 |
+
{"question_id": 60, "text": "What if Vincent van Gogh had been a successful artist during his lifetime?", "category": "counterfactual"}
|
| 61 |
+
{"question_id": 61, "text": "Develop a C++ program that reads a text file line by line and counts the number of occurrences of a specific word in the file.", "category": "coding"}
|
| 62 |
+
{"question_id": 62, "text": "Implement a Python function to find the longest common subsequence of two input strings using dynamic programming.", "category": "coding"}
|
| 63 |
+
{"question_id": 63, "text": "Implement a regular expression in Python to validate an email address.", "category": "coding"}
|
| 64 |
+
{"question_id": 64, "text": "Write a program to find the nth Fibonacci number using dynamic programming.", "category": "coding"}
|
| 65 |
+
{"question_id": 65, "text": "Implement a binary search algorithm to find a specific element in a sorted array.", "category": "coding"}
|
| 66 |
+
{"question_id": 66, "text": "Implement a queue data structure using two stacks in Python.", "category": "coding"}
|
| 67 |
+
{"question_id": 67, "text": "Implement a program to find the common elements in two arrays without using any extra data structures.", "category": "coding"}
|
| 68 |
+
{"question_id": 68, "text": "Given that f(x) = 5x^3 - 2x + 3, find the value of f(2).", "category": "math"}
|
| 69 |
+
{"question_id": 69, "text": "Solve for x in the equation 3x + 10 = 5(x - 2).", "category": "math"}
|
| 70 |
+
{"question_id": 70, "text": "If the endpoints of a line segment are (2, -2) and (10, 4), what is the length of the segment?", "category": "math"}
|
| 71 |
+
{"question_id": 71, "text": "Can you help me write a formal email to a potential business partner proposing a joint venture?", "category": "writing"}
|
| 72 |
+
{"question_id": 72, "text": "Can you help me write a resignation letter to my current employer, while leaving on good terms and expressing gratitude for the opportunities provided?", "category": "writing"}
|
| 73 |
+
{"question_id": 73, "text": "Use an appropriate format to structure a formal letter of recommendation for a student applying to a prestigious graduate program in computer science.", "category": "writing"}
|
| 74 |
+
{"question_id": 74, "text": "Write a compelling product launch announcement email to inform our customers of our new software solution.", "category": "writing"}
|
| 75 |
+
{"question_id": 75, "text": "Draft an apology email to a customer who experienced a delay in their order, and provide reassurance that the issue has been resolved.", "category": "writing"}
|
| 76 |
+
{"question_id": 76, "text": "Write a script for a YouTube video exploring the history and cultural significance of jazz.", "category": "writing"}
|
| 77 |
+
{"question_id": 77, "text": "Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions.", "category": "writing"}
|
| 78 |
+
{"question_id": 78, "text": "Write a captivating movie review for a recently released science fiction film, discussing its plot, characters, and special effects.", "category": "writing"}
|
| 79 |
+
{"question_id": 79, "text": "Structure a podcast script for an episode discussing the influence of streaming platforms on the music industry.", "category": "writing"}
|
| 80 |
+
{"question_id": 80, "text": "Write a symphony concert review, discussing the orchestra's performance and overall audience experience.", "category": "writing"}
|
ChatUniVi/eval/table/reviewer.jsonl
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"reviewer_id": "gpt-4-0328-default", "prompt_id": 1, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for general questions"}
|
| 2 |
+
{"reviewer_id": "gpt-4-0328-coding", "prompt_id": 2, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for coding questions"}
|
| 3 |
+
{"reviewer_id": "gpt-4-0328-math", "prompt_id": 3, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for math questions"}
|
| 4 |
+
{"reviewer_id": "gpt-4-0417-visual", "prompt_id": 4, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for math questions"}
|
ChatUniVi/eval/table/rule.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"coding": {"role": "Assistant", "prompt": "Your task is to evaluate the coding abilities of the above two assistants. They have been asked to implement a program to solve a given problem. Please review their code submissions, paying close attention to their problem-solving approach, code structure, readability, and the inclusion of helpful comments.\n\nPlease ensure that the assistants' submissions:\n\n1. Correctly implement the given problem statement.\n2. Contain accurate and efficient code.\n3. Include clear and concise comments that explain the code's logic and functionality.\n4. Adhere to proper coding standards and best practices.\n\nOnce you have carefully reviewed both submissions, provide detailed feedback on their strengths and weaknesses, along with any suggestions for improvement. You should first output a single line containing two scores on the scale of 1-10 (1: no code/no sense; 10: perfect) for Assistant 1 and 2, respectively. Then give extra comments starting from the next line."},
|
| 3 |
+
"math": {"role": "Assistant", "prompt": "We would like to request your feedback on the mathematical proficiency of two AI assistants regarding the given user question.\nFirstly, please solve the problem independently, without referring to the answers provided by Assistant 1 and Assistant 2.\nAfterward, please examine the problem-solving process of Assistant 1 and Assistant 2 step-by-step to ensure their correctness, identifying any incorrect steps if present. Your evaluation should take into account not only the answer but also the problem-solving steps.\nFinally, please output a Python tuple containing two numerical scores for Assistant 1 and Assistant 2, ranging from 1 to 10, respectively. If applicable, explain the reasons for any variations in their scores and determine which assistant performed better."},
|
| 4 |
+
"default": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above.\nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."},
|
| 5 |
+
"conv": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with five descriptive sentences describing the same image and the bounding box coordinates of each object in the scene. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."},
|
| 6 |
+
"detail": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with five descriptive sentences describing the same image and the bounding box coordinates of each object in the scene. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."},
|
| 7 |
+
"complex": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with five descriptive sentences describing the same image and the bounding box coordinates of each object in the scene. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."},
|
| 8 |
+
"llava_bench_conv": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with a few sentences describing the image. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."},
|
| 9 |
+
"llava_bench_detail": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with a few sentences describing the image. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."},
|
| 10 |
+
"llava_bench_complex": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with a few sentences describing the image. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}
|
| 11 |
+
}
|
ChatUniVi/model/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .language_model.llama import ChatUniViLlamaForCausalLM, ChatUniViConfig
|
ChatUniVi/model/apply_delta.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 6 |
+
from ChatUniVi import ChatUniViLlamaForCausalLM
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def apply_delta(base_model_path, target_model_path, delta_path):
|
| 10 |
+
print("Loading base model")
|
| 11 |
+
base = AutoModelForCausalLM.from_pretrained(
|
| 12 |
+
base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
| 13 |
+
|
| 14 |
+
print("Loading delta")
|
| 15 |
+
delta = ChatUniViLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
| 16 |
+
delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
|
| 17 |
+
|
| 18 |
+
print("Applying delta")
|
| 19 |
+
for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
|
| 20 |
+
if name not in base.state_dict():
|
| 21 |
+
assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
|
| 22 |
+
continue
|
| 23 |
+
if param.data.shape == base.state_dict()[name].shape:
|
| 24 |
+
param.data += base.state_dict()[name]
|
| 25 |
+
else:
|
| 26 |
+
assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \
|
| 27 |
+
f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
|
| 28 |
+
bparam = base.state_dict()[name]
|
| 29 |
+
param.data[:bparam.shape[0], :bparam.shape[1]] += bparam
|
| 30 |
+
|
| 31 |
+
print("Saving target model")
|
| 32 |
+
delta.save_pretrained(target_model_path)
|
| 33 |
+
delta_tokenizer.save_pretrained(target_model_path)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
if __name__ == "__main__":
|
| 37 |
+
parser = argparse.ArgumentParser()
|
| 38 |
+
parser.add_argument("--base-model-path", type=str, required=True)
|
| 39 |
+
parser.add_argument("--target-model-path", type=str, required=True)
|
| 40 |
+
parser.add_argument("--delta-path", type=str, required=True)
|
| 41 |
+
|
| 42 |
+
args = parser.parse_args()
|
| 43 |
+
|
| 44 |
+
apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
|
ChatUniVi/model/arch.py
ADDED
|
@@ -0,0 +1,652 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from PIL.ImImagePlugin import split
|
| 5 |
+
|
| 6 |
+
from .multimodal_encoder.builder import build_vision_tower
|
| 7 |
+
from ChatUniVi.constants import *
|
| 8 |
+
from .cluster import CTM, TCBlock
|
| 9 |
+
from collections import OrderedDict
|
| 10 |
+
from .multimodal_projector.builder import build_vision_projector
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class MetaModel:
|
| 14 |
+
def __init__(self, config):
|
| 15 |
+
super(MetaModel, self).__init__(config)
|
| 16 |
+
|
| 17 |
+
if hasattr(config, "mm_vision_tower"):
|
| 18 |
+
self.vision_tower = build_vision_tower(config, delay_load=True)
|
| 19 |
+
self.mm_projector = nn.Linear(config.mm_hidden_size, config.hidden_size)
|
| 20 |
+
|
| 21 |
+
if hasattr(config, "config"):
|
| 22 |
+
self.use_cluster = config.config["use_cluster"]
|
| 23 |
+
if self.use_cluster:
|
| 24 |
+
self.ctm0 = CTM(sample_ratio=config.config["spatial_cluster_rate0"], embed_dim=self.config.mm_hidden_size, dim_out=self.config.mm_hidden_size, k=5)
|
| 25 |
+
self.block0 = TCBlock(dim=self.config.mm_hidden_size, num_heads=8)
|
| 26 |
+
|
| 27 |
+
self.ctm1 = CTM(sample_ratio=config.config["spatial_cluster_rate1"], embed_dim=self.config.mm_hidden_size, dim_out=self.config.mm_hidden_size, k=3)
|
| 28 |
+
self.block1 = TCBlock(dim=self.config.mm_hidden_size, num_heads=8)
|
| 29 |
+
|
| 30 |
+
self.ctm2 = CTM(sample_ratio=config.config["spatial_cluster_rate2"], embed_dim=self.config.mm_hidden_size, dim_out=self.config.mm_hidden_size, k=3)
|
| 31 |
+
self.block2 = TCBlock(dim=self.config.mm_hidden_size, num_heads=8)
|
| 32 |
+
|
| 33 |
+
self.ctm3 = CTM(sample_ratio=config.config["temporal_cluster_rate"], embed_dim=self.config.mm_hidden_size, dim_out=self.config.mm_hidden_size, k=5)
|
| 34 |
+
self.block3 = TCBlock(dim=self.config.mm_hidden_size, num_heads=8)
|
| 35 |
+
else:
|
| 36 |
+
self.use_cluster = False
|
| 37 |
+
|
| 38 |
+
def get_vision_tower(self):
|
| 39 |
+
vision_tower = getattr(self, 'vision_tower', None)
|
| 40 |
+
if type(vision_tower) is list:
|
| 41 |
+
vision_tower = vision_tower[0]
|
| 42 |
+
return vision_tower
|
| 43 |
+
|
| 44 |
+
def initialize_vision_modules(self, model_args, fsdp=None):
|
| 45 |
+
vision_tower = model_args.vision_tower
|
| 46 |
+
mm_vision_select_layer = model_args.mm_vision_select_layer
|
| 47 |
+
mm_vision_select_feature = model_args.mm_vision_select_feature
|
| 48 |
+
pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
|
| 49 |
+
|
| 50 |
+
self.config.mm_vision_tower = vision_tower
|
| 51 |
+
|
| 52 |
+
vision_tower = build_vision_tower(model_args)
|
| 53 |
+
|
| 54 |
+
self.config.use_mm_proj = True
|
| 55 |
+
self.config.mm_hidden_size = vision_tower.hidden_size
|
| 56 |
+
self.config.mm_vision_select_layer = mm_vision_select_layer
|
| 57 |
+
self.config.mm_vision_select_feature = mm_vision_select_feature
|
| 58 |
+
|
| 59 |
+
if fsdp is not None and len(fsdp) > 0:
|
| 60 |
+
self.vision_tower = [vision_tower]
|
| 61 |
+
else:
|
| 62 |
+
self.vision_tower = vision_tower
|
| 63 |
+
|
| 64 |
+
if not hasattr(self, 'mm_projector'):
|
| 65 |
+
self.mm_projector = build_vision_projector(self.config)
|
| 66 |
+
|
| 67 |
+
if pretrain_mm_mlp_adapter is not None:
|
| 68 |
+
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
|
| 69 |
+
def get_w(weights, keyword):
|
| 70 |
+
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
|
| 71 |
+
|
| 72 |
+
self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
|
| 73 |
+
|
| 74 |
+
def initialize_cluster_modules(self, model_args):
|
| 75 |
+
self.use_cluster = model_args.use_cluster
|
| 76 |
+
|
| 77 |
+
if self.use_cluster and not hasattr(self, 'ctm0'):
|
| 78 |
+
self.ctm0 = CTM(sample_ratio=model_args.spatial_cluster_rate0, embed_dim=self.config.mm_hidden_size, dim_out=self.config.mm_hidden_size, k=5)
|
| 79 |
+
self.block0 = TCBlock(dim=self.config.mm_hidden_size, num_heads=8)
|
| 80 |
+
|
| 81 |
+
self.ctm1 = CTM(sample_ratio=model_args.spatial_cluster_rate1, embed_dim=self.config.mm_hidden_size, dim_out=self.config.mm_hidden_size, k=3)
|
| 82 |
+
self.block1 = TCBlock(dim=self.config.mm_hidden_size, num_heads=8)
|
| 83 |
+
|
| 84 |
+
self.ctm2 = CTM(sample_ratio=model_args.spatial_cluster_rate2, embed_dim=self.config.mm_hidden_size, dim_out=self.config.mm_hidden_size, k=3)
|
| 85 |
+
self.block2 = TCBlock(dim=self.config.mm_hidden_size, num_heads=8)
|
| 86 |
+
|
| 87 |
+
self.ctm3 = CTM(sample_ratio=model_args.temporal_cluster_rate, embed_dim=self.config.mm_hidden_size, dim_out=self.config.mm_hidden_size, k=5)
|
| 88 |
+
self.block3 = TCBlock(dim=self.config.mm_hidden_size, num_heads=8)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class ChatUniViMetaForCausalLM(ABC):
|
| 92 |
+
@abstractmethod
|
| 93 |
+
def get_model(self):
|
| 94 |
+
pass
|
| 95 |
+
|
| 96 |
+
def get_vision_tower(self):
|
| 97 |
+
return self.get_model().get_vision_tower()
|
| 98 |
+
|
| 99 |
+
def encode_images(self, images):
|
| 100 |
+
image_features = self.get_model().get_vision_tower()(images, select_feature="patch")
|
| 101 |
+
return image_features
|
| 102 |
+
|
| 103 |
+
def positional_encoding(self, x, num_features=1024, max_len=64):
|
| 104 |
+
p = torch.zeros((1, max_len, num_features))
|
| 105 |
+
_x = torch.arange(max_len, dtype=torch.float32).reshape(-1, 1) / torch.pow(10000,
|
| 106 |
+
torch.arange(0, num_features, 2, dtype=torch.float32) / num_features)
|
| 107 |
+
|
| 108 |
+
p[:, :, 0::2] = torch.sin(_x)
|
| 109 |
+
p[:, :, 1::2] = torch.cos(_x)
|
| 110 |
+
x = x + p[:, :x.shape[1], :].to(x.device).to(x.dtype)
|
| 111 |
+
return x
|
| 112 |
+
|
| 113 |
+
def project(self, image_features, input_type="image"):
|
| 114 |
+
if self.get_model().use_cluster:
|
| 115 |
+
if input_type == "image":
|
| 116 |
+
cluster_image_features = []
|
| 117 |
+
token_dict = {'x': image_features,
|
| 118 |
+
'token_num': image_features.size(1),
|
| 119 |
+
'idx_token': torch.arange(image_features.size(1))[None, :].repeat(
|
| 120 |
+
image_features.size(0), 1),
|
| 121 |
+
'agg_weight': image_features.new_ones(image_features.size(0), image_features.size(1),
|
| 122 |
+
1),
|
| 123 |
+
'mask': None}
|
| 124 |
+
|
| 125 |
+
token_dict = self.get_model().block0(self.get_model().ctm0(token_dict))
|
| 126 |
+
cluster_image_features.append(token_dict["x"])
|
| 127 |
+
|
| 128 |
+
token_dict = self.get_model().block1(self.get_model().ctm1(token_dict))
|
| 129 |
+
cluster_image_features.append(token_dict["x"])
|
| 130 |
+
|
| 131 |
+
token_dict = self.get_model().block2(self.get_model().ctm2(token_dict))
|
| 132 |
+
cluster_image_features.append(token_dict["x"])
|
| 133 |
+
|
| 134 |
+
image_features = torch.cat(cluster_image_features, dim=1)
|
| 135 |
+
image_features = image_features.to(self.get_model().mm_projector.weight.dtype)
|
| 136 |
+
else:
|
| 137 |
+
cls_features = torch.mean(image_features, dim=1, keepdim=False).unsqueeze(0).clone()
|
| 138 |
+
token_dict = {'x': cls_features,
|
| 139 |
+
'token_num': cls_features.size(1),
|
| 140 |
+
'idx_token': torch.arange(cls_features.size(1))[None, :].repeat(
|
| 141 |
+
cls_features.size(0), 1),
|
| 142 |
+
'agg_weight': cls_features.new_ones(cls_features.size(0), cls_features.size(1),
|
| 143 |
+
1),
|
| 144 |
+
'mask': None}
|
| 145 |
+
|
| 146 |
+
down_dict, token_dict = self.get_model().ctm3(token_dict)
|
| 147 |
+
events = OrderedDict()
|
| 148 |
+
|
| 149 |
+
max_len = 0
|
| 150 |
+
for id, i in enumerate(down_dict["idx_token"][0].tolist()):
|
| 151 |
+
if i not in events:
|
| 152 |
+
events[i] = [id]
|
| 153 |
+
else:
|
| 154 |
+
events[i].append(id)
|
| 155 |
+
max_len = len(events[i]) if max_len < len(events[i]) else max_len
|
| 156 |
+
|
| 157 |
+
cluster_image_features = []
|
| 158 |
+
token_dict = {'x': image_features,
|
| 159 |
+
'token_num': image_features.size(1),
|
| 160 |
+
'idx_token': torch.arange(image_features.size(1))[None, :].repeat(
|
| 161 |
+
image_features.size(0), 1),
|
| 162 |
+
'agg_weight': image_features.new_ones(image_features.size(0), image_features.size(1),
|
| 163 |
+
1),
|
| 164 |
+
'mask': None}
|
| 165 |
+
|
| 166 |
+
token_dict0 = self.get_model().block0(self.get_model().ctm0(token_dict))
|
| 167 |
+
token_dict1 = self.get_model().block1(self.get_model().ctm1(token_dict0))
|
| 168 |
+
token_dict2 = self.get_model().block2(self.get_model().ctm2(token_dict1))
|
| 169 |
+
|
| 170 |
+
for id, key in enumerate(events):
|
| 171 |
+
cur_image_features0 = torch.cat([token_dict0["x"][i] for i in events[key]], dim=0).unsqueeze(0)
|
| 172 |
+
token_dict = {'x': cur_image_features0,
|
| 173 |
+
'token_num': cur_image_features0.size(1),
|
| 174 |
+
'idx_token': torch.arange(cur_image_features0.size(1))[None, :].repeat(
|
| 175 |
+
cur_image_features0.size(0), 1),
|
| 176 |
+
'agg_weight': cur_image_features0.new_ones(cur_image_features0.size(0),
|
| 177 |
+
cur_image_features0.size(1),
|
| 178 |
+
1),
|
| 179 |
+
'mask': None}
|
| 180 |
+
|
| 181 |
+
cur_token_dict0 = self.get_model().block0(self.get_model().ctm0(token_dict))
|
| 182 |
+
cluster_image_features.append(cur_token_dict0["x"])
|
| 183 |
+
|
| 184 |
+
cur_image_features1 = torch.cat([token_dict1["x"][i] for i in events[key]], dim=0).unsqueeze(0)
|
| 185 |
+
token_dict = {'x': cur_image_features1,
|
| 186 |
+
'token_num': cur_image_features1.size(1),
|
| 187 |
+
'idx_token': torch.arange(cur_image_features1.size(1))[None, :].repeat(
|
| 188 |
+
cur_image_features1.size(0), 1),
|
| 189 |
+
'agg_weight': cur_image_features1.new_ones(cur_image_features1.size(0),
|
| 190 |
+
cur_image_features1.size(1),
|
| 191 |
+
1),
|
| 192 |
+
'mask': None}
|
| 193 |
+
|
| 194 |
+
cur_token_dict1 = self.get_model().block1(self.get_model().ctm1(token_dict))
|
| 195 |
+
cluster_image_features.append(cur_token_dict1["x"])
|
| 196 |
+
|
| 197 |
+
cur_image_features2 = torch.cat([token_dict2["x"][i] for i in events[key]], dim=0).unsqueeze(0)
|
| 198 |
+
token_dict = {'x': cur_image_features2,
|
| 199 |
+
'token_num': cur_image_features2.size(1),
|
| 200 |
+
'idx_token': torch.arange(cur_image_features2.size(1))[None, :].repeat(
|
| 201 |
+
cur_image_features2.size(0), 1),
|
| 202 |
+
'agg_weight': cur_image_features2.new_ones(cur_image_features2.size(0),
|
| 203 |
+
cur_image_features2.size(1),
|
| 204 |
+
1),
|
| 205 |
+
'mask': None}
|
| 206 |
+
|
| 207 |
+
cur_token_dict2 = self.get_model().block2(self.get_model().ctm2(token_dict))
|
| 208 |
+
cluster_image_features.append(cur_token_dict2["x"])
|
| 209 |
+
|
| 210 |
+
image_features = torch.cat(cluster_image_features, dim=1)
|
| 211 |
+
image_features = image_features.to(self.get_model().mm_projector.weight.dtype)
|
| 212 |
+
|
| 213 |
+
else:
|
| 214 |
+
if input_type == "video":
|
| 215 |
+
image_features, cls_features = torch.mean(image_features, dim=0, keepdim=False).unsqueeze(
|
| 216 |
+
0), torch.mean(image_features, dim=1, keepdim=False).unsqueeze(0)
|
| 217 |
+
image_features = torch.cat([image_features, cls_features], dim=1)
|
| 218 |
+
|
| 219 |
+
image_features = self.get_model().mm_projector(image_features)
|
| 220 |
+
return image_features # 不同的type形状相同
|
| 221 |
+
|
| 222 |
+
def prepare_inputs_labels_for_multimodal(
|
| 223 |
+
self, input_ids, attention_mask, past_key_values, labels, images, audio_features=None, target_frame=0, ref_ids=None
|
| 224 |
+
):
|
| 225 |
+
IMAGE_TOKEN_INDEX = -200
|
| 226 |
+
AUDIO_TOKEN_INDEX = -300
|
| 227 |
+
# print("\n调用prepare_inputs_labels_for_multimodal")
|
| 228 |
+
vision_tower = self.get_vision_tower()
|
| 229 |
+
# print("获取vision_tower")
|
| 230 |
+
num_frames = images[0].shape[0] # T
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
if vision_tower is None or images is None or input_ids.shape[1] == 1:
|
| 234 |
+
if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1:
|
| 235 |
+
attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device)
|
| 236 |
+
return input_ids, attention_mask, past_key_values, None, labels
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
if ref_ids is not None:
|
| 240 |
+
ref_embeds = []
|
| 241 |
+
for ref_id in ref_ids:
|
| 242 |
+
ref_embed = self.get_model().embed_tokens(ref_id) #[L, 4096]
|
| 243 |
+
ref_embeds.append(ref_embed)
|
| 244 |
+
# list[B]: [len_ref, 4096]
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
if type(images) is list or images.ndim == 5:
|
| 251 |
+
# print("先concat列表中的图像")
|
| 252 |
+
concat_images = torch.cat([image for image in images], dim=0) # [BT, 3, H, W]
|
| 253 |
+
org_image_features = self.encode_images(concat_images) # [BT, 256, 1024]
|
| 254 |
+
|
| 255 |
+
# if audio_features is not None and hasattr(self, "audio_adapter"):
|
| 256 |
+
if True:
|
| 257 |
+
# image_features = self.audio_adapter(org_image_features, audio_features, ref_embeds_T)
|
| 258 |
+
# image_features = self.token_compressor(org_image_features, ref_embeds)
|
| 259 |
+
# print("image_features after compress:", image_features.shape)
|
| 260 |
+
image_features = org_image_features
|
| 261 |
+
|
| 262 |
+
else:
|
| 263 |
+
image_features = org_image_features
|
| 264 |
+
# split_sizes = [image.shape[0] for image in images]
|
| 265 |
+
split_sizes = 1
|
| 266 |
+
image_features = torch.split(image_features, split_sizes, dim=0) # list[BT]: [1, 256,1024]
|
| 267 |
+
image_features = [x.flatten(0, 1) for x in image_features] # list[BT]: [256,1024]
|
| 268 |
+
|
| 269 |
+
org_image_features = torch.split(org_image_features, split_sizes, dim=0)
|
| 270 |
+
org_image_features = [x.flatten(0, 1) for x in org_image_features]
|
| 271 |
+
|
| 272 |
+
else:
|
| 273 |
+
# print("直接获取image_feature")
|
| 274 |
+
image_features = self.encode_images(images)
|
| 275 |
+
org_image_features = image_features
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
new_input_embeds = []
|
| 280 |
+
new_labels = [] if labels is not None else None
|
| 281 |
+
cur_image_idx = 0
|
| 282 |
+
for batch_idx, cur_input_ids in enumerate(input_ids):
|
| 283 |
+
# cur_image_idx += 1
|
| 284 |
+
|
| 285 |
+
# 判断当前input_id中有没有图像token
|
| 286 |
+
# print("cur_input_ids shape:", cur_input_ids.shape)
|
| 287 |
+
# print("cur_input_ids:", cur_input_ids)
|
| 288 |
+
if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0:
|
| 289 |
+
# print("input_ids中没有 IMAGE token")
|
| 290 |
+
# multimodal LLM, but the current sample is not multimodal
|
| 291 |
+
# 直接把input_ids进行text embed
|
| 292 |
+
cur_input_embeds = self.get_model().embed_tokens(cur_input_ids)
|
| 293 |
+
cur_input_embeds = cur_input_embeds + (
|
| 294 |
+
0. * self.get_model().mm_projector(vision_tower.dummy_feature)).sum()
|
| 295 |
+
new_input_embeds.append(cur_input_embeds)
|
| 296 |
+
if labels is not None:
|
| 297 |
+
new_labels.append(labels[batch_idx])
|
| 298 |
+
cur_image_idx += 1
|
| 299 |
+
continue
|
| 300 |
+
|
| 301 |
+
image_token_indices = torch.where((cur_input_ids == IMAGE_TOKEN_INDEX)|(cur_input_ids == AUDIO_TOKEN_INDEX))[0]
|
| 302 |
+
audio_token_indices = torch.where(cur_input_ids == AUDIO_TOKEN_INDEX)[0]
|
| 303 |
+
# print("audio indices:", audio_token_indices)
|
| 304 |
+
# print("image and audio indices:", image_token_indices)
|
| 305 |
+
|
| 306 |
+
cur_new_input_embeds = []
|
| 307 |
+
if labels is not None:
|
| 308 |
+
cur_labels = labels[batch_idx]
|
| 309 |
+
cur_new_labels = []
|
| 310 |
+
assert cur_labels.shape == cur_input_ids.shape
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
# 有多个image token---------------------------------------------
|
| 314 |
+
if len(image_token_indices) > 1:
|
| 315 |
+
# print("有多个image token")
|
| 316 |
+
# return 0
|
| 317 |
+
|
| 318 |
+
temp = []
|
| 319 |
+
|
| 320 |
+
cur, pre = image_token_indices[0], image_token_indices[0]
|
| 321 |
+
# 这里是把连续的<image>的位置放到一个list中存储 分割开的<image>
|
| 322 |
+
for i in image_token_indices:
|
| 323 |
+
cur = i
|
| 324 |
+
# 如果下一个<image>就在上一个<image>之后
|
| 325 |
+
if cur - pre == 1:
|
| 326 |
+
temp[-1] = temp[-1] + [cur]
|
| 327 |
+
else:
|
| 328 |
+
temp.append([cur])
|
| 329 |
+
pre = cur
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
pre_image_token_end = 0
|
| 333 |
+
cur_frames = 0
|
| 334 |
+
for i in temp:
|
| 335 |
+
# 第一个以及最后一个<image>的位置
|
| 336 |
+
image_token_start = i[0]
|
| 337 |
+
image_token_end = i[-1]
|
| 338 |
+
cur_image_features = []
|
| 339 |
+
|
| 340 |
+
if len(i) >= 2: # 处理T个image组成的视频特征
|
| 341 |
+
for frame_idx in range(num_frames):
|
| 342 |
+
cur_image_features.append(org_image_features[batch_idx*num_frames+frame_idx])
|
| 343 |
+
# print(batch_idx*num_frames+frame_idx)
|
| 344 |
+
elif i[0] not in audio_token_indices:
|
| 345 |
+
cur_image_features.append(org_image_features[batch_idx * num_frames + target_frame])
|
| 346 |
+
# print(batch_idx * num_frames + target_frame)
|
| 347 |
+
else:
|
| 348 |
+
cur_image_features.append(audio_features[batch_idx])
|
| 349 |
+
# print(f"audio{batch_idx}")
|
| 350 |
+
# ------------------------------------------------------------------
|
| 351 |
+
# # i是每组<image>的indices 根据其数量从image_features中拿特征
|
| 352 |
+
# for _ in i:
|
| 353 |
+
# # 表示处理的是<image>
|
| 354 |
+
# if _ not in audio_token_indices:
|
| 355 |
+
# # 单个image
|
| 356 |
+
# if cur_frames == num_frames:
|
| 357 |
+
# # cur_image_features.append(org_image_features[cur_image_idx-num_frames+target_frame])
|
| 358 |
+
# cur_image_features.append(org_image_features[batch_idx*num_frames+target_frame])
|
| 359 |
+
# # print(cur_image_idx-num_frames+target_frame)
|
| 360 |
+
# # 多个image
|
| 361 |
+
# else:
|
| 362 |
+
# cur_image_features.append(image_features[cur_image_idx])
|
| 363 |
+
# # print(cur_image_idx)
|
| 364 |
+
# cur_image_idx += 1
|
| 365 |
+
# cur_frames += 1
|
| 366 |
+
# # 处理<audio>
|
| 367 |
+
# else:
|
| 368 |
+
# # cur_image_features.append(self.audio_feature_layer(audio_features[batch_idx]))
|
| 369 |
+
# cur_image_features.append(audio_features[batch_idx])
|
| 370 |
+
# # print("audio:", batch_idx)
|
| 371 |
+
# # cur_image_features list[len(i)] : [256,1024]
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
# 如果当前分组是多个image 代表video
|
| 376 |
+
if len(i) >= 2:
|
| 377 |
+
if not self.compress:
|
| 378 |
+
|
| 379 |
+
# 对拿到的多个image_features进行压缩 并投影
|
| 380 |
+
cur_image_features = torch.stack(cur_image_features, dim=0) # [len(i), 256, 1024]
|
| 381 |
+
cur_image_features = self.project(cur_image_features, input_type="video")
|
| 382 |
+
t, l, n = cur_image_features.size()
|
| 383 |
+
cur_image_features = cur_image_features.contiguous().view(t * l, n) #[112, 4096]
|
| 384 |
+
# print(f"no compression, cur_image_features{cur_image_features.shape}")
|
| 385 |
+
|
| 386 |
+
else:
|
| 387 |
+
|
| 388 |
+
compressed_frames = []
|
| 389 |
+
for cur_image_feature in cur_image_features:
|
| 390 |
+
cur_image_feature = self.project(cur_image_feature.unsqueeze(0), input_type="image") # [1, 256, 1024]
|
| 391 |
+
t, l, n = cur_image_feature.size()
|
| 392 |
+
cur_image_feature = cur_image_feature.contiguous().view(t * l, n) # [112, 4096]
|
| 393 |
+
|
| 394 |
+
compressed_frames.append(cur_image_feature.mean(dim=0).unsqueeze(0)) # [1, 4096]
|
| 395 |
+
compressed_frames = torch.cat(compressed_frames, dim=0) # [T, 4096]
|
| 396 |
+
|
| 397 |
+
cur_image_features = torch.stack(cur_image_features, dim=0) # [len(i), 256, 1024]
|
| 398 |
+
cur_image_features = self.project(cur_image_features, input_type="video")
|
| 399 |
+
t, l, n = cur_image_features.size()
|
| 400 |
+
cur_image_features = cur_image_features.contiguous().view(t * l, n) # [112, 4096]
|
| 401 |
+
|
| 402 |
+
# cur_image_features = torch.cat([cur_image_features, compressed_frames], dim=0) # [122, 4096]
|
| 403 |
+
cur_image_features = torch.cat([compressed_frames, cur_image_features], dim=0) # [122, 4096]
|
| 404 |
+
|
| 405 |
+
# 对于单个的特殊 token 如果是<image>
|
| 406 |
+
elif i[0] not in audio_token_indices:
|
| 407 |
+
cur_image_features = torch.stack(cur_image_features, dim=0)
|
| 408 |
+
cur_image_features = self.project(cur_image_features, input_type="image")
|
| 409 |
+
t, l, n = cur_image_features.size()
|
| 410 |
+
cur_image_features = cur_image_features.contiguous().view(t * l, n) # [112, 4093]
|
| 411 |
+
else:
|
| 412 |
+
cur_image_features = cur_image_features[0] #[10, 4096]
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
|
| 416 |
+
# 把im_start前的文字进行embeds
|
| 417 |
+
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[pre_image_token_end:image_token_start - 1]).detach())
|
| 418 |
+
# 把im_start进行embeds
|
| 419 |
+
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start - 1:image_token_start]))
|
| 420 |
+
# 图像特征
|
| 421 |
+
cur_new_input_embeds.append(cur_image_features)
|
| 422 |
+
# im_end
|
| 423 |
+
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_end + 1:image_token_end + 2]))
|
| 424 |
+
if labels is not None:
|
| 425 |
+
cur_new_labels.append(cur_labels[pre_image_token_end:image_token_start])
|
| 426 |
+
# cur_new_labels填充
|
| 427 |
+
cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
|
| 428 |
+
cur_new_labels.append(cur_labels[image_token_end:image_token_end + 1])
|
| 429 |
+
|
| 430 |
+
# cur_labels设置为剩余的cur_labels
|
| 431 |
+
# cur_labels = cur_labels[image_token_end + 2:]
|
| 432 |
+
else:
|
| 433 |
+
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[pre_image_token_end:image_token_start]))
|
| 434 |
+
cur_new_input_embeds.append(cur_image_features)
|
| 435 |
+
if labels is not None:
|
| 436 |
+
cur_new_labels.append(cur_labels[pre_image_token_end:image_token_start])
|
| 437 |
+
cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
|
| 438 |
+
# cur_labels = cur_labels[image_token_end + 1:]
|
| 439 |
+
|
| 440 |
+
pre_image_token_end = image_token_end + 1
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
# cur_input_ids设置为剩余的cur_input_ids
|
| 444 |
+
if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end',
|
| 445 |
+
False):
|
| 446 |
+
cur_input_ids = cur_input_ids[image_token_end + 2:]
|
| 447 |
+
cur_labels = cur_labels[image_token_end + 2:]
|
| 448 |
+
else:
|
| 449 |
+
cur_input_ids = cur_input_ids[image_token_end + 1:]
|
| 450 |
+
cur_labels = cur_labels[image_token_end + 1:]
|
| 451 |
+
|
| 452 |
+
# 结合上面大于1 此处就是只有一个image token
|
| 453 |
+
elif image_token_indices.numel() > 0:
|
| 454 |
+
# print("只有一个image token")
|
| 455 |
+
|
| 456 |
+
cur_image_features = []
|
| 457 |
+
image_token_start = image_token_indices[0]
|
| 458 |
+
image_token_end = image_token_indices[-1]
|
| 459 |
+
# print("image_token_start:", image_token_start, " image_token_end:", image_token_end)
|
| 460 |
+
|
| 461 |
+
# 根据image token数量 把image feature加入到cur_image_features
|
| 462 |
+
for _ in image_token_indices:
|
| 463 |
+
cur_image_features.append(image_features[cur_image_idx])
|
| 464 |
+
cur_image_idx += 1
|
| 465 |
+
# print("cur_image_features length:", len(cur_image_features))
|
| 466 |
+
|
| 467 |
+
# 对image features进行维度上拼接
|
| 468 |
+
cur_image_features = torch.stack(cur_image_features, dim=0)
|
| 469 |
+
# print("cur_image_features_stacked shape:", cur_image_features.shape)
|
| 470 |
+
cur_image_features = self.project(cur_image_features, input_type="image")
|
| 471 |
+
# print("cur_image_features_projected shape:", cur_image_features.shape)
|
| 472 |
+
|
| 473 |
+
# 获取 图像特征的维度 nums, len, dim
|
| 474 |
+
t, l, n = cur_image_features.size()
|
| 475 |
+
cur_image_features = cur_image_features.contiguous().view(t * l, n)
|
| 476 |
+
# print("cur_image_features_viewed shape:", cur_image_features.shape)
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
|
| 481 |
+
print("no tune_mm_mlp_adapter and no mm_use_im_start_end")
|
| 482 |
+
# 把imagetoken之前的text进行embedding 这两行
|
| 483 |
+
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start-1]).detach())
|
| 484 |
+
# 这里加入的 image——strat——token
|
| 485 |
+
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start-1:image_token_start]))
|
| 486 |
+
print("cur_new_input_embeds length:", len(cur_new_input_embeds))
|
| 487 |
+
print("cur_new_input_embeds shape:", cur_new_input_embeds[0].shape)
|
| 488 |
+
print("cur_new_input_embeds shape:", cur_new_input_embeds[1].shape)
|
| 489 |
+
|
| 490 |
+
# 在图像token位置上加入image feature
|
| 491 |
+
cur_new_input_embeds.append(cur_image_features)
|
| 492 |
+
print("cur_new_input_embeds length:", len(cur_new_input_embeds))
|
| 493 |
+
# print("cur_new_input_embeds shape:", cur_new_input_embeds[2].shape)
|
| 494 |
+
|
| 495 |
+
# 把图像token之后的img-end-token加入
|
| 496 |
+
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_end+1:image_token_end+2]))
|
| 497 |
+
print("cur_new_input_embeds length:", len(cur_new_input_embeds))
|
| 498 |
+
|
| 499 |
+
if labels is not None:
|
| 500 |
+
# 把image token前面的label加入
|
| 501 |
+
cur_new_labels.append(cur_labels[:image_token_start])
|
| 502 |
+
# 根据图像特征形状加入 多个ignore index
|
| 503 |
+
cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
|
| 504 |
+
# 把img-end-token加入
|
| 505 |
+
cur_new_labels.append(cur_labels[image_token_end:image_token_end+1])
|
| 506 |
+
# 把剩下的text label加入
|
| 507 |
+
cur_labels = cur_labels[image_token_end+2:]
|
| 508 |
+
|
| 509 |
+
else:
|
| 510 |
+
# print("tune_mm_mlp_adapter / mm_use_im_start_end")
|
| 511 |
+
# 对图像token之前的text token 进行embedding
|
| 512 |
+
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start]))
|
| 513 |
+
cur_new_input_embeds.append(cur_image_features)
|
| 514 |
+
# print("cur_new_input_embeds length:", len(cur_new_input_embeds))
|
| 515 |
+
|
| 516 |
+
if labels is not None:
|
| 517 |
+
# 把图像前面的labels进行复制
|
| 518 |
+
cur_new_labels.append(cur_labels[:image_token_start])
|
| 519 |
+
# 根据图像特征形状 加入shape[0]个 IGNORE_INDEX
|
| 520 |
+
cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
|
| 521 |
+
# 加入剩下的labels
|
| 522 |
+
# print("cur_new_labels length:", len(cur_new_labels))
|
| 523 |
+
# print("cur_new_labels:", cur_new_labels)
|
| 524 |
+
# print(cur_new_labels[0].shape, ' ',cur_new_labels[1].shape)
|
| 525 |
+
|
| 526 |
+
# 将cur_labels保留为剩下的未处理过的lables
|
| 527 |
+
cur_labels = cur_labels[image_token_end+1:]
|
| 528 |
+
# print("labels after image:", cur_labels)
|
| 529 |
+
# print(len(cur_labels))
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
# 将 cur_input_ids替换为剩下的 没有处理的 (img之后的) ids
|
| 533 |
+
if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
|
| 534 |
+
cur_input_ids = cur_input_ids[image_token_end+2:]
|
| 535 |
+
else:
|
| 536 |
+
cur_input_ids = cur_input_ids[image_token_end+1:]
|
| 537 |
+
# print("input_ids after image :", cur_input_ids)
|
| 538 |
+
|
| 539 |
+
# 如果图像token之后还有text token
|
| 540 |
+
if cur_input_ids.numel() > 0:
|
| 541 |
+
# print("image token 之后还有 text token")
|
| 542 |
+
if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
|
| 543 |
+
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids).detach())
|
| 544 |
+
else:
|
| 545 |
+
# 把剩下的input_id进行embedding
|
| 546 |
+
|
| 547 |
+
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids))
|
| 548 |
+
|
| 549 |
+
# print("cur_new_input_embeds length:", len(cur_new_input_embeds))
|
| 550 |
+
# print("cur_new_input_embeds shape:", cur_new_input_embeds[0].shape, cur_new_input_embeds[1].shape, cur_new_input_embeds[2].shape)
|
| 551 |
+
|
| 552 |
+
if labels is not None:
|
| 553 |
+
# 把剩下的labels加入
|
| 554 |
+
cur_new_labels.append(cur_labels)
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
cur_new_input_embeds = [x.to(device='cuda') for x in cur_new_input_embeds]
|
| 558 |
+
cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)
|
| 559 |
+
|
| 560 |
+
new_input_embeds.append(cur_new_input_embeds)
|
| 561 |
+
if labels is not None:
|
| 562 |
+
cur_new_labels = torch.cat(cur_new_labels, dim=0)
|
| 563 |
+
|
| 564 |
+
new_labels.append(cur_new_labels)
|
| 565 |
+
|
| 566 |
+
# 如果一个batch内部embedd inputs长度不一致
|
| 567 |
+
if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):
|
| 568 |
+
print("batch 内部长度不一致")
|
| 569 |
+
max_len = max(x.shape[0] for x in new_input_embeds)
|
| 570 |
+
|
| 571 |
+
new_input_embeds_align = []
|
| 572 |
+
for cur_new_embed in new_input_embeds:
|
| 573 |
+
cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)
|
| 574 |
+
new_input_embeds_align.append(cur_new_embed)
|
| 575 |
+
new_input_embeds = torch.stack(new_input_embeds_align, dim=0)
|
| 576 |
+
|
| 577 |
+
if labels is not None:
|
| 578 |
+
new_labels_align = []
|
| 579 |
+
_new_labels = new_labels
|
| 580 |
+
for cur_new_label in new_labels:
|
| 581 |
+
cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0)
|
| 582 |
+
new_labels_align.append(cur_new_label)
|
| 583 |
+
new_labels = torch.stack(new_labels_align, dim=0)
|
| 584 |
+
|
| 585 |
+
if attention_mask is not None:
|
| 586 |
+
new_attention_mask = []
|
| 587 |
+
for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels):
|
| 588 |
+
new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device)
|
| 589 |
+
new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device)
|
| 590 |
+
cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0)
|
| 591 |
+
new_attention_mask.append(cur_new_attention_mask)
|
| 592 |
+
attention_mask = torch.stack(new_attention_mask, dim=0)
|
| 593 |
+
assert attention_mask.shape == new_labels.shape
|
| 594 |
+
|
| 595 |
+
# 内部长度一致
|
| 596 |
+
else:
|
| 597 |
+
# 将一个batch的数据 拼接成 [B, token_len, dim]
|
| 598 |
+
new_input_embeds = torch.stack(new_input_embeds, dim=0)
|
| 599 |
+
if labels is not None:
|
| 600 |
+
new_labels = torch.stack(new_labels, dim=0)
|
| 601 |
+
|
| 602 |
+
if attention_mask is not None:
|
| 603 |
+
new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device)
|
| 604 |
+
attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1)
|
| 605 |
+
assert attention_mask.shape == new_input_embeds.shape[:2]
|
| 606 |
+
|
| 607 |
+
return None, attention_mask, past_key_values, new_input_embeds, new_labels
|
| 608 |
+
|
| 609 |
+
def initialize_vision_tokenizer(self, model_args, tokenizer):
|
| 610 |
+
if model_args.mm_use_im_patch_token:
|
| 611 |
+
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
| 612 |
+
tokenizer.add_tokens([DEFAULT_VIDEO_PATCH_TOKEN], special_tokens=True)
|
| 613 |
+
self.resize_token_embeddings(len(tokenizer))
|
| 614 |
+
|
| 615 |
+
if model_args.mm_use_im_start_end:
|
| 616 |
+
num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_VID_START_TOKEN, DEFAULT_VID_END_TOKEN], special_tokens=True)
|
| 617 |
+
self.resize_token_embeddings(len(tokenizer))
|
| 618 |
+
|
| 619 |
+
if num_new_tokens > 0:
|
| 620 |
+
input_embeddings = self.get_input_embeddings().weight.data
|
| 621 |
+
output_embeddings = self.get_output_embeddings().weight.data
|
| 622 |
+
|
| 623 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
| 624 |
+
dim=0, keepdim=True)
|
| 625 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
| 626 |
+
dim=0, keepdim=True)
|
| 627 |
+
|
| 628 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
| 629 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
| 630 |
+
|
| 631 |
+
if model_args.tune_mm_mlp_adapter:
|
| 632 |
+
for p in self.get_input_embeddings().parameters():
|
| 633 |
+
p.requires_grad = True
|
| 634 |
+
for p in self.get_output_embeddings().parameters():
|
| 635 |
+
p.requires_grad = False
|
| 636 |
+
|
| 637 |
+
if model_args.pretrain_mm_mlp_adapter:
|
| 638 |
+
mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
|
| 639 |
+
embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
|
| 640 |
+
assert num_new_tokens == 2
|
| 641 |
+
if input_embeddings.shape == embed_tokens_weight.shape:
|
| 642 |
+
input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
|
| 643 |
+
elif embed_tokens_weight.shape[0] == num_new_tokens:
|
| 644 |
+
input_embeddings[-num_new_tokens:] = embed_tokens_weight
|
| 645 |
+
else:
|
| 646 |
+
raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
|
| 647 |
+
elif model_args.mm_use_im_patch_token:
|
| 648 |
+
if model_args.tune_mm_mlp_adapter:
|
| 649 |
+
for p in self.get_input_embeddings().parameters():
|
| 650 |
+
p.requires_grad = False
|
| 651 |
+
for p in self.get_output_embeddings().parameters():
|
| 652 |
+
p.requires_grad = False
|
ChatUniVi/model/builder.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
|
| 4 |
+
import torch
|
| 5 |
+
from ChatUniVi.model import *
|
| 6 |
+
from ChatUniVi.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
| 7 |
+
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
|
| 8 |
+
from transformers import AutoConfig, AutoModelForCausalLM
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto"):
|
| 12 |
+
kwargs = {"device_map": device_map}
|
| 13 |
+
|
| 14 |
+
if load_8bit:
|
| 15 |
+
kwargs['load_in_8bit'] = True
|
| 16 |
+
elif load_4bit:
|
| 17 |
+
kwargs['load_in_4bit'] = True
|
| 18 |
+
kwargs['quantization_config'] = BitsAndBytesConfig(
|
| 19 |
+
load_in_4bit=True,
|
| 20 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 21 |
+
bnb_4bit_use_double_quant=True,
|
| 22 |
+
bnb_4bit_quant_type='nf4'
|
| 23 |
+
)
|
| 24 |
+
else:
|
| 25 |
+
kwargs['torch_dtype'] = torch.float16
|
| 26 |
+
|
| 27 |
+
if 'chatunivi' in model_name.lower():
|
| 28 |
+
# Load ChatUniVi model
|
| 29 |
+
if 'lora' in model_name.lower() and model_base is not None:
|
| 30 |
+
lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
| 31 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
| 32 |
+
print('Loading ChatUniVi from base model...')
|
| 33 |
+
model = ChatUniViLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
|
| 34 |
+
token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
|
| 35 |
+
if model.lm_head.weight.shape[0] != token_num:
|
| 36 |
+
model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
|
| 37 |
+
model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
|
| 38 |
+
|
| 39 |
+
print('Loading additional ChatUniVi weights...')
|
| 40 |
+
if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
|
| 41 |
+
non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
|
| 42 |
+
else:
|
| 43 |
+
# this is probably from HF Hub
|
| 44 |
+
from huggingface_hub import hf_hub_download
|
| 45 |
+
def load_from_hf(repo_id, filename, subfolder=None):
|
| 46 |
+
cache_file = hf_hub_download(
|
| 47 |
+
repo_id=repo_id,
|
| 48 |
+
filename=filename,
|
| 49 |
+
subfolder=subfolder)
|
| 50 |
+
return torch.load(cache_file, map_location='cpu')
|
| 51 |
+
non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
|
| 52 |
+
non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
|
| 53 |
+
if any(k.startswith('model.model.') for k in non_lora_trainables):
|
| 54 |
+
non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
|
| 55 |
+
model.load_state_dict(non_lora_trainables, strict=False)
|
| 56 |
+
|
| 57 |
+
from peft import PeftModel
|
| 58 |
+
print('Loading LoRA weights...')
|
| 59 |
+
model = PeftModel.from_pretrained(model, model_path)
|
| 60 |
+
print('Merging LoRA weights...')
|
| 61 |
+
model = model.merge_and_unload()
|
| 62 |
+
print('Model is loaded...')
|
| 63 |
+
elif model_base is not None:
|
| 64 |
+
# this may be mm projector only
|
| 65 |
+
print('Loading ChatUniVi from base model...')
|
| 66 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
| 67 |
+
cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
| 68 |
+
model = ChatUniViLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
|
| 69 |
+
|
| 70 |
+
mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
|
| 71 |
+
mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
|
| 72 |
+
model.load_state_dict(mm_projector_weights, strict=False)
|
| 73 |
+
else:
|
| 74 |
+
#
|
| 75 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
| 76 |
+
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
| 77 |
+
else:
|
| 78 |
+
# Load language model
|
| 79 |
+
if model_base is not None:
|
| 80 |
+
# PEFT model
|
| 81 |
+
from peft import PeftModel
|
| 82 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
| 83 |
+
model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto")
|
| 84 |
+
print(f"Loading LoRA weights from {model_path}")
|
| 85 |
+
model = PeftModel.from_pretrained(model, model_path)
|
| 86 |
+
print(f"Merging weights")
|
| 87 |
+
model = model.merge_and_unload()
|
| 88 |
+
print('Convert to FP16...')
|
| 89 |
+
model.to(torch.float16)
|
| 90 |
+
else:
|
| 91 |
+
use_fast = False
|
| 92 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
| 93 |
+
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
| 94 |
+
|
| 95 |
+
image_processor = None
|
| 96 |
+
|
| 97 |
+
if 'chatunivi' in model_name.lower():
|
| 98 |
+
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
|
| 99 |
+
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
|
| 100 |
+
if mm_use_im_patch_token:
|
| 101 |
+
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
| 102 |
+
if mm_use_im_start_end:
|
| 103 |
+
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
| 104 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 105 |
+
|
| 106 |
+
vision_tower = model.get_vision_tower()
|
| 107 |
+
if not vision_tower.is_loaded:
|
| 108 |
+
vision_tower.load_model()
|
| 109 |
+
vision_tower.to(device='cuda', dtype=torch.float16)
|
| 110 |
+
|
| 111 |
+
image_processor = vision_tower.image_eval_processor
|
| 112 |
+
|
| 113 |
+
if hasattr(model.config, "max_sequence_length"):
|
| 114 |
+
context_len = model.config.max_sequence_length
|
| 115 |
+
else:
|
| 116 |
+
context_len = 2048
|
| 117 |
+
|
| 118 |
+
return tokenizer, model, image_processor, context_len
|
ChatUniVi/model/cluster.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import math
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import warnings
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
| 8 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
| 9 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
| 10 |
+
def norm_cdf(x):
|
| 11 |
+
# Computes standard normal cumulative distribution function
|
| 12 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
| 13 |
+
|
| 14 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
| 15 |
+
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
| 16 |
+
"The distribution of values may be incorrect.",
|
| 17 |
+
stacklevel=2)
|
| 18 |
+
|
| 19 |
+
with torch.no_grad():
|
| 20 |
+
# Values are generated by using a truncated uniform distribution and
|
| 21 |
+
# then using the inverse CDF for the normal distribution.
|
| 22 |
+
# Get upper and lower cdf values
|
| 23 |
+
l = norm_cdf((a - mean) / std)
|
| 24 |
+
u = norm_cdf((b - mean) / std)
|
| 25 |
+
|
| 26 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
| 27 |
+
# [2l-1, 2u-1].
|
| 28 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
| 29 |
+
|
| 30 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
| 31 |
+
# standard normal
|
| 32 |
+
tensor.erfinv_()
|
| 33 |
+
|
| 34 |
+
# Transform to proper mean, std
|
| 35 |
+
tensor.mul_(std * math.sqrt(2.))
|
| 36 |
+
tensor.add_(mean)
|
| 37 |
+
|
| 38 |
+
# Clamp to ensure it's in the proper range
|
| 39 |
+
tensor.clamp_(min=a, max=b)
|
| 40 |
+
return tensor
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
| 44 |
+
# type: (Tensor, float, float, float, float) -> Tensor
|
| 45 |
+
r"""Fills the input Tensor with values drawn from a truncated
|
| 46 |
+
normal distribution. The values are effectively drawn from the
|
| 47 |
+
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
| 48 |
+
with values outside :math:`[a, b]` redrawn until they are within
|
| 49 |
+
the bounds. The method used for generating the random values works
|
| 50 |
+
best when :math:`a \leq \text{mean} \leq b`.
|
| 51 |
+
Args:
|
| 52 |
+
tensor: an n-dimensional `torch.Tensor`
|
| 53 |
+
mean: the mean of the normal distribution
|
| 54 |
+
std: the standard deviation of the normal distribution
|
| 55 |
+
a: the minimum cutoff value
|
| 56 |
+
b: the maximum cutoff value
|
| 57 |
+
Examples:
|
| 58 |
+
>>> w = torch.empty(3, 5)
|
| 59 |
+
>>> nn.init.trunc_normal_(w)
|
| 60 |
+
"""
|
| 61 |
+
try:
|
| 62 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
| 63 |
+
except:
|
| 64 |
+
return tensor
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
| 68 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 69 |
+
"""
|
| 70 |
+
if drop_prob == 0. or not training:
|
| 71 |
+
return x
|
| 72 |
+
keep_prob = 1 - drop_prob
|
| 73 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 74 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
| 75 |
+
random_tensor.floor_() # binarize
|
| 76 |
+
output = x.div(keep_prob) * random_tensor
|
| 77 |
+
return output
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class DropPath(nn.Module):
|
| 81 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 82 |
+
"""
|
| 83 |
+
def __init__(self, drop_prob=None):
|
| 84 |
+
super(DropPath, self).__init__()
|
| 85 |
+
self.drop_prob = drop_prob
|
| 86 |
+
|
| 87 |
+
def forward(self, x):
|
| 88 |
+
return drop_path(x, self.drop_prob, self.training)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def index_points(points, idx):
|
| 92 |
+
"""Sample features following the index.
|
| 93 |
+
Returns:
|
| 94 |
+
new_points:, indexed points data, [B, S, C]
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
points: input points data, [B, N, C]
|
| 98 |
+
idx: sample index data, [B, S]
|
| 99 |
+
"""
|
| 100 |
+
device = points.device
|
| 101 |
+
B = points.shape[0]
|
| 102 |
+
view_shape = list(idx.shape)
|
| 103 |
+
view_shape[1:] = [1] * (len(view_shape) - 1)
|
| 104 |
+
repeat_shape = list(idx.shape)
|
| 105 |
+
repeat_shape[0] = 1
|
| 106 |
+
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
|
| 107 |
+
new_points = points[batch_indices, idx, :]
|
| 108 |
+
return new_points
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def cluster_dpc_knn(token_dict, cluster_num, k=5, token_mask=None):
|
| 112 |
+
"""Cluster tokens with DPC-KNN algorithm.
|
| 113 |
+
Return:
|
| 114 |
+
idx_cluster (Tensor[B, N]): cluster index of each token.
|
| 115 |
+
cluster_num (int): actual cluster number. The same with
|
| 116 |
+
input cluster number
|
| 117 |
+
Args:
|
| 118 |
+
token_dict (dict): dict for token information
|
| 119 |
+
cluster_num (int): cluster number
|
| 120 |
+
k (int): number of the nearest neighbor used for local density.
|
| 121 |
+
token_mask (Tensor[B, N]): mask indicate the whether the token is
|
| 122 |
+
padded empty token. Non-zero value means the token is meaningful,
|
| 123 |
+
zero value means the token is an empty token. If set to None, all
|
| 124 |
+
tokens are regarded as meaningful.
|
| 125 |
+
"""
|
| 126 |
+
with torch.no_grad():
|
| 127 |
+
x = token_dict["x"]
|
| 128 |
+
B, N, C = x.shape
|
| 129 |
+
|
| 130 |
+
dist_matrix = torch.cdist(x.float(), x.float()) / (C ** 0.5)
|
| 131 |
+
|
| 132 |
+
if token_mask is not None:
|
| 133 |
+
token_mask = token_mask > 0
|
| 134 |
+
# in order to not affect the local density, the distance between empty tokens
|
| 135 |
+
# and any other tokens should be the maximal distance.
|
| 136 |
+
dist_matrix = dist_matrix * token_mask[:, None, :] + \
|
| 137 |
+
(dist_matrix.max() + 1) * (~token_mask[:, None, :])
|
| 138 |
+
|
| 139 |
+
# get local density
|
| 140 |
+
|
| 141 |
+
dist_nearest, index_nearest = torch.topk(dist_matrix, k=k, dim=-1, largest=False)
|
| 142 |
+
density = (-(dist_nearest ** 2).mean(dim=-1)).exp()
|
| 143 |
+
# add a little noise to ensure no tokens have the same density.
|
| 144 |
+
density = density + torch.rand(
|
| 145 |
+
density.shape, device=density.device, dtype=density.dtype) * 1e-6
|
| 146 |
+
|
| 147 |
+
if token_mask is not None:
|
| 148 |
+
# the density of empty token should be 0
|
| 149 |
+
density = density * token_mask
|
| 150 |
+
|
| 151 |
+
# get distance indicator
|
| 152 |
+
mask = density[:, None, :] > density[:, :, None]
|
| 153 |
+
mask = mask.type(x.dtype)
|
| 154 |
+
dist_max = dist_matrix.flatten(1).max(dim=-1)[0][:, None, None]
|
| 155 |
+
dist, index_parent = (dist_matrix * mask + dist_max * (1 - mask)).min(dim=-1)
|
| 156 |
+
|
| 157 |
+
# select clustering center according to score
|
| 158 |
+
score = dist * density
|
| 159 |
+
_, index_down = torch.topk(score, k=cluster_num, dim=-1)
|
| 160 |
+
|
| 161 |
+
# assign tokens to the nearest center
|
| 162 |
+
dist_matrix = index_points(dist_matrix, index_down)
|
| 163 |
+
|
| 164 |
+
idx_cluster = dist_matrix.argmin(dim=1)
|
| 165 |
+
|
| 166 |
+
# make sure cluster center merge to itself
|
| 167 |
+
idx_batch = torch.arange(B, device=x.device)[:, None].expand(B, cluster_num)
|
| 168 |
+
idx_tmp = torch.arange(cluster_num, device=x.device)[None, :].expand(B, cluster_num)
|
| 169 |
+
idx_cluster[idx_batch.reshape(-1), index_down.reshape(-1)] = idx_tmp.reshape(-1)
|
| 170 |
+
|
| 171 |
+
return idx_cluster, cluster_num
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def merge_tokens(token_dict, idx_cluster, cluster_num, token_weight=None):
|
| 175 |
+
"""Merge tokens in the same cluster to a single cluster.
|
| 176 |
+
Implemented by torch.index_add(). Flops: B*N*(C+2)
|
| 177 |
+
Return:
|
| 178 |
+
out_dict (dict): dict for output token information
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
token_dict (dict): dict for input token information
|
| 182 |
+
idx_cluster (Tensor[B, N]): cluster index of each token.
|
| 183 |
+
cluster_num (int): cluster number
|
| 184 |
+
token_weight (Tensor[B, N, 1]): weight for each token.
|
| 185 |
+
"""
|
| 186 |
+
|
| 187 |
+
x = token_dict['x']
|
| 188 |
+
idx_token = token_dict['idx_token']
|
| 189 |
+
agg_weight = token_dict['agg_weight']
|
| 190 |
+
|
| 191 |
+
B, N, C = x.shape
|
| 192 |
+
if token_weight is None:
|
| 193 |
+
token_weight = x.new_ones(B, N, 1)
|
| 194 |
+
|
| 195 |
+
idx_batch = torch.arange(B, device=x.device)[:, None]
|
| 196 |
+
idx = idx_cluster + idx_batch * cluster_num
|
| 197 |
+
|
| 198 |
+
all_weight = token_weight.new_zeros(B * cluster_num, 1)
|
| 199 |
+
all_weight.index_add_(dim=0, index=idx.reshape(B * N),
|
| 200 |
+
source=token_weight.reshape(B * N, 1))
|
| 201 |
+
all_weight = all_weight + 1e-6
|
| 202 |
+
norm_weight = token_weight / all_weight[idx]
|
| 203 |
+
|
| 204 |
+
# average token features
|
| 205 |
+
x_merged = x.new_zeros(B * cluster_num, C)
|
| 206 |
+
source = x * norm_weight
|
| 207 |
+
|
| 208 |
+
x_merged.index_add_(dim=0, index=idx.reshape(B * N),
|
| 209 |
+
source=source.reshape(B * N, C).type(x.dtype))
|
| 210 |
+
x_merged = x_merged.reshape(B, cluster_num, C)
|
| 211 |
+
|
| 212 |
+
idx_token_new = index_points(idx_cluster[..., None], idx_token).squeeze(-1)
|
| 213 |
+
weight_t = index_points(norm_weight, idx_token)
|
| 214 |
+
agg_weight_new = agg_weight * weight_t
|
| 215 |
+
agg_weight_new / agg_weight_new.max(dim=1, keepdim=True)[0]
|
| 216 |
+
|
| 217 |
+
out_dict = {}
|
| 218 |
+
out_dict['x'] = x_merged
|
| 219 |
+
out_dict['token_num'] = cluster_num
|
| 220 |
+
out_dict['idx_token'] = idx_token_new
|
| 221 |
+
out_dict['agg_weight'] = agg_weight_new
|
| 222 |
+
out_dict['mask'] = None
|
| 223 |
+
return out_dict
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
class CTM(nn.Module):
|
| 227 |
+
def __init__(self, sample_ratio, embed_dim, dim_out, k=5):
|
| 228 |
+
super().__init__()
|
| 229 |
+
self.sample_ratio = sample_ratio
|
| 230 |
+
self.dim_out = dim_out
|
| 231 |
+
self.k = k
|
| 232 |
+
|
| 233 |
+
def forward(self, token_dict, sample_ratio=None):
|
| 234 |
+
x = token_dict["x"]
|
| 235 |
+
B, N, C = x.shape
|
| 236 |
+
|
| 237 |
+
token_weight = x.new_ones(B, N)
|
| 238 |
+
|
| 239 |
+
if token_dict["mask"] is not None:
|
| 240 |
+
token_weight.masked_fill_((1 - token_dict["mask"]).to(torch.bool), float("-inf"))
|
| 241 |
+
token_weight = token_weight.unsqueeze(2)
|
| 242 |
+
token_dict['x'] = x
|
| 243 |
+
|
| 244 |
+
if sample_ratio is not None:
|
| 245 |
+
cluster_num = max(math.ceil(N * sample_ratio), 1)
|
| 246 |
+
elif self.sample_ratio > 1:
|
| 247 |
+
cluster_num = max(math.ceil(self.sample_ratio), 1)
|
| 248 |
+
else:
|
| 249 |
+
cluster_num = max(math.ceil(N * self.sample_ratio), 1)
|
| 250 |
+
|
| 251 |
+
k = min(3, max(cluster_num//2, 1)) if self.k > cluster_num else self.k
|
| 252 |
+
idx_cluster, cluster_num = cluster_dpc_knn(
|
| 253 |
+
token_dict, cluster_num, k, token_mask=token_dict["mask"])
|
| 254 |
+
|
| 255 |
+
down_dict = merge_tokens(token_dict, idx_cluster, cluster_num, token_weight)
|
| 256 |
+
return down_dict, token_dict
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
class TCBlock(nn.Module):
|
| 260 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
|
| 261 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, use_sr_layer=False):
|
| 262 |
+
super().__init__()
|
| 263 |
+
self.apply(self._init_weights)
|
| 264 |
+
|
| 265 |
+
def _init_weights(self, m):
|
| 266 |
+
if isinstance(m, nn.Linear):
|
| 267 |
+
trunc_normal_(m.weight, std=.02)
|
| 268 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 269 |
+
nn.init.constant_(m.bias, 0)
|
| 270 |
+
elif isinstance(m, nn.LayerNorm):
|
| 271 |
+
nn.init.constant_(m.bias, 0)
|
| 272 |
+
nn.init.constant_(m.weight, 1.0)
|
| 273 |
+
elif isinstance(m, nn.Conv2d):
|
| 274 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 275 |
+
fan_out //= m.groups
|
| 276 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
| 277 |
+
if m.bias is not None:
|
| 278 |
+
m.bias.data.zero_()
|
| 279 |
+
|
| 280 |
+
def forward(self, inputs):
|
| 281 |
+
if isinstance(inputs, tuple) or isinstance(inputs, list):
|
| 282 |
+
q_dict, kv_dict = inputs
|
| 283 |
+
else:
|
| 284 |
+
q_dict, kv_dict = inputs, None
|
| 285 |
+
|
| 286 |
+
x = q_dict['x']
|
| 287 |
+
return q_dict
|
ChatUniVi/model/consolidate.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Usage:
|
| 3 |
+
python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
|
| 4 |
+
"""
|
| 5 |
+
import argparse
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 9 |
+
from llava.model import *
|
| 10 |
+
from llava.model.utils import auto_upgrade
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def consolidate_ckpt(src_path, dst_path):
|
| 14 |
+
print("Loading model")
|
| 15 |
+
auto_upgrade(src_path)
|
| 16 |
+
src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
| 17 |
+
src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
|
| 18 |
+
src_model.save_pretrained(dst_path)
|
| 19 |
+
src_tokenizer.save_pretrained(dst_path)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
if __name__ == "__main__":
|
| 23 |
+
parser = argparse.ArgumentParser()
|
| 24 |
+
parser.add_argument("--src", type=str, required=True)
|
| 25 |
+
parser.add_argument("--dst", type=str, required=True)
|
| 26 |
+
|
| 27 |
+
args = parser.parse_args()
|
| 28 |
+
|
| 29 |
+
consolidate_ckpt(args.src, args.dst)
|
ChatUniVi/model/dataloader.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
import math
|
| 3 |
+
from decord import VideoReader, cpu
|
| 4 |
+
import numpy as np
|
| 5 |
+
import os
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def _get_rawvideo_dec(video_path, image_processor, max_frames=64, image_resolution=224, video_framerate=1, s=None, e=None):
|
| 10 |
+
# speed up video decode via decord.
|
| 11 |
+
video_mask = np.zeros(max_frames, dtype=np.int64)
|
| 12 |
+
max_video_length = 0
|
| 13 |
+
|
| 14 |
+
# T x 3 x H x W
|
| 15 |
+
video = np.zeros((max_frames, 3, image_resolution, image_resolution), dtype=np.float64)
|
| 16 |
+
|
| 17 |
+
if s is None:
|
| 18 |
+
start_time, end_time = None, None
|
| 19 |
+
else:
|
| 20 |
+
start_time = int(s)
|
| 21 |
+
end_time = int(e)
|
| 22 |
+
start_time = start_time if start_time >= 0. else 0.
|
| 23 |
+
end_time = end_time if end_time >= 0. else 0.
|
| 24 |
+
if start_time > end_time:
|
| 25 |
+
start_time, end_time = end_time, start_time
|
| 26 |
+
elif start_time == end_time:
|
| 27 |
+
end_time = start_time + 1
|
| 28 |
+
|
| 29 |
+
if os.path.exists(video_path):
|
| 30 |
+
vreader = VideoReader(video_path, ctx=cpu(0))
|
| 31 |
+
else:
|
| 32 |
+
print(video_path)
|
| 33 |
+
raise FileNotFoundError
|
| 34 |
+
|
| 35 |
+
fps = vreader.get_avg_fps()
|
| 36 |
+
f_start = 0 if start_time is None else int(start_time * fps)
|
| 37 |
+
f_end = int(min(1000000000 if end_time is None else end_time * fps, len(vreader) - 1))
|
| 38 |
+
num_frames = f_end - f_start + 1
|
| 39 |
+
if num_frames > 0:
|
| 40 |
+
# T x 3 x H x W
|
| 41 |
+
sample_fps = int(video_framerate)
|
| 42 |
+
t_stride = int(round(float(fps) / sample_fps))
|
| 43 |
+
|
| 44 |
+
all_pos = list(range(f_start, f_end + 1, t_stride))
|
| 45 |
+
if len(all_pos) > max_frames:
|
| 46 |
+
sample_pos = [all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=max_frames, dtype=int)]
|
| 47 |
+
else:
|
| 48 |
+
sample_pos = all_pos
|
| 49 |
+
|
| 50 |
+
patch_images = [Image.fromarray(f) for f in vreader.get_batch(sample_pos).asnumpy()]
|
| 51 |
+
|
| 52 |
+
patch_images = [image_processor.preprocess(img, return_tensors='pt')['pixel_values'][0] for img in patch_images]
|
| 53 |
+
slice_len = len(patch_images)
|
| 54 |
+
return patch_images, slice_len
|
| 55 |
+
max_video_length = max_video_length if max_video_length > slice_len else slice_len
|
| 56 |
+
if slice_len < 1:
|
| 57 |
+
pass
|
| 58 |
+
else:
|
| 59 |
+
while len(patch_images) < max_frames:
|
| 60 |
+
patch_images.append(torch.zeros((3, image_resolution, image_resolution)))
|
| 61 |
+
# video[:slice_len, ...] = patch_images
|
| 62 |
+
else:
|
| 63 |
+
print("video path: {} error.".format(video_path))
|
| 64 |
+
|
| 65 |
+
video_mask[:max_video_length] = [1] * max_video_length
|
| 66 |
+
|
| 67 |
+
return patch_images, video_mask
|
ChatUniVi/model/language_model/language_model/configuration_phi.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation.
|
| 2 |
+
# Licensed under the MIT license.
|
| 3 |
+
|
| 4 |
+
import math
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
from transformers import PretrainedConfig
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class PhiConfig(PretrainedConfig):
|
| 11 |
+
"""Phi configuration."""
|
| 12 |
+
|
| 13 |
+
model_type = "phi-msft"
|
| 14 |
+
attribute_map = {
|
| 15 |
+
"max_position_embeddings": "n_positions",
|
| 16 |
+
"hidden_size": "n_embd",
|
| 17 |
+
"num_attention_heads": "n_head",
|
| 18 |
+
"num_hidden_layers": "n_layer",
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
vocab_size: int = 50304,
|
| 24 |
+
n_positions: int = 2048,
|
| 25 |
+
n_embd: int = 1024,
|
| 26 |
+
n_layer: int = 20,
|
| 27 |
+
n_inner: Optional[int] = None,
|
| 28 |
+
n_head: int = 16,
|
| 29 |
+
n_head_kv: Optional[int] = None,
|
| 30 |
+
rotary_dim: Optional[int] = 32,
|
| 31 |
+
activation_function: Optional[str] = "gelu_new",
|
| 32 |
+
flash_attn: bool = False,
|
| 33 |
+
flash_rotary: bool = False,
|
| 34 |
+
fused_dense: bool = False,
|
| 35 |
+
attn_pdrop: float = 0.0,
|
| 36 |
+
embd_pdrop: float = 0.0,
|
| 37 |
+
resid_pdrop: float = 0.0,
|
| 38 |
+
layer_norm_epsilon: float = 1e-5,
|
| 39 |
+
initializer_range: float = 0.02,
|
| 40 |
+
tie_word_embeddings: bool = False,
|
| 41 |
+
pad_vocab_size_multiple: int = 64,
|
| 42 |
+
**kwargs
|
| 43 |
+
) -> None:
|
| 44 |
+
self.vocab_size = int(math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
|
| 45 |
+
self.n_positions = n_positions
|
| 46 |
+
self.n_embd = n_embd
|
| 47 |
+
self.n_layer = n_layer
|
| 48 |
+
self.n_inner = n_inner
|
| 49 |
+
self.n_head = n_head
|
| 50 |
+
self.n_head_kv = n_head_kv
|
| 51 |
+
self.rotary_dim = min(rotary_dim, n_embd // n_head)
|
| 52 |
+
self.activation_function = activation_function
|
| 53 |
+
self.flash_attn = flash_attn
|
| 54 |
+
self.flash_rotary = flash_rotary
|
| 55 |
+
self.fused_dense = fused_dense
|
| 56 |
+
self.attn_pdrop = attn_pdrop
|
| 57 |
+
self.embd_pdrop = embd_pdrop
|
| 58 |
+
self.resid_pdrop = resid_pdrop
|
| 59 |
+
self.layer_norm_epsilon = layer_norm_epsilon
|
| 60 |
+
self.initializer_range = initializer_range
|
| 61 |
+
|
| 62 |
+
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
ChatUniVi/model/language_model/language_model/modeling_phi.py
ADDED
|
@@ -0,0 +1,984 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation.
|
| 2 |
+
# Licensed under the MIT license.
|
| 3 |
+
#
|
| 4 |
+
# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
|
| 5 |
+
# Licensed under the BSD 3-Clause License.
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import math
|
| 10 |
+
from dataclasses import dataclass, field
|
| 11 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
from einops import rearrange, repeat
|
| 16 |
+
from transformers import PretrainedConfig, PreTrainedModel
|
| 17 |
+
from transformers.activations import ACT2FN
|
| 18 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 19 |
+
|
| 20 |
+
from .configuration_phi import PhiConfig
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
from flash_attn.bert_padding import pad_input, unpad_input
|
| 24 |
+
from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
|
| 25 |
+
from flash_attn.modules.mha import FlashCrossAttention, FlashSelfAttention
|
| 26 |
+
# from flash_attn.ops.fused_dense import FusedDense
|
| 27 |
+
except:
|
| 28 |
+
pad_input, unpad_input = None, None
|
| 29 |
+
FlashRotaryEmbedding = None
|
| 30 |
+
FlashSelfAttention, FlashCrossAttention = None, None
|
| 31 |
+
FusedDense = None
|
| 32 |
+
from flash_attn.bert_padding import pad_input, unpad_input
|
| 33 |
+
from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
|
| 34 |
+
from flash_attn.modules.mha import FlashCrossAttention, FlashSelfAttention
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class InferenceParams:
|
| 38 |
+
"""Inference parameters passed to model to efficiently calculate
|
| 39 |
+
and store context during inference.
|
| 40 |
+
|
| 41 |
+
Reference:
|
| 42 |
+
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
max_seqlen: Maximum sequence length.
|
| 46 |
+
max_batch_size: Maximum batch size.
|
| 47 |
+
seqlen_offset: Sequence length offset.
|
| 48 |
+
batch_size_offset: Batch size offset.
|
| 49 |
+
key_value_memory_dict: Key value memory dictionary.
|
| 50 |
+
lengths_per_sample: Lengths per sample.
|
| 51 |
+
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
max_seqlen: int = field(metadata={"help": "Maximum sequence length."})
|
| 55 |
+
|
| 56 |
+
max_batch_size: int = field(metadata={"help": "Maximum batch size."})
|
| 57 |
+
|
| 58 |
+
seqlen_offset: int = field(default=0, metadata={"help": "Sequence length offset."})
|
| 59 |
+
|
| 60 |
+
batch_size_offset: int = field(default=0, metadata={"help": "Batch size offset."})
|
| 61 |
+
|
| 62 |
+
key_value_memory_dict: Dict[str, Any] = field(
|
| 63 |
+
default_factory=dict, metadata={"help": "Key value memory dictionary."}
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
lengths_per_sample: torch.Tensor = field(default=None, metadata={"help": "Lengths per sample."})
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class Embedding(nn.Module):
|
| 70 |
+
"""Token embedding with dropout."""
|
| 71 |
+
|
| 72 |
+
def __init__(self, config: PretrainedConfig) -> None:
|
| 73 |
+
super().__init__()
|
| 74 |
+
|
| 75 |
+
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
|
| 76 |
+
self.drop = nn.Dropout(config.embd_pdrop)
|
| 77 |
+
|
| 78 |
+
def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
|
| 79 |
+
input_shape = input_ids.size()
|
| 80 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
| 81 |
+
|
| 82 |
+
hidden_states = self.wte(input_ids)
|
| 83 |
+
hidden_states = self.drop(hidden_states)
|
| 84 |
+
|
| 85 |
+
return hidden_states
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _apply_rotary_emb(
|
| 89 |
+
x: torch.FloatTensor,
|
| 90 |
+
cos: torch.FloatTensor,
|
| 91 |
+
sin: torch.FloatTensor,
|
| 92 |
+
) -> torch.FloatTensor:
|
| 93 |
+
_, seqlen, _, _ = x.shape
|
| 94 |
+
_, rotary_dim = cos.shape
|
| 95 |
+
rotary_dim *= 2
|
| 96 |
+
|
| 97 |
+
x_rot = x[:, :, :, :rotary_dim]
|
| 98 |
+
x_pass = x[:, :, :, rotary_dim:]
|
| 99 |
+
|
| 100 |
+
x1, x2 = x_rot.chunk(2, dim=-1)
|
| 101 |
+
c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
|
| 102 |
+
x1, x2, c, s = [t.to(dtype=torch.float32) for t in [x1, x2, c, s]]
|
| 103 |
+
|
| 104 |
+
x_rot = torch.cat([x1 * c - x2 * s, x1 * s + x2 * c], axis=-1).to(x.dtype)
|
| 105 |
+
|
| 106 |
+
return torch.cat([x_rot, x_pass], axis=-1)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def _apply_rotary_emb_kv(
|
| 110 |
+
kv: torch.FloatTensor,
|
| 111 |
+
cos: torch.FloatTensor,
|
| 112 |
+
sin: torch.FloatTensor,
|
| 113 |
+
cos_k: Optional[torch.FloatTensor] = None,
|
| 114 |
+
sin_k: Optional[torch.FloatTensor] = None,
|
| 115 |
+
) -> torch.FloatTensor:
|
| 116 |
+
_, seqlen, _, _, _ = kv.shape
|
| 117 |
+
_, rotary_dim = cos.shape
|
| 118 |
+
rotary_dim *= 2
|
| 119 |
+
|
| 120 |
+
k_rot = kv[:, :, 0, :, :rotary_dim]
|
| 121 |
+
k_pass = kv[:, :, 0, :, rotary_dim:]
|
| 122 |
+
|
| 123 |
+
k1, k2 = k_rot.chunk(2, dim=-1)
|
| 124 |
+
c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
|
| 125 |
+
k1, k2, c, s = [t.to(dtype=torch.float32) for t in [k1, k2, c, s]]
|
| 126 |
+
|
| 127 |
+
k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(kv.dtype)
|
| 128 |
+
|
| 129 |
+
return torch.cat(
|
| 130 |
+
[
|
| 131 |
+
torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
|
| 132 |
+
kv[:, :, 1:2, :, :],
|
| 133 |
+
],
|
| 134 |
+
axis=2,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def _apply_rotary_emb_qkv(
|
| 139 |
+
qkv: torch.FloatTensor,
|
| 140 |
+
cos: torch.FloatTensor,
|
| 141 |
+
sin: torch.FloatTensor,
|
| 142 |
+
cos_k: Optional[torch.FloatTensor] = None,
|
| 143 |
+
sin_k: Optional[torch.FloatTensor] = None,
|
| 144 |
+
) -> torch.FloatTensor:
|
| 145 |
+
_, seqlen, _, _, _ = qkv.shape
|
| 146 |
+
_, rotary_dim = cos.shape
|
| 147 |
+
rotary_dim *= 2
|
| 148 |
+
|
| 149 |
+
q_rot = qkv[:, :, 0, :, :rotary_dim]
|
| 150 |
+
q_pass = qkv[:, :, 0, :, rotary_dim:]
|
| 151 |
+
|
| 152 |
+
k_rot = qkv[:, :, 1, :, :rotary_dim]
|
| 153 |
+
k_pass = qkv[:, :, 1, :, rotary_dim:]
|
| 154 |
+
|
| 155 |
+
q1, q2 = q_rot.chunk(2, dim=-1)
|
| 156 |
+
k1, k2 = k_rot.chunk(2, dim=-1)
|
| 157 |
+
c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
|
| 158 |
+
q1, q2, k1, k2, c, s = [t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s]]
|
| 159 |
+
|
| 160 |
+
q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype)
|
| 161 |
+
k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype)
|
| 162 |
+
|
| 163 |
+
return torch.cat(
|
| 164 |
+
[
|
| 165 |
+
torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2),
|
| 166 |
+
torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
|
| 167 |
+
qkv[:, :, 2:3, :, :],
|
| 168 |
+
],
|
| 169 |
+
axis=2,
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class RotaryEmbedding(nn.Module):
|
| 174 |
+
"""Rotary positional embedding (RoPE).
|
| 175 |
+
|
| 176 |
+
Reference:
|
| 177 |
+
RoFormer: Enhanced Transformer with Rotary Position Embedding.
|
| 178 |
+
https://arxiv.org/pdf/2104.09864.pdf.
|
| 179 |
+
|
| 180 |
+
"""
|
| 181 |
+
|
| 182 |
+
def __init__(
|
| 183 |
+
self,
|
| 184 |
+
dim: int,
|
| 185 |
+
base: int = 10000,
|
| 186 |
+
scale_base: Optional[float] = None,
|
| 187 |
+
pos_idx_in_fp32: bool = True,
|
| 188 |
+
max_position_embeddings: int = 2048,
|
| 189 |
+
device: Optional[str] = None,
|
| 190 |
+
**kwargs,
|
| 191 |
+
) -> None:
|
| 192 |
+
super().__init__()
|
| 193 |
+
|
| 194 |
+
if scale_base is not None:
|
| 195 |
+
raise NotImplementedError
|
| 196 |
+
|
| 197 |
+
self.dim = dim
|
| 198 |
+
self.base = float(base)
|
| 199 |
+
self.scale_base = scale_base
|
| 200 |
+
self.pos_idx_in_fp32 = pos_idx_in_fp32
|
| 201 |
+
self.max_position_embeddings = max_position_embeddings
|
| 202 |
+
self.device = device
|
| 203 |
+
|
| 204 |
+
# Generate and save the inverse frequency buffer (non-trainable)
|
| 205 |
+
inv_freq = self._compute_inv_freq(device)
|
| 206 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 207 |
+
|
| 208 |
+
# Generate and save the scale buffer (non-trainable)
|
| 209 |
+
scale = (
|
| 210 |
+
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
|
| 211 |
+
if scale_base is not None
|
| 212 |
+
else None
|
| 213 |
+
)
|
| 214 |
+
self.register_buffer("scale", scale, persistent=False)
|
| 215 |
+
|
| 216 |
+
# Initialize cached attributes since ONNX can't rely on dynamic initialization
|
| 217 |
+
self._update_cos_sin_cache(max_position_embeddings, device=device, dtype=torch.float32)
|
| 218 |
+
|
| 219 |
+
def _compute_inv_freq(self, device: Optional[str] = None) -> torch.FloatTensor:
|
| 220 |
+
return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
|
| 221 |
+
|
| 222 |
+
def _update_cos_sin_cache(
|
| 223 |
+
self,
|
| 224 |
+
seqlen: int,
|
| 225 |
+
device: Optional[str] = None,
|
| 226 |
+
dtype: Optional[torch.dtype] = None,
|
| 227 |
+
) -> None:
|
| 228 |
+
self._seq_len_cached = seqlen
|
| 229 |
+
|
| 230 |
+
# fp32 is preferred since the output of `torch.arange` can be quite large
|
| 231 |
+
# and bf16 would lose a lot of precision
|
| 232 |
+
if self.pos_idx_in_fp32:
|
| 233 |
+
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
| 234 |
+
if self.inv_freq.dtype != torch.float32:
|
| 235 |
+
inv_freq = self._compute_inv_freq(device=device)
|
| 236 |
+
else:
|
| 237 |
+
inv_freq = self.inv_freq
|
| 238 |
+
else:
|
| 239 |
+
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
| 240 |
+
inv_freq = self.inv_freq
|
| 241 |
+
|
| 242 |
+
# `torch.outer` is preferred since `torch.einsum` converts from fp32 to fp16 if used with AMP
|
| 243 |
+
freqs = torch.outer(t, inv_freq)
|
| 244 |
+
if self.scale is None:
|
| 245 |
+
self._cos_cached = torch.cos(freqs).to(dtype)
|
| 246 |
+
self._sin_cached = torch.sin(freqs).to(dtype)
|
| 247 |
+
else:
|
| 248 |
+
power = (
|
| 249 |
+
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
|
| 250 |
+
) / self.scale_base
|
| 251 |
+
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
|
| 252 |
+
|
| 253 |
+
# Force the scale multiplication to happen in fp32
|
| 254 |
+
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
| 255 |
+
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
| 256 |
+
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
|
| 257 |
+
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
|
| 258 |
+
|
| 259 |
+
def forward(
|
| 260 |
+
self,
|
| 261 |
+
qkv: torch.Tensor,
|
| 262 |
+
kv: Optional[torch.Tensor] = None,
|
| 263 |
+
seqlen_offset: int = 0,
|
| 264 |
+
**kwargs,
|
| 265 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 266 |
+
if (
|
| 267 |
+
self._seq_len_cached < qkv.shape[1] + seqlen_offset
|
| 268 |
+
or self._cos_cached.device != qkv.device
|
| 269 |
+
or self._cos_cached.dtype != qkv.dtype
|
| 270 |
+
or (self.training and self._cos_cached.is_inference())
|
| 271 |
+
):
|
| 272 |
+
self._update_cos_sin_cache(qkv.shape[1] + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
|
| 273 |
+
|
| 274 |
+
if kv is None:
|
| 275 |
+
return _apply_rotary_emb_qkv(
|
| 276 |
+
qkv,
|
| 277 |
+
self._cos_cached[seqlen_offset:],
|
| 278 |
+
self._sin_cached[seqlen_offset:],
|
| 279 |
+
)
|
| 280 |
+
else:
|
| 281 |
+
q = _apply_rotary_emb(
|
| 282 |
+
qkv,
|
| 283 |
+
self._cos_cached[seqlen_offset:],
|
| 284 |
+
self._sin_cached[seqlen_offset:],
|
| 285 |
+
)
|
| 286 |
+
kv = _apply_rotary_emb_kv(
|
| 287 |
+
kv,
|
| 288 |
+
self._cos_cached[seqlen_offset:],
|
| 289 |
+
self._sin_cached[seqlen_offset:],
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
return q, kv
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
class MLP(nn.Module):
|
| 296 |
+
"""Multi-Layer Perceptron.
|
| 297 |
+
|
| 298 |
+
Reference:
|
| 299 |
+
Attention Is All You Need.
|
| 300 |
+
https://arxiv.org/pdf/1706.03762.pdf.
|
| 301 |
+
|
| 302 |
+
"""
|
| 303 |
+
|
| 304 |
+
def __init__(
|
| 305 |
+
self,
|
| 306 |
+
config: PretrainedConfig,
|
| 307 |
+
n_inner: Optional[int] = None,
|
| 308 |
+
act_fn: Optional[str] = None,
|
| 309 |
+
) -> None:
|
| 310 |
+
super().__init__()
|
| 311 |
+
|
| 312 |
+
act_fn = config.activation_function if act_fn is None else act_fn
|
| 313 |
+
|
| 314 |
+
n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
|
| 315 |
+
n_inner = n_inner if n_inner is not None else 4 * config.n_embd
|
| 316 |
+
|
| 317 |
+
self.fc1 = nn.Linear(config.n_embd, n_inner)
|
| 318 |
+
self.fc2 = nn.Linear(n_inner, config.n_embd)
|
| 319 |
+
self.act = ACT2FN[act_fn]
|
| 320 |
+
|
| 321 |
+
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
| 322 |
+
hidden_states = self.fc1(hidden_states)
|
| 323 |
+
hidden_states = self.act(hidden_states)
|
| 324 |
+
hidden_states = self.fc2(hidden_states)
|
| 325 |
+
|
| 326 |
+
return hidden_states
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
class SelfAttention(nn.Module):
|
| 330 |
+
"""Self-attention layer (compatible with PyTorch).
|
| 331 |
+
|
| 332 |
+
Reference:
|
| 333 |
+
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
|
| 334 |
+
|
| 335 |
+
"""
|
| 336 |
+
|
| 337 |
+
def __init__(
|
| 338 |
+
self,
|
| 339 |
+
causal: bool = True,
|
| 340 |
+
softmax_scale: Optional[float] = None,
|
| 341 |
+
attention_dropout: float = 0.0,
|
| 342 |
+
) -> None:
|
| 343 |
+
super().__init__()
|
| 344 |
+
|
| 345 |
+
self.causal = causal
|
| 346 |
+
self.softmax_scale = softmax_scale
|
| 347 |
+
self.drop = nn.Dropout(attention_dropout)
|
| 348 |
+
|
| 349 |
+
@torch.autocast("cpu", enabled=False)
|
| 350 |
+
@torch.autocast("cuda", enabled=False)
|
| 351 |
+
def forward(
|
| 352 |
+
self,
|
| 353 |
+
qkv: torch.FloatTensor,
|
| 354 |
+
causal: bool = None,
|
| 355 |
+
key_padding_mask: Optional[torch.BoolTensor] = None,
|
| 356 |
+
**kwargs,
|
| 357 |
+
) -> torch.FloatTensor:
|
| 358 |
+
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
|
| 359 |
+
q, k, v = qkv.unbind(dim=2)
|
| 360 |
+
|
| 361 |
+
q = q.to(torch.float32)
|
| 362 |
+
k = k.to(torch.float32)
|
| 363 |
+
|
| 364 |
+
causal = self.causal if causal is None else causal
|
| 365 |
+
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
|
| 366 |
+
|
| 367 |
+
# Autocast is manually disabled to avoid `torch.einsum` performing the operation
|
| 368 |
+
# using float16, which might lead to overflow
|
| 369 |
+
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
|
| 370 |
+
|
| 371 |
+
if key_padding_mask is not None:
|
| 372 |
+
padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device)
|
| 373 |
+
padding_mask.masked_fill_(key_padding_mask, 0.0)
|
| 374 |
+
|
| 375 |
+
scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
|
| 376 |
+
|
| 377 |
+
if causal:
|
| 378 |
+
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
|
| 379 |
+
scores = scores + causal_mask.to(dtype=scores.dtype)
|
| 380 |
+
|
| 381 |
+
attention = torch.softmax(scores, dim=-1).to(v.dtype)
|
| 382 |
+
attention = self.drop(attention)
|
| 383 |
+
|
| 384 |
+
output = torch.einsum("bhts,bshd->bthd", attention, v)
|
| 385 |
+
|
| 386 |
+
return output
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
class CrossAttention(nn.Module):
|
| 390 |
+
"""Cross-attention layer (compatible with PyTorch).
|
| 391 |
+
|
| 392 |
+
Reference:
|
| 393 |
+
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
|
| 394 |
+
|
| 395 |
+
"""
|
| 396 |
+
|
| 397 |
+
def __init__(
|
| 398 |
+
self,
|
| 399 |
+
causal: bool = True,
|
| 400 |
+
softmax_scale: Optional[float] = None,
|
| 401 |
+
attention_dropout: float = 0.0,
|
| 402 |
+
) -> None:
|
| 403 |
+
super().__init__()
|
| 404 |
+
|
| 405 |
+
self.causal = causal
|
| 406 |
+
self.softmax_scale = softmax_scale
|
| 407 |
+
self.drop = nn.Dropout(attention_dropout)
|
| 408 |
+
|
| 409 |
+
@torch.autocast("cpu", enabled=False)
|
| 410 |
+
@torch.autocast("cuda", enabled=False)
|
| 411 |
+
def forward(
|
| 412 |
+
self,
|
| 413 |
+
q: torch.FloatTensor,
|
| 414 |
+
kv: torch.FloatTensor,
|
| 415 |
+
causal: bool = None,
|
| 416 |
+
key_padding_mask: Optional[torch.BoolTensor] = None,
|
| 417 |
+
**kwargs,
|
| 418 |
+
) -> torch.FloatTensor:
|
| 419 |
+
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
| 420 |
+
seqlen_k = kv.shape[1]
|
| 421 |
+
|
| 422 |
+
if kv.shape[3] != q.shape[2]:
|
| 423 |
+
kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
|
| 424 |
+
k, v = kv.unbind(dim=2)
|
| 425 |
+
|
| 426 |
+
q = q.to(torch.float32)
|
| 427 |
+
k = k.to(torch.float32)
|
| 428 |
+
|
| 429 |
+
causal = self.causal if causal is None else causal
|
| 430 |
+
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
|
| 431 |
+
|
| 432 |
+
# Autocast is manually disabled to avoid `torch.einsum` performing the operation
|
| 433 |
+
# using float16, which might lead to overflow
|
| 434 |
+
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
|
| 435 |
+
|
| 436 |
+
if key_padding_mask is not None:
|
| 437 |
+
padding_mask = torch.full(
|
| 438 |
+
(batch_size, seqlen_k),
|
| 439 |
+
-10000.0,
|
| 440 |
+
dtype=scores.dtype,
|
| 441 |
+
device=scores.device,
|
| 442 |
+
)
|
| 443 |
+
padding_mask.masked_fill_(key_padding_mask, 0.0)
|
| 444 |
+
|
| 445 |
+
scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
|
| 446 |
+
|
| 447 |
+
if causal:
|
| 448 |
+
rows = rearrange(torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1")
|
| 449 |
+
cols = torch.arange(seqlen_k, device=k.device, dtype=torch.long)
|
| 450 |
+
causal_mask = cols > rows + seqlen_k - seqlen_q
|
| 451 |
+
|
| 452 |
+
scores = scores.masked_fill(causal_mask, -10000.0)
|
| 453 |
+
|
| 454 |
+
attention = torch.softmax(scores, dim=-1).to(v.dtype)
|
| 455 |
+
attention = self.drop(attention)
|
| 456 |
+
|
| 457 |
+
output = torch.einsum("bhts,bshd->bthd", attention, v)
|
| 458 |
+
|
| 459 |
+
return output
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
def _find_mha_dims(
|
| 463 |
+
config: PretrainedConfig,
|
| 464 |
+
n_head: Optional[int] = None,
|
| 465 |
+
n_head_kv: Optional[int] = None,
|
| 466 |
+
head_dim: Optional[int] = None,
|
| 467 |
+
) -> Tuple[int, int]:
|
| 468 |
+
if n_head is None and head_dim is None:
|
| 469 |
+
head_dim = config.n_embd // config.n_head
|
| 470 |
+
n_head = config.n_head
|
| 471 |
+
elif n_head is None or head_dim is None:
|
| 472 |
+
raise ValueError("`n_head` and `head_dim` must be both specified or `None`.")
|
| 473 |
+
|
| 474 |
+
if n_head_kv is None:
|
| 475 |
+
n_head_kv = getattr(config, "n_head_kv", None) or n_head
|
| 476 |
+
|
| 477 |
+
return n_head, n_head_kv, head_dim
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, layer_idx: int) -> torch.FloatTensor:
|
| 481 |
+
num_heads, head_dim = kv.shape[-2:]
|
| 482 |
+
|
| 483 |
+
if layer_idx not in inference_params.key_value_memory_dict:
|
| 484 |
+
inference_params.key_value_memory_dict[layer_idx] = torch.empty(
|
| 485 |
+
inference_params.max_batch_size,
|
| 486 |
+
inference_params.max_seqlen,
|
| 487 |
+
2,
|
| 488 |
+
num_heads,
|
| 489 |
+
head_dim,
|
| 490 |
+
dtype=kv.dtype,
|
| 491 |
+
device=kv.device,
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
batch_start = inference_params.batch_size_offset
|
| 495 |
+
batch_end = batch_start + kv.shape[0]
|
| 496 |
+
|
| 497 |
+
sequence_start = inference_params.seqlen_offset
|
| 498 |
+
sequence_end = sequence_start + kv.shape[1]
|
| 499 |
+
|
| 500 |
+
# When the current sequence length is equal to or larger than the maximum sequence length,
|
| 501 |
+
# we need to concatenate the current `kv` with the cached `kv` to expand its length
|
| 502 |
+
if sequence_end >= inference_params.max_seqlen:
|
| 503 |
+
inference_params.key_value_memory_dict[layer_idx] = torch.concatenate((inference_params.key_value_memory_dict[layer_idx], kv), dim=1)
|
| 504 |
+
|
| 505 |
+
inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
| 506 |
+
kv = inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, :sequence_end, ...]
|
| 507 |
+
|
| 508 |
+
return kv
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
class MHA(nn.Module):
|
| 512 |
+
"""Multi-head attention layer."""
|
| 513 |
+
|
| 514 |
+
def __init__(
|
| 515 |
+
self,
|
| 516 |
+
config: PretrainedConfig,
|
| 517 |
+
dtype: Optional[torch.dtype] = None,
|
| 518 |
+
device: Optional[str] = None,
|
| 519 |
+
rotary_dim: Optional[int] = None,
|
| 520 |
+
rotary_base: float = 10000.0,
|
| 521 |
+
rotary_scale_base: Optional[float] = None,
|
| 522 |
+
n_head: Optional[int] = None,
|
| 523 |
+
n_head_kv: Optional[int] = None,
|
| 524 |
+
head_dim: Optional[int] = None,
|
| 525 |
+
bias: bool = True,
|
| 526 |
+
causal: bool = True,
|
| 527 |
+
softmax_scale: Optional[float] = None,
|
| 528 |
+
layer_idx: Optional[int] = None,
|
| 529 |
+
return_residual: bool = False,
|
| 530 |
+
checkpointing: bool = False,
|
| 531 |
+
) -> None:
|
| 532 |
+
super().__init__()
|
| 533 |
+
|
| 534 |
+
# Rotary embedding
|
| 535 |
+
self.rotary_dim = rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0)
|
| 536 |
+
if self.rotary_dim > 0:
|
| 537 |
+
rotary_cls = FlashRotaryEmbedding if config.flash_rotary else RotaryEmbedding
|
| 538 |
+
if rotary_cls is None:
|
| 539 |
+
rotary_cls = RotaryEmbedding
|
| 540 |
+
|
| 541 |
+
rotary_kwargs = {}
|
| 542 |
+
if rotary_cls is RotaryEmbedding:
|
| 543 |
+
rotary_kwargs["max_position_embeddings"] = config.n_positions
|
| 544 |
+
|
| 545 |
+
self.rotary_emb = rotary_cls(
|
| 546 |
+
self.rotary_dim,
|
| 547 |
+
base=rotary_base,
|
| 548 |
+
scale_base=rotary_scale_base,
|
| 549 |
+
device=device,
|
| 550 |
+
**rotary_kwargs,
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
# MLP
|
| 554 |
+
self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(
|
| 555 |
+
config, n_head=n_head, n_head_kv=n_head_kv, head_dim=head_dim
|
| 556 |
+
)
|
| 557 |
+
op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv)
|
| 558 |
+
hidden_size = config.n_embd
|
| 559 |
+
|
| 560 |
+
linear_cls = FusedDense if config.fused_dense else nn.Linear
|
| 561 |
+
if linear_cls is None:
|
| 562 |
+
linear_cls = nn.Linear
|
| 563 |
+
|
| 564 |
+
self.Wqkv = linear_cls(hidden_size, op_size, bias=bias, device=device, dtype=dtype)
|
| 565 |
+
self.out_proj = linear_cls(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype)
|
| 566 |
+
|
| 567 |
+
# Attention
|
| 568 |
+
# attn_cls = FlashSelfAttention if config.flash_attn else SelfAttention
|
| 569 |
+
attn_cls = FlashSelfAttention
|
| 570 |
+
if attn_cls is None:
|
| 571 |
+
attn_cls = SelfAttention
|
| 572 |
+
|
| 573 |
+
# cross_attn_cls = FlashCrossAttention if config.flash_attn else CrossAttention
|
| 574 |
+
cross_attn_cls = FlashCrossAttention
|
| 575 |
+
if cross_attn_cls is None:
|
| 576 |
+
cross_attn_cls = CrossAttention
|
| 577 |
+
|
| 578 |
+
self.inner_attn = attn_cls(
|
| 579 |
+
causal=causal,
|
| 580 |
+
softmax_scale=softmax_scale,
|
| 581 |
+
attention_dropout=config.attn_pdrop,
|
| 582 |
+
)
|
| 583 |
+
self.inner_cross_attn = cross_attn_cls(
|
| 584 |
+
causal=causal,
|
| 585 |
+
softmax_scale=softmax_scale,
|
| 586 |
+
attention_dropout=config.attn_pdrop,
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
# self.flash_attn = config.flash_attn and attn_cls is FlashSelfAttention
|
| 590 |
+
self.flash_attn = True
|
| 591 |
+
self.layer_idx = layer_idx
|
| 592 |
+
self.return_residual = return_residual
|
| 593 |
+
self.checkpointing = checkpointing
|
| 594 |
+
|
| 595 |
+
def _forward_self_attn(
|
| 596 |
+
self, x: torch.FloatTensor, key_padding_mask: Optional[torch.BoolTensor]
|
| 597 |
+
) -> torch.FloatTensor:
|
| 598 |
+
qkv = self.Wqkv(x)
|
| 599 |
+
qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
|
| 600 |
+
|
| 601 |
+
if self.rotary_dim > 0:
|
| 602 |
+
qkv = self.rotary_emb(qkv)
|
| 603 |
+
|
| 604 |
+
if self.flash_attn:
|
| 605 |
+
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
|
| 606 |
+
|
| 607 |
+
cu_seqlens, max_seqlen = None, None
|
| 608 |
+
if key_padding_mask is not None:
|
| 609 |
+
# If `key_padding_mask` is supplied, we need to unpad the input and retrieve
|
| 610 |
+
# the `cu_seqlens` and `max_seqlen` to be used by `flash-attn`
|
| 611 |
+
qkv, indices, cu_seqlens, max_seqlen = unpad_input(qkv, key_padding_mask)
|
| 612 |
+
|
| 613 |
+
if self.checkpointing:
|
| 614 |
+
attn_output = torch.utils.checkpoint.checkpoint(
|
| 615 |
+
self.inner_attn, qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
|
| 616 |
+
)
|
| 617 |
+
else:
|
| 618 |
+
attn_output = self.inner_attn(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen).to(qkv.device)
|
| 619 |
+
|
| 620 |
+
# If `key_padding_mask` is supplied, we need to pad the output back to the original shape
|
| 621 |
+
return pad_input(attn_output, indices, batch_size, seqlen) if key_padding_mask is not None else attn_output
|
| 622 |
+
|
| 623 |
+
if self.checkpointing:
|
| 624 |
+
return torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, key_padding_mask=key_padding_mask)
|
| 625 |
+
|
| 626 |
+
return self.inner_attn(qkv, key_padding_mask=key_padding_mask)
|
| 627 |
+
|
| 628 |
+
def _forward_cross_attn(
|
| 629 |
+
self,
|
| 630 |
+
x: torch.FloatTensor,
|
| 631 |
+
past_key_values: Optional[InferenceParams],
|
| 632 |
+
key_padding_mask: Optional[torch.BoolTensor],
|
| 633 |
+
) -> torch.FloatTensor:
|
| 634 |
+
batch_size = x.shape[0]
|
| 635 |
+
|
| 636 |
+
qkv = self.Wqkv(x)
|
| 637 |
+
|
| 638 |
+
q = qkv[..., : self.n_head * self.head_dim]
|
| 639 |
+
q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
|
| 640 |
+
|
| 641 |
+
kv = qkv[..., self.n_head * self.head_dim :]
|
| 642 |
+
kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
|
| 643 |
+
|
| 644 |
+
seqlen_offset = past_key_values.seqlen_offset if past_key_values is not None else 0
|
| 645 |
+
causal = None if seqlen_offset == 0 else False
|
| 646 |
+
if self.rotary_dim > 0:
|
| 647 |
+
q, kv = self.rotary_emb(q, kv=kv, seqlen_offset=seqlen_offset)
|
| 648 |
+
|
| 649 |
+
if past_key_values is not None:
|
| 650 |
+
kv = _update_kv_cache(kv, past_key_values, self.layer_idx)
|
| 651 |
+
|
| 652 |
+
if self.flash_attn:
|
| 653 |
+
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
| 654 |
+
seqlen_k = kv.shape[1]
|
| 655 |
+
|
| 656 |
+
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = (
|
| 657 |
+
None,
|
| 658 |
+
None,
|
| 659 |
+
None,
|
| 660 |
+
None,
|
| 661 |
+
)
|
| 662 |
+
if key_padding_mask is not None:
|
| 663 |
+
kv, _, cu_seqlens_k, max_seqlen_k = unpad_input(kv, key_padding_mask)
|
| 664 |
+
|
| 665 |
+
if seqlen_q == 1:
|
| 666 |
+
key_padding_mask = torch.ones(batch_size, 1, device=q.device)
|
| 667 |
+
elif seqlen_q != seqlen_k:
|
| 668 |
+
key_padding_mask = key_padding_mask[:, -seqlen_q:]
|
| 669 |
+
|
| 670 |
+
q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, key_padding_mask)
|
| 671 |
+
|
| 672 |
+
if self.checkpointing:
|
| 673 |
+
attn_output = torch.utils.checkpoint.checkpoint(
|
| 674 |
+
self.inner_cross_attn,
|
| 675 |
+
q,
|
| 676 |
+
kv,
|
| 677 |
+
causal=causal,
|
| 678 |
+
cu_seqlens=cu_seqlens_q,
|
| 679 |
+
max_seqlen=max_seqlen_q,
|
| 680 |
+
cu_seqlens_k=cu_seqlens_k,
|
| 681 |
+
max_seqlen_k=max_seqlen_k,
|
| 682 |
+
)
|
| 683 |
+
else:
|
| 684 |
+
attn_output = self.inner_cross_attn(
|
| 685 |
+
q,
|
| 686 |
+
kv,
|
| 687 |
+
causal=causal,
|
| 688 |
+
cu_seqlens=cu_seqlens_q,
|
| 689 |
+
max_seqlen=max_seqlen_q,
|
| 690 |
+
cu_seqlens_k=cu_seqlens_k,
|
| 691 |
+
max_seqlen_k=max_seqlen_k,
|
| 692 |
+
)
|
| 693 |
+
|
| 694 |
+
return (
|
| 695 |
+
pad_input(attn_output, indices_q, batch_size, max_seqlen_q)
|
| 696 |
+
if key_padding_mask is not None
|
| 697 |
+
else attn_output
|
| 698 |
+
)
|
| 699 |
+
|
| 700 |
+
if self.checkpointing:
|
| 701 |
+
return torch.utils.checkpoint.checkpoint(
|
| 702 |
+
self.inner_cross_attn,
|
| 703 |
+
q,
|
| 704 |
+
kv,
|
| 705 |
+
key_padding_mask=key_padding_mask,
|
| 706 |
+
causal=causal,
|
| 707 |
+
)
|
| 708 |
+
|
| 709 |
+
return self.inner_cross_attn(q, kv, key_padding_mask=key_padding_mask, causal=causal)
|
| 710 |
+
|
| 711 |
+
def forward(
|
| 712 |
+
self,
|
| 713 |
+
x: torch.FloatTensor,
|
| 714 |
+
past_key_values: Optional[InferenceParams] = None,
|
| 715 |
+
attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
|
| 716 |
+
**kwargs,
|
| 717 |
+
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
| 718 |
+
if attention_mask is not None:
|
| 719 |
+
attention_mask = attention_mask.bool()
|
| 720 |
+
else:
|
| 721 |
+
attention_mask = None
|
| 722 |
+
|
| 723 |
+
# MHA
|
| 724 |
+
if self.n_head == self.n_head_kv:
|
| 725 |
+
if past_key_values is None:
|
| 726 |
+
# If `past_key_values` are not supplied, we run self-attention
|
| 727 |
+
attn_output = self._forward_self_attn(x, attention_mask)
|
| 728 |
+
else:
|
| 729 |
+
# If `past_key_values` are supplied, it means that we might have cached values and
|
| 730 |
+
# could take advantage of cross-attention
|
| 731 |
+
attn_output = self._forward_cross_attn(x, past_key_values, attention_mask)
|
| 732 |
+
# MQA / GQA
|
| 733 |
+
else:
|
| 734 |
+
# Regardless of `past_key_values` being supplied or not, it always use cross-attention
|
| 735 |
+
# because `q` and `kv` lengths might be different
|
| 736 |
+
attn_output = self._forward_cross_attn(x, past_key_values, attention_mask)
|
| 737 |
+
|
| 738 |
+
output = rearrange(attn_output, "... h d -> ... (h d)")
|
| 739 |
+
output = self.out_proj(output)
|
| 740 |
+
|
| 741 |
+
return output if not self.return_residual else (output, x)
|
| 742 |
+
|
| 743 |
+
|
| 744 |
+
class ParallelBlock(nn.Module):
|
| 745 |
+
"""Parallel block.
|
| 746 |
+
|
| 747 |
+
This block applies parallel mixer and MLP layers to the input (used in GPT-J and CodeGen).
|
| 748 |
+
|
| 749 |
+
"""
|
| 750 |
+
|
| 751 |
+
def __init__(
|
| 752 |
+
self,
|
| 753 |
+
config: PretrainedConfig,
|
| 754 |
+
block_idx: Optional[int] = None,
|
| 755 |
+
) -> None:
|
| 756 |
+
super().__init__()
|
| 757 |
+
|
| 758 |
+
self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
| 759 |
+
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
| 760 |
+
self.block_idx = block_idx
|
| 761 |
+
|
| 762 |
+
self.mixer = MHA(config, layer_idx=block_idx)
|
| 763 |
+
self.mlp = MLP(config)
|
| 764 |
+
|
| 765 |
+
def forward(
|
| 766 |
+
self,
|
| 767 |
+
hidden_states: torch.FloatTensor,
|
| 768 |
+
past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
|
| 769 |
+
attention_mask: Optional[torch.BoolTensor] = None,
|
| 770 |
+
**kwargs,
|
| 771 |
+
) -> torch.FloatTensor:
|
| 772 |
+
residual = hidden_states
|
| 773 |
+
hidden_states = self.ln(hidden_states)
|
| 774 |
+
|
| 775 |
+
attn_outputs = self.mixer(
|
| 776 |
+
hidden_states,
|
| 777 |
+
past_key_values=past_key_values,
|
| 778 |
+
attention_mask=attention_mask,
|
| 779 |
+
)
|
| 780 |
+
if isinstance(attn_outputs, tuple):
|
| 781 |
+
attn_outputs = attn_outputs[0]
|
| 782 |
+
|
| 783 |
+
attn_outputs = self.resid_dropout(attn_outputs)
|
| 784 |
+
feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
|
| 785 |
+
|
| 786 |
+
hidden_states = attn_outputs + feed_forward_hidden_states + residual
|
| 787 |
+
|
| 788 |
+
return hidden_states
|
| 789 |
+
|
| 790 |
+
|
| 791 |
+
class CausalLMHead(nn.Module):
|
| 792 |
+
"""Causal Language Modeling head.
|
| 793 |
+
|
| 794 |
+
Reference:
|
| 795 |
+
Improving Language Understanding by Generative Pre-Training.
|
| 796 |
+
https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
|
| 797 |
+
|
| 798 |
+
"""
|
| 799 |
+
|
| 800 |
+
def __init__(self, config: PretrainedConfig) -> None:
|
| 801 |
+
super().__init__()
|
| 802 |
+
|
| 803 |
+
self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
| 804 |
+
self.linear = nn.Linear(config.n_embd, config.vocab_size)
|
| 805 |
+
|
| 806 |
+
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
| 807 |
+
hidden_states = self.ln(hidden_states)
|
| 808 |
+
logits = self.linear(hidden_states).to(torch.float32)
|
| 809 |
+
|
| 810 |
+
return logits
|
| 811 |
+
|
| 812 |
+
|
| 813 |
+
class CausalLMLoss(nn.Module):
|
| 814 |
+
"""Causal Language Modeling loss.
|
| 815 |
+
|
| 816 |
+
Reference:
|
| 817 |
+
Improving Language Understanding by Generative Pre-Training.
|
| 818 |
+
https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
|
| 819 |
+
|
| 820 |
+
"""
|
| 821 |
+
|
| 822 |
+
def __init__(self, shift_labels: bool = True) -> None:
|
| 823 |
+
super().__init__()
|
| 824 |
+
|
| 825 |
+
self.shift_labels = shift_labels
|
| 826 |
+
self.loss_fct = nn.CrossEntropyLoss()
|
| 827 |
+
|
| 828 |
+
def forward(self, logits: torch.FloatTensor, labels: torch.LongTensor) -> torch.FloatTensor:
|
| 829 |
+
if self.shift_labels:
|
| 830 |
+
logits = logits[..., :-1, :].contiguous()
|
| 831 |
+
labels = labels[..., 1:].contiguous()
|
| 832 |
+
|
| 833 |
+
loss = self.loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
|
| 834 |
+
|
| 835 |
+
return loss
|
| 836 |
+
|
| 837 |
+
|
| 838 |
+
class PhiPreTrainedModel(PreTrainedModel):
|
| 839 |
+
"""Phi pre-trained model."""
|
| 840 |
+
|
| 841 |
+
config_class = PhiConfig
|
| 842 |
+
base_model_prefix = "transformer"
|
| 843 |
+
supports_gradient_checkpointing = False
|
| 844 |
+
_no_split_modules = ["ParallelBlock"]
|
| 845 |
+
|
| 846 |
+
def __init__(self, *inputs, **kwargs) -> None:
|
| 847 |
+
super().__init__(*inputs, **kwargs)
|
| 848 |
+
|
| 849 |
+
def _init_weights(self, module: nn.Module) -> None:
|
| 850 |
+
if isinstance(module, (nn.Linear,)):
|
| 851 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 852 |
+
if module.bias is not None:
|
| 853 |
+
module.bias.data.zero_()
|
| 854 |
+
elif isinstance(module, nn.Embedding):
|
| 855 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 856 |
+
if module.padding_idx is not None:
|
| 857 |
+
module.weight.data[module.padding_idx].zero_()
|
| 858 |
+
elif isinstance(module, nn.LayerNorm):
|
| 859 |
+
if module.bias is not None:
|
| 860 |
+
module.bias.data.zero_()
|
| 861 |
+
module.weight.data.fill_(1.0)
|
| 862 |
+
|
| 863 |
+
def prepare_inputs_for_generation(
|
| 864 |
+
self,
|
| 865 |
+
input_ids: torch.LongTensor,
|
| 866 |
+
past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
|
| 867 |
+
attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
|
| 868 |
+
**kwargs,
|
| 869 |
+
) -> Dict[str, Any]:
|
| 870 |
+
if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
|
| 871 |
+
past_key_values = InferenceParams(
|
| 872 |
+
max_seqlen=self.config.n_positions,
|
| 873 |
+
max_batch_size=input_ids.shape[0],
|
| 874 |
+
seqlen_offset=0,
|
| 875 |
+
batch_size_offset=0,
|
| 876 |
+
key_value_memory_dict={},
|
| 877 |
+
lengths_per_sample=None,
|
| 878 |
+
)
|
| 879 |
+
else:
|
| 880 |
+
# Assume that `past_key_values` has cached all tokens up to the last token in `input_ids`
|
| 881 |
+
past_key_values.seqlen_offset = input_ids.shape[1] - 1
|
| 882 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
| 883 |
+
|
| 884 |
+
return {
|
| 885 |
+
"input_ids": input_ids,
|
| 886 |
+
"past_key_values": past_key_values,
|
| 887 |
+
"attention_mask": attention_mask,
|
| 888 |
+
}
|
| 889 |
+
|
| 890 |
+
|
| 891 |
+
class PhiModel(PhiPreTrainedModel):
|
| 892 |
+
"""Phi model."""
|
| 893 |
+
|
| 894 |
+
_keys_to_ignore_on_load_missing = [""]
|
| 895 |
+
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"]
|
| 896 |
+
|
| 897 |
+
def __init__(self, config: PhiConfig) -> None:
|
| 898 |
+
super().__init__(config)
|
| 899 |
+
|
| 900 |
+
self.embd = Embedding(config)
|
| 901 |
+
self.embed_tokens = self.embd
|
| 902 |
+
self.h = nn.ModuleList([ParallelBlock(config, block_idx=i) for i in range(config.n_layer)])
|
| 903 |
+
self.gradient_checkpointing = False
|
| 904 |
+
self.post_init()
|
| 905 |
+
|
| 906 |
+
def get_input_embeddings(self) -> nn.Embedding:
|
| 907 |
+
return self.embd.wte
|
| 908 |
+
|
| 909 |
+
def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
|
| 910 |
+
self.embd.wte = new_embeddings
|
| 911 |
+
|
| 912 |
+
def forward(
|
| 913 |
+
self,
|
| 914 |
+
input_ids: torch.LongTensor,
|
| 915 |
+
past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
|
| 916 |
+
attention_mask: Optional[torch.BoolTensor] = None,
|
| 917 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 918 |
+
) -> torch.FloatTensor:
|
| 919 |
+
if inputs_embeds is None:
|
| 920 |
+
hidden_states = self.embd(input_ids)
|
| 921 |
+
else:
|
| 922 |
+
hidden_states = inputs_embeds
|
| 923 |
+
|
| 924 |
+
for layer in self.h:
|
| 925 |
+
hidden_states = layer(
|
| 926 |
+
hidden_states,
|
| 927 |
+
past_key_values=past_key_values,
|
| 928 |
+
attention_mask=attention_mask,
|
| 929 |
+
)
|
| 930 |
+
|
| 931 |
+
return hidden_states
|
| 932 |
+
|
| 933 |
+
|
| 934 |
+
class PhiForCausalLM(PhiPreTrainedModel):
|
| 935 |
+
"""Phi for Causal Language Modeling."""
|
| 936 |
+
|
| 937 |
+
_keys_to_ignore_on_load_missing = [""]
|
| 938 |
+
_keys_to_ignore_on_load_unexpected = [r"transformer\.h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"]
|
| 939 |
+
|
| 940 |
+
def __init__(self, config: PhiConfig) -> None:
|
| 941 |
+
super().__init__(config)
|
| 942 |
+
|
| 943 |
+
self.transformer = PhiModel(config)
|
| 944 |
+
self.lm_head = CausalLMHead(config)
|
| 945 |
+
self.loss = CausalLMLoss()
|
| 946 |
+
|
| 947 |
+
self.post_init()
|
| 948 |
+
|
| 949 |
+
def set_input_embeddings(self, value):
|
| 950 |
+
self.transformer.embd = value
|
| 951 |
+
|
| 952 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder
|
| 953 |
+
def set_decoder(self, decoder):
|
| 954 |
+
self.transformer = decoder
|
| 955 |
+
|
| 956 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder
|
| 957 |
+
def get_decoder(self):
|
| 958 |
+
return self.transformer
|
| 959 |
+
|
| 960 |
+
def get_input_embeddings(self):
|
| 961 |
+
return self.transformer.embd
|
| 962 |
+
|
| 963 |
+
def get_output_embeddings(self) -> nn.Linear:
|
| 964 |
+
return self.lm_head.linear
|
| 965 |
+
|
| 966 |
+
def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
|
| 967 |
+
self.lm_head.linear = new_embeddings
|
| 968 |
+
|
| 969 |
+
def forward(
|
| 970 |
+
self,
|
| 971 |
+
input_ids: torch.LongTensor,
|
| 972 |
+
past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
|
| 973 |
+
attention_mask: Optional[torch.BoolTensor] = None,
|
| 974 |
+
labels: Optional[torch.LongTensor] = None,
|
| 975 |
+
**kwargs,
|
| 976 |
+
) -> CausalLMOutputWithPast:
|
| 977 |
+
hidden_states = self.transformer(input_ids, past_key_values=past_key_values, attention_mask=attention_mask)
|
| 978 |
+
lm_logits = self.lm_head(hidden_states)
|
| 979 |
+
|
| 980 |
+
loss = None
|
| 981 |
+
if labels is not None:
|
| 982 |
+
loss = self.loss(lm_logits, labels)
|
| 983 |
+
|
| 984 |
+
return CausalLMOutputWithPast(loss=loss, logits=lm_logits, past_key_values=past_key_values)
|
ChatUniVi/model/language_model/llama.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Tuple, Union
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from torch.nn import CrossEntropyLoss
|
| 5 |
+
from transformers import AutoConfig, AutoModelForCausalLM, \
|
| 6 |
+
LlamaConfig, LlamaModel, LlamaForCausalLM
|
| 7 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 8 |
+
from models.tf.modeling_outputs import CausalLMOutputWithPastAndLabel
|
| 9 |
+
|
| 10 |
+
from ChatUniVi.model.arch import MetaModel, ChatUniViMetaForCausalLM
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ChatUniViConfig(LlamaConfig):
|
| 14 |
+
model_type = "ChatUniVi"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ChatUniViLlamaModel(MetaModel, LlamaModel):
|
| 18 |
+
config_class = ChatUniViConfig
|
| 19 |
+
|
| 20 |
+
def __init__(self, config: LlamaConfig):
|
| 21 |
+
super(ChatUniViLlamaModel, self).__init__(config)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class ChatUniViLlamaForCausalLM(LlamaForCausalLM, ChatUniViMetaForCausalLM):
|
| 25 |
+
config_class = ChatUniViConfig
|
| 26 |
+
|
| 27 |
+
def __init__(self, config):
|
| 28 |
+
super(LlamaForCausalLM, self).__init__(config)
|
| 29 |
+
self.model = ChatUniViLlamaModel(config)
|
| 30 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 31 |
+
# Initialize weights and apply final processing
|
| 32 |
+
self.post_init()
|
| 33 |
+
|
| 34 |
+
def get_model(self):
|
| 35 |
+
return self.model
|
| 36 |
+
|
| 37 |
+
def forward(
|
| 38 |
+
self,
|
| 39 |
+
input_ids: torch.LongTensor = None,
|
| 40 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 41 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 42 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 43 |
+
labels: Optional[torch.LongTensor] = None,
|
| 44 |
+
use_cache: Optional[bool] = None,
|
| 45 |
+
output_attentions: Optional[bool] = None,
|
| 46 |
+
output_hidden_states: Optional[bool] = None,
|
| 47 |
+
images: Optional[torch.FloatTensor] = None,
|
| 48 |
+
return_dict: Optional[bool] = None,
|
| 49 |
+
|
| 50 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 51 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 52 |
+
output_hidden_states = (
|
| 53 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 54 |
+
)
|
| 55 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 56 |
+
|
| 57 |
+
# print(use_cache, output_attentions, return_dict)
|
| 58 |
+
# return 0
|
| 59 |
+
if inputs_embeds is None:
|
| 60 |
+
input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
|
| 61 |
+
# else:
|
| 62 |
+
# print("不调用 prepare_inputs_labels_for_multimodal")
|
| 63 |
+
|
| 64 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 65 |
+
|
| 66 |
+
outputs = self.model(
|
| 67 |
+
input_ids=input_ids,
|
| 68 |
+
attention_mask=attention_mask,
|
| 69 |
+
past_key_values=past_key_values,
|
| 70 |
+
inputs_embeds=inputs_embeds,
|
| 71 |
+
use_cache=use_cache,
|
| 72 |
+
output_attentions=output_attentions,
|
| 73 |
+
output_hidden_states=output_hidden_states,
|
| 74 |
+
return_dict=return_dict
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
hidden_states = outputs[0]
|
| 78 |
+
logits = self.lm_head(hidden_states)
|
| 79 |
+
|
| 80 |
+
loss = None
|
| 81 |
+
if labels is not None:
|
| 82 |
+
# Shift so that tokens < n predict n
|
| 83 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 84 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 85 |
+
# Flatten the tokens
|
| 86 |
+
loss_fct = CrossEntropyLoss()
|
| 87 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
| 88 |
+
shift_labels = shift_labels.view(-1)
|
| 89 |
+
# Enable model/pipeline parallelism
|
| 90 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
| 91 |
+
loss = loss_fct(shift_logits, shift_labels)
|
| 92 |
+
|
| 93 |
+
if not return_dict:
|
| 94 |
+
output = (logits,) + outputs[1:]
|
| 95 |
+
return (loss,) + output if loss is not None else output
|
| 96 |
+
|
| 97 |
+
# return CausalLMOutputWithPast(
|
| 98 |
+
# loss=loss,
|
| 99 |
+
# logits=logits,
|
| 100 |
+
# past_key_values=outputs.past_key_values,
|
| 101 |
+
# hidden_states=outputs.hidden_states,
|
| 102 |
+
# attentions=outputs.attentions,
|
| 103 |
+
# )
|
| 104 |
+
return CausalLMOutputWithPastAndLabel(
|
| 105 |
+
loss=loss,
|
| 106 |
+
labels = labels,
|
| 107 |
+
logits=logits,
|
| 108 |
+
past_key_values=outputs.past_key_values,
|
| 109 |
+
hidden_states=outputs.hidden_states,
|
| 110 |
+
attentions=outputs.attentions,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
def prepare_inputs_for_generation(
|
| 114 |
+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
| 115 |
+
):
|
| 116 |
+
if past_key_values:
|
| 117 |
+
input_ids = input_ids[:, -1:]
|
| 118 |
+
|
| 119 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 120 |
+
if inputs_embeds is not None and past_key_values is None:
|
| 121 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
| 122 |
+
else:
|
| 123 |
+
model_inputs = {"input_ids": input_ids}
|
| 124 |
+
|
| 125 |
+
model_inputs.update(
|
| 126 |
+
{
|
| 127 |
+
"past_key_values": past_key_values,
|
| 128 |
+
"use_cache": kwargs.get("use_cache"),
|
| 129 |
+
"attention_mask": attention_mask,
|
| 130 |
+
"images": kwargs.get("images", None),
|
| 131 |
+
}
|
| 132 |
+
)
|
| 133 |
+
return model_inputs
|
| 134 |
+
|
| 135 |
+
AutoConfig.register("ChatUniVi", ChatUniViConfig)
|
| 136 |
+
AutoModelForCausalLM.register(ChatUniViConfig, ChatUniViLlamaForCausalLM)
|
ChatUniVi/model/language_model/phi.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 Haotian Liu
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
from typing import List, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
from torch.nn import CrossEntropyLoss
|
| 21 |
+
from transformers import AutoConfig, AutoModelForCausalLM
|
| 22 |
+
from .modeling_phi.modeling_phi import PhiModel, PhiForCausalLM, CausalLMHead, CausalLMLoss
|
| 23 |
+
from .modeling_phi.configuration_phi import PhiConfig
|
| 24 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 25 |
+
|
| 26 |
+
from ChatUniVi.model.arch import MetaModel, ChatUniViMetaForCausalLM
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class ChatUniViConfig(PhiConfig):
|
| 30 |
+
model_type = "ChatUniViPhi2"
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class ChatUniViPhiModel(MetaModel, PhiModel):
|
| 34 |
+
config_class = ChatUniViConfig
|
| 35 |
+
|
| 36 |
+
def __init__(self, config: PhiConfig):
|
| 37 |
+
super(ChatUniViPhiModel, self).__init__(config)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class ChatUniViPhiForCausalLM(PhiForCausalLM, ChatUniViMetaForCausalLM):
|
| 41 |
+
config_class = ChatUniViConfig
|
| 42 |
+
supports_gradient_checkpointing = True
|
| 43 |
+
|
| 44 |
+
def __init__(self, config):
|
| 45 |
+
super(PhiForCausalLM, self).__init__(config)
|
| 46 |
+
self.config = config
|
| 47 |
+
self.transformer = ChatUniViPhiModel(config)
|
| 48 |
+
self.lm_head = CausalLMHead(config)
|
| 49 |
+
self.loss = CausalLMLoss()
|
| 50 |
+
|
| 51 |
+
self.post_init()
|
| 52 |
+
|
| 53 |
+
def get_model(self):
|
| 54 |
+
return self.transformer
|
| 55 |
+
|
| 56 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 57 |
+
module.gradient_checkpointing = value
|
| 58 |
+
|
| 59 |
+
def forward(
|
| 60 |
+
self,
|
| 61 |
+
input_ids: torch.LongTensor = None,
|
| 62 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 63 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 64 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 65 |
+
labels: Optional[torch.LongTensor] = None,
|
| 66 |
+
use_cache: Optional[bool] = None,
|
| 67 |
+
output_attentions: Optional[bool] = None,
|
| 68 |
+
output_hidden_states: Optional[bool] = None,
|
| 69 |
+
images: Optional[torch.FloatTensor] = None,
|
| 70 |
+
return_dict: Optional[bool] = None,
|
| 71 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 72 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 73 |
+
output_hidden_states = (
|
| 74 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 75 |
+
)
|
| 76 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 77 |
+
|
| 78 |
+
input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
|
| 79 |
+
|
| 80 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 81 |
+
|
| 82 |
+
outputs = self.transformer(
|
| 83 |
+
input_ids=input_ids,
|
| 84 |
+
attention_mask=attention_mask,
|
| 85 |
+
past_key_values=past_key_values,
|
| 86 |
+
inputs_embeds=inputs_embeds,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
hidden_states = outputs
|
| 90 |
+
logits = self.lm_head(hidden_states)
|
| 91 |
+
|
| 92 |
+
loss = None
|
| 93 |
+
if labels is not None:
|
| 94 |
+
# Shift so that tokens < n predict n
|
| 95 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 96 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 97 |
+
# Flatten the tokens
|
| 98 |
+
loss_fct = CrossEntropyLoss()
|
| 99 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
| 100 |
+
shift_labels = shift_labels.view(-1)
|
| 101 |
+
# Enable model/pipeline parallelism
|
| 102 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
| 103 |
+
try:
|
| 104 |
+
loss = loss_fct(shift_logits, shift_labels)
|
| 105 |
+
except:
|
| 106 |
+
loss = torch.nn.Parameter(torch.zeros(1), requires_grad=True)
|
| 107 |
+
|
| 108 |
+
if not return_dict:
|
| 109 |
+
output = (logits,) + outputs
|
| 110 |
+
return (loss,) + output if loss is not None else output
|
| 111 |
+
|
| 112 |
+
return CausalLMOutputWithPast(
|
| 113 |
+
loss=loss,
|
| 114 |
+
logits=logits,
|
| 115 |
+
hidden_states=outputs,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
def prepare_inputs_for_generation(
|
| 119 |
+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
| 120 |
+
):
|
| 121 |
+
if past_key_values:
|
| 122 |
+
input_ids = input_ids[:, -1:]
|
| 123 |
+
|
| 124 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 125 |
+
if inputs_embeds is not None and past_key_values is None:
|
| 126 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
| 127 |
+
else:
|
| 128 |
+
model_inputs = {"input_ids": input_ids}
|
| 129 |
+
|
| 130 |
+
model_inputs.update(
|
| 131 |
+
{
|
| 132 |
+
"past_key_values": past_key_values,
|
| 133 |
+
"use_cache": kwargs.get("use_cache"),
|
| 134 |
+
"attention_mask": attention_mask,
|
| 135 |
+
"images": kwargs.get("images", None),
|
| 136 |
+
}
|
| 137 |
+
)
|
| 138 |
+
return model_inputs
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
AutoConfig.register("ChatUniViPhi2", ChatUniViConfig)
|
| 142 |
+
AutoModelForCausalLM.register(ChatUniViConfig, ChatUniViPhiForCausalLM)
|
ChatUniVi/model/make_delta.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Usage:
|
| 3 |
+
python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta
|
| 4 |
+
"""
|
| 5 |
+
import argparse
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 10 |
+
from llava.model.utils import auto_upgrade
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
|
| 14 |
+
print("Loading base model")
|
| 15 |
+
base = AutoModelForCausalLM.from_pretrained(
|
| 16 |
+
base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
| 17 |
+
|
| 18 |
+
print("Loading target model")
|
| 19 |
+
auto_upgrade(target_model_path)
|
| 20 |
+
target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
| 21 |
+
|
| 22 |
+
print("Calculating delta")
|
| 23 |
+
for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
|
| 24 |
+
if name not in base.state_dict():
|
| 25 |
+
assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
|
| 26 |
+
continue
|
| 27 |
+
if param.data.shape == base.state_dict()[name].shape:
|
| 28 |
+
param.data -= base.state_dict()[name]
|
| 29 |
+
else:
|
| 30 |
+
assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
|
| 31 |
+
bparam = base.state_dict()[name]
|
| 32 |
+
param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam
|
| 33 |
+
|
| 34 |
+
print("Saving delta")
|
| 35 |
+
if hub_repo_id:
|
| 36 |
+
kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
|
| 37 |
+
else:
|
| 38 |
+
kwargs = {}
|
| 39 |
+
target.save_pretrained(delta_path, **kwargs)
|
| 40 |
+
target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
|
| 41 |
+
target_tokenizer.save_pretrained(delta_path, **kwargs)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
if __name__ == "__main__":
|
| 45 |
+
parser = argparse.ArgumentParser()
|
| 46 |
+
parser.add_argument("--base-model-path", type=str, required=True)
|
| 47 |
+
parser.add_argument("--target-model-path", type=str, required=True)
|
| 48 |
+
parser.add_argument("--delta-path", type=str, required=True)
|
| 49 |
+
parser.add_argument("--hub-repo-id", type=str, default=None)
|
| 50 |
+
args = parser.parse_args()
|
| 51 |
+
|
| 52 |
+
make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)
|
ChatUniVi/model/multimodal_encoder/builder.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .clip_encoder import CLIPVisionTower
|
| 2 |
+
from .eva_encoder import EVAVisionTower
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def build_vision_tower(vision_tower_cfg, **kwargs):
|
| 6 |
+
vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
|
| 7 |
+
# if vision_tower.startswith("openai") or vision_tower.startswith("laion"):
|
| 8 |
+
# return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
|
| 9 |
+
#
|
| 10 |
+
# elif vision_tower.startswith("eva_vit_g"):
|
| 11 |
+
# return EVAVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
|
| 12 |
+
#
|
| 13 |
+
# raise ValueError(f'Unknown vision tower: {vision_tower}')
|
| 14 |
+
return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
|
ChatUniVi/model/multimodal_encoder/clip_encoder.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class CLIPVisionTower(nn.Module):
|
| 8 |
+
def __init__(self, vision_tower, args=None, delay_load=False):
|
| 9 |
+
super().__init__()
|
| 10 |
+
|
| 11 |
+
self.is_loaded = False
|
| 12 |
+
|
| 13 |
+
self.vision_tower_name = vision_tower
|
| 14 |
+
if args is None:
|
| 15 |
+
self.select_layer = -2
|
| 16 |
+
self.select_feature = 'patch'
|
| 17 |
+
else:
|
| 18 |
+
self.select_layer = args.mm_vision_select_layer
|
| 19 |
+
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
|
| 20 |
+
|
| 21 |
+
if not delay_load:
|
| 22 |
+
self.load_model()
|
| 23 |
+
else:
|
| 24 |
+
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
|
| 25 |
+
|
| 26 |
+
def load_model(self):
|
| 27 |
+
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
|
| 28 |
+
self.image_eval_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
|
| 29 |
+
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
|
| 30 |
+
self.vision_tower.requires_grad_(False)
|
| 31 |
+
|
| 32 |
+
self.is_loaded = True
|
| 33 |
+
|
| 34 |
+
def feature_select(self, image_forward_outs, select_feature='patch'):
|
| 35 |
+
image_features = image_forward_outs.hidden_states[self.select_layer]
|
| 36 |
+
if select_feature == 'patch':
|
| 37 |
+
image_features = image_features[:, 1:]
|
| 38 |
+
elif select_feature == 'cls_patch':
|
| 39 |
+
image_features = image_features
|
| 40 |
+
else:
|
| 41 |
+
raise ValueError(f'Unexpected select feature: {self.select_feature}')
|
| 42 |
+
return image_features
|
| 43 |
+
|
| 44 |
+
@torch.no_grad()
|
| 45 |
+
def forward(self, images, select_feature='patch'):
|
| 46 |
+
if type(images) is list:
|
| 47 |
+
image_features = []
|
| 48 |
+
for image in images:
|
| 49 |
+
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
|
| 50 |
+
image_feature = self.feature_select(image_forward_out, select_feature).to(image.dtype)
|
| 51 |
+
image_features.append(image_feature)
|
| 52 |
+
else:
|
| 53 |
+
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
|
| 54 |
+
image_features = self.feature_select(image_forward_outs, select_feature).to(images.dtype)
|
| 55 |
+
|
| 56 |
+
return image_features
|
| 57 |
+
|
| 58 |
+
@property
|
| 59 |
+
def dummy_feature(self):
|
| 60 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
| 61 |
+
|
| 62 |
+
@property
|
| 63 |
+
def dtype(self):
|
| 64 |
+
return self.vision_tower.dtype
|
| 65 |
+
|
| 66 |
+
@property
|
| 67 |
+
def device(self):
|
| 68 |
+
return self.vision_tower.device
|
| 69 |
+
|
| 70 |
+
@property
|
| 71 |
+
def config(self):
|
| 72 |
+
if self.is_loaded:
|
| 73 |
+
return self.vision_tower.config
|
| 74 |
+
else:
|
| 75 |
+
return self.cfg_only
|
| 76 |
+
|
| 77 |
+
@property
|
| 78 |
+
def hidden_size(self):
|
| 79 |
+
return self.config.hidden_size
|
| 80 |
+
|
| 81 |
+
@property
|
| 82 |
+
def num_patches(self):
|
| 83 |
+
return (self.config.image_size // self.config.patch_size) ** 2
|
ChatUniVi/model/multimodal_encoder/eva_encoder.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from .eva_vit import create_eva_vit_g, _cfg
|
| 4 |
+
from .processor import ImageTrainProcessor, ImageEvalProcessor
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class EVAVisionTower(nn.Module):
|
| 8 |
+
def __init__(self, vision_tower, args, delay_load=False):
|
| 9 |
+
super().__init__()
|
| 10 |
+
|
| 11 |
+
self.is_loaded = False
|
| 12 |
+
|
| 13 |
+
self.vision_tower_name = vision_tower
|
| 14 |
+
self.select_layer = args.mm_vision_select_layer
|
| 15 |
+
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
|
| 16 |
+
|
| 17 |
+
if not delay_load:
|
| 18 |
+
self.load_model()
|
| 19 |
+
else:
|
| 20 |
+
self.cfg_only = _cfg()
|
| 21 |
+
|
| 22 |
+
def load_model(self):
|
| 23 |
+
self.image_processor = ImageTrainProcessor()
|
| 24 |
+
self.image_eval_processor = ImageEvalProcessor()
|
| 25 |
+
self.vision_tower = create_eva_vit_g(
|
| 26 |
+
img_size=224, drop_path_rate=0, use_checkpoint=False, precision="fp16"
|
| 27 |
+
)
|
| 28 |
+
# self.vision_tower.requires_grad_(False)
|
| 29 |
+
|
| 30 |
+
self.is_loaded = True
|
| 31 |
+
|
| 32 |
+
def feature_select(self, image_forward_outs, select_feature='patch'):
|
| 33 |
+
image_features = image_forward_outs[self.select_layer]
|
| 34 |
+
if select_feature == 'patch':
|
| 35 |
+
image_features = image_features[:, 1:]
|
| 36 |
+
elif select_feature == 'cls_patch':
|
| 37 |
+
image_features = image_features
|
| 38 |
+
else:
|
| 39 |
+
raise ValueError(f'Unexpected select feature: {self.select_feature}')
|
| 40 |
+
return image_features
|
| 41 |
+
|
| 42 |
+
@torch.no_grad()
|
| 43 |
+
def forward(self, images, select_feature='patch'):
|
| 44 |
+
if type(images) is list:
|
| 45 |
+
image_features = []
|
| 46 |
+
for image in images:
|
| 47 |
+
image_forward_out = self.vision_tower.get_intermediate_layers(image.to(device=self.device, dtype=self.dtype).unsqueeze(0),)
|
| 48 |
+
image_feature = self.feature_select(image_forward_out, select_feature).to(image.dtype)
|
| 49 |
+
image_features.append(image_feature)
|
| 50 |
+
else:
|
| 51 |
+
image_forward_outs = self.vision_tower.get_intermediate_layers(images.to(device=self.device, dtype=self.dtype))
|
| 52 |
+
image_features = self.feature_select(image_forward_outs, select_feature).to(images.dtype)
|
| 53 |
+
|
| 54 |
+
return image_features
|
| 55 |
+
|
| 56 |
+
@property
|
| 57 |
+
def dummy_feature(self):
|
| 58 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
| 59 |
+
|
| 60 |
+
@property
|
| 61 |
+
def dtype(self):
|
| 62 |
+
return self.vision_tower.cls_token.dtype
|
| 63 |
+
|
| 64 |
+
@property
|
| 65 |
+
def device(self):
|
| 66 |
+
return self.vision_tower.cls_token.device
|
| 67 |
+
|
| 68 |
+
@property
|
| 69 |
+
def config(self):
|
| 70 |
+
if self.is_loaded:
|
| 71 |
+
return self.vision_tower.config
|
| 72 |
+
else:
|
| 73 |
+
return self.cfg_only
|
| 74 |
+
|
| 75 |
+
@property
|
| 76 |
+
def hidden_size(self):
|
| 77 |
+
return self.vision_tower.num_features
|
| 78 |
+
|
| 79 |
+
@property
|
| 80 |
+
def num_patches(self):
|
| 81 |
+
return (self.config.image_size // self.config.patch_size) ** 2
|
ChatUniVi/model/multimodal_encoder/eva_vit.py
ADDED
|
@@ -0,0 +1,448 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Based on EVA, BEIT, timm and DeiT code bases
|
| 2 |
+
# https://github.com/baaivision/EVA
|
| 3 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
| 4 |
+
# https://github.com/microsoft/unilm/tree/master/beit
|
| 5 |
+
# https://github.com/facebookresearch/deit/
|
| 6 |
+
# https://github.com/facebookresearch/dino
|
| 7 |
+
# --------------------------------------------------------'
|
| 8 |
+
import math
|
| 9 |
+
from functools import partial
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
import torch.utils.checkpoint as checkpoint
|
| 15 |
+
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
|
| 16 |
+
from timm.models.registry import register_model
|
| 17 |
+
|
| 18 |
+
from .utils import download_cached_file
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _cfg(url='', **kwargs):
|
| 22 |
+
return {
|
| 23 |
+
'url': url,
|
| 24 |
+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
| 25 |
+
'crop_pct': .9, 'interpolation': 'bicubic',
|
| 26 |
+
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
|
| 27 |
+
**kwargs
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class DropPath(nn.Module):
|
| 32 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(self, drop_prob=None):
|
| 36 |
+
super(DropPath, self).__init__()
|
| 37 |
+
self.drop_prob = drop_prob
|
| 38 |
+
|
| 39 |
+
def forward(self, x):
|
| 40 |
+
return drop_path(x, self.drop_prob, self.training)
|
| 41 |
+
|
| 42 |
+
def extra_repr(self) -> str:
|
| 43 |
+
return 'p={}'.format(self.drop_prob)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class Mlp(nn.Module):
|
| 47 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 48 |
+
super().__init__()
|
| 49 |
+
out_features = out_features or in_features
|
| 50 |
+
hidden_features = hidden_features or in_features
|
| 51 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 52 |
+
self.act = act_layer()
|
| 53 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 54 |
+
self.drop = nn.Dropout(drop)
|
| 55 |
+
|
| 56 |
+
def forward(self, x):
|
| 57 |
+
x = self.fc1(x)
|
| 58 |
+
x = self.act(x)
|
| 59 |
+
# x = self.drop(x)
|
| 60 |
+
# commit this for the orignal BERT implement
|
| 61 |
+
x = self.fc2(x)
|
| 62 |
+
x = self.drop(x)
|
| 63 |
+
return x
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class Attention(nn.Module):
|
| 67 |
+
def __init__(
|
| 68 |
+
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
|
| 69 |
+
proj_drop=0., window_size=None, attn_head_dim=None):
|
| 70 |
+
super().__init__()
|
| 71 |
+
self.num_heads = num_heads
|
| 72 |
+
head_dim = dim // num_heads
|
| 73 |
+
if attn_head_dim is not None:
|
| 74 |
+
head_dim = attn_head_dim
|
| 75 |
+
all_head_dim = head_dim * self.num_heads
|
| 76 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 77 |
+
|
| 78 |
+
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
| 79 |
+
if qkv_bias:
|
| 80 |
+
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
| 81 |
+
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
| 82 |
+
else:
|
| 83 |
+
self.q_bias = None
|
| 84 |
+
self.v_bias = None
|
| 85 |
+
|
| 86 |
+
if window_size:
|
| 87 |
+
self.window_size = window_size
|
| 88 |
+
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
| 89 |
+
self.relative_position_bias_table = nn.Parameter(
|
| 90 |
+
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
| 91 |
+
# cls to token & token 2 cls & cls to cls
|
| 92 |
+
|
| 93 |
+
# get pair-wise relative position index for each token inside the window
|
| 94 |
+
coords_h = torch.arange(window_size[0])
|
| 95 |
+
coords_w = torch.arange(window_size[1])
|
| 96 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
| 97 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
| 98 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
| 99 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
| 100 |
+
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
| 101 |
+
relative_coords[:, :, 1] += window_size[1] - 1
|
| 102 |
+
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
| 103 |
+
relative_position_index = \
|
| 104 |
+
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
|
| 105 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
| 106 |
+
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
| 107 |
+
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
| 108 |
+
relative_position_index[0, 0] = self.num_relative_distance - 1
|
| 109 |
+
|
| 110 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
| 111 |
+
else:
|
| 112 |
+
self.window_size = None
|
| 113 |
+
self.relative_position_bias_table = None
|
| 114 |
+
self.relative_position_index = None
|
| 115 |
+
|
| 116 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 117 |
+
self.proj = nn.Linear(all_head_dim, dim)
|
| 118 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 119 |
+
|
| 120 |
+
def forward(self, x, rel_pos_bias=None):
|
| 121 |
+
B, N, C = x.shape
|
| 122 |
+
qkv_bias = None
|
| 123 |
+
if self.q_bias is not None:
|
| 124 |
+
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
|
| 125 |
+
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 126 |
+
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
| 127 |
+
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
| 128 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
| 129 |
+
|
| 130 |
+
q = q * self.scale
|
| 131 |
+
attn = (q @ k.transpose(-2, -1))
|
| 132 |
+
|
| 133 |
+
if self.relative_position_bias_table is not None:
|
| 134 |
+
relative_position_bias = \
|
| 135 |
+
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
| 136 |
+
self.window_size[0] * self.window_size[1] + 1,
|
| 137 |
+
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
| 138 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
| 139 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
| 140 |
+
|
| 141 |
+
if rel_pos_bias is not None:
|
| 142 |
+
attn = attn + rel_pos_bias
|
| 143 |
+
|
| 144 |
+
attn = attn.softmax(dim=-1)
|
| 145 |
+
attn = self.attn_drop(attn)
|
| 146 |
+
|
| 147 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
| 148 |
+
x = self.proj(x)
|
| 149 |
+
x = self.proj_drop(x)
|
| 150 |
+
return x
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class Block(nn.Module):
|
| 154 |
+
|
| 155 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
| 156 |
+
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
|
| 157 |
+
window_size=None, attn_head_dim=None):
|
| 158 |
+
super().__init__()
|
| 159 |
+
self.norm1 = norm_layer(dim)
|
| 160 |
+
self.attn = Attention(
|
| 161 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 162 |
+
attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
|
| 163 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
| 164 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 165 |
+
self.norm2 = norm_layer(dim)
|
| 166 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 167 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 168 |
+
|
| 169 |
+
if init_values is not None and init_values > 0:
|
| 170 |
+
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
|
| 171 |
+
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
|
| 172 |
+
else:
|
| 173 |
+
self.gamma_1, self.gamma_2 = None, None
|
| 174 |
+
|
| 175 |
+
def forward(self, x, rel_pos_bias=None):
|
| 176 |
+
if self.gamma_1 is None:
|
| 177 |
+
x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
|
| 178 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 179 |
+
else:
|
| 180 |
+
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
|
| 181 |
+
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
| 182 |
+
return x
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class PatchEmbed(nn.Module):
|
| 186 |
+
""" Image to Patch Embedding
|
| 187 |
+
"""
|
| 188 |
+
|
| 189 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
| 190 |
+
super().__init__()
|
| 191 |
+
img_size = to_2tuple(img_size)
|
| 192 |
+
patch_size = to_2tuple(patch_size)
|
| 193 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
| 194 |
+
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
| 195 |
+
self.img_size = img_size
|
| 196 |
+
self.patch_size = patch_size
|
| 197 |
+
self.num_patches = num_patches
|
| 198 |
+
|
| 199 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
| 200 |
+
|
| 201 |
+
def forward(self, x, **kwargs):
|
| 202 |
+
B, C, H, W = x.shape
|
| 203 |
+
# FIXME look at relaxing size constraints
|
| 204 |
+
assert H == self.img_size[0] and W == self.img_size[1], \
|
| 205 |
+
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
| 206 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
| 207 |
+
return x
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class RelativePositionBias(nn.Module):
|
| 211 |
+
|
| 212 |
+
def __init__(self, window_size, num_heads):
|
| 213 |
+
super().__init__()
|
| 214 |
+
self.window_size = window_size
|
| 215 |
+
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
| 216 |
+
self.relative_position_bias_table = nn.Parameter(
|
| 217 |
+
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
| 218 |
+
# cls to token & token 2 cls & cls to cls
|
| 219 |
+
|
| 220 |
+
# get pair-wise relative position index for each token inside the window
|
| 221 |
+
coords_h = torch.arange(window_size[0])
|
| 222 |
+
coords_w = torch.arange(window_size[1])
|
| 223 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
| 224 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
| 225 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
| 226 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
| 227 |
+
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
| 228 |
+
relative_coords[:, :, 1] += window_size[1] - 1
|
| 229 |
+
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
| 230 |
+
relative_position_index = \
|
| 231 |
+
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
|
| 232 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
| 233 |
+
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
| 234 |
+
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
| 235 |
+
relative_position_index[0, 0] = self.num_relative_distance - 1
|
| 236 |
+
|
| 237 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
| 238 |
+
|
| 239 |
+
# trunc_normal_(self.relative_position_bias_table, std=.02)
|
| 240 |
+
|
| 241 |
+
def forward(self):
|
| 242 |
+
relative_position_bias = \
|
| 243 |
+
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
| 244 |
+
self.window_size[0] * self.window_size[1] + 1,
|
| 245 |
+
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
| 246 |
+
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
class VisionTransformer(nn.Module):
|
| 250 |
+
""" Vision Transformer with support for patch or hybrid CNN input stage
|
| 251 |
+
"""
|
| 252 |
+
|
| 253 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
| 254 |
+
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
| 255 |
+
drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None,
|
| 256 |
+
use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,
|
| 257 |
+
use_mean_pooling=True, init_scale=0.001, use_checkpoint=False):
|
| 258 |
+
super().__init__()
|
| 259 |
+
self.image_size = img_size
|
| 260 |
+
self.num_classes = num_classes
|
| 261 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
| 262 |
+
|
| 263 |
+
self.patch_embed = PatchEmbed(
|
| 264 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
| 265 |
+
num_patches = self.patch_embed.num_patches
|
| 266 |
+
|
| 267 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 268 |
+
if use_abs_pos_emb:
|
| 269 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
| 270 |
+
else:
|
| 271 |
+
self.pos_embed = None
|
| 272 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 273 |
+
|
| 274 |
+
if use_shared_rel_pos_bias:
|
| 275 |
+
self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
|
| 276 |
+
else:
|
| 277 |
+
self.rel_pos_bias = None
|
| 278 |
+
self.use_checkpoint = use_checkpoint
|
| 279 |
+
|
| 280 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
| 281 |
+
self.use_rel_pos_bias = use_rel_pos_bias
|
| 282 |
+
self.blocks = nn.ModuleList([
|
| 283 |
+
Block(
|
| 284 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 285 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
| 286 |
+
init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
|
| 287 |
+
for i in range(depth)])
|
| 288 |
+
# self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
|
| 289 |
+
# self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
|
| 290 |
+
# self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
| 291 |
+
|
| 292 |
+
if self.pos_embed is not None:
|
| 293 |
+
trunc_normal_(self.pos_embed, std=.02)
|
| 294 |
+
trunc_normal_(self.cls_token, std=.02)
|
| 295 |
+
# trunc_normal_(self.mask_token, std=.02)
|
| 296 |
+
# if isinstance(self.head, nn.Linear):
|
| 297 |
+
# trunc_normal_(self.head.weight, std=.02)
|
| 298 |
+
self.apply(self._init_weights)
|
| 299 |
+
self.fix_init_weight()
|
| 300 |
+
|
| 301 |
+
# if isinstance(self.head, nn.Linear):
|
| 302 |
+
# self.head.weight.data.mul_(init_scale)
|
| 303 |
+
# self.head.bias.data.mul_(init_scale)
|
| 304 |
+
|
| 305 |
+
def fix_init_weight(self):
|
| 306 |
+
def rescale(param, layer_id):
|
| 307 |
+
param.div_(math.sqrt(2.0 * layer_id))
|
| 308 |
+
|
| 309 |
+
for layer_id, layer in enumerate(self.blocks):
|
| 310 |
+
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
| 311 |
+
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
| 312 |
+
|
| 313 |
+
def _init_weights(self, m):
|
| 314 |
+
if isinstance(m, nn.Linear):
|
| 315 |
+
trunc_normal_(m.weight, std=.02)
|
| 316 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 317 |
+
nn.init.constant_(m.bias, 0)
|
| 318 |
+
elif isinstance(m, nn.LayerNorm):
|
| 319 |
+
nn.init.constant_(m.bias, 0)
|
| 320 |
+
nn.init.constant_(m.weight, 1.0)
|
| 321 |
+
|
| 322 |
+
def get_classifier(self):
|
| 323 |
+
return self.head
|
| 324 |
+
|
| 325 |
+
def reset_classifier(self, num_classes, global_pool=''):
|
| 326 |
+
self.num_classes = num_classes
|
| 327 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
| 328 |
+
|
| 329 |
+
def forward_features(self, x):
|
| 330 |
+
x = self.patch_embed(x)
|
| 331 |
+
batch_size, seq_len, _ = x.size()
|
| 332 |
+
|
| 333 |
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
| 334 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 335 |
+
if self.pos_embed is not None:
|
| 336 |
+
x = x + self.pos_embed
|
| 337 |
+
x = self.pos_drop(x)
|
| 338 |
+
|
| 339 |
+
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
|
| 340 |
+
for blk in self.blocks:
|
| 341 |
+
if self.use_checkpoint:
|
| 342 |
+
x = checkpoint.checkpoint(blk, x, rel_pos_bias)
|
| 343 |
+
else:
|
| 344 |
+
x = blk(x, rel_pos_bias)
|
| 345 |
+
return x
|
| 346 |
+
|
| 347 |
+
# x = self.norm(x)
|
| 348 |
+
|
| 349 |
+
# if self.fc_norm is not None:
|
| 350 |
+
# t = x[:, 1:, :]
|
| 351 |
+
# return self.fc_norm(t.mean(1))
|
| 352 |
+
# else:
|
| 353 |
+
# return x[:, 0]
|
| 354 |
+
|
| 355 |
+
def forward(self, x):
|
| 356 |
+
x = self.forward_features(x)
|
| 357 |
+
# x = self.head(x)
|
| 358 |
+
return x
|
| 359 |
+
|
| 360 |
+
def get_intermediate_layers(self, x):
|
| 361 |
+
x = self.patch_embed(x)
|
| 362 |
+
batch_size, seq_len, _ = x.size()
|
| 363 |
+
|
| 364 |
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
| 365 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 366 |
+
if self.pos_embed is not None:
|
| 367 |
+
x = x + self.pos_embed
|
| 368 |
+
x = self.pos_drop(x)
|
| 369 |
+
|
| 370 |
+
features = []
|
| 371 |
+
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
|
| 372 |
+
for blk in self.blocks:
|
| 373 |
+
x = blk(x, rel_pos_bias)
|
| 374 |
+
features.append(x)
|
| 375 |
+
|
| 376 |
+
return features
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
def interpolate_pos_embed(model, checkpoint_model):
|
| 380 |
+
if 'pos_embed' in checkpoint_model:
|
| 381 |
+
pos_embed_checkpoint = checkpoint_model['pos_embed'].float()
|
| 382 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
| 383 |
+
num_patches = model.patch_embed.num_patches
|
| 384 |
+
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
| 385 |
+
# height (== width) for the checkpoint position embedding
|
| 386 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
| 387 |
+
# height (== width) for the new position embedding
|
| 388 |
+
new_size = int(num_patches ** 0.5)
|
| 389 |
+
# class_token and dist_token are kept unchanged
|
| 390 |
+
if orig_size != new_size:
|
| 391 |
+
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
| 392 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
| 393 |
+
# only the position tokens are interpolated
|
| 394 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
| 395 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
| 396 |
+
pos_tokens = torch.nn.functional.interpolate(
|
| 397 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
| 398 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
| 399 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
| 400 |
+
checkpoint_model['pos_embed'] = new_pos_embed
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
def convert_weights_to_fp16(model: nn.Module):
|
| 404 |
+
"""Convert applicable model parameters to fp16"""
|
| 405 |
+
|
| 406 |
+
def _convert_weights_to_fp16(l):
|
| 407 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
| 408 |
+
l.weight.data = l.weight.data.half()
|
| 409 |
+
if l.bias is not None:
|
| 410 |
+
l.bias.data = l.bias.data.half()
|
| 411 |
+
|
| 412 |
+
# if isinstance(l, (nn.MultiheadAttention, Attention)):
|
| 413 |
+
# for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
| 414 |
+
# tensor = getattr(l, attr)
|
| 415 |
+
# if tensor is not None:
|
| 416 |
+
# tensor.data = tensor.data.half()
|
| 417 |
+
|
| 418 |
+
model.apply(_convert_weights_to_fp16)
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
def create_eva_vit_g(img_size=224, drop_path_rate=0.4, use_checkpoint=False, precision="fp16"):
|
| 422 |
+
model = VisionTransformer(
|
| 423 |
+
img_size=img_size,
|
| 424 |
+
patch_size=14,
|
| 425 |
+
use_mean_pooling=False,
|
| 426 |
+
embed_dim=1408,
|
| 427 |
+
depth=39,
|
| 428 |
+
num_heads=1408 // 88,
|
| 429 |
+
mlp_ratio=4.3637,
|
| 430 |
+
qkv_bias=True,
|
| 431 |
+
drop_path_rate=drop_path_rate,
|
| 432 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
| 433 |
+
use_checkpoint=use_checkpoint,
|
| 434 |
+
)
|
| 435 |
+
url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth"
|
| 436 |
+
cached_file = download_cached_file(
|
| 437 |
+
url, check_hash=False, progress=True
|
| 438 |
+
)
|
| 439 |
+
state_dict = torch.load(cached_file, map_location="cpu")
|
| 440 |
+
interpolate_pos_embed(model, state_dict)
|
| 441 |
+
|
| 442 |
+
incompatible_keys = model.load_state_dict(state_dict, strict=False)
|
| 443 |
+
# print(incompatible_keys)
|
| 444 |
+
|
| 445 |
+
if precision == "fp16":
|
| 446 |
+
# model.to("cuda")
|
| 447 |
+
convert_weights_to_fp16(model)
|
| 448 |
+
return model
|
ChatUniVi/model/multimodal_encoder/processor.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from torchvision import transforms
|
| 3 |
+
from torchvision.transforms.functional import InterpolationMode
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class BaseProcessor:
|
| 7 |
+
def __init__(self, mean=None, std=None):
|
| 8 |
+
if mean is None:
|
| 9 |
+
mean = (0.48145466, 0.4578275, 0.40821073)
|
| 10 |
+
if std is None:
|
| 11 |
+
std = (0.26862954, 0.26130258, 0.27577711)
|
| 12 |
+
|
| 13 |
+
self.normalize = transforms.Normalize(mean, std)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ImageTrainProcessor(BaseProcessor):
|
| 17 |
+
def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0):
|
| 18 |
+
super().__init__(mean=mean, std=std)
|
| 19 |
+
|
| 20 |
+
self.transform = transforms.Compose(
|
| 21 |
+
[
|
| 22 |
+
transforms.Resize(
|
| 23 |
+
(image_size, image_size), interpolation=InterpolationMode.BICUBIC
|
| 24 |
+
),
|
| 25 |
+
transforms.ToTensor(),
|
| 26 |
+
self.normalize,
|
| 27 |
+
]
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
def preprocess(self, item, return_tensors):
|
| 31 |
+
return {'pixel_values': [self.transform(item)]}
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class ImageEvalProcessor(BaseProcessor):
|
| 35 |
+
def __init__(self, image_size=224, mean=None, std=None):
|
| 36 |
+
super().__init__(mean=mean, std=std)
|
| 37 |
+
|
| 38 |
+
self.transform = transforms.Compose(
|
| 39 |
+
[
|
| 40 |
+
transforms.Resize(
|
| 41 |
+
(image_size, image_size), interpolation=InterpolationMode.BICUBIC
|
| 42 |
+
),
|
| 43 |
+
transforms.ToTensor(),
|
| 44 |
+
self.normalize,
|
| 45 |
+
]
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
def preprocess(self, item, return_tensors):
|
| 49 |
+
return {'pixel_values': [self.transform(item)]}
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class QWenImageProcessor(BaseProcessor):
|
| 53 |
+
def __init__(self, image_size=224, mean=None, std=None):
|
| 54 |
+
super().__init__(mean=mean, std=std)
|
| 55 |
+
|
| 56 |
+
mean = (0.48145466, 0.4578275, 0.40821073)
|
| 57 |
+
std = (0.26862954, 0.26130258, 0.27577711)
|
| 58 |
+
self.transform = transforms.Compose([
|
| 59 |
+
transforms.Resize(
|
| 60 |
+
(448, 448),
|
| 61 |
+
interpolation=InterpolationMode.BICUBIC
|
| 62 |
+
),
|
| 63 |
+
transforms.ToTensor(),
|
| 64 |
+
transforms.Normalize(mean=mean, std=std),
|
| 65 |
+
])
|
| 66 |
+
|
| 67 |
+
def preprocess(self, item, return_tensors):
|
| 68 |
+
return {'pixel_values': [self.transform(item)]}
|
ChatUniVi/model/multimodal_encoder/utils.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
| 3 |
+
All rights reserved.
|
| 4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import datetime
|
| 9 |
+
import functools
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.distributed as dist
|
| 14 |
+
import timm.models.hub as timm_hub
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def setup_for_distributed(is_master):
|
| 18 |
+
"""
|
| 19 |
+
This function disables printing when not in master process
|
| 20 |
+
"""
|
| 21 |
+
import builtins as __builtin__
|
| 22 |
+
|
| 23 |
+
builtin_print = __builtin__.print
|
| 24 |
+
|
| 25 |
+
def print(*args, **kwargs):
|
| 26 |
+
force = kwargs.pop("force", False)
|
| 27 |
+
if is_master or force:
|
| 28 |
+
builtin_print(*args, **kwargs)
|
| 29 |
+
|
| 30 |
+
__builtin__.print = print
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def is_dist_avail_and_initialized():
|
| 34 |
+
if not dist.is_available():
|
| 35 |
+
return False
|
| 36 |
+
if not dist.is_initialized():
|
| 37 |
+
return False
|
| 38 |
+
return True
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def get_world_size():
|
| 42 |
+
if not is_dist_avail_and_initialized():
|
| 43 |
+
return 1
|
| 44 |
+
return dist.get_world_size()
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def get_rank():
|
| 48 |
+
if not is_dist_avail_and_initialized():
|
| 49 |
+
return 0
|
| 50 |
+
return dist.get_rank()
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def is_main_process():
|
| 54 |
+
return get_rank() == 0
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def init_distributed_mode(args):
|
| 58 |
+
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
| 59 |
+
args.rank = int(os.environ["RANK"])
|
| 60 |
+
args.world_size = int(os.environ["WORLD_SIZE"])
|
| 61 |
+
args.gpu = int(os.environ["LOCAL_RANK"])
|
| 62 |
+
elif "SLURM_PROCID" in os.environ:
|
| 63 |
+
args.rank = int(os.environ["SLURM_PROCID"])
|
| 64 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
| 65 |
+
else:
|
| 66 |
+
print("Not using distributed mode")
|
| 67 |
+
args.distributed = False
|
| 68 |
+
return
|
| 69 |
+
|
| 70 |
+
args.distributed = True
|
| 71 |
+
|
| 72 |
+
torch.cuda.set_device(args.gpu)
|
| 73 |
+
args.dist_backend = "nccl"
|
| 74 |
+
print(
|
| 75 |
+
"| distributed init (rank {}, world {}): {}".format(
|
| 76 |
+
args.rank, args.world_size, args.dist_url
|
| 77 |
+
),
|
| 78 |
+
flush=True,
|
| 79 |
+
)
|
| 80 |
+
torch.distributed.init_process_group(
|
| 81 |
+
backend=args.dist_backend,
|
| 82 |
+
init_method=args.dist_url,
|
| 83 |
+
world_size=args.world_size,
|
| 84 |
+
rank=args.rank,
|
| 85 |
+
timeout=datetime.timedelta(
|
| 86 |
+
days=365
|
| 87 |
+
), # allow auto-downloading and de-compressing
|
| 88 |
+
)
|
| 89 |
+
torch.distributed.barrier()
|
| 90 |
+
setup_for_distributed(args.rank == 0)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def get_dist_info():
|
| 94 |
+
if torch.__version__ < "1.0":
|
| 95 |
+
initialized = dist._initialized
|
| 96 |
+
else:
|
| 97 |
+
initialized = dist.is_initialized()
|
| 98 |
+
if initialized:
|
| 99 |
+
rank = dist.get_rank()
|
| 100 |
+
world_size = dist.get_world_size()
|
| 101 |
+
else: # non-distributed training
|
| 102 |
+
rank = 0
|
| 103 |
+
world_size = 1
|
| 104 |
+
return rank, world_size
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def main_process(func):
|
| 108 |
+
@functools.wraps(func)
|
| 109 |
+
def wrapper(*args, **kwargs):
|
| 110 |
+
rank, _ = get_dist_info()
|
| 111 |
+
if rank == 0:
|
| 112 |
+
return func(*args, **kwargs)
|
| 113 |
+
|
| 114 |
+
return wrapper
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def download_cached_file(url, check_hash=True, progress=False):
|
| 118 |
+
"""
|
| 119 |
+
Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
|
| 120 |
+
If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
def get_cached_file_path():
|
| 124 |
+
# a hack to sync the file path across processes
|
| 125 |
+
parts = torch.hub.urlparse(url)
|
| 126 |
+
filename = os.path.basename(parts.path)
|
| 127 |
+
cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
|
| 128 |
+
|
| 129 |
+
return cached_file
|
| 130 |
+
|
| 131 |
+
if is_main_process():
|
| 132 |
+
timm_hub.download_cached_file(url, check_hash, progress)
|
| 133 |
+
|
| 134 |
+
if is_dist_avail_and_initialized():
|
| 135 |
+
dist.barrier()
|
| 136 |
+
|
| 137 |
+
return get_cached_file_path()
|
ChatUniVi/model/multimodal_projector/builder.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import re
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class IdentityMap(nn.Module):
|
| 7 |
+
def __init__(self):
|
| 8 |
+
super().__init__()
|
| 9 |
+
|
| 10 |
+
def forward(self, x, *args, **kwargs):
|
| 11 |
+
return x
|
| 12 |
+
|
| 13 |
+
@property
|
| 14 |
+
def config(self):
|
| 15 |
+
return {"mm_projector_type": 'identity'}
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class SimpleResBlock(nn.Module):
|
| 19 |
+
def __init__(self, channels):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.pre_norm = nn.LayerNorm(channels)
|
| 22 |
+
|
| 23 |
+
self.proj = nn.Sequential(
|
| 24 |
+
nn.Linear(channels, channels),
|
| 25 |
+
nn.GELU(),
|
| 26 |
+
nn.Linear(channels, channels)
|
| 27 |
+
)
|
| 28 |
+
def forward(self, x):
|
| 29 |
+
x = self.pre_norm(x)
|
| 30 |
+
return x + self.proj(x)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def build_vision_projector(config, delay_load=False, **kwargs):
|
| 34 |
+
projector_type = getattr(config, 'mm_projector_type', 'linear')
|
| 35 |
+
|
| 36 |
+
if projector_type == 'linear':
|
| 37 |
+
return nn.Linear(config.mm_hidden_size, config.hidden_size)
|
| 38 |
+
|
| 39 |
+
print("projector_type:", projector_type)
|
| 40 |
+
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
|
| 41 |
+
if mlp_gelu_match:
|
| 42 |
+
mlp_depth = int(mlp_gelu_match.group(1))
|
| 43 |
+
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
|
| 44 |
+
for _ in range(1, mlp_depth):
|
| 45 |
+
modules.append(nn.GELU())
|
| 46 |
+
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
|
| 47 |
+
return nn.Sequential(*modules)
|
| 48 |
+
|
| 49 |
+
if projector_type == 'identity':
|
| 50 |
+
return IdentityMap()
|
| 51 |
+
|
| 52 |
+
raise ValueError(f'Unknown projector type: {projector_type}')
|
ChatUniVi/train/llama_flash_attn_monkey_patch.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Tuple
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
|
| 7 |
+
import transformers
|
| 8 |
+
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
| 9 |
+
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
|
| 14 |
+
except ImportError:
|
| 15 |
+
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
|
| 16 |
+
from flash_attn.bert_padding import unpad_input, pad_input
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def forward(
|
| 20 |
+
self,
|
| 21 |
+
hidden_states: torch.Tensor,
|
| 22 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 23 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 24 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 25 |
+
output_attentions: bool = False,
|
| 26 |
+
use_cache: bool = False,
|
| 27 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 28 |
+
"""Input shape: Batch x Time x Channel
|
| 29 |
+
|
| 30 |
+
attention_mask: [bsz, q_len]
|
| 31 |
+
"""
|
| 32 |
+
bsz, q_len, _ = hidden_states.size()
|
| 33 |
+
|
| 34 |
+
query_states = (
|
| 35 |
+
self.q_proj(hidden_states)
|
| 36 |
+
.view(bsz, q_len, self.num_heads, self.head_dim)
|
| 37 |
+
.transpose(1, 2)
|
| 38 |
+
)
|
| 39 |
+
key_states = (
|
| 40 |
+
self.k_proj(hidden_states)
|
| 41 |
+
.view(bsz, q_len, self.num_heads, self.head_dim)
|
| 42 |
+
.transpose(1, 2)
|
| 43 |
+
)
|
| 44 |
+
value_states = (
|
| 45 |
+
self.v_proj(hidden_states)
|
| 46 |
+
.view(bsz, q_len, self.num_heads, self.head_dim)
|
| 47 |
+
.transpose(1, 2)
|
| 48 |
+
)
|
| 49 |
+
# [bsz, q_len, nh, hd]
|
| 50 |
+
# [bsz, nh, q_len, hd]
|
| 51 |
+
|
| 52 |
+
kv_seq_len = key_states.shape[-2]
|
| 53 |
+
assert past_key_value is None, "past_key_value is not supported"
|
| 54 |
+
|
| 55 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 56 |
+
query_states, key_states = apply_rotary_pos_emb(
|
| 57 |
+
query_states, key_states, cos, sin, position_ids
|
| 58 |
+
)
|
| 59 |
+
# [bsz, nh, t, hd]
|
| 60 |
+
assert not output_attentions, "output_attentions is not supported"
|
| 61 |
+
assert not use_cache, "use_cache is not supported"
|
| 62 |
+
|
| 63 |
+
# Flash attention codes from
|
| 64 |
+
# https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
|
| 65 |
+
|
| 66 |
+
# transform the data into the format required by flash attention
|
| 67 |
+
qkv = torch.stack(
|
| 68 |
+
[query_states, key_states, value_states], dim=2
|
| 69 |
+
) # [bsz, nh, 3, q_len, hd]
|
| 70 |
+
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
| 71 |
+
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
| 72 |
+
# the attention_mask should be the same as the key_padding_mask
|
| 73 |
+
key_padding_mask = attention_mask
|
| 74 |
+
|
| 75 |
+
if key_padding_mask is None:
|
| 76 |
+
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
| 77 |
+
max_s = q_len
|
| 78 |
+
cu_q_lens = torch.arange(
|
| 79 |
+
0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
|
| 80 |
+
)
|
| 81 |
+
output = flash_attn_unpadded_qkvpacked_func(
|
| 82 |
+
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
| 83 |
+
)
|
| 84 |
+
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
| 85 |
+
else:
|
| 86 |
+
nheads = qkv.shape[-2]
|
| 87 |
+
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
| 88 |
+
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
|
| 89 |
+
x_unpad = rearrange(
|
| 90 |
+
x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
|
| 91 |
+
)
|
| 92 |
+
output_unpad = flash_attn_unpadded_qkvpacked_func(
|
| 93 |
+
x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
| 94 |
+
)
|
| 95 |
+
output = rearrange(
|
| 96 |
+
pad_input(
|
| 97 |
+
rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len
|
| 98 |
+
),
|
| 99 |
+
"b s (h d) -> b s h d",
|
| 100 |
+
h=nheads,
|
| 101 |
+
)
|
| 102 |
+
return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
| 106 |
+
# requires the attention mask to be the same as the key_padding_mask
|
| 107 |
+
def _prepare_decoder_attention_mask(
|
| 108 |
+
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
| 109 |
+
):
|
| 110 |
+
# [bsz, seq_len]
|
| 111 |
+
return attention_mask
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def replace_llama_attn_with_flash_attn():
|
| 115 |
+
cuda_major, cuda_minor = torch.cuda.get_device_capability()
|
| 116 |
+
if cuda_major < 8:
|
| 117 |
+
logging.warning(
|
| 118 |
+
"Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
|
| 119 |
+
"ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
|
| 120 |
+
)
|
| 121 |
+
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
|
| 122 |
+
_prepare_decoder_attention_mask
|
| 123 |
+
)
|
| 124 |
+
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
|
ChatUniVi/train/train.py
ADDED
|
@@ -0,0 +1,1232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
|
| 2 |
+
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
|
| 3 |
+
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
import copy
|
| 19 |
+
from dataclasses import dataclass, field
|
| 20 |
+
import json
|
| 21 |
+
import logging
|
| 22 |
+
import pathlib
|
| 23 |
+
from typing import Dict, Optional, Sequence, List
|
| 24 |
+
import torch
|
| 25 |
+
import transformers
|
| 26 |
+
from ChatUniVi.constants import *
|
| 27 |
+
from torch.utils.data import Dataset
|
| 28 |
+
from ChatUniVi.train.trainer import ChatUniViTrainer
|
| 29 |
+
from ChatUniVi import conversation as conversation_lib
|
| 30 |
+
from ChatUniVi.model import *
|
| 31 |
+
from ChatUniVi.mm_utils import tokenizer_image_token
|
| 32 |
+
from ChatUniVi.config import ModelConfig, DataConfig
|
| 33 |
+
from PIL import Image
|
| 34 |
+
import random
|
| 35 |
+
import numpy as np
|
| 36 |
+
from ChatUniVi.model.dataloader import _get_rawvideo_dec
|
| 37 |
+
|
| 38 |
+
local_rank = None
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def rank0_print(*args):
|
| 42 |
+
if local_rank == 0:
|
| 43 |
+
print(*args)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@dataclass
|
| 47 |
+
class ModelArguments:
|
| 48 |
+
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
|
| 49 |
+
version: Optional[str] = field(default="v0")
|
| 50 |
+
freeze_backbone: bool = field(default=False)
|
| 51 |
+
tune_mm_mlp_adapter: bool = field(default=False)
|
| 52 |
+
vision_tower: Optional[str] = field(default=None)
|
| 53 |
+
mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer
|
| 54 |
+
pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
|
| 55 |
+
mm_use_im_start_end: bool = field(default=False)
|
| 56 |
+
mm_use_im_patch_token: bool = field(default=True)
|
| 57 |
+
mm_vision_select_feature: Optional[str] = field(default="patch")
|
| 58 |
+
|
| 59 |
+
mm_projector_type: Optional[str] = field(default='linear')
|
| 60 |
+
model_use: str = field(default="BASE")
|
| 61 |
+
mm_use_box_start_end: bool = field(default=False)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@dataclass
|
| 65 |
+
class DataArguments:
|
| 66 |
+
lazy_preprocess: bool = False
|
| 67 |
+
is_multimodal: bool = False
|
| 68 |
+
image_aspect_ratio: str = 'square'
|
| 69 |
+
image_grid_pinpoints: Optional[str] = field(default=None)
|
| 70 |
+
|
| 71 |
+
dataset_use: str = field(default="Pretrain")
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
@dataclass
|
| 75 |
+
class TrainingArguments(transformers.TrainingArguments):
|
| 76 |
+
cache_dir: Optional[str] = field(default=None)
|
| 77 |
+
optim: str = field(default="adamw_torch")
|
| 78 |
+
remove_unused_columns: bool = field(default=False)
|
| 79 |
+
freeze_mm_mlp_adapter: bool = field(default=False)
|
| 80 |
+
mpt_attn_impl: Optional[str] = field(default="triton")
|
| 81 |
+
model_max_length: int = field(
|
| 82 |
+
default=512,
|
| 83 |
+
metadata={
|
| 84 |
+
"help":
|
| 85 |
+
"Maximum sequence length. Sequences will be right padded (and possibly truncated)."
|
| 86 |
+
},
|
| 87 |
+
)
|
| 88 |
+
double_quant: bool = field(
|
| 89 |
+
default=True,
|
| 90 |
+
metadata={"help": "Compress the quantization statistics through double quantization."}
|
| 91 |
+
)
|
| 92 |
+
quant_type: str = field(
|
| 93 |
+
default="nf4",
|
| 94 |
+
metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
|
| 95 |
+
)
|
| 96 |
+
bits: int = field(
|
| 97 |
+
default=16,
|
| 98 |
+
metadata={"help": "How many bits to use."}
|
| 99 |
+
)
|
| 100 |
+
lora_enable: bool = False
|
| 101 |
+
lora_r: int = 64
|
| 102 |
+
lora_alpha: int = 16
|
| 103 |
+
lora_dropout: float = 0.05
|
| 104 |
+
lora_weight_path: str = ""
|
| 105 |
+
lora_bias: str = "none"
|
| 106 |
+
|
| 107 |
+
seed = 42
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def maybe_zero_3(param, ignore_status=False, name=None):
|
| 111 |
+
from deepspeed import zero
|
| 112 |
+
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
|
| 113 |
+
if hasattr(param, "ds_id"):
|
| 114 |
+
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
|
| 115 |
+
if not ignore_status:
|
| 116 |
+
logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
|
| 117 |
+
with zero.GatheredParameters([param]):
|
| 118 |
+
param = param.data.detach().cpu().clone()
|
| 119 |
+
else:
|
| 120 |
+
param = param.detach().cpu().clone()
|
| 121 |
+
return param
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
# Borrowed from peft.utils.get_peft_model_state_dict
|
| 125 |
+
def get_peft_state_maybe_zero_3(named_params, bias):
|
| 126 |
+
if bias == "none":
|
| 127 |
+
to_return = {k: t for k, t in named_params if "lora_" in k}
|
| 128 |
+
elif bias == "all":
|
| 129 |
+
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
|
| 130 |
+
elif bias == "lora_only":
|
| 131 |
+
to_return = {}
|
| 132 |
+
maybe_lora_bias = {}
|
| 133 |
+
lora_bias_names = set()
|
| 134 |
+
for k, t in named_params:
|
| 135 |
+
if "lora_" in k:
|
| 136 |
+
to_return[k] = t
|
| 137 |
+
bias_name = k.split("lora_")[0] + "bias"
|
| 138 |
+
lora_bias_names.add(bias_name)
|
| 139 |
+
elif "bias" in k:
|
| 140 |
+
maybe_lora_bias[k] = t
|
| 141 |
+
for k, t in maybe_lora_bias:
|
| 142 |
+
if bias_name in lora_bias_names:
|
| 143 |
+
to_return[bias_name] = t
|
| 144 |
+
else:
|
| 145 |
+
raise NotImplementedError
|
| 146 |
+
to_return = {k: maybe_zero_3(v, name=k) for k, v in to_return.items()}
|
| 147 |
+
return to_return
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
|
| 151 |
+
to_return = {k: t for k, t in named_params if "lora_" not in k}
|
| 152 |
+
if require_grad_only:
|
| 153 |
+
to_return = {k: t for k, t in to_return.items() if t.requires_grad}
|
| 154 |
+
to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
|
| 155 |
+
return to_return
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
|
| 159 |
+
to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
|
| 160 |
+
to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
|
| 161 |
+
return to_return
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def find_all_linear_names(model):
|
| 165 |
+
cls = torch.nn.Linear
|
| 166 |
+
lora_module_names = set()
|
| 167 |
+
for name, module in model.named_modules():
|
| 168 |
+
if isinstance(module, cls):
|
| 169 |
+
names = name.split('.')
|
| 170 |
+
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
| 171 |
+
|
| 172 |
+
if 'lm_head' in lora_module_names: # needed for 16-bit
|
| 173 |
+
lora_module_names.remove('lm_head')
|
| 174 |
+
return list(lora_module_names)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
|
| 178 |
+
output_dir: str):
|
| 179 |
+
"""Collects the state dict and dump to disk."""
|
| 180 |
+
|
| 181 |
+
if getattr(trainer.args, "tune_mm_mlp_adapter", False):
|
| 182 |
+
# Only save Adapter
|
| 183 |
+
keys_to_match = ['mm_projector', "ctm", "block"]
|
| 184 |
+
if getattr(trainer.args, "use_im_start_end", False):
|
| 185 |
+
keys_to_match.extend(['embed_tokens', 'embed_in'])
|
| 186 |
+
|
| 187 |
+
weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match)
|
| 188 |
+
trainer.model.config.save_pretrained(output_dir)
|
| 189 |
+
|
| 190 |
+
current_folder = output_dir.split('/')[-1]
|
| 191 |
+
parent_folder = os.path.dirname(output_dir)
|
| 192 |
+
if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
|
| 193 |
+
if current_folder.startswith('checkpoint-'):
|
| 194 |
+
mm_projector_folder = os.path.join(parent_folder, "mm_projector")
|
| 195 |
+
os.makedirs(mm_projector_folder, exist_ok=True)
|
| 196 |
+
torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin'))
|
| 197 |
+
else:
|
| 198 |
+
torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
|
| 199 |
+
|
| 200 |
+
if trainer.deepspeed:
|
| 201 |
+
torch.cuda.synchronize()
|
| 202 |
+
trainer.save_model(output_dir)
|
| 203 |
+
return
|
| 204 |
+
|
| 205 |
+
state_dict = trainer.model.state_dict()
|
| 206 |
+
if trainer.args.should_save:
|
| 207 |
+
cpu_state_dict = {
|
| 208 |
+
key: value.cpu()
|
| 209 |
+
for key, value in state_dict.items()
|
| 210 |
+
}
|
| 211 |
+
del state_dict
|
| 212 |
+
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def smart_tokenizer_and_embedding_resize(
|
| 216 |
+
special_tokens_dict: Dict,
|
| 217 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 218 |
+
model: transformers.PreTrainedModel,
|
| 219 |
+
):
|
| 220 |
+
"""Resize tokenizer and embedding.
|
| 221 |
+
|
| 222 |
+
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
|
| 223 |
+
"""
|
| 224 |
+
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
|
| 225 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 226 |
+
|
| 227 |
+
if num_new_tokens > 0:
|
| 228 |
+
input_embeddings = model.get_input_embeddings().weight.data
|
| 229 |
+
output_embeddings = model.get_output_embeddings().weight.data
|
| 230 |
+
|
| 231 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
| 232 |
+
dim=0, keepdim=True)
|
| 233 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
| 234 |
+
dim=0, keepdim=True)
|
| 235 |
+
|
| 236 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
| 237 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def _tokenize_fn(strings: Sequence[str],
|
| 241 |
+
tokenizer: transformers.PreTrainedTokenizer) -> Dict:
|
| 242 |
+
"""Tokenize a list of strings."""
|
| 243 |
+
tokenized_list = [
|
| 244 |
+
tokenizer(
|
| 245 |
+
text,
|
| 246 |
+
return_tensors="pt",
|
| 247 |
+
padding="longest",
|
| 248 |
+
max_length=tokenizer.model_max_length,
|
| 249 |
+
truncation=True,
|
| 250 |
+
) for text in strings
|
| 251 |
+
]
|
| 252 |
+
input_ids = labels = [
|
| 253 |
+
tokenized.input_ids[0] for tokenized in tokenized_list
|
| 254 |
+
]
|
| 255 |
+
input_ids_lens = labels_lens = [
|
| 256 |
+
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
|
| 257 |
+
for tokenized in tokenized_list
|
| 258 |
+
]
|
| 259 |
+
return dict(
|
| 260 |
+
input_ids=input_ids,
|
| 261 |
+
labels=labels,
|
| 262 |
+
input_ids_lens=input_ids_lens,
|
| 263 |
+
labels_lens=labels_lens,
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def _mask_targets(target, tokenized_lens, speakers):
|
| 268 |
+
# cur_idx = 0
|
| 269 |
+
cur_idx = tokenized_lens[0]
|
| 270 |
+
tokenized_lens = tokenized_lens[1:]
|
| 271 |
+
target[:cur_idx] = IGNORE_INDEX
|
| 272 |
+
for tokenized_len, speaker in zip(tokenized_lens, speakers):
|
| 273 |
+
if speaker == "human":
|
| 274 |
+
target[cur_idx + 2:cur_idx + tokenized_len] = IGNORE_INDEX
|
| 275 |
+
cur_idx += tokenized_len
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def _add_speaker_and_signal(header, source, get_conversation=True):
|
| 279 |
+
"""Add speaker and start/end signal on each round."""
|
| 280 |
+
BEGIN_SIGNAL = "### "
|
| 281 |
+
END_SIGNAL = "\n"
|
| 282 |
+
conversation = header
|
| 283 |
+
for sentence in source:
|
| 284 |
+
from_str = sentence["from"]
|
| 285 |
+
if from_str.lower() == "human":
|
| 286 |
+
from_str = conversation_lib.default_conversation.roles[0]
|
| 287 |
+
elif from_str.lower() == "gpt":
|
| 288 |
+
from_str = conversation_lib.default_conversation.roles[1]
|
| 289 |
+
else:
|
| 290 |
+
from_str = 'unknown'
|
| 291 |
+
sentence["value"] = (BEGIN_SIGNAL + from_str + ": " +
|
| 292 |
+
sentence["value"] + END_SIGNAL)
|
| 293 |
+
if get_conversation:
|
| 294 |
+
conversation += sentence["value"]
|
| 295 |
+
conversation += BEGIN_SIGNAL
|
| 296 |
+
return conversation
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def preprocess_multimodal(
|
| 300 |
+
sources: Sequence[str],
|
| 301 |
+
data_args: DataArguments,
|
| 302 |
+
image_token_num=1
|
| 303 |
+
) -> Dict:
|
| 304 |
+
is_multimodal = data_args.is_multimodal
|
| 305 |
+
if not is_multimodal:
|
| 306 |
+
return sources
|
| 307 |
+
|
| 308 |
+
for source in sources:
|
| 309 |
+
for sentence in source:
|
| 310 |
+
if DEFAULT_IMAGE_TOKEN in sentence['value'] or DEFAULT_VIDEO_TOKEN in sentence['value']:
|
| 311 |
+
sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN + '\n', DEFAULT_IMAGE_TOKEN).strip()
|
| 312 |
+
sentence['value'] = sentence['value'].replace('\n' + DEFAULT_IMAGE_TOKEN, DEFAULT_IMAGE_TOKEN).strip()
|
| 313 |
+
if sentence['value'].endswith(DEFAULT_IMAGE_TOKEN):
|
| 314 |
+
IMAGE_TOKEN_NUM = sentence['value'].count(DEFAULT_IMAGE_TOKEN)
|
| 315 |
+
sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN * IMAGE_TOKEN_NUM, '').strip()
|
| 316 |
+
sentence['value'] = DEFAULT_IMAGE_TOKEN * IMAGE_TOKEN_NUM + sentence['value']
|
| 317 |
+
sentence['value'] = sentence['value'].strip()
|
| 318 |
+
if sentence['value'].endswith(DEFAULT_VIDEO_TOKEN):
|
| 319 |
+
VIDEO_TOKEN_NUM = sentence['value'].count(DEFAULT_VIDEO_TOKEN)
|
| 320 |
+
sentence['value'] = sentence['value'].replace(DEFAULT_VIDEO_TOKEN * VIDEO_TOKEN_NUM, '').strip()
|
| 321 |
+
sentence['value'] = DEFAULT_VIDEO_TOKEN * VIDEO_TOKEN_NUM + sentence['value']
|
| 322 |
+
sentence['value'] = sentence['value'].strip()
|
| 323 |
+
|
| 324 |
+
if "mmtag" in conversation_lib.default_conversation.version:
|
| 325 |
+
sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN,
|
| 326 |
+
'<Image>' + DEFAULT_IMAGE_TOKEN + '</Image>')
|
| 327 |
+
|
| 328 |
+
IMAGE_TOKEN_NUM = sentence['value'].count(DEFAULT_IMAGE_TOKEN)
|
| 329 |
+
if IMAGE_TOKEN_NUM > MAX_IMAGE_LENGTH:
|
| 330 |
+
sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN * IMAGE_TOKEN_NUM,
|
| 331 |
+
DEFAULT_IMAGE_TOKEN * MAX_IMAGE_LENGTH).strip()
|
| 332 |
+
|
| 333 |
+
replace_token, vid_replace_token = DEFAULT_IMAGE_TOKEN, DEFAULT_IMAGE_TOKEN * image_token_num
|
| 334 |
+
if data_args.mm_use_im_start_end:
|
| 335 |
+
replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
|
| 336 |
+
vid_replace_token = DEFAULT_VID_START_TOKEN + vid_replace_token + DEFAULT_VID_END_TOKEN
|
| 337 |
+
|
| 338 |
+
sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token + '\n')
|
| 339 |
+
sentence['value'] = sentence['value'].replace(DEFAULT_VIDEO_TOKEN, vid_replace_token + '\n')
|
| 340 |
+
sentence['value'] = sentence['value'].replace('\n\n', '\n')
|
| 341 |
+
|
| 342 |
+
return sources
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def preprocess_llama_2(
|
| 346 |
+
sources,
|
| 347 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 348 |
+
has_image: bool = False
|
| 349 |
+
) -> Dict:
|
| 350 |
+
conv = conversation_lib.default_conversation.copy()
|
| 351 |
+
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
| 352 |
+
|
| 353 |
+
# Apply prompt templates
|
| 354 |
+
conversations = []
|
| 355 |
+
for i, source in enumerate(sources):
|
| 356 |
+
if roles[source[0]["from"]] != conv.roles[0]:
|
| 357 |
+
# Skip the first one if it is not from human
|
| 358 |
+
source = source[1:]
|
| 359 |
+
|
| 360 |
+
conv.messages = []
|
| 361 |
+
for j, sentence in enumerate(source):
|
| 362 |
+
role = roles[sentence["from"]]
|
| 363 |
+
assert role == conv.roles[j % 2], f"{i}"
|
| 364 |
+
conv.append_message(role, sentence["value"])
|
| 365 |
+
conversations.append(conv.get_prompt())
|
| 366 |
+
|
| 367 |
+
# Tokenize conversations
|
| 368 |
+
|
| 369 |
+
if has_image:
|
| 370 |
+
input_ids = torch.stack(
|
| 371 |
+
[tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
|
| 372 |
+
else:
|
| 373 |
+
input_ids = tokenizer(
|
| 374 |
+
conversations,
|
| 375 |
+
return_tensors="pt",
|
| 376 |
+
padding="longest",
|
| 377 |
+
max_length=tokenizer.model_max_length,
|
| 378 |
+
truncation=True,
|
| 379 |
+
).input_ids
|
| 380 |
+
|
| 381 |
+
targets = input_ids.clone()
|
| 382 |
+
|
| 383 |
+
assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2
|
| 384 |
+
|
| 385 |
+
# Mask targets
|
| 386 |
+
sep = "[/INST] "
|
| 387 |
+
for conversation, target in zip(conversations, targets):
|
| 388 |
+
total_len = int(target.ne(tokenizer.pad_token_id).sum())
|
| 389 |
+
|
| 390 |
+
rounds = conversation.split(conv.sep2)
|
| 391 |
+
|
| 392 |
+
cur_len = 1
|
| 393 |
+
target[:cur_len] = IGNORE_INDEX
|
| 394 |
+
|
| 395 |
+
for i, rou in enumerate(rounds):
|
| 396 |
+
if rou == "":
|
| 397 |
+
break
|
| 398 |
+
|
| 399 |
+
parts = rou.split(sep)
|
| 400 |
+
if len(parts) != 2:
|
| 401 |
+
break
|
| 402 |
+
parts[0] += sep
|
| 403 |
+
|
| 404 |
+
if has_image:
|
| 405 |
+
round_len = len(tokenizer_image_token(rou, tokenizer))
|
| 406 |
+
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
|
| 407 |
+
else:
|
| 408 |
+
round_len = len(tokenizer(rou).input_ids)
|
| 409 |
+
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
|
| 410 |
+
|
| 411 |
+
target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
|
| 412 |
+
|
| 413 |
+
cur_len += round_len
|
| 414 |
+
|
| 415 |
+
if tokenizer.eos_token == tokenizer.pad_token:
|
| 416 |
+
cur_len += 1
|
| 417 |
+
|
| 418 |
+
target[cur_len:] = IGNORE_INDEX
|
| 419 |
+
|
| 420 |
+
if cur_len < tokenizer.model_max_length:
|
| 421 |
+
if cur_len != total_len:
|
| 422 |
+
target[:] = IGNORE_INDEX
|
| 423 |
+
print(
|
| 424 |
+
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
|
| 425 |
+
f" (ignored)"
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
return dict(
|
| 429 |
+
input_ids=input_ids,
|
| 430 |
+
labels=targets,
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
def preprocess_v1(
|
| 435 |
+
sources,
|
| 436 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 437 |
+
has_image: bool = False
|
| 438 |
+
) -> Dict:
|
| 439 |
+
conv = conversation_lib.default_conversation.copy()
|
| 440 |
+
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
| 441 |
+
|
| 442 |
+
# Apply prompt templates
|
| 443 |
+
conversations = []
|
| 444 |
+
for i, source in enumerate(sources):
|
| 445 |
+
if roles[source[0]["from"]] != conv.roles[0]:
|
| 446 |
+
# Skip the first one if it is not from human
|
| 447 |
+
source = source[1:]
|
| 448 |
+
|
| 449 |
+
conv.messages = []
|
| 450 |
+
for j, sentence in enumerate(source):
|
| 451 |
+
role = roles[sentence["from"]]
|
| 452 |
+
assert role == conv.roles[j % 2], f"{i}"
|
| 453 |
+
conv.append_message(role, sentence["value"])
|
| 454 |
+
conversations.append(conv.get_prompt())
|
| 455 |
+
|
| 456 |
+
# Tokenize conversations
|
| 457 |
+
if has_image:
|
| 458 |
+
input_ids = torch.stack(
|
| 459 |
+
[tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
|
| 460 |
+
else:
|
| 461 |
+
input_ids = tokenizer(
|
| 462 |
+
conversations,
|
| 463 |
+
return_tensors="pt",
|
| 464 |
+
padding="longest",
|
| 465 |
+
max_length=tokenizer.model_max_length,
|
| 466 |
+
truncation=True,
|
| 467 |
+
).input_ids
|
| 468 |
+
|
| 469 |
+
targets = input_ids.clone()
|
| 470 |
+
assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
|
| 471 |
+
|
| 472 |
+
# Mask targets
|
| 473 |
+
round_len_list = []
|
| 474 |
+
sep = conv.sep + conv.roles[1] + ": "
|
| 475 |
+
for conversation, target in zip(conversations, targets):
|
| 476 |
+
total_len = int(target.ne(tokenizer.pad_token_id).sum())
|
| 477 |
+
|
| 478 |
+
rounds = conversation.split(conv.sep2)
|
| 479 |
+
cur_len = 1
|
| 480 |
+
target[:cur_len] = IGNORE_INDEX
|
| 481 |
+
for i, rou in enumerate(rounds):
|
| 482 |
+
if rou == "":
|
| 483 |
+
break
|
| 484 |
+
|
| 485 |
+
parts = rou.split(sep)
|
| 486 |
+
if len(parts) != 2:
|
| 487 |
+
break
|
| 488 |
+
parts[0] += sep
|
| 489 |
+
|
| 490 |
+
if has_image:
|
| 491 |
+
round_len = len(tokenizer_image_token(rou, tokenizer))
|
| 492 |
+
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
|
| 493 |
+
else:
|
| 494 |
+
round_len = len(tokenizer(rou).input_ids)
|
| 495 |
+
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
|
| 496 |
+
|
| 497 |
+
target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
|
| 498 |
+
# print("rou:", rou)
|
| 499 |
+
# print(round_len, instruction_len)
|
| 500 |
+
# print(len(tokenizer(rou).input_ids), len(tokenizer_image_token(rou, tokenizer)))
|
| 501 |
+
cur_len += round_len
|
| 502 |
+
round_len_list.append(round_len)
|
| 503 |
+
target[cur_len:] = IGNORE_INDEX
|
| 504 |
+
|
| 505 |
+
if cur_len < tokenizer.model_max_length:
|
| 506 |
+
if cur_len != total_len:
|
| 507 |
+
# print(conversations, target, round_len_list)
|
| 508 |
+
target[:] = IGNORE_INDEX
|
| 509 |
+
print(
|
| 510 |
+
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
|
| 511 |
+
f" (ignored)"
|
| 512 |
+
)
|
| 513 |
+
# exit()
|
| 514 |
+
# print("ok", conversations, target, round_len_list)
|
| 515 |
+
return dict(
|
| 516 |
+
input_ids=input_ids,
|
| 517 |
+
labels=targets,
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
def preprocess_mpt(
|
| 522 |
+
sources,
|
| 523 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 524 |
+
) -> Dict:
|
| 525 |
+
conv = conversation_lib.default_conversation.copy()
|
| 526 |
+
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
| 527 |
+
|
| 528 |
+
# Apply prompt templates
|
| 529 |
+
conversations = []
|
| 530 |
+
for i, source in enumerate(sources):
|
| 531 |
+
if roles[source[0]["from"]] != conv.roles[0]:
|
| 532 |
+
# Skip the first one if it is not from human
|
| 533 |
+
source = source[1:]
|
| 534 |
+
|
| 535 |
+
conv.messages = []
|
| 536 |
+
for j, sentence in enumerate(source):
|
| 537 |
+
role = roles[sentence["from"]]
|
| 538 |
+
assert role == conv.roles[j % 2], f"{i}"
|
| 539 |
+
conv.append_message(role, sentence["value"])
|
| 540 |
+
conversations.append(conv.get_prompt())
|
| 541 |
+
|
| 542 |
+
# Tokenize conversations
|
| 543 |
+
input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations],
|
| 544 |
+
dim=0)
|
| 545 |
+
targets = input_ids.clone()
|
| 546 |
+
assert conv.sep_style == conversation_lib.SeparatorStyle.MPT
|
| 547 |
+
|
| 548 |
+
# Mask targets
|
| 549 |
+
sep = conv.sep + conv.roles[1]
|
| 550 |
+
for conversation, target in zip(conversations, targets):
|
| 551 |
+
total_len = int(target.ne(tokenizer.pad_token_id).sum())
|
| 552 |
+
|
| 553 |
+
rounds = conversation.split(conv.sep)
|
| 554 |
+
re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
|
| 555 |
+
for conv_idx in range(3, len(rounds), 2):
|
| 556 |
+
re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx + 2])) # user + gpt
|
| 557 |
+
cur_len = 0
|
| 558 |
+
target[:cur_len] = IGNORE_INDEX
|
| 559 |
+
for i, rou in enumerate(re_rounds):
|
| 560 |
+
if rou == "":
|
| 561 |
+
break
|
| 562 |
+
|
| 563 |
+
parts = rou.split(sep)
|
| 564 |
+
if len(parts) != 2:
|
| 565 |
+
break
|
| 566 |
+
parts[0] += sep
|
| 567 |
+
round_len = len(tokenizer_image_token(rou, tokenizer)) + len(tokenizer_image_token(conv.sep, tokenizer))
|
| 568 |
+
instruction_len = len(tokenizer_image_token(parts[0], tokenizer))
|
| 569 |
+
target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
|
| 570 |
+
|
| 571 |
+
cur_len += round_len
|
| 572 |
+
target[cur_len:] = IGNORE_INDEX
|
| 573 |
+
|
| 574 |
+
if cur_len < tokenizer.model_max_length:
|
| 575 |
+
if cur_len != total_len:
|
| 576 |
+
target[:] = IGNORE_INDEX
|
| 577 |
+
print(
|
| 578 |
+
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
|
| 579 |
+
f" (ignored)"
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
return dict(
|
| 583 |
+
input_ids=input_ids,
|
| 584 |
+
labels=targets,
|
| 585 |
+
)
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
def preprocess_plain(
|
| 589 |
+
sources: Sequence[str],
|
| 590 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 591 |
+
) -> Dict:
|
| 592 |
+
# add end signal and concatenate together
|
| 593 |
+
conversations = []
|
| 594 |
+
for source in sources:
|
| 595 |
+
assert len(source) == 2
|
| 596 |
+
assert DEFAULT_IMAGE_TOKEN in source[0]['value']
|
| 597 |
+
source[0]['value'] = DEFAULT_IMAGE_TOKEN
|
| 598 |
+
conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep
|
| 599 |
+
conversations.append(conversation)
|
| 600 |
+
# tokenize conversations
|
| 601 |
+
input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
|
| 602 |
+
targets = copy.deepcopy(input_ids)
|
| 603 |
+
for target, source in zip(targets, sources):
|
| 604 |
+
tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer))
|
| 605 |
+
target[:tokenized_len] = IGNORE_INDEX
|
| 606 |
+
|
| 607 |
+
return dict(input_ids=input_ids, labels=targets)
|
| 608 |
+
|
| 609 |
+
|
| 610 |
+
def preprocess_phi(
|
| 611 |
+
sources,
|
| 612 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 613 |
+
has_image: bool = False
|
| 614 |
+
) -> Dict:
|
| 615 |
+
conv = conversation_lib.default_conversation.copy()
|
| 616 |
+
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
| 617 |
+
|
| 618 |
+
# Apply prompt templates
|
| 619 |
+
conversations = []
|
| 620 |
+
for i, source in enumerate(sources):
|
| 621 |
+
if roles[source[0]["from"]] != conv.roles[0]:
|
| 622 |
+
# Skip the first one if it is not from human
|
| 623 |
+
source = source[1:]
|
| 624 |
+
|
| 625 |
+
conv.messages = []
|
| 626 |
+
for j, sentence in enumerate(source):
|
| 627 |
+
role = roles[sentence["from"]]
|
| 628 |
+
assert role == conv.roles[j % 2], f"{i}"
|
| 629 |
+
conv.append_message(role, sentence["value"])
|
| 630 |
+
conversations.append(conv.get_prompt())
|
| 631 |
+
|
| 632 |
+
# Tokenize conversations
|
| 633 |
+
if has_image:
|
| 634 |
+
input_ids = torch.stack(
|
| 635 |
+
[tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
|
| 636 |
+
else:
|
| 637 |
+
input_ids = tokenizer(
|
| 638 |
+
conversations,
|
| 639 |
+
return_tensors="pt",
|
| 640 |
+
padding="longest",
|
| 641 |
+
max_length=tokenizer.model_max_length,
|
| 642 |
+
truncation=True,
|
| 643 |
+
).input_ids
|
| 644 |
+
|
| 645 |
+
targets = input_ids.clone()
|
| 646 |
+
assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
|
| 647 |
+
|
| 648 |
+
# Mask targets
|
| 649 |
+
round_len_list = []
|
| 650 |
+
sep = conv.sep + conv.roles[1] + ": "
|
| 651 |
+
for conversation, target in zip(conversations, targets):
|
| 652 |
+
total_len = int(target.ne(tokenizer.pad_token_id).sum())
|
| 653 |
+
|
| 654 |
+
rounds = conversation.split(conv.sep2)
|
| 655 |
+
cur_len = 0
|
| 656 |
+
pre_len = 0
|
| 657 |
+
for i, rou in enumerate(rounds):
|
| 658 |
+
if rou == "":
|
| 659 |
+
break
|
| 660 |
+
|
| 661 |
+
parts = rou.split(sep)
|
| 662 |
+
if len(parts) != 2:
|
| 663 |
+
break
|
| 664 |
+
parts[0] += sep
|
| 665 |
+
|
| 666 |
+
cur_len += 1
|
| 667 |
+
target[pre_len: cur_len] = IGNORE_INDEX
|
| 668 |
+
|
| 669 |
+
if has_image:
|
| 670 |
+
round_len = len(tokenizer_image_token(rou, tokenizer))
|
| 671 |
+
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
|
| 672 |
+
else:
|
| 673 |
+
round_len = len(tokenizer(rou).input_ids)
|
| 674 |
+
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
|
| 675 |
+
|
| 676 |
+
target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
|
| 677 |
+
# print("rou:", rou)
|
| 678 |
+
# print(round_len, instruction_len)
|
| 679 |
+
# print(len(tokenizer(rou).input_ids), len(tokenizer_image_token(rou, tokenizer)))
|
| 680 |
+
cur_len += round_len
|
| 681 |
+
pre_len = cur_len
|
| 682 |
+
round_len_list.append(round_len)
|
| 683 |
+
target[cur_len:] = IGNORE_INDEX
|
| 684 |
+
|
| 685 |
+
if cur_len < tokenizer.model_max_length:
|
| 686 |
+
if cur_len != total_len + len(rounds) - 1:
|
| 687 |
+
# print(conversations, target, round_len_list)
|
| 688 |
+
target[:] = IGNORE_INDEX
|
| 689 |
+
print(
|
| 690 |
+
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
|
| 691 |
+
f" (ignored)"
|
| 692 |
+
)
|
| 693 |
+
# exit()
|
| 694 |
+
# print("ok", conversations, target, round_len_list)
|
| 695 |
+
return dict(
|
| 696 |
+
input_ids=input_ids,
|
| 697 |
+
labels=targets,
|
| 698 |
+
)
|
| 699 |
+
|
| 700 |
+
|
| 701 |
+
def preprocess(
|
| 702 |
+
sources: Sequence[str],
|
| 703 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 704 |
+
has_image: bool = False
|
| 705 |
+
) -> Dict:
|
| 706 |
+
"""
|
| 707 |
+
Given a list of sources, each is a conversation list. This transform:
|
| 708 |
+
1. Add signal '### ' at the beginning each sentence, with end signal '\n';
|
| 709 |
+
2. Concatenate conversations together;
|
| 710 |
+
3. Tokenize the concatenated conversation;
|
| 711 |
+
4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
|
| 712 |
+
"""
|
| 713 |
+
if conversation_lib.default_conversation.version.startswith("phi"):
|
| 714 |
+
return preprocess_phi(sources, tokenizer, has_image=has_image)
|
| 715 |
+
if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:
|
| 716 |
+
return preprocess_plain(sources, tokenizer)
|
| 717 |
+
if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2:
|
| 718 |
+
return preprocess_llama_2(sources, tokenizer, has_image=has_image)
|
| 719 |
+
if conversation_lib.default_conversation.version.startswith("v1"):
|
| 720 |
+
return preprocess_v1(sources, tokenizer, has_image=has_image)
|
| 721 |
+
if conversation_lib.default_conversation.version == "mpt":
|
| 722 |
+
return preprocess_mpt(sources, tokenizer)
|
| 723 |
+
# add end signal and concatenate together
|
| 724 |
+
conversations = []
|
| 725 |
+
for source in sources:
|
| 726 |
+
header = f"{conversation_lib.default_conversation.system}\n\n"
|
| 727 |
+
conversation = _add_speaker_and_signal(header, source)
|
| 728 |
+
conversations.append(conversation)
|
| 729 |
+
|
| 730 |
+
# tokenize conversations
|
| 731 |
+
def get_tokenize_len(prompts):
|
| 732 |
+
return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts]
|
| 733 |
+
|
| 734 |
+
if has_image:
|
| 735 |
+
input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
|
| 736 |
+
else:
|
| 737 |
+
conversations_tokenized = _tokenize_fn(conversations, tokenizer)
|
| 738 |
+
input_ids = conversations_tokenized["input_ids"]
|
| 739 |
+
|
| 740 |
+
targets = copy.deepcopy(input_ids)
|
| 741 |
+
for target, source in zip(targets, sources):
|
| 742 |
+
if has_image:
|
| 743 |
+
tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source])
|
| 744 |
+
else:
|
| 745 |
+
tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"]
|
| 746 |
+
speakers = [sentence["from"] for sentence in source]
|
| 747 |
+
_mask_targets(target, tokenized_lens, speakers)
|
| 748 |
+
|
| 749 |
+
return dict(input_ids=input_ids, labels=targets)
|
| 750 |
+
|
| 751 |
+
|
| 752 |
+
class LazySupervisedDataset(Dataset):
|
| 753 |
+
"""Dataset for supervised fine-tuning."""
|
| 754 |
+
|
| 755 |
+
def __init__(self, tokenizer: transformers.PreTrainedTokenizer,
|
| 756 |
+
data_args: DataArguments):
|
| 757 |
+
super(LazySupervisedDataset, self).__init__()
|
| 758 |
+
|
| 759 |
+
dataset_list = DataConfig[str(data_args.dataset_use)]
|
| 760 |
+
print(dataset_list)
|
| 761 |
+
|
| 762 |
+
self.max_length = MAX_IMAGE_LENGTH
|
| 763 |
+
list_data_dict = []
|
| 764 |
+
self.folder_dict = {}
|
| 765 |
+
for i in dataset_list:
|
| 766 |
+
list_data_dict += json.load(open(i["chat_path"], "r"))
|
| 767 |
+
|
| 768 |
+
image_folder = [folder for folder in i if folder is not "chat_path"]
|
| 769 |
+
|
| 770 |
+
for folder in image_folder:
|
| 771 |
+
if folder not in self.folder_dict:
|
| 772 |
+
self.folder_dict[folder] = i[folder]
|
| 773 |
+
|
| 774 |
+
random.shuffle(list_data_dict)
|
| 775 |
+
|
| 776 |
+
rank0_print("Formatting inputs...Skip in lazy mode")
|
| 777 |
+
self.tokenizer = tokenizer
|
| 778 |
+
self.list_data_dict = list_data_dict
|
| 779 |
+
self.data_args = data_args
|
| 780 |
+
|
| 781 |
+
def __len__(self):
|
| 782 |
+
return len(self.list_data_dict)
|
| 783 |
+
|
| 784 |
+
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
| 785 |
+
sources = self.list_data_dict[i]
|
| 786 |
+
if isinstance(i, int):
|
| 787 |
+
sources = [sources]
|
| 788 |
+
assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
|
| 789 |
+
if 'image' in sources[0]:
|
| 790 |
+
image_file = self.list_data_dict[i]['image']
|
| 791 |
+
|
| 792 |
+
file = image_file[0] if type(image_file) is list else image_file
|
| 793 |
+
|
| 794 |
+
if "llava_image" in file:
|
| 795 |
+
image_folder = self.folder_dict['llava']
|
| 796 |
+
elif "\\" in file:
|
| 797 |
+
image_folder = self.folder_dict['ScienceQA']
|
| 798 |
+
elif "CGD" in file:
|
| 799 |
+
image_folder = self.folder_dict['CDG']
|
| 800 |
+
elif "DC" in file:
|
| 801 |
+
image_folder = self.folder_dict['DC']
|
| 802 |
+
elif "LA" in file:
|
| 803 |
+
image_folder = self.folder_dict['LA']
|
| 804 |
+
elif "SD" in file:
|
| 805 |
+
image_folder = self.folder_dict['SD']
|
| 806 |
+
elif "SN" in file:
|
| 807 |
+
image_folder = self.folder_dict['SN']
|
| 808 |
+
elif "TVC" in file:
|
| 809 |
+
image_folder = self.folder_dict['TVC']
|
| 810 |
+
elif "VST" in file:
|
| 811 |
+
image_folder = self.folder_dict['VST']
|
| 812 |
+
elif "GCC" in file:
|
| 813 |
+
image_folder = self.folder_dict['CC3M']
|
| 814 |
+
elif "COCO_train2014" in file:
|
| 815 |
+
image_folder = self.folder_dict['COCO2014']
|
| 816 |
+
else:
|
| 817 |
+
image_folder = self.folder_dict['COCO2017']
|
| 818 |
+
|
| 819 |
+
processor = self.data_args.image_processor
|
| 820 |
+
|
| 821 |
+
if type(image_file) is list:
|
| 822 |
+
image = [Image.open(os.path.join(image_folder, file.replace("\\", "/"))).convert('RGB') for file in
|
| 823 |
+
image_file]
|
| 824 |
+
if self.data_args.image_aspect_ratio == 'pad':
|
| 825 |
+
def expand2square(pil_img, background_color):
|
| 826 |
+
width, height = pil_img.size
|
| 827 |
+
if width == height:
|
| 828 |
+
return pil_img
|
| 829 |
+
elif width > height:
|
| 830 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
| 831 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
| 832 |
+
return result
|
| 833 |
+
else:
|
| 834 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
| 835 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
| 836 |
+
return result
|
| 837 |
+
|
| 838 |
+
image = [expand2square(i, tuple(int(x * 255) for x in processor.image_mean)) for i in image]
|
| 839 |
+
image = [processor.preprocess(i, return_tensors='pt')['pixel_values'][0] for i in image]
|
| 840 |
+
else:
|
| 841 |
+
image = [processor.preprocess(i, return_tensors='pt')['pixel_values'][0] for i in image]
|
| 842 |
+
else:
|
| 843 |
+
image = Image.open(os.path.join(image_folder, image_file.replace("\\", "/"))).convert('RGB')
|
| 844 |
+
if self.data_args.image_aspect_ratio == 'pad':
|
| 845 |
+
def expand2square(pil_img, background_color):
|
| 846 |
+
width, height = pil_img.size
|
| 847 |
+
if width == height:
|
| 848 |
+
return pil_img
|
| 849 |
+
elif width > height:
|
| 850 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
| 851 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
| 852 |
+
return result
|
| 853 |
+
else:
|
| 854 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
| 855 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
| 856 |
+
return result
|
| 857 |
+
|
| 858 |
+
image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
|
| 859 |
+
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
| 860 |
+
else:
|
| 861 |
+
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
| 862 |
+
|
| 863 |
+
sources = preprocess_multimodal(
|
| 864 |
+
copy.deepcopy([e["conversations"] for e in sources]),
|
| 865 |
+
self.data_args)
|
| 866 |
+
|
| 867 |
+
data_dict = preprocess(
|
| 868 |
+
sources,
|
| 869 |
+
self.tokenizer,
|
| 870 |
+
has_image=True)
|
| 871 |
+
|
| 872 |
+
elif "video" in sources[0]:
|
| 873 |
+
video_file = self.list_data_dict[i]['video']
|
| 874 |
+
|
| 875 |
+
if "valley" in video_file:
|
| 876 |
+
video_folder = self.folder_dict['valley']
|
| 877 |
+
else:
|
| 878 |
+
video_folder = self.folder_dict['VIDEO']
|
| 879 |
+
processor = self.data_args.image_processor
|
| 880 |
+
|
| 881 |
+
if os.path.exists(os.path.join(video_folder, video_file)):
|
| 882 |
+
image, image_token_num = _get_rawvideo_dec(os.path.join(video_folder, video_file), processor,
|
| 883 |
+
max_frames=MAX_IMAGE_LENGTH)
|
| 884 |
+
flag = 0
|
| 885 |
+
else:
|
| 886 |
+
crop_size = self.data_args.image_processor.crop_size
|
| 887 |
+
image, image_token_num = torch.zeros(3, crop_size['height'], crop_size['width']), 1
|
| 888 |
+
flag = 1
|
| 889 |
+
|
| 890 |
+
sources = preprocess_multimodal(
|
| 891 |
+
copy.deepcopy([e["conversations"] for e in sources]),
|
| 892 |
+
self.data_args, image_token_num=image_token_num)
|
| 893 |
+
|
| 894 |
+
data_dict = preprocess(
|
| 895 |
+
sources,
|
| 896 |
+
self.tokenizer,
|
| 897 |
+
has_image=True)
|
| 898 |
+
|
| 899 |
+
if flag:
|
| 900 |
+
data_dict["labels"][:] = IGNORE_INDEX
|
| 901 |
+
print(
|
| 902 |
+
f"WARNING: video load failed: {os.path.join(video_folder, video_file)}."
|
| 903 |
+
f" (ignored)"
|
| 904 |
+
)
|
| 905 |
+
|
| 906 |
+
else:
|
| 907 |
+
sources = copy.deepcopy([e["conversations"] for e in sources])
|
| 908 |
+
|
| 909 |
+
data_dict = preprocess(
|
| 910 |
+
sources,
|
| 911 |
+
self.tokenizer,
|
| 912 |
+
has_image=False)
|
| 913 |
+
|
| 914 |
+
if isinstance(i, int):
|
| 915 |
+
data_dict = dict(input_ids=data_dict["input_ids"][0],
|
| 916 |
+
labels=data_dict["labels"][0])
|
| 917 |
+
|
| 918 |
+
# image exist in the data
|
| 919 |
+
if 'image' in self.list_data_dict[i] or 'video' in self.list_data_dict[i]:
|
| 920 |
+
data_dict['image'] = image
|
| 921 |
+
elif self.data_args.is_multimodal:
|
| 922 |
+
# image does not exist in the data, but the model is multimodal
|
| 923 |
+
crop_size = self.data_args.image_processor.crop_size
|
| 924 |
+
data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
|
| 925 |
+
return data_dict
|
| 926 |
+
|
| 927 |
+
|
| 928 |
+
@dataclass
|
| 929 |
+
class DataCollatorForSupervisedDataset(object):
|
| 930 |
+
"""Collate examples for supervised fine-tuning."""
|
| 931 |
+
|
| 932 |
+
tokenizer: transformers.PreTrainedTokenizer
|
| 933 |
+
|
| 934 |
+
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
| 935 |
+
input_ids, labels = tuple([instance[key] for instance in instances]
|
| 936 |
+
for key in ("input_ids", "labels"))
|
| 937 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(
|
| 938 |
+
input_ids,
|
| 939 |
+
batch_first=True,
|
| 940 |
+
padding_value=self.tokenizer.pad_token_id)
|
| 941 |
+
labels = torch.nn.utils.rnn.pad_sequence(labels,
|
| 942 |
+
batch_first=True,
|
| 943 |
+
padding_value=IGNORE_INDEX)
|
| 944 |
+
input_ids = input_ids[:, :self.tokenizer.model_max_length]
|
| 945 |
+
labels = labels[:, :self.tokenizer.model_max_length]
|
| 946 |
+
batch = dict(
|
| 947 |
+
input_ids=input_ids,
|
| 948 |
+
labels=labels,
|
| 949 |
+
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
| 950 |
+
)
|
| 951 |
+
|
| 952 |
+
if 'image' in instances[0]:
|
| 953 |
+
images = [instance['image'] for instance in instances]
|
| 954 |
+
|
| 955 |
+
new_images = []
|
| 956 |
+
for image in images:
|
| 957 |
+
if type(image) is list:
|
| 958 |
+
for i in image:
|
| 959 |
+
new_images.append(i)
|
| 960 |
+
else:
|
| 961 |
+
new_images.append(image)
|
| 962 |
+
images = new_images
|
| 963 |
+
|
| 964 |
+
if all(x is not None and x.shape == images[0].shape for x in images):
|
| 965 |
+
batch['images'] = torch.stack(images)
|
| 966 |
+
else:
|
| 967 |
+
batch['images'] = images
|
| 968 |
+
|
| 969 |
+
return batch
|
| 970 |
+
|
| 971 |
+
|
| 972 |
+
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
|
| 973 |
+
data_args) -> Dict:
|
| 974 |
+
"""Make dataset and collator for supervised fine-tuning."""
|
| 975 |
+
train_dataset = LazySupervisedDataset(tokenizer=tokenizer, data_args=data_args)
|
| 976 |
+
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
| 977 |
+
return dict(train_dataset=train_dataset,
|
| 978 |
+
eval_dataset=None,
|
| 979 |
+
data_collator=data_collator)
|
| 980 |
+
|
| 981 |
+
|
| 982 |
+
def train():
|
| 983 |
+
global local_rank
|
| 984 |
+
|
| 985 |
+
parser = transformers.HfArgumentParser(
|
| 986 |
+
(ModelArguments, DataArguments, TrainingArguments))
|
| 987 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
| 988 |
+
local_rank = training_args.local_rank
|
| 989 |
+
compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
|
| 990 |
+
|
| 991 |
+
random.seed(training_args.seed)
|
| 992 |
+
os.environ['PYTHONHASHSEED'] = str(training_args.seed)
|
| 993 |
+
np.random.seed(training_args.seed)
|
| 994 |
+
torch.manual_seed(training_args.seed)
|
| 995 |
+
torch.cuda.manual_seed(training_args.seed)
|
| 996 |
+
torch.cuda.manual_seed_all(training_args.seed) # if you are using multi-GPU.
|
| 997 |
+
torch.backends.cudnn.benchmark = False
|
| 998 |
+
torch.backends.cudnn.deterministic = True
|
| 999 |
+
|
| 1000 |
+
bnb_model_from_pretrained_args = {}
|
| 1001 |
+
if training_args.bits in [4, 8]:
|
| 1002 |
+
from transformers import BitsAndBytesConfig
|
| 1003 |
+
bnb_model_from_pretrained_args.update(dict(
|
| 1004 |
+
device_map={"": training_args.device},
|
| 1005 |
+
load_in_4bit=training_args.bits == 4,
|
| 1006 |
+
load_in_8bit=training_args.bits == 8,
|
| 1007 |
+
quantization_config=BitsAndBytesConfig(
|
| 1008 |
+
load_in_4bit=training_args.bits == 4,
|
| 1009 |
+
load_in_8bit=training_args.bits == 8,
|
| 1010 |
+
llm_int8_threshold=6.0,
|
| 1011 |
+
llm_int8_has_fp16_weight=False,
|
| 1012 |
+
bnb_4bit_compute_dtype=compute_dtype,
|
| 1013 |
+
bnb_4bit_use_double_quant=training_args.double_quant,
|
| 1014 |
+
bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}
|
| 1015 |
+
)
|
| 1016 |
+
))
|
| 1017 |
+
|
| 1018 |
+
if model_args.vision_tower is not None:
|
| 1019 |
+
if "phi" in model_args.model_name_or_path.lower():
|
| 1020 |
+
from ChatUniVi.model.language_model.phi import ChatUniViPhiForCausalLM
|
| 1021 |
+
model = ChatUniViPhiForCausalLM.from_pretrained(
|
| 1022 |
+
model_args.model_name_or_path,
|
| 1023 |
+
cache_dir=training_args.cache_dir,
|
| 1024 |
+
**bnb_model_from_pretrained_args
|
| 1025 |
+
)
|
| 1026 |
+
else:
|
| 1027 |
+
model = ChatUniViLlamaForCausalLM.from_pretrained(
|
| 1028 |
+
model_args.model_name_or_path,
|
| 1029 |
+
cache_dir=training_args.cache_dir,
|
| 1030 |
+
**bnb_model_from_pretrained_args
|
| 1031 |
+
)
|
| 1032 |
+
else:
|
| 1033 |
+
model = transformers.LlamaForCausalLM.from_pretrained(
|
| 1034 |
+
model_args.model_name_or_path,
|
| 1035 |
+
cache_dir=training_args.cache_dir,
|
| 1036 |
+
**bnb_model_from_pretrained_args
|
| 1037 |
+
)
|
| 1038 |
+
model.config.use_cache = False
|
| 1039 |
+
|
| 1040 |
+
if model_args.freeze_backbone:
|
| 1041 |
+
model.model.requires_grad_(False)
|
| 1042 |
+
|
| 1043 |
+
if training_args.bits in [4, 8]:
|
| 1044 |
+
from peft import prepare_model_for_kbit_training
|
| 1045 |
+
model.config.torch_dtype = (
|
| 1046 |
+
torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
|
| 1047 |
+
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
|
| 1048 |
+
|
| 1049 |
+
if training_args.gradient_checkpointing:
|
| 1050 |
+
if hasattr(model, "enable_input_require_grads"):
|
| 1051 |
+
model.enable_input_require_grads()
|
| 1052 |
+
else:
|
| 1053 |
+
def make_inputs_require_grad(module, input, output):
|
| 1054 |
+
output.requires_grad_(True)
|
| 1055 |
+
|
| 1056 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
| 1057 |
+
|
| 1058 |
+
if training_args.lora_enable:
|
| 1059 |
+
from peft import LoraConfig, get_peft_model
|
| 1060 |
+
lora_config = LoraConfig(
|
| 1061 |
+
r=training_args.lora_r,
|
| 1062 |
+
lora_alpha=training_args.lora_alpha,
|
| 1063 |
+
target_modules=find_all_linear_names(model),
|
| 1064 |
+
lora_dropout=training_args.lora_dropout,
|
| 1065 |
+
bias=training_args.lora_bias,
|
| 1066 |
+
task_type="CAUSAL_LM",
|
| 1067 |
+
)
|
| 1068 |
+
if training_args.bits == 16:
|
| 1069 |
+
if training_args.bf16:
|
| 1070 |
+
model.to(torch.bfloat16)
|
| 1071 |
+
if training_args.fp16:
|
| 1072 |
+
model.to(torch.float16)
|
| 1073 |
+
rank0_print("Adding LoRA adapters...")
|
| 1074 |
+
model = get_peft_model(model, lora_config)
|
| 1075 |
+
|
| 1076 |
+
if 'mpt' in model_args.model_name_or_path:
|
| 1077 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 1078 |
+
model_args.model_name_or_path,
|
| 1079 |
+
cache_dir=training_args.cache_dir,
|
| 1080 |
+
model_max_length=training_args.model_max_length,
|
| 1081 |
+
padding_side="right"
|
| 1082 |
+
)
|
| 1083 |
+
else:
|
| 1084 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 1085 |
+
model_args.model_name_or_path,
|
| 1086 |
+
cache_dir=training_args.cache_dir,
|
| 1087 |
+
model_max_length=training_args.model_max_length,
|
| 1088 |
+
padding_side="right",
|
| 1089 |
+
use_fast=True,
|
| 1090 |
+
)
|
| 1091 |
+
|
| 1092 |
+
if model_args.version == "v0":
|
| 1093 |
+
if tokenizer.pad_token is None:
|
| 1094 |
+
smart_tokenizer_and_embedding_resize(
|
| 1095 |
+
special_tokens_dict=dict(pad_token="[PAD]"),
|
| 1096 |
+
tokenizer=tokenizer,
|
| 1097 |
+
model=model,
|
| 1098 |
+
)
|
| 1099 |
+
if "llama" in model_args.model_name_or_path.lower():
|
| 1100 |
+
tokenizer.add_special_tokens({
|
| 1101 |
+
"eos_token": "</s>",
|
| 1102 |
+
"bos_token": "<s>",
|
| 1103 |
+
"unk_token": "<unk>",
|
| 1104 |
+
})
|
| 1105 |
+
elif model_args.version == "v0.5":
|
| 1106 |
+
tokenizer.pad_token = tokenizer.unk_token
|
| 1107 |
+
elif model_args.version == "phi":
|
| 1108 |
+
tokenizer.pad_token = tokenizer.unk_token
|
| 1109 |
+
conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
|
| 1110 |
+
else:
|
| 1111 |
+
tokenizer.pad_token = tokenizer.unk_token
|
| 1112 |
+
if model_args.version in conversation_lib.conv_templates:
|
| 1113 |
+
conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
|
| 1114 |
+
else:
|
| 1115 |
+
conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"]
|
| 1116 |
+
|
| 1117 |
+
if model_args.vision_tower is not None:
|
| 1118 |
+
model.get_model().initialize_vision_modules(
|
| 1119 |
+
model_args=model_args,
|
| 1120 |
+
fsdp=training_args.fsdp
|
| 1121 |
+
)
|
| 1122 |
+
|
| 1123 |
+
vision_tower = model.get_vision_tower()
|
| 1124 |
+
vision_tower.to(dtype=torch.float16, device=training_args.device)
|
| 1125 |
+
|
| 1126 |
+
data_args.image_processor = vision_tower.image_processor
|
| 1127 |
+
data_args.is_multimodal = True
|
| 1128 |
+
|
| 1129 |
+
model.config.image_aspect_ratio = data_args.image_aspect_ratio
|
| 1130 |
+
model.config.image_grid_pinpoints = data_args.image_grid_pinpoints
|
| 1131 |
+
|
| 1132 |
+
model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter
|
| 1133 |
+
if model_args.tune_mm_mlp_adapter:
|
| 1134 |
+
model.requires_grad_(False)
|
| 1135 |
+
for p in model.get_model().mm_projector.parameters():
|
| 1136 |
+
p.requires_grad = True
|
| 1137 |
+
|
| 1138 |
+
model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter
|
| 1139 |
+
if training_args.freeze_mm_mlp_adapter:
|
| 1140 |
+
for p in model.get_model().mm_projector.parameters():
|
| 1141 |
+
p.requires_grad = False
|
| 1142 |
+
|
| 1143 |
+
if training_args.bits in [4, 8]:
|
| 1144 |
+
model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device)
|
| 1145 |
+
|
| 1146 |
+
model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end
|
| 1147 |
+
training_args.use_im_start_end = model_args.mm_use_im_start_end
|
| 1148 |
+
|
| 1149 |
+
model.config.mm_use_box_start_end = data_args.mm_use_box_start_end = model_args.mm_use_box_start_end
|
| 1150 |
+
training_args.use_im_start_end = model_args.mm_use_box_start_end
|
| 1151 |
+
|
| 1152 |
+
model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token
|
| 1153 |
+
model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)
|
| 1154 |
+
|
| 1155 |
+
model_config = ModelConfig[str(model_args.model_use)]
|
| 1156 |
+
model.config.aarchitectures = "LlavaLlamaForCausalLM"
|
| 1157 |
+
|
| 1158 |
+
model.config.config = model_config
|
| 1159 |
+
model_args.use_cluster = model_config["use_cluster"]
|
| 1160 |
+
model_args.spatial_cluster_rate0 = model_config["spatial_cluster_rate0"]
|
| 1161 |
+
model_args.spatial_cluster_rate1 = model_config["spatial_cluster_rate1"]
|
| 1162 |
+
model_args.spatial_cluster_rate2 = model_config["spatial_cluster_rate2"]
|
| 1163 |
+
model_args.temporal_cluster_rate = model_config.get("temporal_cluster_rate", 1 / 16)
|
| 1164 |
+
model.get_model().initialize_cluster_modules(model_args)
|
| 1165 |
+
|
| 1166 |
+
if model_args.use_cluster:
|
| 1167 |
+
for n, p in model.named_parameters():
|
| 1168 |
+
if "block" in n or "ctm" in n:
|
| 1169 |
+
p.requires_grad = True
|
| 1170 |
+
|
| 1171 |
+
if model.config.config["freeze"]:
|
| 1172 |
+
for n, p in model.named_parameters():
|
| 1173 |
+
if "block" not in n and "ctm" not in n:
|
| 1174 |
+
p.requires_grad = False
|
| 1175 |
+
|
| 1176 |
+
if model.config.config["mm_tune"]:
|
| 1177 |
+
for p in model.get_model().mm_projector.parameters():
|
| 1178 |
+
p.requires_grad = True
|
| 1179 |
+
|
| 1180 |
+
model_args.vision_tune = model_config["vision_tune"]
|
| 1181 |
+
for p in model.get_vision_tower().parameters():
|
| 1182 |
+
p.requires_grad = model_args.vision_tune
|
| 1183 |
+
|
| 1184 |
+
params_need_grad = [n for n, p in model.named_parameters() if p.requires_grad]
|
| 1185 |
+
print("Parameters require gradients: {}".format(params_need_grad))
|
| 1186 |
+
|
| 1187 |
+
if training_args.bits in [4, 8]:
|
| 1188 |
+
from peft.tuners.lora import LoraLayer
|
| 1189 |
+
for name, module in model.named_modules():
|
| 1190 |
+
if isinstance(module, LoraLayer):
|
| 1191 |
+
if training_args.bf16:
|
| 1192 |
+
module = module.to(torch.bfloat16)
|
| 1193 |
+
if 'norm' in name:
|
| 1194 |
+
module = module.to(torch.float32)
|
| 1195 |
+
if 'lm_head' in name or 'embed_tokens' in name:
|
| 1196 |
+
if hasattr(module, 'weight'):
|
| 1197 |
+
if training_args.bf16 and module.weight.dtype == torch.float32:
|
| 1198 |
+
module = module.to(torch.bfloat16)
|
| 1199 |
+
|
| 1200 |
+
data_module = make_supervised_data_module(tokenizer=tokenizer,
|
| 1201 |
+
data_args=data_args)
|
| 1202 |
+
|
| 1203 |
+
trainer = ChatUniViTrainer(model=model,
|
| 1204 |
+
tokenizer=tokenizer,
|
| 1205 |
+
args=training_args,
|
| 1206 |
+
**data_module)
|
| 1207 |
+
|
| 1208 |
+
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
|
| 1209 |
+
trainer.train(resume_from_checkpoint=True)
|
| 1210 |
+
else:
|
| 1211 |
+
trainer.train()
|
| 1212 |
+
|
| 1213 |
+
model.config.use_cache = True
|
| 1214 |
+
|
| 1215 |
+
if training_args.lora_enable:
|
| 1216 |
+
state_dict = get_peft_state_maybe_zero_3(
|
| 1217 |
+
model.named_parameters(), training_args.lora_bias
|
| 1218 |
+
)
|
| 1219 |
+
non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
|
| 1220 |
+
model.named_parameters()
|
| 1221 |
+
)
|
| 1222 |
+
if training_args.local_rank == 0 or training_args.local_rank == -1:
|
| 1223 |
+
model.config.save_pretrained(training_args.output_dir)
|
| 1224 |
+
model.save_pretrained(training_args.output_dir, state_dict=state_dict)
|
| 1225 |
+
torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))
|
| 1226 |
+
else:
|
| 1227 |
+
safe_save_model_for_hf_trainer(trainer=trainer,
|
| 1228 |
+
output_dir=training_args.output_dir)
|
| 1229 |
+
|
| 1230 |
+
|
| 1231 |
+
if __name__ == "__main__":
|
| 1232 |
+
train()
|
ChatUniVi/train/train_mem.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
|
| 2 |
+
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
|
| 3 |
+
# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
|
| 4 |
+
|
| 5 |
+
# Need to call this before importing transformers.
|
| 6 |
+
from ChatUniVi.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
|
| 7 |
+
|
| 8 |
+
replace_llama_attn_with_flash_attn()
|
| 9 |
+
|
| 10 |
+
from ChatUniVi.train.train import train
|
| 11 |
+
|
| 12 |
+
if __name__ == "__main__":
|
| 13 |
+
train()
|
ChatUniVi/train/trainer.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import Trainer
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def maybe_zero_3(param, ignore_status=False, name=None):
|
| 8 |
+
from deepspeed import zero
|
| 9 |
+
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
|
| 10 |
+
if hasattr(param, "ds_id"):
|
| 11 |
+
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
|
| 12 |
+
if not ignore_status:
|
| 13 |
+
print(name, 'no ignore status')
|
| 14 |
+
with zero.GatheredParameters([param]):
|
| 15 |
+
param = param.data.detach().cpu().clone()
|
| 16 |
+
else:
|
| 17 |
+
param = param.detach().cpu().clone()
|
| 18 |
+
return param
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
|
| 22 |
+
to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
|
| 23 |
+
to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}
|
| 24 |
+
return to_return
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class ChatUniViTrainer(Trainer):
|
| 28 |
+
def _save_checkpoint(self, model, trial, metrics=None):
|
| 29 |
+
if 0 and getattr(self.args, 'tune_mm_mlp_adapter', False):
|
| 30 |
+
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
| 31 |
+
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
| 32 |
+
|
| 33 |
+
run_dir = self._get_output_dir(trial=trial)
|
| 34 |
+
output_dir = os.path.join(run_dir, checkpoint_folder)
|
| 35 |
+
|
| 36 |
+
# Only save Adapter
|
| 37 |
+
keys_to_match = ['mm_projector', "ctm", "block"]
|
| 38 |
+
if getattr(self.args, "use_im_start_end", False):
|
| 39 |
+
keys_to_match.extend(['embed_tokens', 'embed_in'])
|
| 40 |
+
|
| 41 |
+
weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match)
|
| 42 |
+
|
| 43 |
+
if self.args.local_rank == 0 or self.args.local_rank == -1:
|
| 44 |
+
self.model.config.save_pretrained(output_dir)
|
| 45 |
+
torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
|
| 46 |
+
else:
|
| 47 |
+
super(ChatUniViTrainer, self)._save_checkpoint(model, trial, metrics)
|
| 48 |
+
|
| 49 |
+
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
| 50 |
+
if 0 and getattr(self.args, 'tune_mm_mlp_adapter', False):
|
| 51 |
+
pass
|
| 52 |
+
else:
|
| 53 |
+
super(ChatUniViTrainer, self)._save(output_dir, state_dict)
|
configs/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .config import args
|
configs/config.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from email.policy import default
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import sys
|
| 5 |
+
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 6 |
+
sys.path.append(BASE_DIR)
|
| 7 |
+
|
| 8 |
+
import cv2 # type: ignore
|
| 9 |
+
|
| 10 |
+
import argparse
|
| 11 |
+
import json
|
| 12 |
+
import os
|
| 13 |
+
from typing import Any, Dict, List
|
| 14 |
+
|
| 15 |
+
# 数据集结构
|
| 16 |
+
file_arch = """
|
| 17 |
+
./REFAVS/data
|
| 18 |
+
- /media
|
| 19 |
+
- /gt_mask
|
| 20 |
+
- /metadata.csv
|
| 21 |
+
- /audio_embed
|
| 22 |
+
- /image_embed
|
| 23 |
+
"""
|
| 24 |
+
# print(f">>> File arch: {file_arch}")
|
| 25 |
+
|
| 26 |
+
parser = argparse.ArgumentParser(
|
| 27 |
+
description=(
|
| 28 |
+
"SimToken"
|
| 29 |
+
)
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
parser.add_argument("--vision_pretrained",type=str,default='/workspace/SimToken/models/segment_anything/sam_vit_h_4b8939.pth')
|
| 35 |
+
parser.add_argument("--vision_tower",type=str,default='openai/clip-vit-large-patch14')
|
| 36 |
+
parser.add_argument("--mllm",type=str,default='Chat-UniVi/Chat-UniVi-7B-v1.5')
|
| 37 |
+
|
| 38 |
+
parser.add_argument("--conv_template",type=int,default=1)
|
| 39 |
+
parser.add_argument("--ct_weight",type=float,default=0.1)
|
| 40 |
+
parser.add_argument("--input_type",type=str,default='refer')
|
| 41 |
+
parser.add_argument("--compress",action='store_false',default=True)
|
| 42 |
+
parser.add_argument("--start",type=int,default=0)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
parser.add_argument("--name",type=str,default='testrun')
|
| 46 |
+
# path to ref-avs dataset
|
| 47 |
+
parser.add_argument("--data_dir",type=str,default='/workspace/SimToken/data',help=f"The data paranet dir. File arch should be: {file_arch}")
|
| 48 |
+
# path to pretrained checkpoints
|
| 49 |
+
parser.add_argument("--saved_model",type=str,default='/workspace/SimToken/checkpoints/simtoken_pretrained.pth', help="the pretrained simtoken pth")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
parser.add_argument("--log_root",type=str,default='log', help="where to save log during training")
|
| 53 |
+
parser.add_argument("--checkpoint_root",type=str,default='checkpoints', help="where to save trained checkpoints during training")
|
| 54 |
+
|
| 55 |
+
parser.add_argument("--visualization_root",type=str,default='visualization', help="where to save visualization result during test")
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# parser.add_argument("--show_params", action='store_true', help=f"Show params names with Requires_grad==True.")
|
| 61 |
+
|
| 62 |
+
# learning rate
|
| 63 |
+
parser.add_argument("--lr", type=float, default=5e-5, help='lr to fine tuning adapters.')
|
| 64 |
+
# epochs
|
| 65 |
+
parser.add_argument("--epochs", type=int, default=10, help='epochs to fine tuning adapters.')
|
| 66 |
+
parser.add_argument("--batch_size", type=int, default=8)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
parser.add_argument("--gpu_id", type=str, default="0", help="The GPU device to run generation on.")
|
| 70 |
+
|
| 71 |
+
parser.add_argument("--run", type=str, default='train', help="train, test")
|
| 72 |
+
|
| 73 |
+
parser.add_argument("--frame_n", type=int, default=10, help="Frame num of each video. Fixed to 10.")
|
| 74 |
+
parser.add_argument("--text_max_len", type=int, default=25, help="Maximum textual reference length.")
|
| 75 |
+
parser.add_argument("--max_eval_rows", type=int, default=-1, help="Max samples per split during eval; -1 = all.")
|
| 76 |
+
parser.add_argument("--eval_split", type=str, default="test_u", help="Which split to evaluate: test_s, test_u, test_n.")
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
args = parser.parse_args()
|
| 81 |
+
|
| 82 |
+
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
| 83 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
|
| 84 |
+
# print(f'>>> Sys: set "CUDA_VISIBLE_DEVICES" - GPU: {args.gpu_id}')
|
data/metadata.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|