yfan07 commited on
Commit
a47e733
·
verified ·
1 Parent(s): f1106d1

Add files using upload-large-folder tool

Browse files
Files changed (35) hide show
  1. ChatUniVi/eval/questions/video_qa/msvd_qa.json +0 -0
  2. ChatUniVi/eval/questions/video_qa/temporal_qa.json +0 -0
  3. ChatUniVi/eval/questions/video_qa/tgif_a_list.json +1309 -0
  4. ChatUniVi/eval/questions/video_qa/tgif_qa.json +0 -0
  5. ChatUniVi/eval/table/caps_boxes_coco2014_val_80.jsonl +80 -0
  6. ChatUniVi/eval/table/model.jsonl +5 -0
  7. ChatUniVi/eval/table/question.jsonl +80 -0
  8. ChatUniVi/eval/table/reviewer.jsonl +4 -0
  9. ChatUniVi/eval/table/rule.json +11 -0
  10. ChatUniVi/model/__init__.py +1 -0
  11. ChatUniVi/model/apply_delta.py +44 -0
  12. ChatUniVi/model/arch.py +652 -0
  13. ChatUniVi/model/builder.py +118 -0
  14. ChatUniVi/model/cluster.py +287 -0
  15. ChatUniVi/model/consolidate.py +29 -0
  16. ChatUniVi/model/dataloader.py +67 -0
  17. ChatUniVi/model/language_model/language_model/configuration_phi.py +62 -0
  18. ChatUniVi/model/language_model/language_model/modeling_phi.py +984 -0
  19. ChatUniVi/model/language_model/llama.py +136 -0
  20. ChatUniVi/model/language_model/phi.py +142 -0
  21. ChatUniVi/model/make_delta.py +52 -0
  22. ChatUniVi/model/multimodal_encoder/builder.py +14 -0
  23. ChatUniVi/model/multimodal_encoder/clip_encoder.py +83 -0
  24. ChatUniVi/model/multimodal_encoder/eva_encoder.py +81 -0
  25. ChatUniVi/model/multimodal_encoder/eva_vit.py +448 -0
  26. ChatUniVi/model/multimodal_encoder/processor.py +68 -0
  27. ChatUniVi/model/multimodal_encoder/utils.py +137 -0
  28. ChatUniVi/model/multimodal_projector/builder.py +52 -0
  29. ChatUniVi/train/llama_flash_attn_monkey_patch.py +124 -0
  30. ChatUniVi/train/train.py +1232 -0
  31. ChatUniVi/train/train_mem.py +13 -0
  32. ChatUniVi/train/trainer.py +53 -0
  33. configs/__init__.py +1 -0
  34. configs/config.py +84 -0
  35. data/metadata.csv +0 -0
ChatUniVi/eval/questions/video_qa/msvd_qa.json ADDED
The diff for this file is too large to render. See raw diff
 
ChatUniVi/eval/questions/video_qa/temporal_qa.json ADDED
The diff for this file is too large to render. See raw diff
 
ChatUniVi/eval/questions/video_qa/tgif_a_list.json ADDED
@@ -0,0 +1,1309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ "cookie",
3
+ "? machine",
4
+ "two",
5
+ "glasses",
6
+ "black",
7
+ "tail",
8
+ "red",
9
+ "flowers",
10
+ "laptop",
11
+ "three",
12
+ "white",
13
+ "green",
14
+ "? boat",
15
+ "blue",
16
+ "? room",
17
+ "brown",
18
+ "cat",
19
+ "picture",
20
+ "drink",
21
+ "cigarette",
22
+ "clock",
23
+ "car",
24
+ "monkey",
25
+ "guitar",
26
+ "purple",
27
+ "? kitchen",
28
+ "? mirror",
29
+ "meal",
30
+ "four",
31
+ "? tank",
32
+ "? classroom",
33
+ "dog",
34
+ "pipe",
35
+ "leaf",
36
+ "shirt",
37
+ "champagne",
38
+ "string",
39
+ "sweater",
40
+ "? studio",
41
+ "tortoise",
42
+ "and one of them is holding ? dog",
43
+ "rings",
44
+ "vehicles",
45
+ "lollipop",
46
+ "candy",
47
+ "bottle",
48
+ "then a man is shown sitting . ? locker",
49
+ "parakeets",
50
+ "hole",
51
+ "tie",
52
+ "boat",
53
+ "ball",
54
+ "cash",
55
+ "chicken",
56
+ "? street",
57
+ "bird",
58
+ "six",
59
+ "? pool",
60
+ "window",
61
+ "round",
62
+ "instrument",
63
+ "puppy",
64
+ "doorway",
65
+ "juice",
66
+ "flamethrower",
67
+ "gray",
68
+ "dress",
69
+ "hat",
70
+ "kitten",
71
+ "gun",
72
+ "cars",
73
+ "paws",
74
+ "elephant",
75
+ "beam",
76
+ "? chair",
77
+ "chimp",
78
+ "one",
79
+ "butt",
80
+ "mascara",
81
+ "dogs",
82
+ "puppet",
83
+ "hamster",
84
+ "? bedroom",
85
+ "who pretends to slap him in return ? crack",
86
+ "machine",
87
+ "drops",
88
+ "then he removes and throws it to the ground ? hat",
89
+ "when two of the cyclist crash ? bicycles",
90
+ "cannabis",
91
+ "? trap",
92
+ "helmet",
93
+ "motorcycle",
94
+ "purses",
95
+ "bank",
96
+ "orange",
97
+ "guitars",
98
+ "? crib",
99
+ "hedgehog",
100
+ "? hallway",
101
+ "? car",
102
+ "steps",
103
+ "horse",
104
+ "? bath",
105
+ "drawer",
106
+ "cats",
107
+ "duck",
108
+ "wearing , reads a piece of paper on a desk and then raises his head ? glasses",
109
+ "phone",
110
+ "pillow",
111
+ "cup",
112
+ "he has food in front of him . ? chair",
113
+ "surfboard",
114
+ "before one of them climbs from the ring ? two",
115
+ "dancing , and clapping ? four",
116
+ "pool",
117
+ "motorcycles",
118
+ "pictures",
119
+ "? star",
120
+ "clipboard",
121
+ "paw",
122
+ "kiss ? two",
123
+ "turtle",
124
+ "when one touches the other on the shoulder ? two",
125
+ "? house",
126
+ "five",
127
+ "locker",
128
+ "tree",
129
+ "bat",
130
+ "popcorn",
131
+ "broom",
132
+ "guns",
133
+ "paint",
134
+ "seat",
135
+ "and then they run away ? heels",
136
+ "flags",
137
+ "dice",
138
+ "? library",
139
+ "yellow",
140
+ "chair",
141
+ "door",
142
+ "? warehouse",
143
+ "kick it and fall over ? tire",
144
+ "jacket",
145
+ "wire",
146
+ "crow",
147
+ "motions",
148
+ "bubbles",
149
+ "vehicle",
150
+ "wearing and speaking ? necklace",
151
+ "one is dressed funny , look at each other ? two",
152
+ "mice",
153
+ "clothing",
154
+ "bread",
155
+ "fireworks",
156
+ "microphone",
157
+ "mascot",
158
+ "? booth",
159
+ "wolf",
160
+ "? foyer",
161
+ "driver",
162
+ "cylinder",
163
+ "on top of his food bowl ? dog",
164
+ "rabbit",
165
+ "? office",
166
+ "treadmill",
167
+ "cap",
168
+ "tire",
169
+ "stick",
170
+ "is laying and opening her eyes . ? bed",
171
+ "stairs",
172
+ "drums",
173
+ "bar",
174
+ "? bed",
175
+ "spoons",
176
+ "? lab",
177
+ "headphones",
178
+ "one is . ? basket",
179
+ "makeup",
180
+ "frogs",
181
+ "wine",
182
+ "two men sit on a sofa and a man dances along a red carpet ? rectangle",
183
+ "sauce",
184
+ "airplane",
185
+ "and he is playing ? guitar",
186
+ "fox",
187
+ "costume",
188
+ "slide",
189
+ "stamp",
190
+ "butts",
191
+ "? window",
192
+ "rope",
193
+ "receiver",
194
+ "then the dog turns around crazy ? butt",
195
+ "and one talks to someone . ? room",
196
+ "? aisle",
197
+ "headset",
198
+ "horses",
199
+ "handgun",
200
+ "bear",
201
+ "napkin",
202
+ "? bottle",
203
+ "frog",
204
+ "wearing , animal print pants and pink shoes is dancing on a sidewalk ? shirt",
205
+ "bicycle",
206
+ "button",
207
+ "panda",
208
+ "turtles",
209
+ "but keeps flying ? airplane",
210
+ "? headset",
211
+ "lobby",
212
+ "pelican",
213
+ "dive",
214
+ "? cage",
215
+ "dishes",
216
+ "wagon",
217
+ "seven",
218
+ "? bag",
219
+ "butterfly",
220
+ "flask",
221
+ "banana",
222
+ "flasks",
223
+ "bus",
224
+ "device",
225
+ "is riding through the house ? bicycle",
226
+ "bright lightning ? sky",
227
+ "umbrellas",
228
+ "yawns then puts out its paw and pushes a jar off onto the floor ? cat",
229
+ "skateboard",
230
+ "cupcakes",
231
+ "shoe",
232
+ "cloak",
233
+ "apple",
234
+ "wall",
235
+ "horns",
236
+ "trick",
237
+ "date",
238
+ "he is talking to a woman ? beer",
239
+ "hill",
240
+ "? bar",
241
+ "pieces",
242
+ "stars",
243
+ "and the bowl disappears ? dog",
244
+ "bridge",
245
+ "box",
246
+ "with one of them embracing the other from behind ? two",
247
+ "piano",
248
+ "? hall",
249
+ "coffee",
250
+ "peel",
251
+ "cutter",
252
+ "circle",
253
+ "sunglasses",
254
+ "star",
255
+ "? pen",
256
+ "they move slowly ? stairs",
257
+ "kitty",
258
+ "pen",
259
+ "owl",
260
+ "puppies",
261
+ "fish",
262
+ "keyboard",
263
+ "underwear",
264
+ "? gym",
265
+ "pigeon",
266
+ "retriever",
267
+ "masks",
268
+ "kangaroo",
269
+ "close",
270
+ "shorts",
271
+ "band",
272
+ "swimming",
273
+ "? plate",
274
+ "then another man reaches for it ? gun",
275
+ "face",
276
+ "ferret",
277
+ "drug",
278
+ "clothes",
279
+ "spoon",
280
+ "hurdle",
281
+ "grass",
282
+ "? paint",
283
+ "airplanes",
284
+ "talks",
285
+ "whose lights flash on ? flower",
286
+ "with one drumming ? instruments",
287
+ "? bowl",
288
+ "burger",
289
+ "llama",
290
+ "it licks its lips ? horse",
291
+ "? holder",
292
+ "camel",
293
+ "dancing",
294
+ "umbrella",
295
+ "pants",
296
+ "ducklings",
297
+ "mug",
298
+ "necklace",
299
+ "track",
300
+ "smoking and turning her head ? cigarette",
301
+ "ladder",
302
+ "cliff",
303
+ "shirts",
304
+ "shark",
305
+ "is playing ? ukulele",
306
+ "turns",
307
+ "? ball",
308
+ "scooter",
309
+ "? box",
310
+ "? road",
311
+ "cover",
312
+ ". ? cage",
313
+ "backhoe",
314
+ "bed",
315
+ "and she is holding up ? puppet",
316
+ "? two",
317
+ "goblet",
318
+ "is using and smoking a cigarette ? phone",
319
+ "wearing coats , is hugging . ? hallway",
320
+ "but he misses ? ball",
321
+ "diver",
322
+ "? nightclub",
323
+ "they both smile ? round",
324
+ "medic",
325
+ "? stick",
326
+ "train",
327
+ "? microphone",
328
+ "cigar",
329
+ "wearing , comes through a door held open by another man ? suit",
330
+ "wheel",
331
+ "lions",
332
+ "tights",
333
+ "racetrack",
334
+ "one picks up the other and carries him ? two",
335
+ "sun",
336
+ "? floor",
337
+ "beer",
338
+ "berries",
339
+ "mask",
340
+ "heels",
341
+ "decorator",
342
+ "cub",
343
+ "breakfast",
344
+ ". ? chair",
345
+ "then looks away ? monkey",
346
+ "? bucket",
347
+ "snack",
348
+ "girl",
349
+ "suspenders",
350
+ "toy",
351
+ "elephants",
352
+ "boar",
353
+ "bubble",
354
+ "falls off and he grabs it ? hat",
355
+ "trunk",
356
+ "and one of them climbs from one to the other ? frogs",
357
+ "floor",
358
+ "belt",
359
+ "octopus",
360
+ "? dish",
361
+ "truck",
362
+ "snowmobile",
363
+ "standing in the dark , wears ? dress",
364
+ "? bathtub",
365
+ "trees",
366
+ "? mall",
367
+ "bow",
368
+ "beat to the rhythm ? sticks",
369
+ "? store",
370
+ "but stops him ? rope",
371
+ "pug",
372
+ "headgear",
373
+ "tubes",
374
+ "dance",
375
+ "pandas",
376
+ "iguana",
377
+ "concert",
378
+ "dandelion",
379
+ "? garden",
380
+ "queen",
381
+ "instruments",
382
+ "tricycle",
383
+ "racing",
384
+ "? garage",
385
+ "horn",
386
+ "entrance",
387
+ "can",
388
+ "chimpanzee",
389
+ "but the bear cub does ? bear",
390
+ "glass",
391
+ "birds",
392
+ "screaming and pointing ? two",
393
+ "robot",
394
+ "sky",
395
+ "egg",
396
+ "moth",
397
+ "backpack",
398
+ "beverages",
399
+ "bouquet",
400
+ "trumpet",
401
+ "carpet",
402
+ "? apartment",
403
+ "pony",
404
+ "goat",
405
+ "headdress",
406
+ "and he is removing ? hat",
407
+ "house",
408
+ "suit",
409
+ "gum",
410
+ "curb",
411
+ "and then leaves it ? car",
412
+ "snake",
413
+ "he looks at his passenger who is sleeping ? car",
414
+ "? bow-tie",
415
+ "wig",
416
+ "raising a cloud of dust ? car",
417
+ "freezer",
418
+ "delivering , and signing ? flowers",
419
+ "skis",
420
+ "road",
421
+ "deal",
422
+ "ship",
423
+ "? bathroom",
424
+ "bills",
425
+ "piece",
426
+ "items fall out and she makes a face ? door",
427
+ "drinks",
428
+ "dives , . ? cafeteria",
429
+ "goggles",
430
+ "? wagon",
431
+ "man",
432
+ "cups",
433
+ "dolphin",
434
+ "card",
435
+ "building",
436
+ "trunks",
437
+ "liquor",
438
+ "scarf",
439
+ "squash",
440
+ "cheese",
441
+ "then the snake kisses her ? snake",
442
+ "dances seductively ? dress",
443
+ "sword",
444
+ "kiss",
445
+ "possum",
446
+ "stockings",
447
+ "? tray",
448
+ "the one man yells ? two",
449
+ "and she is playing ? guitar",
450
+ "? alley",
451
+ "also wearing ? helmet",
452
+ "beverage",
453
+ "weapon",
454
+ "rodent",
455
+ "beach",
456
+ "? cereals",
457
+ "bench",
458
+ "with two holding glass bottles with colored liquid ? five",
459
+ "holding , jumps in the air and then moves to the back of stage ? guitar",
460
+ "transportation",
461
+ "shampoo",
462
+ "caps",
463
+ "hook",
464
+ "squirrel",
465
+ "scenery",
466
+ "playing",
467
+ "? wheelchair",
468
+ "performer",
469
+ "cake",
470
+ "dancing and playing ? instruments",
471
+ "boxes",
472
+ "leash",
473
+ "? bouquet",
474
+ "but only one arm is . ? sleeve",
475
+ "rifles",
476
+ "lenses",
477
+ "the girl watches him . ? building",
478
+ "almonds",
479
+ "tank",
480
+ "pot",
481
+ "bracelet",
482
+ "knife",
483
+ "mouse",
484
+ "who then catches it ? bottle",
485
+ "exercise",
486
+ "and he is turning around ? wand",
487
+ "purse",
488
+ "stones",
489
+ "show",
490
+ "bag",
491
+ "stocking",
492
+ "balloon",
493
+ "stops , and its tongue remains stuck out ? cat",
494
+ "scythe",
495
+ "creature",
496
+ "cello",
497
+ "and ends up on its back ? bird",
498
+ "pup",
499
+ "? container",
500
+ "and one blows a kiss ? two",
501
+ "animal",
502
+ "trampoline",
503
+ "before they turn and walk away ? two",
504
+ "cloaks",
505
+ "blackjack",
506
+ "as they hit fist to fist ? two",
507
+ "bicycles",
508
+ "watch",
509
+ "corgi",
510
+ "spider",
511
+ "earring",
512
+ "bull",
513
+ "? wheel",
514
+ "? stadium",
515
+ "looking at each other ? two",
516
+ "foxes",
517
+ "mammal",
518
+ "sheep",
519
+ "chases",
520
+ "? armchair",
521
+ ". ? room",
522
+ "dancing , and playing ? instruments",
523
+ "which then falls backwards ? cat",
524
+ "dancer",
525
+ "boots",
526
+ "rotors",
527
+ "? ranch",
528
+ "? shower",
529
+ "paper , scissors as they stand by the door ? two",
530
+ "laying and crying on her pillow . ? bed",
531
+ "pencil",
532
+ "when one side scores a goal ? two",
533
+ "food",
534
+ "one with an arm on the other ? two",
535
+ "sheets",
536
+ "rabbits",
537
+ "pizza",
538
+ "? glove",
539
+ "table",
540
+ "scratched",
541
+ "syrup",
542
+ "cone",
543
+ "while the larger man breaks up the fight ? two",
544
+ "drives",
545
+ "luggage",
546
+ "? vehicle",
547
+ "lift",
548
+ "frame",
549
+ "shoes",
550
+ "opens the door , and the cat and four dogs enter through the door ? building",
551
+ "blinks",
552
+ "crotch",
553
+ "dishwasher",
554
+ "skills",
555
+ "sleeves",
556
+ "model",
557
+ "ties",
558
+ "modeling",
559
+ "bath",
560
+ "jet",
561
+ "tortillas",
562
+ "teapot",
563
+ "barbel",
564
+ "cartwheel",
565
+ "musician",
566
+ "rhino",
567
+ "exits",
568
+ "pole",
569
+ "ski",
570
+ "pajama",
571
+ "woodchucks",
572
+ "lanes",
573
+ "candle",
574
+ "tag",
575
+ "gloves",
576
+ "dinosaur",
577
+ "surface",
578
+ "? tub",
579
+ "snowboard",
580
+ "wearing , hops around her couch while pointing at her face ? glasses",
581
+ "donut",
582
+ "mustard",
583
+ "? tunnel",
584
+ "? theater",
585
+ "wheels",
586
+ "rat",
587
+ "and one talks to someone ? two",
588
+ "bungee",
589
+ "but then suddenly takes off again ? jet",
590
+ "? rink",
591
+ "face shown . ? mirror",
592
+ "shell",
593
+ "costumes",
594
+ "? shield",
595
+ "confetti",
596
+ "flower",
597
+ "gesture",
598
+ "portfolio",
599
+ "and moves from under him ? ball",
600
+ "violin",
601
+ "photographs",
602
+ "uniforms",
603
+ "money",
604
+ "bomb",
605
+ "? rv",
606
+ "claws",
607
+ "lands",
608
+ "turnstile",
609
+ "bot",
610
+ "hose",
611
+ "suitcase",
612
+ "sitting on a table , reaches out and pushes a glass off the table ? paw",
613
+ "mountain",
614
+ "tools",
615
+ "headsets",
616
+ "the streets crumble below it ? airplane",
617
+ "t-shirt",
618
+ "doors",
619
+ "wearing , hugs another person and smiles ? glasses",
620
+ "one of them is shaking his head . ? car",
621
+ "octopuses",
622
+ "performs",
623
+ "cases",
624
+ "deer",
625
+ "? wall",
626
+ "and holding a lighter underneath , it explodes in flames ? balloon",
627
+ "blanket",
628
+ "coat",
629
+ "knives",
630
+ "? frame",
631
+ "trolley",
632
+ "noodles",
633
+ "one cries and holds a handkerchief to his nose , the other tries to comfort him ? two",
634
+ "wrap",
635
+ "? cart",
636
+ "inside of the car get scared ? two",
637
+ "animals",
638
+ "tails",
639
+ "? drawer",
640
+ "? cigarette",
641
+ "? barbel",
642
+ "room",
643
+ "? building",
644
+ "using as a weapon , hits a zombie in the head ? bat",
645
+ "trucks",
646
+ "boxers",
647
+ "drum",
648
+ "challenge",
649
+ "? toilet",
650
+ "llamas",
651
+ "then watches the smoke rise ? cat",
652
+ "mouths from across a room ? two",
653
+ "and it is pushed by a cat ? box",
654
+ "but the bear cub does ? bird",
655
+ "? skateboard",
656
+ "lifts up to her mouth , ? microphone",
657
+ "wearing , talks and bends his head forward ? cap",
658
+ "? doorway",
659
+ "which causes that cat to attack another cat ? cat",
660
+ "giraffe",
661
+ "cam",
662
+ "microphones",
663
+ "losing balance as it tries to walk forward ? cat",
664
+ "groove",
665
+ "tricks",
666
+ "spins , and lands on another ramp ? car",
667
+ "dumbbell",
668
+ "with their arms out , while laughing ? three",
669
+ "sea",
670
+ "carrot",
671
+ "chips",
672
+ "gift",
673
+ "ropes",
674
+ "singer",
675
+ "rocket",
676
+ "? net",
677
+ "blows",
678
+ "? zipper",
679
+ "sticks",
680
+ "tambourine",
681
+ "and he is laughing at a puppet talking ? cookie",
682
+ "? train",
683
+ "boats",
684
+ "across a road , and into the path of a car before being hit ? bicycle",
685
+ "penguins",
686
+ "song",
687
+ "antlers",
688
+ "feather",
689
+ "handcuffs",
690
+ "insect",
691
+ "gratings",
692
+ "milk",
693
+ "blackbird",
694
+ "scaffolding",
695
+ "sheet",
696
+ "seal",
697
+ "which bursts as the car approaches it ? car",
698
+ "? locker",
699
+ "towels",
700
+ "? highway",
701
+ "? lane",
702
+ "? rope",
703
+ "wearing , is singing with a microphone ? dress",
704
+ "vegetables",
705
+ "rag",
706
+ "? hoop",
707
+ "? hospital",
708
+ "keys",
709
+ "and he is raising his arm ? crotch",
710
+ "otter",
711
+ "? corridor",
712
+ "tires",
713
+ "they see it from looking up ? window",
714
+ "trainer",
715
+ "groundhog",
716
+ "gorilla",
717
+ "is sitting on the steps and eating ? shirt",
718
+ "oar",
719
+ "nugget",
720
+ "? cellphone",
721
+ "hamsters",
722
+ "walls",
723
+ "? cup",
724
+ "and then starts wracking it FRAMEQAeatedly ? wand",
725
+ "concoction",
726
+ "computer",
727
+ "hall",
728
+ "one is licking the other ones ear ? cats",
729
+ "earphone",
730
+ "hallway",
731
+ "trailer",
732
+ "magazine",
733
+ "and pointing at it ? laptop",
734
+ "elevator",
735
+ "river",
736
+ "pig",
737
+ "is also using ? earring",
738
+ "case",
739
+ "cape",
740
+ "? tablet",
741
+ "beanie",
742
+ "penguin",
743
+ "race",
744
+ "? excitedly",
745
+ "groomed each other ? cats",
746
+ "carriage",
747
+ "with long hair , open her mouth . ? room",
748
+ "parakeet",
749
+ "call",
750
+ "? tire",
751
+ "windshield",
752
+ "nose",
753
+ "? capsule",
754
+ "woman",
755
+ "snowball",
756
+ "look at one another , and fall to the ground laughing ? three",
757
+ "wing",
758
+ "bowl",
759
+ "lipstick",
760
+ "who is looking upset ? one",
761
+ "balls",
762
+ "cage",
763
+ "sunroof",
764
+ "? shop",
765
+ "shining and wearing a yellow outfit ? microphone",
766
+ "then two of them wave goodbye ? three",
767
+ "? sunglasses",
768
+ "kittens",
769
+ "? lingerie",
770
+ "colors",
771
+ "crying and eating a sandwich . ? bed",
772
+ "? lapel",
773
+ "corn",
774
+ "twirl",
775
+ "dough",
776
+ "dock",
777
+ "taxi",
778
+ "singing",
779
+ "stares",
780
+ "skate",
781
+ "chick",
782
+ "is visiting another guy . ? hospital",
783
+ "comb",
784
+ "roll",
785
+ "runway",
786
+ "statue",
787
+ "rides a skateboard up and launches himself through the air ? ramp",
788
+ "bleachers",
789
+ "? pot",
790
+ "butter",
791
+ "and it bounces off of a wall onto a table ? cat",
792
+ "? basement",
793
+ "eyeliner",
794
+ "wearing , is waving his hand ? shirt",
795
+ "opens the door , and the cat and four dogs enter the building through the door ? cat",
796
+ "right",
797
+ "flashlights",
798
+ "pet",
799
+ "pastry",
800
+ "but then the trailing car is shown a weapon and the car falls back ? car",
801
+ "tuxedo",
802
+ "begins to flip over and over ? car",
803
+ "curtain",
804
+ "fork",
805
+ "he looks away ? guitar",
806
+ "roof",
807
+ "? restroom",
808
+ "who jumps away . ? box",
809
+ "? rag",
810
+ "wearing , talks and raises on eyebrow ? headband",
811
+ "? cloak",
812
+ "then the rider lands on top ? motorcycle",
813
+ "toys",
814
+ "are talking to each other ? two",
815
+ "rats",
816
+ "telephone",
817
+ "bananas",
818
+ "user",
819
+ "stops and gets in ? taxi",
820
+ "cane",
821
+ "bucket",
822
+ "popsicle",
823
+ "? tent",
824
+ "? oven",
825
+ "and the fired a shot ? flower",
826
+ "? broom",
827
+ "? pan",
828
+ "design",
829
+ "hippopotamus",
830
+ "they move to the left ? sky",
831
+ "trying not to laugh ? two",
832
+ "torch",
833
+ "they look at one another , and the woman exits the car . ? car",
834
+ "his head nods to the left . ? chair",
835
+ "and he had a bandage on his head . ? car",
836
+ "vegetable",
837
+ "and everyone celebrates ? star",
838
+ "balloons",
839
+ "men",
840
+ "circles",
841
+ "graffiti",
842
+ "racer",
843
+ "jump",
844
+ "kissing , and spinning around ? two",
845
+ "works",
846
+ "castle",
847
+ "while they are sitting down ? two",
848
+ "sandwich",
849
+ "earpiece",
850
+ "then lift ? shirt",
851
+ "motors",
852
+ "burrito",
853
+ "? singlet",
854
+ "180",
855
+ "? dryer",
856
+ "torches",
857
+ "? pullover",
858
+ "wearing , slides open a door and dances through while carrying a walking tick and radio ? glasses",
859
+ "straw",
860
+ "wearing , pushes a melting ice cream into his mouth as some drops from his hand ? cap",
861
+ "clown",
862
+ "smiles , and turns away . ? classroom",
863
+ "figure",
864
+ "white doll ? two",
865
+ "signs",
866
+ "? airplane",
867
+ "cannon",
868
+ "cloth",
869
+ "serviette",
870
+ "toast",
871
+ "? kit",
872
+ "bats",
873
+ "bobcat",
874
+ "griddle",
875
+ "leaves",
876
+ "pass",
877
+ "? door",
878
+ "ramp",
879
+ "porpoise",
880
+ "scissors",
881
+ "fighter",
882
+ "bandannas",
883
+ "bases",
884
+ "hug each other ? two",
885
+ "duckling",
886
+ "but grabs on and takes a drink ? monkey",
887
+ "winks",
888
+ "? jeep",
889
+ "twirls",
890
+ "harp",
891
+ "one points and talks and the other laughs ? two",
892
+ "then a redhead grabs ? hat",
893
+ "? zoo",
894
+ "tender",
895
+ "disc",
896
+ "fly",
897
+ "wash",
898
+ "harness",
899
+ "opening",
900
+ "brick",
901
+ "watermelon",
902
+ "plate",
903
+ "they bring it closer to their body ? stick",
904
+ "lake",
905
+ "sledgehammer",
906
+ "leaning backward , and waving their arms back and forth ? two",
907
+ "ocean",
908
+ "while spectators watch ? two",
909
+ "shuttle",
910
+ "loop",
911
+ "balcony",
912
+ "? closet",
913
+ "but falls off a table ? cat",
914
+ "anchor",
915
+ "? plaid",
916
+ "terrapins",
917
+ "pop",
918
+ "tool",
919
+ "hay",
920
+ "panther",
921
+ "smiling and laughing ? three",
922
+ "and it lands on his head ? hat",
923
+ "? fountain",
924
+ "photograph",
925
+ "it has a double yolk ? egg",
926
+ "one is in a basket ? dogs",
927
+ "but does ? cub",
928
+ "strips",
929
+ "jeep",
930
+ "when the toaster pops out toast the cat gets scared and jumps off ? cat",
931
+ "then turns around crazy ? dog",
932
+ "goldfish",
933
+ "? elevator",
934
+ "sedan",
935
+ "? pocket",
936
+ "planet",
937
+ "drill",
938
+ "two of them spinning around ? cars",
939
+ "baboon",
940
+ "mirror",
941
+ "? flowers",
942
+ "chairs",
943
+ "make in the air with a wand ? float",
944
+ "jewelry",
945
+ "fabric",
946
+ "coins",
947
+ "handset",
948
+ "jets",
949
+ "bulldog",
950
+ "black hair wearing and raising their hand up to their mouth ? shirt",
951
+ "sweatshirt",
952
+ "workout",
953
+ "rounds",
954
+ "? bench",
955
+ "? piece",
956
+ "sparklers",
957
+ "waterfall",
958
+ "lettuce",
959
+ "crashes",
960
+ "tomato",
961
+ "cheeseburger",
962
+ "strawberry",
963
+ "and another one appears to be . ? garden",
964
+ "flag",
965
+ "eight",
966
+ "toothpick",
967
+ "and disappears ? bowl",
968
+ "? lipstick",
969
+ "and she is smiling ? cat",
970
+ "? alleyway",
971
+ "shield",
972
+ "tuxedos",
973
+ "talking , smiling and waving his hand . ? chair",
974
+ "cheetah",
975
+ "and one player kicks into the goal ? ball",
976
+ "letters",
977
+ "? basket",
978
+ "pill",
979
+ "which trips another man who does a flip and lands on a recycle bin ? peel",
980
+ "human",
981
+ "fence",
982
+ "? sink",
983
+ "black leather trench coat ? star",
984
+ "divers",
985
+ "couch",
986
+ "buttons",
987
+ "shot",
988
+ "rodents",
989
+ "swords",
990
+ "gown",
991
+ "both speeding down the road ? car",
992
+ "people watch them . ? house",
993
+ "belts",
994
+ "catapult",
995
+ "ammunition",
996
+ "potatoes",
997
+ "lemur",
998
+ "while a third moves forward and dances ? two",
999
+ "then their hand and a slogan appears ? towel",
1000
+ "firecrackers",
1001
+ "ribs",
1002
+ "briefcase",
1003
+ "the man spills milk over his face . ? car",
1004
+ "? workshop",
1005
+ "is sitting down and smoking ? cigarette",
1006
+ "dressed in a suit and carrying ? cane",
1007
+ "and she is dancing in a field . ? mirror",
1008
+ "? ashtray",
1009
+ "looking sad . ? hallway",
1010
+ "noodle",
1011
+ "missiles",
1012
+ "? helicopter",
1013
+ "catfish",
1014
+ "toothbrush",
1015
+ "have taken ? pictures",
1016
+ "pane",
1017
+ "he dances on the stage ? headset",
1018
+ "scooters",
1019
+ "then he does the splits . ? hallway",
1020
+ "and it is pushed by a cat ? mouse",
1021
+ "desks",
1022
+ "hills",
1023
+ "stairway",
1024
+ "whisk",
1025
+ "with",
1026
+ "while one of them sings into a microphone ? two",
1027
+ "bottles",
1028
+ "but grabs her leg ? panda",
1029
+ "sled",
1030
+ "nut",
1031
+ "feathers",
1032
+ "dresses",
1033
+ "sink",
1034
+ "wristband",
1035
+ "then jumps up to celebrate ? pool",
1036
+ "drumsticks",
1037
+ "opens her mouth and smiles ? one",
1038
+ "suits",
1039
+ "sculpture",
1040
+ "are fighting for control of the soccer ball ? two",
1041
+ "and he is throwing ? napkin",
1042
+ "pets",
1043
+ "bin",
1044
+ "jockey",
1045
+ "backwards",
1046
+ "spiky , walk across the pavement ? heels",
1047
+ "chainsaw",
1048
+ "? guitar",
1049
+ "with just head and tail exposed ? cat",
1050
+ "when one pins the other one down for a three count ? two",
1051
+ "shore",
1052
+ "chicks",
1053
+ "dancing and laughing ? two",
1054
+ "looking sideways and singing ? guitar",
1055
+ "? turns",
1056
+ "lamp",
1057
+ "paper , scissors ? two",
1058
+ "chocolate",
1059
+ "bra",
1060
+ "blonde woman wearing a back top and matching ? piece",
1061
+ "holding hands ? two",
1062
+ "while the man next to him talks and moves his hands around ? one",
1063
+ "cubs",
1064
+ "having cake . ? restaurant",
1065
+ "figurine",
1066
+ "hood",
1067
+ "lens",
1068
+ "groomed each other ? two",
1069
+ "sabers",
1070
+ "before jumping in the pool ? dog",
1071
+ "mattress",
1072
+ "sidewalk",
1073
+ "landing",
1074
+ "rocks",
1075
+ "avocado",
1076
+ "? bear",
1077
+ "and a man spills , crouches , and cowers ? coffee",
1078
+ "disks",
1079
+ "mountainside",
1080
+ "lips",
1081
+ "chest",
1082
+ "wan",
1083
+ "glove",
1084
+ "? beer",
1085
+ "tortilla",
1086
+ "? stable",
1087
+ "meteor",
1088
+ "expression",
1089
+ "? kayak",
1090
+ "biscuit",
1091
+ "ukulele",
1092
+ "at something ? two",
1093
+ "convertible",
1094
+ "climber",
1095
+ "is using the pay phone and smoking ? cigarette",
1096
+ "wearing , looks mad ? jacket",
1097
+ "mike",
1098
+ "sleeping and stretching on the person 's stomach ? cat",
1099
+ "denim",
1100
+ "lantern",
1101
+ "breaks the branch its sitting on in the tree , and falls to the ground ? panda",
1102
+ "so that she 's almost laying down . ? car",
1103
+ "smears",
1104
+ "hair",
1105
+ "bones",
1106
+ "blade",
1107
+ "unicycle",
1108
+ "? cone",
1109
+ "wallet",
1110
+ "blouse",
1111
+ "trousers",
1112
+ "buds",
1113
+ "spill",
1114
+ "rib",
1115
+ "porcupine",
1116
+ "tray",
1117
+ "map",
1118
+ "sad ? dog",
1119
+ "socks",
1120
+ "automobile",
1121
+ "parallel",
1122
+ "skyscraper",
1123
+ "classroom",
1124
+ "catwalk",
1125
+ "the bike crashes ? bicycle",
1126
+ "stare , and look shocked ? four",
1127
+ "towel",
1128
+ "whilst another one is sitting down ? guitar",
1129
+ "lion",
1130
+ "cargo",
1131
+ "grabs",
1132
+ "and then starts wracking it FRAMEQAeatedly ? cat",
1133
+ "vest",
1134
+ "spits",
1135
+ "wearing is walking and waving ? dress",
1136
+ "poker",
1137
+ "robe",
1138
+ "bandanna",
1139
+ "little fingers ? two",
1140
+ "person",
1141
+ "doves",
1142
+ "container",
1143
+ "wearing , uses gymnastic rings to lift herself to a seated position then into a handstand ? clothes",
1144
+ "forklift",
1145
+ "buildings",
1146
+ "wearing ? blouse",
1147
+ "making a crack big enough for the rest to get in ? cat",
1148
+ "carrots",
1149
+ "lizard",
1150
+ "beakers",
1151
+ "blower",
1152
+ "and another woman is running in black shorts ? pants",
1153
+ "marks",
1154
+ "spaceship",
1155
+ "when one man lays the other man down ? two",
1156
+ "are dancing on a stage while the crowd cheers ? two",
1157
+ "they start to head bang . ? car",
1158
+ "then one blows confetti into the air ? two",
1159
+ "sitting down , when someone else steps up and spins the chair around . ? chair",
1160
+ "puppets",
1161
+ "garage",
1162
+ "lemon",
1163
+ "wearing , is sitting and doing something with her foot ? clothes",
1164
+ "and two men with lighting swords want to fight with him ? door",
1165
+ "treat",
1166
+ "lamb",
1167
+ "ways",
1168
+ "and one man throws ? hat",
1169
+ "pick",
1170
+ "product",
1171
+ "is throwing around the room ? clothes",
1172
+ "the clothes of the people catch on fire ? horses",
1173
+ "all , have the same type of hair style ? three",
1174
+ "whip",
1175
+ "mop",
1176
+ "pointing his fingers and nodding ? bow",
1177
+ "bags",
1178
+ "machines",
1179
+ "seeds",
1180
+ "symbol",
1181
+ "layer",
1182
+ "opens ? door",
1183
+ "dark sunglasses , and cigar ? two",
1184
+ "the man smashes the head of a zombie ? bat",
1185
+ "extinguisher",
1186
+ "candles",
1187
+ ", looking out ? window",
1188
+ "group",
1189
+ "drop",
1190
+ "is riding , into the swimming pool ? bicycle",
1191
+ "stake",
1192
+ "block",
1193
+ "and he is singing into a microphone ? guitar",
1194
+ "ornament",
1195
+ "spins as he bends over . ? chair",
1196
+ "? shirts",
1197
+ "? colors",
1198
+ "hookah",
1199
+ "? courtyard",
1200
+ "cactus",
1201
+ "are having taken while on stage ? picture",
1202
+ "an orange ? shell",
1203
+ "and he is talking ? sunglasses",
1204
+ "veil",
1205
+ "then rolling around in the mud ? horse",
1206
+ "? pillow",
1207
+ "drugs",
1208
+ "? couch",
1209
+ "bun",
1210
+ "koala",
1211
+ "one wearing brown shoes and the other has no footwear ? two",
1212
+ "and he is falling in the water ? dog",
1213
+ "is smoking ? cigarette",
1214
+ "rooster",
1215
+ "submarine",
1216
+ "wand",
1217
+ "helicopter",
1218
+ "wearing , smiles as her hair blows in the wind ? hat",
1219
+ "and fails , to jump into the window ? cat",
1220
+ "tram",
1221
+ "and then is knocked down when it hits him in the head ? bag",
1222
+ "curve",
1223
+ "handrail",
1224
+ "bulldozer",
1225
+ "stops a taxi . ? street",
1226
+ "speedometer",
1227
+ "? necklace",
1228
+ "curbs",
1229
+ "over multiple vehicles and lands on another ramp ? bicycle",
1230
+ "wolves",
1231
+ "laundry",
1232
+ "holding , laughs into a microphone and then puts her fingers up to her lips ? guitar",
1233
+ "peeking . ? room",
1234
+ "cigarettes",
1235
+ "bells",
1236
+ "sill",
1237
+ "raspberry",
1238
+ "suited",
1239
+ "shawl",
1240
+ "wakes",
1241
+ "applying the brake , and applying the gas as needed . ? car",
1242
+ "poodle",
1243
+ "and he 's ? candles",
1244
+ "then skids on the ground ? motorcycle",
1245
+ "office",
1246
+ "outdoors",
1247
+ "it stops at the edge ? car",
1248
+ "as she puts it all on top of her head ? two",
1249
+ "but his reflection is doing something different . ? mirror",
1250
+ "holding , are walking together ? bear",
1251
+ "hats",
1252
+ "mat",
1253
+ "then the team mate scores a goal ? ball",
1254
+ "one with a guitar are behind him ? one",
1255
+ "? looks",
1256
+ "grenade",
1257
+ "coin",
1258
+ "toasting each other with their liquor bottles ? two",
1259
+ "saxophone",
1260
+ "capes",
1261
+ "lounges",
1262
+ "? scissors",
1263
+ "hoop",
1264
+ "rack",
1265
+ "frisbee",
1266
+ "then jumps in the air and runs away ? cat",
1267
+ "wearing , is hugging in the hallway ? coats",
1268
+ "? lobby",
1269
+ "corridor",
1270
+ "who they push to the ground ? two",
1271
+ "worms",
1272
+ "tablet",
1273
+ "who turns and causes the kitten to raise its paw ? kitten",
1274
+ "chariot",
1275
+ "lock",
1276
+ "tongs",
1277
+ "game",
1278
+ "s head while he is trying to eat ? cat",
1279
+ "pie",
1280
+ "feline",
1281
+ "and then are shown ? pictures",
1282
+ "parasol",
1283
+ "pumpkins",
1284
+ "notebook",
1285
+ "the horse leans its head around her ? horse",
1286
+ "spaghetti",
1287
+ "outside",
1288
+ "? bib",
1289
+ "gold",
1290
+ "cart",
1291
+ "the trees are being passed by , and the clouds are above ? sun",
1292
+ "the other elephant pulls it closer ? elephant",
1293
+ "most of them wearing ? sunglasses",
1294
+ "and are falling down on top of him ? balloons",
1295
+ "nods his head and blinks ? one",
1296
+ "with long brown hair , wink and raises to her face ? two",
1297
+ "uncontrollably",
1298
+ "wearing , raises two fingers to her face ? cap",
1299
+ "swinging its hips from side to side ? turtle",
1300
+ "skates",
1301
+ "they look at one another , and the woman exits ? car",
1302
+ "his friends join in the background . ? chair",
1303
+ "store",
1304
+ "donuts",
1305
+ "then sticks its tongue out ? dog",
1306
+ "and then a massive explosion occurs ? container",
1307
+ "then kisses her ? snake",
1308
+ "brakes"
1309
+ ]
ChatUniVi/eval/questions/video_qa/tgif_qa.json ADDED
The diff for this file is too large to render. See raw diff
 
ChatUniVi/eval/table/caps_boxes_coco2014_val_80.jsonl ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"id": "000000296284", "image": "COCO_val2014_000000296284.jpg", "captions": ["A donut shop is full of different flavors of donuts.", "Fruit flavored donuts lined up in a glass fronted cabinet", "A rack with some doughnuts in a glass case.", "A display case in a bakery filled with donuts.", "An assortment of doughnuts are arranged in a display case."], "instances": [{"category": "donut", "bbox": [0.37, 0.584, 0.504, 0.709]}, {"category": "donut", "bbox": [0.369, 0.22, 0.492, 0.317]}, {"category": "donut", "bbox": [0.471, 0.587, 0.639, 0.706]}, {"category": "donut", "bbox": [0.544, 0.213, 0.679, 0.316]}, {"category": "donut", "bbox": [0.035, 0.22, 0.196, 0.328]}, {"category": "donut", "bbox": [0.054, 0.608, 0.221, 0.711]}, {"category": "donut", "bbox": [0.283, 0.586, 0.429, 0.708]}, {"category": "donut", "bbox": [0.466, 0.226, 0.585, 0.32]}, {"category": "donut", "bbox": [0.28, 0.232, 0.393, 0.322]}, {"category": "donut", "bbox": [0.0, 0.609, 0.097, 0.722]}]}
2
+ {"id": "000000151358", "image": "COCO_val2014_000000151358.jpg", "captions": ["A newspaper that has sunglasses on top of it sitting in front of books.", "an apple sunglasses books and a teddy bear", "A folded newspaper and sunglasses are on a table with an apple, books, and teddy bear behind.", "An apple sitting on a table next to sunglasses and a news paper.", "There are sunglasses laying on the folded newspaper."], "instances": [{"category": "tie", "bbox": [0.258, 0.074, 0.527, 0.589]}, {"category": "apple", "bbox": [0.621, 0.482, 0.853, 0.645]}, {"category": "book", "bbox": [0.154, 0.107, 0.275, 0.59]}, {"category": "book", "bbox": [0.535, 0.09, 0.735, 0.583]}, {"category": "book", "bbox": [0.051, 0.112, 0.159, 0.6]}, {"category": "teddy bear", "bbox": [0.753, 0.084, 1.0, 0.517]}, {"category": "book", "bbox": [0.681, 0.097, 0.796, 0.483]}, {"category": "book", "bbox": [0.443, 0.099, 0.574, 0.588]}, {"category": "book", "bbox": [0.267, 0.337, 0.386, 0.579]}]}
3
+ {"id": "000000052312", "image": "COCO_val2014_000000052312.jpg", "captions": ["The old man literally has a toothbrush mustache.", "An old man with a tooth brush head under his nose, mimicking Hitler", "A man wearing a toothbrush for a moustache.", "A man with the head of a toothbrush under his nose like a mustache", "An elderly man wearing the head of a toothbrush as a moustache."], "instances": [{"category": "toothbrush", "bbox": [0.345, 0.59, 0.594, 0.679]}, {"category": "person", "bbox": [0.0, 0.03, 1.0, 0.99]}]}
4
+ {"id": "000000473210", "image": "COCO_val2014_000000473210.jpg", "captions": ["two people taking apart their wii controllers to replace batteries", "People taking apart video game remote controls on a table", "People handling a couple of remotes taking them apart.", "two sets of hands a wooden table and two controllers", "Two people who are taking apart a video game controller."], "instances": [{"category": "person", "bbox": [0.002, 0.334, 0.453, 0.986]}, {"category": "remote", "bbox": [0.407, 0.207, 0.727, 0.604]}, {"category": "remote", "bbox": [0.088, 0.344, 0.313, 0.547]}, {"category": "laptop", "bbox": [0.001, 0.049, 0.1, 0.197]}, {"category": "person", "bbox": [0.484, 0.254, 0.998, 0.985]}, {"category": "dining table", "bbox": [0.0, 0.003, 1.0, 0.956]}]}
5
+ {"id": "000000097131", "image": "COCO_val2014_000000097131.jpg", "captions": ["A car parked by a parking meter in front of a building.", "A car is sitting parked at a curb in front of a parking meter.", "A black car on the street next to a parking meter.", "A gray car parked in front of two parking meters.", "A black car parked on the side of the road."], "instances": [{"category": "car", "bbox": [0.227, 0.362, 0.946, 0.761]}, {"category": "car", "bbox": [0.793, 0.322, 0.88, 0.4]}, {"category": "car", "bbox": [0.0, 0.447, 0.028, 0.726]}, {"category": "parking meter", "bbox": [0.156, 0.35, 0.186, 0.453]}, {"category": "truck", "bbox": [0.907, 0.331, 1.0, 0.408]}, {"category": "parking meter", "bbox": [0.188, 0.349, 0.218, 0.448]}]}
6
+ {"id": "000000543364", "image": "COCO_val2014_000000543364.jpg", "captions": ["There is a table in the middle of the room.", "A room with a couch, table, lamp and a chaise.", "A living room with couch, chaise, track lighting, and a large window.", "A room with large windows, a couch and a table.", "A living room with lots of furniture and a large window."], "instances": [{"category": "dining table", "bbox": [0.388, 0.644, 0.636, 0.879]}, {"category": "couch", "bbox": [0.194, 0.531, 0.552, 0.777]}, {"category": "couch", "bbox": [0.568, 0.488, 0.907, 0.783]}, {"category": "remote", "bbox": [0.524, 0.651, 0.556, 0.675]}, {"category": "chair", "bbox": [0.661, 0.478, 0.802, 0.604]}]}
7
+ {"id": "000000217181", "image": "COCO_val2014_000000217181.jpg", "captions": ["They are standing next to some stylish motorcycles.", "Three men are standing around looking at sports motorcycles.", "A small group of men are standing around a motorcycle.", "Two men surrounding a blue motorcycle and others", "A few blue motorcycles are parked in a lot."], "instances": [{"category": "car", "bbox": [0.011, 0.177, 0.2, 0.336]}, {"category": "motorcycle", "bbox": [0.032, 0.139, 0.907, 0.982]}, {"category": "motorcycle", "bbox": [0.0, 0.239, 0.148, 0.613]}, {"category": "motorcycle", "bbox": [0.0, 0.301, 0.106, 0.45]}, {"category": "person", "bbox": [0.775, 0.043, 0.93, 0.463]}, {"category": "person", "bbox": [0.717, 0.116, 0.81, 0.509]}, {"category": "person", "bbox": [0.296, 0.008, 0.472, 0.325]}, {"category": "person", "bbox": [0.115, 0.19, 0.164, 0.269]}, {"category": "truck", "bbox": [0.63, 0.227, 0.731, 0.335]}]}
8
+ {"id": "000000140289", "image": "COCO_val2014_000000140289.jpg", "captions": ["Two born bears walking though a forest surrounded by trees.", "Two full grown brown bears in a habitat.", "Two bears are roaming around in the woods.", "Two bears around logs in front of a large rock.", "Two big bears wandering through the woods together"], "instances": [{"category": "bear", "bbox": [0.131, 0.269, 0.375, 0.65]}, {"category": "bear", "bbox": [0.568, 0.193, 0.809, 0.827]}]}
9
+ {"id": "000000460149", "image": "COCO_val2014_000000460149.jpg", "captions": ["A clock hosted on a pole on a pavement next to a building", "Street clock on quiet street with trees and bicycles.", "A tall clock stands on an empty sidewalk.", "A pole that has a clock on the top of it.", "a clock on a short tower and potted plants along the sidewalk"], "instances": [{"category": "potted plant", "bbox": [0.14, 0.71, 0.338, 0.856]}, {"category": "bicycle", "bbox": [0.65, 0.671, 0.766, 0.733]}, {"category": "car", "bbox": [0.38, 0.608, 0.488, 0.656]}, {"category": "clock", "bbox": [0.468, 0.048, 0.699, 0.216]}, {"category": "bicycle", "bbox": [0.669, 0.662, 0.719, 0.67]}, {"category": "car", "bbox": [0.786, 0.625, 0.86, 0.668]}, {"category": "potted plant", "bbox": [0.756, 0.637, 0.819, 0.682]}, {"category": "person", "bbox": [0.942, 0.615, 0.954, 0.641]}, {"category": "bicycle", "bbox": [0.648, 0.68, 0.714, 0.747]}, {"category": "car", "bbox": [0.837, 0.619, 0.88, 0.659]}, {"category": "potted plant", "bbox": [0.017, 0.197, 0.443, 0.686]}]}
10
+ {"id": "000000225738", "image": "COCO_val2014_000000225738.jpg", "captions": ["A group of giraffes standing up in their natural habitat.", "A group of giraffe standing in a grass field.", "A group of four giraffes near the same tree.", "there are four giraffes standing among some dry brush", "A herd of giraffe standing on top of a grass field."], "instances": [{"category": "giraffe", "bbox": [0.648, 0.231, 0.855, 0.915]}, {"category": "giraffe", "bbox": [0.33, 0.136, 0.521, 0.93]}, {"category": "giraffe", "bbox": [0.406, 0.261, 0.515, 1.0]}, {"category": "giraffe", "bbox": [0.347, 0.194, 0.583, 0.922]}]}
11
+ {"id": "000000109532", "image": "COCO_val2014_000000109532.jpg", "captions": ["An adorable husky dog sleeping in a dog bed next to a fan.", "A dark room with a dog sleeping on a dog bed.", "A dog is sleeping in a dark room.", "a large dog laying in a dog bed in a living room", "A dog sleeping on a dog bed in a room."], "instances": [{"category": "dog", "bbox": [0.426, 0.661, 0.582, 0.925]}, {"category": "potted plant", "bbox": [0.603, 0.261, 0.781, 0.613]}, {"category": "chair", "bbox": [0.67, 0.515, 0.899, 0.801]}, {"category": "potted plant", "bbox": [0.671, 0.439, 0.763, 0.612]}, {"category": "chair", "bbox": [0.852, 0.653, 0.948, 0.818]}]}
12
+ {"id": "000000118606", "image": "COCO_val2014_000000118606.jpg", "captions": ["A man riding skis on top of a rail.", "a person riding a pair of skis on a rail", "Someone on a pair of skis on a ramp at the ski slope", "Person with skis in the air above the snow.", "A man performing a trick on a rail while skiing."], "instances": [{"category": "person", "bbox": [0.444, 0.361, 0.537, 0.633]}, {"category": "skis", "bbox": [0.413, 0.554, 0.539, 0.664]}, {"category": "person", "bbox": [0.342, 0.585, 0.352, 0.62]}, {"category": "person", "bbox": [0.439, 0.565, 0.446, 0.58]}]}
13
+ {"id": "000000385873", "image": "COCO_val2014_000000385873.jpg", "captions": ["Three pizzas sitting next to each other in boxes.", "Two smaller pizzas sit beside a large pizza topped with tortilla chips.", "Three pizzas inside their delivery boxes, one with two side orders of sauce.", "One pizza is larger than two other pizzas.", "Three pizza boxes with pizza in them are open."], "instances": [{"category": "bowl", "bbox": [0.634, 0.624, 0.736, 0.752]}, {"category": "pizza", "bbox": [0.3, 0.382, 0.615, 0.733]}, {"category": "pizza", "bbox": [0.0, 0.4, 0.287, 0.745]}, {"category": "pizza", "bbox": [0.624, 0.279, 0.999, 0.753]}, {"category": "bowl", "bbox": [0.94, 0.247, 1.0, 0.352]}]}
14
+ {"id": "000000092109", "image": "COCO_val2014_000000092109.jpg", "captions": ["A giraffe's head is pictured in this clear, colorful photo.", "A giraffe is standing tall in the middle of several bright green trees", "The face of a giraffe looking to the side.", "the close up head shot of a giraffe", "this is a giraffe chewing on some leaves"], "instances": [{"category": "giraffe", "bbox": [0.236, 0.122, 1.0, 0.987]}]}
15
+ {"id": "000000163076", "image": "COCO_val2014_000000163076.jpg", "captions": ["There's an outdoor dining area featuring a fountain.", "A table sitting next to a water fountain covered by an umbrella.", "An empty restaurant patio with tables and umbrellas.", "An outdoor restaurant with a fountain at night", "A fountain bubbles in the plaza of an outdoor cafe."], "instances": [{"category": "umbrella", "bbox": [0.064, 0.069, 0.95, 0.844]}, {"category": "chair", "bbox": [0.198, 0.574, 0.355, 0.704]}, {"category": "chair", "bbox": [0.42, 0.571, 0.55, 0.738]}, {"category": "dining table", "bbox": [0.066, 0.741, 0.766, 0.925]}, {"category": "dining table", "bbox": [0.059, 0.584, 0.27, 0.659]}, {"category": "chair", "bbox": [0.432, 0.567, 0.52, 0.624]}, {"category": "chair", "bbox": [0.433, 0.555, 0.504, 0.6]}, {"category": "chair", "bbox": [0.109, 0.673, 0.374, 0.796]}]}
16
+ {"id": "000000560371", "image": "COCO_val2014_000000560371.jpg", "captions": ["Street signs from the corner of 8th ave. and 22 3/4 st.", "A two way street sign with one sign that changes from one name to another.", "A street sign is pointing towards 8th avenue and the other is pointing towards 22 3/4 street in the middle of the forest.", "A street sign standing in front of some trees.", "Peculiar street sign showing intersection of 23 3/4 St and 8th Ave/CTH D."], "instances": []}
17
+ {"id": "000000367571", "image": "COCO_val2014_000000367571.jpg", "captions": ["A couple of different doughnuts in a box", "There are four donuts in a box, and some are cake donuts and a doughnut with nuts and coconut on top.", "A box of glazed doughnuts on a table.", "Three donuts with toppings on them sitting inside a box.", "A box that is filled with different kinds of doughnuts."], "instances": [{"category": "donut", "bbox": [0.412, 0.335, 0.711, 0.681]}, {"category": "donut", "bbox": [0.093, 0.493, 0.486, 0.922]}, {"category": "donut", "bbox": [0.713, 0.423, 0.957, 0.874]}, {"category": "donut", "bbox": [0.13, 0.331, 0.397, 0.55]}]}
18
+ {"id": "000000580197", "image": "COCO_val2014_000000580197.jpg", "captions": ["Two men in bow ties standing next to steel rafter.", "Several men in suits talking together in a room.", "An older man in a tuxedo standing next to a younger man in a tuxedo wearing glasses.", "Two men wearing tuxedos glance at each other.", "Older man in tuxedo sitting next to another younger man in tuxedo."], "instances": [{"category": "tie", "bbox": [0.914, 0.46, 0.984, 0.512]}, {"category": "person", "bbox": [0.297, 0.638, 0.71, 0.989]}, {"category": "person", "bbox": [0.77, 0.177, 1.0, 0.971]}, {"category": "tie", "bbox": [0.281, 0.481, 0.368, 0.519]}, {"category": "person", "bbox": [0.103, 0.204, 0.497, 1.0]}]}
19
+ {"id": "000000506095", "image": "COCO_val2014_000000506095.jpg", "captions": ["A cat is staring at a laptop computer.", "a cat on a desk with a laptop and a mouse", "A cat that is sitting at a desk next to a laptop.", "A kitten sitting on a laptop computer sitting on top of a wooden desk.", "A kitten sits facing an open black laptop."], "instances": [{"category": "cat", "bbox": [0.658, 0.207, 1.0, 0.754]}, {"category": "laptop", "bbox": [0.108, 0.135, 0.766, 0.69]}, {"category": "book", "bbox": [0.836, 0.239, 0.954, 0.273]}, {"category": "book", "bbox": [0.0, 0.556, 0.128, 0.685]}, {"category": "book", "bbox": [0.039, 0.574, 0.257, 0.691]}, {"category": "book", "bbox": [0.825, 0.214, 0.962, 0.254]}, {"category": "book", "bbox": [0.892, 0.275, 0.958, 0.308]}, {"category": "book", "bbox": [0.922, 0.318, 0.986, 0.353]}, {"category": "book", "bbox": [0.87, 0.267, 0.951, 0.291]}, {"category": "book", "bbox": [0.949, 0.102, 0.976, 0.114]}, {"category": "book", "bbox": [0.936, 0.161, 0.958, 0.168]}]}
20
+ {"id": "000000024996", "image": "COCO_val2014_000000024996.jpg", "captions": ["A bathroom with a glass door and a sink.", "A blue lined bathroom with an open glass door.", "A nice bathroom with a sink, toilet, and tiled shower.", "A bathroom that is clean and shiny in the day.", "a bathroom with a sink and a mirror and a window"], "instances": [{"category": "toilet", "bbox": [0.842, 0.934, 0.95, 1.0]}, {"category": "sink", "bbox": [0.506, 0.724, 0.683, 0.834]}]}
21
+ {"id": "000000457882", "image": "COCO_val2014_000000457882.jpg", "captions": ["a girl in a bikini and a brown and white dog and a few other people", "A woman with a swimsuit on sitting with a dog.", "A woman is sitting with a dog on her lap.", "A dog sitting next to a woman in her swimsuit.", "WOMAN SITTING WITH HER DOG, AND OTHER WOMEN ARE AROUND"], "instances": [{"category": "dog", "bbox": [0.202, 0.409, 0.54, 0.81]}, {"category": "dog", "bbox": [0.61, 0.428, 0.729, 0.723]}, {"category": "boat", "bbox": [0.003, 0.705, 0.939, 0.974]}, {"category": "person", "bbox": [0.236, 0.001, 0.558, 0.784]}, {"category": "person", "bbox": [0.681, 0.001, 0.957, 0.798]}, {"category": "person", "bbox": [0.849, 0.478, 1.0, 0.946]}, {"category": "person", "bbox": [0.345, 0.187, 0.634, 0.828]}, {"category": "person", "bbox": [0.033, 0.345, 0.109, 0.434]}]}
22
+ {"id": "000000081552", "image": "COCO_val2014_000000081552.jpg", "captions": ["A cat sitting and curled up on a red couch", "A cat laying on a red couch sleeping.", "a tan and black cat curled up asleep on a red velvet seat", "A cat is curled up on a red sofa.", "Cat curled up, sleeping on a red plush couch."], "instances": [{"category": "cat", "bbox": [0.412, 0.237, 0.634, 0.482]}, {"category": "couch", "bbox": [0.003, 0.005, 1.0, 0.99]}]}
23
+ {"id": "000000273450", "image": "COCO_val2014_000000273450.jpg", "captions": ["A person flipping of a parking meter on the side of a road.", "A man holds up his middle finger to a parking meter.", "Person giving the middle finger to a parking meter.", "a black silver white blue red an orange parking meter and a hand flipping it off", "A person is flipping off a parking meter."], "instances": [{"category": "person", "bbox": [0.0, 0.475, 0.565, 0.987]}, {"category": "car", "bbox": [0.0, 0.0, 0.531, 0.734]}, {"category": "parking meter", "bbox": [0.0, 0.0, 1.0, 0.987]}]}
24
+ {"id": "000000203879", "image": "COCO_val2014_000000203879.jpg", "captions": ["There is a small cellphone displayed between a set of ear buds and two paper weights.", "a cell phone lays next to some diamonds", "a close up of a cell phone on a table near earbuds", "A cell phone sits on a table next to some jewels.", "A cell phone, ear buds, and two jewels laying near each other."], "instances": [{"category": "cell phone", "bbox": [0.322, 0.233, 0.62, 0.79]}]}
25
+ {"id": "000000346875", "image": "COCO_val2014_000000346875.jpg", "captions": ["two zebras in a field near one another", "A couple of zebra walking across a green field.", "Two zebra are walking near a gravel road.", "two zebras in a green field of grass and some trees", "A zebra follows another zebra through a park."], "instances": [{"category": "zebra", "bbox": [0.591, 0.263, 0.82, 0.466]}, {"category": "zebra", "bbox": [0.293, 0.243, 0.561, 0.45]}]}
26
+ {"id": "000000525439", "image": "COCO_val2014_000000525439.jpg", "captions": ["a man stands in front of a flipped skate boarder", "A man standing next to a skateboard that is laying on the ground wheels pointed up.", "Skateboard laying upside down on cement with someone standing next to it.", "A boy in camo shorts stands before an overturned skateboard.", "a person with an upside down skate board"], "instances": [{"category": "person", "bbox": [0.307, 0.001, 0.63, 0.739]}, {"category": "skateboard", "bbox": [0.0, 0.592, 0.626, 0.969]}]}
27
+ {"id": "000000304749", "image": "COCO_val2014_000000304749.jpg", "captions": ["The woman is taking a picture in the bathroom mirror.", "A picture of a woman in a mirror.", "A woman's midsection reflected in a round mirror.", "A circular mirror reflecting a woman's stomach in turquoise shirt.", "A selfie taken of a person from the neck down."], "instances": [{"category": "person", "bbox": [0.092, 0.001, 0.646, 0.496]}]}
28
+ {"id": "000000323760", "image": "COCO_val2014_000000323760.jpg", "captions": ["A toilet is shown in a bare room.", "A ugly bathroom with a section of the wall missing.", "A toilet in a stripped bathroom with studs, bricks and plaster showing", "A bathroom with no walls and a toilet bowl", "A white toilet next to some torn out walls."], "instances": [{"category": "toilet", "bbox": [0.167, 0.585, 0.714, 1.0]}]}
29
+ {"id": "000000066144", "image": "COCO_val2014_000000066144.jpg", "captions": ["A woman standing in front of window next to a bug and a stop sign.", "A car parked on the street next to a tree and stop sign.", "A lone Volkswagen is parked by a stop sign.", "A window view of a small car near a street stop sign.", "An old VW Bug standing at a stop sign."], "instances": [{"category": "stop sign", "bbox": [0.501, 0.328, 0.569, 0.428]}, {"category": "car", "bbox": [0.242, 0.488, 0.56, 0.726]}, {"category": "car", "bbox": [0.279, 0.325, 0.33, 0.363]}, {"category": "car", "bbox": [0.153, 0.333, 0.29, 0.405]}, {"category": "car", "bbox": [0.11, 0.339, 0.177, 0.373]}, {"category": "car", "bbox": [0.0, 0.654, 0.082, 0.826]}, {"category": "car", "bbox": [0.0, 0.322, 0.064, 0.364]}, {"category": "car", "bbox": [0.451, 0.333, 0.51, 0.392]}]}
30
+ {"id": "000000455772", "image": "COCO_val2014_000000455772.jpg", "captions": ["A person in a field jumping to catch a Frisbee.", "A guy jumping to catch a frisbee in mid-air.", "A person that is trying to get a frisbee.", "Nice reach, but the Frisbee flies on, victorious.", "A man playing frisbee in a grassy yard."], "instances": [{"category": "car", "bbox": [0.148, 0.339, 0.201, 0.476]}, {"category": "car", "bbox": [0.376, 0.396, 0.424, 0.476]}, {"category": "person", "bbox": [0.547, 0.122, 0.698, 0.904]}, {"category": "frisbee", "bbox": [0.479, 0.154, 0.555, 0.231]}, {"category": "car", "bbox": [0.001, 0.299, 0.085, 0.394]}]}
31
+ {"id": "000000511117", "image": "COCO_val2014_000000511117.jpg", "captions": ["A couple of kids standing on top of a grass covered field.", "A little boy wearing a baseball uniform stands by a little girl.", "A young boy in a baseball uniform and a young girl are standing in front of a chain link fence.", "A little boy and girl standing on a baseball field. The boy has a uniform on.", "A young baseball player is standing next to a young girl."], "instances": [{"category": "person", "bbox": [0.514, 0.178, 0.776, 0.774]}, {"category": "baseball glove", "bbox": [0.468, 0.462, 0.593, 0.609]}, {"category": "person", "bbox": [0.174, 0.051, 0.598, 0.839]}, {"category": "bench", "bbox": [0.558, 0.125, 1.0, 0.315]}]}
32
+ {"id": "000000207151", "image": "COCO_val2014_000000207151.jpg", "captions": ["A vegetarian pizza is half eaten on a pizza holder.", "A couple of pieces of pizza with vegetable slices on them.", "A wooden pan serving tray with a pizza on it.", "A pizza on a cutting board is half gone.", "A Pizza is nearly finished with only three pieces left."], "instances": [{"category": "bottle", "bbox": [0.001, 0.001, 0.121, 0.231]}, {"category": "cup", "bbox": [0.0, 0.002, 0.121, 0.238]}, {"category": "pizza", "bbox": [0.17, 0.472, 0.526, 0.82]}, {"category": "pizza", "bbox": [0.398, 0.106, 0.962, 0.679]}, {"category": "dining table", "bbox": [0.0, 0.001, 1.0, 0.988]}]}
33
+ {"id": "000000431165", "image": "COCO_val2014_000000431165.jpg", "captions": ["A baby elephant standing in front of a brick building.", "An elephant is standing near a dirt mount in an exhibit.", "Grey elephant standing next to a large sand dune in a pen.", "An elephant standing alone inside of an enclosure.", "The baby elephant is alone in the pen."], "instances": [{"category": "elephant", "bbox": [0.303, 0.399, 0.638, 0.78]}]}
34
+ {"id": "000000378545", "image": "COCO_val2014_000000378545.jpg", "captions": ["A pole that has a clock on top of it.", "A clock mounted on an outdoor post with Roman numerals.", "a clock on a pole saying it is 12:45", "An ornamental standing clock is at the foreground of a row of houses.", "A black and gold clock on a pole in front of a building."], "instances": [{"category": "clock", "bbox": [0.216, 0.249, 0.749, 0.658]}]}
35
+ {"id": "000000555904", "image": "COCO_val2014_000000555904.jpg", "captions": ["A man sitting at a bar filled with liquor.", "People sitting a a take near several bottles of wine on shelves.", "Several people are sitting at a table drinking.", "Several people in a bar sitting at a long table.", "People eating in a restaurant near wine bottles."], "instances": [{"category": "dining table", "bbox": [0.123, 0.663, 0.317, 0.811]}, {"category": "person", "bbox": [0.715, 0.239, 1.0, 0.998]}, {"category": "person", "bbox": [0.142, 0.528, 0.281, 0.742]}, {"category": "person", "bbox": [0.529, 0.53, 0.606, 0.69]}, {"category": "person", "bbox": [0.705, 0.518, 0.796, 0.673]}, {"category": "wine glass", "bbox": [0.247, 0.669, 0.27, 0.718]}, {"category": "person", "bbox": [0.281, 0.524, 0.534, 1.0]}, {"category": "bottle", "bbox": [0.168, 0.346, 0.189, 0.425]}, {"category": "bottle", "bbox": [0.379, 0.264, 0.431, 0.433]}, {"category": "bottle", "bbox": [0.252, 0.313, 0.277, 0.429]}, {"category": "bottle", "bbox": [0.294, 0.295, 0.326, 0.43]}, {"category": "bottle", "bbox": [0.589, 0.35, 0.613, 0.444]}, {"category": "bottle", "bbox": [0.433, 0.281, 0.473, 0.437]}, {"category": "bottle", "bbox": [0.478, 0.289, 0.513, 0.44]}, {"category": "wine glass", "bbox": [0.688, 0.615, 0.709, 0.69]}, {"category": "cup", "bbox": [0.589, 0.647, 0.612, 0.693]}, {"category": "person", "bbox": [0.732, 0.356, 0.953, 0.806]}, {"category": "bottle", "bbox": [0.555, 0.337, 0.585, 0.438]}, {"category": "bottle", "bbox": [0.337, 0.29, 0.378, 0.432]}, {"category": "bottle", "bbox": [0.21, 0.333, 0.232, 0.426]}, {"category": "bottle", "bbox": [0.134, 0.36, 0.148, 0.422]}, {"category": "bottle", "bbox": [0.516, 0.312, 0.557, 0.439]}, {"category": "cup", "bbox": [0.231, 0.718, 0.26, 0.763]}, {"category": "chair", "bbox": [0.517, 0.828, 0.65, 0.999]}, {"category": "chair", "bbox": [0.643, 0.804, 0.738, 0.841]}, {"category": "chair", "bbox": [0.347, 0.908, 0.519, 1.0]}, {"category": "chair", "bbox": [0.64, 0.806, 0.74, 0.998]}, {"category": "cup", "bbox": [0.205, 0.692, 0.232, 0.767]}, {"category": "dining table", "bbox": [0.536, 0.676, 0.743, 0.838]}, {"category": "person", "bbox": [0.002, 0.501, 0.263, 0.987]}, {"category": "bottle", "bbox": [0.531, 0.461, 0.542, 0.526]}, {"category": "bottle", "bbox": [0.237, 0.354, 0.702, 0.629]}]}
36
+ {"id": "000000415393", "image": "COCO_val2014_000000415393.jpg", "captions": ["a man on a skate board looks like he is falling", "A man does a skateboard trick on a skateboard ramp", "Guy falling off a skateboard in a room.", "A man riding a skateboard on top of a table.", "a man skating on part of a ramp with his skateboard"], "instances": [{"category": "person", "bbox": [0.361, 0.016, 0.809, 0.888]}, {"category": "skateboard", "bbox": [0.606, 0.809, 0.889, 0.901]}, {"category": "person", "bbox": [0.479, 0.091, 0.576, 0.386]}, {"category": "person", "bbox": [0.047, 0.441, 0.197, 0.759]}, {"category": "person", "bbox": [0.038, 0.453, 0.076, 0.545]}, {"category": "person", "bbox": [0.249, 0.307, 0.311, 0.591]}]}
37
+ {"id": "000000161011", "image": "COCO_val2014_000000161011.jpg", "captions": ["Three skiers posing for a picture on the slope.", "Three skiers pause for a photo at the top of a mountain.", "Three people standing on a mountain taking a picture as they ski.", "A woman and two men on skis on a snowy hillside surrounded by trees", "Three skiers have stopped to pose for a picture."], "instances": [{"category": "person", "bbox": [0.36, 0.321, 0.509, 0.82]}, {"category": "person", "bbox": [0.179, 0.281, 0.349, 0.795]}, {"category": "person", "bbox": [0.611, 0.292, 0.751, 0.809]}, {"category": "skis", "bbox": [0.595, 0.743, 0.732, 0.961]}, {"category": "skis", "bbox": [0.341, 0.724, 0.621, 0.907]}, {"category": "skis", "bbox": [0.212, 0.705, 0.398, 0.905]}]}
38
+ {"id": "000000284296", "image": "COCO_val2014_000000284296.jpg", "captions": ["Three giraffe's leaning over to get a sip of water.", "an image of a herd of giraffes in the water", "three giraffes banding down to drink water with trees in the background", "Three giraffe drinking from a pond with brush in back.", "Giraffes leaning down to drink at a watering hole"], "instances": [{"category": "giraffe", "bbox": [0.624, 0.387, 0.822, 0.635]}, {"category": "giraffe", "bbox": [0.4, 0.326, 0.561, 0.58]}, {"category": "giraffe", "bbox": [0.152, 0.291, 0.343, 0.551]}]}
39
+ {"id": "000000056013", "image": "COCO_val2014_000000056013.jpg", "captions": ["a number of luggage bags on a cart in a lobby", "Wheeled cart with luggage at lobby of commercial business.", "Trolley used for transporting personal luggage to guests rooms.", "A luggage cart topped with lots of luggage.", "a cart filled with suitcases and bags"], "instances": [{"category": "backpack", "bbox": [0.276, 0.52, 0.456, 0.678]}, {"category": "suitcase", "bbox": [0.41, 0.58, 0.597, 0.827]}, {"category": "suitcase", "bbox": [0.173, 0.645, 0.363, 0.836]}, {"category": "person", "bbox": [0.959, 0.297, 1.0, 0.478]}, {"category": "suitcase", "bbox": [0.526, 0.519, 0.712, 0.706]}, {"category": "person", "bbox": [0.762, 0.253, 0.871, 0.46]}, {"category": "backpack", "bbox": [0.517, 0.514, 0.694, 0.698]}, {"category": "handbag", "bbox": [0.316, 0.181, 0.431, 0.426]}, {"category": "suitcase", "bbox": [0.747, 0.453, 0.858, 0.557]}]}
40
+ {"id": "000000293505", "image": "COCO_val2014_000000293505.jpg", "captions": ["A person on a motor bike next to a cow.", "A woman riding a motorcycle down a dirt road.", "there is a woman riding a scooter down a dirt road", "A woman on a moped, two men and animals walking down the road.", "A woman on a motorcycle is next to a man walking a dog along with other people going down a dirt road."], "instances": [{"category": "cow", "bbox": [0.602, 0.472, 0.721, 0.816]}, {"category": "motorcycle", "bbox": [0.402, 0.512, 0.516, 0.788]}, {"category": "person", "bbox": [0.408, 0.4, 0.514, 0.639]}, {"category": "person", "bbox": [0.754, 0.301, 1.0, 1.0]}, {"category": "person", "bbox": [0.705, 0.415, 0.789, 0.714]}, {"category": "cow", "bbox": [0.347, 0.44, 0.373, 0.509]}, {"category": "cow", "bbox": [0.361, 0.436, 0.381, 0.501]}]}
41
+ {"id": "000000305873", "image": "COCO_val2014_000000305873.jpg", "captions": ["A little girl holding a red black dotted umbrella.", "A little girl with rain boots and a rain jacket on and an open umbrella to match her jacket.", "a little girl holding onto a lady bug pattern umbrella", "The child wears a labybug rain coat with a matching umbrella.", "A little girl wearing a ladybug raincoat and green rubber boots holding a ladybug umbrella"], "instances": [{"category": "umbrella", "bbox": [0.246, 0.002, 0.992, 0.415]}, {"category": "person", "bbox": [0.35, 0.132, 0.699, 0.791]}, {"category": "car", "bbox": [0.614, 0.0, 1.0, 0.465]}]}
42
+ {"id": "000000034096", "image": "COCO_val2014_000000034096.jpg", "captions": ["A house being built with lots of wood.", "A big pile of building material is placed on the floor in the wooden structure.", "A partially-built house with wooden studs and staircase in view.", "A house full of wood getting built at the moment.", "The beginning stages of a home still being made."], "instances": [{"category": "bed", "bbox": [0.505, 0.42, 0.721, 0.59]}, {"category": "tv", "bbox": [0.192, 0.441, 0.335, 0.606]}]}
43
+ {"id": "000000165257", "image": "COCO_val2014_000000165257.jpg", "captions": ["A large black counter top sitting next to a sink.", "a clean kitchen counter with a clean sink", "A kitchen with a sink, dishwasher and some boxes on the counter.", "A kitchen with a sink, dishwasher and boxes on the counter.", "a black counter on a wood cabinet in a kitchen", "a new kitchen cabinet with a sink being installed"], "instances": [{"category": "sink", "bbox": [0.513, 0.243, 0.718, 0.314]}]}
44
+ {"id": "000000431026", "image": "COCO_val2014_000000431026.jpg", "captions": ["a street sign on a city street near some tall bushes", "street signs on a metal pole lining a sidewalk lined with shrubbery.", "a large hedge of bushes on a corner near a street sign.", "Two street signs on sidewalk next to bushes and trees.", "Street signs along a well manicured street with large houses."], "instances": []}
45
+ {"id": "000000524575", "image": "COCO_val2014_000000524575.jpg", "captions": ["Three giraffe and a wildebeest in a field.", "A moose and several giraffes are grazing in the field.", "Zebras in the wild with a wildebeest behind them", "Two giraffe and a ox standing in a field eating grass.", "Giraffes and other safari animals graze in a sunlit field."], "instances": [{"category": "cow", "bbox": [0.46, 0.716, 0.643, 0.999]}, {"category": "giraffe", "bbox": [0.285, 0.5, 0.401, 0.826]}, {"category": "giraffe", "bbox": [0.083, 0.554, 0.179, 0.821]}, {"category": "giraffe", "bbox": [0.887, 0.481, 0.968, 0.715]}]}
46
+ {"id": "000000326550", "image": "COCO_val2014_000000326550.jpg", "captions": ["Black and white photograph of a person holding a surfboard by water.", "A person with a surfboard standing next to the water.", "A surfer stands on the rocks watching a wave crash.", "A man standing on a beach holding a surfboard.", "a person looking at the waves ready to surf"], "instances": [{"category": "person", "bbox": [0.327, 0.461, 0.492, 0.897]}, {"category": "surfboard", "bbox": [0.282, 0.56, 0.606, 0.741]}, {"category": "person", "bbox": [0.924, 0.352, 0.933, 0.362]}, {"category": "person", "bbox": [0.912, 0.348, 0.919, 0.36]}]}
47
+ {"id": "000000018476", "image": "COCO_val2014_000000018476.jpg", "captions": ["A tie that is sitting on top of a shirt.", "This photograph appears to be looking truly wonderful.", "a uniform complete with shoes laying on a bed", "Suit laid out with a red tie, white shirt and black shoes.", "a white shirt a red tie and some black shoes"], "instances": [{"category": "tie", "bbox": [0.457, 0.09, 0.853, 0.984]}, {"category": "bed", "bbox": [0.005, 0.005, 1.0, 0.379]}]}
48
+ {"id": "000000480652", "image": "COCO_val2014_000000480652.jpg", "captions": ["These suitcases are sitting next to a chair.", "An assortment of luggage bags stacked by a kitchen chair.", "A stack of luggage by a chair and table.", "a table and chair with several pieces of luggage nearby", "A pile of luggage sitting on the floor."], "instances": [{"category": "chair", "bbox": [0.483, 0.192, 1.0, 0.769]}, {"category": "backpack", "bbox": [0.433, 0.429, 0.742, 0.856]}, {"category": "suitcase", "bbox": [0.059, 0.414, 0.453, 0.841]}, {"category": "handbag", "bbox": [0.19, 0.184, 0.779, 0.475]}, {"category": "suitcase", "bbox": [0.175, 0.204, 0.583, 0.462]}]}
49
+ {"id": "000000012748", "image": "COCO_val2014_000000012748.jpg", "captions": ["A man and child next to a horse.", "a little boy touching the nose of a brown horse", "A man holding a baby whose petting a horse.", "a man letting his baby pet a horse", "man holding a baby and petting a horse"], "instances": [{"category": "horse", "bbox": [0.003, 0.079, 0.504, 0.868]}, {"category": "person", "bbox": [0.452, 0.294, 1.0, 0.989]}, {"category": "person", "bbox": [0.46, 0.217, 1.0, 0.988]}]}
50
+ {"id": "000000247840", "image": "COCO_val2014_000000247840.jpg", "captions": ["Large group of people standing outside a restaurant together.", "A dairy queen has people standing outside waiting", "an image of people standing outside and ice cream store", "Several people are lined up outside of a store.", "The front of a Dairy Queen restaurant with people entering the side."], "instances": [{"category": "fire hydrant", "bbox": [0.774, 0.674, 0.83, 0.807]}, {"category": "person", "bbox": [0.741, 0.465, 0.824, 0.755]}, {"category": "person", "bbox": [0.806, 0.471, 0.839, 0.722]}, {"category": "person", "bbox": [0.831, 0.499, 0.866, 0.726]}, {"category": "bench", "bbox": [0.061, 0.69, 0.219, 0.768]}, {"category": "handbag", "bbox": [0.859, 0.558, 0.877, 0.603]}, {"category": "person", "bbox": [0.719, 0.504, 0.75, 0.626]}, {"category": "potted plant", "bbox": [0.7, 0.648, 0.764, 0.743]}, {"category": "handbag", "bbox": [0.827, 0.548, 0.837, 0.577]}, {"category": "sandwich", "bbox": [0.359, 0.618, 0.417, 0.694]}]}
51
+ {"id": "000000399452", "image": "COCO_val2014_000000399452.jpg", "captions": ["a sandwhich sitting on a plate next to a glass of tea, bowl of soup", "a sandwich on a white plate a drink on a brown table", "A sandwich and chips sit on a white plate.", "a large plate of food with a glass of soda by it", "A sandwich sitting on top of a white plate next to a cup of coffee."], "instances": [{"category": "sandwich", "bbox": [0.175, 0.326, 0.605, 0.71]}, {"category": "cup", "bbox": [0.504, 0.024, 0.687, 0.419]}, {"category": "knife", "bbox": [0.742, 0.283, 0.857, 0.376]}, {"category": "spoon", "bbox": [0.618, 0.46, 0.797, 0.809]}, {"category": "fork", "bbox": [0.684, 0.254, 0.805, 0.395]}, {"category": "bowl", "bbox": [0.782, 0.366, 1.0, 0.62]}, {"category": "chair", "bbox": [0.202, 0.0, 0.671, 0.148]}, {"category": "dining table", "bbox": [0.002, 0.126, 0.996, 0.987]}]}
52
+ {"id": "000000515716", "image": "COCO_val2014_000000515716.jpg", "captions": ["A couple of women standing on either side of a man wearing glasses.", "Two women and a man are holding glasses up at a wine tasting.", "Three young adults holding wine glasses while standing at a bar.", "A group of people sit holding glasses and smiling at a table with several bottles.", "A group of people at a celebration having a taste of wine."], "instances": [{"category": "bottle", "bbox": [0.529, 0.604, 0.637, 0.908]}, {"category": "bottle", "bbox": [0.379, 0.398, 0.481, 0.892]}, {"category": "bottle", "bbox": [0.942, 0.464, 0.988, 0.653]}, {"category": "person", "bbox": [0.0, 0.126, 0.136, 0.811]}, {"category": "person", "bbox": [0.05, 0.093, 0.211, 0.471]}, {"category": "person", "bbox": [0.401, 0.031, 0.678, 0.683]}, {"category": "person", "bbox": [0.617, 0.191, 0.94, 0.858]}, {"category": "person", "bbox": [0.723, 0.098, 0.947, 0.564]}, {"category": "wine glass", "bbox": [0.634, 0.434, 0.697, 0.628]}, {"category": "wine glass", "bbox": [0.285, 0.346, 0.372, 0.558]}, {"category": "wine glass", "bbox": [0.522, 0.422, 0.583, 0.544]}, {"category": "handbag", "bbox": [0.704, 0.601, 1.0, 0.916]}, {"category": "person", "bbox": [0.944, 0.319, 0.999, 0.604]}, {"category": "bottle", "bbox": [0.921, 0.46, 0.953, 0.636]}, {"category": "person", "bbox": [0.116, 0.171, 0.41, 0.829]}]}
53
+ {"id": "000000116173", "image": "COCO_val2014_000000116173.jpg", "captions": ["The boy is on his surfboard in the water riding it.", "a young boy riding a boogie board in the water", "A boy riding surf board in the ocean.", "A young boy is riding a surfboard on a small wave.", "A young boy is surfing in the ocean."], "instances": [{"category": "person", "bbox": [0.485, 0.238, 0.702, 0.821]}, {"category": "person", "bbox": [0.866, 0.223, 0.921, 0.29]}, {"category": "person", "bbox": [0.752, 0.146, 0.775, 0.188]}, {"category": "surfboard", "bbox": [0.239, 0.758, 0.782, 0.846]}, {"category": "surfboard", "bbox": [0.853, 0.277, 0.981, 0.29]}, {"category": "surfboard", "bbox": [0.727, 0.169, 0.801, 0.198]}, {"category": "person", "bbox": [0.637, 0.194, 0.677, 0.261]}]}
54
+ {"id": "000000186013", "image": "COCO_val2014_000000186013.jpg", "captions": ["A beach scene includes many different kites flying in a cloudy sky.", "Kites being flown at the beach at twilight.", "A beach with flags in the ground and kites overhead in the sky.", "A beach with rows of flags in the sand and kites flying overhead.", "A beach filled with kites and wind sails next to the ocean."], "instances": [{"category": "kite", "bbox": [0.174, 0.4, 0.351, 0.483]}, {"category": "kite", "bbox": [0.144, 0.13, 0.273, 0.17]}, {"category": "kite", "bbox": [0.236, 0.269, 0.268, 0.294]}, {"category": "kite", "bbox": [0.464, 0.204, 0.598, 0.271]}, {"category": "kite", "bbox": [0.61, 0.304, 0.659, 0.342]}, {"category": "kite", "bbox": [0.545, 0.435, 0.565, 0.452]}, {"category": "kite", "bbox": [0.027, 0.558, 0.151, 0.59]}, {"category": "kite", "bbox": [0.93, 0.429, 0.973, 0.536]}, {"category": "kite", "bbox": [0.684, 0.36, 0.697, 0.374]}, {"category": "surfboard", "bbox": [0.393, 0.627, 0.446, 0.934]}, {"category": "person", "bbox": [0.959, 0.685, 0.984, 0.713]}, {"category": "person", "bbox": [0.919, 0.681, 0.94, 0.725]}, {"category": "person", "bbox": [0.8, 0.597, 0.805, 0.61]}, {"category": "person", "bbox": [0.079, 0.928, 0.116, 0.975]}, {"category": "kite", "bbox": [0.743, 0.307, 0.755, 0.319]}, {"category": "kite", "bbox": [0.78, 0.322, 0.795, 0.335]}, {"category": "kite", "bbox": [0.536, 0.526, 0.597, 0.617]}, {"category": "person", "bbox": [0.941, 0.694, 0.961, 0.726]}, {"category": "kite", "bbox": [0.575, 0.446, 0.594, 0.471]}]}
55
+ {"id": "000000015029", "image": "COCO_val2014_000000015029.jpg", "captions": ["A man holding a white frisbee standing on top of a field.", "A man is playing frisbee next to a tent.", "Guy at the park holding a frisbee with people in the back under a tent", "A man is holding a Frisbee standing in the grass.", "Young adult male holding a frisbee at an event."], "instances": [{"category": "frisbee", "bbox": [0.138, 0.359, 0.215, 0.587]}, {"category": "person", "bbox": [0.16, 0.002, 0.726, 0.995]}, {"category": "person", "bbox": [0.81, 0.73, 0.852, 0.825]}, {"category": "person", "bbox": [0.786, 0.749, 0.833, 0.814]}, {"category": "person", "bbox": [0.847, 0.743, 0.89, 0.804]}, {"category": "person", "bbox": [0.614, 0.749, 0.706, 0.936]}]}
56
+ {"id": "000000500565", "image": "COCO_val2014_000000500565.jpg", "captions": ["A woman holding a child wrapped in a towel brushing her teeth.", "A woman is holding a baby who is wrapped in a towel and holding a toothbrush", "A woman holding a little boy who is brushing his teeth.", "A baby with a toothbrush in his mouth while being held by a woman", "a close up of an adult holding a child brushing their teeth"], "instances": [{"category": "toothbrush", "bbox": [0.586, 0.66, 0.754, 0.821]}, {"category": "person", "bbox": [0.002, 0.007, 0.637, 0.991]}, {"category": "person", "bbox": [0.357, 0.196, 0.998, 0.984]}]}
57
+ {"id": "000000297323", "image": "COCO_val2014_000000297323.jpg", "captions": ["Two buses are parked against a curb in front of a building.", "Two automobiles parked on the side of a building.", "two tourist buses parked on street in front of old industrial building", "Two unique city buses stopped at a stop sign.", "Buses parked outside by a building and stop sign."], "instances": [{"category": "bus", "bbox": [0.7, 0.711, 0.92, 0.881]}, {"category": "person", "bbox": [0.936, 0.771, 0.972, 0.833]}, {"category": "stop sign", "bbox": [0.237, 0.666, 0.285, 0.728]}, {"category": "bus", "bbox": [0.334, 0.71, 0.678, 0.935]}, {"category": "truck", "bbox": [0.335, 0.72, 0.683, 0.934]}, {"category": "person", "bbox": [0.34, 0.791, 0.367, 0.834]}]}
58
+ {"id": "000000441147", "image": "COCO_val2014_000000441147.jpg", "captions": ["Two antique suitcases sit stacked one on top of the other.", "Two suitcases are stacked on each other and one is black while the other is brown and yellow.", "a close up of two luggage suit cases stacked on each other", "A stack of antique luggage is displayed with price tags.", "two suitcases made of leather and stacked on top of each other"], "instances": [{"category": "suitcase", "bbox": [0.167, 0.025, 0.989, 0.445]}, {"category": "suitcase", "bbox": [0.002, 0.31, 0.994, 0.996]}]}
59
+ {"id": "000000353536", "image": "COCO_val2014_000000353536.jpg", "captions": ["A table topped with plates and glasses with eating utensils..", "a fork is laying on a small white plate", "dirty dishes on a table, and a bottle of something.", "a table top with some dishes on top of it", "A table full of dirty dishes is pictured in this image."], "instances": [{"category": "dining table", "bbox": [0.0, 0.007, 0.998, 0.988]}, {"category": "bottle", "bbox": [0.554, 0.002, 0.768, 0.411]}, {"category": "cup", "bbox": [0.372, 0.011, 0.544, 0.427]}, {"category": "fork", "bbox": [0.442, 0.464, 0.818, 0.572]}, {"category": "fork", "bbox": [0.089, 0.233, 0.272, 0.456]}, {"category": "spoon", "bbox": [0.144, 0.218, 0.326, 0.413]}, {"category": "cup", "bbox": [0.688, 0.056, 0.812, 0.361]}]}
60
+ {"id": "000000416256", "image": "COCO_val2014_000000416256.jpg", "captions": ["A cat laying on the floor next to a keyboard.", "an orange and white cat is laying next to a keyboard and some wires", "A cat is laying next to a computer keyboard.", "a cat laying on a floor next to a keyboard", "A CAT LAYING ON THE FLOOR AMIDST A COMPUTER,SPEAKERS,CORDS"], "instances": [{"category": "cat", "bbox": [0.235, 0.23, 0.737, 0.639]}, {"category": "keyboard", "bbox": [0.243, 0.562, 0.631, 0.836]}, {"category": "keyboard", "bbox": [0.058, 0.33, 0.277, 0.608]}]}
61
+ {"id": "000000214367", "image": "COCO_val2014_000000214367.jpg", "captions": ["Wood shading on the side of a window with brick siding.", "A tree filled with lots of red fruit near a building.", "By the window outside is a apple tree, where the apples are ready to be picked.", "Some very nice looking red fruity by a window,", "A shuttered window has a fruit tree outside it."], "instances": [{"category": "apple", "bbox": [0.214, 0.112, 0.408, 0.266]}, {"category": "apple", "bbox": [0.472, 0.166, 0.618, 0.293]}, {"category": "apple", "bbox": [0.055, 0.592, 0.172, 0.686]}, {"category": "apple", "bbox": [0.126, 0.661, 0.236, 0.739]}, {"category": "apple", "bbox": [0.52, 0.09, 0.609, 0.143]}, {"category": "apple", "bbox": [0.226, 0.354, 0.285, 0.409]}, {"category": "apple", "bbox": [0.0, 0.698, 0.096, 0.771]}, {"category": "apple", "bbox": [0.001, 0.646, 0.042, 0.713]}, {"category": "apple", "bbox": [0.258, 0.719, 0.329, 0.778]}]}
62
+ {"id": "000000210299", "image": "COCO_val2014_000000210299.jpg", "captions": ["A little boy riding his bike and wearing a helmet", "A little boy raveling down a road on a bike, with a yellow helmet on.", "The boy wears a helmet while riding his bicycle.", "a small child wearing a helmet and riding a bike", "A little boy wearing a helmet and riding a bike."], "instances": [{"category": "person", "bbox": [0.198, 0.259, 0.399, 0.679]}, {"category": "bicycle", "bbox": [0.213, 0.383, 0.408, 0.835]}]}
63
+ {"id": "000000088218", "image": "COCO_val2014_000000088218.jpg", "captions": ["Signs proclaim the famous Haight Ashbury intersection and district.", "a pole with street lights, signs and wires attached to it", "A traffic light at the intersection of Haight and Ashbury", "A traffic sign is shown with traffic signs above it.", "The street signs and traffic signal are below wires attached to the pole."], "instances": [{"category": "traffic light", "bbox": [0.443, 0.435, 0.658, 0.721]}]}
64
+ {"id": "000000020650", "image": "COCO_val2014_000000020650.jpg", "captions": ["Burger with broccoli, pickle, and fork on orange plate", "On a plate is kept a burger and a bowl of broccoli and a fork.", "There is half a sandwich on an orange plate with a pickle and a bowl of broccoli", "A A bowl and a sandwich on an orange plate on a table.", "A plate has a sandwich, broccoli, and a pickle."], "instances": [{"category": "sandwich", "bbox": [0.436, 0.155, 0.805, 0.859]}, {"category": "sandwich", "bbox": [0.311, 0.006, 0.748, 0.293]}, {"category": "fork", "bbox": [0.0, 0.665, 0.578, 0.876]}, {"category": "bowl", "bbox": [0.002, 0.263, 0.487, 0.744]}, {"category": "bowl", "bbox": [0.708, 0.003, 0.828, 0.03]}, {"category": "broccoli", "bbox": [0.185, 0.288, 0.366, 0.546]}, {"category": "broccoli", "bbox": [0.017, 0.344, 0.384, 0.654]}, {"category": "broccoli", "bbox": [0.31, 0.191, 0.466, 0.463]}, {"category": "broccoli", "bbox": [0.104, 0.107, 0.285, 0.342]}, {"category": "broccoli", "bbox": [0.092, 0.276, 0.242, 0.442]}, {"category": "dining table", "bbox": [0.002, 0.0, 0.999, 0.987]}]}
65
+ {"id": "000000514915", "image": "COCO_val2014_000000514915.jpg", "captions": ["A large black dog laying on a kitchen floor.", "A dog is laying down on the floor in the home.", "Black dog laying down on the kitchen floor next to it's bowls and toy", "A black dog with a red collar laying on a tiled floor.", "A black dog that is laying on the floor."], "instances": [{"category": "dog", "bbox": [0.087, 0.276, 0.812, 0.792]}, {"category": "bowl", "bbox": [0.437, 0.09, 0.533, 0.213]}, {"category": "bowl", "bbox": [0.537, 0.035, 0.665, 0.141]}]}
66
+ {"id": "000000205183", "image": "COCO_val2014_000000205183.jpg", "captions": ["A duck walking along a paved road next to a patch of grass.", "A close up of a duck walking on a path.", "a duck walks along a cement patch while looking down", "A white duck out of water, walking on the ground.", "A goose standing in the road, looking at the ground."], "instances": [{"category": "bird", "bbox": [0.291, 0.235, 0.859, 0.889]}]}
67
+ {"id": "000000534270", "image": "COCO_val2014_000000534270.jpg", "captions": ["Man and woman with umbrella hats sitting on top of a bridge.", "A couple equipped with umbrella hats taking a break from walking their dog on a bridge on a rainy day.", "Two people in ridiculous looking umbrella hats.", "two people with umbrella hats near one another", "A couple of people wearing umbrella hats next to the ocean."], "instances": [{"category": "dog", "bbox": [0.456, 0.832, 0.6, 0.983]}, {"category": "person", "bbox": [0.433, 0.464, 0.636, 0.975]}, {"category": "person", "bbox": [0.263, 0.321, 0.459, 0.978]}, {"category": "boat", "bbox": [0.912, 0.4, 0.978, 0.433]}, {"category": "boat", "bbox": [0.211, 0.236, 0.478, 0.304]}, {"category": "boat", "bbox": [0.144, 0.328, 0.189, 0.361]}, {"category": "umbrella", "bbox": [0.443, 0.402, 0.607, 0.473]}, {"category": "umbrella", "bbox": [0.325, 0.311, 0.483, 0.432]}, {"category": "umbrella", "bbox": [0.207, 0.738, 0.284, 0.778]}, {"category": "umbrella", "bbox": [0.489, 0.713, 0.649, 0.83]}]}
68
+ {"id": "000000408439", "image": "COCO_val2014_000000408439.jpg", "captions": ["Cliffs rise on the edge of a placid lake.", "A scenic view of a river with a train on the edge of it in the distance.", "A large lake surrounded by beautiful tree covered mountains.", "a landscape scene with water, mountains and trees", "A train on a waterfront track surrounded by mountains."], "instances": [{"category": "train", "bbox": [0.008, 0.591, 0.562, 0.644]}]}
69
+ {"id": "000000474253", "image": "COCO_val2014_000000474253.jpg", "captions": ["A man riding on the back of a horse through a river.", "A person is riding a horse through water.", "Horse and rider crossing waterway during competitive event.", "A woman riding a horse splashes through a large puddle.", "A young man riding a horse through some water."], "instances": [{"category": "horse", "bbox": [0.385, 0.235, 0.651, 0.814]}, {"category": "person", "bbox": [0.396, 0.06, 0.576, 0.675]}, {"category": "person", "bbox": [0.29, 0.148, 0.355, 0.333]}, {"category": "person", "bbox": [0.129, 0.163, 0.212, 0.349]}, {"category": "person", "bbox": [0.005, 0.014, 0.038, 0.165]}, {"category": "person", "bbox": [0.144, 0.011, 0.193, 0.155]}, {"category": "person", "bbox": [0.089, 0.007, 0.133, 0.162]}]}
70
+ {"id": "000000098029", "image": "COCO_val2014_000000098029.jpg", "captions": ["a table with many plates on it with a bread basket", "A table set for four has many foods and fruits on it.", "Several objects displayed on a kitchen table including bread, oranges and plating.", "Several dishes and food items sit on a table.", "An assortment of foods sitting on a round brown table."], "instances": [{"category": "refrigerator", "bbox": [0.013, 0.004, 0.37, 0.317]}, {"category": "bottle", "bbox": [0.467, 0.517, 0.555, 0.638]}, {"category": "bottle", "bbox": [0.602, 0.536, 0.658, 0.609]}, {"category": "chair", "bbox": [0.747, 0.367, 1.0, 0.592]}, {"category": "chair", "bbox": [0.044, 0.368, 0.358, 0.544]}, {"category": "cup", "bbox": [0.296, 0.465, 0.359, 0.54]}, {"category": "cup", "bbox": [0.709, 0.67, 0.782, 0.736]}, {"category": "cup", "bbox": [0.213, 0.684, 0.294, 0.753]}, {"category": "knife", "bbox": [0.787, 0.699, 0.922, 0.797]}, {"category": "knife", "bbox": [0.161, 0.539, 0.265, 0.584]}, {"category": "spoon", "bbox": [0.813, 0.674, 0.922, 0.759]}, {"category": "spoon", "bbox": [0.156, 0.555, 0.233, 0.587]}, {"category": "spoon", "bbox": [0.596, 0.467, 0.613, 0.509]}, {"category": "bowl", "bbox": [0.241, 0.753, 0.505, 0.935]}, {"category": "banana", "bbox": [0.632, 0.138, 0.718, 0.161]}, {"category": "apple", "bbox": [0.701, 0.152, 0.758, 0.191]}, {"category": "orange", "bbox": [0.607, 0.66, 0.692, 0.716]}, {"category": "orange", "bbox": [0.565, 0.636, 0.611, 0.667]}, {"category": "orange", "bbox": [0.526, 0.624, 0.572, 0.652]}, {"category": "orange", "bbox": [0.61, 0.628, 0.656, 0.657]}, {"category": "orange", "bbox": [0.599, 0.649, 0.643, 0.677]}, {"category": "dining table", "bbox": [0.013, 0.439, 0.964, 0.986]}, {"category": "cup", "bbox": [0.612, 0.489, 0.669, 0.548]}, {"category": "knife", "bbox": [0.605, 0.457, 0.638, 0.53]}, {"category": "apple", "bbox": [0.502, 0.137, 0.537, 0.159]}, {"category": "orange", "bbox": [0.54, 0.135, 0.563, 0.151]}, {"category": "orange", "bbox": [0.527, 0.129, 0.554, 0.142]}, {"category": "orange", "bbox": [0.611, 0.155, 0.641, 0.171]}, {"category": "chair", "bbox": [0.0, 0.843, 0.29, 0.989]}, {"category": "cup", "bbox": [0.353, 0.469, 0.411, 0.511]}, {"category": "cup", "bbox": [0.609, 0.716, 0.682, 0.786]}, {"category": "orange", "bbox": [0.638, 0.158, 0.679, 0.177]}, {"category": "cake", "bbox": [0.38, 0.821, 0.481, 0.895]}, {"category": "chair", "bbox": [0.79, 0.747, 1.0, 1.0]}, {"category": "bottle", "bbox": [0.719, 0.55, 0.769, 0.616]}, {"category": "bottle", "bbox": [0.795, 0.546, 0.873, 0.613]}, {"category": "knife", "bbox": [0.17, 0.799, 0.264, 0.88]}, {"category": "cup", "bbox": [0.317, 0.695, 0.391, 0.752]}]}
71
+ {"id": "000000294073", "image": "COCO_val2014_000000294073.jpg", "captions": ["A woman and a man standing between two brown horses.", "A COUPLE WEARING YELLOW DRESS STANDING NEAR TWO HORSES.", "An older couple stands between two horses.", "A man and a woman standing with two horses", "A man and a woman stand in between two horses."], "instances": [{"category": "horse", "bbox": [0.0, 0.052, 0.49, 0.989]}, {"category": "horse", "bbox": [0.632, 0.23, 1.0, 0.989]}, {"category": "person", "bbox": [0.425, 0.326, 0.696, 0.987]}, {"category": "person", "bbox": [0.627, 0.203, 0.828, 0.986]}, {"category": "book", "bbox": [0.525, 0.597, 0.644, 0.833]}]}
72
+ {"id": "000000203629", "image": "COCO_val2014_000000203629.jpg", "captions": ["A man on a cell phone in a public area holding his thumb up.", "A group of people gathered inside of a room.", "A man on his cellphone posing for a picture.", "A man giving a thumbs up while on a cell phone.", "The man is giving a thumbs up while on his phone."], "instances": [{"category": "cell phone", "bbox": [0.43, 0.459, 0.449, 0.503]}, {"category": "cup", "bbox": [0.756, 0.838, 0.865, 0.98]}, {"category": "person", "bbox": [0.232, 0.317, 0.603, 0.98]}, {"category": "person", "bbox": [0.602, 0.405, 1.0, 0.999]}, {"category": "person", "bbox": [0.003, 0.339, 0.313, 0.987]}, {"category": "person", "bbox": [0.164, 0.379, 0.258, 0.733]}, {"category": "person", "bbox": [0.564, 0.36, 0.673, 0.645]}, {"category": "person", "bbox": [0.241, 0.379, 0.336, 0.512]}, {"category": "person", "bbox": [0.682, 0.372, 0.736, 0.502]}, {"category": "person", "bbox": [0.654, 0.428, 0.734, 0.536]}, {"category": "person", "bbox": [0.718, 0.368, 0.787, 0.508]}, {"category": "person", "bbox": [0.148, 0.362, 0.205, 0.529]}, {"category": "person", "bbox": [0.001, 0.431, 0.044, 0.564]}, {"category": "cup", "bbox": [0.901, 0.808, 0.995, 0.982]}]}
73
+ {"id": "000000119876", "image": "COCO_val2014_000000119876.jpg", "captions": ["A man dressed loudly is using his cell phone.", "A man talking on the phone while he walks down the street.", "A man with pink hair talking on a cell phone.", "A man in a purple shirt and tie and purple hair.", "a man colored his hair in purple walking on the road"], "instances": [{"category": "bicycle", "bbox": [0.525, 0.222, 0.924, 0.608]}, {"category": "bicycle", "bbox": [0.895, 0.249, 1.0, 0.642]}, {"category": "person", "bbox": [0.0, 0.0, 0.738, 1.0]}, {"category": "tie", "bbox": [0.319, 0.255, 0.423, 0.638]}, {"category": "cell phone", "bbox": [0.411, 0.13, 0.426, 0.161]}, {"category": "handbag", "bbox": [0.369, 0.205, 0.575, 0.839]}]}
74
+ {"id": "000000164255", "image": "COCO_val2014_000000164255.jpg", "captions": ["An umbrella that is standing in the sand.", "An umbrella is stuck in the sand on the beach.", "a colorful striped umbrella on the beach near the ocean", "A colorful umbrella is set up at the beach.", "The colorful umbrella is sitting by the beach,"], "instances": [{"category": "umbrella", "bbox": [0.0, 0.101, 0.567, 0.575]}]}
75
+ {"id": "000000192817", "image": "COCO_val2014_000000192817.jpg", "captions": ["A view from a window high up in the sky.", "A bunch of mountains seen from a plane window.", "The window from a plane overlooking the ground.", "The view of a mountain area from an airplane window.", "An aerial view of mountains and lakes from an airplane window."], "instances": []}
76
+ {"id": "000000258285", "image": "COCO_val2014_000000258285.jpg", "captions": ["Two large passenger jets flying over a beach filled with birds.", "A plane is flying over a bird filed lake", "Two airplanes are in the sky over blue water.", "An airplane landing over an airplane on the ground.", "A photo of two plans with water and birds surrounding it , one plane in the air one one the ground."], "instances": [{"category": "bird", "bbox": [0.507, 0.941, 0.536, 0.973]}, {"category": "bird", "bbox": [0.304, 0.933, 0.315, 0.95]}, {"category": "bird", "bbox": [0.129, 0.885, 0.143, 0.912]}, {"category": "bird", "bbox": [0.158, 0.851, 0.165, 0.87]}, {"category": "bird", "bbox": [0.404, 0.839, 0.429, 0.864]}, {"category": "bird", "bbox": [0.498, 0.833, 0.513, 0.861]}, {"category": "airplane", "bbox": [0.276, 0.085, 0.825, 0.316]}, {"category": "airplane", "bbox": [0.478, 0.252, 0.983, 0.495]}, {"category": "bird", "bbox": [0.552, 0.828, 0.564, 0.844]}, {"category": "bird", "bbox": [0.789, 0.812, 0.798, 0.836]}, {"category": "bird", "bbox": [0.927, 0.82, 0.936, 0.838]}, {"category": "bird", "bbox": [0.65, 0.828, 0.664, 0.849]}, {"category": "bird", "bbox": [0.752, 0.81, 0.763, 0.83]}, {"category": "bird", "bbox": [0.841, 0.817, 0.852, 0.828]}, {"category": "bird", "bbox": [0.292, 0.849, 0.311, 0.868]}, {"category": "bird", "bbox": [0.005, 0.727, 0.981, 0.998]}]}
77
+ {"id": "000000506483", "image": "COCO_val2014_000000506483.jpg", "captions": ["An art installation is placed by a street.", "People sit near a display of large artworks including an oversize bench and painted feline heads.", "Looking down on a giant rocking bench and large animal heads.", "An over sized wooden bench next to two massive animal art sculptures.", "artistic sculptures and images on a city street"], "instances": [{"category": "car", "bbox": [0.656, 0.939, 0.933, 1.0]}, {"category": "person", "bbox": [0.08, 0.664, 0.147, 0.805]}, {"category": "person", "bbox": [0.154, 0.646, 0.217, 0.821]}, {"category": "bench", "bbox": [0.316, 0.124, 0.951, 0.635]}, {"category": "backpack", "bbox": [0.062, 0.701, 0.097, 0.769]}, {"category": "person", "bbox": [0.0, 0.132, 0.031, 0.197]}]}
78
+ {"id": "000000502168", "image": "COCO_val2014_000000502168.jpg", "captions": ["a fleet of naval ships in the ocean", "A group of men on aircraft carrier with other boats in the distance.", "A large ship floating in the ocean next to other ships.", "Several men on a boat looking over the side.", "The men wear hardhats as they work on the aircraft carrier."], "instances": [{"category": "boat", "bbox": [0.634, 0.292, 1.0, 0.982]}, {"category": "person", "bbox": [0.675, 0.507, 0.736, 0.731]}, {"category": "person", "bbox": [0.684, 0.737, 0.817, 1.0]}, {"category": "person", "bbox": [0.803, 0.691, 0.883, 0.932]}, {"category": "person", "bbox": [0.741, 0.56, 0.798, 0.767]}, {"category": "person", "bbox": [0.924, 0.269, 0.951, 0.367]}, {"category": "boat", "bbox": [0.079, 0.171, 0.172, 0.231]}, {"category": "boat", "bbox": [0.863, 0.131, 0.961, 0.239]}, {"category": "boat", "bbox": [0.435, 0.288, 0.46, 0.313]}, {"category": "boat", "bbox": [0.591, 0.186, 0.605, 0.222]}, {"category": "person", "bbox": [0.451, 0.289, 0.455, 0.296]}, {"category": "person", "bbox": [0.446, 0.29, 0.451, 0.296]}, {"category": "person", "bbox": [0.872, 0.627, 0.957, 0.966]}, {"category": "person", "bbox": [0.44, 0.288, 0.446, 0.3]}]}
79
+ {"id": "000000319432", "image": "COCO_val2014_000000319432.jpg", "captions": ["Man holding two shirts with luggage and window", "A man holding clothes on a hanger with a suitcase in front of him.", "A man show a red and a white clothing hangers.", "A man holding his garment bags in both hands", "A man holding up some clothes in some hanger bags."], "instances": [{"category": "person", "bbox": [0.0, 0.092, 0.776, 0.852]}, {"category": "suitcase", "bbox": [0.153, 0.798, 0.587, 1.0]}]}
80
+ {"id": "000000131019", "image": "COCO_val2014_000000131019.jpg", "captions": ["Two zebras and two monkeys walking on the grass.", "Two giraffes and another animal are on green grass.", "A baboon and two zebras grazing on the savannah.", "A baboon and its baby eat by two zebras in the grass", "Monkey standing behind two zebras as they graze."], "instances": [{"category": "zebra", "bbox": [0.367, 0.258, 0.834, 0.646]}, {"category": "zebra", "bbox": [0.161, 0.13, 0.396, 0.375]}, {"category": "bird", "bbox": [0.309, 0.138, 0.34, 0.163]}]}
ChatUniVi/eval/table/model.jsonl ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {"model_id": "vicuna-13b:20230322-clean-lang", "model_name": "vicuna-13b", "model_version": "20230322-clean-lang", "model_metadata": "vicuna-13b-20230322-clean-lang"}
2
+ {"model_id": "alpaca-13b:v1", "model_name": "alpaca-13b", "model_version": "v1", "model_metadata": "alpaca-13b"}
3
+ {"model_id": "llama-13b:v1", "model_name": "llama-13b", "model_version": "v1", "model_metadata": "hf-llama-13b"}
4
+ {"model_id": "bard:20230327", "model_name": "bard", "model_version": "20230327", "model_metadata": "Google Bard 20230327"}
5
+ {"model_id": "gpt-3.5-turbo:20230327", "model_name": "gpt-3.5-turbo", "model_version": "20230327", "model_metadata": "OpenAI ChatGPT gpt-3.5-turbo Chat Completion"}
ChatUniVi/eval/table/question.jsonl ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"question_id": 1, "text": "How can I improve my time management skills?", "category": "generic"}
2
+ {"question_id": 2, "text": "What are the most effective ways to deal with stress?", "category": "generic"}
3
+ {"question_id": 3, "text": "What are the main differences between Python and JavaScript programming languages?", "category": "generic"}
4
+ {"question_id": 4, "text": "How can I increase my productivity while working from home?", "category": "generic"}
5
+ {"question_id": 5, "text": "Can you explain the basics of quantum computing?", "category": "generic"}
6
+ {"question_id": 6, "text": "What are the differences between plant-based and animal-based protein sources?", "category": "generic"}
7
+ {"question_id": 7, "text": "How can I develop my critical thinking skills?", "category": "generic"}
8
+ {"question_id": 8, "text": "What are the major challenges faced by the education sector today?", "category": "generic"}
9
+ {"question_id": 9, "text": "What are the primary factors that influence consumer behavior?", "category": "generic"}
10
+ {"question_id": 10, "text": "What are the most effective strategies for conflict resolution in the workplace?", "category": "generic"}
11
+ {"question_id": 11, "text": "What are some potential implications of using a single-use plastic bottle versus a reusable bottle on both the environment and human health?", "category": "knowledge"}
12
+ {"question_id": 12, "text": "What factors would you consider when designing an inclusive and accessible public transportation system?", "category": "knowledge"}
13
+ {"question_id": 13, "text": "How can governments utilize fiscal and monetary policies to combat economic recessions?", "category": "knowledge"}
14
+ {"question_id": 14, "text": "How do language and cultural barriers affect the way people communicate and form relationships in multicultural societies?", "category": "knowledge"}
15
+ {"question_id": 15, "text": "Describe a scenario where artificial intelligence could be used to improve the quality and efficiency of healthcare delivery.", "category": "knowledge"}
16
+ {"question_id": 16, "text": "Explain the process of gene editing using CRISPR-Cas9 technology, and discuss its potential applications and ethical implications.", "category": "knowledge"}
17
+ {"question_id": 17, "text": "How do vaccinations work to protect individuals and communities from infectious diseases, and what is herd immunity?", "category": "knowledge"}
18
+ {"question_id": 18, "text": "How do social media platforms influence the way people consume and share news, and what are the potential implications for the spread of misinformation?", "category": "knowledge"}
19
+ {"question_id": 19, "text": "How do cultural, social, and economic factors influence people's food choices, and how can this knowledge be used to promote healthier diets?", "category": "knowledge"}
20
+ {"question_id": 20, "text": "Explain the process of natural selection and how it contributes to the evolution and adaptation of species.", "category": "knowledge"}
21
+ {"question_id": 21, "text": "How would you introduce yourself as a medieval knight at a royal banquet?", "category": "roleplay"}
22
+ {"question_id": 22, "text": "As a pirate captain, what would you say to your crew to motivate them to search for hidden treasure?", "category": "roleplay"}
23
+ {"question_id": 23, "text": "If you were a Shakespearean character, how would you declare your love for someone in a soliloquy?", "category": "roleplay"}
24
+ {"question_id": 24, "text": "As a superhero, how would you explain your origin story to a curious child?", "category": "roleplay"}
25
+ {"question_id": 25, "text": "Imagine you are a time traveler from the year 3000. What technological advancements would you tell people about?", "category": "roleplay"}
26
+ {"question_id": 26, "text": "As a sports commentator, describe the winning play in the final seconds of a championship game.", "category": "roleplay"}
27
+ {"question_id": 27, "text": "Pretend to be a world-famous chef. How would you describe your signature dish to a panel of judges?", "category": "roleplay"}
28
+ {"question_id": 28, "text": "You are a mountain climber reaching the summit of Mount Everest. Describe your emotions and the view from the top.", "category": "roleplay"}
29
+ {"question_id": 29, "text": "As a space colonist on Mars, describe your daily life and the challenges you face living on another planet.", "category": "roleplay"}
30
+ {"question_id": 30, "text": "Pretend to be a character in a post-apocalyptic world. Describe how you survive and the allies you encounter.", "category": "roleplay"}
31
+ {"question_id": 31, "text": "How can you determine if a restaurant is popular among locals or mainly attracts tourists, and why might this information be useful?", "category": "common-sense"}
32
+ {"question_id": 32, "text": "What are some subtle clues that suggest someone is pretending to understand a topic or conversation when they are actually confused or uninformed?", "category": "common-sense"}
33
+ {"question_id": 33, "text": "Why might someone choose to use a paper map or ask for directions instead of relying on a GPS device or smartphone app?", "category": "common-sense"}
34
+ {"question_id": 34, "text": "How can you determine if a person is genuinely interested in a conversation or simply being polite?", "category": "common-sense"}
35
+ {"question_id": 35, "text": "Why might someone prefer to shop at a small, locally-owned business instead of a large chain store, even if the prices are higher?", "category": "common-sense"}
36
+ {"question_id": 36, "text": "How can you assess the credibility of a source of information, such as a news article or blog post, without relying solely on the reputation of the author or publisher?", "category": "common-sense"}
37
+ {"question_id": 37, "text": "Why do some people enjoy the sensation of being scared, such as by watching horror movies or going on roller coasters, while others avoid these experiences?", "category": "common-sense"}
38
+ {"question_id": 38, "text": "How can observing the behavior of other people in a social situation provide clues about cultural norms and expectations?", "category": "common-sense"}
39
+ {"question_id": 39, "text": "Do we have a moral obligation to explore space, or should we focus on solving Earth's problems first?", "category": "common-sense"}
40
+ {"question_id": 40, "text": "In a world where automation is becoming increasingly prevalent, is it more important to prioritize job creation or technological progress?", "category": "common-sense"}
41
+ {"question_id": 41, "text": "How many times does the average human blink in a lifetime? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"}
42
+ {"question_id": 42, "text": "How many atoms are in a grain of salt? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"}
43
+ {"question_id": 43, "text": "How many lightning strikes occur on Earth each day? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"}
44
+ {"question_id": 44, "text": "How many balloons would it take to lift a house like in the movie \"Up\"? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"}
45
+ {"question_id": 45, "text": "How many text messages are sent globally in a minute? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"}
46
+ {"question_id": 46, "text": "How many words are spoken daily on Earth? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"}
47
+ {"question_id": 47, "text": "How many snowflakes fall during a typical winter? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"}
48
+ {"question_id": 48, "text": "How many pages are in all the books ever written? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"}
49
+ {"question_id": 49, "text": "How many times has the Earth orbited the Sun since the beginning of life? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"}
50
+ {"question_id": 50, "text": "How many songs have been recorded throughout history? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"}
51
+ {"question_id": 51, "text": "What if the Internet had been invented during the Renaissance period?", "category": "counterfactual"}
52
+ {"question_id": 52, "text": "What if the Aztecs had successfully repelled the Spanish conquistadors?", "category": "counterfactual"}
53
+ {"question_id": 53, "text": "What if the Black Death had not occurred in the 14th century?", "category": "counterfactual"}
54
+ {"question_id": 54, "text": "What if Isaac Newton had focused on biology instead of physics?", "category": "counterfactual"}
55
+ {"question_id": 55, "text": "What if the Beatles had never formed as a band?", "category": "counterfactual"}
56
+ {"question_id": 56, "text": "What if Alan Turing had not cracked the Enigma code during World War II?", "category": "counterfactual"}
57
+ {"question_id": 57, "text": "What if the Suez Canal had never been constructed?", "category": "counterfactual"}
58
+ {"question_id": 58, "text": "What if the Maya civilization had never mysteriously collapsed?", "category": "counterfactual"}
59
+ {"question_id": 59, "text": "What if Christopher Columbus had not discovered the Americas?", "category": "counterfactual"}
60
+ {"question_id": 60, "text": "What if Vincent van Gogh had been a successful artist during his lifetime?", "category": "counterfactual"}
61
+ {"question_id": 61, "text": "Develop a C++ program that reads a text file line by line and counts the number of occurrences of a specific word in the file.", "category": "coding"}
62
+ {"question_id": 62, "text": "Implement a Python function to find the longest common subsequence of two input strings using dynamic programming.", "category": "coding"}
63
+ {"question_id": 63, "text": "Implement a regular expression in Python to validate an email address.", "category": "coding"}
64
+ {"question_id": 64, "text": "Write a program to find the nth Fibonacci number using dynamic programming.", "category": "coding"}
65
+ {"question_id": 65, "text": "Implement a binary search algorithm to find a specific element in a sorted array.", "category": "coding"}
66
+ {"question_id": 66, "text": "Implement a queue data structure using two stacks in Python.", "category": "coding"}
67
+ {"question_id": 67, "text": "Implement a program to find the common elements in two arrays without using any extra data structures.", "category": "coding"}
68
+ {"question_id": 68, "text": "Given that f(x) = 5x^3 - 2x + 3, find the value of f(2).", "category": "math"}
69
+ {"question_id": 69, "text": "Solve for x in the equation 3x + 10 = 5(x - 2).", "category": "math"}
70
+ {"question_id": 70, "text": "If the endpoints of a line segment are (2, -2) and (10, 4), what is the length of the segment?", "category": "math"}
71
+ {"question_id": 71, "text": "Can you help me write a formal email to a potential business partner proposing a joint venture?", "category": "writing"}
72
+ {"question_id": 72, "text": "Can you help me write a resignation letter to my current employer, while leaving on good terms and expressing gratitude for the opportunities provided?", "category": "writing"}
73
+ {"question_id": 73, "text": "Use an appropriate format to structure a formal letter of recommendation for a student applying to a prestigious graduate program in computer science.", "category": "writing"}
74
+ {"question_id": 74, "text": "Write a compelling product launch announcement email to inform our customers of our new software solution.", "category": "writing"}
75
+ {"question_id": 75, "text": "Draft an apology email to a customer who experienced a delay in their order, and provide reassurance that the issue has been resolved.", "category": "writing"}
76
+ {"question_id": 76, "text": "Write a script for a YouTube video exploring the history and cultural significance of jazz.", "category": "writing"}
77
+ {"question_id": 77, "text": "Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions.", "category": "writing"}
78
+ {"question_id": 78, "text": "Write a captivating movie review for a recently released science fiction film, discussing its plot, characters, and special effects.", "category": "writing"}
79
+ {"question_id": 79, "text": "Structure a podcast script for an episode discussing the influence of streaming platforms on the music industry.", "category": "writing"}
80
+ {"question_id": 80, "text": "Write a symphony concert review, discussing the orchestra's performance and overall audience experience.", "category": "writing"}
ChatUniVi/eval/table/reviewer.jsonl ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {"reviewer_id": "gpt-4-0328-default", "prompt_id": 1, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for general questions"}
2
+ {"reviewer_id": "gpt-4-0328-coding", "prompt_id": 2, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for coding questions"}
3
+ {"reviewer_id": "gpt-4-0328-math", "prompt_id": 3, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for math questions"}
4
+ {"reviewer_id": "gpt-4-0417-visual", "prompt_id": 4, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for math questions"}
ChatUniVi/eval/table/rule.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "coding": {"role": "Assistant", "prompt": "Your task is to evaluate the coding abilities of the above two assistants. They have been asked to implement a program to solve a given problem. Please review their code submissions, paying close attention to their problem-solving approach, code structure, readability, and the inclusion of helpful comments.\n\nPlease ensure that the assistants' submissions:\n\n1. Correctly implement the given problem statement.\n2. Contain accurate and efficient code.\n3. Include clear and concise comments that explain the code's logic and functionality.\n4. Adhere to proper coding standards and best practices.\n\nOnce you have carefully reviewed both submissions, provide detailed feedback on their strengths and weaknesses, along with any suggestions for improvement. You should first output a single line containing two scores on the scale of 1-10 (1: no code/no sense; 10: perfect) for Assistant 1 and 2, respectively. Then give extra comments starting from the next line."},
3
+ "math": {"role": "Assistant", "prompt": "We would like to request your feedback on the mathematical proficiency of two AI assistants regarding the given user question.\nFirstly, please solve the problem independently, without referring to the answers provided by Assistant 1 and Assistant 2.\nAfterward, please examine the problem-solving process of Assistant 1 and Assistant 2 step-by-step to ensure their correctness, identifying any incorrect steps if present. Your evaluation should take into account not only the answer but also the problem-solving steps.\nFinally, please output a Python tuple containing two numerical scores for Assistant 1 and Assistant 2, ranging from 1 to 10, respectively. If applicable, explain the reasons for any variations in their scores and determine which assistant performed better."},
4
+ "default": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above.\nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."},
5
+ "conv": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with five descriptive sentences describing the same image and the bounding box coordinates of each object in the scene. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."},
6
+ "detail": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with five descriptive sentences describing the same image and the bounding box coordinates of each object in the scene. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."},
7
+ "complex": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with five descriptive sentences describing the same image and the bounding box coordinates of each object in the scene. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."},
8
+ "llava_bench_conv": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with a few sentences describing the image. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."},
9
+ "llava_bench_detail": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with a few sentences describing the image. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."},
10
+ "llava_bench_complex": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with a few sentences describing the image. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}
11
+ }
ChatUniVi/model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .language_model.llama import ChatUniViLlamaForCausalLM, ChatUniViConfig
ChatUniVi/model/apply_delta.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import torch
4
+ from tqdm import tqdm
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+ from ChatUniVi import ChatUniViLlamaForCausalLM
7
+
8
+
9
+ def apply_delta(base_model_path, target_model_path, delta_path):
10
+ print("Loading base model")
11
+ base = AutoModelForCausalLM.from_pretrained(
12
+ base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
13
+
14
+ print("Loading delta")
15
+ delta = ChatUniViLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
16
+ delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
17
+
18
+ print("Applying delta")
19
+ for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
20
+ if name not in base.state_dict():
21
+ assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
22
+ continue
23
+ if param.data.shape == base.state_dict()[name].shape:
24
+ param.data += base.state_dict()[name]
25
+ else:
26
+ assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \
27
+ f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
28
+ bparam = base.state_dict()[name]
29
+ param.data[:bparam.shape[0], :bparam.shape[1]] += bparam
30
+
31
+ print("Saving target model")
32
+ delta.save_pretrained(target_model_path)
33
+ delta_tokenizer.save_pretrained(target_model_path)
34
+
35
+
36
+ if __name__ == "__main__":
37
+ parser = argparse.ArgumentParser()
38
+ parser.add_argument("--base-model-path", type=str, required=True)
39
+ parser.add_argument("--target-model-path", type=str, required=True)
40
+ parser.add_argument("--delta-path", type=str, required=True)
41
+
42
+ args = parser.parse_args()
43
+
44
+ apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
ChatUniVi/model/arch.py ADDED
@@ -0,0 +1,652 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ import torch
3
+ import torch.nn as nn
4
+ from PIL.ImImagePlugin import split
5
+
6
+ from .multimodal_encoder.builder import build_vision_tower
7
+ from ChatUniVi.constants import *
8
+ from .cluster import CTM, TCBlock
9
+ from collections import OrderedDict
10
+ from .multimodal_projector.builder import build_vision_projector
11
+
12
+
13
+ class MetaModel:
14
+ def __init__(self, config):
15
+ super(MetaModel, self).__init__(config)
16
+
17
+ if hasattr(config, "mm_vision_tower"):
18
+ self.vision_tower = build_vision_tower(config, delay_load=True)
19
+ self.mm_projector = nn.Linear(config.mm_hidden_size, config.hidden_size)
20
+
21
+ if hasattr(config, "config"):
22
+ self.use_cluster = config.config["use_cluster"]
23
+ if self.use_cluster:
24
+ self.ctm0 = CTM(sample_ratio=config.config["spatial_cluster_rate0"], embed_dim=self.config.mm_hidden_size, dim_out=self.config.mm_hidden_size, k=5)
25
+ self.block0 = TCBlock(dim=self.config.mm_hidden_size, num_heads=8)
26
+
27
+ self.ctm1 = CTM(sample_ratio=config.config["spatial_cluster_rate1"], embed_dim=self.config.mm_hidden_size, dim_out=self.config.mm_hidden_size, k=3)
28
+ self.block1 = TCBlock(dim=self.config.mm_hidden_size, num_heads=8)
29
+
30
+ self.ctm2 = CTM(sample_ratio=config.config["spatial_cluster_rate2"], embed_dim=self.config.mm_hidden_size, dim_out=self.config.mm_hidden_size, k=3)
31
+ self.block2 = TCBlock(dim=self.config.mm_hidden_size, num_heads=8)
32
+
33
+ self.ctm3 = CTM(sample_ratio=config.config["temporal_cluster_rate"], embed_dim=self.config.mm_hidden_size, dim_out=self.config.mm_hidden_size, k=5)
34
+ self.block3 = TCBlock(dim=self.config.mm_hidden_size, num_heads=8)
35
+ else:
36
+ self.use_cluster = False
37
+
38
+ def get_vision_tower(self):
39
+ vision_tower = getattr(self, 'vision_tower', None)
40
+ if type(vision_tower) is list:
41
+ vision_tower = vision_tower[0]
42
+ return vision_tower
43
+
44
+ def initialize_vision_modules(self, model_args, fsdp=None):
45
+ vision_tower = model_args.vision_tower
46
+ mm_vision_select_layer = model_args.mm_vision_select_layer
47
+ mm_vision_select_feature = model_args.mm_vision_select_feature
48
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
49
+
50
+ self.config.mm_vision_tower = vision_tower
51
+
52
+ vision_tower = build_vision_tower(model_args)
53
+
54
+ self.config.use_mm_proj = True
55
+ self.config.mm_hidden_size = vision_tower.hidden_size
56
+ self.config.mm_vision_select_layer = mm_vision_select_layer
57
+ self.config.mm_vision_select_feature = mm_vision_select_feature
58
+
59
+ if fsdp is not None and len(fsdp) > 0:
60
+ self.vision_tower = [vision_tower]
61
+ else:
62
+ self.vision_tower = vision_tower
63
+
64
+ if not hasattr(self, 'mm_projector'):
65
+ self.mm_projector = build_vision_projector(self.config)
66
+
67
+ if pretrain_mm_mlp_adapter is not None:
68
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
69
+ def get_w(weights, keyword):
70
+ return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
71
+
72
+ self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
73
+
74
+ def initialize_cluster_modules(self, model_args):
75
+ self.use_cluster = model_args.use_cluster
76
+
77
+ if self.use_cluster and not hasattr(self, 'ctm0'):
78
+ self.ctm0 = CTM(sample_ratio=model_args.spatial_cluster_rate0, embed_dim=self.config.mm_hidden_size, dim_out=self.config.mm_hidden_size, k=5)
79
+ self.block0 = TCBlock(dim=self.config.mm_hidden_size, num_heads=8)
80
+
81
+ self.ctm1 = CTM(sample_ratio=model_args.spatial_cluster_rate1, embed_dim=self.config.mm_hidden_size, dim_out=self.config.mm_hidden_size, k=3)
82
+ self.block1 = TCBlock(dim=self.config.mm_hidden_size, num_heads=8)
83
+
84
+ self.ctm2 = CTM(sample_ratio=model_args.spatial_cluster_rate2, embed_dim=self.config.mm_hidden_size, dim_out=self.config.mm_hidden_size, k=3)
85
+ self.block2 = TCBlock(dim=self.config.mm_hidden_size, num_heads=8)
86
+
87
+ self.ctm3 = CTM(sample_ratio=model_args.temporal_cluster_rate, embed_dim=self.config.mm_hidden_size, dim_out=self.config.mm_hidden_size, k=5)
88
+ self.block3 = TCBlock(dim=self.config.mm_hidden_size, num_heads=8)
89
+
90
+
91
+ class ChatUniViMetaForCausalLM(ABC):
92
+ @abstractmethod
93
+ def get_model(self):
94
+ pass
95
+
96
+ def get_vision_tower(self):
97
+ return self.get_model().get_vision_tower()
98
+
99
+ def encode_images(self, images):
100
+ image_features = self.get_model().get_vision_tower()(images, select_feature="patch")
101
+ return image_features
102
+
103
+ def positional_encoding(self, x, num_features=1024, max_len=64):
104
+ p = torch.zeros((1, max_len, num_features))
105
+ _x = torch.arange(max_len, dtype=torch.float32).reshape(-1, 1) / torch.pow(10000,
106
+ torch.arange(0, num_features, 2, dtype=torch.float32) / num_features)
107
+
108
+ p[:, :, 0::2] = torch.sin(_x)
109
+ p[:, :, 1::2] = torch.cos(_x)
110
+ x = x + p[:, :x.shape[1], :].to(x.device).to(x.dtype)
111
+ return x
112
+
113
+ def project(self, image_features, input_type="image"):
114
+ if self.get_model().use_cluster:
115
+ if input_type == "image":
116
+ cluster_image_features = []
117
+ token_dict = {'x': image_features,
118
+ 'token_num': image_features.size(1),
119
+ 'idx_token': torch.arange(image_features.size(1))[None, :].repeat(
120
+ image_features.size(0), 1),
121
+ 'agg_weight': image_features.new_ones(image_features.size(0), image_features.size(1),
122
+ 1),
123
+ 'mask': None}
124
+
125
+ token_dict = self.get_model().block0(self.get_model().ctm0(token_dict))
126
+ cluster_image_features.append(token_dict["x"])
127
+
128
+ token_dict = self.get_model().block1(self.get_model().ctm1(token_dict))
129
+ cluster_image_features.append(token_dict["x"])
130
+
131
+ token_dict = self.get_model().block2(self.get_model().ctm2(token_dict))
132
+ cluster_image_features.append(token_dict["x"])
133
+
134
+ image_features = torch.cat(cluster_image_features, dim=1)
135
+ image_features = image_features.to(self.get_model().mm_projector.weight.dtype)
136
+ else:
137
+ cls_features = torch.mean(image_features, dim=1, keepdim=False).unsqueeze(0).clone()
138
+ token_dict = {'x': cls_features,
139
+ 'token_num': cls_features.size(1),
140
+ 'idx_token': torch.arange(cls_features.size(1))[None, :].repeat(
141
+ cls_features.size(0), 1),
142
+ 'agg_weight': cls_features.new_ones(cls_features.size(0), cls_features.size(1),
143
+ 1),
144
+ 'mask': None}
145
+
146
+ down_dict, token_dict = self.get_model().ctm3(token_dict)
147
+ events = OrderedDict()
148
+
149
+ max_len = 0
150
+ for id, i in enumerate(down_dict["idx_token"][0].tolist()):
151
+ if i not in events:
152
+ events[i] = [id]
153
+ else:
154
+ events[i].append(id)
155
+ max_len = len(events[i]) if max_len < len(events[i]) else max_len
156
+
157
+ cluster_image_features = []
158
+ token_dict = {'x': image_features,
159
+ 'token_num': image_features.size(1),
160
+ 'idx_token': torch.arange(image_features.size(1))[None, :].repeat(
161
+ image_features.size(0), 1),
162
+ 'agg_weight': image_features.new_ones(image_features.size(0), image_features.size(1),
163
+ 1),
164
+ 'mask': None}
165
+
166
+ token_dict0 = self.get_model().block0(self.get_model().ctm0(token_dict))
167
+ token_dict1 = self.get_model().block1(self.get_model().ctm1(token_dict0))
168
+ token_dict2 = self.get_model().block2(self.get_model().ctm2(token_dict1))
169
+
170
+ for id, key in enumerate(events):
171
+ cur_image_features0 = torch.cat([token_dict0["x"][i] for i in events[key]], dim=0).unsqueeze(0)
172
+ token_dict = {'x': cur_image_features0,
173
+ 'token_num': cur_image_features0.size(1),
174
+ 'idx_token': torch.arange(cur_image_features0.size(1))[None, :].repeat(
175
+ cur_image_features0.size(0), 1),
176
+ 'agg_weight': cur_image_features0.new_ones(cur_image_features0.size(0),
177
+ cur_image_features0.size(1),
178
+ 1),
179
+ 'mask': None}
180
+
181
+ cur_token_dict0 = self.get_model().block0(self.get_model().ctm0(token_dict))
182
+ cluster_image_features.append(cur_token_dict0["x"])
183
+
184
+ cur_image_features1 = torch.cat([token_dict1["x"][i] for i in events[key]], dim=0).unsqueeze(0)
185
+ token_dict = {'x': cur_image_features1,
186
+ 'token_num': cur_image_features1.size(1),
187
+ 'idx_token': torch.arange(cur_image_features1.size(1))[None, :].repeat(
188
+ cur_image_features1.size(0), 1),
189
+ 'agg_weight': cur_image_features1.new_ones(cur_image_features1.size(0),
190
+ cur_image_features1.size(1),
191
+ 1),
192
+ 'mask': None}
193
+
194
+ cur_token_dict1 = self.get_model().block1(self.get_model().ctm1(token_dict))
195
+ cluster_image_features.append(cur_token_dict1["x"])
196
+
197
+ cur_image_features2 = torch.cat([token_dict2["x"][i] for i in events[key]], dim=0).unsqueeze(0)
198
+ token_dict = {'x': cur_image_features2,
199
+ 'token_num': cur_image_features2.size(1),
200
+ 'idx_token': torch.arange(cur_image_features2.size(1))[None, :].repeat(
201
+ cur_image_features2.size(0), 1),
202
+ 'agg_weight': cur_image_features2.new_ones(cur_image_features2.size(0),
203
+ cur_image_features2.size(1),
204
+ 1),
205
+ 'mask': None}
206
+
207
+ cur_token_dict2 = self.get_model().block2(self.get_model().ctm2(token_dict))
208
+ cluster_image_features.append(cur_token_dict2["x"])
209
+
210
+ image_features = torch.cat(cluster_image_features, dim=1)
211
+ image_features = image_features.to(self.get_model().mm_projector.weight.dtype)
212
+
213
+ else:
214
+ if input_type == "video":
215
+ image_features, cls_features = torch.mean(image_features, dim=0, keepdim=False).unsqueeze(
216
+ 0), torch.mean(image_features, dim=1, keepdim=False).unsqueeze(0)
217
+ image_features = torch.cat([image_features, cls_features], dim=1)
218
+
219
+ image_features = self.get_model().mm_projector(image_features)
220
+ return image_features # 不同的type形状相同
221
+
222
+ def prepare_inputs_labels_for_multimodal(
223
+ self, input_ids, attention_mask, past_key_values, labels, images, audio_features=None, target_frame=0, ref_ids=None
224
+ ):
225
+ IMAGE_TOKEN_INDEX = -200
226
+ AUDIO_TOKEN_INDEX = -300
227
+ # print("\n调用prepare_inputs_labels_for_multimodal")
228
+ vision_tower = self.get_vision_tower()
229
+ # print("获取vision_tower")
230
+ num_frames = images[0].shape[0] # T
231
+
232
+
233
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
234
+ if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1:
235
+ attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device)
236
+ return input_ids, attention_mask, past_key_values, None, labels
237
+
238
+
239
+ if ref_ids is not None:
240
+ ref_embeds = []
241
+ for ref_id in ref_ids:
242
+ ref_embed = self.get_model().embed_tokens(ref_id) #[L, 4096]
243
+ ref_embeds.append(ref_embed)
244
+ # list[B]: [len_ref, 4096]
245
+
246
+
247
+
248
+
249
+
250
+ if type(images) is list or images.ndim == 5:
251
+ # print("先concat列表中的图像")
252
+ concat_images = torch.cat([image for image in images], dim=0) # [BT, 3, H, W]
253
+ org_image_features = self.encode_images(concat_images) # [BT, 256, 1024]
254
+
255
+ # if audio_features is not None and hasattr(self, "audio_adapter"):
256
+ if True:
257
+ # image_features = self.audio_adapter(org_image_features, audio_features, ref_embeds_T)
258
+ # image_features = self.token_compressor(org_image_features, ref_embeds)
259
+ # print("image_features after compress:", image_features.shape)
260
+ image_features = org_image_features
261
+
262
+ else:
263
+ image_features = org_image_features
264
+ # split_sizes = [image.shape[0] for image in images]
265
+ split_sizes = 1
266
+ image_features = torch.split(image_features, split_sizes, dim=0) # list[BT]: [1, 256,1024]
267
+ image_features = [x.flatten(0, 1) for x in image_features] # list[BT]: [256,1024]
268
+
269
+ org_image_features = torch.split(org_image_features, split_sizes, dim=0)
270
+ org_image_features = [x.flatten(0, 1) for x in org_image_features]
271
+
272
+ else:
273
+ # print("直接获取image_feature")
274
+ image_features = self.encode_images(images)
275
+ org_image_features = image_features
276
+
277
+
278
+
279
+ new_input_embeds = []
280
+ new_labels = [] if labels is not None else None
281
+ cur_image_idx = 0
282
+ for batch_idx, cur_input_ids in enumerate(input_ids):
283
+ # cur_image_idx += 1
284
+
285
+ # 判断当前input_id中有没有图像token
286
+ # print("cur_input_ids shape:", cur_input_ids.shape)
287
+ # print("cur_input_ids:", cur_input_ids)
288
+ if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0:
289
+ # print("input_ids中没有 IMAGE token")
290
+ # multimodal LLM, but the current sample is not multimodal
291
+ # 直接把input_ids进行text embed
292
+ cur_input_embeds = self.get_model().embed_tokens(cur_input_ids)
293
+ cur_input_embeds = cur_input_embeds + (
294
+ 0. * self.get_model().mm_projector(vision_tower.dummy_feature)).sum()
295
+ new_input_embeds.append(cur_input_embeds)
296
+ if labels is not None:
297
+ new_labels.append(labels[batch_idx])
298
+ cur_image_idx += 1
299
+ continue
300
+
301
+ image_token_indices = torch.where((cur_input_ids == IMAGE_TOKEN_INDEX)|(cur_input_ids == AUDIO_TOKEN_INDEX))[0]
302
+ audio_token_indices = torch.where(cur_input_ids == AUDIO_TOKEN_INDEX)[0]
303
+ # print("audio indices:", audio_token_indices)
304
+ # print("image and audio indices:", image_token_indices)
305
+
306
+ cur_new_input_embeds = []
307
+ if labels is not None:
308
+ cur_labels = labels[batch_idx]
309
+ cur_new_labels = []
310
+ assert cur_labels.shape == cur_input_ids.shape
311
+
312
+
313
+ # 有多个image token---------------------------------------------
314
+ if len(image_token_indices) > 1:
315
+ # print("有多个image token")
316
+ # return 0
317
+
318
+ temp = []
319
+
320
+ cur, pre = image_token_indices[0], image_token_indices[0]
321
+ # 这里是把连续的<image>的位置放到一个list中存储 分割开的<image>
322
+ for i in image_token_indices:
323
+ cur = i
324
+ # 如果下一个<image>就在上一个<image>之后
325
+ if cur - pre == 1:
326
+ temp[-1] = temp[-1] + [cur]
327
+ else:
328
+ temp.append([cur])
329
+ pre = cur
330
+
331
+
332
+ pre_image_token_end = 0
333
+ cur_frames = 0
334
+ for i in temp:
335
+ # 第一个以及最后一个<image>的位置
336
+ image_token_start = i[0]
337
+ image_token_end = i[-1]
338
+ cur_image_features = []
339
+
340
+ if len(i) >= 2: # 处理T个image组成的视频特征
341
+ for frame_idx in range(num_frames):
342
+ cur_image_features.append(org_image_features[batch_idx*num_frames+frame_idx])
343
+ # print(batch_idx*num_frames+frame_idx)
344
+ elif i[0] not in audio_token_indices:
345
+ cur_image_features.append(org_image_features[batch_idx * num_frames + target_frame])
346
+ # print(batch_idx * num_frames + target_frame)
347
+ else:
348
+ cur_image_features.append(audio_features[batch_idx])
349
+ # print(f"audio{batch_idx}")
350
+ # ------------------------------------------------------------------
351
+ # # i是每组<image>的indices 根据其数量从image_features中拿特征
352
+ # for _ in i:
353
+ # # 表示处理的是<image>
354
+ # if _ not in audio_token_indices:
355
+ # # 单个image
356
+ # if cur_frames == num_frames:
357
+ # # cur_image_features.append(org_image_features[cur_image_idx-num_frames+target_frame])
358
+ # cur_image_features.append(org_image_features[batch_idx*num_frames+target_frame])
359
+ # # print(cur_image_idx-num_frames+target_frame)
360
+ # # 多个image
361
+ # else:
362
+ # cur_image_features.append(image_features[cur_image_idx])
363
+ # # print(cur_image_idx)
364
+ # cur_image_idx += 1
365
+ # cur_frames += 1
366
+ # # 处理<audio>
367
+ # else:
368
+ # # cur_image_features.append(self.audio_feature_layer(audio_features[batch_idx]))
369
+ # cur_image_features.append(audio_features[batch_idx])
370
+ # # print("audio:", batch_idx)
371
+ # # cur_image_features list[len(i)] : [256,1024]
372
+
373
+
374
+
375
+ # 如果当前分组是多个image 代表video
376
+ if len(i) >= 2:
377
+ if not self.compress:
378
+
379
+ # 对拿到的多个image_features进行压缩 并投影
380
+ cur_image_features = torch.stack(cur_image_features, dim=0) # [len(i), 256, 1024]
381
+ cur_image_features = self.project(cur_image_features, input_type="video")
382
+ t, l, n = cur_image_features.size()
383
+ cur_image_features = cur_image_features.contiguous().view(t * l, n) #[112, 4096]
384
+ # print(f"no compression, cur_image_features{cur_image_features.shape}")
385
+
386
+ else:
387
+
388
+ compressed_frames = []
389
+ for cur_image_feature in cur_image_features:
390
+ cur_image_feature = self.project(cur_image_feature.unsqueeze(0), input_type="image") # [1, 256, 1024]
391
+ t, l, n = cur_image_feature.size()
392
+ cur_image_feature = cur_image_feature.contiguous().view(t * l, n) # [112, 4096]
393
+
394
+ compressed_frames.append(cur_image_feature.mean(dim=0).unsqueeze(0)) # [1, 4096]
395
+ compressed_frames = torch.cat(compressed_frames, dim=0) # [T, 4096]
396
+
397
+ cur_image_features = torch.stack(cur_image_features, dim=0) # [len(i), 256, 1024]
398
+ cur_image_features = self.project(cur_image_features, input_type="video")
399
+ t, l, n = cur_image_features.size()
400
+ cur_image_features = cur_image_features.contiguous().view(t * l, n) # [112, 4096]
401
+
402
+ # cur_image_features = torch.cat([cur_image_features, compressed_frames], dim=0) # [122, 4096]
403
+ cur_image_features = torch.cat([compressed_frames, cur_image_features], dim=0) # [122, 4096]
404
+
405
+ # 对于单个的特殊 token 如果是<image>
406
+ elif i[0] not in audio_token_indices:
407
+ cur_image_features = torch.stack(cur_image_features, dim=0)
408
+ cur_image_features = self.project(cur_image_features, input_type="image")
409
+ t, l, n = cur_image_features.size()
410
+ cur_image_features = cur_image_features.contiguous().view(t * l, n) # [112, 4093]
411
+ else:
412
+ cur_image_features = cur_image_features[0] #[10, 4096]
413
+
414
+
415
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
416
+ # 把im_start前的文字进行embeds
417
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[pre_image_token_end:image_token_start - 1]).detach())
418
+ # 把im_start进行embeds
419
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start - 1:image_token_start]))
420
+ # 图像特征
421
+ cur_new_input_embeds.append(cur_image_features)
422
+ # im_end
423
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_end + 1:image_token_end + 2]))
424
+ if labels is not None:
425
+ cur_new_labels.append(cur_labels[pre_image_token_end:image_token_start])
426
+ # cur_new_labels填充
427
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
428
+ cur_new_labels.append(cur_labels[image_token_end:image_token_end + 1])
429
+
430
+ # cur_labels设置为剩余的cur_labels
431
+ # cur_labels = cur_labels[image_token_end + 2:]
432
+ else:
433
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[pre_image_token_end:image_token_start]))
434
+ cur_new_input_embeds.append(cur_image_features)
435
+ if labels is not None:
436
+ cur_new_labels.append(cur_labels[pre_image_token_end:image_token_start])
437
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
438
+ # cur_labels = cur_labels[image_token_end + 1:]
439
+
440
+ pre_image_token_end = image_token_end + 1
441
+
442
+
443
+ # cur_input_ids设置为剩余的cur_input_ids
444
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end',
445
+ False):
446
+ cur_input_ids = cur_input_ids[image_token_end + 2:]
447
+ cur_labels = cur_labels[image_token_end + 2:]
448
+ else:
449
+ cur_input_ids = cur_input_ids[image_token_end + 1:]
450
+ cur_labels = cur_labels[image_token_end + 1:]
451
+
452
+ # 结合上面大于1 此处就是只有一个image token
453
+ elif image_token_indices.numel() > 0:
454
+ # print("只有一个image token")
455
+
456
+ cur_image_features = []
457
+ image_token_start = image_token_indices[0]
458
+ image_token_end = image_token_indices[-1]
459
+ # print("image_token_start:", image_token_start, " image_token_end:", image_token_end)
460
+
461
+ # 根据image token数量 把image feature加入到cur_image_features
462
+ for _ in image_token_indices:
463
+ cur_image_features.append(image_features[cur_image_idx])
464
+ cur_image_idx += 1
465
+ # print("cur_image_features length:", len(cur_image_features))
466
+
467
+ # 对image features进行维度上拼接
468
+ cur_image_features = torch.stack(cur_image_features, dim=0)
469
+ # print("cur_image_features_stacked shape:", cur_image_features.shape)
470
+ cur_image_features = self.project(cur_image_features, input_type="image")
471
+ # print("cur_image_features_projected shape:", cur_image_features.shape)
472
+
473
+ # 获取 图像特征的维度 nums, len, dim
474
+ t, l, n = cur_image_features.size()
475
+ cur_image_features = cur_image_features.contiguous().view(t * l, n)
476
+ # print("cur_image_features_viewed shape:", cur_image_features.shape)
477
+
478
+
479
+
480
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
481
+ print("no tune_mm_mlp_adapter and no mm_use_im_start_end")
482
+ # 把imagetoken之前的text进行embedding 这两行
483
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start-1]).detach())
484
+ # 这里加入的 image——strat——token
485
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start-1:image_token_start]))
486
+ print("cur_new_input_embeds length:", len(cur_new_input_embeds))
487
+ print("cur_new_input_embeds shape:", cur_new_input_embeds[0].shape)
488
+ print("cur_new_input_embeds shape:", cur_new_input_embeds[1].shape)
489
+
490
+ # 在图像token位置上加入image feature
491
+ cur_new_input_embeds.append(cur_image_features)
492
+ print("cur_new_input_embeds length:", len(cur_new_input_embeds))
493
+ # print("cur_new_input_embeds shape:", cur_new_input_embeds[2].shape)
494
+
495
+ # 把图像token之后的img-end-token加入
496
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_end+1:image_token_end+2]))
497
+ print("cur_new_input_embeds length:", len(cur_new_input_embeds))
498
+
499
+ if labels is not None:
500
+ # 把image token前面的label加入
501
+ cur_new_labels.append(cur_labels[:image_token_start])
502
+ # 根据图像特征形状加入 多个ignore index
503
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
504
+ # 把img-end-token加入
505
+ cur_new_labels.append(cur_labels[image_token_end:image_token_end+1])
506
+ # 把剩下的text label加入
507
+ cur_labels = cur_labels[image_token_end+2:]
508
+
509
+ else:
510
+ # print("tune_mm_mlp_adapter / mm_use_im_start_end")
511
+ # 对图像token之前的text token 进行embedding
512
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start]))
513
+ cur_new_input_embeds.append(cur_image_features)
514
+ # print("cur_new_input_embeds length:", len(cur_new_input_embeds))
515
+
516
+ if labels is not None:
517
+ # 把图像前面的labels进行复制
518
+ cur_new_labels.append(cur_labels[:image_token_start])
519
+ # 根据图像特征形状 加入shape[0]个 IGNORE_INDEX
520
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
521
+ # 加入剩下的labels
522
+ # print("cur_new_labels length:", len(cur_new_labels))
523
+ # print("cur_new_labels:", cur_new_labels)
524
+ # print(cur_new_labels[0].shape, ' ',cur_new_labels[1].shape)
525
+
526
+ # 将cur_labels保留为剩下的未处理过的lables
527
+ cur_labels = cur_labels[image_token_end+1:]
528
+ # print("labels after image:", cur_labels)
529
+ # print(len(cur_labels))
530
+
531
+
532
+ # 将 cur_input_ids替换为剩下的 没有处理的 (img之后的) ids
533
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
534
+ cur_input_ids = cur_input_ids[image_token_end+2:]
535
+ else:
536
+ cur_input_ids = cur_input_ids[image_token_end+1:]
537
+ # print("input_ids after image :", cur_input_ids)
538
+
539
+ # 如果图像token之后还有text token
540
+ if cur_input_ids.numel() > 0:
541
+ # print("image token 之后还有 text token")
542
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
543
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids).detach())
544
+ else:
545
+ # 把剩下的input_id进行embedding
546
+
547
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids))
548
+
549
+ # print("cur_new_input_embeds length:", len(cur_new_input_embeds))
550
+ # print("cur_new_input_embeds shape:", cur_new_input_embeds[0].shape, cur_new_input_embeds[1].shape, cur_new_input_embeds[2].shape)
551
+
552
+ if labels is not None:
553
+ # 把剩下的labels加入
554
+ cur_new_labels.append(cur_labels)
555
+
556
+
557
+ cur_new_input_embeds = [x.to(device='cuda') for x in cur_new_input_embeds]
558
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)
559
+
560
+ new_input_embeds.append(cur_new_input_embeds)
561
+ if labels is not None:
562
+ cur_new_labels = torch.cat(cur_new_labels, dim=0)
563
+
564
+ new_labels.append(cur_new_labels)
565
+
566
+ # 如果一个batch内部embedd inputs长度不一致
567
+ if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):
568
+ print("batch 内部长度不一致")
569
+ max_len = max(x.shape[0] for x in new_input_embeds)
570
+
571
+ new_input_embeds_align = []
572
+ for cur_new_embed in new_input_embeds:
573
+ cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)
574
+ new_input_embeds_align.append(cur_new_embed)
575
+ new_input_embeds = torch.stack(new_input_embeds_align, dim=0)
576
+
577
+ if labels is not None:
578
+ new_labels_align = []
579
+ _new_labels = new_labels
580
+ for cur_new_label in new_labels:
581
+ cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0)
582
+ new_labels_align.append(cur_new_label)
583
+ new_labels = torch.stack(new_labels_align, dim=0)
584
+
585
+ if attention_mask is not None:
586
+ new_attention_mask = []
587
+ for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels):
588
+ new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device)
589
+ new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device)
590
+ cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0)
591
+ new_attention_mask.append(cur_new_attention_mask)
592
+ attention_mask = torch.stack(new_attention_mask, dim=0)
593
+ assert attention_mask.shape == new_labels.shape
594
+
595
+ # 内部长度一致
596
+ else:
597
+ # 将一个batch的数据 拼接成 [B, token_len, dim]
598
+ new_input_embeds = torch.stack(new_input_embeds, dim=0)
599
+ if labels is not None:
600
+ new_labels = torch.stack(new_labels, dim=0)
601
+
602
+ if attention_mask is not None:
603
+ new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device)
604
+ attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1)
605
+ assert attention_mask.shape == new_input_embeds.shape[:2]
606
+
607
+ return None, attention_mask, past_key_values, new_input_embeds, new_labels
608
+
609
+ def initialize_vision_tokenizer(self, model_args, tokenizer):
610
+ if model_args.mm_use_im_patch_token:
611
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
612
+ tokenizer.add_tokens([DEFAULT_VIDEO_PATCH_TOKEN], special_tokens=True)
613
+ self.resize_token_embeddings(len(tokenizer))
614
+
615
+ if model_args.mm_use_im_start_end:
616
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_VID_START_TOKEN, DEFAULT_VID_END_TOKEN], special_tokens=True)
617
+ self.resize_token_embeddings(len(tokenizer))
618
+
619
+ if num_new_tokens > 0:
620
+ input_embeddings = self.get_input_embeddings().weight.data
621
+ output_embeddings = self.get_output_embeddings().weight.data
622
+
623
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
624
+ dim=0, keepdim=True)
625
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
626
+ dim=0, keepdim=True)
627
+
628
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
629
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
630
+
631
+ if model_args.tune_mm_mlp_adapter:
632
+ for p in self.get_input_embeddings().parameters():
633
+ p.requires_grad = True
634
+ for p in self.get_output_embeddings().parameters():
635
+ p.requires_grad = False
636
+
637
+ if model_args.pretrain_mm_mlp_adapter:
638
+ mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
639
+ embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
640
+ assert num_new_tokens == 2
641
+ if input_embeddings.shape == embed_tokens_weight.shape:
642
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
643
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
644
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
645
+ else:
646
+ raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
647
+ elif model_args.mm_use_im_patch_token:
648
+ if model_args.tune_mm_mlp_adapter:
649
+ for p in self.get_input_embeddings().parameters():
650
+ p.requires_grad = False
651
+ for p in self.get_output_embeddings().parameters():
652
+ p.requires_grad = False
ChatUniVi/model/builder.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
4
+ import torch
5
+ from ChatUniVi.model import *
6
+ from ChatUniVi.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
7
+ from accelerate import init_empty_weights, load_checkpoint_and_dispatch
8
+ from transformers import AutoConfig, AutoModelForCausalLM
9
+
10
+
11
+ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto"):
12
+ kwargs = {"device_map": device_map}
13
+
14
+ if load_8bit:
15
+ kwargs['load_in_8bit'] = True
16
+ elif load_4bit:
17
+ kwargs['load_in_4bit'] = True
18
+ kwargs['quantization_config'] = BitsAndBytesConfig(
19
+ load_in_4bit=True,
20
+ bnb_4bit_compute_dtype=torch.float16,
21
+ bnb_4bit_use_double_quant=True,
22
+ bnb_4bit_quant_type='nf4'
23
+ )
24
+ else:
25
+ kwargs['torch_dtype'] = torch.float16
26
+
27
+ if 'chatunivi' in model_name.lower():
28
+ # Load ChatUniVi model
29
+ if 'lora' in model_name.lower() and model_base is not None:
30
+ lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
31
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
32
+ print('Loading ChatUniVi from base model...')
33
+ model = ChatUniViLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
34
+ token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
35
+ if model.lm_head.weight.shape[0] != token_num:
36
+ model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
37
+ model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
38
+
39
+ print('Loading additional ChatUniVi weights...')
40
+ if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
41
+ non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
42
+ else:
43
+ # this is probably from HF Hub
44
+ from huggingface_hub import hf_hub_download
45
+ def load_from_hf(repo_id, filename, subfolder=None):
46
+ cache_file = hf_hub_download(
47
+ repo_id=repo_id,
48
+ filename=filename,
49
+ subfolder=subfolder)
50
+ return torch.load(cache_file, map_location='cpu')
51
+ non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
52
+ non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
53
+ if any(k.startswith('model.model.') for k in non_lora_trainables):
54
+ non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
55
+ model.load_state_dict(non_lora_trainables, strict=False)
56
+
57
+ from peft import PeftModel
58
+ print('Loading LoRA weights...')
59
+ model = PeftModel.from_pretrained(model, model_path)
60
+ print('Merging LoRA weights...')
61
+ model = model.merge_and_unload()
62
+ print('Model is loaded...')
63
+ elif model_base is not None:
64
+ # this may be mm projector only
65
+ print('Loading ChatUniVi from base model...')
66
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
67
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
68
+ model = ChatUniViLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
69
+
70
+ mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
71
+ mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
72
+ model.load_state_dict(mm_projector_weights, strict=False)
73
+ else:
74
+ #
75
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
76
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
77
+ else:
78
+ # Load language model
79
+ if model_base is not None:
80
+ # PEFT model
81
+ from peft import PeftModel
82
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
83
+ model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto")
84
+ print(f"Loading LoRA weights from {model_path}")
85
+ model = PeftModel.from_pretrained(model, model_path)
86
+ print(f"Merging weights")
87
+ model = model.merge_and_unload()
88
+ print('Convert to FP16...')
89
+ model.to(torch.float16)
90
+ else:
91
+ use_fast = False
92
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
93
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
94
+
95
+ image_processor = None
96
+
97
+ if 'chatunivi' in model_name.lower():
98
+ mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
99
+ mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
100
+ if mm_use_im_patch_token:
101
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
102
+ if mm_use_im_start_end:
103
+ tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
104
+ model.resize_token_embeddings(len(tokenizer))
105
+
106
+ vision_tower = model.get_vision_tower()
107
+ if not vision_tower.is_loaded:
108
+ vision_tower.load_model()
109
+ vision_tower.to(device='cuda', dtype=torch.float16)
110
+
111
+ image_processor = vision_tower.image_eval_processor
112
+
113
+ if hasattr(model.config, "max_sequence_length"):
114
+ context_len = model.config.max_sequence_length
115
+ else:
116
+ context_len = 2048
117
+
118
+ return tokenizer, model, image_processor, context_len
ChatUniVi/model/cluster.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import torch.nn as nn
4
+ import warnings
5
+
6
+
7
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
8
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
9
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
10
+ def norm_cdf(x):
11
+ # Computes standard normal cumulative distribution function
12
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
13
+
14
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
15
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
16
+ "The distribution of values may be incorrect.",
17
+ stacklevel=2)
18
+
19
+ with torch.no_grad():
20
+ # Values are generated by using a truncated uniform distribution and
21
+ # then using the inverse CDF for the normal distribution.
22
+ # Get upper and lower cdf values
23
+ l = norm_cdf((a - mean) / std)
24
+ u = norm_cdf((b - mean) / std)
25
+
26
+ # Uniformly fill tensor with values from [l, u], then translate to
27
+ # [2l-1, 2u-1].
28
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
29
+
30
+ # Use inverse cdf transform for normal distribution to get truncated
31
+ # standard normal
32
+ tensor.erfinv_()
33
+
34
+ # Transform to proper mean, std
35
+ tensor.mul_(std * math.sqrt(2.))
36
+ tensor.add_(mean)
37
+
38
+ # Clamp to ensure it's in the proper range
39
+ tensor.clamp_(min=a, max=b)
40
+ return tensor
41
+
42
+
43
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
44
+ # type: (Tensor, float, float, float, float) -> Tensor
45
+ r"""Fills the input Tensor with values drawn from a truncated
46
+ normal distribution. The values are effectively drawn from the
47
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
48
+ with values outside :math:`[a, b]` redrawn until they are within
49
+ the bounds. The method used for generating the random values works
50
+ best when :math:`a \leq \text{mean} \leq b`.
51
+ Args:
52
+ tensor: an n-dimensional `torch.Tensor`
53
+ mean: the mean of the normal distribution
54
+ std: the standard deviation of the normal distribution
55
+ a: the minimum cutoff value
56
+ b: the maximum cutoff value
57
+ Examples:
58
+ >>> w = torch.empty(3, 5)
59
+ >>> nn.init.trunc_normal_(w)
60
+ """
61
+ try:
62
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
63
+ except:
64
+ return tensor
65
+
66
+
67
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
68
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
69
+ """
70
+ if drop_prob == 0. or not training:
71
+ return x
72
+ keep_prob = 1 - drop_prob
73
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
74
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
75
+ random_tensor.floor_() # binarize
76
+ output = x.div(keep_prob) * random_tensor
77
+ return output
78
+
79
+
80
+ class DropPath(nn.Module):
81
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
82
+ """
83
+ def __init__(self, drop_prob=None):
84
+ super(DropPath, self).__init__()
85
+ self.drop_prob = drop_prob
86
+
87
+ def forward(self, x):
88
+ return drop_path(x, self.drop_prob, self.training)
89
+
90
+
91
+ def index_points(points, idx):
92
+ """Sample features following the index.
93
+ Returns:
94
+ new_points:, indexed points data, [B, S, C]
95
+
96
+ Args:
97
+ points: input points data, [B, N, C]
98
+ idx: sample index data, [B, S]
99
+ """
100
+ device = points.device
101
+ B = points.shape[0]
102
+ view_shape = list(idx.shape)
103
+ view_shape[1:] = [1] * (len(view_shape) - 1)
104
+ repeat_shape = list(idx.shape)
105
+ repeat_shape[0] = 1
106
+ batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
107
+ new_points = points[batch_indices, idx, :]
108
+ return new_points
109
+
110
+
111
+ def cluster_dpc_knn(token_dict, cluster_num, k=5, token_mask=None):
112
+ """Cluster tokens with DPC-KNN algorithm.
113
+ Return:
114
+ idx_cluster (Tensor[B, N]): cluster index of each token.
115
+ cluster_num (int): actual cluster number. The same with
116
+ input cluster number
117
+ Args:
118
+ token_dict (dict): dict for token information
119
+ cluster_num (int): cluster number
120
+ k (int): number of the nearest neighbor used for local density.
121
+ token_mask (Tensor[B, N]): mask indicate the whether the token is
122
+ padded empty token. Non-zero value means the token is meaningful,
123
+ zero value means the token is an empty token. If set to None, all
124
+ tokens are regarded as meaningful.
125
+ """
126
+ with torch.no_grad():
127
+ x = token_dict["x"]
128
+ B, N, C = x.shape
129
+
130
+ dist_matrix = torch.cdist(x.float(), x.float()) / (C ** 0.5)
131
+
132
+ if token_mask is not None:
133
+ token_mask = token_mask > 0
134
+ # in order to not affect the local density, the distance between empty tokens
135
+ # and any other tokens should be the maximal distance.
136
+ dist_matrix = dist_matrix * token_mask[:, None, :] + \
137
+ (dist_matrix.max() + 1) * (~token_mask[:, None, :])
138
+
139
+ # get local density
140
+
141
+ dist_nearest, index_nearest = torch.topk(dist_matrix, k=k, dim=-1, largest=False)
142
+ density = (-(dist_nearest ** 2).mean(dim=-1)).exp()
143
+ # add a little noise to ensure no tokens have the same density.
144
+ density = density + torch.rand(
145
+ density.shape, device=density.device, dtype=density.dtype) * 1e-6
146
+
147
+ if token_mask is not None:
148
+ # the density of empty token should be 0
149
+ density = density * token_mask
150
+
151
+ # get distance indicator
152
+ mask = density[:, None, :] > density[:, :, None]
153
+ mask = mask.type(x.dtype)
154
+ dist_max = dist_matrix.flatten(1).max(dim=-1)[0][:, None, None]
155
+ dist, index_parent = (dist_matrix * mask + dist_max * (1 - mask)).min(dim=-1)
156
+
157
+ # select clustering center according to score
158
+ score = dist * density
159
+ _, index_down = torch.topk(score, k=cluster_num, dim=-1)
160
+
161
+ # assign tokens to the nearest center
162
+ dist_matrix = index_points(dist_matrix, index_down)
163
+
164
+ idx_cluster = dist_matrix.argmin(dim=1)
165
+
166
+ # make sure cluster center merge to itself
167
+ idx_batch = torch.arange(B, device=x.device)[:, None].expand(B, cluster_num)
168
+ idx_tmp = torch.arange(cluster_num, device=x.device)[None, :].expand(B, cluster_num)
169
+ idx_cluster[idx_batch.reshape(-1), index_down.reshape(-1)] = idx_tmp.reshape(-1)
170
+
171
+ return idx_cluster, cluster_num
172
+
173
+
174
+ def merge_tokens(token_dict, idx_cluster, cluster_num, token_weight=None):
175
+ """Merge tokens in the same cluster to a single cluster.
176
+ Implemented by torch.index_add(). Flops: B*N*(C+2)
177
+ Return:
178
+ out_dict (dict): dict for output token information
179
+
180
+ Args:
181
+ token_dict (dict): dict for input token information
182
+ idx_cluster (Tensor[B, N]): cluster index of each token.
183
+ cluster_num (int): cluster number
184
+ token_weight (Tensor[B, N, 1]): weight for each token.
185
+ """
186
+
187
+ x = token_dict['x']
188
+ idx_token = token_dict['idx_token']
189
+ agg_weight = token_dict['agg_weight']
190
+
191
+ B, N, C = x.shape
192
+ if token_weight is None:
193
+ token_weight = x.new_ones(B, N, 1)
194
+
195
+ idx_batch = torch.arange(B, device=x.device)[:, None]
196
+ idx = idx_cluster + idx_batch * cluster_num
197
+
198
+ all_weight = token_weight.new_zeros(B * cluster_num, 1)
199
+ all_weight.index_add_(dim=0, index=idx.reshape(B * N),
200
+ source=token_weight.reshape(B * N, 1))
201
+ all_weight = all_weight + 1e-6
202
+ norm_weight = token_weight / all_weight[idx]
203
+
204
+ # average token features
205
+ x_merged = x.new_zeros(B * cluster_num, C)
206
+ source = x * norm_weight
207
+
208
+ x_merged.index_add_(dim=0, index=idx.reshape(B * N),
209
+ source=source.reshape(B * N, C).type(x.dtype))
210
+ x_merged = x_merged.reshape(B, cluster_num, C)
211
+
212
+ idx_token_new = index_points(idx_cluster[..., None], idx_token).squeeze(-1)
213
+ weight_t = index_points(norm_weight, idx_token)
214
+ agg_weight_new = agg_weight * weight_t
215
+ agg_weight_new / agg_weight_new.max(dim=1, keepdim=True)[0]
216
+
217
+ out_dict = {}
218
+ out_dict['x'] = x_merged
219
+ out_dict['token_num'] = cluster_num
220
+ out_dict['idx_token'] = idx_token_new
221
+ out_dict['agg_weight'] = agg_weight_new
222
+ out_dict['mask'] = None
223
+ return out_dict
224
+
225
+
226
+ class CTM(nn.Module):
227
+ def __init__(self, sample_ratio, embed_dim, dim_out, k=5):
228
+ super().__init__()
229
+ self.sample_ratio = sample_ratio
230
+ self.dim_out = dim_out
231
+ self.k = k
232
+
233
+ def forward(self, token_dict, sample_ratio=None):
234
+ x = token_dict["x"]
235
+ B, N, C = x.shape
236
+
237
+ token_weight = x.new_ones(B, N)
238
+
239
+ if token_dict["mask"] is not None:
240
+ token_weight.masked_fill_((1 - token_dict["mask"]).to(torch.bool), float("-inf"))
241
+ token_weight = token_weight.unsqueeze(2)
242
+ token_dict['x'] = x
243
+
244
+ if sample_ratio is not None:
245
+ cluster_num = max(math.ceil(N * sample_ratio), 1)
246
+ elif self.sample_ratio > 1:
247
+ cluster_num = max(math.ceil(self.sample_ratio), 1)
248
+ else:
249
+ cluster_num = max(math.ceil(N * self.sample_ratio), 1)
250
+
251
+ k = min(3, max(cluster_num//2, 1)) if self.k > cluster_num else self.k
252
+ idx_cluster, cluster_num = cluster_dpc_knn(
253
+ token_dict, cluster_num, k, token_mask=token_dict["mask"])
254
+
255
+ down_dict = merge_tokens(token_dict, idx_cluster, cluster_num, token_weight)
256
+ return down_dict, token_dict
257
+
258
+
259
+ class TCBlock(nn.Module):
260
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
261
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, use_sr_layer=False):
262
+ super().__init__()
263
+ self.apply(self._init_weights)
264
+
265
+ def _init_weights(self, m):
266
+ if isinstance(m, nn.Linear):
267
+ trunc_normal_(m.weight, std=.02)
268
+ if isinstance(m, nn.Linear) and m.bias is not None:
269
+ nn.init.constant_(m.bias, 0)
270
+ elif isinstance(m, nn.LayerNorm):
271
+ nn.init.constant_(m.bias, 0)
272
+ nn.init.constant_(m.weight, 1.0)
273
+ elif isinstance(m, nn.Conv2d):
274
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
275
+ fan_out //= m.groups
276
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
277
+ if m.bias is not None:
278
+ m.bias.data.zero_()
279
+
280
+ def forward(self, inputs):
281
+ if isinstance(inputs, tuple) or isinstance(inputs, list):
282
+ q_dict, kv_dict = inputs
283
+ else:
284
+ q_dict, kv_dict = inputs, None
285
+
286
+ x = q_dict['x']
287
+ return q_dict
ChatUniVi/model/consolidate.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
4
+ """
5
+ import argparse
6
+
7
+ import torch
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
9
+ from llava.model import *
10
+ from llava.model.utils import auto_upgrade
11
+
12
+
13
+ def consolidate_ckpt(src_path, dst_path):
14
+ print("Loading model")
15
+ auto_upgrade(src_path)
16
+ src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17
+ src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
18
+ src_model.save_pretrained(dst_path)
19
+ src_tokenizer.save_pretrained(dst_path)
20
+
21
+
22
+ if __name__ == "__main__":
23
+ parser = argparse.ArgumentParser()
24
+ parser.add_argument("--src", type=str, required=True)
25
+ parser.add_argument("--dst", type=str, required=True)
26
+
27
+ args = parser.parse_args()
28
+
29
+ consolidate_ckpt(args.src, args.dst)
ChatUniVi/model/dataloader.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import math
3
+ from decord import VideoReader, cpu
4
+ import numpy as np
5
+ import os
6
+ import torch
7
+
8
+
9
+ def _get_rawvideo_dec(video_path, image_processor, max_frames=64, image_resolution=224, video_framerate=1, s=None, e=None):
10
+ # speed up video decode via decord.
11
+ video_mask = np.zeros(max_frames, dtype=np.int64)
12
+ max_video_length = 0
13
+
14
+ # T x 3 x H x W
15
+ video = np.zeros((max_frames, 3, image_resolution, image_resolution), dtype=np.float64)
16
+
17
+ if s is None:
18
+ start_time, end_time = None, None
19
+ else:
20
+ start_time = int(s)
21
+ end_time = int(e)
22
+ start_time = start_time if start_time >= 0. else 0.
23
+ end_time = end_time if end_time >= 0. else 0.
24
+ if start_time > end_time:
25
+ start_time, end_time = end_time, start_time
26
+ elif start_time == end_time:
27
+ end_time = start_time + 1
28
+
29
+ if os.path.exists(video_path):
30
+ vreader = VideoReader(video_path, ctx=cpu(0))
31
+ else:
32
+ print(video_path)
33
+ raise FileNotFoundError
34
+
35
+ fps = vreader.get_avg_fps()
36
+ f_start = 0 if start_time is None else int(start_time * fps)
37
+ f_end = int(min(1000000000 if end_time is None else end_time * fps, len(vreader) - 1))
38
+ num_frames = f_end - f_start + 1
39
+ if num_frames > 0:
40
+ # T x 3 x H x W
41
+ sample_fps = int(video_framerate)
42
+ t_stride = int(round(float(fps) / sample_fps))
43
+
44
+ all_pos = list(range(f_start, f_end + 1, t_stride))
45
+ if len(all_pos) > max_frames:
46
+ sample_pos = [all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=max_frames, dtype=int)]
47
+ else:
48
+ sample_pos = all_pos
49
+
50
+ patch_images = [Image.fromarray(f) for f in vreader.get_batch(sample_pos).asnumpy()]
51
+
52
+ patch_images = [image_processor.preprocess(img, return_tensors='pt')['pixel_values'][0] for img in patch_images]
53
+ slice_len = len(patch_images)
54
+ return patch_images, slice_len
55
+ max_video_length = max_video_length if max_video_length > slice_len else slice_len
56
+ if slice_len < 1:
57
+ pass
58
+ else:
59
+ while len(patch_images) < max_frames:
60
+ patch_images.append(torch.zeros((3, image_resolution, image_resolution)))
61
+ # video[:slice_len, ...] = patch_images
62
+ else:
63
+ print("video path: {} error.".format(video_path))
64
+
65
+ video_mask[:max_video_length] = [1] * max_video_length
66
+
67
+ return patch_images, video_mask
ChatUniVi/model/language_model/language_model/configuration_phi.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT license.
3
+
4
+ import math
5
+ from typing import Optional
6
+
7
+ from transformers import PretrainedConfig
8
+
9
+
10
+ class PhiConfig(PretrainedConfig):
11
+ """Phi configuration."""
12
+
13
+ model_type = "phi-msft"
14
+ attribute_map = {
15
+ "max_position_embeddings": "n_positions",
16
+ "hidden_size": "n_embd",
17
+ "num_attention_heads": "n_head",
18
+ "num_hidden_layers": "n_layer",
19
+ }
20
+
21
+ def __init__(
22
+ self,
23
+ vocab_size: int = 50304,
24
+ n_positions: int = 2048,
25
+ n_embd: int = 1024,
26
+ n_layer: int = 20,
27
+ n_inner: Optional[int] = None,
28
+ n_head: int = 16,
29
+ n_head_kv: Optional[int] = None,
30
+ rotary_dim: Optional[int] = 32,
31
+ activation_function: Optional[str] = "gelu_new",
32
+ flash_attn: bool = False,
33
+ flash_rotary: bool = False,
34
+ fused_dense: bool = False,
35
+ attn_pdrop: float = 0.0,
36
+ embd_pdrop: float = 0.0,
37
+ resid_pdrop: float = 0.0,
38
+ layer_norm_epsilon: float = 1e-5,
39
+ initializer_range: float = 0.02,
40
+ tie_word_embeddings: bool = False,
41
+ pad_vocab_size_multiple: int = 64,
42
+ **kwargs
43
+ ) -> None:
44
+ self.vocab_size = int(math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
45
+ self.n_positions = n_positions
46
+ self.n_embd = n_embd
47
+ self.n_layer = n_layer
48
+ self.n_inner = n_inner
49
+ self.n_head = n_head
50
+ self.n_head_kv = n_head_kv
51
+ self.rotary_dim = min(rotary_dim, n_embd // n_head)
52
+ self.activation_function = activation_function
53
+ self.flash_attn = flash_attn
54
+ self.flash_rotary = flash_rotary
55
+ self.fused_dense = fused_dense
56
+ self.attn_pdrop = attn_pdrop
57
+ self.embd_pdrop = embd_pdrop
58
+ self.resid_pdrop = resid_pdrop
59
+ self.layer_norm_epsilon = layer_norm_epsilon
60
+ self.initializer_range = initializer_range
61
+
62
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
ChatUniVi/model/language_model/language_model/modeling_phi.py ADDED
@@ -0,0 +1,984 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT license.
3
+ #
4
+ # Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
5
+ # Licensed under the BSD 3-Clause License.
6
+
7
+ from __future__ import annotations
8
+
9
+ import math
10
+ from dataclasses import dataclass, field
11
+ from typing import Any, Dict, Optional, Tuple, Union
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ from einops import rearrange, repeat
16
+ from transformers import PretrainedConfig, PreTrainedModel
17
+ from transformers.activations import ACT2FN
18
+ from transformers.modeling_outputs import CausalLMOutputWithPast
19
+
20
+ from .configuration_phi import PhiConfig
21
+
22
+ try:
23
+ from flash_attn.bert_padding import pad_input, unpad_input
24
+ from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
25
+ from flash_attn.modules.mha import FlashCrossAttention, FlashSelfAttention
26
+ # from flash_attn.ops.fused_dense import FusedDense
27
+ except:
28
+ pad_input, unpad_input = None, None
29
+ FlashRotaryEmbedding = None
30
+ FlashSelfAttention, FlashCrossAttention = None, None
31
+ FusedDense = None
32
+ from flash_attn.bert_padding import pad_input, unpad_input
33
+ from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
34
+ from flash_attn.modules.mha import FlashCrossAttention, FlashSelfAttention
35
+
36
+ @dataclass
37
+ class InferenceParams:
38
+ """Inference parameters passed to model to efficiently calculate
39
+ and store context during inference.
40
+
41
+ Reference:
42
+ https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py.
43
+
44
+ Args:
45
+ max_seqlen: Maximum sequence length.
46
+ max_batch_size: Maximum batch size.
47
+ seqlen_offset: Sequence length offset.
48
+ batch_size_offset: Batch size offset.
49
+ key_value_memory_dict: Key value memory dictionary.
50
+ lengths_per_sample: Lengths per sample.
51
+
52
+ """
53
+
54
+ max_seqlen: int = field(metadata={"help": "Maximum sequence length."})
55
+
56
+ max_batch_size: int = field(metadata={"help": "Maximum batch size."})
57
+
58
+ seqlen_offset: int = field(default=0, metadata={"help": "Sequence length offset."})
59
+
60
+ batch_size_offset: int = field(default=0, metadata={"help": "Batch size offset."})
61
+
62
+ key_value_memory_dict: Dict[str, Any] = field(
63
+ default_factory=dict, metadata={"help": "Key value memory dictionary."}
64
+ )
65
+
66
+ lengths_per_sample: torch.Tensor = field(default=None, metadata={"help": "Lengths per sample."})
67
+
68
+
69
+ class Embedding(nn.Module):
70
+ """Token embedding with dropout."""
71
+
72
+ def __init__(self, config: PretrainedConfig) -> None:
73
+ super().__init__()
74
+
75
+ self.wte = nn.Embedding(config.vocab_size, config.n_embd)
76
+ self.drop = nn.Dropout(config.embd_pdrop)
77
+
78
+ def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
79
+ input_shape = input_ids.size()
80
+ input_ids = input_ids.view(-1, input_shape[-1])
81
+
82
+ hidden_states = self.wte(input_ids)
83
+ hidden_states = self.drop(hidden_states)
84
+
85
+ return hidden_states
86
+
87
+
88
+ def _apply_rotary_emb(
89
+ x: torch.FloatTensor,
90
+ cos: torch.FloatTensor,
91
+ sin: torch.FloatTensor,
92
+ ) -> torch.FloatTensor:
93
+ _, seqlen, _, _ = x.shape
94
+ _, rotary_dim = cos.shape
95
+ rotary_dim *= 2
96
+
97
+ x_rot = x[:, :, :, :rotary_dim]
98
+ x_pass = x[:, :, :, rotary_dim:]
99
+
100
+ x1, x2 = x_rot.chunk(2, dim=-1)
101
+ c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
102
+ x1, x2, c, s = [t.to(dtype=torch.float32) for t in [x1, x2, c, s]]
103
+
104
+ x_rot = torch.cat([x1 * c - x2 * s, x1 * s + x2 * c], axis=-1).to(x.dtype)
105
+
106
+ return torch.cat([x_rot, x_pass], axis=-1)
107
+
108
+
109
+ def _apply_rotary_emb_kv(
110
+ kv: torch.FloatTensor,
111
+ cos: torch.FloatTensor,
112
+ sin: torch.FloatTensor,
113
+ cos_k: Optional[torch.FloatTensor] = None,
114
+ sin_k: Optional[torch.FloatTensor] = None,
115
+ ) -> torch.FloatTensor:
116
+ _, seqlen, _, _, _ = kv.shape
117
+ _, rotary_dim = cos.shape
118
+ rotary_dim *= 2
119
+
120
+ k_rot = kv[:, :, 0, :, :rotary_dim]
121
+ k_pass = kv[:, :, 0, :, rotary_dim:]
122
+
123
+ k1, k2 = k_rot.chunk(2, dim=-1)
124
+ c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
125
+ k1, k2, c, s = [t.to(dtype=torch.float32) for t in [k1, k2, c, s]]
126
+
127
+ k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(kv.dtype)
128
+
129
+ return torch.cat(
130
+ [
131
+ torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
132
+ kv[:, :, 1:2, :, :],
133
+ ],
134
+ axis=2,
135
+ )
136
+
137
+
138
+ def _apply_rotary_emb_qkv(
139
+ qkv: torch.FloatTensor,
140
+ cos: torch.FloatTensor,
141
+ sin: torch.FloatTensor,
142
+ cos_k: Optional[torch.FloatTensor] = None,
143
+ sin_k: Optional[torch.FloatTensor] = None,
144
+ ) -> torch.FloatTensor:
145
+ _, seqlen, _, _, _ = qkv.shape
146
+ _, rotary_dim = cos.shape
147
+ rotary_dim *= 2
148
+
149
+ q_rot = qkv[:, :, 0, :, :rotary_dim]
150
+ q_pass = qkv[:, :, 0, :, rotary_dim:]
151
+
152
+ k_rot = qkv[:, :, 1, :, :rotary_dim]
153
+ k_pass = qkv[:, :, 1, :, rotary_dim:]
154
+
155
+ q1, q2 = q_rot.chunk(2, dim=-1)
156
+ k1, k2 = k_rot.chunk(2, dim=-1)
157
+ c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
158
+ q1, q2, k1, k2, c, s = [t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s]]
159
+
160
+ q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype)
161
+ k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype)
162
+
163
+ return torch.cat(
164
+ [
165
+ torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2),
166
+ torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
167
+ qkv[:, :, 2:3, :, :],
168
+ ],
169
+ axis=2,
170
+ )
171
+
172
+
173
+ class RotaryEmbedding(nn.Module):
174
+ """Rotary positional embedding (RoPE).
175
+
176
+ Reference:
177
+ RoFormer: Enhanced Transformer with Rotary Position Embedding.
178
+ https://arxiv.org/pdf/2104.09864.pdf.
179
+
180
+ """
181
+
182
+ def __init__(
183
+ self,
184
+ dim: int,
185
+ base: int = 10000,
186
+ scale_base: Optional[float] = None,
187
+ pos_idx_in_fp32: bool = True,
188
+ max_position_embeddings: int = 2048,
189
+ device: Optional[str] = None,
190
+ **kwargs,
191
+ ) -> None:
192
+ super().__init__()
193
+
194
+ if scale_base is not None:
195
+ raise NotImplementedError
196
+
197
+ self.dim = dim
198
+ self.base = float(base)
199
+ self.scale_base = scale_base
200
+ self.pos_idx_in_fp32 = pos_idx_in_fp32
201
+ self.max_position_embeddings = max_position_embeddings
202
+ self.device = device
203
+
204
+ # Generate and save the inverse frequency buffer (non-trainable)
205
+ inv_freq = self._compute_inv_freq(device)
206
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
207
+
208
+ # Generate and save the scale buffer (non-trainable)
209
+ scale = (
210
+ (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
211
+ if scale_base is not None
212
+ else None
213
+ )
214
+ self.register_buffer("scale", scale, persistent=False)
215
+
216
+ # Initialize cached attributes since ONNX can't rely on dynamic initialization
217
+ self._update_cos_sin_cache(max_position_embeddings, device=device, dtype=torch.float32)
218
+
219
+ def _compute_inv_freq(self, device: Optional[str] = None) -> torch.FloatTensor:
220
+ return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
221
+
222
+ def _update_cos_sin_cache(
223
+ self,
224
+ seqlen: int,
225
+ device: Optional[str] = None,
226
+ dtype: Optional[torch.dtype] = None,
227
+ ) -> None:
228
+ self._seq_len_cached = seqlen
229
+
230
+ # fp32 is preferred since the output of `torch.arange` can be quite large
231
+ # and bf16 would lose a lot of precision
232
+ if self.pos_idx_in_fp32:
233
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
234
+ if self.inv_freq.dtype != torch.float32:
235
+ inv_freq = self._compute_inv_freq(device=device)
236
+ else:
237
+ inv_freq = self.inv_freq
238
+ else:
239
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
240
+ inv_freq = self.inv_freq
241
+
242
+ # `torch.outer` is preferred since `torch.einsum` converts from fp32 to fp16 if used with AMP
243
+ freqs = torch.outer(t, inv_freq)
244
+ if self.scale is None:
245
+ self._cos_cached = torch.cos(freqs).to(dtype)
246
+ self._sin_cached = torch.sin(freqs).to(dtype)
247
+ else:
248
+ power = (
249
+ torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
250
+ ) / self.scale_base
251
+ scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
252
+
253
+ # Force the scale multiplication to happen in fp32
254
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
255
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
256
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
257
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
258
+
259
+ def forward(
260
+ self,
261
+ qkv: torch.Tensor,
262
+ kv: Optional[torch.Tensor] = None,
263
+ seqlen_offset: int = 0,
264
+ **kwargs,
265
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
266
+ if (
267
+ self._seq_len_cached < qkv.shape[1] + seqlen_offset
268
+ or self._cos_cached.device != qkv.device
269
+ or self._cos_cached.dtype != qkv.dtype
270
+ or (self.training and self._cos_cached.is_inference())
271
+ ):
272
+ self._update_cos_sin_cache(qkv.shape[1] + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
273
+
274
+ if kv is None:
275
+ return _apply_rotary_emb_qkv(
276
+ qkv,
277
+ self._cos_cached[seqlen_offset:],
278
+ self._sin_cached[seqlen_offset:],
279
+ )
280
+ else:
281
+ q = _apply_rotary_emb(
282
+ qkv,
283
+ self._cos_cached[seqlen_offset:],
284
+ self._sin_cached[seqlen_offset:],
285
+ )
286
+ kv = _apply_rotary_emb_kv(
287
+ kv,
288
+ self._cos_cached[seqlen_offset:],
289
+ self._sin_cached[seqlen_offset:],
290
+ )
291
+
292
+ return q, kv
293
+
294
+
295
+ class MLP(nn.Module):
296
+ """Multi-Layer Perceptron.
297
+
298
+ Reference:
299
+ Attention Is All You Need.
300
+ https://arxiv.org/pdf/1706.03762.pdf.
301
+
302
+ """
303
+
304
+ def __init__(
305
+ self,
306
+ config: PretrainedConfig,
307
+ n_inner: Optional[int] = None,
308
+ act_fn: Optional[str] = None,
309
+ ) -> None:
310
+ super().__init__()
311
+
312
+ act_fn = config.activation_function if act_fn is None else act_fn
313
+
314
+ n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
315
+ n_inner = n_inner if n_inner is not None else 4 * config.n_embd
316
+
317
+ self.fc1 = nn.Linear(config.n_embd, n_inner)
318
+ self.fc2 = nn.Linear(n_inner, config.n_embd)
319
+ self.act = ACT2FN[act_fn]
320
+
321
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
322
+ hidden_states = self.fc1(hidden_states)
323
+ hidden_states = self.act(hidden_states)
324
+ hidden_states = self.fc2(hidden_states)
325
+
326
+ return hidden_states
327
+
328
+
329
+ class SelfAttention(nn.Module):
330
+ """Self-attention layer (compatible with PyTorch).
331
+
332
+ Reference:
333
+ https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
334
+
335
+ """
336
+
337
+ def __init__(
338
+ self,
339
+ causal: bool = True,
340
+ softmax_scale: Optional[float] = None,
341
+ attention_dropout: float = 0.0,
342
+ ) -> None:
343
+ super().__init__()
344
+
345
+ self.causal = causal
346
+ self.softmax_scale = softmax_scale
347
+ self.drop = nn.Dropout(attention_dropout)
348
+
349
+ @torch.autocast("cpu", enabled=False)
350
+ @torch.autocast("cuda", enabled=False)
351
+ def forward(
352
+ self,
353
+ qkv: torch.FloatTensor,
354
+ causal: bool = None,
355
+ key_padding_mask: Optional[torch.BoolTensor] = None,
356
+ **kwargs,
357
+ ) -> torch.FloatTensor:
358
+ batch_size, seqlen = qkv.shape[0], qkv.shape[1]
359
+ q, k, v = qkv.unbind(dim=2)
360
+
361
+ q = q.to(torch.float32)
362
+ k = k.to(torch.float32)
363
+
364
+ causal = self.causal if causal is None else causal
365
+ softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
366
+
367
+ # Autocast is manually disabled to avoid `torch.einsum` performing the operation
368
+ # using float16, which might lead to overflow
369
+ scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
370
+
371
+ if key_padding_mask is not None:
372
+ padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device)
373
+ padding_mask.masked_fill_(key_padding_mask, 0.0)
374
+
375
+ scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
376
+
377
+ if causal:
378
+ causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
379
+ scores = scores + causal_mask.to(dtype=scores.dtype)
380
+
381
+ attention = torch.softmax(scores, dim=-1).to(v.dtype)
382
+ attention = self.drop(attention)
383
+
384
+ output = torch.einsum("bhts,bshd->bthd", attention, v)
385
+
386
+ return output
387
+
388
+
389
+ class CrossAttention(nn.Module):
390
+ """Cross-attention layer (compatible with PyTorch).
391
+
392
+ Reference:
393
+ https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
394
+
395
+ """
396
+
397
+ def __init__(
398
+ self,
399
+ causal: bool = True,
400
+ softmax_scale: Optional[float] = None,
401
+ attention_dropout: float = 0.0,
402
+ ) -> None:
403
+ super().__init__()
404
+
405
+ self.causal = causal
406
+ self.softmax_scale = softmax_scale
407
+ self.drop = nn.Dropout(attention_dropout)
408
+
409
+ @torch.autocast("cpu", enabled=False)
410
+ @torch.autocast("cuda", enabled=False)
411
+ def forward(
412
+ self,
413
+ q: torch.FloatTensor,
414
+ kv: torch.FloatTensor,
415
+ causal: bool = None,
416
+ key_padding_mask: Optional[torch.BoolTensor] = None,
417
+ **kwargs,
418
+ ) -> torch.FloatTensor:
419
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
420
+ seqlen_k = kv.shape[1]
421
+
422
+ if kv.shape[3] != q.shape[2]:
423
+ kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
424
+ k, v = kv.unbind(dim=2)
425
+
426
+ q = q.to(torch.float32)
427
+ k = k.to(torch.float32)
428
+
429
+ causal = self.causal if causal is None else causal
430
+ softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
431
+
432
+ # Autocast is manually disabled to avoid `torch.einsum` performing the operation
433
+ # using float16, which might lead to overflow
434
+ scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
435
+
436
+ if key_padding_mask is not None:
437
+ padding_mask = torch.full(
438
+ (batch_size, seqlen_k),
439
+ -10000.0,
440
+ dtype=scores.dtype,
441
+ device=scores.device,
442
+ )
443
+ padding_mask.masked_fill_(key_padding_mask, 0.0)
444
+
445
+ scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
446
+
447
+ if causal:
448
+ rows = rearrange(torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1")
449
+ cols = torch.arange(seqlen_k, device=k.device, dtype=torch.long)
450
+ causal_mask = cols > rows + seqlen_k - seqlen_q
451
+
452
+ scores = scores.masked_fill(causal_mask, -10000.0)
453
+
454
+ attention = torch.softmax(scores, dim=-1).to(v.dtype)
455
+ attention = self.drop(attention)
456
+
457
+ output = torch.einsum("bhts,bshd->bthd", attention, v)
458
+
459
+ return output
460
+
461
+
462
+ def _find_mha_dims(
463
+ config: PretrainedConfig,
464
+ n_head: Optional[int] = None,
465
+ n_head_kv: Optional[int] = None,
466
+ head_dim: Optional[int] = None,
467
+ ) -> Tuple[int, int]:
468
+ if n_head is None and head_dim is None:
469
+ head_dim = config.n_embd // config.n_head
470
+ n_head = config.n_head
471
+ elif n_head is None or head_dim is None:
472
+ raise ValueError("`n_head` and `head_dim` must be both specified or `None`.")
473
+
474
+ if n_head_kv is None:
475
+ n_head_kv = getattr(config, "n_head_kv", None) or n_head
476
+
477
+ return n_head, n_head_kv, head_dim
478
+
479
+
480
+ def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, layer_idx: int) -> torch.FloatTensor:
481
+ num_heads, head_dim = kv.shape[-2:]
482
+
483
+ if layer_idx not in inference_params.key_value_memory_dict:
484
+ inference_params.key_value_memory_dict[layer_idx] = torch.empty(
485
+ inference_params.max_batch_size,
486
+ inference_params.max_seqlen,
487
+ 2,
488
+ num_heads,
489
+ head_dim,
490
+ dtype=kv.dtype,
491
+ device=kv.device,
492
+ )
493
+
494
+ batch_start = inference_params.batch_size_offset
495
+ batch_end = batch_start + kv.shape[0]
496
+
497
+ sequence_start = inference_params.seqlen_offset
498
+ sequence_end = sequence_start + kv.shape[1]
499
+
500
+ # When the current sequence length is equal to or larger than the maximum sequence length,
501
+ # we need to concatenate the current `kv` with the cached `kv` to expand its length
502
+ if sequence_end >= inference_params.max_seqlen:
503
+ inference_params.key_value_memory_dict[layer_idx] = torch.concatenate((inference_params.key_value_memory_dict[layer_idx], kv), dim=1)
504
+
505
+ inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, sequence_start:sequence_end, ...] = kv
506
+ kv = inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, :sequence_end, ...]
507
+
508
+ return kv
509
+
510
+
511
+ class MHA(nn.Module):
512
+ """Multi-head attention layer."""
513
+
514
+ def __init__(
515
+ self,
516
+ config: PretrainedConfig,
517
+ dtype: Optional[torch.dtype] = None,
518
+ device: Optional[str] = None,
519
+ rotary_dim: Optional[int] = None,
520
+ rotary_base: float = 10000.0,
521
+ rotary_scale_base: Optional[float] = None,
522
+ n_head: Optional[int] = None,
523
+ n_head_kv: Optional[int] = None,
524
+ head_dim: Optional[int] = None,
525
+ bias: bool = True,
526
+ causal: bool = True,
527
+ softmax_scale: Optional[float] = None,
528
+ layer_idx: Optional[int] = None,
529
+ return_residual: bool = False,
530
+ checkpointing: bool = False,
531
+ ) -> None:
532
+ super().__init__()
533
+
534
+ # Rotary embedding
535
+ self.rotary_dim = rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0)
536
+ if self.rotary_dim > 0:
537
+ rotary_cls = FlashRotaryEmbedding if config.flash_rotary else RotaryEmbedding
538
+ if rotary_cls is None:
539
+ rotary_cls = RotaryEmbedding
540
+
541
+ rotary_kwargs = {}
542
+ if rotary_cls is RotaryEmbedding:
543
+ rotary_kwargs["max_position_embeddings"] = config.n_positions
544
+
545
+ self.rotary_emb = rotary_cls(
546
+ self.rotary_dim,
547
+ base=rotary_base,
548
+ scale_base=rotary_scale_base,
549
+ device=device,
550
+ **rotary_kwargs,
551
+ )
552
+
553
+ # MLP
554
+ self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(
555
+ config, n_head=n_head, n_head_kv=n_head_kv, head_dim=head_dim
556
+ )
557
+ op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv)
558
+ hidden_size = config.n_embd
559
+
560
+ linear_cls = FusedDense if config.fused_dense else nn.Linear
561
+ if linear_cls is None:
562
+ linear_cls = nn.Linear
563
+
564
+ self.Wqkv = linear_cls(hidden_size, op_size, bias=bias, device=device, dtype=dtype)
565
+ self.out_proj = linear_cls(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype)
566
+
567
+ # Attention
568
+ # attn_cls = FlashSelfAttention if config.flash_attn else SelfAttention
569
+ attn_cls = FlashSelfAttention
570
+ if attn_cls is None:
571
+ attn_cls = SelfAttention
572
+
573
+ # cross_attn_cls = FlashCrossAttention if config.flash_attn else CrossAttention
574
+ cross_attn_cls = FlashCrossAttention
575
+ if cross_attn_cls is None:
576
+ cross_attn_cls = CrossAttention
577
+
578
+ self.inner_attn = attn_cls(
579
+ causal=causal,
580
+ softmax_scale=softmax_scale,
581
+ attention_dropout=config.attn_pdrop,
582
+ )
583
+ self.inner_cross_attn = cross_attn_cls(
584
+ causal=causal,
585
+ softmax_scale=softmax_scale,
586
+ attention_dropout=config.attn_pdrop,
587
+ )
588
+
589
+ # self.flash_attn = config.flash_attn and attn_cls is FlashSelfAttention
590
+ self.flash_attn = True
591
+ self.layer_idx = layer_idx
592
+ self.return_residual = return_residual
593
+ self.checkpointing = checkpointing
594
+
595
+ def _forward_self_attn(
596
+ self, x: torch.FloatTensor, key_padding_mask: Optional[torch.BoolTensor]
597
+ ) -> torch.FloatTensor:
598
+ qkv = self.Wqkv(x)
599
+ qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
600
+
601
+ if self.rotary_dim > 0:
602
+ qkv = self.rotary_emb(qkv)
603
+
604
+ if self.flash_attn:
605
+ batch_size, seqlen = qkv.shape[0], qkv.shape[1]
606
+
607
+ cu_seqlens, max_seqlen = None, None
608
+ if key_padding_mask is not None:
609
+ # If `key_padding_mask` is supplied, we need to unpad the input and retrieve
610
+ # the `cu_seqlens` and `max_seqlen` to be used by `flash-attn`
611
+ qkv, indices, cu_seqlens, max_seqlen = unpad_input(qkv, key_padding_mask)
612
+
613
+ if self.checkpointing:
614
+ attn_output = torch.utils.checkpoint.checkpoint(
615
+ self.inner_attn, qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
616
+ )
617
+ else:
618
+ attn_output = self.inner_attn(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen).to(qkv.device)
619
+
620
+ # If `key_padding_mask` is supplied, we need to pad the output back to the original shape
621
+ return pad_input(attn_output, indices, batch_size, seqlen) if key_padding_mask is not None else attn_output
622
+
623
+ if self.checkpointing:
624
+ return torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, key_padding_mask=key_padding_mask)
625
+
626
+ return self.inner_attn(qkv, key_padding_mask=key_padding_mask)
627
+
628
+ def _forward_cross_attn(
629
+ self,
630
+ x: torch.FloatTensor,
631
+ past_key_values: Optional[InferenceParams],
632
+ key_padding_mask: Optional[torch.BoolTensor],
633
+ ) -> torch.FloatTensor:
634
+ batch_size = x.shape[0]
635
+
636
+ qkv = self.Wqkv(x)
637
+
638
+ q = qkv[..., : self.n_head * self.head_dim]
639
+ q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
640
+
641
+ kv = qkv[..., self.n_head * self.head_dim :]
642
+ kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
643
+
644
+ seqlen_offset = past_key_values.seqlen_offset if past_key_values is not None else 0
645
+ causal = None if seqlen_offset == 0 else False
646
+ if self.rotary_dim > 0:
647
+ q, kv = self.rotary_emb(q, kv=kv, seqlen_offset=seqlen_offset)
648
+
649
+ if past_key_values is not None:
650
+ kv = _update_kv_cache(kv, past_key_values, self.layer_idx)
651
+
652
+ if self.flash_attn:
653
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
654
+ seqlen_k = kv.shape[1]
655
+
656
+ cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = (
657
+ None,
658
+ None,
659
+ None,
660
+ None,
661
+ )
662
+ if key_padding_mask is not None:
663
+ kv, _, cu_seqlens_k, max_seqlen_k = unpad_input(kv, key_padding_mask)
664
+
665
+ if seqlen_q == 1:
666
+ key_padding_mask = torch.ones(batch_size, 1, device=q.device)
667
+ elif seqlen_q != seqlen_k:
668
+ key_padding_mask = key_padding_mask[:, -seqlen_q:]
669
+
670
+ q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, key_padding_mask)
671
+
672
+ if self.checkpointing:
673
+ attn_output = torch.utils.checkpoint.checkpoint(
674
+ self.inner_cross_attn,
675
+ q,
676
+ kv,
677
+ causal=causal,
678
+ cu_seqlens=cu_seqlens_q,
679
+ max_seqlen=max_seqlen_q,
680
+ cu_seqlens_k=cu_seqlens_k,
681
+ max_seqlen_k=max_seqlen_k,
682
+ )
683
+ else:
684
+ attn_output = self.inner_cross_attn(
685
+ q,
686
+ kv,
687
+ causal=causal,
688
+ cu_seqlens=cu_seqlens_q,
689
+ max_seqlen=max_seqlen_q,
690
+ cu_seqlens_k=cu_seqlens_k,
691
+ max_seqlen_k=max_seqlen_k,
692
+ )
693
+
694
+ return (
695
+ pad_input(attn_output, indices_q, batch_size, max_seqlen_q)
696
+ if key_padding_mask is not None
697
+ else attn_output
698
+ )
699
+
700
+ if self.checkpointing:
701
+ return torch.utils.checkpoint.checkpoint(
702
+ self.inner_cross_attn,
703
+ q,
704
+ kv,
705
+ key_padding_mask=key_padding_mask,
706
+ causal=causal,
707
+ )
708
+
709
+ return self.inner_cross_attn(q, kv, key_padding_mask=key_padding_mask, causal=causal)
710
+
711
+ def forward(
712
+ self,
713
+ x: torch.FloatTensor,
714
+ past_key_values: Optional[InferenceParams] = None,
715
+ attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
716
+ **kwargs,
717
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
718
+ if attention_mask is not None:
719
+ attention_mask = attention_mask.bool()
720
+ else:
721
+ attention_mask = None
722
+
723
+ # MHA
724
+ if self.n_head == self.n_head_kv:
725
+ if past_key_values is None:
726
+ # If `past_key_values` are not supplied, we run self-attention
727
+ attn_output = self._forward_self_attn(x, attention_mask)
728
+ else:
729
+ # If `past_key_values` are supplied, it means that we might have cached values and
730
+ # could take advantage of cross-attention
731
+ attn_output = self._forward_cross_attn(x, past_key_values, attention_mask)
732
+ # MQA / GQA
733
+ else:
734
+ # Regardless of `past_key_values` being supplied or not, it always use cross-attention
735
+ # because `q` and `kv` lengths might be different
736
+ attn_output = self._forward_cross_attn(x, past_key_values, attention_mask)
737
+
738
+ output = rearrange(attn_output, "... h d -> ... (h d)")
739
+ output = self.out_proj(output)
740
+
741
+ return output if not self.return_residual else (output, x)
742
+
743
+
744
+ class ParallelBlock(nn.Module):
745
+ """Parallel block.
746
+
747
+ This block applies parallel mixer and MLP layers to the input (used in GPT-J and CodeGen).
748
+
749
+ """
750
+
751
+ def __init__(
752
+ self,
753
+ config: PretrainedConfig,
754
+ block_idx: Optional[int] = None,
755
+ ) -> None:
756
+ super().__init__()
757
+
758
+ self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
759
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
760
+ self.block_idx = block_idx
761
+
762
+ self.mixer = MHA(config, layer_idx=block_idx)
763
+ self.mlp = MLP(config)
764
+
765
+ def forward(
766
+ self,
767
+ hidden_states: torch.FloatTensor,
768
+ past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
769
+ attention_mask: Optional[torch.BoolTensor] = None,
770
+ **kwargs,
771
+ ) -> torch.FloatTensor:
772
+ residual = hidden_states
773
+ hidden_states = self.ln(hidden_states)
774
+
775
+ attn_outputs = self.mixer(
776
+ hidden_states,
777
+ past_key_values=past_key_values,
778
+ attention_mask=attention_mask,
779
+ )
780
+ if isinstance(attn_outputs, tuple):
781
+ attn_outputs = attn_outputs[0]
782
+
783
+ attn_outputs = self.resid_dropout(attn_outputs)
784
+ feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
785
+
786
+ hidden_states = attn_outputs + feed_forward_hidden_states + residual
787
+
788
+ return hidden_states
789
+
790
+
791
+ class CausalLMHead(nn.Module):
792
+ """Causal Language Modeling head.
793
+
794
+ Reference:
795
+ Improving Language Understanding by Generative Pre-Training.
796
+ https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
797
+
798
+ """
799
+
800
+ def __init__(self, config: PretrainedConfig) -> None:
801
+ super().__init__()
802
+
803
+ self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
804
+ self.linear = nn.Linear(config.n_embd, config.vocab_size)
805
+
806
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
807
+ hidden_states = self.ln(hidden_states)
808
+ logits = self.linear(hidden_states).to(torch.float32)
809
+
810
+ return logits
811
+
812
+
813
+ class CausalLMLoss(nn.Module):
814
+ """Causal Language Modeling loss.
815
+
816
+ Reference:
817
+ Improving Language Understanding by Generative Pre-Training.
818
+ https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
819
+
820
+ """
821
+
822
+ def __init__(self, shift_labels: bool = True) -> None:
823
+ super().__init__()
824
+
825
+ self.shift_labels = shift_labels
826
+ self.loss_fct = nn.CrossEntropyLoss()
827
+
828
+ def forward(self, logits: torch.FloatTensor, labels: torch.LongTensor) -> torch.FloatTensor:
829
+ if self.shift_labels:
830
+ logits = logits[..., :-1, :].contiguous()
831
+ labels = labels[..., 1:].contiguous()
832
+
833
+ loss = self.loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
834
+
835
+ return loss
836
+
837
+
838
+ class PhiPreTrainedModel(PreTrainedModel):
839
+ """Phi pre-trained model."""
840
+
841
+ config_class = PhiConfig
842
+ base_model_prefix = "transformer"
843
+ supports_gradient_checkpointing = False
844
+ _no_split_modules = ["ParallelBlock"]
845
+
846
+ def __init__(self, *inputs, **kwargs) -> None:
847
+ super().__init__(*inputs, **kwargs)
848
+
849
+ def _init_weights(self, module: nn.Module) -> None:
850
+ if isinstance(module, (nn.Linear,)):
851
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
852
+ if module.bias is not None:
853
+ module.bias.data.zero_()
854
+ elif isinstance(module, nn.Embedding):
855
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
856
+ if module.padding_idx is not None:
857
+ module.weight.data[module.padding_idx].zero_()
858
+ elif isinstance(module, nn.LayerNorm):
859
+ if module.bias is not None:
860
+ module.bias.data.zero_()
861
+ module.weight.data.fill_(1.0)
862
+
863
+ def prepare_inputs_for_generation(
864
+ self,
865
+ input_ids: torch.LongTensor,
866
+ past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
867
+ attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
868
+ **kwargs,
869
+ ) -> Dict[str, Any]:
870
+ if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
871
+ past_key_values = InferenceParams(
872
+ max_seqlen=self.config.n_positions,
873
+ max_batch_size=input_ids.shape[0],
874
+ seqlen_offset=0,
875
+ batch_size_offset=0,
876
+ key_value_memory_dict={},
877
+ lengths_per_sample=None,
878
+ )
879
+ else:
880
+ # Assume that `past_key_values` has cached all tokens up to the last token in `input_ids`
881
+ past_key_values.seqlen_offset = input_ids.shape[1] - 1
882
+ input_ids = input_ids[:, -1].unsqueeze(-1)
883
+
884
+ return {
885
+ "input_ids": input_ids,
886
+ "past_key_values": past_key_values,
887
+ "attention_mask": attention_mask,
888
+ }
889
+
890
+
891
+ class PhiModel(PhiPreTrainedModel):
892
+ """Phi model."""
893
+
894
+ _keys_to_ignore_on_load_missing = [""]
895
+ _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"]
896
+
897
+ def __init__(self, config: PhiConfig) -> None:
898
+ super().__init__(config)
899
+
900
+ self.embd = Embedding(config)
901
+ self.embed_tokens = self.embd
902
+ self.h = nn.ModuleList([ParallelBlock(config, block_idx=i) for i in range(config.n_layer)])
903
+ self.gradient_checkpointing = False
904
+ self.post_init()
905
+
906
+ def get_input_embeddings(self) -> nn.Embedding:
907
+ return self.embd.wte
908
+
909
+ def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
910
+ self.embd.wte = new_embeddings
911
+
912
+ def forward(
913
+ self,
914
+ input_ids: torch.LongTensor,
915
+ past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
916
+ attention_mask: Optional[torch.BoolTensor] = None,
917
+ inputs_embeds: Optional[torch.FloatTensor] = None,
918
+ ) -> torch.FloatTensor:
919
+ if inputs_embeds is None:
920
+ hidden_states = self.embd(input_ids)
921
+ else:
922
+ hidden_states = inputs_embeds
923
+
924
+ for layer in self.h:
925
+ hidden_states = layer(
926
+ hidden_states,
927
+ past_key_values=past_key_values,
928
+ attention_mask=attention_mask,
929
+ )
930
+
931
+ return hidden_states
932
+
933
+
934
+ class PhiForCausalLM(PhiPreTrainedModel):
935
+ """Phi for Causal Language Modeling."""
936
+
937
+ _keys_to_ignore_on_load_missing = [""]
938
+ _keys_to_ignore_on_load_unexpected = [r"transformer\.h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"]
939
+
940
+ def __init__(self, config: PhiConfig) -> None:
941
+ super().__init__(config)
942
+
943
+ self.transformer = PhiModel(config)
944
+ self.lm_head = CausalLMHead(config)
945
+ self.loss = CausalLMLoss()
946
+
947
+ self.post_init()
948
+
949
+ def set_input_embeddings(self, value):
950
+ self.transformer.embd = value
951
+
952
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder
953
+ def set_decoder(self, decoder):
954
+ self.transformer = decoder
955
+
956
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder
957
+ def get_decoder(self):
958
+ return self.transformer
959
+
960
+ def get_input_embeddings(self):
961
+ return self.transformer.embd
962
+
963
+ def get_output_embeddings(self) -> nn.Linear:
964
+ return self.lm_head.linear
965
+
966
+ def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
967
+ self.lm_head.linear = new_embeddings
968
+
969
+ def forward(
970
+ self,
971
+ input_ids: torch.LongTensor,
972
+ past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
973
+ attention_mask: Optional[torch.BoolTensor] = None,
974
+ labels: Optional[torch.LongTensor] = None,
975
+ **kwargs,
976
+ ) -> CausalLMOutputWithPast:
977
+ hidden_states = self.transformer(input_ids, past_key_values=past_key_values, attention_mask=attention_mask)
978
+ lm_logits = self.lm_head(hidden_states)
979
+
980
+ loss = None
981
+ if labels is not None:
982
+ loss = self.loss(lm_logits, labels)
983
+
984
+ return CausalLMOutputWithPast(loss=loss, logits=lm_logits, past_key_values=past_key_values)
ChatUniVi/model/language_model/llama.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import CrossEntropyLoss
5
+ from transformers import AutoConfig, AutoModelForCausalLM, \
6
+ LlamaConfig, LlamaModel, LlamaForCausalLM
7
+ from transformers.modeling_outputs import CausalLMOutputWithPast
8
+ from models.tf.modeling_outputs import CausalLMOutputWithPastAndLabel
9
+
10
+ from ChatUniVi.model.arch import MetaModel, ChatUniViMetaForCausalLM
11
+
12
+
13
+ class ChatUniViConfig(LlamaConfig):
14
+ model_type = "ChatUniVi"
15
+
16
+
17
+ class ChatUniViLlamaModel(MetaModel, LlamaModel):
18
+ config_class = ChatUniViConfig
19
+
20
+ def __init__(self, config: LlamaConfig):
21
+ super(ChatUniViLlamaModel, self).__init__(config)
22
+
23
+
24
+ class ChatUniViLlamaForCausalLM(LlamaForCausalLM, ChatUniViMetaForCausalLM):
25
+ config_class = ChatUniViConfig
26
+
27
+ def __init__(self, config):
28
+ super(LlamaForCausalLM, self).__init__(config)
29
+ self.model = ChatUniViLlamaModel(config)
30
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
31
+ # Initialize weights and apply final processing
32
+ self.post_init()
33
+
34
+ def get_model(self):
35
+ return self.model
36
+
37
+ def forward(
38
+ self,
39
+ input_ids: torch.LongTensor = None,
40
+ attention_mask: Optional[torch.Tensor] = None,
41
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
42
+ inputs_embeds: Optional[torch.FloatTensor] = None,
43
+ labels: Optional[torch.LongTensor] = None,
44
+ use_cache: Optional[bool] = None,
45
+ output_attentions: Optional[bool] = None,
46
+ output_hidden_states: Optional[bool] = None,
47
+ images: Optional[torch.FloatTensor] = None,
48
+ return_dict: Optional[bool] = None,
49
+
50
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
51
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
52
+ output_hidden_states = (
53
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
54
+ )
55
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
56
+
57
+ # print(use_cache, output_attentions, return_dict)
58
+ # return 0
59
+ if inputs_embeds is None:
60
+ input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
61
+ # else:
62
+ # print("不调用 prepare_inputs_labels_for_multimodal")
63
+
64
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
65
+
66
+ outputs = self.model(
67
+ input_ids=input_ids,
68
+ attention_mask=attention_mask,
69
+ past_key_values=past_key_values,
70
+ inputs_embeds=inputs_embeds,
71
+ use_cache=use_cache,
72
+ output_attentions=output_attentions,
73
+ output_hidden_states=output_hidden_states,
74
+ return_dict=return_dict
75
+ )
76
+
77
+ hidden_states = outputs[0]
78
+ logits = self.lm_head(hidden_states)
79
+
80
+ loss = None
81
+ if labels is not None:
82
+ # Shift so that tokens < n predict n
83
+ shift_logits = logits[..., :-1, :].contiguous()
84
+ shift_labels = labels[..., 1:].contiguous()
85
+ # Flatten the tokens
86
+ loss_fct = CrossEntropyLoss()
87
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
88
+ shift_labels = shift_labels.view(-1)
89
+ # Enable model/pipeline parallelism
90
+ shift_labels = shift_labels.to(shift_logits.device)
91
+ loss = loss_fct(shift_logits, shift_labels)
92
+
93
+ if not return_dict:
94
+ output = (logits,) + outputs[1:]
95
+ return (loss,) + output if loss is not None else output
96
+
97
+ # return CausalLMOutputWithPast(
98
+ # loss=loss,
99
+ # logits=logits,
100
+ # past_key_values=outputs.past_key_values,
101
+ # hidden_states=outputs.hidden_states,
102
+ # attentions=outputs.attentions,
103
+ # )
104
+ return CausalLMOutputWithPastAndLabel(
105
+ loss=loss,
106
+ labels = labels,
107
+ logits=logits,
108
+ past_key_values=outputs.past_key_values,
109
+ hidden_states=outputs.hidden_states,
110
+ attentions=outputs.attentions,
111
+ )
112
+
113
+ def prepare_inputs_for_generation(
114
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
115
+ ):
116
+ if past_key_values:
117
+ input_ids = input_ids[:, -1:]
118
+
119
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
120
+ if inputs_embeds is not None and past_key_values is None:
121
+ model_inputs = {"inputs_embeds": inputs_embeds}
122
+ else:
123
+ model_inputs = {"input_ids": input_ids}
124
+
125
+ model_inputs.update(
126
+ {
127
+ "past_key_values": past_key_values,
128
+ "use_cache": kwargs.get("use_cache"),
129
+ "attention_mask": attention_mask,
130
+ "images": kwargs.get("images", None),
131
+ }
132
+ )
133
+ return model_inputs
134
+
135
+ AutoConfig.register("ChatUniVi", ChatUniViConfig)
136
+ AutoModelForCausalLM.register(ChatUniViConfig, ChatUniViLlamaForCausalLM)
ChatUniVi/model/language_model/phi.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from torch.nn import CrossEntropyLoss
21
+ from transformers import AutoConfig, AutoModelForCausalLM
22
+ from .modeling_phi.modeling_phi import PhiModel, PhiForCausalLM, CausalLMHead, CausalLMLoss
23
+ from .modeling_phi.configuration_phi import PhiConfig
24
+ from transformers.modeling_outputs import CausalLMOutputWithPast
25
+
26
+ from ChatUniVi.model.arch import MetaModel, ChatUniViMetaForCausalLM
27
+
28
+
29
+ class ChatUniViConfig(PhiConfig):
30
+ model_type = "ChatUniViPhi2"
31
+
32
+
33
+ class ChatUniViPhiModel(MetaModel, PhiModel):
34
+ config_class = ChatUniViConfig
35
+
36
+ def __init__(self, config: PhiConfig):
37
+ super(ChatUniViPhiModel, self).__init__(config)
38
+
39
+
40
+ class ChatUniViPhiForCausalLM(PhiForCausalLM, ChatUniViMetaForCausalLM):
41
+ config_class = ChatUniViConfig
42
+ supports_gradient_checkpointing = True
43
+
44
+ def __init__(self, config):
45
+ super(PhiForCausalLM, self).__init__(config)
46
+ self.config = config
47
+ self.transformer = ChatUniViPhiModel(config)
48
+ self.lm_head = CausalLMHead(config)
49
+ self.loss = CausalLMLoss()
50
+
51
+ self.post_init()
52
+
53
+ def get_model(self):
54
+ return self.transformer
55
+
56
+ def _set_gradient_checkpointing(self, module, value=False):
57
+ module.gradient_checkpointing = value
58
+
59
+ def forward(
60
+ self,
61
+ input_ids: torch.LongTensor = None,
62
+ attention_mask: Optional[torch.Tensor] = None,
63
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
64
+ inputs_embeds: Optional[torch.FloatTensor] = None,
65
+ labels: Optional[torch.LongTensor] = None,
66
+ use_cache: Optional[bool] = None,
67
+ output_attentions: Optional[bool] = None,
68
+ output_hidden_states: Optional[bool] = None,
69
+ images: Optional[torch.FloatTensor] = None,
70
+ return_dict: Optional[bool] = None,
71
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
72
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
73
+ output_hidden_states = (
74
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
75
+ )
76
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
77
+
78
+ input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
79
+
80
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
81
+
82
+ outputs = self.transformer(
83
+ input_ids=input_ids,
84
+ attention_mask=attention_mask,
85
+ past_key_values=past_key_values,
86
+ inputs_embeds=inputs_embeds,
87
+ )
88
+
89
+ hidden_states = outputs
90
+ logits = self.lm_head(hidden_states)
91
+
92
+ loss = None
93
+ if labels is not None:
94
+ # Shift so that tokens < n predict n
95
+ shift_logits = logits[..., :-1, :].contiguous()
96
+ shift_labels = labels[..., 1:].contiguous()
97
+ # Flatten the tokens
98
+ loss_fct = CrossEntropyLoss()
99
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
100
+ shift_labels = shift_labels.view(-1)
101
+ # Enable model/pipeline parallelism
102
+ shift_labels = shift_labels.to(shift_logits.device)
103
+ try:
104
+ loss = loss_fct(shift_logits, shift_labels)
105
+ except:
106
+ loss = torch.nn.Parameter(torch.zeros(1), requires_grad=True)
107
+
108
+ if not return_dict:
109
+ output = (logits,) + outputs
110
+ return (loss,) + output if loss is not None else output
111
+
112
+ return CausalLMOutputWithPast(
113
+ loss=loss,
114
+ logits=logits,
115
+ hidden_states=outputs,
116
+ )
117
+
118
+ def prepare_inputs_for_generation(
119
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
120
+ ):
121
+ if past_key_values:
122
+ input_ids = input_ids[:, -1:]
123
+
124
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
125
+ if inputs_embeds is not None and past_key_values is None:
126
+ model_inputs = {"inputs_embeds": inputs_embeds}
127
+ else:
128
+ model_inputs = {"input_ids": input_ids}
129
+
130
+ model_inputs.update(
131
+ {
132
+ "past_key_values": past_key_values,
133
+ "use_cache": kwargs.get("use_cache"),
134
+ "attention_mask": attention_mask,
135
+ "images": kwargs.get("images", None),
136
+ }
137
+ )
138
+ return model_inputs
139
+
140
+
141
+ AutoConfig.register("ChatUniViPhi2", ChatUniViConfig)
142
+ AutoModelForCausalLM.register(ChatUniViConfig, ChatUniViPhiForCausalLM)
ChatUniVi/model/make_delta.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta
4
+ """
5
+ import argparse
6
+
7
+ import torch
8
+ from tqdm import tqdm
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
10
+ from llava.model.utils import auto_upgrade
11
+
12
+
13
+ def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
14
+ print("Loading base model")
15
+ base = AutoModelForCausalLM.from_pretrained(
16
+ base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17
+
18
+ print("Loading target model")
19
+ auto_upgrade(target_model_path)
20
+ target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
21
+
22
+ print("Calculating delta")
23
+ for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
24
+ if name not in base.state_dict():
25
+ assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
26
+ continue
27
+ if param.data.shape == base.state_dict()[name].shape:
28
+ param.data -= base.state_dict()[name]
29
+ else:
30
+ assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
31
+ bparam = base.state_dict()[name]
32
+ param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam
33
+
34
+ print("Saving delta")
35
+ if hub_repo_id:
36
+ kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
37
+ else:
38
+ kwargs = {}
39
+ target.save_pretrained(delta_path, **kwargs)
40
+ target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
41
+ target_tokenizer.save_pretrained(delta_path, **kwargs)
42
+
43
+
44
+ if __name__ == "__main__":
45
+ parser = argparse.ArgumentParser()
46
+ parser.add_argument("--base-model-path", type=str, required=True)
47
+ parser.add_argument("--target-model-path", type=str, required=True)
48
+ parser.add_argument("--delta-path", type=str, required=True)
49
+ parser.add_argument("--hub-repo-id", type=str, default=None)
50
+ args = parser.parse_args()
51
+
52
+ make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)
ChatUniVi/model/multimodal_encoder/builder.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .clip_encoder import CLIPVisionTower
2
+ from .eva_encoder import EVAVisionTower
3
+
4
+
5
+ def build_vision_tower(vision_tower_cfg, **kwargs):
6
+ vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
7
+ # if vision_tower.startswith("openai") or vision_tower.startswith("laion"):
8
+ # return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
9
+ #
10
+ # elif vision_tower.startswith("eva_vit_g"):
11
+ # return EVAVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
12
+ #
13
+ # raise ValueError(f'Unknown vision tower: {vision_tower}')
14
+ return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
ChatUniVi/model/multimodal_encoder/clip_encoder.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
5
+
6
+
7
+ class CLIPVisionTower(nn.Module):
8
+ def __init__(self, vision_tower, args=None, delay_load=False):
9
+ super().__init__()
10
+
11
+ self.is_loaded = False
12
+
13
+ self.vision_tower_name = vision_tower
14
+ if args is None:
15
+ self.select_layer = -2
16
+ self.select_feature = 'patch'
17
+ else:
18
+ self.select_layer = args.mm_vision_select_layer
19
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
20
+
21
+ if not delay_load:
22
+ self.load_model()
23
+ else:
24
+ self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
25
+
26
+ def load_model(self):
27
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
28
+ self.image_eval_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
29
+ self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
30
+ self.vision_tower.requires_grad_(False)
31
+
32
+ self.is_loaded = True
33
+
34
+ def feature_select(self, image_forward_outs, select_feature='patch'):
35
+ image_features = image_forward_outs.hidden_states[self.select_layer]
36
+ if select_feature == 'patch':
37
+ image_features = image_features[:, 1:]
38
+ elif select_feature == 'cls_patch':
39
+ image_features = image_features
40
+ else:
41
+ raise ValueError(f'Unexpected select feature: {self.select_feature}')
42
+ return image_features
43
+
44
+ @torch.no_grad()
45
+ def forward(self, images, select_feature='patch'):
46
+ if type(images) is list:
47
+ image_features = []
48
+ for image in images:
49
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
50
+ image_feature = self.feature_select(image_forward_out, select_feature).to(image.dtype)
51
+ image_features.append(image_feature)
52
+ else:
53
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
54
+ image_features = self.feature_select(image_forward_outs, select_feature).to(images.dtype)
55
+
56
+ return image_features
57
+
58
+ @property
59
+ def dummy_feature(self):
60
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
61
+
62
+ @property
63
+ def dtype(self):
64
+ return self.vision_tower.dtype
65
+
66
+ @property
67
+ def device(self):
68
+ return self.vision_tower.device
69
+
70
+ @property
71
+ def config(self):
72
+ if self.is_loaded:
73
+ return self.vision_tower.config
74
+ else:
75
+ return self.cfg_only
76
+
77
+ @property
78
+ def hidden_size(self):
79
+ return self.config.hidden_size
80
+
81
+ @property
82
+ def num_patches(self):
83
+ return (self.config.image_size // self.config.patch_size) ** 2
ChatUniVi/model/multimodal_encoder/eva_encoder.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .eva_vit import create_eva_vit_g, _cfg
4
+ from .processor import ImageTrainProcessor, ImageEvalProcessor
5
+
6
+
7
+ class EVAVisionTower(nn.Module):
8
+ def __init__(self, vision_tower, args, delay_load=False):
9
+ super().__init__()
10
+
11
+ self.is_loaded = False
12
+
13
+ self.vision_tower_name = vision_tower
14
+ self.select_layer = args.mm_vision_select_layer
15
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
16
+
17
+ if not delay_load:
18
+ self.load_model()
19
+ else:
20
+ self.cfg_only = _cfg()
21
+
22
+ def load_model(self):
23
+ self.image_processor = ImageTrainProcessor()
24
+ self.image_eval_processor = ImageEvalProcessor()
25
+ self.vision_tower = create_eva_vit_g(
26
+ img_size=224, drop_path_rate=0, use_checkpoint=False, precision="fp16"
27
+ )
28
+ # self.vision_tower.requires_grad_(False)
29
+
30
+ self.is_loaded = True
31
+
32
+ def feature_select(self, image_forward_outs, select_feature='patch'):
33
+ image_features = image_forward_outs[self.select_layer]
34
+ if select_feature == 'patch':
35
+ image_features = image_features[:, 1:]
36
+ elif select_feature == 'cls_patch':
37
+ image_features = image_features
38
+ else:
39
+ raise ValueError(f'Unexpected select feature: {self.select_feature}')
40
+ return image_features
41
+
42
+ @torch.no_grad()
43
+ def forward(self, images, select_feature='patch'):
44
+ if type(images) is list:
45
+ image_features = []
46
+ for image in images:
47
+ image_forward_out = self.vision_tower.get_intermediate_layers(image.to(device=self.device, dtype=self.dtype).unsqueeze(0),)
48
+ image_feature = self.feature_select(image_forward_out, select_feature).to(image.dtype)
49
+ image_features.append(image_feature)
50
+ else:
51
+ image_forward_outs = self.vision_tower.get_intermediate_layers(images.to(device=self.device, dtype=self.dtype))
52
+ image_features = self.feature_select(image_forward_outs, select_feature).to(images.dtype)
53
+
54
+ return image_features
55
+
56
+ @property
57
+ def dummy_feature(self):
58
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
59
+
60
+ @property
61
+ def dtype(self):
62
+ return self.vision_tower.cls_token.dtype
63
+
64
+ @property
65
+ def device(self):
66
+ return self.vision_tower.cls_token.device
67
+
68
+ @property
69
+ def config(self):
70
+ if self.is_loaded:
71
+ return self.vision_tower.config
72
+ else:
73
+ return self.cfg_only
74
+
75
+ @property
76
+ def hidden_size(self):
77
+ return self.vision_tower.num_features
78
+
79
+ @property
80
+ def num_patches(self):
81
+ return (self.config.image_size // self.config.patch_size) ** 2
ChatUniVi/model/multimodal_encoder/eva_vit.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on EVA, BEIT, timm and DeiT code bases
2
+ # https://github.com/baaivision/EVA
3
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm
4
+ # https://github.com/microsoft/unilm/tree/master/beit
5
+ # https://github.com/facebookresearch/deit/
6
+ # https://github.com/facebookresearch/dino
7
+ # --------------------------------------------------------'
8
+ import math
9
+ from functools import partial
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import torch.utils.checkpoint as checkpoint
15
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
16
+ from timm.models.registry import register_model
17
+
18
+ from .utils import download_cached_file
19
+
20
+
21
+ def _cfg(url='', **kwargs):
22
+ return {
23
+ 'url': url,
24
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
25
+ 'crop_pct': .9, 'interpolation': 'bicubic',
26
+ 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
27
+ **kwargs
28
+ }
29
+
30
+
31
+ class DropPath(nn.Module):
32
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
33
+ """
34
+
35
+ def __init__(self, drop_prob=None):
36
+ super(DropPath, self).__init__()
37
+ self.drop_prob = drop_prob
38
+
39
+ def forward(self, x):
40
+ return drop_path(x, self.drop_prob, self.training)
41
+
42
+ def extra_repr(self) -> str:
43
+ return 'p={}'.format(self.drop_prob)
44
+
45
+
46
+ class Mlp(nn.Module):
47
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
48
+ super().__init__()
49
+ out_features = out_features or in_features
50
+ hidden_features = hidden_features or in_features
51
+ self.fc1 = nn.Linear(in_features, hidden_features)
52
+ self.act = act_layer()
53
+ self.fc2 = nn.Linear(hidden_features, out_features)
54
+ self.drop = nn.Dropout(drop)
55
+
56
+ def forward(self, x):
57
+ x = self.fc1(x)
58
+ x = self.act(x)
59
+ # x = self.drop(x)
60
+ # commit this for the orignal BERT implement
61
+ x = self.fc2(x)
62
+ x = self.drop(x)
63
+ return x
64
+
65
+
66
+ class Attention(nn.Module):
67
+ def __init__(
68
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
69
+ proj_drop=0., window_size=None, attn_head_dim=None):
70
+ super().__init__()
71
+ self.num_heads = num_heads
72
+ head_dim = dim // num_heads
73
+ if attn_head_dim is not None:
74
+ head_dim = attn_head_dim
75
+ all_head_dim = head_dim * self.num_heads
76
+ self.scale = qk_scale or head_dim ** -0.5
77
+
78
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
79
+ if qkv_bias:
80
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
81
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
82
+ else:
83
+ self.q_bias = None
84
+ self.v_bias = None
85
+
86
+ if window_size:
87
+ self.window_size = window_size
88
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
89
+ self.relative_position_bias_table = nn.Parameter(
90
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
91
+ # cls to token & token 2 cls & cls to cls
92
+
93
+ # get pair-wise relative position index for each token inside the window
94
+ coords_h = torch.arange(window_size[0])
95
+ coords_w = torch.arange(window_size[1])
96
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
97
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
98
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
99
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
100
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
101
+ relative_coords[:, :, 1] += window_size[1] - 1
102
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
103
+ relative_position_index = \
104
+ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
105
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
106
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
107
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
108
+ relative_position_index[0, 0] = self.num_relative_distance - 1
109
+
110
+ self.register_buffer("relative_position_index", relative_position_index)
111
+ else:
112
+ self.window_size = None
113
+ self.relative_position_bias_table = None
114
+ self.relative_position_index = None
115
+
116
+ self.attn_drop = nn.Dropout(attn_drop)
117
+ self.proj = nn.Linear(all_head_dim, dim)
118
+ self.proj_drop = nn.Dropout(proj_drop)
119
+
120
+ def forward(self, x, rel_pos_bias=None):
121
+ B, N, C = x.shape
122
+ qkv_bias = None
123
+ if self.q_bias is not None:
124
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
125
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
126
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
127
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
128
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
129
+
130
+ q = q * self.scale
131
+ attn = (q @ k.transpose(-2, -1))
132
+
133
+ if self.relative_position_bias_table is not None:
134
+ relative_position_bias = \
135
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
136
+ self.window_size[0] * self.window_size[1] + 1,
137
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
138
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
139
+ attn = attn + relative_position_bias.unsqueeze(0)
140
+
141
+ if rel_pos_bias is not None:
142
+ attn = attn + rel_pos_bias
143
+
144
+ attn = attn.softmax(dim=-1)
145
+ attn = self.attn_drop(attn)
146
+
147
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
148
+ x = self.proj(x)
149
+ x = self.proj_drop(x)
150
+ return x
151
+
152
+
153
+ class Block(nn.Module):
154
+
155
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
156
+ drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
157
+ window_size=None, attn_head_dim=None):
158
+ super().__init__()
159
+ self.norm1 = norm_layer(dim)
160
+ self.attn = Attention(
161
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
162
+ attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
163
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
164
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
165
+ self.norm2 = norm_layer(dim)
166
+ mlp_hidden_dim = int(dim * mlp_ratio)
167
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
168
+
169
+ if init_values is not None and init_values > 0:
170
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
171
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
172
+ else:
173
+ self.gamma_1, self.gamma_2 = None, None
174
+
175
+ def forward(self, x, rel_pos_bias=None):
176
+ if self.gamma_1 is None:
177
+ x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
178
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
179
+ else:
180
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
181
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
182
+ return x
183
+
184
+
185
+ class PatchEmbed(nn.Module):
186
+ """ Image to Patch Embedding
187
+ """
188
+
189
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
190
+ super().__init__()
191
+ img_size = to_2tuple(img_size)
192
+ patch_size = to_2tuple(patch_size)
193
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
194
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
195
+ self.img_size = img_size
196
+ self.patch_size = patch_size
197
+ self.num_patches = num_patches
198
+
199
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
200
+
201
+ def forward(self, x, **kwargs):
202
+ B, C, H, W = x.shape
203
+ # FIXME look at relaxing size constraints
204
+ assert H == self.img_size[0] and W == self.img_size[1], \
205
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
206
+ x = self.proj(x).flatten(2).transpose(1, 2)
207
+ return x
208
+
209
+
210
+ class RelativePositionBias(nn.Module):
211
+
212
+ def __init__(self, window_size, num_heads):
213
+ super().__init__()
214
+ self.window_size = window_size
215
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
216
+ self.relative_position_bias_table = nn.Parameter(
217
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
218
+ # cls to token & token 2 cls & cls to cls
219
+
220
+ # get pair-wise relative position index for each token inside the window
221
+ coords_h = torch.arange(window_size[0])
222
+ coords_w = torch.arange(window_size[1])
223
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
224
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
225
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
226
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
227
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
228
+ relative_coords[:, :, 1] += window_size[1] - 1
229
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
230
+ relative_position_index = \
231
+ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
232
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
233
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
234
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
235
+ relative_position_index[0, 0] = self.num_relative_distance - 1
236
+
237
+ self.register_buffer("relative_position_index", relative_position_index)
238
+
239
+ # trunc_normal_(self.relative_position_bias_table, std=.02)
240
+
241
+ def forward(self):
242
+ relative_position_bias = \
243
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
244
+ self.window_size[0] * self.window_size[1] + 1,
245
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
246
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
247
+
248
+
249
+ class VisionTransformer(nn.Module):
250
+ """ Vision Transformer with support for patch or hybrid CNN input stage
251
+ """
252
+
253
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
254
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
255
+ drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None,
256
+ use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,
257
+ use_mean_pooling=True, init_scale=0.001, use_checkpoint=False):
258
+ super().__init__()
259
+ self.image_size = img_size
260
+ self.num_classes = num_classes
261
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
262
+
263
+ self.patch_embed = PatchEmbed(
264
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
265
+ num_patches = self.patch_embed.num_patches
266
+
267
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
268
+ if use_abs_pos_emb:
269
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
270
+ else:
271
+ self.pos_embed = None
272
+ self.pos_drop = nn.Dropout(p=drop_rate)
273
+
274
+ if use_shared_rel_pos_bias:
275
+ self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
276
+ else:
277
+ self.rel_pos_bias = None
278
+ self.use_checkpoint = use_checkpoint
279
+
280
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
281
+ self.use_rel_pos_bias = use_rel_pos_bias
282
+ self.blocks = nn.ModuleList([
283
+ Block(
284
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
285
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
286
+ init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
287
+ for i in range(depth)])
288
+ # self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
289
+ # self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
290
+ # self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
291
+
292
+ if self.pos_embed is not None:
293
+ trunc_normal_(self.pos_embed, std=.02)
294
+ trunc_normal_(self.cls_token, std=.02)
295
+ # trunc_normal_(self.mask_token, std=.02)
296
+ # if isinstance(self.head, nn.Linear):
297
+ # trunc_normal_(self.head.weight, std=.02)
298
+ self.apply(self._init_weights)
299
+ self.fix_init_weight()
300
+
301
+ # if isinstance(self.head, nn.Linear):
302
+ # self.head.weight.data.mul_(init_scale)
303
+ # self.head.bias.data.mul_(init_scale)
304
+
305
+ def fix_init_weight(self):
306
+ def rescale(param, layer_id):
307
+ param.div_(math.sqrt(2.0 * layer_id))
308
+
309
+ for layer_id, layer in enumerate(self.blocks):
310
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
311
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
312
+
313
+ def _init_weights(self, m):
314
+ if isinstance(m, nn.Linear):
315
+ trunc_normal_(m.weight, std=.02)
316
+ if isinstance(m, nn.Linear) and m.bias is not None:
317
+ nn.init.constant_(m.bias, 0)
318
+ elif isinstance(m, nn.LayerNorm):
319
+ nn.init.constant_(m.bias, 0)
320
+ nn.init.constant_(m.weight, 1.0)
321
+
322
+ def get_classifier(self):
323
+ return self.head
324
+
325
+ def reset_classifier(self, num_classes, global_pool=''):
326
+ self.num_classes = num_classes
327
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
328
+
329
+ def forward_features(self, x):
330
+ x = self.patch_embed(x)
331
+ batch_size, seq_len, _ = x.size()
332
+
333
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
334
+ x = torch.cat((cls_tokens, x), dim=1)
335
+ if self.pos_embed is not None:
336
+ x = x + self.pos_embed
337
+ x = self.pos_drop(x)
338
+
339
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
340
+ for blk in self.blocks:
341
+ if self.use_checkpoint:
342
+ x = checkpoint.checkpoint(blk, x, rel_pos_bias)
343
+ else:
344
+ x = blk(x, rel_pos_bias)
345
+ return x
346
+
347
+ # x = self.norm(x)
348
+
349
+ # if self.fc_norm is not None:
350
+ # t = x[:, 1:, :]
351
+ # return self.fc_norm(t.mean(1))
352
+ # else:
353
+ # return x[:, 0]
354
+
355
+ def forward(self, x):
356
+ x = self.forward_features(x)
357
+ # x = self.head(x)
358
+ return x
359
+
360
+ def get_intermediate_layers(self, x):
361
+ x = self.patch_embed(x)
362
+ batch_size, seq_len, _ = x.size()
363
+
364
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
365
+ x = torch.cat((cls_tokens, x), dim=1)
366
+ if self.pos_embed is not None:
367
+ x = x + self.pos_embed
368
+ x = self.pos_drop(x)
369
+
370
+ features = []
371
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
372
+ for blk in self.blocks:
373
+ x = blk(x, rel_pos_bias)
374
+ features.append(x)
375
+
376
+ return features
377
+
378
+
379
+ def interpolate_pos_embed(model, checkpoint_model):
380
+ if 'pos_embed' in checkpoint_model:
381
+ pos_embed_checkpoint = checkpoint_model['pos_embed'].float()
382
+ embedding_size = pos_embed_checkpoint.shape[-1]
383
+ num_patches = model.patch_embed.num_patches
384
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
385
+ # height (== width) for the checkpoint position embedding
386
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
387
+ # height (== width) for the new position embedding
388
+ new_size = int(num_patches ** 0.5)
389
+ # class_token and dist_token are kept unchanged
390
+ if orig_size != new_size:
391
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
392
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
393
+ # only the position tokens are interpolated
394
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
395
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
396
+ pos_tokens = torch.nn.functional.interpolate(
397
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
398
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
399
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
400
+ checkpoint_model['pos_embed'] = new_pos_embed
401
+
402
+
403
+ def convert_weights_to_fp16(model: nn.Module):
404
+ """Convert applicable model parameters to fp16"""
405
+
406
+ def _convert_weights_to_fp16(l):
407
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
408
+ l.weight.data = l.weight.data.half()
409
+ if l.bias is not None:
410
+ l.bias.data = l.bias.data.half()
411
+
412
+ # if isinstance(l, (nn.MultiheadAttention, Attention)):
413
+ # for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
414
+ # tensor = getattr(l, attr)
415
+ # if tensor is not None:
416
+ # tensor.data = tensor.data.half()
417
+
418
+ model.apply(_convert_weights_to_fp16)
419
+
420
+
421
+ def create_eva_vit_g(img_size=224, drop_path_rate=0.4, use_checkpoint=False, precision="fp16"):
422
+ model = VisionTransformer(
423
+ img_size=img_size,
424
+ patch_size=14,
425
+ use_mean_pooling=False,
426
+ embed_dim=1408,
427
+ depth=39,
428
+ num_heads=1408 // 88,
429
+ mlp_ratio=4.3637,
430
+ qkv_bias=True,
431
+ drop_path_rate=drop_path_rate,
432
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
433
+ use_checkpoint=use_checkpoint,
434
+ )
435
+ url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth"
436
+ cached_file = download_cached_file(
437
+ url, check_hash=False, progress=True
438
+ )
439
+ state_dict = torch.load(cached_file, map_location="cpu")
440
+ interpolate_pos_embed(model, state_dict)
441
+
442
+ incompatible_keys = model.load_state_dict(state_dict, strict=False)
443
+ # print(incompatible_keys)
444
+
445
+ if precision == "fp16":
446
+ # model.to("cuda")
447
+ convert_weights_to_fp16(model)
448
+ return model
ChatUniVi/model/multimodal_encoder/processor.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from torchvision import transforms
3
+ from torchvision.transforms.functional import InterpolationMode
4
+
5
+
6
+ class BaseProcessor:
7
+ def __init__(self, mean=None, std=None):
8
+ if mean is None:
9
+ mean = (0.48145466, 0.4578275, 0.40821073)
10
+ if std is None:
11
+ std = (0.26862954, 0.26130258, 0.27577711)
12
+
13
+ self.normalize = transforms.Normalize(mean, std)
14
+
15
+
16
+ class ImageTrainProcessor(BaseProcessor):
17
+ def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0):
18
+ super().__init__(mean=mean, std=std)
19
+
20
+ self.transform = transforms.Compose(
21
+ [
22
+ transforms.Resize(
23
+ (image_size, image_size), interpolation=InterpolationMode.BICUBIC
24
+ ),
25
+ transforms.ToTensor(),
26
+ self.normalize,
27
+ ]
28
+ )
29
+
30
+ def preprocess(self, item, return_tensors):
31
+ return {'pixel_values': [self.transform(item)]}
32
+
33
+
34
+ class ImageEvalProcessor(BaseProcessor):
35
+ def __init__(self, image_size=224, mean=None, std=None):
36
+ super().__init__(mean=mean, std=std)
37
+
38
+ self.transform = transforms.Compose(
39
+ [
40
+ transforms.Resize(
41
+ (image_size, image_size), interpolation=InterpolationMode.BICUBIC
42
+ ),
43
+ transforms.ToTensor(),
44
+ self.normalize,
45
+ ]
46
+ )
47
+
48
+ def preprocess(self, item, return_tensors):
49
+ return {'pixel_values': [self.transform(item)]}
50
+
51
+
52
+ class QWenImageProcessor(BaseProcessor):
53
+ def __init__(self, image_size=224, mean=None, std=None):
54
+ super().__init__(mean=mean, std=std)
55
+
56
+ mean = (0.48145466, 0.4578275, 0.40821073)
57
+ std = (0.26862954, 0.26130258, 0.27577711)
58
+ self.transform = transforms.Compose([
59
+ transforms.Resize(
60
+ (448, 448),
61
+ interpolation=InterpolationMode.BICUBIC
62
+ ),
63
+ transforms.ToTensor(),
64
+ transforms.Normalize(mean=mean, std=std),
65
+ ])
66
+
67
+ def preprocess(self, item, return_tensors):
68
+ return {'pixel_values': [self.transform(item)]}
ChatUniVi/model/multimodal_encoder/utils.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import datetime
9
+ import functools
10
+ import os
11
+
12
+ import torch
13
+ import torch.distributed as dist
14
+ import timm.models.hub as timm_hub
15
+
16
+
17
+ def setup_for_distributed(is_master):
18
+ """
19
+ This function disables printing when not in master process
20
+ """
21
+ import builtins as __builtin__
22
+
23
+ builtin_print = __builtin__.print
24
+
25
+ def print(*args, **kwargs):
26
+ force = kwargs.pop("force", False)
27
+ if is_master or force:
28
+ builtin_print(*args, **kwargs)
29
+
30
+ __builtin__.print = print
31
+
32
+
33
+ def is_dist_avail_and_initialized():
34
+ if not dist.is_available():
35
+ return False
36
+ if not dist.is_initialized():
37
+ return False
38
+ return True
39
+
40
+
41
+ def get_world_size():
42
+ if not is_dist_avail_and_initialized():
43
+ return 1
44
+ return dist.get_world_size()
45
+
46
+
47
+ def get_rank():
48
+ if not is_dist_avail_and_initialized():
49
+ return 0
50
+ return dist.get_rank()
51
+
52
+
53
+ def is_main_process():
54
+ return get_rank() == 0
55
+
56
+
57
+ def init_distributed_mode(args):
58
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
59
+ args.rank = int(os.environ["RANK"])
60
+ args.world_size = int(os.environ["WORLD_SIZE"])
61
+ args.gpu = int(os.environ["LOCAL_RANK"])
62
+ elif "SLURM_PROCID" in os.environ:
63
+ args.rank = int(os.environ["SLURM_PROCID"])
64
+ args.gpu = args.rank % torch.cuda.device_count()
65
+ else:
66
+ print("Not using distributed mode")
67
+ args.distributed = False
68
+ return
69
+
70
+ args.distributed = True
71
+
72
+ torch.cuda.set_device(args.gpu)
73
+ args.dist_backend = "nccl"
74
+ print(
75
+ "| distributed init (rank {}, world {}): {}".format(
76
+ args.rank, args.world_size, args.dist_url
77
+ ),
78
+ flush=True,
79
+ )
80
+ torch.distributed.init_process_group(
81
+ backend=args.dist_backend,
82
+ init_method=args.dist_url,
83
+ world_size=args.world_size,
84
+ rank=args.rank,
85
+ timeout=datetime.timedelta(
86
+ days=365
87
+ ), # allow auto-downloading and de-compressing
88
+ )
89
+ torch.distributed.barrier()
90
+ setup_for_distributed(args.rank == 0)
91
+
92
+
93
+ def get_dist_info():
94
+ if torch.__version__ < "1.0":
95
+ initialized = dist._initialized
96
+ else:
97
+ initialized = dist.is_initialized()
98
+ if initialized:
99
+ rank = dist.get_rank()
100
+ world_size = dist.get_world_size()
101
+ else: # non-distributed training
102
+ rank = 0
103
+ world_size = 1
104
+ return rank, world_size
105
+
106
+
107
+ def main_process(func):
108
+ @functools.wraps(func)
109
+ def wrapper(*args, **kwargs):
110
+ rank, _ = get_dist_info()
111
+ if rank == 0:
112
+ return func(*args, **kwargs)
113
+
114
+ return wrapper
115
+
116
+
117
+ def download_cached_file(url, check_hash=True, progress=False):
118
+ """
119
+ Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
120
+ If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
121
+ """
122
+
123
+ def get_cached_file_path():
124
+ # a hack to sync the file path across processes
125
+ parts = torch.hub.urlparse(url)
126
+ filename = os.path.basename(parts.path)
127
+ cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
128
+
129
+ return cached_file
130
+
131
+ if is_main_process():
132
+ timm_hub.download_cached_file(url, check_hash, progress)
133
+
134
+ if is_dist_avail_and_initialized():
135
+ dist.barrier()
136
+
137
+ return get_cached_file_path()
ChatUniVi/model/multimodal_projector/builder.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import re
4
+
5
+
6
+ class IdentityMap(nn.Module):
7
+ def __init__(self):
8
+ super().__init__()
9
+
10
+ def forward(self, x, *args, **kwargs):
11
+ return x
12
+
13
+ @property
14
+ def config(self):
15
+ return {"mm_projector_type": 'identity'}
16
+
17
+
18
+ class SimpleResBlock(nn.Module):
19
+ def __init__(self, channels):
20
+ super().__init__()
21
+ self.pre_norm = nn.LayerNorm(channels)
22
+
23
+ self.proj = nn.Sequential(
24
+ nn.Linear(channels, channels),
25
+ nn.GELU(),
26
+ nn.Linear(channels, channels)
27
+ )
28
+ def forward(self, x):
29
+ x = self.pre_norm(x)
30
+ return x + self.proj(x)
31
+
32
+
33
+ def build_vision_projector(config, delay_load=False, **kwargs):
34
+ projector_type = getattr(config, 'mm_projector_type', 'linear')
35
+
36
+ if projector_type == 'linear':
37
+ return nn.Linear(config.mm_hidden_size, config.hidden_size)
38
+
39
+ print("projector_type:", projector_type)
40
+ mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
41
+ if mlp_gelu_match:
42
+ mlp_depth = int(mlp_gelu_match.group(1))
43
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
44
+ for _ in range(1, mlp_depth):
45
+ modules.append(nn.GELU())
46
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
47
+ return nn.Sequential(*modules)
48
+
49
+ if projector_type == 'identity':
50
+ return IdentityMap()
51
+
52
+ raise ValueError(f'Unknown projector type: {projector_type}')
ChatUniVi/train/llama_flash_attn_monkey_patch.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple
2
+ import logging
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ import transformers
8
+ from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
9
+
10
+ from einops import rearrange
11
+
12
+ try:
13
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
14
+ except ImportError:
15
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
16
+ from flash_attn.bert_padding import unpad_input, pad_input
17
+
18
+
19
+ def forward(
20
+ self,
21
+ hidden_states: torch.Tensor,
22
+ attention_mask: Optional[torch.Tensor] = None,
23
+ position_ids: Optional[torch.Tensor] = None,
24
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
25
+ output_attentions: bool = False,
26
+ use_cache: bool = False,
27
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
28
+ """Input shape: Batch x Time x Channel
29
+
30
+ attention_mask: [bsz, q_len]
31
+ """
32
+ bsz, q_len, _ = hidden_states.size()
33
+
34
+ query_states = (
35
+ self.q_proj(hidden_states)
36
+ .view(bsz, q_len, self.num_heads, self.head_dim)
37
+ .transpose(1, 2)
38
+ )
39
+ key_states = (
40
+ self.k_proj(hidden_states)
41
+ .view(bsz, q_len, self.num_heads, self.head_dim)
42
+ .transpose(1, 2)
43
+ )
44
+ value_states = (
45
+ self.v_proj(hidden_states)
46
+ .view(bsz, q_len, self.num_heads, self.head_dim)
47
+ .transpose(1, 2)
48
+ )
49
+ # [bsz, q_len, nh, hd]
50
+ # [bsz, nh, q_len, hd]
51
+
52
+ kv_seq_len = key_states.shape[-2]
53
+ assert past_key_value is None, "past_key_value is not supported"
54
+
55
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
56
+ query_states, key_states = apply_rotary_pos_emb(
57
+ query_states, key_states, cos, sin, position_ids
58
+ )
59
+ # [bsz, nh, t, hd]
60
+ assert not output_attentions, "output_attentions is not supported"
61
+ assert not use_cache, "use_cache is not supported"
62
+
63
+ # Flash attention codes from
64
+ # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
65
+
66
+ # transform the data into the format required by flash attention
67
+ qkv = torch.stack(
68
+ [query_states, key_states, value_states], dim=2
69
+ ) # [bsz, nh, 3, q_len, hd]
70
+ qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
71
+ # We have disabled _prepare_decoder_attention_mask in LlamaModel
72
+ # the attention_mask should be the same as the key_padding_mask
73
+ key_padding_mask = attention_mask
74
+
75
+ if key_padding_mask is None:
76
+ qkv = rearrange(qkv, "b s ... -> (b s) ...")
77
+ max_s = q_len
78
+ cu_q_lens = torch.arange(
79
+ 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
80
+ )
81
+ output = flash_attn_unpadded_qkvpacked_func(
82
+ qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
83
+ )
84
+ output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
85
+ else:
86
+ nheads = qkv.shape[-2]
87
+ x = rearrange(qkv, "b s three h d -> b s (three h d)")
88
+ x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
89
+ x_unpad = rearrange(
90
+ x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
91
+ )
92
+ output_unpad = flash_attn_unpadded_qkvpacked_func(
93
+ x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
94
+ )
95
+ output = rearrange(
96
+ pad_input(
97
+ rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len
98
+ ),
99
+ "b s (h d) -> b s h d",
100
+ h=nheads,
101
+ )
102
+ return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None
103
+
104
+
105
+ # Disable the transformation of the attention mask in LlamaModel as the flash attention
106
+ # requires the attention mask to be the same as the key_padding_mask
107
+ def _prepare_decoder_attention_mask(
108
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
109
+ ):
110
+ # [bsz, seq_len]
111
+ return attention_mask
112
+
113
+
114
+ def replace_llama_attn_with_flash_attn():
115
+ cuda_major, cuda_minor = torch.cuda.get_device_capability()
116
+ if cuda_major < 8:
117
+ logging.warning(
118
+ "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
119
+ "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
120
+ )
121
+ transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
122
+ _prepare_decoder_attention_mask
123
+ )
124
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
ChatUniVi/train/train.py ADDED
@@ -0,0 +1,1232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
2
+ # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
3
+ # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import os
18
+ import copy
19
+ from dataclasses import dataclass, field
20
+ import json
21
+ import logging
22
+ import pathlib
23
+ from typing import Dict, Optional, Sequence, List
24
+ import torch
25
+ import transformers
26
+ from ChatUniVi.constants import *
27
+ from torch.utils.data import Dataset
28
+ from ChatUniVi.train.trainer import ChatUniViTrainer
29
+ from ChatUniVi import conversation as conversation_lib
30
+ from ChatUniVi.model import *
31
+ from ChatUniVi.mm_utils import tokenizer_image_token
32
+ from ChatUniVi.config import ModelConfig, DataConfig
33
+ from PIL import Image
34
+ import random
35
+ import numpy as np
36
+ from ChatUniVi.model.dataloader import _get_rawvideo_dec
37
+
38
+ local_rank = None
39
+
40
+
41
+ def rank0_print(*args):
42
+ if local_rank == 0:
43
+ print(*args)
44
+
45
+
46
+ @dataclass
47
+ class ModelArguments:
48
+ model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
49
+ version: Optional[str] = field(default="v0")
50
+ freeze_backbone: bool = field(default=False)
51
+ tune_mm_mlp_adapter: bool = field(default=False)
52
+ vision_tower: Optional[str] = field(default=None)
53
+ mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer
54
+ pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
55
+ mm_use_im_start_end: bool = field(default=False)
56
+ mm_use_im_patch_token: bool = field(default=True)
57
+ mm_vision_select_feature: Optional[str] = field(default="patch")
58
+
59
+ mm_projector_type: Optional[str] = field(default='linear')
60
+ model_use: str = field(default="BASE")
61
+ mm_use_box_start_end: bool = field(default=False)
62
+
63
+
64
+ @dataclass
65
+ class DataArguments:
66
+ lazy_preprocess: bool = False
67
+ is_multimodal: bool = False
68
+ image_aspect_ratio: str = 'square'
69
+ image_grid_pinpoints: Optional[str] = field(default=None)
70
+
71
+ dataset_use: str = field(default="Pretrain")
72
+
73
+
74
+ @dataclass
75
+ class TrainingArguments(transformers.TrainingArguments):
76
+ cache_dir: Optional[str] = field(default=None)
77
+ optim: str = field(default="adamw_torch")
78
+ remove_unused_columns: bool = field(default=False)
79
+ freeze_mm_mlp_adapter: bool = field(default=False)
80
+ mpt_attn_impl: Optional[str] = field(default="triton")
81
+ model_max_length: int = field(
82
+ default=512,
83
+ metadata={
84
+ "help":
85
+ "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
86
+ },
87
+ )
88
+ double_quant: bool = field(
89
+ default=True,
90
+ metadata={"help": "Compress the quantization statistics through double quantization."}
91
+ )
92
+ quant_type: str = field(
93
+ default="nf4",
94
+ metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
95
+ )
96
+ bits: int = field(
97
+ default=16,
98
+ metadata={"help": "How many bits to use."}
99
+ )
100
+ lora_enable: bool = False
101
+ lora_r: int = 64
102
+ lora_alpha: int = 16
103
+ lora_dropout: float = 0.05
104
+ lora_weight_path: str = ""
105
+ lora_bias: str = "none"
106
+
107
+ seed = 42
108
+
109
+
110
+ def maybe_zero_3(param, ignore_status=False, name=None):
111
+ from deepspeed import zero
112
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
113
+ if hasattr(param, "ds_id"):
114
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
115
+ if not ignore_status:
116
+ logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
117
+ with zero.GatheredParameters([param]):
118
+ param = param.data.detach().cpu().clone()
119
+ else:
120
+ param = param.detach().cpu().clone()
121
+ return param
122
+
123
+
124
+ # Borrowed from peft.utils.get_peft_model_state_dict
125
+ def get_peft_state_maybe_zero_3(named_params, bias):
126
+ if bias == "none":
127
+ to_return = {k: t for k, t in named_params if "lora_" in k}
128
+ elif bias == "all":
129
+ to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
130
+ elif bias == "lora_only":
131
+ to_return = {}
132
+ maybe_lora_bias = {}
133
+ lora_bias_names = set()
134
+ for k, t in named_params:
135
+ if "lora_" in k:
136
+ to_return[k] = t
137
+ bias_name = k.split("lora_")[0] + "bias"
138
+ lora_bias_names.add(bias_name)
139
+ elif "bias" in k:
140
+ maybe_lora_bias[k] = t
141
+ for k, t in maybe_lora_bias:
142
+ if bias_name in lora_bias_names:
143
+ to_return[bias_name] = t
144
+ else:
145
+ raise NotImplementedError
146
+ to_return = {k: maybe_zero_3(v, name=k) for k, v in to_return.items()}
147
+ return to_return
148
+
149
+
150
+ def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
151
+ to_return = {k: t for k, t in named_params if "lora_" not in k}
152
+ if require_grad_only:
153
+ to_return = {k: t for k, t in to_return.items() if t.requires_grad}
154
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
155
+ return to_return
156
+
157
+
158
+ def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
159
+ to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
160
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
161
+ return to_return
162
+
163
+
164
+ def find_all_linear_names(model):
165
+ cls = torch.nn.Linear
166
+ lora_module_names = set()
167
+ for name, module in model.named_modules():
168
+ if isinstance(module, cls):
169
+ names = name.split('.')
170
+ lora_module_names.add(names[0] if len(names) == 1 else names[-1])
171
+
172
+ if 'lm_head' in lora_module_names: # needed for 16-bit
173
+ lora_module_names.remove('lm_head')
174
+ return list(lora_module_names)
175
+
176
+
177
+ def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
178
+ output_dir: str):
179
+ """Collects the state dict and dump to disk."""
180
+
181
+ if getattr(trainer.args, "tune_mm_mlp_adapter", False):
182
+ # Only save Adapter
183
+ keys_to_match = ['mm_projector', "ctm", "block"]
184
+ if getattr(trainer.args, "use_im_start_end", False):
185
+ keys_to_match.extend(['embed_tokens', 'embed_in'])
186
+
187
+ weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match)
188
+ trainer.model.config.save_pretrained(output_dir)
189
+
190
+ current_folder = output_dir.split('/')[-1]
191
+ parent_folder = os.path.dirname(output_dir)
192
+ if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
193
+ if current_folder.startswith('checkpoint-'):
194
+ mm_projector_folder = os.path.join(parent_folder, "mm_projector")
195
+ os.makedirs(mm_projector_folder, exist_ok=True)
196
+ torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin'))
197
+ else:
198
+ torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
199
+
200
+ if trainer.deepspeed:
201
+ torch.cuda.synchronize()
202
+ trainer.save_model(output_dir)
203
+ return
204
+
205
+ state_dict = trainer.model.state_dict()
206
+ if trainer.args.should_save:
207
+ cpu_state_dict = {
208
+ key: value.cpu()
209
+ for key, value in state_dict.items()
210
+ }
211
+ del state_dict
212
+ trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
213
+
214
+
215
+ def smart_tokenizer_and_embedding_resize(
216
+ special_tokens_dict: Dict,
217
+ tokenizer: transformers.PreTrainedTokenizer,
218
+ model: transformers.PreTrainedModel,
219
+ ):
220
+ """Resize tokenizer and embedding.
221
+
222
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
223
+ """
224
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
225
+ model.resize_token_embeddings(len(tokenizer))
226
+
227
+ if num_new_tokens > 0:
228
+ input_embeddings = model.get_input_embeddings().weight.data
229
+ output_embeddings = model.get_output_embeddings().weight.data
230
+
231
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
232
+ dim=0, keepdim=True)
233
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
234
+ dim=0, keepdim=True)
235
+
236
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
237
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
238
+
239
+
240
+ def _tokenize_fn(strings: Sequence[str],
241
+ tokenizer: transformers.PreTrainedTokenizer) -> Dict:
242
+ """Tokenize a list of strings."""
243
+ tokenized_list = [
244
+ tokenizer(
245
+ text,
246
+ return_tensors="pt",
247
+ padding="longest",
248
+ max_length=tokenizer.model_max_length,
249
+ truncation=True,
250
+ ) for text in strings
251
+ ]
252
+ input_ids = labels = [
253
+ tokenized.input_ids[0] for tokenized in tokenized_list
254
+ ]
255
+ input_ids_lens = labels_lens = [
256
+ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
257
+ for tokenized in tokenized_list
258
+ ]
259
+ return dict(
260
+ input_ids=input_ids,
261
+ labels=labels,
262
+ input_ids_lens=input_ids_lens,
263
+ labels_lens=labels_lens,
264
+ )
265
+
266
+
267
+ def _mask_targets(target, tokenized_lens, speakers):
268
+ # cur_idx = 0
269
+ cur_idx = tokenized_lens[0]
270
+ tokenized_lens = tokenized_lens[1:]
271
+ target[:cur_idx] = IGNORE_INDEX
272
+ for tokenized_len, speaker in zip(tokenized_lens, speakers):
273
+ if speaker == "human":
274
+ target[cur_idx + 2:cur_idx + tokenized_len] = IGNORE_INDEX
275
+ cur_idx += tokenized_len
276
+
277
+
278
+ def _add_speaker_and_signal(header, source, get_conversation=True):
279
+ """Add speaker and start/end signal on each round."""
280
+ BEGIN_SIGNAL = "### "
281
+ END_SIGNAL = "\n"
282
+ conversation = header
283
+ for sentence in source:
284
+ from_str = sentence["from"]
285
+ if from_str.lower() == "human":
286
+ from_str = conversation_lib.default_conversation.roles[0]
287
+ elif from_str.lower() == "gpt":
288
+ from_str = conversation_lib.default_conversation.roles[1]
289
+ else:
290
+ from_str = 'unknown'
291
+ sentence["value"] = (BEGIN_SIGNAL + from_str + ": " +
292
+ sentence["value"] + END_SIGNAL)
293
+ if get_conversation:
294
+ conversation += sentence["value"]
295
+ conversation += BEGIN_SIGNAL
296
+ return conversation
297
+
298
+
299
+ def preprocess_multimodal(
300
+ sources: Sequence[str],
301
+ data_args: DataArguments,
302
+ image_token_num=1
303
+ ) -> Dict:
304
+ is_multimodal = data_args.is_multimodal
305
+ if not is_multimodal:
306
+ return sources
307
+
308
+ for source in sources:
309
+ for sentence in source:
310
+ if DEFAULT_IMAGE_TOKEN in sentence['value'] or DEFAULT_VIDEO_TOKEN in sentence['value']:
311
+ sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN + '\n', DEFAULT_IMAGE_TOKEN).strip()
312
+ sentence['value'] = sentence['value'].replace('\n' + DEFAULT_IMAGE_TOKEN, DEFAULT_IMAGE_TOKEN).strip()
313
+ if sentence['value'].endswith(DEFAULT_IMAGE_TOKEN):
314
+ IMAGE_TOKEN_NUM = sentence['value'].count(DEFAULT_IMAGE_TOKEN)
315
+ sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN * IMAGE_TOKEN_NUM, '').strip()
316
+ sentence['value'] = DEFAULT_IMAGE_TOKEN * IMAGE_TOKEN_NUM + sentence['value']
317
+ sentence['value'] = sentence['value'].strip()
318
+ if sentence['value'].endswith(DEFAULT_VIDEO_TOKEN):
319
+ VIDEO_TOKEN_NUM = sentence['value'].count(DEFAULT_VIDEO_TOKEN)
320
+ sentence['value'] = sentence['value'].replace(DEFAULT_VIDEO_TOKEN * VIDEO_TOKEN_NUM, '').strip()
321
+ sentence['value'] = DEFAULT_VIDEO_TOKEN * VIDEO_TOKEN_NUM + sentence['value']
322
+ sentence['value'] = sentence['value'].strip()
323
+
324
+ if "mmtag" in conversation_lib.default_conversation.version:
325
+ sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN,
326
+ '<Image>' + DEFAULT_IMAGE_TOKEN + '</Image>')
327
+
328
+ IMAGE_TOKEN_NUM = sentence['value'].count(DEFAULT_IMAGE_TOKEN)
329
+ if IMAGE_TOKEN_NUM > MAX_IMAGE_LENGTH:
330
+ sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN * IMAGE_TOKEN_NUM,
331
+ DEFAULT_IMAGE_TOKEN * MAX_IMAGE_LENGTH).strip()
332
+
333
+ replace_token, vid_replace_token = DEFAULT_IMAGE_TOKEN, DEFAULT_IMAGE_TOKEN * image_token_num
334
+ if data_args.mm_use_im_start_end:
335
+ replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
336
+ vid_replace_token = DEFAULT_VID_START_TOKEN + vid_replace_token + DEFAULT_VID_END_TOKEN
337
+
338
+ sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token + '\n')
339
+ sentence['value'] = sentence['value'].replace(DEFAULT_VIDEO_TOKEN, vid_replace_token + '\n')
340
+ sentence['value'] = sentence['value'].replace('\n\n', '\n')
341
+
342
+ return sources
343
+
344
+
345
+ def preprocess_llama_2(
346
+ sources,
347
+ tokenizer: transformers.PreTrainedTokenizer,
348
+ has_image: bool = False
349
+ ) -> Dict:
350
+ conv = conversation_lib.default_conversation.copy()
351
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
352
+
353
+ # Apply prompt templates
354
+ conversations = []
355
+ for i, source in enumerate(sources):
356
+ if roles[source[0]["from"]] != conv.roles[0]:
357
+ # Skip the first one if it is not from human
358
+ source = source[1:]
359
+
360
+ conv.messages = []
361
+ for j, sentence in enumerate(source):
362
+ role = roles[sentence["from"]]
363
+ assert role == conv.roles[j % 2], f"{i}"
364
+ conv.append_message(role, sentence["value"])
365
+ conversations.append(conv.get_prompt())
366
+
367
+ # Tokenize conversations
368
+
369
+ if has_image:
370
+ input_ids = torch.stack(
371
+ [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
372
+ else:
373
+ input_ids = tokenizer(
374
+ conversations,
375
+ return_tensors="pt",
376
+ padding="longest",
377
+ max_length=tokenizer.model_max_length,
378
+ truncation=True,
379
+ ).input_ids
380
+
381
+ targets = input_ids.clone()
382
+
383
+ assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2
384
+
385
+ # Mask targets
386
+ sep = "[/INST] "
387
+ for conversation, target in zip(conversations, targets):
388
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
389
+
390
+ rounds = conversation.split(conv.sep2)
391
+
392
+ cur_len = 1
393
+ target[:cur_len] = IGNORE_INDEX
394
+
395
+ for i, rou in enumerate(rounds):
396
+ if rou == "":
397
+ break
398
+
399
+ parts = rou.split(sep)
400
+ if len(parts) != 2:
401
+ break
402
+ parts[0] += sep
403
+
404
+ if has_image:
405
+ round_len = len(tokenizer_image_token(rou, tokenizer))
406
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
407
+ else:
408
+ round_len = len(tokenizer(rou).input_ids)
409
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
410
+
411
+ target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
412
+
413
+ cur_len += round_len
414
+
415
+ if tokenizer.eos_token == tokenizer.pad_token:
416
+ cur_len += 1
417
+
418
+ target[cur_len:] = IGNORE_INDEX
419
+
420
+ if cur_len < tokenizer.model_max_length:
421
+ if cur_len != total_len:
422
+ target[:] = IGNORE_INDEX
423
+ print(
424
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
425
+ f" (ignored)"
426
+ )
427
+
428
+ return dict(
429
+ input_ids=input_ids,
430
+ labels=targets,
431
+ )
432
+
433
+
434
+ def preprocess_v1(
435
+ sources,
436
+ tokenizer: transformers.PreTrainedTokenizer,
437
+ has_image: bool = False
438
+ ) -> Dict:
439
+ conv = conversation_lib.default_conversation.copy()
440
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
441
+
442
+ # Apply prompt templates
443
+ conversations = []
444
+ for i, source in enumerate(sources):
445
+ if roles[source[0]["from"]] != conv.roles[0]:
446
+ # Skip the first one if it is not from human
447
+ source = source[1:]
448
+
449
+ conv.messages = []
450
+ for j, sentence in enumerate(source):
451
+ role = roles[sentence["from"]]
452
+ assert role == conv.roles[j % 2], f"{i}"
453
+ conv.append_message(role, sentence["value"])
454
+ conversations.append(conv.get_prompt())
455
+
456
+ # Tokenize conversations
457
+ if has_image:
458
+ input_ids = torch.stack(
459
+ [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
460
+ else:
461
+ input_ids = tokenizer(
462
+ conversations,
463
+ return_tensors="pt",
464
+ padding="longest",
465
+ max_length=tokenizer.model_max_length,
466
+ truncation=True,
467
+ ).input_ids
468
+
469
+ targets = input_ids.clone()
470
+ assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
471
+
472
+ # Mask targets
473
+ round_len_list = []
474
+ sep = conv.sep + conv.roles[1] + ": "
475
+ for conversation, target in zip(conversations, targets):
476
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
477
+
478
+ rounds = conversation.split(conv.sep2)
479
+ cur_len = 1
480
+ target[:cur_len] = IGNORE_INDEX
481
+ for i, rou in enumerate(rounds):
482
+ if rou == "":
483
+ break
484
+
485
+ parts = rou.split(sep)
486
+ if len(parts) != 2:
487
+ break
488
+ parts[0] += sep
489
+
490
+ if has_image:
491
+ round_len = len(tokenizer_image_token(rou, tokenizer))
492
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
493
+ else:
494
+ round_len = len(tokenizer(rou).input_ids)
495
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
496
+
497
+ target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
498
+ # print("rou:", rou)
499
+ # print(round_len, instruction_len)
500
+ # print(len(tokenizer(rou).input_ids), len(tokenizer_image_token(rou, tokenizer)))
501
+ cur_len += round_len
502
+ round_len_list.append(round_len)
503
+ target[cur_len:] = IGNORE_INDEX
504
+
505
+ if cur_len < tokenizer.model_max_length:
506
+ if cur_len != total_len:
507
+ # print(conversations, target, round_len_list)
508
+ target[:] = IGNORE_INDEX
509
+ print(
510
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
511
+ f" (ignored)"
512
+ )
513
+ # exit()
514
+ # print("ok", conversations, target, round_len_list)
515
+ return dict(
516
+ input_ids=input_ids,
517
+ labels=targets,
518
+ )
519
+
520
+
521
+ def preprocess_mpt(
522
+ sources,
523
+ tokenizer: transformers.PreTrainedTokenizer,
524
+ ) -> Dict:
525
+ conv = conversation_lib.default_conversation.copy()
526
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
527
+
528
+ # Apply prompt templates
529
+ conversations = []
530
+ for i, source in enumerate(sources):
531
+ if roles[source[0]["from"]] != conv.roles[0]:
532
+ # Skip the first one if it is not from human
533
+ source = source[1:]
534
+
535
+ conv.messages = []
536
+ for j, sentence in enumerate(source):
537
+ role = roles[sentence["from"]]
538
+ assert role == conv.roles[j % 2], f"{i}"
539
+ conv.append_message(role, sentence["value"])
540
+ conversations.append(conv.get_prompt())
541
+
542
+ # Tokenize conversations
543
+ input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations],
544
+ dim=0)
545
+ targets = input_ids.clone()
546
+ assert conv.sep_style == conversation_lib.SeparatorStyle.MPT
547
+
548
+ # Mask targets
549
+ sep = conv.sep + conv.roles[1]
550
+ for conversation, target in zip(conversations, targets):
551
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
552
+
553
+ rounds = conversation.split(conv.sep)
554
+ re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
555
+ for conv_idx in range(3, len(rounds), 2):
556
+ re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx + 2])) # user + gpt
557
+ cur_len = 0
558
+ target[:cur_len] = IGNORE_INDEX
559
+ for i, rou in enumerate(re_rounds):
560
+ if rou == "":
561
+ break
562
+
563
+ parts = rou.split(sep)
564
+ if len(parts) != 2:
565
+ break
566
+ parts[0] += sep
567
+ round_len = len(tokenizer_image_token(rou, tokenizer)) + len(tokenizer_image_token(conv.sep, tokenizer))
568
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer))
569
+ target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
570
+
571
+ cur_len += round_len
572
+ target[cur_len:] = IGNORE_INDEX
573
+
574
+ if cur_len < tokenizer.model_max_length:
575
+ if cur_len != total_len:
576
+ target[:] = IGNORE_INDEX
577
+ print(
578
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
579
+ f" (ignored)"
580
+ )
581
+
582
+ return dict(
583
+ input_ids=input_ids,
584
+ labels=targets,
585
+ )
586
+
587
+
588
+ def preprocess_plain(
589
+ sources: Sequence[str],
590
+ tokenizer: transformers.PreTrainedTokenizer,
591
+ ) -> Dict:
592
+ # add end signal and concatenate together
593
+ conversations = []
594
+ for source in sources:
595
+ assert len(source) == 2
596
+ assert DEFAULT_IMAGE_TOKEN in source[0]['value']
597
+ source[0]['value'] = DEFAULT_IMAGE_TOKEN
598
+ conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep
599
+ conversations.append(conversation)
600
+ # tokenize conversations
601
+ input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
602
+ targets = copy.deepcopy(input_ids)
603
+ for target, source in zip(targets, sources):
604
+ tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer))
605
+ target[:tokenized_len] = IGNORE_INDEX
606
+
607
+ return dict(input_ids=input_ids, labels=targets)
608
+
609
+
610
+ def preprocess_phi(
611
+ sources,
612
+ tokenizer: transformers.PreTrainedTokenizer,
613
+ has_image: bool = False
614
+ ) -> Dict:
615
+ conv = conversation_lib.default_conversation.copy()
616
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
617
+
618
+ # Apply prompt templates
619
+ conversations = []
620
+ for i, source in enumerate(sources):
621
+ if roles[source[0]["from"]] != conv.roles[0]:
622
+ # Skip the first one if it is not from human
623
+ source = source[1:]
624
+
625
+ conv.messages = []
626
+ for j, sentence in enumerate(source):
627
+ role = roles[sentence["from"]]
628
+ assert role == conv.roles[j % 2], f"{i}"
629
+ conv.append_message(role, sentence["value"])
630
+ conversations.append(conv.get_prompt())
631
+
632
+ # Tokenize conversations
633
+ if has_image:
634
+ input_ids = torch.stack(
635
+ [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
636
+ else:
637
+ input_ids = tokenizer(
638
+ conversations,
639
+ return_tensors="pt",
640
+ padding="longest",
641
+ max_length=tokenizer.model_max_length,
642
+ truncation=True,
643
+ ).input_ids
644
+
645
+ targets = input_ids.clone()
646
+ assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
647
+
648
+ # Mask targets
649
+ round_len_list = []
650
+ sep = conv.sep + conv.roles[1] + ": "
651
+ for conversation, target in zip(conversations, targets):
652
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
653
+
654
+ rounds = conversation.split(conv.sep2)
655
+ cur_len = 0
656
+ pre_len = 0
657
+ for i, rou in enumerate(rounds):
658
+ if rou == "":
659
+ break
660
+
661
+ parts = rou.split(sep)
662
+ if len(parts) != 2:
663
+ break
664
+ parts[0] += sep
665
+
666
+ cur_len += 1
667
+ target[pre_len: cur_len] = IGNORE_INDEX
668
+
669
+ if has_image:
670
+ round_len = len(tokenizer_image_token(rou, tokenizer))
671
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
672
+ else:
673
+ round_len = len(tokenizer(rou).input_ids)
674
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
675
+
676
+ target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
677
+ # print("rou:", rou)
678
+ # print(round_len, instruction_len)
679
+ # print(len(tokenizer(rou).input_ids), len(tokenizer_image_token(rou, tokenizer)))
680
+ cur_len += round_len
681
+ pre_len = cur_len
682
+ round_len_list.append(round_len)
683
+ target[cur_len:] = IGNORE_INDEX
684
+
685
+ if cur_len < tokenizer.model_max_length:
686
+ if cur_len != total_len + len(rounds) - 1:
687
+ # print(conversations, target, round_len_list)
688
+ target[:] = IGNORE_INDEX
689
+ print(
690
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
691
+ f" (ignored)"
692
+ )
693
+ # exit()
694
+ # print("ok", conversations, target, round_len_list)
695
+ return dict(
696
+ input_ids=input_ids,
697
+ labels=targets,
698
+ )
699
+
700
+
701
+ def preprocess(
702
+ sources: Sequence[str],
703
+ tokenizer: transformers.PreTrainedTokenizer,
704
+ has_image: bool = False
705
+ ) -> Dict:
706
+ """
707
+ Given a list of sources, each is a conversation list. This transform:
708
+ 1. Add signal '### ' at the beginning each sentence, with end signal '\n';
709
+ 2. Concatenate conversations together;
710
+ 3. Tokenize the concatenated conversation;
711
+ 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
712
+ """
713
+ if conversation_lib.default_conversation.version.startswith("phi"):
714
+ return preprocess_phi(sources, tokenizer, has_image=has_image)
715
+ if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:
716
+ return preprocess_plain(sources, tokenizer)
717
+ if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2:
718
+ return preprocess_llama_2(sources, tokenizer, has_image=has_image)
719
+ if conversation_lib.default_conversation.version.startswith("v1"):
720
+ return preprocess_v1(sources, tokenizer, has_image=has_image)
721
+ if conversation_lib.default_conversation.version == "mpt":
722
+ return preprocess_mpt(sources, tokenizer)
723
+ # add end signal and concatenate together
724
+ conversations = []
725
+ for source in sources:
726
+ header = f"{conversation_lib.default_conversation.system}\n\n"
727
+ conversation = _add_speaker_and_signal(header, source)
728
+ conversations.append(conversation)
729
+
730
+ # tokenize conversations
731
+ def get_tokenize_len(prompts):
732
+ return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts]
733
+
734
+ if has_image:
735
+ input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
736
+ else:
737
+ conversations_tokenized = _tokenize_fn(conversations, tokenizer)
738
+ input_ids = conversations_tokenized["input_ids"]
739
+
740
+ targets = copy.deepcopy(input_ids)
741
+ for target, source in zip(targets, sources):
742
+ if has_image:
743
+ tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source])
744
+ else:
745
+ tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"]
746
+ speakers = [sentence["from"] for sentence in source]
747
+ _mask_targets(target, tokenized_lens, speakers)
748
+
749
+ return dict(input_ids=input_ids, labels=targets)
750
+
751
+
752
+ class LazySupervisedDataset(Dataset):
753
+ """Dataset for supervised fine-tuning."""
754
+
755
+ def __init__(self, tokenizer: transformers.PreTrainedTokenizer,
756
+ data_args: DataArguments):
757
+ super(LazySupervisedDataset, self).__init__()
758
+
759
+ dataset_list = DataConfig[str(data_args.dataset_use)]
760
+ print(dataset_list)
761
+
762
+ self.max_length = MAX_IMAGE_LENGTH
763
+ list_data_dict = []
764
+ self.folder_dict = {}
765
+ for i in dataset_list:
766
+ list_data_dict += json.load(open(i["chat_path"], "r"))
767
+
768
+ image_folder = [folder for folder in i if folder is not "chat_path"]
769
+
770
+ for folder in image_folder:
771
+ if folder not in self.folder_dict:
772
+ self.folder_dict[folder] = i[folder]
773
+
774
+ random.shuffle(list_data_dict)
775
+
776
+ rank0_print("Formatting inputs...Skip in lazy mode")
777
+ self.tokenizer = tokenizer
778
+ self.list_data_dict = list_data_dict
779
+ self.data_args = data_args
780
+
781
+ def __len__(self):
782
+ return len(self.list_data_dict)
783
+
784
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
785
+ sources = self.list_data_dict[i]
786
+ if isinstance(i, int):
787
+ sources = [sources]
788
+ assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
789
+ if 'image' in sources[0]:
790
+ image_file = self.list_data_dict[i]['image']
791
+
792
+ file = image_file[0] if type(image_file) is list else image_file
793
+
794
+ if "llava_image" in file:
795
+ image_folder = self.folder_dict['llava']
796
+ elif "\\" in file:
797
+ image_folder = self.folder_dict['ScienceQA']
798
+ elif "CGD" in file:
799
+ image_folder = self.folder_dict['CDG']
800
+ elif "DC" in file:
801
+ image_folder = self.folder_dict['DC']
802
+ elif "LA" in file:
803
+ image_folder = self.folder_dict['LA']
804
+ elif "SD" in file:
805
+ image_folder = self.folder_dict['SD']
806
+ elif "SN" in file:
807
+ image_folder = self.folder_dict['SN']
808
+ elif "TVC" in file:
809
+ image_folder = self.folder_dict['TVC']
810
+ elif "VST" in file:
811
+ image_folder = self.folder_dict['VST']
812
+ elif "GCC" in file:
813
+ image_folder = self.folder_dict['CC3M']
814
+ elif "COCO_train2014" in file:
815
+ image_folder = self.folder_dict['COCO2014']
816
+ else:
817
+ image_folder = self.folder_dict['COCO2017']
818
+
819
+ processor = self.data_args.image_processor
820
+
821
+ if type(image_file) is list:
822
+ image = [Image.open(os.path.join(image_folder, file.replace("\\", "/"))).convert('RGB') for file in
823
+ image_file]
824
+ if self.data_args.image_aspect_ratio == 'pad':
825
+ def expand2square(pil_img, background_color):
826
+ width, height = pil_img.size
827
+ if width == height:
828
+ return pil_img
829
+ elif width > height:
830
+ result = Image.new(pil_img.mode, (width, width), background_color)
831
+ result.paste(pil_img, (0, (width - height) // 2))
832
+ return result
833
+ else:
834
+ result = Image.new(pil_img.mode, (height, height), background_color)
835
+ result.paste(pil_img, ((height - width) // 2, 0))
836
+ return result
837
+
838
+ image = [expand2square(i, tuple(int(x * 255) for x in processor.image_mean)) for i in image]
839
+ image = [processor.preprocess(i, return_tensors='pt')['pixel_values'][0] for i in image]
840
+ else:
841
+ image = [processor.preprocess(i, return_tensors='pt')['pixel_values'][0] for i in image]
842
+ else:
843
+ image = Image.open(os.path.join(image_folder, image_file.replace("\\", "/"))).convert('RGB')
844
+ if self.data_args.image_aspect_ratio == 'pad':
845
+ def expand2square(pil_img, background_color):
846
+ width, height = pil_img.size
847
+ if width == height:
848
+ return pil_img
849
+ elif width > height:
850
+ result = Image.new(pil_img.mode, (width, width), background_color)
851
+ result.paste(pil_img, (0, (width - height) // 2))
852
+ return result
853
+ else:
854
+ result = Image.new(pil_img.mode, (height, height), background_color)
855
+ result.paste(pil_img, ((height - width) // 2, 0))
856
+ return result
857
+
858
+ image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
859
+ image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
860
+ else:
861
+ image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
862
+
863
+ sources = preprocess_multimodal(
864
+ copy.deepcopy([e["conversations"] for e in sources]),
865
+ self.data_args)
866
+
867
+ data_dict = preprocess(
868
+ sources,
869
+ self.tokenizer,
870
+ has_image=True)
871
+
872
+ elif "video" in sources[0]:
873
+ video_file = self.list_data_dict[i]['video']
874
+
875
+ if "valley" in video_file:
876
+ video_folder = self.folder_dict['valley']
877
+ else:
878
+ video_folder = self.folder_dict['VIDEO']
879
+ processor = self.data_args.image_processor
880
+
881
+ if os.path.exists(os.path.join(video_folder, video_file)):
882
+ image, image_token_num = _get_rawvideo_dec(os.path.join(video_folder, video_file), processor,
883
+ max_frames=MAX_IMAGE_LENGTH)
884
+ flag = 0
885
+ else:
886
+ crop_size = self.data_args.image_processor.crop_size
887
+ image, image_token_num = torch.zeros(3, crop_size['height'], crop_size['width']), 1
888
+ flag = 1
889
+
890
+ sources = preprocess_multimodal(
891
+ copy.deepcopy([e["conversations"] for e in sources]),
892
+ self.data_args, image_token_num=image_token_num)
893
+
894
+ data_dict = preprocess(
895
+ sources,
896
+ self.tokenizer,
897
+ has_image=True)
898
+
899
+ if flag:
900
+ data_dict["labels"][:] = IGNORE_INDEX
901
+ print(
902
+ f"WARNING: video load failed: {os.path.join(video_folder, video_file)}."
903
+ f" (ignored)"
904
+ )
905
+
906
+ else:
907
+ sources = copy.deepcopy([e["conversations"] for e in sources])
908
+
909
+ data_dict = preprocess(
910
+ sources,
911
+ self.tokenizer,
912
+ has_image=False)
913
+
914
+ if isinstance(i, int):
915
+ data_dict = dict(input_ids=data_dict["input_ids"][0],
916
+ labels=data_dict["labels"][0])
917
+
918
+ # image exist in the data
919
+ if 'image' in self.list_data_dict[i] or 'video' in self.list_data_dict[i]:
920
+ data_dict['image'] = image
921
+ elif self.data_args.is_multimodal:
922
+ # image does not exist in the data, but the model is multimodal
923
+ crop_size = self.data_args.image_processor.crop_size
924
+ data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
925
+ return data_dict
926
+
927
+
928
+ @dataclass
929
+ class DataCollatorForSupervisedDataset(object):
930
+ """Collate examples for supervised fine-tuning."""
931
+
932
+ tokenizer: transformers.PreTrainedTokenizer
933
+
934
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
935
+ input_ids, labels = tuple([instance[key] for instance in instances]
936
+ for key in ("input_ids", "labels"))
937
+ input_ids = torch.nn.utils.rnn.pad_sequence(
938
+ input_ids,
939
+ batch_first=True,
940
+ padding_value=self.tokenizer.pad_token_id)
941
+ labels = torch.nn.utils.rnn.pad_sequence(labels,
942
+ batch_first=True,
943
+ padding_value=IGNORE_INDEX)
944
+ input_ids = input_ids[:, :self.tokenizer.model_max_length]
945
+ labels = labels[:, :self.tokenizer.model_max_length]
946
+ batch = dict(
947
+ input_ids=input_ids,
948
+ labels=labels,
949
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
950
+ )
951
+
952
+ if 'image' in instances[0]:
953
+ images = [instance['image'] for instance in instances]
954
+
955
+ new_images = []
956
+ for image in images:
957
+ if type(image) is list:
958
+ for i in image:
959
+ new_images.append(i)
960
+ else:
961
+ new_images.append(image)
962
+ images = new_images
963
+
964
+ if all(x is not None and x.shape == images[0].shape for x in images):
965
+ batch['images'] = torch.stack(images)
966
+ else:
967
+ batch['images'] = images
968
+
969
+ return batch
970
+
971
+
972
+ def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
973
+ data_args) -> Dict:
974
+ """Make dataset and collator for supervised fine-tuning."""
975
+ train_dataset = LazySupervisedDataset(tokenizer=tokenizer, data_args=data_args)
976
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
977
+ return dict(train_dataset=train_dataset,
978
+ eval_dataset=None,
979
+ data_collator=data_collator)
980
+
981
+
982
+ def train():
983
+ global local_rank
984
+
985
+ parser = transformers.HfArgumentParser(
986
+ (ModelArguments, DataArguments, TrainingArguments))
987
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
988
+ local_rank = training_args.local_rank
989
+ compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
990
+
991
+ random.seed(training_args.seed)
992
+ os.environ['PYTHONHASHSEED'] = str(training_args.seed)
993
+ np.random.seed(training_args.seed)
994
+ torch.manual_seed(training_args.seed)
995
+ torch.cuda.manual_seed(training_args.seed)
996
+ torch.cuda.manual_seed_all(training_args.seed) # if you are using multi-GPU.
997
+ torch.backends.cudnn.benchmark = False
998
+ torch.backends.cudnn.deterministic = True
999
+
1000
+ bnb_model_from_pretrained_args = {}
1001
+ if training_args.bits in [4, 8]:
1002
+ from transformers import BitsAndBytesConfig
1003
+ bnb_model_from_pretrained_args.update(dict(
1004
+ device_map={"": training_args.device},
1005
+ load_in_4bit=training_args.bits == 4,
1006
+ load_in_8bit=training_args.bits == 8,
1007
+ quantization_config=BitsAndBytesConfig(
1008
+ load_in_4bit=training_args.bits == 4,
1009
+ load_in_8bit=training_args.bits == 8,
1010
+ llm_int8_threshold=6.0,
1011
+ llm_int8_has_fp16_weight=False,
1012
+ bnb_4bit_compute_dtype=compute_dtype,
1013
+ bnb_4bit_use_double_quant=training_args.double_quant,
1014
+ bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}
1015
+ )
1016
+ ))
1017
+
1018
+ if model_args.vision_tower is not None:
1019
+ if "phi" in model_args.model_name_or_path.lower():
1020
+ from ChatUniVi.model.language_model.phi import ChatUniViPhiForCausalLM
1021
+ model = ChatUniViPhiForCausalLM.from_pretrained(
1022
+ model_args.model_name_or_path,
1023
+ cache_dir=training_args.cache_dir,
1024
+ **bnb_model_from_pretrained_args
1025
+ )
1026
+ else:
1027
+ model = ChatUniViLlamaForCausalLM.from_pretrained(
1028
+ model_args.model_name_or_path,
1029
+ cache_dir=training_args.cache_dir,
1030
+ **bnb_model_from_pretrained_args
1031
+ )
1032
+ else:
1033
+ model = transformers.LlamaForCausalLM.from_pretrained(
1034
+ model_args.model_name_or_path,
1035
+ cache_dir=training_args.cache_dir,
1036
+ **bnb_model_from_pretrained_args
1037
+ )
1038
+ model.config.use_cache = False
1039
+
1040
+ if model_args.freeze_backbone:
1041
+ model.model.requires_grad_(False)
1042
+
1043
+ if training_args.bits in [4, 8]:
1044
+ from peft import prepare_model_for_kbit_training
1045
+ model.config.torch_dtype = (
1046
+ torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
1047
+ model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
1048
+
1049
+ if training_args.gradient_checkpointing:
1050
+ if hasattr(model, "enable_input_require_grads"):
1051
+ model.enable_input_require_grads()
1052
+ else:
1053
+ def make_inputs_require_grad(module, input, output):
1054
+ output.requires_grad_(True)
1055
+
1056
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
1057
+
1058
+ if training_args.lora_enable:
1059
+ from peft import LoraConfig, get_peft_model
1060
+ lora_config = LoraConfig(
1061
+ r=training_args.lora_r,
1062
+ lora_alpha=training_args.lora_alpha,
1063
+ target_modules=find_all_linear_names(model),
1064
+ lora_dropout=training_args.lora_dropout,
1065
+ bias=training_args.lora_bias,
1066
+ task_type="CAUSAL_LM",
1067
+ )
1068
+ if training_args.bits == 16:
1069
+ if training_args.bf16:
1070
+ model.to(torch.bfloat16)
1071
+ if training_args.fp16:
1072
+ model.to(torch.float16)
1073
+ rank0_print("Adding LoRA adapters...")
1074
+ model = get_peft_model(model, lora_config)
1075
+
1076
+ if 'mpt' in model_args.model_name_or_path:
1077
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
1078
+ model_args.model_name_or_path,
1079
+ cache_dir=training_args.cache_dir,
1080
+ model_max_length=training_args.model_max_length,
1081
+ padding_side="right"
1082
+ )
1083
+ else:
1084
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
1085
+ model_args.model_name_or_path,
1086
+ cache_dir=training_args.cache_dir,
1087
+ model_max_length=training_args.model_max_length,
1088
+ padding_side="right",
1089
+ use_fast=True,
1090
+ )
1091
+
1092
+ if model_args.version == "v0":
1093
+ if tokenizer.pad_token is None:
1094
+ smart_tokenizer_and_embedding_resize(
1095
+ special_tokens_dict=dict(pad_token="[PAD]"),
1096
+ tokenizer=tokenizer,
1097
+ model=model,
1098
+ )
1099
+ if "llama" in model_args.model_name_or_path.lower():
1100
+ tokenizer.add_special_tokens({
1101
+ "eos_token": "</s>",
1102
+ "bos_token": "<s>",
1103
+ "unk_token": "<unk>",
1104
+ })
1105
+ elif model_args.version == "v0.5":
1106
+ tokenizer.pad_token = tokenizer.unk_token
1107
+ elif model_args.version == "phi":
1108
+ tokenizer.pad_token = tokenizer.unk_token
1109
+ conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
1110
+ else:
1111
+ tokenizer.pad_token = tokenizer.unk_token
1112
+ if model_args.version in conversation_lib.conv_templates:
1113
+ conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
1114
+ else:
1115
+ conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"]
1116
+
1117
+ if model_args.vision_tower is not None:
1118
+ model.get_model().initialize_vision_modules(
1119
+ model_args=model_args,
1120
+ fsdp=training_args.fsdp
1121
+ )
1122
+
1123
+ vision_tower = model.get_vision_tower()
1124
+ vision_tower.to(dtype=torch.float16, device=training_args.device)
1125
+
1126
+ data_args.image_processor = vision_tower.image_processor
1127
+ data_args.is_multimodal = True
1128
+
1129
+ model.config.image_aspect_ratio = data_args.image_aspect_ratio
1130
+ model.config.image_grid_pinpoints = data_args.image_grid_pinpoints
1131
+
1132
+ model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter
1133
+ if model_args.tune_mm_mlp_adapter:
1134
+ model.requires_grad_(False)
1135
+ for p in model.get_model().mm_projector.parameters():
1136
+ p.requires_grad = True
1137
+
1138
+ model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter
1139
+ if training_args.freeze_mm_mlp_adapter:
1140
+ for p in model.get_model().mm_projector.parameters():
1141
+ p.requires_grad = False
1142
+
1143
+ if training_args.bits in [4, 8]:
1144
+ model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device)
1145
+
1146
+ model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end
1147
+ training_args.use_im_start_end = model_args.mm_use_im_start_end
1148
+
1149
+ model.config.mm_use_box_start_end = data_args.mm_use_box_start_end = model_args.mm_use_box_start_end
1150
+ training_args.use_im_start_end = model_args.mm_use_box_start_end
1151
+
1152
+ model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token
1153
+ model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)
1154
+
1155
+ model_config = ModelConfig[str(model_args.model_use)]
1156
+ model.config.aarchitectures = "LlavaLlamaForCausalLM"
1157
+
1158
+ model.config.config = model_config
1159
+ model_args.use_cluster = model_config["use_cluster"]
1160
+ model_args.spatial_cluster_rate0 = model_config["spatial_cluster_rate0"]
1161
+ model_args.spatial_cluster_rate1 = model_config["spatial_cluster_rate1"]
1162
+ model_args.spatial_cluster_rate2 = model_config["spatial_cluster_rate2"]
1163
+ model_args.temporal_cluster_rate = model_config.get("temporal_cluster_rate", 1 / 16)
1164
+ model.get_model().initialize_cluster_modules(model_args)
1165
+
1166
+ if model_args.use_cluster:
1167
+ for n, p in model.named_parameters():
1168
+ if "block" in n or "ctm" in n:
1169
+ p.requires_grad = True
1170
+
1171
+ if model.config.config["freeze"]:
1172
+ for n, p in model.named_parameters():
1173
+ if "block" not in n and "ctm" not in n:
1174
+ p.requires_grad = False
1175
+
1176
+ if model.config.config["mm_tune"]:
1177
+ for p in model.get_model().mm_projector.parameters():
1178
+ p.requires_grad = True
1179
+
1180
+ model_args.vision_tune = model_config["vision_tune"]
1181
+ for p in model.get_vision_tower().parameters():
1182
+ p.requires_grad = model_args.vision_tune
1183
+
1184
+ params_need_grad = [n for n, p in model.named_parameters() if p.requires_grad]
1185
+ print("Parameters require gradients: {}".format(params_need_grad))
1186
+
1187
+ if training_args.bits in [4, 8]:
1188
+ from peft.tuners.lora import LoraLayer
1189
+ for name, module in model.named_modules():
1190
+ if isinstance(module, LoraLayer):
1191
+ if training_args.bf16:
1192
+ module = module.to(torch.bfloat16)
1193
+ if 'norm' in name:
1194
+ module = module.to(torch.float32)
1195
+ if 'lm_head' in name or 'embed_tokens' in name:
1196
+ if hasattr(module, 'weight'):
1197
+ if training_args.bf16 and module.weight.dtype == torch.float32:
1198
+ module = module.to(torch.bfloat16)
1199
+
1200
+ data_module = make_supervised_data_module(tokenizer=tokenizer,
1201
+ data_args=data_args)
1202
+
1203
+ trainer = ChatUniViTrainer(model=model,
1204
+ tokenizer=tokenizer,
1205
+ args=training_args,
1206
+ **data_module)
1207
+
1208
+ if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
1209
+ trainer.train(resume_from_checkpoint=True)
1210
+ else:
1211
+ trainer.train()
1212
+
1213
+ model.config.use_cache = True
1214
+
1215
+ if training_args.lora_enable:
1216
+ state_dict = get_peft_state_maybe_zero_3(
1217
+ model.named_parameters(), training_args.lora_bias
1218
+ )
1219
+ non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
1220
+ model.named_parameters()
1221
+ )
1222
+ if training_args.local_rank == 0 or training_args.local_rank == -1:
1223
+ model.config.save_pretrained(training_args.output_dir)
1224
+ model.save_pretrained(training_args.output_dir, state_dict=state_dict)
1225
+ torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))
1226
+ else:
1227
+ safe_save_model_for_hf_trainer(trainer=trainer,
1228
+ output_dir=training_args.output_dir)
1229
+
1230
+
1231
+ if __name__ == "__main__":
1232
+ train()
ChatUniVi/train/train_mem.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
2
+ # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
3
+ # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
4
+
5
+ # Need to call this before importing transformers.
6
+ from ChatUniVi.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
7
+
8
+ replace_llama_attn_with_flash_attn()
9
+
10
+ from ChatUniVi.train.train import train
11
+
12
+ if __name__ == "__main__":
13
+ train()
ChatUniVi/train/trainer.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from transformers import Trainer
4
+ from typing import Optional
5
+
6
+
7
+ def maybe_zero_3(param, ignore_status=False, name=None):
8
+ from deepspeed import zero
9
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
10
+ if hasattr(param, "ds_id"):
11
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
12
+ if not ignore_status:
13
+ print(name, 'no ignore status')
14
+ with zero.GatheredParameters([param]):
15
+ param = param.data.detach().cpu().clone()
16
+ else:
17
+ param = param.detach().cpu().clone()
18
+ return param
19
+
20
+
21
+ def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
22
+ to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
23
+ to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}
24
+ return to_return
25
+
26
+
27
+ class ChatUniViTrainer(Trainer):
28
+ def _save_checkpoint(self, model, trial, metrics=None):
29
+ if 0 and getattr(self.args, 'tune_mm_mlp_adapter', False):
30
+ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
31
+ checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
32
+
33
+ run_dir = self._get_output_dir(trial=trial)
34
+ output_dir = os.path.join(run_dir, checkpoint_folder)
35
+
36
+ # Only save Adapter
37
+ keys_to_match = ['mm_projector', "ctm", "block"]
38
+ if getattr(self.args, "use_im_start_end", False):
39
+ keys_to_match.extend(['embed_tokens', 'embed_in'])
40
+
41
+ weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match)
42
+
43
+ if self.args.local_rank == 0 or self.args.local_rank == -1:
44
+ self.model.config.save_pretrained(output_dir)
45
+ torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
46
+ else:
47
+ super(ChatUniViTrainer, self)._save_checkpoint(model, trial, metrics)
48
+
49
+ def _save(self, output_dir: Optional[str] = None, state_dict=None):
50
+ if 0 and getattr(self.args, 'tune_mm_mlp_adapter', False):
51
+ pass
52
+ else:
53
+ super(ChatUniViTrainer, self)._save(output_dir, state_dict)
configs/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .config import args
configs/config.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from email.policy import default
2
+ import os
3
+
4
+ import sys
5
+ BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
6
+ sys.path.append(BASE_DIR)
7
+
8
+ import cv2 # type: ignore
9
+
10
+ import argparse
11
+ import json
12
+ import os
13
+ from typing import Any, Dict, List
14
+
15
+ # 数据集结构
16
+ file_arch = """
17
+ ./REFAVS/data
18
+ - /media
19
+ - /gt_mask
20
+ - /metadata.csv
21
+ - /audio_embed
22
+ - /image_embed
23
+ """
24
+ # print(f">>> File arch: {file_arch}")
25
+
26
+ parser = argparse.ArgumentParser(
27
+ description=(
28
+ "SimToken"
29
+ )
30
+ )
31
+
32
+
33
+
34
+ parser.add_argument("--vision_pretrained",type=str,default='/workspace/SimToken/models/segment_anything/sam_vit_h_4b8939.pth')
35
+ parser.add_argument("--vision_tower",type=str,default='openai/clip-vit-large-patch14')
36
+ parser.add_argument("--mllm",type=str,default='Chat-UniVi/Chat-UniVi-7B-v1.5')
37
+
38
+ parser.add_argument("--conv_template",type=int,default=1)
39
+ parser.add_argument("--ct_weight",type=float,default=0.1)
40
+ parser.add_argument("--input_type",type=str,default='refer')
41
+ parser.add_argument("--compress",action='store_false',default=True)
42
+ parser.add_argument("--start",type=int,default=0)
43
+
44
+
45
+ parser.add_argument("--name",type=str,default='testrun')
46
+ # path to ref-avs dataset
47
+ parser.add_argument("--data_dir",type=str,default='/workspace/SimToken/data',help=f"The data paranet dir. File arch should be: {file_arch}")
48
+ # path to pretrained checkpoints
49
+ parser.add_argument("--saved_model",type=str,default='/workspace/SimToken/checkpoints/simtoken_pretrained.pth', help="the pretrained simtoken pth")
50
+
51
+
52
+ parser.add_argument("--log_root",type=str,default='log', help="where to save log during training")
53
+ parser.add_argument("--checkpoint_root",type=str,default='checkpoints', help="where to save trained checkpoints during training")
54
+
55
+ parser.add_argument("--visualization_root",type=str,default='visualization', help="where to save visualization result during test")
56
+
57
+
58
+
59
+
60
+ # parser.add_argument("--show_params", action='store_true', help=f"Show params names with Requires_grad==True.")
61
+
62
+ # learning rate
63
+ parser.add_argument("--lr", type=float, default=5e-5, help='lr to fine tuning adapters.')
64
+ # epochs
65
+ parser.add_argument("--epochs", type=int, default=10, help='epochs to fine tuning adapters.')
66
+ parser.add_argument("--batch_size", type=int, default=8)
67
+
68
+
69
+ parser.add_argument("--gpu_id", type=str, default="0", help="The GPU device to run generation on.")
70
+
71
+ parser.add_argument("--run", type=str, default='train', help="train, test")
72
+
73
+ parser.add_argument("--frame_n", type=int, default=10, help="Frame num of each video. Fixed to 10.")
74
+ parser.add_argument("--text_max_len", type=int, default=25, help="Maximum textual reference length.")
75
+ parser.add_argument("--max_eval_rows", type=int, default=-1, help="Max samples per split during eval; -1 = all.")
76
+ parser.add_argument("--eval_split", type=str, default="test_u", help="Which split to evaluate: test_s, test_u, test_n.")
77
+
78
+
79
+
80
+ args = parser.parse_args()
81
+
82
+ # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
83
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
84
+ # print(f'>>> Sys: set "CUDA_VISIBLE_DEVICES" - GPU: {args.gpu_id}')
data/metadata.csv ADDED
The diff for this file is too large to render. See raw diff