nateraw commited on
Commit
39d5658
1 Parent(s): b60f97b

Upload . with huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +7 -0
  2. .gitignore +129 -0
  3. CODE_OF_CONDUCT.md +5 -0
  4. CONTRIBUTING.md +39 -0
  5. LICENSE +22 -0
  6. app.py +146 -0
  7. assets/06919917-76bc-4adc-b944-2a722f165513.gif +3 -0
  8. assets/3c0dffd0-e38e-4643-bc48-d513943dc20b_012_014.mp4 +0 -0
  9. assets/ab865129-78fa-47d4-8a50-ff8c5533246f.gif +3 -0
  10. assets/cf7c12db-1a9e-46d3-96d6-38174bbe373c.gif +3 -0
  11. assets/mixkit-chef-preparing-a-sauce-in-a-blender-43034-medium.gif +3 -0
  12. assets/mixkit-hands-of-a-baker-kneading-a-dough-42467-medium.gif +3 -0
  13. assets/mixkit-pastry-chef-cutting-a-loaf-into-slices-43015-medium.gif +3 -0
  14. assets/narrator.gif +3 -0
  15. assets/rephraser.gif +0 -0
  16. datasets/README.md +153 -0
  17. demo_narrator.py +97 -0
  18. demo_narrator_3rd_person.py +99 -0
  19. docs/INSTALL.md +15 -0
  20. docs/MODEL_ZOO.md +311 -0
  21. docs/PRETRAIN.md +125 -0
  22. eval_narrator.py +308 -0
  23. eval_zeroshot.py +389 -0
  24. lavila/data/__pycache__/datasets.cpython-38.pyc +0 -0
  25. lavila/data/__pycache__/video_transforms.cpython-38.pyc +0 -0
  26. lavila/data/datasets.py +517 -0
  27. lavila/data/video_transforms.py +186 -0
  28. lavila/models/__pycache__/distributed_utils.cpython-38.pyc +0 -0
  29. lavila/models/__pycache__/gpt2_gated.cpython-38.pyc +0 -0
  30. lavila/models/__pycache__/loss.cpython-38.pyc +0 -0
  31. lavila/models/__pycache__/models.cpython-38.pyc +0 -0
  32. lavila/models/bpe_simple_vocab_16e6.txt.gz +3 -0
  33. lavila/models/coca.py +131 -0
  34. lavila/models/distributed_utils.py +89 -0
  35. lavila/models/gpt2_gated.py +1615 -0
  36. lavila/models/loss.py +367 -0
  37. lavila/models/models.py +1218 -0
  38. lavila/models/narrator.py +385 -0
  39. lavila/models/openai_clip.py +237 -0
  40. lavila/models/openai_model.py +485 -0
  41. lavila/models/timesformer.py +390 -0
  42. lavila/models/tokenizer.py +239 -0
  43. lavila/models/utils.py +108 -0
  44. lavila/utils/distributed.py +102 -0
  45. lavila/utils/evaluation.py +36 -0
  46. lavila/utils/evaluation_charades.py +53 -0
  47. lavila/utils/evaluation_egomcq.py +25 -0
  48. lavila/utils/evaluation_ek100cls.py +35 -0
  49. lavila/utils/evaluation_ek100mir.py +201 -0
  50. lavila/utils/meter.py +65 -0
.gitattributes CHANGED
@@ -32,3 +32,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ assets/mixkit-hands-of-a-baker-kneading-a-dough-42467-medium.gif filter=lfs diff=lfs merge=lfs -text
36
+ assets/mixkit-pastry-chef-cutting-a-loaf-into-slices-43015-medium.gif filter=lfs diff=lfs merge=lfs -text
37
+ assets/ab865129-78fa-47d4-8a50-ff8c5533246f.gif filter=lfs diff=lfs merge=lfs -text
38
+ assets/06919917-76bc-4adc-b944-2a722f165513.gif filter=lfs diff=lfs merge=lfs -text
39
+ assets/cf7c12db-1a9e-46d3-96d6-38174bbe373c.gif filter=lfs diff=lfs merge=lfs -text
40
+ assets/narrator.gif filter=lfs diff=lfs merge=lfs -text
41
+ assets/mixkit-chef-preparing-a-sauce-in-a-blender-43034-medium.gif filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Code of Conduct
2
+
3
+ Facebook has adopted a Code of Conduct that we expect project participants to adhere to.
4
+ Please read the [full text](https://code.fb.com/codeofconduct/)
5
+ so that you can understand what actions will and will not be tolerated.
CONTRIBUTING.md ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to LaViLa
2
+ We want to make contributing to this project as easy and transparent as
3
+ possible.
4
+
5
+ ## Our Development Process
6
+ Minor changes and improvements will be released on an ongoing basis. Larger changes (e.g., changesets implementing a new paper) will be released on a more periodic basis.
7
+
8
+ ## Pull Requests
9
+ We actively welcome your pull requests.
10
+
11
+ 1. Fork the repo and create your branch from `main`.
12
+ 2. If you've added code that should be tested, add tests.
13
+ 3. If you've changed APIs, update the documentation.
14
+ 4. Ensure the test suite passes.
15
+ 5. Make sure your code lints.
16
+ 6. If you haven't already, complete the Contributor License Agreement ("CLA").
17
+
18
+ ## Contributor License Agreement ("CLA")
19
+ In order to accept your pull request, we need you to submit a CLA. You only need
20
+ to do this once to work on any of Facebook's open source projects.
21
+
22
+ Complete your CLA here: <https://code.facebook.com/cla>
23
+
24
+ ## Issues
25
+ We use GitHub issues to track public bugs. Please ensure your description is
26
+ clear and has sufficient instructions to be able to reproduce the issue.
27
+
28
+ Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
29
+ disclosure of security bugs. In those cases, please go through the process
30
+ outlined on that page and do not file a public issue.
31
+
32
+ ## Coding Style
33
+ * 4 spaces for indentation rather than tabs
34
+ * 80 character line length
35
+ * PEP8 formatting following [Black](https://black.readthedocs.io/en/stable/)
36
+
37
+ ## License
38
+ By contributing to LaViLa, you agree that your contributions will be licensed
39
+ under the LICENSE file in the root directory of this source tree.
LICENSE ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ MIT License
3
+
4
+ Copyright (c) Meta Platforms, Inc. and affiliates.
5
+
6
+ Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ of this software and associated documentation files (the "Software"), to deal
8
+ in the Software without restriction, including without limitation the rights
9
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ copies of the Software, and to permit persons to whom the Software is
11
+ furnished to do so, subject to the following conditions:
12
+
13
+ The above copyright notice and this permission notice shall be included in all
14
+ copies or substantial portions of the Software.
15
+
16
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ SOFTWARE.
app.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.insert(0, './')
3
+
4
+ import decord
5
+ import numpy as np
6
+ import torch
7
+ import os
8
+
9
+ from lavila.data.video_transforms import Permute
10
+ from lavila.data.datasets import get_frame_ids, video_loader_by_frames
11
+ from lavila.models.models import VCLM_OPENAI_TIMESFORMER_BASE_GPT2
12
+ from lavila.models.tokenizer import MyGPT2Tokenizer
13
+ from collections import OrderedDict
14
+ import torch
15
+ import torchvision.transforms as transforms
16
+ import torchvision.transforms._transforms_video as transforms_video
17
+ import gradio as gr
18
+
19
+ def get_frame_ids(start_frame, end_frame, num_segments=32, jitter=True):
20
+ seg_size = float(end_frame - start_frame - 1) / num_segments
21
+ seq = []
22
+ for i in range(num_segments):
23
+ start = int(np.round(seg_size * i) + start_frame)
24
+ end = int(np.round(seg_size * (i + 1)) + start_frame)
25
+ end = min(end, end_frame)
26
+ if jitter:
27
+ frame_id = np.random.randint(low=start, high=(end + 1))
28
+ else:
29
+ frame_id = (start + end) // 2
30
+ seq.append(frame_id)
31
+ return seq
32
+
33
+ def video_loader_by_frames(root, vid, frame_ids):
34
+ vr = decord.VideoReader(os.path.join(root, vid))
35
+ try:
36
+ frames = vr.get_batch(frame_ids).asnumpy()
37
+ frames = [torch.tensor(frame, dtype=torch.float32) for frame in frames]
38
+ except (IndexError, decord.DECORDError) as error:
39
+ print(error)
40
+ print("Erroneous video: ", vid)
41
+ frames = [torch.zeros((240, 320, 3)) for _ in range(len(frame_ids))]
42
+ return torch.stack(frames, dim=0)
43
+
44
+ def iter_clips(video_path, num_segments=4, stride_size=16):
45
+ # The video is represented by `num_seg=4` frames
46
+ vr = decord.VideoReader(video_path)
47
+ frame_sample_size = num_segments * stride_size
48
+ max_start_frame = len(vr) - frame_sample_size
49
+ curr_frame = 0
50
+ fps = vr.get_avg_fps()
51
+ while curr_frame < max_start_frame:
52
+ stop_frame = min(frame_sample_size, len(vr))
53
+ curr_sec, stop_sec = curr_frame / fps, stop_frame / fps
54
+ frame_ids = get_frame_ids(curr_frame, stop_frame, num_segments=num_segments, jitter=False)
55
+ frames = video_loader_by_frames('./', video_path, frame_ids)
56
+ yield curr_sec, stop_sec, frames
57
+ curr_frame += frame_sample_size
58
+
59
+
60
+ class Pipeline:
61
+ def __init__(self, path=""):
62
+ ckpt_path = os.path.join(path, 'vclm_openai_timesformer_base_gpt2_base.pt_ego4d.jobid_319630.ep_0002.md5sum_68a71f.pth')
63
+ ckpt = torch.load(ckpt_path, map_location='cpu')
64
+ state_dict = OrderedDict()
65
+ for k, v in ckpt['state_dict'].items():
66
+ state_dict[k.replace('module.', '')] = v
67
+
68
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
69
+ self.model = VCLM_OPENAI_TIMESFORMER_BASE_GPT2(
70
+ text_use_cls_token=False,
71
+ project_embed_dim=256,
72
+ gated_xattn=True,
73
+ timesformer_gated_xattn=False,
74
+ freeze_lm_vclm=False,
75
+ freeze_visual_vclm=False,
76
+ freeze_visual_vclm_temporal=False,
77
+ num_frames=4,
78
+ drop_path_rate=0.
79
+ )
80
+ self.model.load_state_dict(state_dict, strict=True)
81
+ self.model.to(self.device)
82
+ self.model.eval()
83
+
84
+ self.tokenizer = MyGPT2Tokenizer('gpt2', add_bos=True)
85
+
86
+ crop_size = 224
87
+ self.val_transform = transforms.Compose([
88
+ Permute([3, 0, 1, 2]),
89
+ transforms.Resize(crop_size),
90
+ transforms.CenterCrop(crop_size),
91
+ transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305])
92
+ ])
93
+
94
+ def decode_one(self, generated_ids, tokenizer):
95
+ # get the index of <EOS>
96
+ if tokenizer.eos_token_id == tokenizer.bos_token_id:
97
+ if tokenizer.eos_token_id in generated_ids[1:].tolist():
98
+ eos_id = generated_ids[1:].tolist().index(tokenizer.eos_token_id) + 1
99
+ else:
100
+ eos_id = len(generated_ids.tolist()) - 1
101
+ elif tokenizer.eos_token_id in generated_ids.tolist():
102
+ eos_id = generated_ids.tolist().index(tokenizer.eos_token_id)
103
+ else:
104
+ eos_id = len(generated_ids.tolist()) - 1
105
+ generated_text_str = tokenizer.tokenizer.decode(generated_ids[1:eos_id].tolist())
106
+ return generated_text_str
107
+
108
+ def __call__(self, video_path, temperature=0.7, top_p=0.95, max_text_length=77, num_return_sequences=10):
109
+ text = ""
110
+ with torch.autocast(self.device):
111
+ for start, stop, frames in iter_clips(video_path):
112
+ text += f"{'-'*30} Predictions From: {start:10.3f}-{stop:10.3f} seconds {'-'*30}\n"
113
+ frames = self.val_transform(frames).unsqueeze(0)
114
+ if self.device == 'cuda':
115
+ frames = frames.to(self.device).half()
116
+
117
+ with torch.no_grad():
118
+ image_features = self.model.encode_image(frames)
119
+ generated_text_ids, ppls = self.model.generate(
120
+ image_features,
121
+ self.tokenizer,
122
+ target=None, # free-form generation
123
+ max_text_length=max_text_length,
124
+ top_k=None,
125
+ top_p=top_p, # nucleus sampling
126
+ num_return_sequences=num_return_sequences, # number of candidates: 10
127
+ temperature=temperature,
128
+ early_stopping=True,
129
+ )
130
+ for i in range(num_return_sequences):
131
+ generated_text_str = self.decode_one(generated_text_ids[i], self.tokenizer)
132
+ text += '\t{}: {}\n'.format(i, generated_text_str)
133
+ return text
134
+
135
+ interface = gr.Interface(
136
+ Pipeline(),
137
+ inputs=[
138
+ gr.Video(label='video_path'),
139
+ gr.Slider(0.0, 1.0, 0.7, label='temperature'),
140
+ gr.Slider(0.0, 1.0, 0.95, label='top_p'),
141
+ ],
142
+ outputs='text'
143
+ )
144
+
145
+ if __name__ == '__main__':
146
+ interface.launch(debug=True)
assets/06919917-76bc-4adc-b944-2a722f165513.gif ADDED

Git LFS Details

  • SHA256: 1f9162772e374a719d9ad0c2237afb787513508909b408507c16621ab90593e9
  • Pointer size: 132 Bytes
  • Size of remote file: 4.97 MB
assets/3c0dffd0-e38e-4643-bc48-d513943dc20b_012_014.mp4 ADDED
Binary file (127 kB). View file
 
assets/ab865129-78fa-47d4-8a50-ff8c5533246f.gif ADDED

Git LFS Details

  • SHA256: ed707740d4b644291abd106c7e5f98fbb39f4830a81c67aacbc22a82945374db
  • Pointer size: 132 Bytes
  • Size of remote file: 4.29 MB
assets/cf7c12db-1a9e-46d3-96d6-38174bbe373c.gif ADDED

Git LFS Details

  • SHA256: 8bc5878f7c211098062dafc4f2f6771b820bd9c735f0c22e32f691e22073d109
  • Pointer size: 132 Bytes
  • Size of remote file: 3.59 MB
assets/mixkit-chef-preparing-a-sauce-in-a-blender-43034-medium.gif ADDED

Git LFS Details

  • SHA256: f0d95669ec321ff88e10fe00b72effaaba51462d97bc54555abc6b6b12836259
  • Pointer size: 132 Bytes
  • Size of remote file: 1.41 MB
assets/mixkit-hands-of-a-baker-kneading-a-dough-42467-medium.gif ADDED

Git LFS Details

  • SHA256: 24631b3b6fa2e07dfff335df2346fbb754ead3a3d43d9019432147dd0fca9ecc
  • Pointer size: 132 Bytes
  • Size of remote file: 1.94 MB
assets/mixkit-pastry-chef-cutting-a-loaf-into-slices-43015-medium.gif ADDED

Git LFS Details

  • SHA256: bb6a201a3496f368b35deccb54e9744d30206d4162b803279f4441738f29230b
  • Pointer size: 132 Bytes
  • Size of remote file: 1.16 MB
assets/narrator.gif ADDED

Git LFS Details

  • SHA256: 00288fdeac22cd4617922083810c7562ed062236df8f76e37a383d3d44f00297
  • Pointer size: 132 Bytes
  • Size of remote file: 1.95 MB
assets/rephraser.gif ADDED
datasets/README.md ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Preparing datasets for LAVILA
2
+
3
+ Please download the (selected) datasets from the official websites and place or sim-link them under `$LAVILA_ROOT/datasets/`.
4
+
5
+ ```bash
6
+ $LAVILA_ROOT/datasets/
7
+ CharadesEgo/
8
+ EGTEA/
9
+ EK100/
10
+ Ego4D/
11
+ ```
12
+
13
+ ## Ego4D
14
+ 1. Download [Ego4D videos](https://ego4d-data.org/docs/start-here/#download-data) (license is required).
15
+
16
+ 2. Preprocess(TBA)
17
+
18
+ 3. Download annotations
19
+
20
+ a. Download [egomcq.json](https://drive.google.com/file/d/1-5iRYf4BCHmj4MYQYFRMY4bhsWJUN3rW/view) to `$LAVILA_ROOT/datasets/Ego4D` (if you want to evaluate EgoMCQ).
21
+
22
+ b. Download [metadata for train split](https://dl.fbaipublicfiles.com/lavila/metadata/ego4d/ego4d_train.pkl) and [val split](https://dl.fbaipublicfiles.com/lavila/metadata/ego4d/ego4d_val.pkl) to `$LAVILA_ROOT/datasets/Ego4D` ((if you want to train LAVILA from scratch).
23
+
24
+ The fold should look like this:
25
+ ```bash
26
+ $LAVILA_ROOT/datasets/
27
+ Ego4D/
28
+ ego4d_train.pkl
29
+ ego4d_val.pkl
30
+ egomcq.json
31
+ video_288px/
32
+ 000786a7-3f9d-4fe6-bfb3-045b368f7d44.mp4/
33
+ 0.mp4
34
+ 300.mp4
35
+ 000a3525-6c98-4650-aaab-be7d2c7b9402.mp4/
36
+ 0.mp4
37
+ ...
38
+ ```
39
+
40
+
41
+ ## EPIC-Kitchens-100 (EK-100)
42
+
43
+ 1. Download annotations
44
+
45
+ ```bash
46
+ # Assume that you are under `datasets/EK100/`
47
+ git clone https://github.com/epic-kitchens/epic-kitchens-100-annotations
48
+ ```
49
+
50
+ 2. Download videos.
51
+
52
+ a. For raw videos, please download them from [https://epic-kitchens.github.io/](https://epic-kitchens.github.io/).
53
+
54
+ b. (Recommended) The raw videos are huge (~1 TB). As an alternative, please check out a [resized version]().
55
+
56
+ 3. (For EK-100 MIR)
57
+
58
+ a. Generate the relevancy matrix of train/val splits using [the official code](https://github.com/mwray/Joint-Part-of-Speech-Embeddings).
59
+
60
+ b. (Recommended) The generated result has some randomness. Therefore, we also provide the [replica of train split](https://dl.fbaipublicfiles.com/lavila/metadata/EK100/caption_relevancy_EPIC_100_retrieval_train.pkl) and [val split](https://dl.fbaipublicfiles.com/lavila/metadata/EK100/caption_relevancy_EPIC_100_retrieval_test.pkl). Please put them to the folder `$LAVILA_ROOT/datasets/EK100/epic-kitchens-100-annotations/retrieval_annotations/relevancy/`.
61
+
62
+
63
+ The folder should look like this:
64
+ ```bash
65
+ $LAVILA_ROOT/datasets/
66
+ EK100/
67
+ epic-kitchens-100-annotations/
68
+ EPIC_100_train.csv
69
+ EPIC_100_validation.csv
70
+ ...
71
+ retrieval_annotations/relevancy/ # this appears if you do 3.
72
+ caption_relevancy_EPIC_100_retrieval_train.pkl
73
+ caption_relevancy_EPIC_100_retrieval_test.pkl
74
+ video_ht256px/
75
+ P01/
76
+ P01_01.MP4
77
+ P01_02.MP4
78
+ ...
79
+ P01_19.MP4
80
+ P02/
81
+ P02_01.MP4
82
+ P02_02.MP4
83
+ ...
84
+ P02_15.MP4
85
+ ...
86
+ ```
87
+
88
+ ## CharadesEgo
89
+
90
+ 1. Download annotations at [https://prior.allenai.org/projects/charades-ego](https://prior.allenai.org/projects/charades-ego).
91
+ ```bash
92
+ ### Annotations
93
+ # Assume that you are under `datasets/CharadesEgo/`
94
+ wget https://ai2-public-datasets.s3-us-west-2.amazonaws.com/charades/CharadesEgo.zip
95
+ unzip CharadesEgo.zip && rm CharadesEgo.zip
96
+ ```
97
+
98
+ 2. Download data (~11GB) at [https://prior.allenai.org/projects/charades-ego](https://prior.allenai.org/projects/charades-ego).
99
+ ```bash
100
+ ### Data
101
+ wget https://ai2-public-datasets.s3-us-west-2.amazonaws.com/charades/CharadesEgo_v1_480.tar
102
+ tar -xvf CharadesEgo_v1_480.tar # Or specify an external path using `-C` and sim-link it to here
103
+ rm CharadesEgo_v1_480.tar
104
+ ```
105
+
106
+ 3. (For fine-tuning CharadesEgo) Download two additional metadata files: [clip-level metadata (train)](https://dl.fbaipublicfiles.com/lavila/metadata/CharadesEgo/metadata_filtered_train.pkl) and [clip-level metadata (val)](https://dl.fbaipublicfiles.com/lavila/metadata/CharadesEgo/metadata_filtered_val.pkl). Put them to the folder `$LAVILA_ROOT/datasets/CharadesEgo/CharadesEgo/`.
107
+
108
+ The folder should look like this:
109
+ ```bash
110
+ $LAVILA_ROOT/datasets/
111
+ CharadesEgo/
112
+ CharadesEgo/
113
+ CharadesEgo_v1_train_only1st.csv
114
+ CharadesEgo_v1_test_only1st.csv
115
+ ...
116
+ metadata_filtered_train.pkl # this appears if you do 3.
117
+ metadata_filtered_val.pkl # this appears if you do 3.
118
+ CharadesEgo_v1_480/
119
+ 005BU.mp4
120
+ 005BUEGO.mp4
121
+ ...
122
+ ```
123
+
124
+
125
+ ## EGTEA
126
+
127
+ 1. Visit [https://cbs.ic.gatech.edu/fpv/](https://cbs.ic.gatech.edu/fpv/).
128
+
129
+ 2. Download `TRIMMED_ACTION_CLIPS` (~20GB) and `ACTION_ANNOTATIONS` and untar to the current folder `$LAVILA_ROOT/datasets/EGTEA`.
130
+
131
+ ```bash
132
+ unzip action_annotation.zip -d EGTEA/ && rm action_annotation.zip
133
+ ```
134
+
135
+ The folder should look like this:
136
+ ```bash
137
+ $LAVILA_ROOT/datasets/
138
+ EGTEA/
139
+ train_split1.txt
140
+ test_split1.txt
141
+ cropped_clips/
142
+ OP01-R01-PastaSalad/
143
+ OP01-R01-PastaSalad-1002316-1004005-F024051-F024101.mp4
144
+ OP01-R01-PastaSalad-1004110-1021110-F024057-F024548.mp4
145
+ OP01-R01-PastaSalad-1022590-1024050-F024539-F024581.mp4
146
+ ...
147
+ OP01-R02-TurkeySandwich/
148
+ OP01-R02-TurkeySandwich-102320-105110-F002449-F002529.mp4
149
+ OP01-R02-TurkeySandwich-105440-106460-F002528-F002558.mp4
150
+ OP01-R02-TurkeySandwich-107332-133184-F002513-F003259.mp4
151
+ ...
152
+ ...
153
+ ```
demo_narrator.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import argparse
9
+ import os
10
+ import urllib.request
11
+ from collections import OrderedDict
12
+
13
+ import decord
14
+ import torch
15
+ import torchvision.transforms as transforms
16
+ import torchvision.transforms._transforms_video as transforms_video
17
+
18
+ from lavila.data.video_transforms import Permute
19
+ from lavila.data.datasets import get_frame_ids, video_loader_by_frames
20
+ from lavila.models.models import VCLM_OPENAI_TIMESFORMER_LARGE_336PX_GPT2_XL
21
+ from lavila.models.tokenizer import MyGPT2Tokenizer
22
+ from eval_narrator import decode_one
23
+
24
+
25
+ def main(args):
26
+
27
+ vr = decord.VideoReader(args.video_path)
28
+ num_seg = 4
29
+ frame_ids = get_frame_ids(0, len(vr), num_segments=num_seg, jitter=False)
30
+ frames = video_loader_by_frames('./', args.video_path, frame_ids)
31
+
32
+ ckpt_name = 'vclm_openai_timesformer_large_336px_gpt2_xl.pt_ego4d.jobid_246897.ep_0003.md5sum_443263.pth'
33
+ ckpt_path = os.path.join('modelzoo/', ckpt_name)
34
+ os.makedirs('modelzoo/', exist_ok=True)
35
+ if not os.path.exists(ckpt_path):
36
+ print('downloading model to {}'.format(ckpt_path))
37
+ urllib.request.urlretrieve('https://dl.fbaipublicfiles.com/lavila/checkpoints/narrator/{}'.format(ckpt_name), ckpt_path)
38
+ ckpt = torch.load(ckpt_path, map_location='cpu')
39
+ state_dict = OrderedDict()
40
+ for k, v in ckpt['state_dict'].items():
41
+ state_dict[k.replace('module.', '')] = v
42
+
43
+ # instantiate the model, and load the pre-trained weights
44
+ model = VCLM_OPENAI_TIMESFORMER_LARGE_336PX_GPT2_XL(
45
+ text_use_cls_token=False,
46
+ project_embed_dim=256,
47
+ gated_xattn=True,
48
+ timesformer_gated_xattn=False,
49
+ freeze_lm_vclm=False, # we use model.eval() anyway
50
+ freeze_visual_vclm=False, # we use model.eval() anyway
51
+ num_frames=4,
52
+ drop_path_rate=0.
53
+ )
54
+ model.load_state_dict(state_dict, strict=True)
55
+ if args.cuda:
56
+ model.cuda()
57
+ model.eval()
58
+
59
+ # transforms on input frames
60
+ crop_size = 336
61
+ val_transform = transforms.Compose([
62
+ Permute([3, 0, 1, 2]),
63
+ transforms.Resize(crop_size),
64
+ transforms.CenterCrop(crop_size),
65
+ transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305])
66
+ ])
67
+ frames = val_transform(frames)
68
+ frames = frames.unsqueeze(0) # fake a batch dimension
69
+
70
+ tokenizer = MyGPT2Tokenizer('gpt2-xl', add_bos=True)
71
+ with torch.no_grad():
72
+ if args.cuda:
73
+ frames = frames.cuda(non_blocking=True)
74
+ image_features = model.encode_image(frames)
75
+ generated_text_ids, ppls = model.generate(
76
+ image_features,
77
+ tokenizer,
78
+ target=None, # free-form generation
79
+ max_text_length=77,
80
+ top_k=None,
81
+ top_p=0.95, # nucleus sampling
82
+ num_return_sequences=10, # number of candidates: 10
83
+ temperature=0.7,
84
+ early_stopping=True,
85
+ )
86
+
87
+ for i in range(10):
88
+ generated_text_str = decode_one(generated_text_ids[i], tokenizer)
89
+ print('{}: {}'.format(i, generated_text_str))
90
+
91
+
92
+ if __name__ == '__main__':
93
+ parser = argparse.ArgumentParser('lavila narrator demo')
94
+ parser.add_argument('--cuda', action='store_true', help='use cuda')
95
+ parser.add_argument('--video-path', default='assets/3c0dffd0-e38e-4643-bc48-d513943dc20b_012_014.mp4', type=str, help='video path')
96
+ args = parser.parse_args()
97
+ main(args)
demo_narrator_3rd_person.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import argparse
9
+ import os
10
+ import urllib.request
11
+ from collections import OrderedDict
12
+
13
+ import decord
14
+ import torch
15
+ import torchvision.transforms as transforms
16
+ import torchvision.transforms._transforms_video as transforms_video
17
+
18
+ from lavila.data.video_transforms import Permute
19
+ from lavila.data.datasets import get_frame_ids, video_loader_by_frames
20
+ from lavila.models.models import VCLM_OPENAI_TIMESFORMER_LARGE_GPT2_XL
21
+ from lavila.models.tokenizer import MyGPT2Tokenizer
22
+ from eval_narrator import decode_one
23
+
24
+
25
+ def main(args):
26
+
27
+ vr = decord.VideoReader(args.video_path)
28
+ num_seg = 4
29
+ frame_ids = get_frame_ids(0, len(vr), num_segments=num_seg, jitter=False)
30
+ frames = video_loader_by_frames('./', args.video_path, frame_ids)
31
+
32
+ ckpt_name = 'vclm_openai_timesformer_large_gpt2_xl.pt_htm.jobid_341080.ep_0001.pth'
33
+ ckpt_path = os.path.join('modelzoo/', ckpt_name)
34
+ os.makedirs('modelzoo/', exist_ok=True)
35
+ if not os.path.exists(ckpt_path):
36
+ print('downloading model to {}'.format(ckpt_path))
37
+ urllib.request.urlretrieve('https://dl.fbaipublicfiles.com/lavila/checkpoints/narrator/htm_aa/{}'.format(ckpt_name), ckpt_path)
38
+ ckpt = torch.load(ckpt_path, map_location='cpu')
39
+ state_dict = OrderedDict()
40
+ for k, v in ckpt['state_dict'].items():
41
+ state_dict[k.replace('module.', '')] = v
42
+
43
+ # instantiate the model, and load the pre-trained weights
44
+ model = VCLM_OPENAI_TIMESFORMER_LARGE_GPT2_XL(
45
+ text_use_cls_token=False,
46
+ project_embed_dim=256,
47
+ gated_xattn=True,
48
+ timesformer_gated_xattn=False,
49
+ freeze_lm_vclm=False, # we use model.eval() anyway
50
+ freeze_visual_vclm=False, # we use model.eval() anyway
51
+ freeze_visual_vclm_temporal=False,
52
+ num_frames=4,
53
+ drop_path_rate=0.
54
+ )
55
+ model.load_state_dict(state_dict, strict=True)
56
+ if args.cuda:
57
+ model.cuda()
58
+ model.eval()
59
+
60
+ # transforms on input frames
61
+ crop_size = 224
62
+ val_transform = transforms.Compose([
63
+ Permute([3, 0, 1, 2]),
64
+ transforms.Resize(crop_size),
65
+ transforms.CenterCrop(crop_size),
66
+ transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305])
67
+ ])
68
+ frames = val_transform(frames)
69
+ frames = frames.unsqueeze(0) # fake a batch dimension
70
+
71
+ tokenizer = MyGPT2Tokenizer('gpt2-xl', add_bos=True)
72
+ with torch.no_grad():
73
+ if args.cuda:
74
+ frames = frames.cuda(non_blocking=True)
75
+ image_features = model.encode_image(frames)
76
+ generated_text_ids, ppls = model.generate(
77
+ image_features,
78
+ tokenizer,
79
+ target=None, # free-form generation
80
+ max_text_length=77,
81
+ top_k=None,
82
+ top_p=0.95, # nucleus sampling
83
+ num_return_sequences=10, # number of candidates: 10
84
+ temperature=0.7,
85
+ early_stopping=True,
86
+ )
87
+
88
+ for i in range(10):
89
+ generated_text_str = decode_one(generated_text_ids[i], tokenizer)
90
+ print('{}: {}'.format(i, generated_text_str))
91
+
92
+
93
+ if __name__ == '__main__':
94
+ parser = argparse.ArgumentParser('lavila narrator demo')
95
+ parser.add_argument('--cuda', action='store_true', help='use cuda')
96
+ parser.add_argument('--video-path', type=str,
97
+ default='assets/mixkit-pastry-chef-cutting-a-loaf-into-slices-43015-medium.mp4')
98
+ args = parser.parse_args()
99
+ main(args)
docs/INSTALL.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Installation
2
+
3
+ ## Requirements
4
+
5
+
6
+ ## Example conda environment setup
7
+
8
+ ```bash
9
+ conda create --name lavila python=3.8 -y
10
+ conda activate lavila
11
+ pip install -r requirements.txt
12
+ ```
13
+
14
+ ## datasets
15
+ If you want to train/evaluate on the datasets, please see [datasets/README.md](../datasets/README.md) to see how we prepare datasets for this project.
docs/MODEL_ZOO.md ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LAVILA Model Zoo
2
+
3
+ ## Multi-node Training
4
+ We use multi-node training on a SLURM cluster with [submitit](https://github.com/facebookincubator/submitit) for producing the results and models in the paper.
5
+ Please install `submitit` in your conda environment:
6
+ ```bash
7
+ pip install submitit
8
+ ```
9
+
10
+
11
+ ## Pre-training
12
+
13
+ Please refer to [PRETRAIN.md](./PRETRAIN.md).
14
+
15
+
16
+ ## Narrator
17
+
18
+ | Visual Encoder | Text Decoder | METEOR | ROUGE-L | CIDEr | Pre-trained<br>Vis. Encoder (md5) | checkpoint (md5) |
19
+ | :------------: | :----------: | :----: | :-----: | :---: | :-------------------------------: | :--------: |
20
+ | TSF-B | GPT-2 | 0.282 | 0.517 | 0.833 | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/clip_openai_timesformer_base.baseline.ep_0003.pth) (dbcc4d) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/narrator/vclm_openai_timesformer_base_gpt2_base.pt_ego4d.jobid_319630.ep_0002.md5sum_68a71f.pth) (68a71f) |
21
+ | TSF-L@HR | GPT-2 XL | 0.298 | 0.539 | 0.977 | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/clip_openai_timesformer_large_336px_distilbert_base.baseline.ep_0003.pth) (5c69b8) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/narrator/vclm_openai_timesformer_large_336px_gpt2_xl.pt_ego4d.jobid_246897.ep_0003.md5sum_443263.pth) (443263) |
22
+
23
+
24
+ <details><summary>Ego4D val split</summary>
25
+ <p>
26
+
27
+ ```bash
28
+ torchrun --nproc_per_node=1 \
29
+ eval_narrator.py \
30
+ --caption-top-p 0.95 --caption-temperature 0.7 \
31
+ --eval-freq 10000 \
32
+ --resume $CHECKPOINT
33
+ ```
34
+
35
+ </p></details>
36
+
37
+ ## Zero-shot
38
+
39
+ <div class="table-wrapper" markdown="block">
40
+
41
+ | | Backbone | EK-100 MIR<br>avg. mAP | EK-100 MIR<br>avg. nDCG | Charades-Ego<br>mAP^ | EGTEA<br> mean acc. | EgoMCQ<br>intra-video acc. | checkpoint |
42
+ | :----------: | :------: | :--------------------: | :---------------------: | :------------------: | :-----------------: | :------------------------: | :----------: |
43
+ | Prev. SOTA^^ | TSF-B | 22.1/23.3 | 22.1/27.9 | 25.2 | 17.6 | 57.2 | [Epoch 1](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/egovlp_epo1_converted_f16.md5sum_7a3d3b.pth), [best epoch](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/egovlp_converted_f16.md5sum_c33363.pth) |
44
+ | LAVILA | TSF-B | 29.7/30.9 | 31.5/32.0 | 26.8 | 28.9 | 59.9 | [Epoch 1](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/clip_openai_timesformer_base.narrator_rephraser.ep_0001.md5sum_02dbb9.pth)^, [Epoch 5](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/clip_openai_timesformer_base.narrator_rephraser.ep_0005.md5sum_d73a9c.pth) |
45
+ | LAVILA | TSF-L | 35.0/36.1 | 34.2/34.6 | 28.9 | 34.1 | 63.1 | [Epoch 1](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/clip_openai_timesformer_large.narrator_rephraser.ep_0001.md5sum_9a25de.pth)^, [Epoch 3](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/clip_openai_timesformer_large.narrator_rephraser.ep_0003.md5sum_c89337.pth) |
46
+
47
+ </div>
48
+
49
+ ^ Note that the pre-trained checkpoint to evaluate CharadesEgo is different from that to evalute other datasets.
50
+ Specifically, we use the checkpoint at epoch 1 to zero-shot evaluate CharadesEgo and the checkpoint that achieves best average mAP on EK-100 MIR to evaluate other datasets, as is done in [EgoVLP](https://arxiv.org/pdf/2206.01670.pdf).
51
+ Our guess is that since CharadesEgo videos (captured by head-mounted mobile cameras) are visually different from Ego4D/EPIC-Kitchens videos (captured by professional action cameras, eg GoPro), pre-training on Ego4D videos for longer will lead to some potential domain discrepancy.
52
+
53
+ ^^ We use the checkpoints released by [EgoVLP](https://github.com/showlab/EgoVLP) and convert them to be compatible with this codebase. Also note that our reproduced numbers are better than the reported numbers, especially on EK-100 MIR since we evaluate on raw videos directly (for more details, check out Appendix F & Table 10 in our paper).
54
+
55
+ <details><summary>1. EK-100 MIR</summary>
56
+ <p>
57
+
58
+ ```bash
59
+ python eval_zeroshot.py --dataset ek100_mir --root datasets/EK100/video_ht256px/ --clip-length 4 --resume $PATH
60
+ ```
61
+ By increasing the number of frames per clip, eg `--clip-length 16`, you are expected to see a better performance.
62
+
63
+ </p></details>
64
+
65
+ <details><summary>2. EK-100 CLS</summary>
66
+ <p>
67
+
68
+ ```bash
69
+ python eval_zeroshot.py --dataset ek100_cls --metadata-val datasets/EK100/epic-kitchens-100-annotations/EPIC_100_validation.csv --resume $PATH
70
+ ```
71
+
72
+ </p></details>
73
+
74
+ <details><summary>3. Charades-Ego</summary>
75
+ <p>
76
+
77
+ ```bash
78
+ python eval_zeroshot.py --dataset charades_ego --metadata-val datasets/CharadesEgo/CharadesEgo/CharadesEgo_v1_test_only1st.csv --root datasets/CharadesEgo/CharadesEgo_v1_480/ --clip-length 16 --sparse-sample --resume $PATH
79
+ ```
80
+
81
+ </p></details>
82
+
83
+ <details><summary>4. EGTEA</summary>
84
+ <p>
85
+
86
+ ```bash
87
+ python eval_zeroshot.py --dataset egtea --metadata-val datasets/EGTEA/test_split1.txt --root datasets/EGTEA/cropped_clips/ --clip-length 16 --clip-stride 2 --num-crops 3 --num-clips 10 --resume $PATH
88
+ ```
89
+
90
+ </p></details>
91
+
92
+ <details><summary>5. EgoMCQ</summary>
93
+ <p>
94
+
95
+ ```bash
96
+ python eval_zeroshot.py --dataset ego4d_mcq --metadata-val datasets/Ego4D/egomcq.json --root datasets/Ego4D/video_5min_chunks_288px/ --clip-length 4 --resume $PATH --use-half -j 4
97
+ ```
98
+
99
+ </p></details>
100
+
101
+ ## Fine-tuned
102
+
103
+ ### EK-100 MIR
104
+
105
+ <div class="table-wrapper" markdown="block">
106
+
107
+ | | Backbone | avg mAP | avg nDCG | Pretrain (md5) | Fine-tuned checkpoint | training log |
108
+ | :----: | :-------:| :-----: | :------: | :----------: | :-------------------: | :----------: |
109
+ | LAVILA | TSF-B | 50.5 | 65.0 | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/clip_openai_timesformer_base.narrator_rephraser.ep_0005.md5sum_d73a9c.pth) (d73a9c) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ek100_mir/clip_openai_timesformer_base.ft_ek100_mir.ep_0085.md5sum_c67d95.pth) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ek100_mir/clip_openai_timesformer_base.ft_ek100_mir.jobid_57361.log) |
110
+ | LAVILA | TSF-L | 50.9 | 66.5 | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/clip_openai_timesformer_large.narrator_rephraser.ep_0003.md5sum_c89337.pth) (c89337) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ek100_mir/clip_openai_timesformer_large.ft_ek100_mir.ep_0095.md5sum_bd508b.pth) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ek100_mir/clip_openai_timesformer_large.ft_ek100_mir.jobid_56606.log) |
111
+
112
+ </div>
113
+
114
+
115
+ <details><summary>Training and evaluating scripts</summary>
116
+ <p>
117
+
118
+ ### Multi-node training (Slurm)
119
+ ```bash
120
+ # TimeSformer-Base
121
+ python run_with_submitit_finetune_retrieval.py \
122
+ --pretrain-model $PATH \
123
+ --use-checkpoint --nodes 4
124
+
125
+ # TimeSformer-Large
126
+ python run_with_submitit_finetune_retrieval.py \
127
+ --pretrain-model $PATH \
128
+ --batch-size 4 \
129
+ --use-checkpoint --nodes 4
130
+ ```
131
+
132
+ ### Single-machine training
133
+ ```bash
134
+ torchrun --nproc_per_node=8 \
135
+ main_finetune_retrieval.py \
136
+ --output-dir $OUT_DIR \
137
+ --pretrain-model $PATH \
138
+ --use-checkpoint
139
+ ```
140
+
141
+ Note that you might see a slight drop of performance when training on a single node compared to multiple nodes (everything else being the same) because of a smaller total batch size.
142
+
143
+ ### Evaluation
144
+
145
+ Evaluation is done every `--eval-freq 5` epochs by default during fine-tuning.
146
+ If you want to evaluate any checkpoint after fine-tuning, please switch to `--evaluate` mode and specify the path to the checkpoint by `--resume $FINETUNED_CHECKPOINT`.
147
+ ```bash
148
+ torchrun --nproc_per_node=1 \
149
+ main_finetune_retrieval.py \
150
+ --output-dir $OUT_DIR \
151
+ --pretrain-model $PATH \
152
+ --use-checkpoint \
153
+ --evaluate \
154
+ --resume $FINETUNED_CHECKPOINT
155
+ ```
156
+
157
+
158
+ </p></details>
159
+
160
+ ### CharadesEgo
161
+
162
+ <div class="table-wrapper" markdown="block">
163
+
164
+ | | Backbone | video mAP |Pretrain^ (md5) | Fine-tuned checkpoint | training log |
165
+ | :----: | :-------:| :------: | :-------: | :-------------------: | :----------: |
166
+ | LAVILA | TSF-B | 33.7 | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/clip_openai_timesformer_base.narrator_rephraser.ep_0001.md5sum_02dbb9.pth) (02dbb9) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/charades_ego/clip_openai_timesformer_base.ft_charades_ego.ep_0005.md5sum_39bf4b.pth) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/charades_ego/clip_openai_timesformer_base.ft_charades_ego.jobid_65760.log) |
167
+ | LAVILA | TSF-L | 36.1 | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/clip_openai_timesformer_large.narrator_rephraser.ep_0001.md5sum_9a25de.pth) (9a25de) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/charades_ego/clip_openai_timesformer_large.ft_charades_ego.ep_0003.md5sum_9448b2.pth) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/charades_ego/clip_openai_timesformer_large.ft_charades_ego.jobid_65760.log) |
168
+
169
+ </div>
170
+
171
+ ^ Note that the pre-trained checkpoint for fine-tuning CharadesEgo is different from that for fine-tuning EK-100 or EGTEA. Same reason stated above.
172
+
173
+ <details><summary>Training and evaluating scripts</summary>
174
+ <p>
175
+
176
+ ### Multi-node training (Slurm)
177
+
178
+ ```bash
179
+ # TimeSformer-Base
180
+ python run_with_submitit_finetune_retrieval.py \
181
+ --dataset charades_ego \
182
+ --metadata datasets/CharadesEgo/CharadesEgo/metadata_filtered_train.pkl \
183
+ --metadata-val datasets/CharadesEgo/CharadesEgo/CharadesEgo_v1_test_only1st.csv \
184
+ --root datasets/CharadesEgo/CharadesEgo_v1_480/ \
185
+ --epochs 10 \
186
+ --save-freq 1 --eval-freq 1 \
187
+ --sparse-sample \
188
+ --pretrain-model $PATH \
189
+ --use-checkpoint --nodes 4
190
+
191
+ # TimeSformer-Large
192
+ python run_with_submitit_finetune_retrieval.py \
193
+ --dataset charades_ego \
194
+ --metadata datasets/CharadesEgo/CharadesEgo/metadata_filtered_train.pkl \
195
+ --metadata-val datasets/CharadesEgo/CharadesEgo/CharadesEgo_v1_test_only1st.csv \
196
+ --root datasets/CharadesEgo/CharadesEgo_v1_480/ \
197
+ --epochs 10 \
198
+ --save-freq 1 --eval-freq 1 \
199
+ --sparse-sample \
200
+ --pretrain-model $PATH \
201
+ --batch-size 4 \
202
+ --use-checkpoint --nodes 4
203
+ ```
204
+
205
+ ### Evaluation
206
+ ```bash
207
+ torchrun --nproc_per_node=1 \
208
+ main_finetune_retrieval.py \
209
+ --dataset charades_ego \
210
+ --metadata datasets/CharadesEgo/CharadesEgo/metadata_filtered_train.pkl \
211
+ --metadata-val datasets/CharadesEgo/CharadesEgo/CharadesEgo_v1_test_only1st.csv \
212
+ --root datasets/CharadesEgo/CharadesEgo_v1_480/ \
213
+ --output-dir $OUT_DIR \
214
+ --sparse-sample \
215
+ --pretrain-model $PATH \
216
+ --evaluate \
217
+ --resume $FINETUNED_CHECKPOINT
218
+ ```
219
+
220
+ </p></details>
221
+
222
+ ### EK-100 CLS
223
+
224
+ <div class="table-wrapper" markdown="block">
225
+
226
+ | | Backbone | V+N+A multi-head | Verb top-1 | Noun top-1 | Action top-1 | Pretrain (md5) | Fine-tuned checkpoint | training log |
227
+ | :----: | :-------:| :--------------: | :--------: | :--------: | :---------: | :------------: | :-------------------: | :----------: |
228
+ | LAVILA | TSF-B | no | 67.7 | 56.7 | 46.2 | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/clip_openai_timesformer_base.narrator_rephraser.ep_0005.md5sum_d73a9c.pth) (d73a9c) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ek100_cls/clip_openai_timesformer_base.ft_ek100_cls.single_head.ep_0100.md5sum_e8aa0c.pth) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ek100_cls/clip_openai_timesformer_base.ft_ek100_cls.single_head.jobid_73363.log) |
229
+ | LAVILA | TSF-B | yes | 69.0 | 58.4 | 46.9 | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/clip_openai_timesformer_base.narrator_rephraser.ep_0005.md5sum_d73a9c.pth) (d73a9c) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ek100_cls/clip_openai_timesformer_base.ft_ek100_cls.ep_0100.md5sum_4e3575.pth) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ek100_cls/clip_openai_timesformer_base.ft_ek100_cls.jobid_73361.log) |
230
+ | LAVILA | TSF-L | yes | 72.0 | 62.9 | 51.0 | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/clip_openai_timesformer_large.narrator_rephraser.ep_0003.md5sum_c89337.pth) (c89337) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ek100_cls/clip_openai_timesformer_large.ft_ek100_cls.ep_0090.md5sum_4a2509.pth) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ek100_cls/clip_openai_timesformer_large.ft_ek100_cls.jobid_74016.log) |
231
+ </div>
232
+
233
+ <details><summary>Training and evaluating scripts</summary>
234
+ <p>
235
+
236
+ ### Multi-node training (Slurm)
237
+
238
+ ```bash
239
+ # TimeSformer-Base
240
+ python run_with_submitit_finetune_classification.py \
241
+ --pretrain-model $PATH \
242
+ --use-vn-classifier --num-classes 97 300 3806 \
243
+ --use-sgd --wd 4e-5 --lr-multiplier-on-backbone 0.1 \
244
+ --use-checkpoint --node 1
245
+
246
+ # TimeSformer-Large
247
+ python run_with_submitit_finetune_classification.py \
248
+ --pretrain-model $PATH \
249
+ --use-vn-classifier --num-classes 97 300 3806 \
250
+ --use-sgd --wd 4e-5 --lr-multiplier-on-backbone 0.1 \
251
+ --use-checkpoint --node 4
252
+ ```
253
+
254
+ </p></details>
255
+
256
+ ### EGTEA
257
+
258
+ <div class="table-wrapper" markdown="block">
259
+
260
+ | | Backbone | mean Acc. | Pretrain (md5) | Fine-tuned checkpoint | training log |
261
+ | :----: | :-------:| :-------: | :------: | :-------------------: | :----------: |
262
+ | LAVILA | TSF-B | 70.12 | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/clip_openai_timesformer_base.narrator_rephraser.ep_0005.md5sum_d73a9c.pth) (d73a9c) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/egtea/clip_openai_timesformer_base.ft_egtea.ep_0090.md5sum_3b1faf.pth) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/egtea/clip_openai_timesformer_base.ft_egtea.jobid_73358.log) |
263
+ | LAVILA | TSF-L | 76.00 | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/clip_openai_timesformer_large.narrator_rephraser.ep_0003.md5sum_c89337.pth) (c89337) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/egtea/clip_openai_timesformer_large.ft_egtea.ep_0095.md5sum_a5ba17.pth) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/egtea/clip_openai_timesformer_large.ft_egtea.jobid_74026.log) |
264
+
265
+ </div>
266
+
267
+ <details><summary>Training and evaluating scripts</summary>
268
+ <p>
269
+
270
+ ```bash
271
+ # TimeSformer-Base
272
+ python run_with_submitit_finetune_classification.py \
273
+ --dataset egtea \
274
+ --metadata-train datasets/EGTEA/train_split1.txt \
275
+ --metadata-val datasets/EGTEA/test_split1.txt \
276
+ --root datasets/EGTEA/cropped_clips/ \
277
+ --pretrain-model $PATH \
278
+ --num-classes 106 \
279
+ --use-sgd --wd 4e-5 \
280
+ --use-checkpoint --node 1
281
+
282
+ # TimeSformer-Large
283
+ python run_with_submitit_finetune_classification.py \
284
+ --dataset egtea \
285
+ --metadata-train datasets/EGTEA/train_split1.txt \
286
+ --metadata-val datasets/EGTEA/test_split1.txt \
287
+ --root datasets/EGTEA/cropped_clips/ \
288
+ --pretrain-model $PATH \
289
+ --num-classes 106 \
290
+ --use-sgd --wd 4e-5 \
291
+ --batch-size 4 \
292
+ --use-checkpoint --node 4
293
+ ```
294
+ ### Evaluation
295
+ ```bash
296
+ torchrun --nproc_per_node=1 \
297
+ main_finetune_classification.py \
298
+ --dataset egtea \
299
+ --metadata-train datasets/EGTEA/train_split1.txt \
300
+ --metadata-val datasets/EGTEA/test_split1.txt \
301
+ --root datasets/EGTEA/cropped_clips/ \
302
+ --output-dir $OUT_DIR \
303
+ --pretrain-model $PATH \
304
+ --num-classes 106 \
305
+ --use-sgd --wd 4e-5 \
306
+ --evaluate \
307
+ --resume $FINETUNED_CHECKPOINT \
308
+ --num-crops 3 --num-clips 10 \
309
+ --use-half
310
+ ```
311
+ </p></details>
docs/PRETRAIN.md ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LAVILA Pretraining
2
+
3
+ In this doc, we provide a step-by-step guide (with commands) to train LaViLa.
4
+ Note that we recommend running the following job with four 8x V100 (32GB) nodes (or eight nodes for the larger backbone) using [submitit](https://github.com/facebookincubator/submitit).
5
+ See how to install submitit at [here](./MODEL_ZOO.md#multi-node-training).
6
+
7
+
8
+ ## Pre-training Dual-Encoder Baseline
9
+
10
+ We first pre-train a dual-encoder baseline with human annotations on Ego4d clips.
11
+ The goal is (1) to establish a comparable baseline for LAVILA, and (2) provide a video encoder for narrator (see below).
12
+ We use a default batch size of 32 per gpu so that the total batch size for InfoNCE loss is `32*8*4=1024`.
13
+
14
+ <details><summary> Train a baseline dual-encoder (with TSF-B) </summary>
15
+
16
+ ```bash
17
+ python run_with_submitit_pretrain.py --model CLIP_OPENAI_TIMESFORMER_BASE \
18
+ --norm-embed --freeze-temperature \
19
+ --fix-lr --contrastive-use-vissl \
20
+ --nodes 4 --use_volta32
21
+ ```
22
+ </details>
23
+
24
+ To fit a High-Resolution TimeSformer-Large with a sufficient batch size, we use [DistilBERT](https://huggingface.co/docs/transformers/model_doc/distilbert), a memory-efficient text encoder, instead of the original text encoder in the CLIP. Additionally we apply [gradient checkpointing](https://pytorch.org/docs/stable/checkpoint.html) and [Zero Redundancy Optimizer (ZeRO)](https://arxiv.org/abs/1910.02054).
25
+
26
+ <details><summary> Train a baseline dual-encoder (with TSF-L@HR) </summary>
27
+
28
+ ```bash
29
+ python run_with_submitit_pretrain.py --model CLIP_OPENAI_TIMESFORMER_LARGE_336PX_DISTILBERT_BASE \
30
+ --batch-size 8 \
31
+ --use-checkpoint --use-zero \
32
+ --norm-embed --freeze-temperature \
33
+ --fix-lr --contrastive-use-vissl \
34
+ --nodes 8 --use_volta32
35
+ ```
36
+ </details>
37
+
38
+ ## Training and Evaluating Narrator
39
+
40
+ The narrator is a *visually conditioned* large language model (VCLM), which comprises a pre-trained video encoder (obtained above), a text decoder (GPT-2 family), and a few gated cross-attention modules that attends visual information while captioning. Both the video encoder and the text decoder are kept frozen while the cross-attention modules are learnable.
41
+
42
+ Note that we turn off Pytorch's automatic mixed-precision (AMP) during training the narrator. We observe training is instable if AMP is on.
43
+
44
+ Also note that `$PATH` can be found in the `Vis. Encoder` column of [MODEL_ZOO.md#Narrator](./MODEL_ZOO.md#narrator). If you are using your own checkpoint (e.g. pre-trained in the previous step), please make sure that the following keys in the checkpoint have been dropped: `epoch`, `optimizer`, and `scaler`.
45
+
46
+ <details><summary> Train a baseline narrator (TSF-B as visual encoder and GPT-2 base as textual decoder) </summary>
47
+
48
+ ```bash
49
+ python run_with_submitit_pretrain.py \
50
+ --model VCLM_OPENAI_TIMESFORMER_BASE_GPT2 \
51
+ --gated-xattn --freeze-lm-vclm --freeze-visual-vclm --freeze-visual-vclm-temporal \
52
+ --fix-lr --batch-size 8 --clip-grad-value 1.0 --eval-freq 1 --disable-amp \
53
+ --nodes 4 --use_volta32 --resume $PATH # Eg. $PATH can be "modelzoo/clip_openai_timesformer_base.baseline.ep_0003.pth"
54
+ ```
55
+
56
+ </details>
57
+
58
+ <details><summary> Train a strong narrator (TSF-L@HR as visual encoder and GPT-2 XL as textual decoder) </summary>
59
+
60
+ ```bash
61
+ python run_with_submitit_pretrain.py \
62
+ --model VCLM_OPENAI_TIMESFORMER_LARGE_336PX_GPT2_XL \
63
+ --gated-xattn --freeze-lm-vclm --freeze-visual-vclm --freeze-visual-vclm-temporal --use-checkpoint \
64
+ --fix-lr --batch-size 8 --clip-grad-value 1.0 --eval-freq 1 --disable-amp \
65
+ --nodes 4 --use_volta32 --resume $PATH # Eg. $PATH can be "modelzoo/clip_openai_timesformer_large_336px_distilbert_base.baseline.ep_0003.pth"
66
+ ```
67
+ </details>
68
+
69
+ <details><summary> Evaluate the narrator on Ego4D val split </summary>
70
+
71
+ ```bash
72
+ torchrun --nproc_per_node=1 eval_narrator.py \
73
+ --caption-top-p 0.95 --caption-temperature 0.7 \
74
+ --eval-freq 10000 \ # evaluate on the val split of Ego4D (1/10000-subset for fast evaluation)
75
+ --resume $VCLM_CHECKPOINT
76
+ ```
77
+ This will output some common NLG metrics, such as BLEU-x, METEOR, ROUGE_L, and CIDEr (using the human narrations as ground-truth).
78
+ </details>
79
+
80
+ ## Narrating video clips using LAVILA-Narrator
81
+
82
+
83
+ <details><summary> Infer the narrator </summary>
84
+
85
+ ```bash
86
+ python run_with_submitit_infer_narrator.py \
87
+ --metadata datasets/Ego4D/ego4d_train.pkl \
88
+ --batch-size 64 \
89
+ --resume $PATH --use-half \
90
+ --nodes 4 --use_volta32
91
+ ```
92
+ </details>
93
+
94
+ It will generate a pickle file (`$output_dir/total.pkl`) which is a list of quintuples - `(video_uid: str, start_time: float, end_time: float, narration_list: List[str], NLL_list: List[float])`.
95
+
96
+ For narrator-generated narrations on Ego4D ground-truth clips, we also provide a [replica](https://dl.fbaipublicfiles.com/lavila/metadata/ego4d/ego4d_train.narrator_63690737.return_10.pkl). Note that the narrator used here is our best performing one.
97
+
98
+ ## Rephrasing human narrations using LAVILA-Rephraser
99
+
100
+ Rephraser is a standard LLM that can paraphrase narrations in existing clips.
101
+ Specifically, we use an off-the-shelf T5-based paraphraser which is publicly available at [Hugging Face's model hub](https://huggingface.co/ramsrigouthamg/t5-large-paraphraser-diverse-high-quality).
102
+ For more details, please refer to the [model card](https://huggingface.co/ramsrigouthamg/t5-large-paraphraser-diverse-high-quality).
103
+
104
+ For rephrased human narrations on Ego4D ground-truth clips, we provide a [replica](https://dl.fbaipublicfiles.com/lavila/metadata/ego4d/ego4d_train.rephraser.no_punkt_top3.pkl).
105
+
106
+
107
+ ## Pre-training LAVILA Dual-Encoder
108
+ Now we are ready to pre-train our LAVILA's dual-encoder by combining human annotations (augmented by Rephraser) and the Narrator-generated narrations.
109
+
110
+ <details><summary> Training a LaViLa dual-encoder </summary>
111
+
112
+ ```bash
113
+ python run_with_submitit_pretrain.py --model CLIP_OPENAI_TIMESFORMER_BASE \
114
+ --metadata datasets/Ego4D/ego4d_train.rephraser.no_punkt_top3.pkl \
115
+ --metadata-aux datasets/Ego4D/ego4d_train.narrator_63690737.return_10.pkl \
116
+ --norm-embed --freeze-temperature \
117
+ --freeze-pseudo-temperature \
118
+ --fix-lr --contrastive-use-vissl \
119
+ --nodes 4 --use_volta32
120
+ ```
121
+ </details>
122
+
123
+ ## Down-stream Evaluation
124
+ With the pre-trained dual-encoder at hand, we now can do zero-shot or fine-tuning evalution evaluations on down-stream benchmarks.
125
+ Please refer to [MODEL_ZOO.md](./MODEL_ZOO.md#zero-shot) for more details.
eval_narrator.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import os.path as osp
9
+ import time
10
+ from collections import OrderedDict
11
+
12
+ import numpy as np
13
+ # https://github.com/numpy/numpy/issues/21079
14
+ try:
15
+ import numpy.distutils
16
+ numpy.distutils.__config__.blas_opt_info = np.distutils.__config__.blas_ilp64_opt_info
17
+ except Exception:
18
+ pass
19
+ from nlgeval import NLGEval
20
+
21
+ import torch
22
+ import torchvision.transforms as transforms
23
+ import torchvision.transforms._transforms_video as transforms_video
24
+
25
+ from lavila.data import datasets
26
+ from lavila.data.video_transforms import Permute, SpatialCrop, TemporalCrop
27
+ from lavila.models import models
28
+ from lavila.models.utils import inflate_positional_embeds
29
+ from lavila.utils import distributed as dist_utils
30
+ from lavila.utils.preprocess import generate_tokenizer
31
+
32
+
33
+ def decode_one(generated_ids, tokenizer):
34
+ # get the index of <EOS>
35
+ if tokenizer.eos_token_id == tokenizer.bos_token_id:
36
+ if tokenizer.eos_token_id in generated_ids[1:].tolist():
37
+ eos_id = generated_ids[1:].tolist().index(tokenizer.eos_token_id) + 1
38
+ else:
39
+ eos_id = len(generated_ids.tolist()) - 1
40
+ elif tokenizer.eos_token_id in generated_ids.tolist():
41
+ eos_id = generated_ids.tolist().index(tokenizer.eos_token_id)
42
+ else:
43
+ eos_id = len(generated_ids.tolist()) - 1
44
+ generated_text_str = tokenizer.tokenizer.decode(generated_ids[1:eos_id].tolist())
45
+ return generated_text_str
46
+
47
+
48
+ def get_args_parser():
49
+ parser = argparse.ArgumentParser(description='LAVILA 0-shot evaluations', add_help=False)
50
+ parser.add_argument('--dataset', default='ego4d', type=str,
51
+ choices=['ego4d'])
52
+ parser.add_argument('--root',
53
+ default='datasets/Ego4D/video_5min_chunks_288px/',
54
+ type=str, help='path to dataset root')
55
+ parser.add_argument('--metadata-val',
56
+ default='datasets/Ego4D/ego4d_val.pkl',
57
+ type=str, help='path to metadata file (val set)')
58
+ parser.add_argument('--output-dir', default='./', type=str, help='output dir')
59
+ parser.add_argument('--num-crops', default=1, type=int, help='number of crops in transforms')
60
+ parser.add_argument('--num-clips', default=1, type=int, help='number of clips (for untrimmed videos, eg. Charades)')
61
+ parser.add_argument('--clip-length', default=4, type=int, help='clip length')
62
+ parser.add_argument('--clip-stride', default=16, type=int, help='clip stride')
63
+ parser.add_argument('--sparse-sample', action='store_true', help='switch to sparse sampling')
64
+ parser.add_argument('--batch-size', default=16, type=int, help='batch_size')
65
+ # captioning options
66
+ parser.add_argument('--caption-sample', default='multinomial_sample',
67
+ choices=['multinomial_sample', 'beam_sample', 'group_beam_search'])
68
+ parser.add_argument('--caption-top-k', default=None, type=int, help='top-k sampling (predecessor of nucleus sampling)')
69
+ parser.add_argument('--caption-top-p', default=0.95, type=float, help='top-p sampling sampling (aka nucleus sampling)')
70
+ parser.add_argument('--caption-num-beams', default=3, type=int)
71
+ parser.add_argument('--caption-num-beam-groups', default=1, type=int)
72
+ parser.add_argument('--caption-temperature', default=0.7, type=float)
73
+ parser.add_argument('--caption-length-penalty', default=1.0, type=float)
74
+ parser.add_argument('--caption-num-return-sequences', default=1, type=int)
75
+ parser.add_argument('--caption-max-len', default=77, type=int)
76
+ parser.add_argument('--caption-disable-visual', action='store_true')
77
+ parser.add_argument('--caption-early-stop', action='store_true', help='early stopping to save computation')
78
+ parser.add_argument('--caption-output-filename', default='caption.txt', type=str)
79
+ # others
80
+ parser.add_argument('--eval-freq', default=1000, type=int,
81
+ help='percentage (1/eval_freq) of val data to evaluate (for fast prototyping)')
82
+ parser.add_argument('--print-freq', default=10, type=int)
83
+ parser.add_argument('-j', '--workers', default=10, type=int, metavar='N',
84
+ help='number of data loading workers per process')
85
+ parser.add_argument('--resume', default='', type=str, help='path to latest checkpoint')
86
+ parser.add_argument('--use-half', action='store_true')
87
+ return parser
88
+
89
+
90
+ def main(args):
91
+ if args.resume:
92
+ ckpt_path = args.resume
93
+ elif osp.isfile(osp.join(args.output_dir, 'checkpoint_best.pt')):
94
+ ckpt_path = osp.join(args.output_dir, 'checkpoint_best.pt')
95
+ else:
96
+ raise Exception('no checkpoint found')
97
+
98
+ ckpt = torch.load(ckpt_path, map_location='cpu')
99
+
100
+ # create model
101
+ state_dict = OrderedDict()
102
+ for k, v in ckpt['state_dict'].items():
103
+ state_dict[k.replace('module.', '')] = v
104
+
105
+ old_args = ckpt['args']
106
+ print('=> creating model: {}'.format(old_args.model))
107
+ model = getattr(models, old_args.model)(
108
+ text_use_cls_token=old_args.use_cls_token,
109
+ project_embed_dim=old_args.project_embed_dim,
110
+ gated_xattn=False if 'gated_xattn' not in old_args else old_args.gated_xattn,
111
+ timesformer_gated_xattn=False if 'timesformer_gated_xattn' not in old_args else old_args.timesformer_gated_xattn,
112
+ timesformer_freeze_space=False if 'timesformer_freeze_space' not in old_args else old_args.timesformer_freeze_space,
113
+ freeze_lm_vclm=False if 'freeze_lm_vclm' not in old_args else old_args.freeze_lm_vclm,
114
+ freeze_visual_vclm=False if 'freeze_visual_vclm' not in old_args else old_args.freeze_visual_vclm,
115
+ num_frames=args.clip_length,
116
+ drop_path_rate=0,
117
+ )
118
+ model.cuda()
119
+ if 'TIMESFORMER' in old_args.model or 'EGOVLP' in old_args.model:
120
+ # inflate weight
121
+ print('=> inflating PE in models due to different frame numbers')
122
+ state_dict = inflate_positional_embeds(
123
+ model.state_dict(), state_dict,
124
+ num_frames=args.clip_length,
125
+ load_temporal_fix='bilinear',
126
+ )
127
+ model.load_state_dict(state_dict, strict=True)
128
+ print("=> loaded resume checkpoint '{}' (epoch {}, best_metric = {})".format(args.resume, ckpt['epoch'], ckpt['best_acc1']))
129
+
130
+ torch.backends.cudnn.benchmark = True
131
+
132
+ tokenizer = generate_tokenizer(old_args.model)
133
+ crop_size = 224 if '336PX' not in old_args.model else 336
134
+ if args.num_crops == 1 and args.num_clips == 1:
135
+ val_transform = transforms.Compose([
136
+ Permute([3, 0, 1, 2]), # T H W C -> C T H W
137
+ transforms.Resize(crop_size),
138
+ transforms.CenterCrop(crop_size),
139
+ (transforms_video.NormalizeVideo(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]) if ('OPENAI' not in old_args.model) else
140
+ transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305])),
141
+ ])
142
+ else:
143
+ val_transform = transforms.Compose([
144
+ Permute([3, 0, 1, 2]), # T H W C -> C T H W
145
+ transforms.Resize(crop_size),
146
+ (transforms_video.NormalizeVideo(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]) if ('OPENAI' not in old_args.model) else
147
+ transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305])),
148
+ TemporalCrop(frames_per_clip=args.clip_length, stride=args.clip_length),
149
+ SpatialCrop(crop_size=crop_size, num_crops=args.num_crops),
150
+ ])
151
+
152
+ val_dataset = datasets.VideoCaptionDatasetCLIP(
153
+ args.dataset,
154
+ args.root,
155
+ args.metadata_val,
156
+ transform=val_transform,
157
+ is_training=False,
158
+ tokenizer=tokenizer,
159
+ clip_length=args.clip_length,
160
+ clip_stride=args.clip_stride,
161
+ sparse_sample=False,
162
+ subsample_stride=args.eval_freq,
163
+ )
164
+
165
+ val_loader = torch.utils.data.DataLoader(
166
+ val_dataset, batch_size=args.batch_size, shuffle=False,
167
+ num_workers=args.workers, pin_memory=True, drop_last=False)
168
+
169
+ validate_caption(val_loader, model, tokenizer, args.caption_output_filename, use_half=args.use_half)
170
+
171
+
172
+ def validate_caption(val_loader, model, tokenizer, output_filename='caption.txt', use_half=False):
173
+ model.eval()
174
+ if args.use_half:
175
+ model = model.half()
176
+ nlgeval = NLGEval()
177
+ f = open(output_filename, 'w')
178
+ ppls_all = []
179
+ ppls_with_teacher_all = []
180
+ reference = []
181
+ hypothesis = []
182
+ end_time = time.time()
183
+ id_offset = 0
184
+ print('=> start forwarding')
185
+ with torch.no_grad():
186
+ for i, inputs in enumerate(val_loader):
187
+ if i % args.print_freq == 0:
188
+ print('finish batch {}/{} in {} sec'.format(i, len(val_loader), time.time() - end_time))
189
+ end_time = time.time()
190
+ images = inputs[0].cuda(non_blocking=True)
191
+ if use_half:
192
+ images = images.half()
193
+ target = inputs[1].cuda(non_blocking=True)
194
+
195
+ # encode images
196
+ image_features = dist_utils.get_model(model).encode_image(images)
197
+
198
+ # teacher forcing (to get standard ppl metric)
199
+ generated_text_ids_with_teacher, ppls_with_teacher = dist_utils.get_model(model).generate(
200
+ image_features,
201
+ tokenizer,
202
+ target=target,
203
+ max_text_length=args.caption_max_len,
204
+ top_k=args.caption_top_k,
205
+ top_p=args.caption_top_p,
206
+ teacher_forcing=True,
207
+ early_stopping=args.caption_early_stop,
208
+ )
209
+
210
+ if args.caption_sample == 'multinomial_sample':
211
+ assert args.caption_num_beam_groups == 1
212
+ generated_text_ids, ppls = dist_utils.get_model(model).generate(
213
+ image_features,
214
+ tokenizer,
215
+ target=target.repeat_interleave(args.caption_num_return_sequences, dim=0),
216
+ max_text_length=args.caption_max_len,
217
+ top_k=args.caption_top_k,
218
+ top_p=args.caption_top_p,
219
+ num_return_sequences=args.caption_num_return_sequences,
220
+ temperature=args.caption_temperature,
221
+ early_stopping=args.caption_early_stop,
222
+ )
223
+ elif args.caption_sample == 'beam_sample':
224
+ assert args.caption_num_beam_groups == 1
225
+ generated_text_ids, ppls = dist_utils.get_model(model).beam_sample(
226
+ image_features,
227
+ tokenizer,
228
+ target=target,
229
+ max_text_length=args.caption_max_len,
230
+ top_k=args.caption_top_k,
231
+ top_p=args.caption_top_p,
232
+ temperature=args.caption_temperature,
233
+ length_penalty=args.caption_length_penalty,
234
+ num_beams=args.caption_num_beams,
235
+ num_return_sequences=args.caption_num_return_sequences,
236
+ early_stopping=args.caption_early_stop,
237
+ )
238
+ elif args.caption_sample == 'group_beam_search':
239
+ assert args.caption_num_beam_groups > 1 and args.caption_num_beams % args.caption_num_beam_groups == 0
240
+ generated_text_ids, ppls = dist_utils.get_model(model).group_beam_search(
241
+ image_features,
242
+ tokenizer,
243
+ target=target if not args.caption_no_gt else None,
244
+ max_text_length=args.caption_max_len,
245
+ top_k=args.caption_top_k,
246
+ top_p=args.caption_top_p,
247
+ temperature=args.caption_temperature,
248
+ length_penalty=args.caption_length_penalty,
249
+ num_beams=args.caption_num_beams,
250
+ num_beam_groups=args.caption_num_beam_groups,
251
+ num_return_sequences=args.caption_num_return_sequences,
252
+ early_stopping=args.caption_early_stop,
253
+ )
254
+ else:
255
+ raise NotImplementedError
256
+ ppls_all.append(ppls.reshape(-1, args.caption_num_return_sequences).mean(1))
257
+ ppls_with_teacher_all.append(ppls_with_teacher)
258
+
259
+ for j in range(generated_text_ids.shape[0] // args.caption_num_return_sequences):
260
+ for k in range(args.caption_num_return_sequences):
261
+ jj = j * args.caption_num_return_sequences + k
262
+
263
+ generated_text_str = decode_one(generated_text_ids[jj], tokenizer)
264
+ gt_text = decode_one(target[j], tokenizer)
265
+ generated_text_str_with_teacher = decode_one(generated_text_ids_with_teacher[j], tokenizer)
266
+
267
+ from transformers import BertTokenizer
268
+ bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
269
+ gt_text = bert_tokenizer.decode(bert_tokenizer(gt_text)['input_ids'][1:-1])
270
+ generated_text_str = bert_tokenizer.decode(bert_tokenizer(generated_text_str)['input_ids'][1:-1])
271
+ generated_text_str_with_teacher = bert_tokenizer.decode(bert_tokenizer(generated_text_str_with_teacher)['input_ids'][1:-1])
272
+ reference.append(gt_text)
273
+ hypothesis.append(generated_text_str)
274
+ s1 = '[{:6d}] Groundtruth | | {}'.format(id_offset + j, gt_text)
275
+ s2 = '[{:6d}] Generated | PPL : {:9.3f} | {}'.format(id_offset + j, ppls[jj], generated_text_str)
276
+ s3 = '[{:6d}] Generated (w/. teacher) | PPL : {:9.3f} | {}'.format(id_offset + j, ppls_with_teacher[j], generated_text_str_with_teacher)
277
+ for s in [s1, s2, s3]:
278
+ # if i % args.print_freq == 0:
279
+ # print(s)
280
+ f.write('{} \n'.format(s))
281
+ id_offset += generated_text_ids.shape[0] // args.caption_num_return_sequences
282
+
283
+ ppls_with_teacher_all = torch.cat(ppls_with_teacher_all, dim=0)
284
+ ppls_all = torch.cat(ppls_all, dim=0)
285
+
286
+ print('PPL (w/. teacher) = {:9.3f}'.format(ppls_with_teacher_all.mean().item()))
287
+ print('PPL (w/o. teacher) = {:9.3f}'.format(ppls_all.mean().item()))
288
+ f.write('PPL (w/. teacher) = {:9.3f} \n'.format(ppls_with_teacher_all.mean().item()))
289
+ f.write('PPL (w/o. teacher) = {:9.3f} \n'.format(ppls_all.mean().item()))
290
+
291
+ print('Avg length for reference: {:9.3f}'.format(sum(map(lambda sentence: len(sentence.split(' ')), reference)) / len(reference)))
292
+ print('Avg length for hypothesis: {:9.3f}'.format(sum(map(lambda sentence: len(sentence.split(' ')), hypothesis)) / len(hypothesis)))
293
+ f.write('Avg length for reference: {:9.3f} \n'.format(sum(map(lambda sentence: len(sentence.split(' ')), reference)) / len(reference)))
294
+ f.write('Avg length for hypothesis: {:9.3f} \n'.format(sum(map(lambda sentence: len(sentence.split(' ')), hypothesis)) / len(hypothesis)))
295
+
296
+ print('=> Calling NLGEval')
297
+ f.write('=> Calling NLGEval\n')
298
+ metrics_dict = nlgeval.compute_metrics([reference], hypothesis)
299
+ for k in metrics_dict:
300
+ print('{:16s} = {:9.3f}'.format(k, metrics_dict[k]))
301
+ f.write('{:16s} = {:9.3f} \n'.format(k, metrics_dict[k]))
302
+ f.close()
303
+
304
+
305
+ if __name__ == '__main__':
306
+ parser = argparse.ArgumentParser('lavila 0-shot evaluations', parents=[get_args_parser()])
307
+ args = parser.parse_args()
308
+ main(args)
eval_zeroshot.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import numpy as np
9
+ import os.path as osp
10
+ import time
11
+ from collections import OrderedDict
12
+
13
+ import pandas as pd
14
+ import torch
15
+ import torchvision.transforms as transforms
16
+ import torchvision.transforms._transforms_video as transforms_video
17
+ from sklearn.metrics import confusion_matrix
18
+
19
+ from lavila.data import datasets
20
+ from lavila.data.video_transforms import Permute, SpatialCrop, TemporalCrop
21
+ from lavila.models import models
22
+ from lavila.models.utils import inflate_positional_embeds
23
+ from lavila.utils import distributed as dist_utils
24
+ from lavila.utils.evaluation import accuracy, get_mean_accuracy
25
+ from lavila.utils.evaluation_egomcq import egomcq_accuracy_metrics
26
+ from lavila.utils.evaluation_ek100mir import (calculate_k_counts, calculate_IDCG, calculate_mAP, calculate_nDCG)
27
+ from lavila.utils.evaluation_charades import charades_map
28
+ from lavila.utils.preprocess import generate_label_map, generate_tokenizer
29
+
30
+
31
+ def get_args_parser():
32
+ parser = argparse.ArgumentParser(description='LAVILA 0-shot evaluations', add_help=False)
33
+ parser.add_argument('--dataset', default='ek100_mir', type=str,
34
+ choices=['ek100_cls', 'ek100_mir', 'charades_ego', 'egtea', 'ego4d_mcq'])
35
+ parser.add_argument('--root',
36
+ default='datasets/EK100/video_ht256px/',
37
+ type=str, help='path to dataset root')
38
+ parser.add_argument('--metadata-val',
39
+ default='datasets/EK100/epic-kitchens-100-annotations/retrieval_annotations/EPIC_100_retrieval_test.csv',
40
+ type=str, help='path to metadata file (val set)')
41
+ parser.add_argument('--relevancy-path',
42
+ default='datasets/EK100/epic-kitchens-100-annotations/retrieval_annotations/relevancy/caption_relevancy_EPIC_100_retrieval_test.pkl',
43
+ type=str, help='path to relevancy matrix (val set)')
44
+ parser.add_argument('--output-dir', default='./', type=str, help='output dir')
45
+ parser.add_argument('--num-crops', default=1, type=int, help='number of crops in transforms')
46
+ parser.add_argument('--num-clips', default=1, type=int, help='number of clips (for untrimmed videos, eg. Charades)')
47
+ parser.add_argument('--clip-length', default=4, type=int, help='clip length')
48
+ parser.add_argument('--clip-stride', default=16, type=int, help='clip stride')
49
+ parser.add_argument('--sparse-sample', action='store_true', help='switch to sparse sampling')
50
+ parser.add_argument('--batch-size', default=16, type=int, help='batch_size')
51
+ parser.add_argument('--cls-use-template', action='store_true', help='use prompt in 0-shot classification')
52
+ parser.add_argument('--print-freq', default=100, type=int)
53
+ parser.add_argument('-j', '--workers', default=10, type=int, metavar='N',
54
+ help='number of data loading workers per process')
55
+ parser.add_argument('--resume', default='', type=str, help='path to latest checkpoint')
56
+ parser.add_argument('--use-half', action='store_true')
57
+ return parser
58
+
59
+
60
+ def main(args):
61
+ if args.resume:
62
+ ckpt_path = args.resume
63
+ elif osp.isfile(osp.join(args.output_dir, 'checkpoint_best.pt')):
64
+ ckpt_path = osp.join(args.output_dir, 'checkpoint_best.pt')
65
+ else:
66
+ raise Exception('no checkpoint found')
67
+
68
+ ckpt = torch.load(ckpt_path, map_location='cpu')
69
+
70
+ # create model
71
+ state_dict = OrderedDict()
72
+ for k, v in ckpt['state_dict'].items():
73
+ state_dict[k.replace('module.', '')] = v
74
+
75
+ old_args = ckpt['args']
76
+ print('=> creating model: {}'.format(old_args.model))
77
+ model = getattr(models, old_args.model)(
78
+ text_use_cls_token=old_args.use_cls_token,
79
+ project_embed_dim=old_args.project_embed_dim,
80
+ gated_xattn=False if 'gated_xattn' not in old_args else old_args.gated_xattn,
81
+ timesformer_gated_xattn=False if 'timesformer_gated_xattn' not in old_args else old_args.timesformer_gated_xattn,
82
+ timesformer_freeze_space=False if 'timesformer_freeze_space' not in old_args else old_args.timesformer_freeze_space,
83
+ freeze_lm_vclm=False if 'freeze_lm_vclm' not in old_args else old_args.freeze_lm_vclm,
84
+ freeze_visual_vclm=False if 'freeze_visual_vclm' not in old_args else old_args.freeze_visual_vclm,
85
+ num_frames=args.clip_length,
86
+ drop_path_rate=0,
87
+ )
88
+ model.cuda()
89
+ if 'TIMESFORMER' in old_args.model or 'EGOVLP' in old_args.model:
90
+ # inflate weight
91
+ print('=> inflating PE in models due to different frame numbers')
92
+ state_dict = inflate_positional_embeds(
93
+ model.state_dict(), state_dict,
94
+ num_frames=args.clip_length,
95
+ load_temporal_fix='bilinear',
96
+ )
97
+ model.load_state_dict(state_dict, strict=True)
98
+ print("=> loaded resume checkpoint '{}' (epoch {}, best_metric = {})".format(args.resume, ckpt['epoch'], ckpt['best_acc1']))
99
+
100
+ torch.backends.cudnn.benchmark = True
101
+
102
+ if args.dataset in ['ek100_cls', 'charades_ego', 'egtea']:
103
+ labels, mapping_vn2act = generate_label_map(args.dataset)
104
+ else:
105
+ mapping_vn2act = None
106
+ tokenizer = generate_tokenizer(old_args.model)
107
+ crop_size = 224 if '336PX' not in old_args.model else 336
108
+ if args.num_crops == 1 and args.num_clips == 1:
109
+ val_transform = transforms.Compose([
110
+ Permute([3, 0, 1, 2]), # T H W C -> C T H W
111
+ transforms.Resize(crop_size),
112
+ transforms.CenterCrop(crop_size),
113
+ (transforms_video.NormalizeVideo(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]) if ('OPENAI' not in old_args.model) else
114
+ transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305])),
115
+ ])
116
+ else:
117
+ val_transform = transforms.Compose([
118
+ Permute([3, 0, 1, 2]), # T H W C -> C T H W
119
+ transforms.Resize(crop_size),
120
+ (transforms_video.NormalizeVideo(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]) if ('OPENAI' not in old_args.model) else
121
+ transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305])),
122
+ TemporalCrop(frames_per_clip=args.clip_length, stride=args.clip_length),
123
+ SpatialCrop(crop_size=crop_size, num_crops=args.num_crops),
124
+ ])
125
+
126
+ val_dataset = datasets.get_downstream_dataset(
127
+ val_transform, tokenizer, args, subset='val', label_mapping=mapping_vn2act,
128
+ )
129
+
130
+ val_loader = torch.utils.data.DataLoader(
131
+ val_dataset, batch_size=args.batch_size, shuffle=False,
132
+ num_workers=args.workers, pin_memory=True, drop_last=False)
133
+
134
+ if args.cls_use_template:
135
+ templates = ['#C C {}', '#C {}']
136
+ else:
137
+ templates = ['{}']
138
+
139
+ if args.dataset in ['ek100_cls', 'charades_ego', 'egtea']:
140
+ preds, targets = validate_zeroshot(val_loader, templates, labels, model, tokenizer)
141
+ if args.dataset == 'ek100_cls':
142
+ if args.use_half:
143
+ preds = preds.float()
144
+ top1, top5 = accuracy(preds, targets, topk=(1, 5))
145
+ print('top1 = {:.3f}'.format(top1.item()))
146
+ print('top5 = {:.3f}'.format(top5.item()))
147
+ elif args.dataset == 'charades_ego':
148
+ preds, targets = preds.numpy(), targets.numpy()
149
+ m_ap, _, _ = charades_map(preds, targets)
150
+ print('mAP = {:.3f}'.format(m_ap))
151
+ elif args.dataset == 'egtea':
152
+ preds, targets = preds.numpy(), targets.numpy()
153
+ print(preds.shape, targets.shape)
154
+ cm = confusion_matrix(targets, preds.argmax(axis=1))
155
+ mean_class_acc, acc = get_mean_accuracy(cm)
156
+ print('Mean Acc. = {:.3f}, Top-1 Acc. = {:.3f}'.format(mean_class_acc, acc))
157
+
158
+ if args.dataset == 'ek100_mir':
159
+ val_dataset = datasets.VideoCaptionDatasetCLIP(
160
+ 'ek100_mir',
161
+ args.root,
162
+ args.metadata_val,
163
+ transform=val_transform, is_training=False,
164
+ tokenizer=tokenizer,
165
+ clip_length=args.clip_length,
166
+ clip_stride=args.clip_stride,
167
+ sparse_sample=False
168
+ )
169
+ val_loader = torch.utils.data.DataLoader(
170
+ val_dataset, batch_size=args.batch_size, shuffle=False,
171
+ num_workers=args.workers, pin_memory=True, drop_last=False
172
+ )
173
+ similarity_matrix = get_similarity_matrix(val_loader, model, print_freq=args.print_freq, use_half=args.use_half)
174
+ similarity_matrix = (similarity_matrix + 1) / 2
175
+ video_id = pd.read_csv(args.metadata_val).values[:, 0]
176
+ text_id = pd.read_csv(args.metadata_val.replace("test.csv", "test_sentence.csv")).values[:, 0]
177
+ indexes = [video_id.tolist().index(elem) for elem in text_id]
178
+ similarity_matrix = similarity_matrix[:, indexes]
179
+ print(similarity_matrix.shape)
180
+ rel_matrix = pd.read_pickle(args.relevancy_path)
181
+ vis_map = calculate_mAP(similarity_matrix, rel_matrix)
182
+ txt_map = calculate_mAP(similarity_matrix.T, rel_matrix.T)
183
+ print('mAP: V->T: {:.3f} T->V: {:.3f} AVG: {:.3f}'.format(vis_map, txt_map, (vis_map + txt_map) / 2))
184
+ vis_k_counts = calculate_k_counts(rel_matrix)
185
+ txt_k_counts = calculate_k_counts(rel_matrix.T)
186
+ vis_IDCG = calculate_IDCG(rel_matrix, vis_k_counts)
187
+ txt_IDCG = calculate_IDCG(rel_matrix.T, txt_k_counts)
188
+ vis_nDCG = calculate_nDCG(similarity_matrix, rel_matrix, k_counts=vis_k_counts, IDCG=vis_IDCG)
189
+ txt_nDCG = calculate_nDCG(similarity_matrix.T, rel_matrix.T, k_counts=txt_k_counts, IDCG=txt_IDCG)
190
+ print('nDCG: V->T: {:.3f} T->V: {:.3f} AVG: {:.3f}'.format(vis_nDCG, txt_nDCG, (vis_nDCG + txt_nDCG) / 2))
191
+
192
+ if args.dataset == 'ego4d_mcq':
193
+ val_dataset = datasets.VideoCaptionDatasetMCQ(
194
+ args.dataset,
195
+ args.root,
196
+ args.metadata_val,
197
+ transform=val_transform, is_training=False,
198
+ tokenizer=tokenizer,
199
+ clip_length=args.clip_length,
200
+ clip_stride=args.clip_stride,
201
+ sparse_sample=False,
202
+ )
203
+ val_loader = torch.utils.data.DataLoader(
204
+ val_dataset, batch_size=args.batch_size, shuffle=False,
205
+ num_workers=args.workers, pin_memory=True, drop_last=False
206
+ )
207
+ validate_mcq(val_loader, model, use_half=args.use_half)
208
+
209
+
210
+ def validate_zeroshot(val_loader, templates, labels, model, tokenizer):
211
+ model.eval()
212
+ if args.use_half:
213
+ model = model.half()
214
+ all_outputs = []
215
+ all_targets = []
216
+ all_vis_features = []
217
+ print('=> encoding captions')
218
+ with torch.no_grad():
219
+ text_features = []
220
+ for label in labels:
221
+ if isinstance(label, list):
222
+ texts = [tmpl.format(lbl) for tmpl in templates for lbl in label]
223
+ else:
224
+ texts = [tmpl.format(label) for tmpl in templates]
225
+ texts = tokenizer(texts)
226
+ if isinstance(texts, tuple):
227
+ # Bert-style tokenizer will output both ids and mask
228
+ texts, masks = texts
229
+ texts = texts.cuda(non_blocking=True)
230
+ masks = masks.cuda(non_blocking=True)
231
+ else:
232
+ texts = texts.cuda(non_blocking=True)
233
+ masks = None
234
+ texts = texts.view(-1, 77).contiguous()
235
+ masks = masks.view(-1, 77).contiguous() if masks is not None else None
236
+ if masks is not None:
237
+ class_embeddings = dist_utils.get_model(model).encode_text(texts, attention_mask=masks)
238
+ else:
239
+ class_embeddings = dist_utils.get_model(model).encode_text(texts)
240
+ class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
241
+
242
+ class_embeddings = class_embeddings.mean(dim=0)
243
+ class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
244
+
245
+ text_features.append(class_embeddings)
246
+ text_features = torch.stack(text_features, dim=0)
247
+
248
+ print('=> start forwarding')
249
+ end_time = time.time()
250
+ for i, (images, target) in enumerate(val_loader):
251
+ if i % args.print_freq == 0:
252
+ print('finish batch {}/{} in {} sec'.format(i, len(val_loader), time.time() - end_time))
253
+ end_time = time.time()
254
+ if isinstance(images, torch.Tensor):
255
+ images = images.cuda(non_blocking=True)
256
+ if args.use_half:
257
+ images = images.half()
258
+ target = target.cuda(non_blocking=True)
259
+
260
+ # encode images
261
+ image_features = dist_utils.get_model(model).encode_image(images)
262
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
263
+ all_vis_features.append(image_features)
264
+ # cosine similarity as logits
265
+ logits_per_image = image_features @ text_features.t()
266
+ # logits_per_image = torch.softmax(logits_per_image, dim=1)
267
+ else:
268
+ target = target.cuda(non_blocking=True)
269
+ images_list = images
270
+ logits_all_clips = []
271
+ for images in images_list:
272
+ images = images.cuda(non_blocking=True)
273
+ if args.use_half:
274
+ images = images.half()
275
+ image_features = dist_utils.get_model(model).encode_image(images)
276
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
277
+ logits_per_image = image_features @ text_features.t()
278
+ logits_all_clips.append(logits_per_image)
279
+
280
+ logits_all_clips = torch.stack(logits_all_clips, dim=0)
281
+ logits_per_image = logits_all_clips.max(0).values
282
+ # logits_per_image = logits_all_clips.mean(0)
283
+ logits_per_image = torch.softmax(logits_per_image, dim=1)
284
+
285
+ all_outputs.append(logits_per_image.cpu())
286
+ all_targets.append(target.cpu())
287
+
288
+ return torch.cat(all_outputs), torch.cat(all_targets)
289
+
290
+
291
+ def get_similarity_matrix(val_loader, model, print_freq=100, use_half=False):
292
+ model.eval()
293
+ if use_half:
294
+ model = model.half()
295
+ all_text_embed = []
296
+ all_video_embed = []
297
+ with torch.no_grad():
298
+ print('=> encoding visual and textual')
299
+ for i, inputs in enumerate(val_loader):
300
+ if i % print_freq == 0:
301
+ print('finish batch {}/{}'.format(i, len(val_loader)))
302
+ frames = inputs[0].cuda(non_blocking=True)
303
+ if use_half:
304
+ frames = frames.half()
305
+ texts = inputs[1].cuda(non_blocking=True)
306
+ if len(inputs) == 4:
307
+ masks = inputs[2].cuda(non_blocking=True)
308
+ else:
309
+ masks = None
310
+
311
+ # encode images
312
+ image_features = dist_utils.get_model(model).encode_image(frames)
313
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
314
+ all_video_embed.append(image_features.cpu().numpy())
315
+
316
+ if texts.ndim == 3:
317
+ is_multiple_narrations = True
318
+ texts = texts.view(-1, texts.shape[-1])
319
+ else:
320
+ is_multiple_narrations = False
321
+ if masks is not None:
322
+ text_features = dist_utils.get_model(model).encode_text(texts, attention_mask=masks)
323
+ else:
324
+ text_features = dist_utils.get_model(model).encode_text(texts)
325
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
326
+ all_text_embed.append(text_features.cpu().numpy())
327
+
328
+ all_text_embed = np.vstack(all_text_embed)
329
+ all_video_embed = np.vstack(all_video_embed)
330
+ similarity_matrix = np.matmul(all_video_embed, all_text_embed.T)
331
+ if is_multiple_narrations:
332
+ similarity_matrix = similarity_matrix.reshape(all_video_embed.shape[0], all_video_embed.shape[0], -1)
333
+
334
+ return similarity_matrix
335
+
336
+
337
+ def validate_mcq(val_loader, model, use_half=False):
338
+ model.eval()
339
+ if use_half:
340
+ model.half()
341
+ with torch.no_grad():
342
+ print('=> start forwarding')
343
+ all_preds = []
344
+ all_gts = []
345
+ all_types = []
346
+ end_time = time.time()
347
+ for i, inputs in enumerate(val_loader):
348
+ if i % args.print_freq == 0:
349
+ print('finish batch {}/{} in {} sec'.format(i, len(val_loader), time.time() - end_time))
350
+ end_time = time.time()
351
+ texts_query = inputs[0].cuda(non_blocking=True)
352
+ frames_options = inputs[1].cuda(non_blocking=True)
353
+ if use_half:
354
+ frames_options = frames_options.half()
355
+ answer = inputs[3]
356
+ q_type = inputs[4]
357
+ if len(inputs) == 7:
358
+ masks_query = inputs[5].cuda(non_blocking=True)
359
+ else:
360
+ masks_query = None
361
+
362
+ batch_size = frames_options.shape[0]
363
+
364
+ frames_options = frames_options.view(-1, *frames_options.shape[2:])
365
+ image_features = dist_utils.get_model(model).encode_image(frames_options)
366
+ image_features = image_features.view(batch_size, -1, *image_features.shape[1:])
367
+
368
+ if masks_query is not None:
369
+ query_features = dist_utils.get_model(model).encode_text(texts_query, attention_mask=masks_query)
370
+ else:
371
+ query_features = dist_utils.get_model(model).encode_text(texts_query)
372
+
373
+ all_gts.append(answer)
374
+ all_types.append(q_type)
375
+ for j in range(batch_size):
376
+ similarity_matrix = torch.matmul(query_features[j], image_features[j].T)
377
+ similarity_matrix = similarity_matrix.cpu().detach()
378
+ all_preds.append(similarity_matrix)
379
+ all_preds = torch.stack(all_preds)
380
+ all_gts = torch.cat(all_gts)
381
+ all_types = torch.cat(all_types)
382
+ metrics = egomcq_accuracy_metrics(all_preds, all_gts, all_types)
383
+ print(metrics)
384
+
385
+
386
+ if __name__ == '__main__':
387
+ parser = argparse.ArgumentParser('lavila 0-shot evaluations', parents=[get_args_parser()])
388
+ args = parser.parse_args()
389
+ main(args)
lavila/data/__pycache__/datasets.cpython-38.pyc ADDED
Binary file (14.4 kB). View file
 
lavila/data/__pycache__/video_transforms.cpython-38.pyc ADDED
Binary file (6.25 kB). View file
 
lavila/data/datasets.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import csv
8
+ import glob
9
+ import json
10
+ import numpy as np
11
+ import os.path as osp
12
+ import pickle
13
+ import random
14
+
15
+ import decord
16
+ import pandas as pd
17
+ import torch
18
+
19
+
20
+ def datetime2sec(str):
21
+ hh, mm, ss = str.split(':')
22
+ return int(hh) * 3600 + int(mm) * 60 + float(ss)
23
+
24
+
25
+ def video_loader(root, vid, second, end_second=None, chunk_len=300, fps=30, clip_length=32, jitter=False):
26
+ if chunk_len == -1:
27
+ vr = decord.VideoReader(osp.join(root, '{}.mp4'.format(vid)))
28
+ second_offset = second
29
+ if end_second is not None:
30
+ end_second = min(end_second, len(vr) / vr.get_avg_fps())
31
+ else:
32
+ end_second = len(vr) / vr.get_avg_fps()
33
+ else:
34
+ chunk_start = int(second) // chunk_len * chunk_len
35
+ second_offset = second - chunk_start
36
+ vr = decord.VideoReader(osp.join(root, '{}.mp4'.format(vid), '{}.mp4'.format(chunk_start)))
37
+ if fps == -1:
38
+ fps = vr.get_avg_fps()
39
+
40
+ # calculate frame_ids
41
+ frame_offset = int(np.round(second_offset * fps))
42
+ total_duration = max(int((end_second - second) * fps), clip_length)
43
+ if chunk_len == -1:
44
+ if end_second <= second:
45
+ raise ValueError("end_second should be greater than second")
46
+ else:
47
+ frame_ids = get_frame_ids(frame_offset, min(frame_offset + total_duration, len(vr)), num_segments=clip_length, jitter=jitter)
48
+ else:
49
+ frame_ids = get_frame_ids(frame_offset, frame_offset + total_duration, num_segments=clip_length, jitter=jitter)
50
+
51
+ # load frames
52
+ if max(frame_ids) < len(vr):
53
+ try:
54
+ frames = vr.get_batch(frame_ids).asnumpy()
55
+ except decord.DECORDError as error:
56
+ print(error)
57
+ frames = vr.get_batch([0] * len(frame_ids)).asnumpy()
58
+ else:
59
+ # find the remaining frames in the next chunk
60
+ try:
61
+ frame_ids_part1 = list(filter(lambda frame_id: frame_id < len(vr), frame_ids))
62
+ frames_part1 = vr.get_batch(frame_ids_part1).asnumpy()
63
+ vr2 = decord.VideoReader(osp.join(root, '{}.mp4'.format(vid), '{}.mp4'.format(chunk_start + chunk_len)))
64
+ frame_ids_part2 = list(filter(lambda frame_id: frame_id >= len(vr), frame_ids))
65
+ frame_ids_part2 = [min(frame_id % len(vr), len(vr2) - 1) for frame_id in frame_ids_part2]
66
+ frames_part2 = vr2.get_batch(frame_ids_part2).asnumpy()
67
+ frames = np.concatenate([frames_part1, frames_part2], axis=0)
68
+ # the next chunk does not exist; the current chunk is the last one
69
+ except (RuntimeError, decord.DECORDError) as error:
70
+ print(error)
71
+ frame_ids = get_frame_ids(min(frame_offset, len(vr) - 1), len(vr), num_segments=clip_length, jitter=jitter)
72
+ frames = vr.get_batch(frame_ids).asnumpy()
73
+
74
+ frames = [torch.tensor(frame, dtype=torch.float32) for frame in frames]
75
+ return torch.stack(frames, dim=0)
76
+
77
+
78
+ def get_frame_ids(start_frame, end_frame, num_segments=32, jitter=True):
79
+ seg_size = float(end_frame - start_frame - 1) / num_segments
80
+ seq = []
81
+ for i in range(num_segments):
82
+ start = int(np.round(seg_size * i) + start_frame)
83
+ end = int(np.round(seg_size * (i + 1)) + start_frame)
84
+ end = min(end, end_frame)
85
+ if jitter:
86
+ frame_id = np.random.randint(low=start, high=(end + 1))
87
+ else:
88
+ frame_id = (start + end) // 2
89
+ seq.append(frame_id)
90
+ return seq
91
+
92
+
93
+ def video_loader_by_frames(root, vid, frame_ids):
94
+ vr = decord.VideoReader(osp.join(root, vid))
95
+ try:
96
+ frames = vr.get_batch(frame_ids).asnumpy()
97
+ frames = [torch.tensor(frame, dtype=torch.float32) for frame in frames]
98
+ except (IndexError, decord.DECORDError) as error:
99
+ print(error)
100
+ print("Erroneous video: ", vid)
101
+ frames = [torch.zeros((240, 320, 3)) for _ in range(len(frame_ids))]
102
+ return torch.stack(frames, dim=0)
103
+
104
+
105
+ class VideoCaptionDatasetBase(torch.utils.data.Dataset):
106
+ def __init__(self, dataset, root, metadata, is_trimmed=True):
107
+ self.dataset = dataset
108
+ self.root = root
109
+ self.is_trimmed = is_trimmed
110
+
111
+ if self.dataset == 'ego4d':
112
+ with open(metadata, 'rb') as f:
113
+ self.samples = pickle.load(f)
114
+ elif self.dataset == 'ego4d_mcq':
115
+ with open(metadata, 'r') as f:
116
+ self.samples = json.load(f)
117
+ elif self.dataset in ['ek100_cls', 'ek100_mir']:
118
+ video_list = glob.glob(osp.join(self.root, '*/*.MP4'))
119
+ fps_dict = {video: decord.VideoReader(video).get_avg_fps() for video in video_list}
120
+ self.samples = []
121
+ with open(metadata) as f:
122
+ csv_reader = csv.reader(f)
123
+ _ = next(csv_reader) # skip the header
124
+ for row in csv_reader:
125
+ pid, vid = row[1:3]
126
+ # start_frame, end_frame = int(row[6]), int(row[7])
127
+ # Deprecated: some videos might have fps mismatch issue
128
+ start_timestamp, end_timestamp = datetime2sec(row[4]), datetime2sec(row[5])
129
+ narration = row[8]
130
+ verb, noun = int(row[10]), int(row[12])
131
+ vid_path = '{}/{}.MP4'.format(pid, vid)
132
+ fps = fps_dict[osp.join(self.root, vid_path)]
133
+ start_frame = int(np.round(fps * start_timestamp))
134
+ end_frame = int(np.ceil(fps * end_timestamp))
135
+ self.samples.append((vid_path, start_frame, end_frame, narration, verb, noun))
136
+ if self.dataset == 'ek100_mir':
137
+ self.metadata_sentence = pd.read_csv(metadata[:metadata.index('.csv')] + '_sentence.csv')
138
+ if 'train' in metadata:
139
+ self.relevancy_mat = pickle.load(open(osp.join(osp.dirname(metadata), 'relevancy', 'caption_relevancy_EPIC_100_retrieval_train.pkl'), 'rb'))
140
+ elif 'test' in metadata:
141
+ self.relevancy_mat = pickle.load(open(osp.join(osp.dirname(metadata), 'relevancy', 'caption_relevancy_EPIC_100_retrieval_test.pkl'), 'rb'))
142
+ else:
143
+ raise ValueError('{} should contain either "train" or "test"!'.format(metadata))
144
+ self.relevancy = .1
145
+ elif self.dataset == 'egtea':
146
+ video_list = glob.glob(osp.join(self.root, '*/*'))
147
+ len_dict = {video: len(decord.VideoReader(video)) for video in video_list}
148
+
149
+ vn_list, labels = [], []
150
+ for row in open(osp.join(osp.dirname(metadata), 'action_idx.txt')):
151
+ row = row.strip()
152
+ vn = int(row.split(' ')[-1])
153
+ vn_list.append(vn)
154
+ narration = ' '.join(row.split(' ')[:-1])
155
+ labels.append(narration.replace('_', ' ').lower())
156
+ # labels.append(narration)
157
+ mapping_act2narration = {vn: narration for vn, narration in zip(vn_list, labels)}
158
+
159
+ self.samples = []
160
+ with open(metadata) as f:
161
+ for row in f:
162
+ clip_id, action_idx = row.strip().split(' ')[:2]
163
+ video_id = '-'.join(clip_id.split('-')[:3])
164
+ vid_relpath = osp.join(video_id, '{}.mp4'.format(clip_id))
165
+ vid_fullpath = osp.join(self.root, video_id, '{}.mp4'.format(clip_id))
166
+ self.samples.append((vid_relpath, 0, len_dict[vid_fullpath], mapping_act2narration[int(action_idx)]))
167
+ elif self.dataset == 'charades_ego':
168
+ video_list = glob.glob(osp.join(self.root, '*.mp4'))
169
+ fps_dict = {video: decord.VideoReader(video).get_avg_fps() for video in video_list}
170
+ self.samples = []
171
+ with open(metadata) as f:
172
+ csv_reader = csv.reader(f)
173
+ _ = next(csv_reader) # skip the header
174
+ for row in csv_reader:
175
+ video_id = row[0]
176
+ if self.is_trimmed:
177
+ for action_tuple in row[9].split(';'):
178
+ if not action_tuple:
179
+ continue
180
+ action, start_timestamp, end_timestamp = action_tuple.split(' ')
181
+ start_timestamp, end_timestamp = float(start_timestamp), float(end_timestamp)
182
+ vid_path = '{}.mp4'.format(video_id)
183
+ fps = fps_dict[osp.join(self.root, vid_path)]
184
+ start_frame = int(np.round(fps * start_timestamp))
185
+ end_frame = int(np.ceil(fps * end_timestamp))
186
+ self.samples.append((vid_path, start_frame, end_frame, action))
187
+ else:
188
+ if not row[9]:
189
+ action_list = []
190
+ else:
191
+ action_list = [action_tuple.split(' ')[0] for action_tuple in row[9].split(';')]
192
+ vid_path = '{}.mp4'.format(video_id)
193
+ fps = fps_dict[osp.join(self.root, vid_path)]
194
+ duration = fps * float(row[10])
195
+ self.samples.append((vid_path, 0, duration, action_list))
196
+ elif self.dataset == 'charades_ego_trimmed':
197
+ with open(metadata, 'rb') as f:
198
+ self.samples = pickle.load(f)
199
+ else:
200
+ raise NotImplementedError
201
+
202
+ def get_raw_item(self, i, is_training=True, num_clips=1, clip_length=32, clip_stride=2, sparse_sample=False,
203
+ narration_selection='random'):
204
+ if self.dataset == 'ego4d':
205
+ if len(self.samples[i]) == 4:
206
+ vid, start_second, end_second, narration = self.samples[i]
207
+ frames = video_loader(self.root, vid, start_second,
208
+ end_second=end_second,
209
+ clip_length=clip_length,
210
+ jitter=is_training)
211
+ if isinstance(narration, list):
212
+ if narration_selection == 'random':
213
+ narration = random.choice(narration)
214
+ elif narration_selection == 'concat':
215
+ narration = '. '.join(narration)
216
+ elif narration_selection == 'list':
217
+ narration = narration
218
+ else:
219
+ raise ValueError
220
+ return frames, narration
221
+ elif len(self.samples[i]) == 5:
222
+ # TODO: need better filtering strategy based on nll
223
+ vid, start_second, end_second, narration, _ = self.samples[i]
224
+ frames = video_loader(self.root, vid, start_second,
225
+ end_second=end_second,
226
+ clip_length=clip_length,
227
+ jitter=is_training)
228
+ if isinstance(narration, list):
229
+ if narration_selection == 'random':
230
+ narration = random.choice(narration)
231
+ elif narration_selection == 'concat':
232
+ narration = '. '.join(narration)
233
+ elif narration_selection == 'list':
234
+ narration = narration
235
+ else:
236
+ raise ValueError
237
+ return frames, narration
238
+ elif self.dataset == 'ego4d_mcq':
239
+ itemMCQ = self.samples[str(i)]
240
+ answerIndex = itemMCQ['answer']
241
+ textQuery = itemMCQ['query']['clip_text']
242
+ sampleOptions = itemMCQ['choices']
243
+ frames_options = []
244
+ narration_options = []
245
+ for option_id in range(len(sampleOptions)):
246
+ option = sampleOptions[str(option_id)]
247
+ frames = video_loader(self.root, option['video_uid'],
248
+ float(option['clip_start']), end_second=float(option['clip_end']),
249
+ clip_length=clip_length,
250
+ jitter=is_training)
251
+ frames_options.append(frames)
252
+ narration_options.append(option['clip_text'])
253
+ return textQuery, frames_options, narration_options, answerIndex, itemMCQ['types']
254
+ elif self.dataset == 'ek100_mir':
255
+ vid_path, start_frame, end_frame, narration, verb, noun = self.samples[i]
256
+ # from third_party.EgoVLP.base.base_dataset import sample_frames_start_end
257
+ # frame_ids = sample_frames_start_end(clip_length, start_frame, end_frame, sample='uniform', fix_start=None)
258
+ frame_ids = get_frame_ids(start_frame, end_frame, num_segments=clip_length, jitter=is_training)
259
+ frames = video_loader_by_frames(self.root, vid_path, frame_ids)
260
+ if is_training:
261
+ positive_list = np.where(self.relevancy_mat[i] > self.relevancy)[0].tolist()
262
+ if positive_list != []:
263
+ pos = random.sample(positive_list, min(len(positive_list), 1))[0]
264
+ if pos < len(self.metadata_sentence) and pos < self.relevancy_mat.shape[1]:
265
+ return frames, (self.metadata_sentence.iloc[pos][1], self.relevancy_mat[i][pos])
266
+ else:
267
+ return frames, (narration, 1)
268
+ elif self.dataset == 'ek100_cls':
269
+ vid_path, start_frame, end_frame, narration, verb, noun = self.samples[i]
270
+ frame_ids = get_frame_ids(start_frame, end_frame, num_segments=clip_length, jitter=is_training)
271
+ frames = video_loader_by_frames(self.root, vid_path, frame_ids)
272
+ return frames, '{}:{}'.format(verb, noun)
273
+ elif self.dataset == 'egtea':
274
+ vid_path, start_frame, end_frame, sentence = self.samples[i]
275
+ if is_training:
276
+ assert num_clips == 1
277
+ if end_frame < clip_length * clip_stride:
278
+ frames = video_loader_by_frames(self.root, vid_path, list(np.arange(0, end_frame)))
279
+ zeros = torch.zeros((clip_length * clip_stride - end_frame, *frames.shape[1:]))
280
+ frames = torch.cat((frames, zeros), dim=0)
281
+ frames = frames[::clip_stride]
282
+ else:
283
+ start_id = np.random.randint(0, end_frame - clip_length * clip_stride + 1)
284
+ frame_ids = np.arange(start_id, start_id + clip_length * clip_stride, clip_stride)
285
+ frames = video_loader_by_frames(self.root, vid_path, frame_ids)
286
+ else:
287
+ if end_frame < clip_length * clip_stride:
288
+ frames = video_loader_by_frames(self.root, vid_path, list(np.arange(0, end_frame)))
289
+ zeros = torch.zeros((clip_length * clip_stride - end_frame, *frames.shape[1:]))
290
+ frames = torch.cat((frames, zeros), dim=0)
291
+ frames = frames[::clip_stride]
292
+ frames = frames.repeat(num_clips, 1, 1, 1)
293
+ else:
294
+ frame_ids = []
295
+ for start_id in np.linspace(0, end_frame - clip_length * clip_stride, num_clips, dtype=int):
296
+ frame_ids.extend(np.arange(start_id, start_id + clip_length * clip_stride, clip_stride))
297
+ frames = video_loader_by_frames(self.root, vid_path, frame_ids)
298
+ return frames, sentence
299
+ elif self.dataset == 'charades_ego':
300
+ vid_path, start_frame, end_frame, action_list = self.samples[i]
301
+ if sparse_sample:
302
+ frame_ids = get_frame_ids(start_frame, end_frame, num_segments=num_clips * clip_length, jitter=is_training)
303
+ frames = video_loader_by_frames(self.root, vid_path, frame_ids)
304
+ else:
305
+ if end_frame < clip_length * clip_stride:
306
+ frames = video_loader_by_frames(self.root, vid_path, list(np.arange(0, end_frame)))
307
+ zeros = torch.zeros((clip_length * clip_stride - end_frame, *frames.shape[1:]))
308
+ frames = torch.cat((frames, zeros), dim=0)
309
+ frames = frames[::clip_stride]
310
+ frames = frames.repeat(num_clips, 1, 1, 1)
311
+ else:
312
+ frame_ids = []
313
+ for start_id in np.linspace(0, end_frame - clip_length * clip_stride, num_clips, dtype=int):
314
+ frame_ids.extend(np.arange(start_id, start_id + clip_length * clip_stride, clip_stride))
315
+ print('frame_ids:', frame_ids)
316
+ frames = video_loader_by_frames(self.root, vid_path, frame_ids)
317
+ return frames, action_list
318
+ elif self.dataset == 'charades_ego_trimmed':
319
+ vid, start_second, end_second, narration = self.samples[i]
320
+ frames = video_loader(self.root, vid, start_second,
321
+ end_second=end_second,
322
+ chunk_len=-1, # no chunk for CharadesEgo
323
+ fps=-1, # could be variable fps
324
+ clip_length=clip_length,
325
+ jitter=is_training)
326
+ return frames, narration
327
+ else:
328
+ raise NotImplementedError
329
+
330
+ def __getitem__(self, i):
331
+ raise NotImplementedError
332
+
333
+ def __len__(self):
334
+ return len(self.samples)
335
+
336
+
337
+ class VideoCaptionDatasetCLIP(VideoCaptionDatasetBase):
338
+ def __init__(self, dataset, root, metadata, transform=None,
339
+ is_training=True, tokenizer=None,
340
+ clip_length=32, clip_stride=2, sparse_sample=False,
341
+ narration_selection='random',
342
+ num_hard_negatives=0,
343
+ subsample_stride=None):
344
+ super().__init__(dataset, root, metadata)
345
+
346
+ self.full_samples = self.samples.copy()
347
+ if isinstance(subsample_stride, int):
348
+ self.samples = self.samples[::subsample_stride]
349
+ self.transform = transform
350
+ self.is_training = is_training
351
+ self.tokenizer = tokenizer
352
+ self.clip_length = clip_length
353
+ self.clip_stride = clip_stride
354
+ self.sparse_sample = sparse_sample
355
+ self.narration_selection = narration_selection
356
+ self.num_hard_negatives = num_hard_negatives
357
+ if num_hard_negatives > 0:
358
+ assert self.dataset == 'htm_aa'
359
+
360
+ def __getitem__(self, i):
361
+ frames, caption = self.get_raw_item(
362
+ i, is_training=self.is_training,
363
+ clip_length=self.clip_length,
364
+ clip_stride=self.clip_stride,
365
+ sparse_sample=self.sparse_sample,
366
+ narration_selection=self.narration_selection,
367
+ )
368
+
369
+ # ek100_mir will also output relevancy value
370
+ if isinstance(caption, tuple):
371
+ caption, relevancy = caption
372
+ else:
373
+ relevancy = 0.
374
+
375
+ # apply transformation
376
+ if self.transform is not None:
377
+ frames = self.transform(frames)
378
+
379
+ # tokenize caption
380
+ if self.tokenizer is not None:
381
+ caption = self.tokenizer(caption)
382
+
383
+ if isinstance(caption, tuple):
384
+ caption, mask = caption
385
+ return frames, caption, mask, relevancy
386
+ else:
387
+ return frames, caption, relevancy
388
+
389
+
390
+ class VideoCaptionDatasetMCQ(VideoCaptionDatasetBase):
391
+ def __init__(self, dataset, root, metadata, transform=None,
392
+ is_training=True, tokenizer=None,
393
+ clip_length=32, clip_stride=2, sparse_sample=False,
394
+ narration_selection='random'):
395
+ super().__init__(dataset, root, metadata)
396
+
397
+ self.full_samples = self.samples.copy()
398
+ self.transform = transform
399
+ self.is_training = is_training
400
+ self.tokenizer = tokenizer
401
+ self.clip_length = clip_length
402
+ self.clip_stride = clip_stride
403
+ self.sparse_sample = sparse_sample
404
+ self.narration_selection = narration_selection
405
+
406
+ def __getitem__(self, i):
407
+
408
+ textQuery, frames_options, narration_options, answerIndex, q_type = self.get_raw_item(
409
+ i, is_training=self.is_training,
410
+ clip_length=self.clip_length,
411
+ clip_stride=self.clip_stride,
412
+ sparse_sample=self.sparse_sample,
413
+ narration_selection=self.narration_selection,
414
+ )
415
+
416
+ # apply transformation
417
+ if self.transform is not None:
418
+ frames_options = [self.transform(frames) for frames in frames_options]
419
+
420
+ # tokenize caption
421
+ if self.tokenizer is not None:
422
+ textQuery = self.tokenizer(textQuery)
423
+ narration_options = self.tokenizer(narration_options)
424
+ if isinstance(textQuery, tuple):
425
+ textQuery, mask_query = textQuery
426
+ narration_options, mask_options = narration_options
427
+ return (
428
+ textQuery, torch.stack(frames_options, dim=0),
429
+ narration_options, answerIndex, q_type,
430
+ mask_query, mask_options
431
+ )
432
+ else:
433
+ return textQuery, torch.stack(frames_options, dim=0), narration_options, answerIndex, q_type
434
+
435
+
436
+ class VideoClassyDataset(VideoCaptionDatasetBase):
437
+ def __init__(
438
+ self, dataset, root, metadata, transform=None,
439
+ is_training=True, label_mapping=None,
440
+ num_clips=1,
441
+ clip_length=32, clip_stride=2,
442
+ sparse_sample=False,
443
+ is_trimmed=True,
444
+ ):
445
+ super().__init__(dataset, root, metadata, is_trimmed=is_trimmed)
446
+
447
+ self.transform = transform
448
+ self.is_training = is_training
449
+ self.label_mapping = label_mapping
450
+ self.num_clips = num_clips
451
+ self.clip_length = clip_length
452
+ self.clip_stride = clip_stride
453
+ self.sparse_sample = sparse_sample
454
+
455
+ def __getitem__(self, i):
456
+ frames, label = self.get_raw_item(
457
+ i, is_training=self.is_training,
458
+ num_clips=self.num_clips,
459
+ clip_length=self.clip_length,
460
+ clip_stride=self.clip_stride,
461
+ sparse_sample=self.sparse_sample,
462
+ )
463
+
464
+ # apply transformation
465
+ if self.transform is not None:
466
+ frames = self.transform(frames)
467
+
468
+ if self.label_mapping is not None:
469
+ if isinstance(label, list):
470
+ # multi-label case
471
+ res_array = np.zeros(len(self.label_mapping))
472
+ for lbl in label:
473
+ res_array[self.label_mapping[lbl]] = 1.
474
+ label = res_array
475
+ else:
476
+ label = self.label_mapping[label]
477
+
478
+ return frames, label
479
+
480
+
481
+ def get_dataset(train_transform, tokenizer, args, is_training=True):
482
+ if 'narration_selection' not in args:
483
+ args.narration_selection = 'random'
484
+ if args.model.startswith('CLIP') or args.model.startswith('VCLM'):
485
+ return VideoCaptionDatasetCLIP(
486
+ args.dataset, args.root, args.metadata, train_transform,
487
+ is_training=is_training,
488
+ tokenizer=tokenizer,
489
+ clip_length=args.clip_length, clip_stride=args.clip_stride,
490
+ sparse_sample=args.sparse_sample,
491
+ narration_selection=args.narration_selection,
492
+ num_hard_negatives=args.num_hard_neg if 'num_hard_neg' in args else 0,
493
+ )
494
+ else:
495
+ raise NotImplementedError
496
+
497
+
498
+ def get_downstream_dataset(transform, tokenizer, args, subset='train', label_mapping=None):
499
+ if subset == 'train':
500
+ return VideoClassyDataset(
501
+ args.dataset, args.root, args.metadata_train, transform,
502
+ is_training=True, label_mapping=label_mapping,
503
+ num_clips=args.num_clips,
504
+ clip_length=args.clip_length, clip_stride=args.clip_stride,
505
+ sparse_sample=args.sparse_sample,
506
+ )
507
+ elif subset == 'val':
508
+ return VideoClassyDataset(
509
+ args.dataset, args.root, args.metadata_val, transform,
510
+ is_training=False, label_mapping=label_mapping,
511
+ num_clips=args.num_clips,
512
+ clip_length=args.clip_length, clip_stride=args.clip_stride,
513
+ sparse_sample=args.sparse_sample,
514
+ is_trimmed=not args.dataset == 'charades_ego'
515
+ )
516
+ else:
517
+ assert ValueError("subset should be either 'train' or 'val'")
lavila/data/video_transforms.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from typing import Sequence
9
+ import torch
10
+ import torch.nn as nn
11
+ from torchvision import transforms
12
+
13
+
14
+ class Permute(nn.Module):
15
+ """
16
+ Permutation as an op
17
+ """
18
+
19
+ def __init__(self, ordering):
20
+ super().__init__()
21
+ self.ordering = ordering
22
+
23
+ def forward(self, frames):
24
+ """
25
+ Args:
26
+ frames in some ordering, by default (C, T, H, W)
27
+ Returns:
28
+ frames in the ordering that was specified
29
+ """
30
+ return frames.permute(self.ordering)
31
+
32
+
33
+ class TemporalCrop(nn.Module):
34
+ """
35
+ Convert the video into smaller clips temporally.
36
+ """
37
+
38
+ def __init__(
39
+ self, frames_per_clip: int = 8, stride: int = 8, frame_stride: int = 1
40
+ ):
41
+ super().__init__()
42
+ self.frames = frames_per_clip
43
+ self.stride = stride
44
+ self.frame_stride = frame_stride
45
+
46
+ def forward(self, video):
47
+ assert video.ndim == 4, "Must be (C, T, H, W)"
48
+ res = []
49
+ for start in range(
50
+ 0, video.size(1) - (self.frames * self.frame_stride) + 1, self.stride
51
+ ):
52
+ end = start + (self.frames) * self.frame_stride
53
+ res.append(video[:, start: end: self.frame_stride, ...])
54
+ return res
55
+
56
+
57
+ def crop_boxes(boxes, x_offset, y_offset):
58
+ """
59
+ Peform crop on the bounding boxes given the offsets.
60
+ Args:
61
+ boxes (ndarray or None): bounding boxes to peform crop. The dimension
62
+ is `num boxes` x 4.
63
+ x_offset (int): cropping offset in the x axis.
64
+ y_offset (int): cropping offset in the y axis.
65
+ Returns:
66
+ cropped_boxes (ndarray or None): the cropped boxes with dimension of
67
+ `num boxes` x 4.
68
+ """
69
+ cropped_boxes = boxes.copy()
70
+ cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset
71
+ cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset
72
+
73
+ return cropped_boxes
74
+
75
+
76
+ def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None):
77
+ """
78
+ Perform uniform spatial sampling on the images and corresponding boxes.
79
+ Args:
80
+ images (tensor): images to perform uniform crop. The dimension is
81
+ `num frames` x `channel` x `height` x `width`.
82
+ size (int): size of height and weight to crop the images.
83
+ spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
84
+ is larger than height. Or 0, 1, or 2 for top, center, and bottom
85
+ crop if height is larger than width.
86
+ boxes (ndarray or None): optional. Corresponding boxes to images.
87
+ Dimension is `num boxes` x 4.
88
+ scale_size (int): optinal. If not None, resize the images to scale_size before
89
+ performing any crop.
90
+ Returns:
91
+ cropped (tensor): images with dimension of
92
+ `num frames` x `channel` x `size` x `size`.
93
+ cropped_boxes (ndarray or None): the cropped boxes with dimension of
94
+ `num boxes` x 4.
95
+ """
96
+ assert spatial_idx in [0, 1, 2]
97
+ ndim = len(images.shape)
98
+ if ndim == 3:
99
+ images = images.unsqueeze(0)
100
+ height = images.shape[2]
101
+ width = images.shape[3]
102
+
103
+ if scale_size is not None:
104
+ if width <= height:
105
+ width, height = scale_size, int(height / width * scale_size)
106
+ else:
107
+ width, height = int(width / height * scale_size), scale_size
108
+ images = torch.nn.functional.interpolate(
109
+ images,
110
+ size=(height, width),
111
+ mode="bilinear",
112
+ align_corners=False,
113
+ )
114
+
115
+ y_offset = int(math.ceil((height - size) / 2))
116
+ x_offset = int(math.ceil((width - size) / 2))
117
+
118
+ if height > width:
119
+ if spatial_idx == 0:
120
+ y_offset = 0
121
+ elif spatial_idx == 2:
122
+ y_offset = height - size
123
+ else:
124
+ if spatial_idx == 0:
125
+ x_offset = 0
126
+ elif spatial_idx == 2:
127
+ x_offset = width - size
128
+ cropped = images[:, :, y_offset: y_offset + size, x_offset: x_offset + size]
129
+ cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
130
+ if ndim == 3:
131
+ cropped = cropped.squeeze(0)
132
+ return cropped, cropped_boxes
133
+
134
+
135
+ class SpatialCrop(nn.Module):
136
+ """
137
+ Convert the video into 3 smaller clips spatially. Must be used after the
138
+ temporal crops to get spatial crops, and should be used with
139
+ -2 in the spatial crop at the slowfast augmentation stage (so full
140
+ frames are passed in here). Will return a larger list with the
141
+ 3x spatial crops as well. It's useful for 3x4 testing (eg in SwinT)
142
+ or 3x10 testing in SlowFast etc.
143
+ """
144
+
145
+ def __init__(self, crop_size: int = 224, num_crops: int = 3):
146
+ super().__init__()
147
+ self.crop_size = crop_size
148
+ if num_crops == 6:
149
+ self.crops_to_ext = [0, 1, 2]
150
+ # I guess Swin uses 5 crops without flipping, but that doesn't
151
+ # make sense given they first resize to 224 and take 224 crops.
152
+ # (pg 6 of https://arxiv.org/pdf/2106.13230.pdf)
153
+ # So I'm assuming we can use flipped crops and that will add sth..
154
+ self.flipped_crops_to_ext = [0, 1, 2]
155
+ elif num_crops == 3:
156
+ self.crops_to_ext = [0, 1, 2]
157
+ self.flipped_crops_to_ext = []
158
+ elif num_crops == 1:
159
+ self.crops_to_ext = [1]
160
+ self.flipped_crops_to_ext = []
161
+ else:
162
+ raise NotImplementedError(
163
+ "Nothing else supported yet, "
164
+ "slowfast only takes 0, 1, 2 as arguments"
165
+ )
166
+
167
+ def forward(self, videos: Sequence[torch.Tensor]):
168
+ """
169
+ Args:
170
+ videos: A list of C, T, H, W videos.
171
+ Returns:
172
+ videos: A list with 3x the number of elements. Each video converted
173
+ to C, T, H', W' by spatial cropping.
174
+ """
175
+ assert isinstance(videos, list), "Must be a list of videos after temporal crops"
176
+ assert all([video.ndim == 4 for video in videos]), "Must be (C,T,H,W)"
177
+ res = []
178
+ for video in videos:
179
+ for spatial_idx in self.crops_to_ext:
180
+ res.append(uniform_crop(video, self.crop_size, spatial_idx)[0])
181
+ if not self.flipped_crops_to_ext:
182
+ continue
183
+ flipped_video = transforms.functional.hflip(video)
184
+ for spatial_idx in self.flipped_crops_to_ext:
185
+ res.append(uniform_crop(flipped_video, self.crop_size, spatial_idx)[0])
186
+ return res
lavila/models/__pycache__/distributed_utils.cpython-38.pyc ADDED
Binary file (2.95 kB). View file
 
lavila/models/__pycache__/gpt2_gated.cpython-38.pyc ADDED
Binary file (46.9 kB). View file
 
lavila/models/__pycache__/loss.cpython-38.pyc ADDED
Binary file (8.63 kB). View file
 
lavila/models/__pycache__/models.cpython-38.pyc ADDED
Binary file (22.6 kB). View file
 
lavila/models/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
lavila/models/coca.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Part of the code is from https://github.com/lucidrains/CoCa-pytorch/blob/main/coca_pytorch/coca_pytorch.py
8
+ # Modified by Yue Zhao
9
+ # The original code is under MIT License
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from torch import einsum
15
+ from einops import rearrange
16
+
17
+
18
+ def exists(val):
19
+ return val is not None
20
+
21
+
22
+ def default(val, d):
23
+ return val if exists(val) else d
24
+
25
+
26
+ # normalization
27
+ # they use layernorm without bias, something that pytorch does not offer
28
+ class LayerNorm(nn.Module):
29
+ def __init__(self, dim):
30
+ super().__init__()
31
+ self.gamma = nn.Parameter(torch.ones(dim))
32
+ self.register_buffer("beta", torch.zeros(dim))
33
+
34
+ def forward(self, x):
35
+ return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
36
+
37
+
38
+ class Residual(nn.Module):
39
+ def __init__(self, fn):
40
+ super().__init__()
41
+ self.fn = fn
42
+
43
+ def forward(self, x, *args, **kwargs):
44
+ return self.fn(x, *args, **kwargs) + x
45
+
46
+
47
+ # classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU for gating the feedforward
48
+ # https://arxiv.org/abs/2002.05202
49
+ class SwiGLU(nn.Module):
50
+ def forward(self, x):
51
+ x, gate = x.chunk(2, dim=-1)
52
+ return F.silu(gate) * x
53
+
54
+
55
+ class CrossAttention(nn.Module):
56
+ def __init__(
57
+ self,
58
+ dim,
59
+ *,
60
+ context_dim=None,
61
+ dim_head=64,
62
+ heads=8,
63
+ parallel_ff=False,
64
+ ff_mult=4,
65
+ norm_context=False
66
+ ):
67
+ super().__init__()
68
+ self.heads = heads
69
+ self.scale = dim_head ** -0.5
70
+ inner_dim = heads * dim_head
71
+ context_dim = default(context_dim, dim)
72
+
73
+ self.norm = LayerNorm(dim)
74
+ self.context_norm = LayerNorm(context_dim) if norm_context else nn.Identity()
75
+
76
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
77
+ self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False)
78
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
79
+
80
+ # whether to have parallel feedforward
81
+
82
+ ff_inner_dim = ff_mult * dim
83
+
84
+ self.ff = nn.Sequential(
85
+ nn.Linear(dim, ff_inner_dim * 2, bias=False),
86
+ SwiGLU(),
87
+ nn.Linear(ff_inner_dim, dim, bias=False)
88
+ ) if parallel_ff else None
89
+
90
+ def forward(self, x, context):
91
+ """
92
+ einstein notation
93
+ b - batch
94
+ h - heads
95
+ n, i, j - sequence length (base sequence length, source, target)
96
+ d - feature dimension
97
+ """
98
+
99
+ # pre-layernorm, for queries and context
100
+ x = self.norm(x)
101
+ context = self.context_norm(context)
102
+
103
+ # get queries
104
+ q = self.to_q(x)
105
+ q = rearrange(q, 'b n (h d) -> b h n d', h=self.heads)
106
+
107
+ # scale
108
+ q = q * self.scale
109
+
110
+ # get key / values
111
+ k, v = self.to_kv(context).chunk(2, dim=-1)
112
+
113
+ # query / key similarity
114
+ sim = einsum('b h i d, b j d -> b h i j', q, k)
115
+
116
+ # attention
117
+ sim = sim - sim.amax(dim=-1, keepdim=True)
118
+ attn = sim.softmax(dim=-1)
119
+
120
+ # aggregate
121
+ out = einsum('b h i j, b j d -> b h i d', attn, v)
122
+
123
+ # merge and combine heads
124
+ out = rearrange(out, 'b h n d -> b n (h d)')
125
+ out = self.to_out(out)
126
+
127
+ # add parallel feedforward (for multimodal layers)
128
+ if exists(self.ff):
129
+ out = out + self.ff(x)
130
+
131
+ return out
lavila/models/distributed_utils.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ # Part of the code is from
7
+ # `https://github.com/facebookresearch/vissl/blob/main/vissl/utils/distributed_utils.py` and
8
+ # `https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/generic/distributed_util.py`
9
+ # Modified by Yue Zhao
10
+ # The original code is under MIT License
11
+
12
+ import torch
13
+ import torch.distributed as dist
14
+ from typing import Tuple
15
+
16
+
17
+ def convert_to_distributed_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, str]:
18
+ """
19
+ For some backends, such as NCCL, communication only works if the
20
+ tensor is on the GPU. This helper function converts to the correct
21
+ device and returns the tensor + original device.
22
+ """
23
+ orig_device = "cpu" if not tensor.is_cuda else "gpu"
24
+ if (
25
+ torch.distributed.is_available()
26
+ and torch.distributed.get_backend() == torch.distributed.Backend.NCCL
27
+ and not tensor.is_cuda
28
+ ):
29
+ tensor = tensor.cuda()
30
+ return (tensor, orig_device)
31
+
32
+
33
+ def convert_to_normal_tensor(tensor: torch.Tensor, orig_device: str) -> torch.Tensor:
34
+ """
35
+ For some backends, such as NCCL, communication only works if the
36
+ tensor is on the GPU. This converts the tensor back to original device.
37
+ """
38
+ if tensor.is_cuda and orig_device == "cpu":
39
+ tensor = tensor.cpu()
40
+ return tensor
41
+
42
+
43
+ def is_distributed_training_run() -> bool:
44
+ return (
45
+ torch.distributed.is_available()
46
+ and torch.distributed.is_initialized()
47
+ and (torch.distributed.get_world_size() > 1)
48
+ )
49
+
50
+
51
+ class GatherLayer(torch.autograd.Function):
52
+ """
53
+ Gather tensors from all workers with support for backward propagation:
54
+ This implementation does not cut the gradients as torch.distributed.all_gather does.
55
+ """
56
+
57
+ @staticmethod
58
+ def forward(ctx, x):
59
+ output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
60
+ dist.all_gather(output, x)
61
+ return tuple(output)
62
+
63
+ @staticmethod
64
+ def backward(ctx, *grads):
65
+ all_gradients = torch.stack(grads)
66
+ dist.all_reduce(all_gradients)
67
+ return all_gradients[dist.get_rank()]
68
+
69
+
70
+ def gather_from_all(tensor: torch.Tensor) -> torch.Tensor:
71
+ """
72
+ Similar to classy_vision.generic.distributed_util.gather_from_all
73
+ except that it does not cut the gradients
74
+ """
75
+ if tensor.ndim == 0:
76
+ # 0 dim tensors cannot be gathered. so unsqueeze
77
+ tensor = tensor.unsqueeze(0)
78
+
79
+ if is_distributed_training_run():
80
+ tensor, orig_device = convert_to_distributed_tensor(tensor)
81
+ gathered_tensors = GatherLayer.apply(tensor)
82
+ gathered_tensors = [
83
+ convert_to_normal_tensor(_tensor, orig_device)
84
+ for _tensor in gathered_tensors
85
+ ]
86
+ else:
87
+ gathered_tensors = [tensor]
88
+ gathered_tensor = torch.cat(gathered_tensors, 0)
89
+ return gathered_tensor
lavila/models/gpt2_gated.py ADDED
@@ -0,0 +1,1615 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Part of the code is from https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
8
+ # Modified by Yue Zhao
9
+ #
10
+ #
11
+ # coding=utf-8
12
+ # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
13
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
14
+ #
15
+ # Licensed under the Apache License, Version 2.0 (the "License");
16
+ # you may not use this file except in compliance with the License.
17
+ # You may obtain a copy of the License at
18
+ #
19
+ # http://www.apache.org/licenses/LICENSE-2.0
20
+ #
21
+ # Unless required by applicable law or agreed to in writing, software
22
+ # distributed under the License is distributed on an "AS IS" BASIS,
23
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24
+ # See the License for the specific language governing permissions and
25
+ # limitations under the License.
26
+ """PyTorch OpenAI GPT-2 model."""
27
+
28
+ import copy
29
+ import math
30
+ import os
31
+ from dataclasses import dataclass
32
+ from typing import Optional, Tuple, Union
33
+
34
+ import torch
35
+ import torch.utils.checkpoint
36
+ from packaging import version
37
+ from torch import nn
38
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
39
+
40
+
41
+ if version.parse(torch.__version__) >= version.parse("1.6"):
42
+ is_amp_available = True
43
+ from torch.cuda.amp import autocast
44
+ else:
45
+ is_amp_available = False
46
+
47
+ from transformers.activations import ACT2FN
48
+ from transformers.modeling_outputs import (
49
+ BaseModelOutputWithPastAndCrossAttentions,
50
+ CausalLMOutputWithCrossAttentions,
51
+ SequenceClassifierOutputWithPast,
52
+ TokenClassifierOutput,
53
+ )
54
+ from transformers.modeling_utils import PreTrainedModel, SequenceSummary
55
+ from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
56
+ from transformers.utils import (
57
+ ModelOutput,
58
+ add_code_sample_docstrings,
59
+ add_start_docstrings,
60
+ add_start_docstrings_to_model_forward,
61
+ logging,
62
+ replace_return_docstrings,
63
+ )
64
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
65
+ from transformers.models.gpt2.configuration_gpt2 import GPT2Config
66
+
67
+
68
+ logger = logging.get_logger(__name__)
69
+
70
+ _CHECKPOINT_FOR_DOC = "gpt2"
71
+ _CONFIG_FOR_DOC = "GPT2Config"
72
+ _TOKENIZER_FOR_DOC = "GPT2Tokenizer"
73
+
74
+ GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [
75
+ "gpt2",
76
+ "gpt2-medium",
77
+ "gpt2-large",
78
+ "gpt2-xl",
79
+ "distilgpt2",
80
+ # See all GPT-2 models at https://huggingface.co/models?filter=gpt2
81
+ ]
82
+
83
+
84
+ def augment_gpt2_config(config, cross_attn_freq=1, gated_xattn=True):
85
+ new_config = copy.deepcopy(config)
86
+ new_config.add_cross_attention = True
87
+ new_config.add_cross_attention_freq = cross_attn_freq
88
+ new_config.is_tanh_gating = gated_xattn
89
+ return new_config
90
+
91
+
92
+ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
93
+ """Load tf checkpoints in a pytorch model"""
94
+ try:
95
+ import re
96
+
97
+ import tensorflow as tf
98
+ except ImportError:
99
+ logger.error(
100
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
101
+ "https://www.tensorflow.org/install/ for installation instructions."
102
+ )
103
+ raise
104
+ tf_path = os.path.abspath(gpt2_checkpoint_path)
105
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
106
+ # Load weights from TF model
107
+ init_vars = tf.train.list_variables(tf_path)
108
+ names = []
109
+ arrays = []
110
+ for name, shape in init_vars:
111
+ logger.info(f"Loading TF weight {name} with shape {shape}")
112
+ array = tf.train.load_variable(tf_path, name)
113
+ names.append(name)
114
+ arrays.append(array.squeeze())
115
+
116
+ for name, array in zip(names, arrays):
117
+ name = name[6:] # skip "model/"
118
+ name = name.split("/")
119
+ pointer = model
120
+ for m_name in name:
121
+ if re.fullmatch(r"[A-Za-z]+\d+", m_name):
122
+ scope_names = re.split(r"(\d+)", m_name)
123
+ else:
124
+ scope_names = [m_name]
125
+ if scope_names[0] == "w" or scope_names[0] == "g":
126
+ pointer = getattr(pointer, "weight")
127
+ elif scope_names[0] == "b":
128
+ pointer = getattr(pointer, "bias")
129
+ elif scope_names[0] == "wpe" or scope_names[0] == "wte":
130
+ pointer = getattr(pointer, scope_names[0])
131
+ pointer = getattr(pointer, "weight")
132
+ else:
133
+ pointer = getattr(pointer, scope_names[0])
134
+ if len(scope_names) >= 2:
135
+ num = int(scope_names[1])
136
+ pointer = pointer[num]
137
+ try:
138
+ assert (
139
+ pointer.shape == array.shape
140
+ ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
141
+ except AssertionError as e:
142
+ e.args += (pointer.shape, array.shape)
143
+ raise
144
+ logger.info(f"Initialize PyTorch weight {name}")
145
+ pointer.data = torch.from_numpy(array)
146
+ return model
147
+
148
+
149
+ class GPT2Attention(nn.Module):
150
+ def __init__(self, config, is_cross_attention=False, layer_idx=None):
151
+ super().__init__()
152
+
153
+ max_positions = config.max_position_embeddings
154
+ self.register_buffer(
155
+ "bias",
156
+ torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
157
+ 1, 1, max_positions, max_positions
158
+ ),
159
+ )
160
+ self.register_buffer("masked_bias", torch.tensor(-1e4))
161
+
162
+ self.embed_dim = config.hidden_size
163
+ self.num_heads = config.num_attention_heads
164
+ self.head_dim = self.embed_dim // self.num_heads
165
+ self.split_size = self.embed_dim
166
+ if self.head_dim * self.num_heads != self.embed_dim:
167
+ raise ValueError(
168
+ f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
169
+ )
170
+
171
+ self.scale_attn_weights = config.scale_attn_weights
172
+ self.is_cross_attention = is_cross_attention
173
+
174
+ # Layer-wise attention scaling, reordering, and upcasting
175
+ self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
176
+ self.layer_idx = layer_idx
177
+ self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
178
+
179
+ if self.is_cross_attention:
180
+ self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
181
+ self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
182
+ else:
183
+ self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
184
+ self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
185
+
186
+ self.attn_dropout = nn.Dropout(config.attn_pdrop)
187
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
188
+
189
+ self.pruned_heads = set()
190
+
191
+ def prune_heads(self, heads):
192
+ if len(heads) == 0:
193
+ return
194
+ heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads)
195
+ index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
196
+
197
+ # Prune conv1d layers
198
+ self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
199
+ self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
200
+
201
+ # Update hyper params
202
+ self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads))
203
+ self.num_heads = self.num_heads - len(heads)
204
+ self.pruned_heads = self.pruned_heads.union(heads)
205
+
206
+ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
207
+ attn_weights = torch.matmul(query, key.transpose(-1, -2))
208
+
209
+ if self.scale_attn_weights:
210
+ attn_weights = attn_weights / (value.size(-1) ** 0.5)
211
+
212
+ # Layer-wise attention scaling
213
+ if self.scale_attn_by_inverse_layer_idx:
214
+ attn_weights = attn_weights / float(self.layer_idx + 1)
215
+
216
+ if not self.is_cross_attention:
217
+ # if only "normal" attention layer implements causal mask
218
+ query_length, key_length = query.size(-2), key.size(-2)
219
+ causal_mask = self.bias[:, :, key_length - query_length: key_length, :key_length].bool()
220
+ attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
221
+
222
+ if attention_mask is not None:
223
+ # Apply the attention mask
224
+ attn_weights = attn_weights + attention_mask
225
+
226
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
227
+
228
+ # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
229
+ attn_weights = attn_weights.type(value.dtype)
230
+ attn_weights = self.attn_dropout(attn_weights)
231
+
232
+ # Mask heads if we want to
233
+ if head_mask is not None:
234
+ attn_weights = attn_weights * head_mask
235
+
236
+ attn_output = torch.matmul(attn_weights, value)
237
+
238
+ return attn_output, attn_weights
239
+
240
+ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
241
+ # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
242
+ bsz, num_heads, q_seq_len, dk = query.size()
243
+ _, _, k_seq_len, _ = key.size()
244
+
245
+ # Preallocate attn_weights for `baddbmm`
246
+ attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)
247
+
248
+ # Compute Scale Factor
249
+ scale_factor = 1.0
250
+ if self.scale_attn_weights:
251
+ scale_factor /= float(value.size(-1)) ** 0.5
252
+
253
+ if self.scale_attn_by_inverse_layer_idx:
254
+ scale_factor /= float(self.layer_idx + 1)
255
+
256
+ # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
257
+ if is_amp_available:
258
+ with autocast(enabled=False):
259
+ q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
260
+ attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
261
+ attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
262
+ else:
263
+ q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
264
+ attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
265
+ attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
266
+
267
+ if not self.is_cross_attention:
268
+ # if only "normal" attention layer implements causal mask
269
+ query_length, key_length = query.size(-2), key.size(-2)
270
+ causal_mask = self.bias[:, :, key_length - query_length: key_length, :key_length].bool()
271
+ attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
272
+
273
+ if attention_mask is not None:
274
+ # Apply the attention mask
275
+ attn_weights = attn_weights + attention_mask
276
+
277
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
278
+
279
+ # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
280
+ if attn_weights.dtype != torch.float32:
281
+ raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32")
282
+ attn_weights = attn_weights.type(value.dtype)
283
+ attn_weights = self.attn_dropout(attn_weights)
284
+
285
+ # Mask heads if we want to
286
+ if head_mask is not None:
287
+ attn_weights = attn_weights * head_mask
288
+
289
+ attn_output = torch.matmul(attn_weights, value)
290
+
291
+ return attn_output, attn_weights
292
+
293
+ def _split_heads(self, tensor, num_heads, attn_head_size):
294
+ """
295
+ Splits hidden_size dim into attn_head_size and num_heads
296
+ """
297
+ new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
298
+ tensor = tensor.view(new_shape)
299
+ return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
300
+
301
+ def _merge_heads(self, tensor, num_heads, attn_head_size):
302
+ """
303
+ Merges attn_head_size dim and num_attn_heads dim into hidden_size
304
+ """
305
+ tensor = tensor.permute(0, 2, 1, 3).contiguous()
306
+ new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
307
+ return tensor.view(new_shape)
308
+
309
+ def forward(
310
+ self,
311
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
312
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
313
+ attention_mask: Optional[torch.FloatTensor] = None,
314
+ head_mask: Optional[torch.FloatTensor] = None,
315
+ encoder_hidden_states: Optional[torch.Tensor] = None,
316
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
317
+ use_cache: Optional[bool] = False,
318
+ output_attentions: Optional[bool] = False,
319
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
320
+ if encoder_hidden_states is not None:
321
+ if not hasattr(self, "q_attn"):
322
+ raise ValueError(
323
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
324
+ "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
325
+ )
326
+
327
+ query = self.q_attn(hidden_states)
328
+ key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
329
+ attention_mask = encoder_attention_mask
330
+ else:
331
+ query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
332
+
333
+ query = self._split_heads(query, self.num_heads, self.head_dim)
334
+ key = self._split_heads(key, self.num_heads, self.head_dim)
335
+ value = self._split_heads(value, self.num_heads, self.head_dim)
336
+
337
+ if layer_past is not None:
338
+ past_key, past_value = layer_past
339
+ key = torch.cat((past_key, key), dim=-2)
340
+ value = torch.cat((past_value, value), dim=-2)
341
+
342
+ if use_cache is True:
343
+ present = (key, value)
344
+ else:
345
+ present = None
346
+
347
+ if self.reorder_and_upcast_attn:
348
+ attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
349
+ else:
350
+ attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
351
+
352
+ attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
353
+ attn_output = self.c_proj(attn_output)
354
+ attn_output = self.resid_dropout(attn_output)
355
+
356
+ outputs = (attn_output, present)
357
+ if output_attentions:
358
+ outputs += (attn_weights,)
359
+
360
+ return outputs # a, present, (attentions)
361
+
362
+
363
+ class SqReLU(nn.Module):
364
+ """
365
+ See So: Primer: Searching for Efficient Transformers for Language Modeling (So., https://arxiv.org/abs/2109.08668).
366
+ """
367
+
368
+ def __init__(self):
369
+ super().__init__()
370
+ self.act = self._sqrelu_python
371
+
372
+ def _sqrelu_python(self, input: torch.Tensor) -> torch.Tensor:
373
+ return torch.pow(nn.functional.relu(input), 2)
374
+
375
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
376
+ return self.act(input)
377
+
378
+
379
+ class GPT2MLP(nn.Module):
380
+ def __init__(self, intermediate_size, config, squared_relu=False):
381
+ super().__init__()
382
+ embed_dim = config.hidden_size
383
+ self.c_fc = Conv1D(intermediate_size, embed_dim)
384
+ self.c_proj = Conv1D(embed_dim, intermediate_size)
385
+ if squared_relu:
386
+ self.act = SqReLU()
387
+ else:
388
+ self.act = ACT2FN[config.activation_function]
389
+ self.dropout = nn.Dropout(config.resid_pdrop)
390
+
391
+ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
392
+ hidden_states = self.c_fc(hidden_states)
393
+ hidden_states = self.act(hidden_states)
394
+ hidden_states = self.c_proj(hidden_states)
395
+ hidden_states = self.dropout(hidden_states)
396
+ return hidden_states
397
+
398
+
399
+ class GPT2Block(nn.Module):
400
+ def __init__(self, config, layer_idx=None):
401
+ super().__init__()
402
+ hidden_size = config.hidden_size
403
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
404
+
405
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
406
+ self.attn = GPT2Attention(config, layer_idx=layer_idx)
407
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
408
+
409
+ self.add_cross_attention_freq = config.add_cross_attention_freq
410
+ if config.add_cross_attention and layer_idx % config.add_cross_attention_freq == 0:
411
+ self.crossattention = GPT2Attention(config, is_cross_attention=True, layer_idx=layer_idx)
412
+ self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
413
+ self.mlp_crossattention = GPT2MLP(inner_dim, config, squared_relu=True)
414
+ self.ln_2_crossattention = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
415
+ if config.is_tanh_gating:
416
+ self.alpha_cattn = nn.Parameter(torch.zeros([]))
417
+ self.alpha_dense = nn.Parameter(torch.zeros([]))
418
+
419
+ self.mlp = GPT2MLP(inner_dim, config)
420
+
421
+ def forward(
422
+ self,
423
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
424
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
425
+ attention_mask: Optional[torch.FloatTensor] = None,
426
+ head_mask: Optional[torch.FloatTensor] = None,
427
+ encoder_hidden_states: Optional[torch.Tensor] = None,
428
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
429
+ use_cache: Optional[bool] = False,
430
+ output_attentions: Optional[bool] = False,
431
+ ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
432
+ if encoder_hidden_states is not None and self.attn.layer_idx % self.add_cross_attention_freq == 0:
433
+ # add one self-attention block for cross-attention
434
+ if not hasattr(self, "crossattention"):
435
+ raise ValueError(
436
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
437
+ "cross-attention layers by setting `config.add_cross_attention=True`"
438
+ )
439
+ residual = hidden_states
440
+ hidden_states = self.ln_cross_attn(hidden_states)
441
+ cross_attn_outputs = self.crossattention(
442
+ hidden_states,
443
+ attention_mask=attention_mask,
444
+ head_mask=head_mask,
445
+ encoder_hidden_states=encoder_hidden_states,
446
+ encoder_attention_mask=encoder_attention_mask,
447
+ output_attentions=output_attentions,
448
+ )
449
+ attn_output = cross_attn_outputs[0]
450
+ if hasattr(self, "alpha_cattn"):
451
+ attn_output = torch.tanh(self.alpha_cattn) * attn_output
452
+ # residual connection
453
+ hidden_states = residual + attn_output
454
+
455
+ residual = hidden_states
456
+ hidden_states = self.ln_2_crossattention(hidden_states)
457
+ feed_forward_hidden_states = self.mlp_crossattention(hidden_states)
458
+ if hasattr(self, "alpha_dense"):
459
+ feed_forward_hidden_states = torch.tanh(self.alpha_dense) * feed_forward_hidden_states
460
+ # residual connection
461
+ hidden_states = residual + feed_forward_hidden_states
462
+
463
+ # Self-Attention
464
+ residual = hidden_states
465
+ hidden_states = self.ln_1(hidden_states)
466
+ attn_outputs = self.attn(
467
+ hidden_states,
468
+ layer_past=layer_past,
469
+ attention_mask=attention_mask,
470
+ head_mask=head_mask,
471
+ use_cache=use_cache,
472
+ output_attentions=output_attentions,
473
+ )
474
+ attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
475
+ outputs = attn_outputs[1:]
476
+ # residual connection
477
+ hidden_states = attn_output + residual
478
+
479
+ # add cross attentions (follow the original order, not to mess things up)
480
+ if encoder_hidden_states is not None and self.attn.layer_idx % self.add_cross_attention_freq == 0:
481
+ outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
482
+
483
+ # FFN
484
+ residual = hidden_states
485
+ hidden_states = self.ln_2(hidden_states)
486
+ feed_forward_hidden_states = self.mlp(hidden_states)
487
+ # residual connection
488
+ hidden_states = residual + feed_forward_hidden_states
489
+
490
+ if use_cache:
491
+ outputs = (hidden_states,) + outputs
492
+ else:
493
+ outputs = (hidden_states,) + outputs[1:]
494
+
495
+ return outputs # hidden_states, present, (attentions, cross_attentions)
496
+
497
+
498
+ class GPT2PreTrainedModel(PreTrainedModel):
499
+ """
500
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
501
+ models.
502
+ """
503
+
504
+ config_class = GPT2Config
505
+ load_tf_weights = load_tf_weights_in_gpt2
506
+ base_model_prefix = "transformer"
507
+ is_parallelizable = True
508
+ supports_gradient_checkpointing = True
509
+
510
+ def __init__(self, *inputs, **kwargs):
511
+ super().__init__(*inputs, **kwargs)
512
+
513
+ def _init_weights(self, module):
514
+ """Initialize the weights."""
515
+ if isinstance(module, (nn.Linear, Conv1D)):
516
+ # Slightly different from the TF version which uses truncated_normal for initialization
517
+ # cf https://github.com/pytorch/pytorch/pull/5617
518
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
519
+ if module.bias is not None:
520
+ module.bias.data.zero_()
521
+ elif isinstance(module, nn.Embedding):
522
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
523
+ if module.padding_idx is not None:
524
+ module.weight.data[module.padding_idx].zero_()
525
+ elif isinstance(module, nn.LayerNorm):
526
+ module.bias.data.zero_()
527
+ module.weight.data.fill_(1.0)
528
+
529
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
530
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
531
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
532
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
533
+ #
534
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
535
+ for name, p in module.named_parameters():
536
+ if "c_proj" in name and "weight" in name:
537
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
538
+ p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
539
+
540
+ def _set_gradient_checkpointing(self, module, value=False):
541
+ if isinstance(module, GPT2Model):
542
+ module.gradient_checkpointing = value
543
+
544
+
545
+ @dataclass
546
+ class GPT2DoubleHeadsModelOutput(ModelOutput):
547
+ """
548
+ Base class for outputs of models predicting if two sentences are consecutive or not.
549
+
550
+ Args:
551
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
552
+ Language modeling loss.
553
+ mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided):
554
+ Multiple choice classification loss.
555
+ logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`):
556
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
557
+ mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`):
558
+ Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
559
+ past_key_values (`Tuple[Tuple[torch.Tensor]]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
560
+ Tuple of length `config.n_layers`, containing tuples of tensors of shape `(batch_size, num_heads,
561
+ sequence_length, embed_size_per_head)`).
562
+
563
+ Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
564
+ `past_key_values` input) to speed up sequential decoding.
565
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
566
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
567
+ shape `(batch_size, sequence_length, hidden_size)`.
568
+
569
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
570
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
571
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
572
+ sequence_length)`.
573
+
574
+ GPT2Attentions weights after the attention softmax, used to compute the weighted average in the
575
+ self-attention heads.
576
+ """
577
+
578
+ loss: Optional[torch.FloatTensor] = None
579
+ mc_loss: Optional[torch.FloatTensor] = None
580
+ logits: torch.FloatTensor = None
581
+ mc_logits: torch.FloatTensor = None
582
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
583
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
584
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
585
+
586
+
587
+ GPT2_START_DOCSTRING = r"""
588
+
589
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
590
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
591
+ etc.)
592
+
593
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
594
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
595
+ and behavior.
596
+
597
+ Parameters:
598
+ config ([`GPT2Config`]): Model configuration class with all the parameters of the model.
599
+ Initializing with a config file does not load the weights associated with the model, only the
600
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
601
+ """
602
+
603
+ GPT2_INPUTS_DOCSTRING = r"""
604
+ Args:
605
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
606
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
607
+ `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
608
+ sequence tokens in the vocabulary.
609
+
610
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
611
+ `input_ids`.
612
+
613
+ Indices can be obtained using [`GPT2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
614
+ [`PreTrainedTokenizer.__call__`] for details.
615
+
616
+ [What are input IDs?](../glossary#input-ids)
617
+ past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
618
+ Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
619
+ `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
620
+ their past given to this model should not be passed as `input_ids` as they have already been computed.
621
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
622
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
623
+
624
+ - 1 for tokens that are **not masked**,
625
+ - 0 for tokens that are **masked**.
626
+
627
+ If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for
628
+ `past_key_values`. In other words, the `attention_mask` always has to have the length:
629
+ `len(past_key_values) + len(input_ids)`
630
+
631
+ [What are attention masks?](../glossary#attention-mask)
632
+ token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
633
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
634
+ 1]`:
635
+
636
+ - 0 corresponds to a *sentence A* token,
637
+ - 1 corresponds to a *sentence B* token.
638
+
639
+ [What are token type IDs?](../glossary#token-type-ids)
640
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
641
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
642
+ config.max_position_embeddings - 1]`.
643
+
644
+ [What are position IDs?](../glossary#position-ids)
645
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
646
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
647
+
648
+ - 1 indicates the head is **not masked**,
649
+ - 0 indicates the head is **masked**.
650
+
651
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
652
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
653
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
654
+ model's internal embedding lookup matrix.
655
+
656
+ If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
657
+ `past_key_values`).
658
+ use_cache (`bool`, *optional*):
659
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
660
+ `past_key_values`).
661
+ output_attentions (`bool`, *optional*):
662
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
663
+ tensors for more detail.
664
+ output_hidden_states (`bool`, *optional*):
665
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
666
+ more detail.
667
+ return_dict (`bool`, *optional*):
668
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
669
+ """
670
+ PARALLELIZE_DOCSTRING = r"""
671
+ This is an experimental feature and is a subject to change at a moment's notice.
672
+
673
+ Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
674
+ it will evenly distribute blocks across all devices.
675
+
676
+ Args:
677
+ device_map (`Dict[int, list]`, optional, defaults to None):
678
+ A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
679
+ automatically mapped to the first device (for esoteric reasons). That means that the first device should
680
+ have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the
681
+ following number of attention modules:
682
+
683
+ - gpt2: 12
684
+ - gpt2-medium: 24
685
+ - gpt2-large: 36
686
+ - gpt2-xl: 48
687
+
688
+ Example:
689
+
690
+ ```python
691
+ # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules:
692
+ model = GPT2LMHeadModel.from_pretrained("gpt2-xl")
693
+ device_map = {
694
+ 0: [0, 1, 2, 3, 4, 5, 6, 7, 8],
695
+ 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
696
+ 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34],
697
+ 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
698
+ }
699
+ model.parallelize(device_map)
700
+ ```
701
+ """
702
+ DEPARALLELIZE_DOCSTRING = r"""
703
+ Moves the model to cpu from a model parallel state.
704
+
705
+ Example:
706
+
707
+ ```python
708
+ # On a 4 GPU machine with gpt2-large:
709
+ model = GPT2LMHeadModel.from_pretrained("gpt2-large")
710
+ device_map = {
711
+ 0: [0, 1, 2, 3, 4, 5, 6, 7],
712
+ 1: [8, 9, 10, 11, 12, 13, 14, 15],
713
+ 2: [16, 17, 18, 19, 20, 21, 22, 23],
714
+ 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35],
715
+ }
716
+ model.parallelize(device_map) # Splits the model across several devices
717
+ model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
718
+ ```
719
+ """
720
+
721
+
722
+ @add_start_docstrings(
723
+ "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
724
+ GPT2_START_DOCSTRING,
725
+ )
726
+ class GPT2Model(GPT2PreTrainedModel):
727
+ _keys_to_ignore_on_load_missing = ["attn.masked_bias"]
728
+
729
+ def __init__(self, config):
730
+ super().__init__(config)
731
+
732
+ self.embed_dim = config.hidden_size
733
+
734
+ self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
735
+ self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
736
+
737
+ self.drop = nn.Dropout(config.embd_pdrop)
738
+ self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)])
739
+ self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
740
+
741
+ # Model parallel
742
+ self.model_parallel = False
743
+ self.device_map = None
744
+ self.gradient_checkpointing = False
745
+
746
+ # Initialize weights and apply final processing
747
+ self.post_init()
748
+
749
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
750
+ def parallelize(self, device_map=None):
751
+ # Check validity of device_map
752
+ self.device_map = (
753
+ get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map
754
+ )
755
+ assert_device_map(self.device_map, len(self.h))
756
+ self.model_parallel = True
757
+ self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
758
+ self.last_device = "cuda:" + str(max(self.device_map.keys()))
759
+ self.wte = self.wte.to(self.first_device)
760
+ self.wpe = self.wpe.to(self.first_device)
761
+ # Load onto devices
762
+ for k, v in self.device_map.items():
763
+ for block in v:
764
+ cuda_device = "cuda:" + str(k)
765
+ self.h[block] = self.h[block].to(cuda_device)
766
+ # ln_f to last
767
+ self.ln_f = self.ln_f.to(self.last_device)
768
+
769
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
770
+ def deparallelize(self):
771
+ self.model_parallel = False
772
+ self.device_map = None
773
+ self.first_device = "cpu"
774
+ self.last_device = "cpu"
775
+ self.wte = self.wte.to("cpu")
776
+ self.wpe = self.wpe.to("cpu")
777
+ for index in range(len(self.h)):
778
+ self.h[index] = self.h[index].to("cpu")
779
+ self.ln_f = self.ln_f.to("cpu")
780
+ torch.cuda.empty_cache()
781
+
782
+ def get_input_embeddings(self):
783
+ return self.wte
784
+
785
+ def set_input_embeddings(self, new_embeddings):
786
+ self.wte = new_embeddings
787
+
788
+ def _prune_heads(self, heads_to_prune):
789
+ """
790
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
791
+ """
792
+ for layer, heads in heads_to_prune.items():
793
+ self.h[layer].attn.prune_heads(heads)
794
+
795
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
796
+ @add_code_sample_docstrings(
797
+ processor_class=_TOKENIZER_FOR_DOC,
798
+ checkpoint=_CHECKPOINT_FOR_DOC,
799
+ output_type=BaseModelOutputWithPastAndCrossAttentions,
800
+ config_class=_CONFIG_FOR_DOC,
801
+ )
802
+ def forward(
803
+ self,
804
+ input_ids: Optional[torch.LongTensor] = None,
805
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
806
+ attention_mask: Optional[torch.FloatTensor] = None,
807
+ token_type_ids: Optional[torch.LongTensor] = None,
808
+ position_ids: Optional[torch.LongTensor] = None,
809
+ head_mask: Optional[torch.FloatTensor] = None,
810
+ inputs_embeds: Optional[torch.FloatTensor] = None,
811
+ encoder_hidden_states: Optional[torch.Tensor] = None,
812
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
813
+ use_cache: Optional[bool] = None,
814
+ output_attentions: Optional[bool] = None,
815
+ output_hidden_states: Optional[bool] = None,
816
+ return_dict: Optional[bool] = None,
817
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
818
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
819
+ output_hidden_states = (
820
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
821
+ )
822
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
823
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
824
+
825
+ if input_ids is not None and inputs_embeds is not None:
826
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
827
+ elif input_ids is not None:
828
+ input_shape = input_ids.size()
829
+ input_ids = input_ids.view(-1, input_shape[-1])
830
+ batch_size = input_ids.shape[0]
831
+ elif inputs_embeds is not None:
832
+ input_shape = inputs_embeds.size()[:-1]
833
+ batch_size = inputs_embeds.shape[0]
834
+ else:
835
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
836
+
837
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
838
+
839
+ if token_type_ids is not None:
840
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
841
+ if position_ids is not None:
842
+ position_ids = position_ids.view(-1, input_shape[-1])
843
+
844
+ if past_key_values is None:
845
+ past_length = 0
846
+ past_key_values = tuple([None] * len(self.h))
847
+ else:
848
+ past_length = past_key_values[0][0].size(-2)
849
+ if position_ids is None:
850
+ position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
851
+ position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
852
+
853
+ # GPT2Attention mask.
854
+ if attention_mask is not None:
855
+ if batch_size <= 0:
856
+ raise ValueError("batch_size has to be defined and > 0")
857
+ attention_mask = attention_mask.view(batch_size, -1)
858
+ # We create a 3D attention mask from a 2D tensor mask.
859
+ # Sizes are [batch_size, 1, 1, to_seq_length]
860
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
861
+ # this attention mask is more simple than the triangular masking of causal attention
862
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
863
+ attention_mask = attention_mask[:, None, None, :]
864
+
865
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
866
+ # masked positions, this operation will create a tensor which is 0.0 for
867
+ # positions we want to attend and -10000.0 for masked positions.
868
+ # Since we are adding it to the raw scores before the softmax, this is
869
+ # effectively the same as removing these entirely.
870
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
871
+ attention_mask = (1.0 - attention_mask) * -10000.0
872
+
873
+ # If a 2D or 3D attention mask is provided for the cross-attention
874
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
875
+ if self.config.add_cross_attention and encoder_hidden_states is not None:
876
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
877
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
878
+ if encoder_attention_mask is None:
879
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
880
+ encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
881
+ else:
882
+ encoder_attention_mask = None
883
+
884
+ # Prepare head mask if needed
885
+ # 1.0 in head_mask indicate we keep the head
886
+ # attention_probs has shape bsz x n_heads x N x N
887
+ # head_mask has shape n_layer x batch x n_heads x N x N
888
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
889
+
890
+ if inputs_embeds is None:
891
+ inputs_embeds = self.wte(input_ids)
892
+ position_embeds = self.wpe(position_ids)
893
+ hidden_states = inputs_embeds + position_embeds
894
+
895
+ if token_type_ids is not None:
896
+ token_type_embeds = self.wte(token_type_ids)
897
+ hidden_states = hidden_states + token_type_embeds
898
+
899
+ hidden_states = self.drop(hidden_states)
900
+
901
+ output_shape = input_shape + (hidden_states.size(-1),)
902
+
903
+ presents = () if use_cache else None
904
+ all_self_attentions = () if output_attentions else None
905
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
906
+ all_hidden_states = () if output_hidden_states else None
907
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
908
+
909
+ # Model parallel
910
+ if self.model_parallel:
911
+ torch.cuda.set_device(hidden_states.device)
912
+ # Ensure layer_past is on same device as hidden_states (might not be correct)
913
+ if layer_past is not None:
914
+ layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
915
+ # Ensure that attention_mask is always on the same device as hidden_states
916
+ if attention_mask is not None:
917
+ attention_mask = attention_mask.to(hidden_states.device)
918
+ if isinstance(head_mask, torch.Tensor):
919
+ head_mask = head_mask.to(hidden_states.device)
920
+ if output_hidden_states:
921
+ all_hidden_states = all_hidden_states + (hidden_states,)
922
+
923
+ if self.gradient_checkpointing and self.training:
924
+
925
+ if use_cache:
926
+ logger.warning(
927
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
928
+ )
929
+ use_cache = False
930
+
931
+ def create_custom_forward(module):
932
+ def custom_forward(*inputs):
933
+ # None for past_key_value
934
+ return module(*inputs, use_cache, output_attentions)
935
+
936
+ return custom_forward
937
+
938
+ outputs = torch.utils.checkpoint.checkpoint(
939
+ create_custom_forward(block),
940
+ hidden_states,
941
+ None,
942
+ attention_mask,
943
+ head_mask[i],
944
+ encoder_hidden_states,
945
+ encoder_attention_mask,
946
+ )
947
+ else:
948
+ outputs = block(
949
+ hidden_states,
950
+ layer_past=layer_past,
951
+ attention_mask=attention_mask,
952
+ head_mask=head_mask[i],
953
+ encoder_hidden_states=encoder_hidden_states,
954
+ encoder_attention_mask=encoder_attention_mask,
955
+ use_cache=use_cache,
956
+ output_attentions=output_attentions,
957
+ )
958
+
959
+ hidden_states = outputs[0]
960
+ if use_cache is True:
961
+ presents = presents + (outputs[1],)
962
+
963
+ if output_attentions:
964
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
965
+ if self.config.add_cross_attention:
966
+ all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
967
+
968
+ # Model Parallel: If it's the last layer for that device, put things on the next device
969
+ if self.model_parallel:
970
+ for k, v in self.device_map.items():
971
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
972
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
973
+
974
+ hidden_states = self.ln_f(hidden_states)
975
+
976
+ hidden_states = hidden_states.view(output_shape)
977
+ # Add last hidden state
978
+ if output_hidden_states:
979
+ all_hidden_states = all_hidden_states + (hidden_states,)
980
+
981
+ if not return_dict:
982
+ return tuple(
983
+ v
984
+ for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
985
+ if v is not None
986
+ )
987
+
988
+ return BaseModelOutputWithPastAndCrossAttentions(
989
+ last_hidden_state=hidden_states,
990
+ past_key_values=presents,
991
+ hidden_states=all_hidden_states,
992
+ attentions=all_self_attentions,
993
+ cross_attentions=all_cross_attentions,
994
+ )
995
+
996
+
997
+ @add_start_docstrings(
998
+ """
999
+ The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
1000
+ embeddings).
1001
+ """,
1002
+ GPT2_START_DOCSTRING,
1003
+ )
1004
+ class GPT2LMHeadModel(GPT2PreTrainedModel):
1005
+ _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"]
1006
+
1007
+ def __init__(self, config):
1008
+ super().__init__(config)
1009
+ self.transformer = GPT2Model(config)
1010
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
1011
+
1012
+ # Model parallel
1013
+ self.model_parallel = False
1014
+ self.device_map = None
1015
+
1016
+ # Initialize weights and apply final processing
1017
+ self.post_init()
1018
+
1019
+ def freeze_lm_weights(self):
1020
+ freeze_list, unfreeze_list = [], []
1021
+ for n, p in self.named_parameters():
1022
+ if 'crossattention' in n or 'cross_attn' in n or 'alpha_cattn' in n or 'alpha_dense' in n:
1023
+ p.requires_grad = True
1024
+ unfreeze_list.append(n)
1025
+ else:
1026
+ p.requires_grad = False
1027
+ freeze_list.append(n)
1028
+ print("Freeze the pretrained parts in LM: {}".format(freeze_list))
1029
+ print(" Learn the rest parts in LM: {}".format(unfreeze_list))
1030
+
1031
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
1032
+ def parallelize(self, device_map=None):
1033
+ self.device_map = (
1034
+ get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
1035
+ if device_map is None
1036
+ else device_map
1037
+ )
1038
+ assert_device_map(self.device_map, len(self.transformer.h))
1039
+ self.transformer.parallelize(self.device_map)
1040
+ self.lm_head = self.lm_head.to(self.transformer.first_device)
1041
+ self.model_parallel = True
1042
+
1043
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1044
+ def deparallelize(self):
1045
+ self.transformer.deparallelize()
1046
+ self.transformer = self.transformer.to("cpu")
1047
+ self.lm_head = self.lm_head.to("cpu")
1048
+ self.model_parallel = False
1049
+ torch.cuda.empty_cache()
1050
+
1051
+ def get_output_embeddings(self):
1052
+ return self.lm_head
1053
+
1054
+ def set_output_embeddings(self, new_embeddings):
1055
+ self.lm_head = new_embeddings
1056
+
1057
+ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
1058
+ token_type_ids = kwargs.get("token_type_ids", None)
1059
+ # only last token for inputs_ids if past is defined in kwargs
1060
+ if past:
1061
+ input_ids = input_ids[:, -1].unsqueeze(-1)
1062
+ if token_type_ids is not None:
1063
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
1064
+
1065
+ attention_mask = kwargs.get("attention_mask", None)
1066
+ position_ids = kwargs.get("position_ids", None)
1067
+
1068
+ if attention_mask is not None and position_ids is None:
1069
+ # create position_ids on the fly for batch generation
1070
+ position_ids = attention_mask.long().cumsum(-1) - 1
1071
+ position_ids.masked_fill_(attention_mask == 0, 1)
1072
+ if past:
1073
+ position_ids = position_ids[:, -1].unsqueeze(-1)
1074
+ else:
1075
+ position_ids = None
1076
+ return {
1077
+ "input_ids": input_ids,
1078
+ "past_key_values": past,
1079
+ "use_cache": kwargs.get("use_cache"),
1080
+ "position_ids": position_ids,
1081
+ "attention_mask": attention_mask,
1082
+ "token_type_ids": token_type_ids,
1083
+ }
1084
+
1085
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1086
+ @add_code_sample_docstrings(
1087
+ processor_class=_TOKENIZER_FOR_DOC,
1088
+ checkpoint=_CHECKPOINT_FOR_DOC,
1089
+ output_type=CausalLMOutputWithCrossAttentions,
1090
+ config_class=_CONFIG_FOR_DOC,
1091
+ )
1092
+ def forward(
1093
+ self,
1094
+ input_ids: Optional[torch.LongTensor] = None,
1095
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1096
+ attention_mask: Optional[torch.FloatTensor] = None,
1097
+ token_type_ids: Optional[torch.LongTensor] = None,
1098
+ position_ids: Optional[torch.LongTensor] = None,
1099
+ head_mask: Optional[torch.FloatTensor] = None,
1100
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1101
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1102
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1103
+ labels: Optional[torch.LongTensor] = None,
1104
+ use_cache: Optional[bool] = None,
1105
+ output_attentions: Optional[bool] = None,
1106
+ output_hidden_states: Optional[bool] = None,
1107
+ return_dict: Optional[bool] = None,
1108
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
1109
+ r"""
1110
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1111
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1112
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
1113
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
1114
+ """
1115
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1116
+
1117
+ transformer_outputs = self.transformer(
1118
+ input_ids,
1119
+ past_key_values=past_key_values,
1120
+ attention_mask=attention_mask,
1121
+ token_type_ids=token_type_ids,
1122
+ position_ids=position_ids,
1123
+ head_mask=head_mask,
1124
+ inputs_embeds=inputs_embeds,
1125
+ encoder_hidden_states=encoder_hidden_states,
1126
+ encoder_attention_mask=encoder_attention_mask,
1127
+ use_cache=use_cache,
1128
+ output_attentions=output_attentions,
1129
+ output_hidden_states=output_hidden_states,
1130
+ return_dict=return_dict,
1131
+ )
1132
+ hidden_states = transformer_outputs[0]
1133
+
1134
+ # Set device for model parallelism
1135
+ if self.model_parallel:
1136
+ torch.cuda.set_device(self.transformer.first_device)
1137
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
1138
+
1139
+ lm_logits = self.lm_head(hidden_states)
1140
+
1141
+ loss = None
1142
+ if labels is not None:
1143
+ # Shift so that tokens < n predict n
1144
+ shift_logits = lm_logits[..., :-1, :].contiguous()
1145
+ shift_labels = labels[..., 1:].contiguous()
1146
+ # Flatten the tokens
1147
+ loss_fct = CrossEntropyLoss()
1148
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1149
+
1150
+ if not return_dict:
1151
+ output = (lm_logits,) + transformer_outputs[1:]
1152
+ return ((loss,) + output) if loss is not None else output
1153
+
1154
+ return CausalLMOutputWithCrossAttentions(
1155
+ loss=loss,
1156
+ logits=lm_logits,
1157
+ past_key_values=transformer_outputs.past_key_values,
1158
+ hidden_states=transformer_outputs.hidden_states,
1159
+ attentions=transformer_outputs.attentions,
1160
+ cross_attentions=transformer_outputs.cross_attentions,
1161
+ )
1162
+
1163
+ @staticmethod
1164
+ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
1165
+ """
1166
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1167
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1168
+ beam_idx at every generation step.
1169
+ """
1170
+ return tuple(
1171
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
1172
+ for layer_past in past
1173
+ )
1174
+
1175
+
1176
+ @add_start_docstrings(
1177
+ """
1178
+ The GPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for
1179
+ RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the
1180
+ input embeddings, the classification head takes as input the input of a specified classification token index in the
1181
+ input sequence).
1182
+ """,
1183
+ GPT2_START_DOCSTRING,
1184
+ )
1185
+ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
1186
+ _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"]
1187
+
1188
+ def __init__(self, config):
1189
+ super().__init__(config)
1190
+ config.num_labels = 1
1191
+ self.transformer = GPT2Model(config)
1192
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
1193
+ self.multiple_choice_head = SequenceSummary(config)
1194
+
1195
+ # Model parallel
1196
+ self.model_parallel = False
1197
+ self.device_map = None
1198
+
1199
+ # Initialize weights and apply final processing
1200
+ self.post_init()
1201
+
1202
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
1203
+ def parallelize(self, device_map=None):
1204
+ self.device_map = (
1205
+ get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
1206
+ if device_map is None
1207
+ else device_map
1208
+ )
1209
+ assert_device_map(self.device_map, len(self.transformer.h))
1210
+ self.transformer.parallelize(self.device_map)
1211
+ self.lm_head = self.lm_head.to(self.transformer.first_device)
1212
+ self.multiple_choice_head = self.multiple_choice_head.to(self.transformer.first_device)
1213
+ self.model_parallel = True
1214
+
1215
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1216
+ def deparallelize(self):
1217
+ self.transformer.deparallelize()
1218
+ self.transformer = self.transformer.to("cpu")
1219
+ self.lm_head = self.lm_head.to("cpu")
1220
+ self.multiple_choice_head = self.multiple_choice_head.to("cpu")
1221
+ self.model_parallel = False
1222
+ torch.cuda.empty_cache()
1223
+
1224
+ def get_output_embeddings(self):
1225
+ return self.lm_head
1226
+
1227
+ def set_output_embeddings(self, new_embeddings):
1228
+ self.lm_head = new_embeddings
1229
+
1230
+ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
1231
+ token_type_ids = kwargs.get("token_type_ids", None)
1232
+ # only last token for inputs_ids if past is defined in kwargs
1233
+ if past:
1234
+ input_ids = input_ids[:, -1].unsqueeze(-1)
1235
+ if token_type_ids is not None:
1236
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
1237
+
1238
+ attention_mask = kwargs.get("attention_mask", None)
1239
+ position_ids = kwargs.get("position_ids", None)
1240
+
1241
+ if attention_mask is not None and position_ids is None:
1242
+ # create position_ids on the fly for batch generation
1243
+ position_ids = attention_mask.long().cumsum(-1) - 1
1244
+ position_ids.masked_fill_(attention_mask == 0, 1)
1245
+ if past:
1246
+ position_ids = position_ids[:, -1].unsqueeze(-1)
1247
+ else:
1248
+ position_ids = None
1249
+
1250
+ return {
1251
+ "input_ids": input_ids,
1252
+ "past_key_values": past,
1253
+ "use_cache": kwargs.get("use_cache"),
1254
+ "position_ids": position_ids,
1255
+ "attention_mask": attention_mask,
1256
+ "token_type_ids": token_type_ids,
1257
+ }
1258
+
1259
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1260
+ @replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
1261
+ def forward(
1262
+ self,
1263
+ input_ids: Optional[torch.LongTensor] = None,
1264
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1265
+ attention_mask: Optional[torch.FloatTensor] = None,
1266
+ token_type_ids: Optional[torch.LongTensor] = None,
1267
+ position_ids: Optional[torch.LongTensor] = None,
1268
+ head_mask: Optional[torch.FloatTensor] = None,
1269
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1270
+ mc_token_ids: Optional[torch.LongTensor] = None,
1271
+ labels: Optional[torch.LongTensor] = None,
1272
+ mc_labels: Optional[torch.LongTensor] = None,
1273
+ use_cache: Optional[bool] = None,
1274
+ output_attentions: Optional[bool] = None,
1275
+ output_hidden_states: Optional[bool] = None,
1276
+ return_dict: Optional[bool] = None,
1277
+ **kwargs,
1278
+ ) -> Union[Tuple, GPT2DoubleHeadsModelOutput]:
1279
+ r"""
1280
+ mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input):
1281
+ Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) -
1282
+ 1[`.
1283
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1284
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1285
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size - 1]` All labels set to
1286
+ `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]`
1287
+ mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*):
1288
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
1289
+ where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above)
1290
+
1291
+ Return:
1292
+
1293
+ Example:
1294
+
1295
+ ```python
1296
+ >>> import torch
1297
+ >>> from transformers import GPT2Tokenizer, GPT2DoubleHeadsModel
1298
+
1299
+ >>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
1300
+ >>> model = GPT2DoubleHeadsModel.from_pretrained("gpt2")
1301
+
1302
+ >>> # Add a [CLS] to the vocabulary (we should train it also!)
1303
+ >>> num_added_tokens = tokenizer.add_special_tokens({"cls_token": "[CLS]"})
1304
+ >>> # Update the model embeddings with the new vocabulary size
1305
+ >>> embedding_layer = model.resize_token_embeddings(len(tokenizer))
1306
+
1307
+ >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
1308
+ >>> encoded_choices = [tokenizer.encode(s) for s in choices]
1309
+ >>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices]
1310
+
1311
+ >>> input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2
1312
+ >>> mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1
1313
+
1314
+ >>> outputs = model(input_ids, mc_token_ids=mc_token_ids)
1315
+ >>> lm_logits = outputs.logits
1316
+ >>> mc_logits = outputs.mc_logits
1317
+ ```"""
1318
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1319
+
1320
+ transformer_outputs = self.transformer(
1321
+ input_ids,
1322
+ past_key_values=past_key_values,
1323
+ attention_mask=attention_mask,
1324
+ token_type_ids=token_type_ids,
1325
+ position_ids=position_ids,
1326
+ head_mask=head_mask,
1327
+ inputs_embeds=inputs_embeds,
1328
+ use_cache=use_cache,
1329
+ output_attentions=output_attentions,
1330
+ output_hidden_states=output_hidden_states,
1331
+ return_dict=return_dict,
1332
+ )
1333
+
1334
+ hidden_states = transformer_outputs[0]
1335
+
1336
+ # Set device for model parallelism
1337
+ if self.model_parallel:
1338
+ torch.cuda.set_device(self.transformer.first_device)
1339
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
1340
+
1341
+ lm_logits = self.lm_head(hidden_states)
1342
+ mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
1343
+
1344
+ mc_loss = None
1345
+ if mc_labels is not None:
1346
+ loss_fct = CrossEntropyLoss()
1347
+ mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
1348
+ lm_loss = None
1349
+ if labels is not None:
1350
+ shift_logits = lm_logits[..., :-1, :].contiguous()
1351
+ shift_labels = labels[..., 1:].contiguous()
1352
+ loss_fct = CrossEntropyLoss()
1353
+ lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1354
+
1355
+ if not return_dict:
1356
+ output = (lm_logits, mc_logits) + transformer_outputs[1:]
1357
+ if mc_loss is not None:
1358
+ output = (mc_loss,) + output
1359
+ return ((lm_loss,) + output) if lm_loss is not None else output
1360
+
1361
+ return GPT2DoubleHeadsModelOutput(
1362
+ loss=lm_loss,
1363
+ mc_loss=mc_loss,
1364
+ logits=lm_logits,
1365
+ mc_logits=mc_logits,
1366
+ past_key_values=transformer_outputs.past_key_values,
1367
+ hidden_states=transformer_outputs.hidden_states,
1368
+ attentions=transformer_outputs.attentions,
1369
+ )
1370
+
1371
+ @staticmethod
1372
+ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
1373
+ """
1374
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1375
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1376
+ beam_idx at every generation step.
1377
+ """
1378
+ return tuple(
1379
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
1380
+ for layer_past in past
1381
+ )
1382
+
1383
+
1384
+ @add_start_docstrings(
1385
+ """
1386
+ The GPT2 Model transformer with a sequence classification head on top (linear layer).
1387
+
1388
+ [`GPT2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1389
+ (e.g. GPT-1) do.
1390
+
1391
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1392
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1393
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1394
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1395
+ each row of the batch).
1396
+ """,
1397
+ GPT2_START_DOCSTRING,
1398
+ )
1399
+ class GPT2ForSequenceClassification(GPT2PreTrainedModel):
1400
+ _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
1401
+
1402
+ def __init__(self, config):
1403
+ super().__init__(config)
1404
+ self.num_labels = config.num_labels
1405
+ self.transformer = GPT2Model(config)
1406
+ self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
1407
+
1408
+ # Model parallel
1409
+ self.model_parallel = False
1410
+ self.device_map = None
1411
+
1412
+ # Initialize weights and apply final processing
1413
+ self.post_init()
1414
+
1415
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1416
+ @add_code_sample_docstrings(
1417
+ processor_class=_TOKENIZER_FOR_DOC,
1418
+ checkpoint="microsoft/DialogRPT-updown",
1419
+ output_type=SequenceClassifierOutputWithPast,
1420
+ config_class=_CONFIG_FOR_DOC,
1421
+ expected_output="'LABEL_0'",
1422
+ expected_loss=5.28,
1423
+ )
1424
+ def forward(
1425
+ self,
1426
+ input_ids: Optional[torch.LongTensor] = None,
1427
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1428
+ attention_mask: Optional[torch.FloatTensor] = None,
1429
+ token_type_ids: Optional[torch.LongTensor] = None,
1430
+ position_ids: Optional[torch.LongTensor] = None,
1431
+ head_mask: Optional[torch.FloatTensor] = None,
1432
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1433
+ labels: Optional[torch.LongTensor] = None,
1434
+ use_cache: Optional[bool] = None,
1435
+ output_attentions: Optional[bool] = None,
1436
+ output_hidden_states: Optional[bool] = None,
1437
+ return_dict: Optional[bool] = None,
1438
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1439
+ r"""
1440
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1441
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1442
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1443
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1444
+ """
1445
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1446
+
1447
+ transformer_outputs = self.transformer(
1448
+ input_ids,
1449
+ past_key_values=past_key_values,
1450
+ attention_mask=attention_mask,
1451
+ token_type_ids=token_type_ids,
1452
+ position_ids=position_ids,
1453
+ head_mask=head_mask,
1454
+ inputs_embeds=inputs_embeds,
1455
+ use_cache=use_cache,
1456
+ output_attentions=output_attentions,
1457
+ output_hidden_states=output_hidden_states,
1458
+ return_dict=return_dict,
1459
+ )
1460
+ hidden_states = transformer_outputs[0]
1461
+ logits = self.score(hidden_states)
1462
+
1463
+ if input_ids is not None:
1464
+ batch_size, sequence_length = input_ids.shape[:2]
1465
+ else:
1466
+ batch_size, sequence_length = inputs_embeds.shape[:2]
1467
+
1468
+ assert (
1469
+ self.config.pad_token_id is not None or batch_size == 1
1470
+ ), "Cannot handle batch sizes > 1 if no padding token is defined."
1471
+ if self.config.pad_token_id is None:
1472
+ sequence_lengths = -1
1473
+ else:
1474
+ if input_ids is not None:
1475
+ sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
1476
+ else:
1477
+ sequence_lengths = -1
1478
+ logger.warning(
1479
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1480
+ f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1481
+ )
1482
+
1483
+ pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths]
1484
+
1485
+ loss = None
1486
+ if labels is not None:
1487
+ if self.config.problem_type is None:
1488
+ if self.num_labels == 1:
1489
+ self.config.problem_type = "regression"
1490
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1491
+ self.config.problem_type = "single_label_classification"
1492
+ else:
1493
+ self.config.problem_type = "multi_label_classification"
1494
+
1495
+ if self.config.problem_type == "regression":
1496
+ loss_fct = MSELoss()
1497
+ if self.num_labels == 1:
1498
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1499
+ else:
1500
+ loss = loss_fct(pooled_logits, labels)
1501
+ elif self.config.problem_type == "single_label_classification":
1502
+ loss_fct = CrossEntropyLoss()
1503
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1504
+ elif self.config.problem_type == "multi_label_classification":
1505
+ loss_fct = BCEWithLogitsLoss()
1506
+ loss = loss_fct(pooled_logits, labels)
1507
+ if not return_dict:
1508
+ output = (pooled_logits,) + transformer_outputs[1:]
1509
+ return ((loss,) + output) if loss is not None else output
1510
+
1511
+ return SequenceClassifierOutputWithPast(
1512
+ loss=loss,
1513
+ logits=pooled_logits,
1514
+ past_key_values=transformer_outputs.past_key_values,
1515
+ hidden_states=transformer_outputs.hidden_states,
1516
+ attentions=transformer_outputs.attentions,
1517
+ )
1518
+
1519
+
1520
+ @add_start_docstrings(
1521
+ """
1522
+ GPT2 Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1523
+ Named-Entity-Recognition (NER) tasks.
1524
+ """,
1525
+ GPT2_START_DOCSTRING,
1526
+ )
1527
+ class GPT2ForTokenClassification(GPT2PreTrainedModel):
1528
+ def __init__(self, config):
1529
+ super().__init__(config)
1530
+ self.num_labels = config.num_labels
1531
+
1532
+ self.transformer = GPT2Model(config)
1533
+ if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
1534
+ classifier_dropout = config.classifier_dropout
1535
+ elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
1536
+ classifier_dropout = config.hidden_dropout
1537
+ else:
1538
+ classifier_dropout = 0.1
1539
+ self.dropout = nn.Dropout(classifier_dropout)
1540
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1541
+
1542
+ # Model parallel
1543
+ self.model_parallel = False
1544
+ self.device_map = None
1545
+
1546
+ # Initialize weights and apply final processing
1547
+ self.post_init()
1548
+
1549
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1550
+ # fmt: off
1551
+ @add_code_sample_docstrings(
1552
+ processor_class=_TOKENIZER_FOR_DOC,
1553
+ checkpoint="brad1141/gpt2-finetuned-comp2",
1554
+ output_type=TokenClassifierOutput,
1555
+ config_class=_CONFIG_FOR_DOC,
1556
+ expected_loss=0.25,
1557
+ expected_output=["Lead", "Lead", "Lead", "Position", "Lead", "Lead", "Lead", "Lead", "Lead", "Lead", "Lead", "Lead"],
1558
+ )
1559
+ # fmt: on
1560
+ def forward(
1561
+ self,
1562
+ input_ids: Optional[torch.LongTensor] = None,
1563
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1564
+ attention_mask: Optional[torch.FloatTensor] = None,
1565
+ token_type_ids: Optional[torch.LongTensor] = None,
1566
+ position_ids: Optional[torch.LongTensor] = None,
1567
+ head_mask: Optional[torch.FloatTensor] = None,
1568
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1569
+ labels: Optional[torch.LongTensor] = None,
1570
+ use_cache: Optional[bool] = None,
1571
+ output_attentions: Optional[bool] = None,
1572
+ output_hidden_states: Optional[bool] = None,
1573
+ return_dict: Optional[bool] = None,
1574
+ ) -> Union[Tuple, TokenClassifierOutput]:
1575
+ r"""
1576
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1577
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1578
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1579
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1580
+ """
1581
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1582
+
1583
+ transformer_outputs = self.transformer(
1584
+ input_ids,
1585
+ past_key_values=past_key_values,
1586
+ attention_mask=attention_mask,
1587
+ token_type_ids=token_type_ids,
1588
+ position_ids=position_ids,
1589
+ head_mask=head_mask,
1590
+ inputs_embeds=inputs_embeds,
1591
+ use_cache=use_cache,
1592
+ output_attentions=output_attentions,
1593
+ output_hidden_states=output_hidden_states,
1594
+ return_dict=return_dict,
1595
+ )
1596
+
1597
+ hidden_states = transformer_outputs[0]
1598
+ hidden_states = self.dropout(hidden_states)
1599
+ logits = self.classifier(hidden_states)
1600
+
1601
+ loss = None
1602
+ if labels is not None:
1603
+ loss_fct = CrossEntropyLoss()
1604
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1605
+
1606
+ if not return_dict:
1607
+ output = (logits,) + transformer_outputs[2:]
1608
+ return ((loss,) + output) if loss is not None else output
1609
+
1610
+ return TokenClassifierOutput(
1611
+ loss=loss,
1612
+ logits=logits,
1613
+ hidden_states=transformer_outputs.hidden_states,
1614
+ attentions=transformer_outputs.attentions,
1615
+ )
lavila/models/loss.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+
9
+ import torch
10
+ import torch.distributed as dist
11
+ import torch.distributed.nn
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ from .distributed_utils import gather_from_all
16
+
17
+
18
+ def gather_features(
19
+ image_features,
20
+ text_features,
21
+ local_loss=False,
22
+ gather_with_grad=False,
23
+ rank=0,
24
+ world_size=1,
25
+ ):
26
+ # Adapted from: https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/loss.py
27
+ # We gather tensors from all gpus
28
+ if gather_with_grad:
29
+ all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
30
+ all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
31
+ else:
32
+ gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
33
+ gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
34
+ dist.all_gather(gathered_image_features, image_features)
35
+ dist.all_gather(gathered_text_features, text_features)
36
+ if not local_loss:
37
+ # ensure grads for local rank when all_* features don't have a gradient
38
+ gathered_image_features[rank] = image_features
39
+ gathered_text_features[rank] = text_features
40
+ all_image_features = torch.cat(gathered_image_features, dim=0)
41
+ all_text_features = torch.cat(gathered_text_features, dim=0)
42
+
43
+ return all_image_features, all_text_features
44
+
45
+
46
+ class CLIPLoss(nn.Module):
47
+
48
+ def __init__(
49
+ self,
50
+ use_vissl=False,
51
+ local_loss=False,
52
+ gather_with_grad=False,
53
+ cache_labels=False,
54
+ rank=0,
55
+ world_size=1,
56
+ ):
57
+ super().__init__()
58
+ self.use_vissl = use_vissl
59
+ self.local_loss = local_loss
60
+ self.gather_with_grad = gather_with_grad
61
+ self.cache_labels = cache_labels
62
+ self.rank = rank
63
+ self.world_size = world_size
64
+
65
+ # cache state
66
+ self.prev_num_logits = 0
67
+ self.labels = {}
68
+
69
+ def forward(self, outputs):
70
+ image_features = outputs['image_embed']
71
+ text_features = outputs['text_embed']
72
+ logit_scale = outputs['logit_scale']
73
+ device = image_features.device
74
+ if self.world_size > 1:
75
+ if self.use_vissl:
76
+ all_image_features = gather_from_all(image_features)
77
+ all_text_features = gather_from_all(text_features)
78
+ logits_per_image = logit_scale * all_image_features @ all_text_features.T
79
+ logits_per_text = logits_per_image.T
80
+ else:
81
+ all_image_features, all_text_features = gather_features(
82
+ image_features, text_features,
83
+ self.local_loss, self.gather_with_grad, self.rank, self.world_size)
84
+
85
+ if self.local_loss:
86
+ logits_per_image = logit_scale * image_features @ all_text_features.T
87
+ logits_per_text = logit_scale * text_features @ all_image_features.T
88
+ else:
89
+ logits_per_image = logit_scale * all_image_features @ all_text_features.T
90
+ logits_per_text = logits_per_image.T
91
+ else:
92
+ logits_per_image = logit_scale * image_features @ text_features.T
93
+ logits_per_text = logit_scale * text_features @ image_features.T
94
+
95
+ # calculated ground-truth and cache if enabled
96
+ num_logits = logits_per_image.shape[0]
97
+ if self.prev_num_logits != num_logits or device not in self.labels:
98
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
99
+ if self.world_size > 1 and self.local_loss:
100
+ labels = labels + num_logits * self.rank
101
+ if self.cache_labels:
102
+ self.labels[device] = labels
103
+ self.prev_num_logits = num_logits
104
+ else:
105
+ labels = self.labels[device]
106
+
107
+ loss = (
108
+ F.cross_entropy(logits_per_image, labels) +
109
+ F.cross_entropy(logits_per_text, labels)
110
+ ) / 2
111
+
112
+ # compute accuracy
113
+ with torch.no_grad():
114
+ pred = torch.argmax(logits_per_image, dim=-1)
115
+ correct = pred.eq(labels).sum()
116
+ acc = 100 * correct / logits_per_image.size(0)
117
+
118
+ return {'loss': loss, 'clip_loss': loss, 'clip_acc': acc}
119
+
120
+
121
+ class SSLCLIPLoss(nn.Module):
122
+
123
+ def __init__(
124
+ self,
125
+ use_vissl=False,
126
+ local_loss=False,
127
+ gather_with_grad=False,
128
+ cache_labels=False,
129
+ rank=0,
130
+ world_size=1,
131
+ scale_init=0.08,
132
+ freeze_scale=False,
133
+ ):
134
+ super().__init__()
135
+ self.use_vissl = use_vissl
136
+ self.local_loss = local_loss
137
+ self.gather_with_grad = gather_with_grad
138
+ self.cache_labels = cache_labels
139
+ self.rank = rank
140
+ self.world_size = world_size
141
+ self.logit_scale_pseudo = nn.Parameter(torch.ones([]) * np.log(1 / scale_init))
142
+ if freeze_scale:
143
+ self.logit_scale_pseudo.requires_grad = False
144
+
145
+ # cache state
146
+ self.prev_num_logits = 0
147
+ self.labels = {}
148
+
149
+ def forward(self, outputs, gt_indicators):
150
+ image_features = outputs['image_embed']
151
+ text_features = outputs['text_embed']
152
+ logit_scale = outputs['logit_scale']
153
+ logit_scale_pseudo = self.logit_scale_pseudo.exp()
154
+ device = image_features.device
155
+ if self.world_size > 1:
156
+ if self.use_vissl:
157
+ all_image_features = gather_from_all(image_features)
158
+ all_text_features = gather_from_all(text_features)
159
+ all_gt_indicators = gather_from_all(gt_indicators)
160
+ num = all_gt_indicators.shape[0]
161
+ mask = all_gt_indicators.repeat(num, 1) + all_gt_indicators.repeat(num, 1).T
162
+ logit_scale_mat = torch.ones((num, num), device=device)
163
+ logit_scale_mat[mask == 0] = logit_scale_pseudo
164
+ logit_scale_mat[mask == 1] = torch.sqrt(logit_scale_pseudo * logit_scale)
165
+ logit_scale_mat[mask == 2] = logit_scale
166
+ logits_per_image = logit_scale_mat * (all_image_features @ all_text_features.T)
167
+ logits_per_text = logits_per_image.T
168
+ else:
169
+ raise NotImplementedError
170
+ else:
171
+ all_gt_indicators = gt_indicators
172
+ num = gt_indicators.shape[0]
173
+ mask = gt_indicators.repeat(num, 1) + gt_indicators.repeat(num, 1).T
174
+ logit_scale_mat = torch.ones((num, num), device=device)
175
+ logit_scale_mat[mask == 0] = logit_scale_pseudo
176
+ logit_scale_mat[mask == 1] = torch.sqrt(logit_scale_pseudo * logit_scale)
177
+ logit_scale_mat[mask == 2] = logit_scale
178
+ logits_per_image = logit_scale_mat * (image_features @ text_features.T)
179
+ logits_per_text = logit_scale_mat * (text_features @ image_features.T)
180
+
181
+ # calculated ground-truth and cache if enabled
182
+ num_logits = logits_per_image.shape[0]
183
+ if self.prev_num_logits != num_logits or device not in self.labels:
184
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
185
+ if self.world_size > 1 and self.local_loss:
186
+ labels = labels + num_logits * self.rank
187
+ if self.cache_labels:
188
+ self.labels[device] = labels
189
+ self.prev_num_logits = num_logits
190
+ else:
191
+ labels = self.labels[device]
192
+
193
+ loss = (
194
+ F.cross_entropy(logits_per_image, labels) +
195
+ F.cross_entropy(logits_per_text, labels)
196
+ ) / 2
197
+
198
+ # compute accuracy
199
+ with torch.no_grad():
200
+ pred = torch.argmax(logits_per_image, dim=-1)
201
+ correct = pred.eq(labels).sum()
202
+ acc = 100 * correct / logits_per_image.size(0)
203
+ pred_gt = pred[all_gt_indicators == 1]
204
+ labels_gt = labels[all_gt_indicators == 1]
205
+ pred_pseudo = pred[all_gt_indicators == 0]
206
+ labels_pseudo = labels[all_gt_indicators == 0]
207
+ num_gt = pred_gt.shape[0]
208
+ num_pseudo = pred_pseudo.shape[0]
209
+ correct_gt = pred_gt.eq(labels_gt).sum()
210
+ correct_pseudo = pred_pseudo.eq(labels_pseudo).sum()
211
+ acc_gt = 100 * correct_gt / num_gt
212
+ acc_pseudo = 100 * correct_pseudo / num_pseudo
213
+
214
+ return {
215
+ 'loss': loss, 'clip_loss': loss, 'num_gt': torch.tensor([num_gt]), 'num_pseudo': torch.tensor([num_pseudo]),
216
+ 'clip_acc': acc, 'clip_acc_gt': acc_gt, 'clip_acc_pseudo': acc_pseudo
217
+ }
218
+
219
+
220
+ class CaptionLoss(nn.Module):
221
+ def __init__(self, pad_id=0, tokenizer=None):
222
+ super().__init__()
223
+ self.pad_id = pad_id
224
+ self.tokenizer = tokenizer
225
+ self.pad_id = tokenizer.pad_token_id
226
+
227
+ def forward(self, outputs):
228
+ logits = outputs['text_tokens_logits']
229
+ labels = outputs['labels']
230
+ # loss = F.cross_entropy(logits, labels, ignore_index=self.pad_id)
231
+ loss = F.cross_entropy(logits, labels, ignore_index=self.pad_id, reduction='none')
232
+
233
+ # compute accuracy
234
+ with torch.no_grad():
235
+ correct = 0.
236
+ total = 0.
237
+ ppls = []
238
+ for i in range(logits.size(0)):
239
+ pred = torch.argmax(logits[i], dim=0)
240
+ nopad = labels[i].ne(self.pad_id)
241
+ correct += (pred.eq(labels[i]) & nopad).sum()
242
+ total += nopad.sum()
243
+ ppl = torch.exp(loss[i].sum() / nopad.sum())
244
+ ppls.append(ppl)
245
+ # TODO: for debug only
246
+ # sep_pos = labels[i].tolist().index(self.tokenizer.tokenizer.sep_token_id)
247
+ # if self.tokenizer is not None:
248
+ # print('{} {} {}'.format(
249
+ # i, self.tokenizer.tokenizer.convert_ids_to_tokens(pred[:sep_pos]),
250
+ # self.tokenizer.tokenizer.convert_ids_to_tokens(labels[i, :sep_pos]),
251
+ # ))
252
+ acc = 100 * correct / (total + 1e-8)
253
+ return {'loss': loss.mean(), 'caption_loss': loss.mean(), 'caption_acc': acc, 'ppl': torch.tensor(ppls).mean()}
254
+
255
+
256
+ def sim_matrix(a, b, eps=1e-8):
257
+ """
258
+ added eps for numerical stability
259
+ """
260
+ a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
261
+ a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
262
+ b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
263
+ sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
264
+ return sim_mt
265
+
266
+
267
+ class MaxMarginRankingLoss(nn.Module):
268
+
269
+ def __init__(self, margin=0.2, fix_norm=True):
270
+ super().__init__()
271
+ self.fix_norm = fix_norm
272
+ self.loss = nn.MarginRankingLoss(margin)
273
+ self.margin = margin
274
+
275
+ def forward(self, outputs, weight=None):
276
+ image_features = outputs['image_embed']
277
+ text_features = outputs['text_embed']
278
+
279
+ all_image_features = gather_from_all(image_features)
280
+ all_text_features = gather_from_all(text_features)
281
+ x = sim_matrix(all_text_features, all_image_features)
282
+
283
+ n = x.size()[0]
284
+
285
+ x1 = torch.diag(x)
286
+ x1 = x1.unsqueeze(1)
287
+ x1 = x1.expand(n, n)
288
+ x1 = x1.contiguous().view(-1, 1)
289
+ x1 = torch.cat((x1, x1), 0)
290
+
291
+ x2 = x.view(-1, 1)
292
+ x3 = x.transpose(0, 1).contiguous().view(-1, 1)
293
+
294
+ x2 = torch.cat((x2, x3), 0)
295
+ max_margin = F.relu(self.margin - (x1 - x2))
296
+
297
+ if self.fix_norm:
298
+ # remove the elements from the diagonal
299
+ keep = torch.ones(x.shape) - torch.eye(x.shape[0]) # 128 x 128
300
+ keep1 = keep.view(-1, 1)
301
+ keep2 = keep.transpose(0, 1).contiguous().view(-1, 1)
302
+ keep_idx = torch.nonzero(torch.cat((keep1, keep2), 0).flatten()).flatten()
303
+ if x1.is_cuda:
304
+ keep_idx = keep_idx.cuda()
305
+ x1_ = torch.index_select(x1, dim=0, index=keep_idx)
306
+ x2_ = torch.index_select(x2, dim=0, index=keep_idx)
307
+ max_margin = F.relu(self.margin - (x1_ - x2_))
308
+
309
+ return {
310
+ 'loss': max_margin.mean(),
311
+ 'max_margin_loss': max_margin.mean()
312
+ }
313
+
314
+
315
+ class AdaptiveMaxMarginRankingLoss(nn.Module):
316
+
317
+ def __init__(self, margin=0.4, fix_norm=True):
318
+ super().__init__()
319
+ self.fix_norm = fix_norm
320
+ self.loss = nn.MarginRankingLoss(margin)
321
+ self.margin = margin
322
+
323
+ def forward(self, outputs, weight=None):
324
+ image_features = outputs['image_embed']
325
+ text_features = outputs['text_embed']
326
+
327
+ all_image_features = gather_from_all(image_features)
328
+ all_text_features = gather_from_all(text_features)
329
+ all_weights = gather_from_all(weight)
330
+ x = sim_matrix(all_text_features, all_image_features)
331
+
332
+ n = x.size()[0]
333
+
334
+ x1 = torch.diag(x)
335
+ x1 = x1.unsqueeze(1)
336
+ x1 = x1.expand(n, n)
337
+ x1 = x1.contiguous().view(-1, 1)
338
+ x1 = torch.cat((x1, x1), 0)
339
+
340
+ w1 = all_weights.unsqueeze(1)
341
+ w1 = w1.expand(n, n)
342
+ w1 = w1.contiguous().view(-1, 1)
343
+ w1 = torch.cat((w1, w1), 0)
344
+
345
+ x2 = x.view(-1, 1)
346
+ x3 = x.transpose(0, 1).contiguous().view(-1, 1)
347
+
348
+ x2 = torch.cat((x2, x3), 0)
349
+ max_margin = F.relu(w1 * self.margin - (x1 - x2))
350
+
351
+ if self.fix_norm:
352
+ # remove the elements from the diagonal
353
+ keep = torch.ones(x.shape) - torch.eye(x.shape[0]) # 128 x 128
354
+ keep1 = keep.view(-1, 1)
355
+ keep2 = keep.transpose(0, 1).contiguous().view(-1, 1)
356
+ keep_idx = torch.nonzero(torch.cat((keep1, keep2), 0).flatten()).flatten()
357
+ if x1.is_cuda:
358
+ keep_idx = keep_idx.cuda()
359
+ x1_ = torch.index_select(x1, dim=0, index=keep_idx)
360
+ w1_ = torch.index_select(w1, dim=0, index=keep_idx)
361
+ x2_ = torch.index_select(x2, dim=0, index=keep_idx)
362
+ max_margin = F.relu(w1_ * self.margin - (x1_ - x2_))
363
+
364
+ return {
365
+ 'loss': max_margin.mean(),
366
+ 'max_margin_loss': max_margin.mean()
367
+ }
lavila/models/models.py ADDED
@@ -0,0 +1,1218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+ import timm
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from transformers import DistilBertModel, GPT2LMHeadModel
13
+
14
+ import lavila.models.loss as loss
15
+ from lavila.models.gpt2_gated import GPT2LMHeadModel as GatedGPT2LMHeadModel
16
+ from lavila.models.gpt2_gated import augment_gpt2_config
17
+ from lavila.models.narrator import VCLM_HF
18
+ from lavila.models.openai_clip import load as load_openai_clip
19
+ from lavila.models.openai_model import QuickGELU, Transformer
20
+ from lavila.models.timesformer import SpaceTimeTransformer
21
+ from lavila.models.utils import remap_keys, rsetattr
22
+
23
+
24
+ class VideoClassifier(nn.Module):
25
+ def __init__(self,
26
+ vision_model: nn.Module,
27
+ dropout: float,
28
+ num_classes: int,
29
+ **kwargs,
30
+ ):
31
+ super().__init__()
32
+ self.visual = vision_model
33
+ self.dropout = nn.Dropout(dropout)
34
+ self.fc_cls = nn.Linear(vision_model.num_features, num_classes, bias=True)
35
+
36
+ self.fc_cls.weight.data.normal_(mean=0.0, std=0.01)
37
+ self.fc_cls.bias.data.zero_()
38
+
39
+ def forward(self, image, use_checkpoint=False):
40
+ image_embed = self.visual(image, use_checkpoint=use_checkpoint)
41
+ if isinstance(image_embed, list):
42
+ assert len(image_embed) == 1
43
+ image_embed = image_embed[0]
44
+ logit = self.fc_cls(self.dropout(image_embed))
45
+ return logit
46
+
47
+
48
+ class VideoClassifierMultiHead(nn.Module):
49
+ def __init__(self,
50
+ vision_model: nn.Module,
51
+ dropout: float,
52
+ num_classes_list: list,
53
+ **kwargs,
54
+ ):
55
+ super().__init__()
56
+ self.visual = vision_model
57
+ self.dropout = nn.Dropout(dropout)
58
+ self.fc_cls = nn.ModuleList(
59
+ [nn.Linear(vision_model.num_features, num_classes, bias=True) for num_classes in num_classes_list]
60
+ )
61
+
62
+ for m in self.fc_cls:
63
+ m.weight.data.normal_(mean=0.0, std=0.01)
64
+ m.bias.data.zero_()
65
+
66
+ def forward(self, image, use_checkpoint=False):
67
+ image_embed = self.visual(image, use_checkpoint=use_checkpoint)
68
+ if isinstance(image_embed, list):
69
+ assert len(image_embed) == 1
70
+ image_embed = image_embed[0]
71
+ logit_list = [m(self.dropout(image_embed)) for m in self.fc_cls]
72
+ return logit_list
73
+
74
+
75
+ class CLIP(nn.Module):
76
+ def __init__(self,
77
+ embed_dim: int,
78
+ # vision
79
+ vision_width: int,
80
+ vision_model: nn.Module,
81
+ # text
82
+ context_length: int,
83
+ vocab_size: int,
84
+ transformer_width: int,
85
+ transformer_heads: int,
86
+ transformer_layers: int,
87
+ tempearture_init=0.07,
88
+ **kwargs,
89
+ ):
90
+ super().__init__()
91
+
92
+ self.context_length = context_length
93
+ self.vision_width = vision_width
94
+
95
+ self.visual = vision_model
96
+ self.transformer = Transformer(
97
+ width=transformer_width,
98
+ layers=transformer_layers,
99
+ heads=transformer_heads,
100
+ attn_mask=self.build_attention_mask(),
101
+ )
102
+
103
+ self.vocab_size = vocab_size
104
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
105
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
106
+ self.ln_final = nn.LayerNorm(transformer_width) # used to be `models.transformer.LayerNorm``
107
+
108
+ self.image_projection = nn.Parameter(torch.empty(vision_width, embed_dim))
109
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
110
+ print("=> initialize initial temperature with {}".format(tempearture_init))
111
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / tempearture_init))
112
+
113
+ self.initialize_parameters()
114
+
115
+ def initialize_parameters(self):
116
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
117
+ nn.init.normal_(self.positional_embedding, std=0.01)
118
+
119
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
120
+ attn_std = self.transformer.width ** -0.5
121
+ fc_std = (2 * self.transformer.width) ** -0.5
122
+ for block in self.transformer.resblocks:
123
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
124
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
125
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
126
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
127
+
128
+ nn.init.normal_(self.image_projection, std=self.vision_width ** -0.5)
129
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
130
+
131
+ def build_attention_mask(self):
132
+ # lazily create causal attention mask, with full attention between the vision tokens
133
+ # pytorch uses additive attention mask; fill with -inf
134
+ mask = torch.empty(self.context_length, self.context_length)
135
+ mask.fill_(float("-inf"))
136
+ mask.triu_(1) # zero out the lower diagonal
137
+ return mask
138
+
139
+ def encode_image(self, image, use_checkpoint=False, apply_project=True):
140
+ x = self.visual(image, use_checkpoint=use_checkpoint)
141
+ if isinstance(x, list):
142
+ assert len(x) == 1
143
+ x = x[0]
144
+ if not apply_project:
145
+ return x
146
+ x = x @ self.image_projection
147
+
148
+ return x
149
+
150
+ def encode_text(self, text, use_checkpoint=False):
151
+ x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
152
+ x = x + self.positional_embedding
153
+ x = x.permute(1, 0, 2) # NLD -> LND
154
+ x = self.transformer(x, use_checkpoint=use_checkpoint)
155
+ x = x.permute(1, 0, 2) # LND -> NLD
156
+ x = self.ln_final(x)
157
+
158
+ # x.shape = [batch_size, n_ctx, transformer.width]
159
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
160
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
161
+
162
+ return x
163
+
164
+ def forward(self, image, text, use_checkpoint=False, norm_embed=False):
165
+ image_embed = self.encode_image(image, use_checkpoint=use_checkpoint)
166
+ text_embed = self.encode_text(text, use_checkpoint=use_checkpoint)
167
+
168
+ if norm_embed:
169
+ image_embed = F.normalize(image_embed, dim=-1)
170
+ text_embed = F.normalize(text_embed, dim=-1)
171
+ return {'image_embed': image_embed,
172
+ 'text_embed': text_embed,
173
+ 'logit_scale': self.logit_scale.exp()}
174
+
175
+
176
+ class CLIP_HF(nn.Module):
177
+ def __init__(self,
178
+ embed_dim: int,
179
+ # vision
180
+ vision_width: int,
181
+ vision_model: nn.Module,
182
+ # text
183
+ text_width: int,
184
+ text_model: nn.Module,
185
+ text_use_cls_token: bool,
186
+ text_is_regressive: bool,
187
+ tempearture_init=0.07,
188
+ **kwargs,
189
+ ):
190
+ super().__init__()
191
+
192
+ self.vision_width = vision_width
193
+ self.visual = vision_model
194
+ self.text_width = text_width
195
+ self.textual = text_model
196
+ self.text_use_cls_token = text_use_cls_token
197
+ self.text_is_regressive = text_is_regressive
198
+
199
+ if 'projection' not in kwargs:
200
+ self.projection = 'default'
201
+ else:
202
+ self.projection = kwargs['projection']
203
+ if self.projection == 'default':
204
+ self.image_projection = nn.Parameter(torch.empty(vision_width, embed_dim))
205
+ self.text_projection = nn.Parameter(torch.empty(text_width, embed_dim))
206
+ elif self.projection == 'frozen_in_time':
207
+ self.image_projection = nn.Sequential(
208
+ nn.Linear(vision_width, embed_dim)
209
+ )
210
+ self.text_projection = nn.Sequential(
211
+ nn.ReLU(),
212
+ nn.Linear(text_width, embed_dim)
213
+ )
214
+ print("=> initialize initial temperature with {}".format(tempearture_init))
215
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / tempearture_init))
216
+
217
+ self.initialize_parameters()
218
+
219
+ def initialize_parameters(self):
220
+ if self.projection == 'default':
221
+ nn.init.normal_(self.image_projection, std=self.vision_width ** -0.5)
222
+ nn.init.normal_(self.text_projection, std=self.text_width ** -0.5)
223
+ else:
224
+ nn.init.normal_(self.image_projection[0].weight, std=self.vision_width ** -0.5)
225
+ nn.init.normal_(self.text_projection[1].weight, std=self.text_width ** -0.5)
226
+
227
+ def build_attention_mask(self):
228
+ # lazily create causal attention mask, with full attention between the vision tokens
229
+ # pytorch uses additive attention mask; fill with -inf
230
+ mask = torch.empty(self.context_length, self.context_length)
231
+ mask.fill_(float("-inf"))
232
+ mask.triu_(1) # zero out the lower diagonal
233
+ return mask
234
+
235
+ def encode_image(self, image, use_checkpoint=False, apply_project=True):
236
+ x = self.visual(image, use_checkpoint=use_checkpoint)
237
+ if isinstance(x, list):
238
+ assert len(x) == 1
239
+ x = x[0]
240
+ if not apply_project:
241
+ return x
242
+ if self.projection == 'default':
243
+ x = x @ self.image_projection
244
+ else:
245
+ x = self.image_projection(x)
246
+
247
+ return x
248
+
249
+ def encode_text(self, text, attention_mask=None, use_checkpoint=False):
250
+ if use_checkpoint:
251
+ if isinstance(self.textual, DistilBertModel):
252
+ pass
253
+ # print("DistilBertModel does not support gradient checkpointing. Skipping even if use_checkpoint=True")
254
+ else:
255
+ self.textual.gradient_checkpointing_enable()
256
+ else:
257
+ self.textual.gradient_checkpointing_disable()
258
+ # text, attention_mask = text.squeeze(1), attention_mask.squeeze(1)
259
+ # ^ uncomment this only when doing local debugging (distributed=False)
260
+ x = self.textual(text, attention_mask=attention_mask)
261
+
262
+ if self.text_is_regressive:
263
+ # gpt-style
264
+ x = x.last_hidden_state
265
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
266
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)]
267
+ else:
268
+ # bert-style
269
+ if self.text_use_cls_token:
270
+ x = x.last_hidden_state
271
+ x = x[torch.arange(x.shape[0]), 0, :]
272
+ else:
273
+ x = x.pooler_output
274
+ if self.projection == 'default':
275
+ x = x @ self.text_projection
276
+ else:
277
+ x = self.text_projection(x)
278
+
279
+ return x
280
+
281
+ def forward(self, image, text, mask=None, use_checkpoint=False, norm_embed=False):
282
+ image_embed = self.encode_image(image, use_checkpoint=use_checkpoint)
283
+ text_embed = self.encode_text(text, attention_mask=mask, use_checkpoint=use_checkpoint)
284
+
285
+ if norm_embed:
286
+ image_embed = F.normalize(image_embed, dim=-1)
287
+ text_embed = F.normalize(text_embed, dim=-1)
288
+ return {'image_embed': image_embed,
289
+ 'text_embed': text_embed,
290
+ 'logit_scale': self.logit_scale.exp()}
291
+
292
+
293
+ def get_loss(model, args, tokenizer=None):
294
+ if model.startswith('CLIP'):
295
+ return loss.CLIPLoss(
296
+ use_vissl=args.contrastive_use_vissl,
297
+ cache_labels=True,
298
+ rank=args.rank,
299
+ world_size=args.world_size,
300
+ )
301
+ elif model.startswith('VCLM'):
302
+ return loss.CaptionLoss(tokenizer=tokenizer)
303
+ else:
304
+ raise NotImplementedError
305
+
306
+
307
+ def get_metric_names(model):
308
+ if model.startswith('CLIP'):
309
+ return ['loss', 'clip_loss', 'clip_acc']
310
+ elif model.startswith('VCLM'):
311
+ return ['loss', 'caption_loss', 'caption_acc', 'ppl']
312
+ else:
313
+ raise NotImplementedError
314
+
315
+
316
+ def CLIP_OPENAI_TIMESFORMER_BASE(
317
+ num_frames=4, timesformer_gated_xattn=False, drop_path_rate=0, timesformer_freeze_space=False,
318
+ temperature_init=0.07, project_embed_dim=256, **kwargs,
319
+ ):
320
+ vision_model = SpaceTimeTransformer(
321
+ num_frames=num_frames,
322
+ time_init='zeros',
323
+ attention_style='frozen-in-time',
324
+ ln_pre=True,
325
+ act_layer=QuickGELU,
326
+ is_tanh_gating=timesformer_gated_xattn,
327
+ drop_path_rate=drop_path_rate,
328
+ )
329
+ clip_model, _ = load_openai_clip('ViT-B/16', 'cpu')
330
+ print("=> Loading CLIP (ViT-B/16) weights")
331
+ remapped_state_dict = remap_keys(clip_model.visual.state_dict(), transformer_layers=12)
332
+ res = vision_model.load_state_dict(remapped_state_dict, strict=False)
333
+ print(res)
334
+ if timesformer_freeze_space:
335
+ print("=> Freeze the space part in TimeSformer")
336
+ freeze_list, unfreeze_list = [], []
337
+ for n, p in vision_model.named_parameters():
338
+ if n not in remapped_state_dict or n == 'cls_token':
339
+ p.requires_grad = True
340
+ unfreeze_list.append(n)
341
+ else:
342
+ p.requires_grad = False
343
+ freeze_list.append(n)
344
+ print("Freeze the pretrained parts in TimeSformer: {}".format(freeze_list))
345
+ print(" Learn the rest parts in TimeSformer: {}".format(unfreeze_list))
346
+
347
+ vision_model.head = nn.Identity()
348
+ vision_model.pre_logits = nn.Identity()
349
+ vision_model.fc = nn.Identity()
350
+ model = CLIP(
351
+ embed_dim=project_embed_dim,
352
+ vision_width=768,
353
+ vision_model=vision_model,
354
+ context_length=77,
355
+ vocab_size=49408,
356
+ transformer_width=512,
357
+ transformer_heads=8,
358
+ transformer_layers=12,
359
+ tempearture_init=temperature_init,
360
+ **kwargs
361
+ )
362
+ model.transformer.load_state_dict(clip_model.transformer.state_dict())
363
+ model.token_embedding.load_state_dict(clip_model.token_embedding.state_dict())
364
+ model.positional_embedding.data.copy_(clip_model.positional_embedding.data)
365
+ model.ln_final.load_state_dict(clip_model.ln_final.state_dict())
366
+ if project_embed_dim == clip_model.text_projection.shape[1]:
367
+ print("=> Loading CLIP's text_projection, image_projection and logit_scale directly")
368
+ model.image_projection.data.copy_(clip_model.visual.proj.data)
369
+ model.text_projection.data.copy_(clip_model.text_projection.data)
370
+ model.logit_scale.data.copy_(clip_model.logit_scale.data)
371
+ return model
372
+
373
+
374
+ def CLIP_OPENAI_TIMESFORMER_LARGE(
375
+ num_frames=4, timesformer_gated_xattn=False, drop_path_rate=0, timesformer_freeze_space=False,
376
+ temperature_init=0.07, project_embed_dim=256, **kwargs,
377
+ ):
378
+ vision_model = SpaceTimeTransformer(
379
+ img_size=224, patch_size=14,
380
+ embed_dim=1024, depth=24, num_heads=16,
381
+ num_frames=num_frames,
382
+ time_init='zeros',
383
+ attention_style='frozen-in-time',
384
+ ln_pre=True,
385
+ act_layer=QuickGELU,
386
+ is_tanh_gating=timesformer_gated_xattn,
387
+ drop_path_rate=drop_path_rate,
388
+ )
389
+ clip_model, _ = load_openai_clip('ViT-L/14', 'cpu')
390
+ print("=> Loading CLIP (ViT-L/14) weights")
391
+ remapped_state_dict = remap_keys(clip_model.visual.state_dict(), transformer_layers=24)
392
+ res = vision_model.load_state_dict(remapped_state_dict, strict=False)
393
+ print(res)
394
+ if timesformer_freeze_space:
395
+ print("=> Freeze the space part in TimeSformer")
396
+ freeze_list, unfreeze_list = [], []
397
+ for n, p in vision_model.named_parameters():
398
+ if n not in remapped_state_dict or n == 'cls_token':
399
+ p.requires_grad = True
400
+ unfreeze_list.append(n)
401
+ else:
402
+ p.requires_grad = False
403
+ freeze_list.append(n)
404
+ print("Freeze the pretrained parts in TimeSformer: {}".format(freeze_list))
405
+ print(" Learn the rest parts in TimeSformer: {}".format(unfreeze_list))
406
+
407
+ vision_model.head = nn.Identity()
408
+ vision_model.pre_logits = nn.Identity()
409
+ vision_model.fc = nn.Identity()
410
+ model = CLIP(
411
+ embed_dim=project_embed_dim,
412
+ vision_width=1024,
413
+ vision_model=vision_model,
414
+ context_length=77,
415
+ vocab_size=49408,
416
+ transformer_width=768,
417
+ transformer_heads=12,
418
+ transformer_layers=12,
419
+ tempearture_init=temperature_init,
420
+ **kwargs
421
+ )
422
+ model.transformer.load_state_dict(clip_model.transformer.state_dict())
423
+ model.token_embedding.load_state_dict(clip_model.token_embedding.state_dict())
424
+ model.positional_embedding.data.copy_(clip_model.positional_embedding.data)
425
+ model.ln_final.load_state_dict(clip_model.ln_final.state_dict())
426
+ if project_embed_dim == clip_model.text_projection.shape[1]:
427
+ print("=> Loading CLIP's text_projection, image_projection and logit_scale directly")
428
+ model.image_projection.data.copy_(clip_model.visual.proj.data)
429
+ model.text_projection.data.copy_(clip_model.text_projection.data)
430
+ model.logit_scale.data.copy_(clip_model.logit_scale.data)
431
+ return model
432
+
433
+
434
+ def CLIP_OPENAI_TIMESFORMER_LARGE_336PX(
435
+ num_frames=4, timesformer_gated_xattn=False, drop_path_rate=0, timesformer_freeze_space=False,
436
+ temperature_init=0.07, project_embed_dim=256, **kwargs,
437
+ ):
438
+ vision_model = SpaceTimeTransformer(
439
+ img_size=336, patch_size=14,
440
+ embed_dim=1024, depth=24, num_heads=16,
441
+ num_frames=num_frames,
442
+ time_init='zeros',
443
+ attention_style='frozen-in-time',
444
+ ln_pre=True,
445
+ act_layer=QuickGELU,
446
+ is_tanh_gating=timesformer_gated_xattn,
447
+ drop_path_rate=drop_path_rate,
448
+ )
449
+ clip_model, _ = load_openai_clip('ViT-L/14@336px', 'cpu')
450
+ print("=> Loading CLIP (ViT-L/14@336px) weights")
451
+ remapped_state_dict = remap_keys(clip_model.visual.state_dict(), transformer_layers=24)
452
+ res = vision_model.load_state_dict(remapped_state_dict, strict=False)
453
+ print(res)
454
+ if timesformer_freeze_space:
455
+ print("=> Freeze the space part in TimeSformer")
456
+ freeze_list, unfreeze_list = [], []
457
+ for n, p in vision_model.named_parameters():
458
+ if n not in remapped_state_dict or n == 'cls_token':
459
+ p.requires_grad = True
460
+ unfreeze_list.append(n)
461
+ else:
462
+ p.requires_grad = False
463
+ freeze_list.append(n)
464
+ print("Freeze the pretrained parts in TimeSformer: {}".format(freeze_list))
465
+ print(" Learn the rest parts in TimeSformer: {}".format(unfreeze_list))
466
+
467
+ vision_model.head = nn.Identity()
468
+ vision_model.pre_logits = nn.Identity()
469
+ vision_model.fc = nn.Identity()
470
+ model = CLIP(
471
+ embed_dim=project_embed_dim,
472
+ vision_width=1024,
473
+ vision_model=vision_model,
474
+ context_length=77,
475
+ vocab_size=49408,
476
+ transformer_width=768,
477
+ transformer_heads=12,
478
+ transformer_layers=12,
479
+ tempearture_init=temperature_init,
480
+ **kwargs
481
+ )
482
+ model.transformer.load_state_dict(clip_model.transformer.state_dict())
483
+ model.token_embedding.load_state_dict(clip_model.token_embedding.state_dict())
484
+ model.positional_embedding.data.copy_(clip_model.positional_embedding.data)
485
+ model.ln_final.load_state_dict(clip_model.ln_final.state_dict())
486
+ if project_embed_dim == clip_model.text_projection.shape[1]:
487
+ print("=> Loading CLIP's text_projection, image_projection and logit_scale directly")
488
+ model.image_projection.data.copy_(clip_model.visual.proj.data)
489
+ model.text_projection.data.copy_(clip_model.text_projection.data)
490
+ model.logit_scale.data.copy_(clip_model.logit_scale.data)
491
+ return model
492
+
493
+
494
+ def CLIP_OPENAI_TIMESFORMER_BASE_DISTILBERT_BASE(
495
+ num_frames=4, timesformer_gated_xattn=False, drop_path_rate=0, timesformer_freeze_space=False,
496
+ temperature_init=0.07, project_embed_dim=256, **kwargs,
497
+ ):
498
+ vision_model = SpaceTimeTransformer(
499
+ num_frames=num_frames,
500
+ time_init='zeros',
501
+ attention_style='frozen-in-time',
502
+ ln_pre=True,
503
+ act_layer=QuickGELU,
504
+ is_tanh_gating=timesformer_gated_xattn,
505
+ drop_path_rate=drop_path_rate,
506
+ )
507
+ clip_model, _ = load_openai_clip('ViT-B/16', 'cpu')
508
+ print("=> Loading CLIP (ViT-B/16) weights")
509
+ remapped_state_dict = remap_keys(clip_model.visual.state_dict(), transformer_layers=12)
510
+ res = vision_model.load_state_dict(remapped_state_dict, strict=False)
511
+ print(res)
512
+ if timesformer_freeze_space:
513
+ print("=> Freeze the space part in TimeSformer")
514
+ freeze_list, unfreeze_list = [], []
515
+ for n, p in vision_model.named_parameters():
516
+ if n not in remapped_state_dict or n == 'cls_token':
517
+ p.requires_grad = True
518
+ unfreeze_list.append(n)
519
+ else:
520
+ p.requires_grad = False
521
+ freeze_list.append(n)
522
+ print("Freeze the pretrained parts in TimeSformer: {}".format(freeze_list))
523
+ print(" Learn the rest parts in TimeSformer: {}".format(unfreeze_list))
524
+
525
+ vision_model.head = nn.Identity()
526
+ vision_model.pre_logits = nn.Identity()
527
+ vision_model.fc = nn.Identity()
528
+
529
+ text_model = DistilBertModel.from_pretrained(
530
+ 'distilbert-base-uncased',
531
+ )
532
+ kwargs.pop('text_use_cls_token') # ignore args.use_cls_token since DistilBert does not have pooler on top
533
+ model = CLIP_HF(
534
+ embed_dim=project_embed_dim,
535
+ vision_width=vision_model.embed_dim,
536
+ vision_model=vision_model,
537
+ text_width=768,
538
+ text_model=text_model,
539
+ text_use_cls_token=True, # DistilBert does not have pooler on top
540
+ text_is_regressive=False,
541
+ tempearture_init=temperature_init,
542
+ **kwargs,
543
+ )
544
+
545
+ return model
546
+
547
+
548
+ def CLIP_OPENAI_TIMESFORMER_LARGE_DISTILBERT_BASE(
549
+ num_frames=4, timesformer_gated_xattn=False, drop_path_rate=0, timesformer_freeze_space=False,
550
+ temperature_init=0.07, project_embed_dim=256, **kwargs,
551
+ ):
552
+ vision_model = SpaceTimeTransformer(
553
+ img_size=224, patch_size=14,
554
+ embed_dim=1024, depth=24, num_heads=16,
555
+ num_frames=num_frames,
556
+ time_init='zeros',
557
+ attention_style='frozen-in-time',
558
+ ln_pre=True,
559
+ act_layer=QuickGELU,
560
+ is_tanh_gating=timesformer_gated_xattn,
561
+ drop_path_rate=drop_path_rate,
562
+ )
563
+ clip_model, _ = load_openai_clip('ViT-L/14', 'cpu')
564
+ print("=> Loading CLIP (ViT-L/14) weights")
565
+ remapped_state_dict = remap_keys(clip_model.visual.state_dict(), transformer_layers=24)
566
+ res = vision_model.load_state_dict(remapped_state_dict, strict=False)
567
+ print(res)
568
+ if timesformer_freeze_space:
569
+ print("=> Freeze the space part in TimeSformer")
570
+ freeze_list, unfreeze_list = [], []
571
+ for n, p in vision_model.named_parameters():
572
+ if n not in remapped_state_dict or n == 'cls_token':
573
+ p.requires_grad = True
574
+ unfreeze_list.append(n)
575
+ else:
576
+ p.requires_grad = False
577
+ freeze_list.append(n)
578
+ print("Freeze the pretrained parts in TimeSformer: {}".format(freeze_list))
579
+ print(" Learn the rest parts in TimeSformer: {}".format(unfreeze_list))
580
+
581
+ vision_model.head = nn.Identity()
582
+ vision_model.pre_logits = nn.Identity()
583
+ vision_model.fc = nn.Identity()
584
+
585
+ text_model = DistilBertModel.from_pretrained(
586
+ 'distilbert-base-uncased',
587
+ )
588
+ kwargs.pop('text_use_cls_token') # ignore args.use_cls_token since DistilBert does not have pooler on top
589
+ model = CLIP_HF(
590
+ embed_dim=project_embed_dim,
591
+ vision_width=vision_model.embed_dim,
592
+ vision_model=vision_model,
593
+ text_width=768,
594
+ text_model=text_model,
595
+ text_use_cls_token=True, # DistilBert does not have pooler on top
596
+ text_is_regressive=False,
597
+ tempearture_init=temperature_init,
598
+ **kwargs,
599
+ )
600
+
601
+ return model
602
+
603
+
604
+ def CLIP_OPENAI_TIMESFORMER_LARGE_336PX_DISTILBERT_BASE(
605
+ num_frames=4, timesformer_gated_xattn=False, drop_path_rate=0, timesformer_freeze_space=False,
606
+ temperature_init=0.07, project_embed_dim=256, **kwargs,
607
+ ):
608
+ vision_model = SpaceTimeTransformer(
609
+ img_size=336, patch_size=14,
610
+ embed_dim=1024, depth=24, num_heads=16,
611
+ num_frames=num_frames,
612
+ time_init='zeros',
613
+ attention_style='frozen-in-time',
614
+ ln_pre=True,
615
+ act_layer=QuickGELU,
616
+ is_tanh_gating=timesformer_gated_xattn,
617
+ drop_path_rate=drop_path_rate,
618
+ )
619
+ clip_model, _ = load_openai_clip('ViT-L/14@336px', 'cpu')
620
+ print("=> Loading CLIP (ViT-L/14@336px) weights")
621
+ remapped_state_dict = remap_keys(clip_model.visual.state_dict(), transformer_layers=24)
622
+ res = vision_model.load_state_dict(remapped_state_dict, strict=False)
623
+ print(res)
624
+ if timesformer_freeze_space:
625
+ print("=> Freeze the space part in TimeSformer")
626
+ freeze_list, unfreeze_list = [], []
627
+ for n, p in vision_model.named_parameters():
628
+ if n not in remapped_state_dict or n == 'cls_token':
629
+ p.requires_grad = True
630
+ unfreeze_list.append(n)
631
+ else:
632
+ p.requires_grad = False
633
+ freeze_list.append(n)
634
+ print("Freeze the pretrained parts in TimeSformer: {}".format(freeze_list))
635
+ print(" Learn the rest parts in TimeSformer: {}".format(unfreeze_list))
636
+
637
+ vision_model.head = nn.Identity()
638
+ vision_model.pre_logits = nn.Identity()
639
+ vision_model.fc = nn.Identity()
640
+
641
+ text_model = DistilBertModel.from_pretrained(
642
+ 'distilbert-base-uncased',
643
+ )
644
+ kwargs.pop('text_use_cls_token') # ignore args.use_cls_token since DistilBert does not have pooler on top
645
+ model = CLIP_HF(
646
+ embed_dim=project_embed_dim,
647
+ vision_width=vision_model.embed_dim,
648
+ vision_model=vision_model,
649
+ text_width=768,
650
+ text_model=text_model,
651
+ text_use_cls_token=True, # DistilBert does not have pooler on top
652
+ text_is_regressive=False,
653
+ tempearture_init=temperature_init,
654
+ **kwargs,
655
+ )
656
+
657
+ return model
658
+
659
+
660
+ def CLIP_HF_EGOVLP_DISTILBERT_BASE(num_frames=4, project_embed_dim=256, **kwargs):
661
+ vision_model = SpaceTimeTransformer(
662
+ num_frames=num_frames,
663
+ time_init='zeros',
664
+ attention_style='frozen-in-time',
665
+ )
666
+ vit_model = timm.models.vision_transformer.vit_base_patch16_224(pretrained=True)
667
+ vision_model.load_state_dict(vit_model.state_dict(), strict=False)
668
+ vision_model.head = nn.Identity()
669
+ vision_model.pre_logits = nn.Identity()
670
+ vision_model.fc = nn.Identity()
671
+
672
+ text_model = DistilBertModel.from_pretrained(
673
+ 'distilbert-base-uncased',
674
+ )
675
+ kwargs.pop('text_use_cls_token') # ignore args.use_cls_token since DistilBert does not have pooler on top
676
+ kwargs.update({'projection': 'frozen_in_time'})
677
+ model = CLIP_HF(
678
+ embed_dim=project_embed_dim,
679
+ vision_width=vision_model.embed_dim,
680
+ vision_model=vision_model,
681
+ text_width=768,
682
+ text_model=text_model,
683
+ text_use_cls_token=True, # DistilBert does not have pooler on top
684
+ text_is_regressive=False,
685
+ **kwargs,
686
+ )
687
+
688
+ return model
689
+
690
+
691
+ def CLIP_HF_TIMESFORMER_DISTILBERT_BASE(num_frames=4, drop_path_rate=0, temperature_init=0.07, project_embed_dim=256, **kwargs):
692
+ vision_model = SpaceTimeTransformer(
693
+ num_frames=num_frames,
694
+ time_init='zeros',
695
+ attention_style='frozen-in-time',
696
+ drop_path_rate=drop_path_rate,
697
+ )
698
+ vit_model = timm.models.vision_transformer.vit_base_patch16_224(pretrained=True)
699
+ vision_model.load_state_dict(vit_model.state_dict(), strict=False)
700
+ vision_model.head = nn.Identity()
701
+ vision_model.pre_logits = nn.Identity()
702
+ vision_model.fc = nn.Identity()
703
+
704
+ text_model = DistilBertModel.from_pretrained(
705
+ 'distilbert-base-uncased',
706
+ )
707
+ kwargs.pop('text_use_cls_token') # ignore args.use_cls_token since DistilBert does not have pooler on top
708
+ model = CLIP_HF(
709
+ embed_dim=project_embed_dim,
710
+ vision_width=vision_model.embed_dim,
711
+ vision_model=vision_model,
712
+ text_width=768,
713
+ text_model=text_model,
714
+ text_use_cls_token=True, # DistilBert does not have pooler on top
715
+ text_is_regressive=False,
716
+ tempearture_init=temperature_init,
717
+ **kwargs,
718
+ )
719
+
720
+ return model
721
+
722
+
723
+ def VCLM_OPENAI_VITB16_GPT2_LARGE(gated_xattn=False, freeze_lm_vclm=False,
724
+ freeze_visual_vclm=False, freeze_visual_vclm_temporal=False, **kwargs):
725
+ clip_model, _ = load_openai_clip('ViT-B/16', 'cpu')
726
+ vision_model = clip_model.visual
727
+ kwargs.pop('text_use_cls_token')
728
+
729
+ gpt2 = GPT2LMHeadModel.from_pretrained(
730
+ "gpt2-large",
731
+ use_cache=False,
732
+ )
733
+ new_config = augment_gpt2_config(gpt2.config, cross_attn_freq=2, gated_xattn=gated_xattn)
734
+ text_decoder = GatedGPT2LMHeadModel(new_config)
735
+ for n, p in gpt2.named_parameters():
736
+ rsetattr(text_decoder, n + '.data', p.data)
737
+
738
+ if freeze_lm_vclm:
739
+ print('Freeze the LM part of TextDecoder of VCLM')
740
+ text_decoder.freeze_lm_weights()
741
+
742
+ if freeze_visual_vclm:
743
+ print('Freeze the spatial part of VideoEncoder of VCLM')
744
+ vision_model.freeze_spatial_weights()
745
+
746
+ if freeze_visual_vclm_temporal:
747
+ print('Freeze the temporal part of VideoEncoder of VCLM')
748
+ vision_model.freeze_temporal_weights()
749
+
750
+ model = VCLM_HF(
751
+ vision_width=768,
752
+ vision_model=vision_model,
753
+ text_width=1280,
754
+ text_decoder=text_decoder,
755
+ num_img_queries=256,
756
+ dim_head=64,
757
+ heads=20,
758
+ **kwargs,
759
+ )
760
+
761
+ return model
762
+
763
+
764
+ def VCLM_OPENAI_VITB16_GPT2_XL(gated_xattn=False, freeze_lm_vclm=False,
765
+ freeze_visual_vclm=False, freeze_visual_vclm_temporal=False, **kwargs):
766
+ clip_model, _ = load_openai_clip('ViT-B/16', 'cpu')
767
+ vision_model = clip_model.visual
768
+ kwargs.pop('text_use_cls_token')
769
+
770
+ gpt2 = GPT2LMHeadModel.from_pretrained(
771
+ "gpt2-xl",
772
+ use_cache=False,
773
+ )
774
+ new_config = augment_gpt2_config(gpt2.config, cross_attn_freq=2, gated_xattn=gated_xattn)
775
+ text_decoder = GatedGPT2LMHeadModel(new_config)
776
+ for n, p in gpt2.named_parameters():
777
+ rsetattr(text_decoder, n + '.data', p.data)
778
+
779
+ if freeze_lm_vclm:
780
+ print('Freeze the LM part of TextDecoder of VCLM')
781
+ text_decoder.freeze_lm_weights()
782
+
783
+ if freeze_visual_vclm:
784
+ print('Freeze the spatial part of VideoEncoder of VCLM')
785
+ vision_model.freeze_spatial_weights()
786
+
787
+ if freeze_visual_vclm_temporal:
788
+ print('Freeze the temporal part of VideoEncoder of VCLM')
789
+ vision_model.freeze_temporal_weights()
790
+
791
+ model = VCLM_HF(
792
+ vision_width=768,
793
+ vision_model=vision_model,
794
+ text_width=1600,
795
+ text_decoder=text_decoder,
796
+ num_img_queries=256,
797
+ dim_head=64,
798
+ heads=25,
799
+ **kwargs,
800
+ )
801
+
802
+ return model
803
+
804
+
805
+ def VCLM_OPENAI_VITL14_GPT2_XL(gated_xattn=False, freeze_lm_vclm=False,
806
+ freeze_visual_vclm=False, freeze_visual_vclm_temporal=False, **kwargs):
807
+ clip_model, _ = load_openai_clip('ViT-L/14', 'cpu')
808
+ vision_model = clip_model.visual
809
+ kwargs.pop('text_use_cls_token')
810
+
811
+ gpt2 = GPT2LMHeadModel.from_pretrained(
812
+ "gpt2-xl",
813
+ use_cache=False,
814
+ )
815
+ new_config = augment_gpt2_config(gpt2.config, cross_attn_freq=2, gated_xattn=gated_xattn)
816
+ text_decoder = GatedGPT2LMHeadModel(new_config)
817
+ for n, p in gpt2.named_parameters():
818
+ rsetattr(text_decoder, n + '.data', p.data)
819
+
820
+ if freeze_lm_vclm:
821
+ print('Freeze the LM part of TextDecoder of VCLM')
822
+ text_decoder.freeze_lm_weights()
823
+
824
+ if freeze_visual_vclm:
825
+ print('Freeze the spatial part of VideoEncoder of VCLM')
826
+ vision_model.freeze_spatial_weights()
827
+
828
+ if freeze_visual_vclm_temporal:
829
+ print('Freeze the temporal part of VideoEncoder of VCLM')
830
+ vision_model.freeze_temporal_weights()
831
+
832
+ model = VCLM_HF(
833
+ vision_width=1024,
834
+ vision_model=vision_model,
835
+ text_width=1600,
836
+ text_decoder=text_decoder,
837
+ num_img_queries=256,
838
+ dim_head=64,
839
+ heads=25,
840
+ **kwargs,
841
+ )
842
+
843
+ return model
844
+
845
+
846
+ def VCLM_OPENAI_VITL14_336PX_GPT2_XL(gated_xattn=False, freeze_lm_vclm=False,
847
+ freeze_visual_vclm=False, freeze_visual_vclm_temporal=False, **kwargs):
848
+ clip_model, _ = load_openai_clip('ViT-L/14@336px', 'cpu')
849
+ vision_model = clip_model.visual
850
+ kwargs.pop('text_use_cls_token')
851
+
852
+ gpt2 = GPT2LMHeadModel.from_pretrained(
853
+ "gpt2-xl",
854
+ use_cache=False,
855
+ )
856
+ new_config = augment_gpt2_config(gpt2.config, cross_attn_freq=2, gated_xattn=gated_xattn)
857
+ text_decoder = GatedGPT2LMHeadModel(new_config)
858
+ for n, p in gpt2.named_parameters():
859
+ rsetattr(text_decoder, n + '.data', p.data)
860
+
861
+ if freeze_lm_vclm:
862
+ print('Freeze the LM part of TextDecoder of VCLM')
863
+ text_decoder.freeze_lm_weights()
864
+
865
+ if freeze_visual_vclm:
866
+ print('Freeze the spatial part of VideoEncoder of VCLM')
867
+ vision_model.freeze_spatial_weights()
868
+
869
+ if freeze_visual_vclm_temporal:
870
+ print('Freeze the temporal part of VideoEncoder of VCLM')
871
+ vision_model.freeze_temporal_weights()
872
+
873
+ model = VCLM_HF(
874
+ vision_width=1024,
875
+ vision_model=vision_model,
876
+ text_width=1600,
877
+ text_decoder=text_decoder,
878
+ num_img_queries=256,
879
+ dim_head=64,
880
+ heads=25,
881
+ **kwargs,
882
+ )
883
+
884
+ return model
885
+
886
+
887
+ def VCLM_OPENAI_TIMESFORMER_BASE_GPT2(
888
+ gated_xattn=False,
889
+ random_init_gpt2=False,
890
+ freeze_lm_vclm=False,
891
+ freeze_visual_vclm=False,
892
+ freeze_visual_vclm_temporal=False,
893
+ num_frames=4,
894
+ timesformer_gated_xattn=False,
895
+ **kwargs,
896
+ ):
897
+ vision_model = SpaceTimeTransformer(
898
+ num_frames=num_frames,
899
+ time_init='zeros',
900
+ attention_style='frozen-in-time',
901
+ ln_pre=True,
902
+ act_layer=QuickGELU,
903
+ is_tanh_gating=timesformer_gated_xattn,
904
+ )
905
+ clip_model, _ = load_openai_clip('ViT-B/16', 'cpu')
906
+ print("=> Loading CLIP (ViT-B/16) weights")
907
+ remapped_state_dict = remap_keys(clip_model.visual.state_dict(), transformer_layers=12)
908
+ res = vision_model.load_state_dict(remapped_state_dict, strict=False)
909
+ print(res)
910
+ vision_model.head = nn.Identity()
911
+ vision_model.pre_logits = nn.Identity()
912
+ vision_model.fc = nn.Identity()
913
+
914
+ gpt2 = GPT2LMHeadModel.from_pretrained(
915
+ "gpt2",
916
+ use_cache=False,
917
+ )
918
+ new_config = augment_gpt2_config(gpt2.config, cross_attn_freq=1, gated_xattn=gated_xattn)
919
+ text_decoder = GatedGPT2LMHeadModel(new_config)
920
+ if not random_init_gpt2:
921
+ print('Loading LM from pretrained weights..')
922
+ for n, p in gpt2.named_parameters():
923
+ rsetattr(text_decoder, n + '.data', p.data)
924
+
925
+ if freeze_lm_vclm:
926
+ print('Freeze the LM part of TextDecoder of VCLM')
927
+ text_decoder.freeze_lm_weights()
928
+
929
+ if freeze_visual_vclm:
930
+ print('Freeze the spatial part of VideoEncoder of VCLM')
931
+ vision_model.freeze_spatial_weights()
932
+
933
+ if freeze_visual_vclm_temporal:
934
+ print('Freeze the temporal part of VideoEncoder of VCLM')
935
+ vision_model.freeze_temporal_weights()
936
+
937
+ model = VCLM_HF(
938
+ vision_width=768,
939
+ vision_model=vision_model,
940
+ text_width=768,
941
+ text_decoder=text_decoder,
942
+ num_img_queries=256,
943
+ dim_head=64,
944
+ heads=12,
945
+ **kwargs,
946
+ )
947
+
948
+ return model
949
+
950
+
951
+ def VCLM_OPENAI_TIMESFORMER_BASE_GPT2_XL(
952
+ gated_xattn=False,
953
+ freeze_lm_vclm=False,
954
+ freeze_visual_vclm=False,
955
+ freeze_visual_vclm_temporal=False,
956
+ num_frames=4,
957
+ timesformer_gated_xattn=False,
958
+ **kwargs,
959
+ ):
960
+ vision_model = SpaceTimeTransformer(
961
+ num_frames=num_frames,
962
+ time_init='zeros',
963
+ attention_style='frozen-in-time',
964
+ ln_pre=True,
965
+ act_layer=QuickGELU,
966
+ is_tanh_gating=timesformer_gated_xattn,
967
+ )
968
+ clip_model, _ = load_openai_clip('ViT-B/16', 'cpu')
969
+ print("=> Loading CLIP (ViT-B/16) weights")
970
+ remapped_state_dict = remap_keys(clip_model.visual.state_dict(), transformer_layers=12)
971
+ res = vision_model.load_state_dict(remapped_state_dict, strict=False)
972
+ print(res)
973
+ vision_model.head = nn.Identity()
974
+ vision_model.pre_logits = nn.Identity()
975
+ vision_model.fc = nn.Identity()
976
+
977
+ gpt2 = GPT2LMHeadModel.from_pretrained(
978
+ "gpt2-xl",
979
+ use_cache=False,
980
+ )
981
+ new_config = augment_gpt2_config(gpt2.config, cross_attn_freq=2, gated_xattn=gated_xattn)
982
+ text_decoder = GatedGPT2LMHeadModel(new_config)
983
+ for n, p in gpt2.named_parameters():
984
+ rsetattr(text_decoder, n + '.data', p.data)
985
+
986
+ if freeze_lm_vclm:
987
+ print('Freeze the LM part of TextDecoder of VCLM')
988
+ text_decoder.freeze_lm_weights()
989
+
990
+ if freeze_visual_vclm:
991
+ print('Freeze the spatial part of VideoEncoder of VCLM')
992
+ vision_model.freeze_spatial_weights()
993
+
994
+ if freeze_visual_vclm_temporal:
995
+ print('Freeze the temporal part of VideoEncoder of VCLM')
996
+ vision_model.freeze_temporal_weights()
997
+
998
+ model = VCLM_HF(
999
+ vision_width=768,
1000
+ vision_model=vision_model,
1001
+ text_width=1600,
1002
+ text_decoder=text_decoder,
1003
+ num_img_queries=256,
1004
+ dim_head=64,
1005
+ heads=25,
1006
+ **kwargs,
1007
+ )
1008
+
1009
+ return model
1010
+
1011
+
1012
+ def VCLM_OPENAI_TIMESFORMER_LARGE_GPT2_XL(
1013
+ gated_xattn=False,
1014
+ freeze_lm_vclm=False,
1015
+ freeze_visual_vclm=False,
1016
+ freeze_visual_vclm_temporal=False,
1017
+ num_frames=4,
1018
+ timesformer_gated_xattn=False,
1019
+ **kwargs,
1020
+ ):
1021
+ vision_model = SpaceTimeTransformer(
1022
+ img_size=224, patch_size=14,
1023
+ embed_dim=1024, depth=24, num_heads=16,
1024
+ num_frames=num_frames,
1025
+ time_init='zeros',
1026
+ attention_style='frozen-in-time',
1027
+ ln_pre=True,
1028
+ act_layer=QuickGELU,
1029
+ is_tanh_gating=timesformer_gated_xattn,
1030
+ )
1031
+ clip_model, _ = load_openai_clip('ViT-L/14', 'cpu')
1032
+ print("=> Loading CLIP (ViT-L/14x) weights")
1033
+ remapped_state_dict = remap_keys(clip_model.visual.state_dict(), transformer_layers=24)
1034
+ res = vision_model.load_state_dict(remapped_state_dict, strict=False)
1035
+ print(res)
1036
+ vision_model.head = nn.Identity()
1037
+ vision_model.pre_logits = nn.Identity()
1038
+ vision_model.fc = nn.Identity()
1039
+
1040
+ gpt2 = GPT2LMHeadModel.from_pretrained(
1041
+ "gpt2-xl",
1042
+ use_cache=False,
1043
+ )
1044
+ new_config = augment_gpt2_config(gpt2.config, cross_attn_freq=2, gated_xattn=gated_xattn)
1045
+ text_decoder = GatedGPT2LMHeadModel(new_config)
1046
+ for n, p in gpt2.named_parameters():
1047
+ rsetattr(text_decoder, n + '.data', p.data)
1048
+
1049
+ if freeze_lm_vclm:
1050
+ print('Freeze the LM part of TextDecoder of VCLM')
1051
+ text_decoder.freeze_lm_weights()
1052
+
1053
+ if freeze_visual_vclm:
1054
+ print('Freeze the spatial part of VideoEncoder of VCLM')
1055
+ vision_model.freeze_spatial_weights()
1056
+
1057
+ if freeze_visual_vclm_temporal:
1058
+ print('Freeze the temporal part of VideoEncoder of VCLM')
1059
+ vision_model.freeze_temporal_weights()
1060
+
1061
+ model = VCLM_HF(
1062
+ vision_width=1024,
1063
+ vision_model=vision_model,
1064
+ text_width=1600,
1065
+ text_decoder=text_decoder,
1066
+ num_img_queries=256,
1067
+ dim_head=64,
1068
+ heads=25,
1069
+ **kwargs,
1070
+ )
1071
+
1072
+ return model
1073
+
1074
+
1075
+ def VCLM_OPENAI_TIMESFORMER_LARGE_GPT2(
1076
+ gated_xattn=False,
1077
+ freeze_lm_vclm=False,
1078
+ freeze_visual_vclm=False,
1079
+ freeze_visual_vclm_temporal=False,
1080
+ num_frames=4,
1081
+ timesformer_gated_xattn=False,
1082
+ **kwargs
1083
+ ):
1084
+ vision_model = SpaceTimeTransformer(
1085
+ img_size=224, patch_size=14,
1086
+ embed_dim=1024, depth=24, num_heads=16,
1087
+ num_frames=num_frames,
1088
+ time_init='zeros',
1089
+ attention_style='frozen-in-time',
1090
+ ln_pre=True,
1091
+ act_layer=QuickGELU,
1092
+ is_tanh_gating=timesformer_gated_xattn,
1093
+ )
1094
+ clip_model, _ = load_openai_clip('ViT-L/14', 'cpu')
1095
+ print("=> Loading CLIP (ViT-L/14x) weights")
1096
+ remapped_state_dict = remap_keys(clip_model.visual.state_dict(), transformer_layers=24)
1097
+ res = vision_model.load_state_dict(remapped_state_dict, strict=False)
1098
+ print(res)
1099
+ vision_model.head = nn.Identity()
1100
+ vision_model.pre_logits = nn.Identity()
1101
+ vision_model.fc = nn.Identity()
1102
+
1103
+ gpt2 = GPT2LMHeadModel.from_pretrained(
1104
+ "gpt2",
1105
+ use_cache=False,
1106
+ )
1107
+ new_config = augment_gpt2_config(gpt2.config, cross_attn_freq=1, gated_xattn=gated_xattn)
1108
+ text_decoder = GatedGPT2LMHeadModel(new_config)
1109
+ for n, p in gpt2.named_parameters():
1110
+ rsetattr(text_decoder, n + '.data', p.data)
1111
+
1112
+ if freeze_lm_vclm:
1113
+ print('Freeze the LM part of TextDecoder of VCLM')
1114
+ text_decoder.freeze_lm_weights()
1115
+
1116
+ if freeze_visual_vclm:
1117
+ print('Freeze the spatial part of VideoEncoder of VCLM')
1118
+ vision_model.freeze_spatial_weights()
1119
+
1120
+ if freeze_visual_vclm_temporal:
1121
+ print('Freeze the temporal part of VideoEncoder of VCLM')
1122
+ vision_model.freeze_temporal_weights()
1123
+
1124
+ model = VCLM_HF(
1125
+ vision_width=1024,
1126
+ vision_model=vision_model,
1127
+ text_width=768,
1128
+ text_decoder=text_decoder,
1129
+ num_img_queries=256,
1130
+ dim_head=64,
1131
+ heads=12,
1132
+ **kwargs,
1133
+ )
1134
+
1135
+ return model
1136
+
1137
+
1138
+ def VCLM_OPENAI_TIMESFORMER_LARGE_336PX_GPT2_XL(
1139
+ gated_xattn=False,
1140
+ freeze_lm_vclm=False,
1141
+ freeze_visual_vclm=False,
1142
+ freeze_visual_vclm_temporal=False,
1143
+ num_frames=4,
1144
+ timesformer_gated_xattn=False,
1145
+ **kwargs,
1146
+ ):
1147
+ vision_model = SpaceTimeTransformer(
1148
+ img_size=336, patch_size=14,
1149
+ embed_dim=1024, depth=24, num_heads=16,
1150
+ num_frames=num_frames,
1151
+ time_init='zeros',
1152
+ attention_style='frozen-in-time',
1153
+ ln_pre=True,
1154
+ act_layer=QuickGELU,
1155
+ is_tanh_gating=timesformer_gated_xattn,
1156
+ )
1157
+ clip_model, _ = load_openai_clip('ViT-L/14@336px', 'cpu')
1158
+ print("=> Loading CLIP (ViT-L/14@336px) weights")
1159
+ remapped_state_dict = remap_keys(clip_model.visual.state_dict(), transformer_layers=24)
1160
+ res = vision_model.load_state_dict(remapped_state_dict, strict=False)
1161
+ print(res)
1162
+ vision_model.head = nn.Identity()
1163
+ vision_model.pre_logits = nn.Identity()
1164
+ vision_model.fc = nn.Identity()
1165
+
1166
+ gpt2 = GPT2LMHeadModel.from_pretrained(
1167
+ "gpt2-xl",
1168
+ use_cache=False,
1169
+ )
1170
+ new_config = augment_gpt2_config(gpt2.config, cross_attn_freq=3, gated_xattn=gated_xattn)
1171
+ text_decoder = GatedGPT2LMHeadModel(new_config)
1172
+ for n, p in gpt2.named_parameters():
1173
+ rsetattr(text_decoder, n + '.data', p.data)
1174
+
1175
+ if freeze_lm_vclm:
1176
+ print('Freeze the LM part of TextDecoder of VCLM')
1177
+ text_decoder.freeze_lm_weights()
1178
+
1179
+ if freeze_visual_vclm:
1180
+ print('Freeze the spatial part of VideoEncoder of VCLM')
1181
+ vision_model.freeze_spatial_weights()
1182
+
1183
+ if freeze_visual_vclm_temporal:
1184
+ print('Freeze the temporal part of VideoEncoder of VCLM')
1185
+ vision_model.freeze_temporal_weights()
1186
+
1187
+ model = VCLM_HF(
1188
+ vision_width=1024,
1189
+ vision_model=vision_model,
1190
+ text_width=1600,
1191
+ text_decoder=text_decoder,
1192
+ num_img_queries=256,
1193
+ dim_head=64,
1194
+ heads=25,
1195
+ **kwargs,
1196
+ )
1197
+
1198
+ return model
1199
+
1200
+
1201
+ def CLIP_OPENAI_VITB32(**kwargs):
1202
+ model, _ = load_openai_clip('ViT-B/32', 'cpu')
1203
+ return model
1204
+
1205
+
1206
+ def CLIP_OPENAI_VITB16(**kwargs):
1207
+ model, _ = load_openai_clip('ViT-B/16', 'cpu')
1208
+ return model
1209
+
1210
+
1211
+ def CLIP_OPENAI_VITL14(**kwargs):
1212
+ model, _ = load_openai_clip('ViT-L/14', 'cpu')
1213
+ return model
1214
+
1215
+
1216
+ def CLIP_OPENAI_VITL14_336PX(**kwargs):
1217
+ model, _ = load_openai_clip('ViT-L/14@336px', 'cpu')
1218
+ return model
lavila/models/narrator.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Part of the code is from https://github.com/huggingface/transformers/blob/main/src/transformers/generation_utils.py
8
+ # Modified by Yue Zhao
9
+ # The original code is under Apache 2.0 License
10
+
11
+
12
+ import torch
13
+ from torch import nn
14
+ import torch.nn.functional as F
15
+ from einops import rearrange, repeat
16
+ from transformers import BeamSearchScorer
17
+ from transformers.generation_logits_process import (
18
+ LogitsProcessorList, TopKLogitsWarper, TopPLogitsWarper,
19
+ TemperatureLogitsWarper, TypicalLogitsWarper, LogitNormalization
20
+ )
21
+
22
+ from lavila.models.coca import CrossAttention, LayerNorm
23
+ from lavila.models.openai_model import VisionTransformer
24
+ from lavila.models.timesformer import SpaceTimeTransformer
25
+
26
+
27
+ class VCLM_HF(nn.Module):
28
+ def __init__(self,
29
+ # vision
30
+ vision_width: int,
31
+ vision_model: nn.Module,
32
+ # text
33
+ text_width: int,
34
+ text_decoder: nn.Module,
35
+ num_img_queries=256,
36
+ dim_head=64,
37
+ heads=8,
38
+ **kwargs,
39
+ ):
40
+ super().__init__()
41
+ self.vision_width = vision_width
42
+ self.visual = vision_model
43
+ self.text_width = text_width
44
+ self.text_decoder = text_decoder
45
+
46
+ self.img_queries = nn.Parameter(torch.empty(num_img_queries, text_width))
47
+ self.img_attn_pool = CrossAttention(
48
+ dim=text_width, context_dim=vision_width,
49
+ dim_head=dim_head, heads=heads,
50
+ norm_context=True
51
+ )
52
+ self.img_attn_pool_norm = LayerNorm(text_width)
53
+
54
+ self.initialize_parameters()
55
+
56
+ def initialize_parameters(self):
57
+ nn.init.normal_(self.img_queries, std=self.text_width ** -0.5)
58
+
59
+ def encode_image(self, image, use_checkpoint=False):
60
+ if isinstance(self.visual, VisionTransformer):
61
+ # openai_model.VisionTransformer accepts (N, C, H, W) instead of (N, C, T, H, W)
62
+ image = image.permute(0, 2, 1, 3, 4) # BCTHW -> BTCHW
63
+ bb, tt, _, _, _ = image.shape
64
+ x = self.visual(image.reshape(-1, *image.shape[2:]), use_checkpoint=use_checkpoint, cls_at_last=False) # NLD
65
+ x = x.view(bb, tt, *x.shape[1:])
66
+ x = x.permute(0, 3, 1, 2)
67
+ elif isinstance(self.visual, SpaceTimeTransformer):
68
+ image = image.permute(0, 2, 1, 3, 4).contiguous() # BCTHW -> BTCHW
69
+ bb, tt, _, _, _ = image.shape
70
+ x = self.visual.forward_features(image, use_checkpoint=use_checkpoint, cls_at_last=False) # NLD
71
+ x = x.permute(0, 2, 1)
72
+ else:
73
+ x = self.visual(image, use_checkpoint=use_checkpoint, mean_at_last=False)
74
+ if isinstance(x, list):
75
+ assert len(x) == 1
76
+ x = x[0]
77
+
78
+ x = x.flatten(start_dim=2) # BDTHW -> BD(THW)
79
+ x = x.permute(0, 2, 1) # BDN -> BND
80
+ img_queries = repeat(self.img_queries, 'n d -> b n d', b=x.shape[0])
81
+ img_queries = self.img_attn_pool(img_queries, x)
82
+ img_queries = self.img_attn_pool_norm(img_queries)
83
+ return img_queries
84
+
85
+ def forward(self, image, text, mask=None, use_checkpoint=False, norm_embed=False):
86
+ if use_checkpoint:
87
+ self.text_decoder.gradient_checkpointing_enable()
88
+ else:
89
+ self.text_decoder.gradient_checkpointing_disable()
90
+
91
+ text, labels = text[:, :-1], text[:, 1:]
92
+ # mask = mask[:, :-1]
93
+ image_tokens = self.encode_image(image, use_checkpoint=use_checkpoint)
94
+
95
+ output_decoder = self.text_decoder(text.contiguous(), encoder_hidden_states=image_tokens)
96
+ text_tokens_logits = output_decoder.logits
97
+ text_tokens_logits = rearrange(text_tokens_logits, 'b n c -> b c n')
98
+
99
+ return {'text_tokens_logits': text_tokens_logits,
100
+ 'labels': labels}
101
+
102
+ def generate(self, image_tokens, tokenizer, target=None, max_text_length=77, top_k=None, top_p=None,
103
+ num_return_sequences=1, temperature=1.0, teacher_forcing=False, early_stopping=False):
104
+ image_tokens = image_tokens.repeat_interleave(num_return_sequences, dim=0)
105
+ device = image_tokens.device
106
+ generated_text_ids = torch.LongTensor([[tokenizer.bos_token_id]] * image_tokens.shape[0]).to(device)
107
+ condition_text_ids = generated_text_ids.clone()
108
+
109
+ logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, typical_p=None, temperature=temperature, num_beams=1)
110
+
111
+ nlls, num_tokens = torch.zeros(image_tokens.shape[0]).to(device), torch.zeros(image_tokens.shape[0]).to(device)
112
+ is_reach_eos = torch.zeros(image_tokens.shape[0]).bool().to(device)
113
+ with torch.no_grad():
114
+ for i in range(max_text_length - 1):
115
+ output_decoder = self.text_decoder(condition_text_ids, encoder_hidden_states=image_tokens)
116
+ decoded_token_logits = output_decoder.logits
117
+ next_token_logits = decoded_token_logits[:, -1, :]
118
+ if target is not None:
119
+ nll = F.cross_entropy(next_token_logits, target[:, i+1], ignore_index=tokenizer.pad_token_id, reduction='none')
120
+ nlls += nll
121
+ num_tokens += target[:, i+1].ne(tokenizer.pad_token_id)
122
+ else:
123
+ nll = torch.special.entr(F.softmax(next_token_logits, dim=1)).sum(dim=1)
124
+ nlls += nll * (~is_reach_eos)
125
+ num_tokens += (~is_reach_eos)
126
+ # filtered_p = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p, device=device)
127
+ next_token_logits = logits_warper(generated_text_ids, next_token_logits)
128
+ filtered_p = F.softmax(next_token_logits, dim=-1)
129
+ next_token = torch.multinomial(filtered_p, num_samples=1)
130
+ is_reach_eos = is_reach_eos | (next_token[:, 0] == tokenizer.eos_token_id)
131
+ if early_stopping and torch.all(is_reach_eos):
132
+ break
133
+
134
+ if teacher_forcing:
135
+ condition_text_ids = target[:, :i+2]
136
+ else:
137
+ condition_text_ids = torch.cat((generated_text_ids, next_token), dim=1)
138
+
139
+ generated_text_ids = torch.cat((generated_text_ids, next_token), dim=1)
140
+ if target is not None:
141
+ return generated_text_ids, torch.exp(nlls / num_tokens)
142
+ else:
143
+ return generated_text_ids, torch.exp(nlls / num_tokens)
144
+
145
+ def beam_sample(self, image_tokens, tokenizer, target=None, max_text_length=77, top_k=None, top_p=None,
146
+ temperature=1.0, length_penalty=1.,
147
+ num_beams=3, num_return_sequences=1, teacher_forcing=False, early_stopping=False):
148
+ batch_size = image_tokens.shape[0]
149
+ device = image_tokens.device
150
+ input_ids = torch.ones((batch_size, 1), device=device, dtype=torch.long)
151
+ input_ids = input_ids * tokenizer.bos_token_id
152
+
153
+ expanded_return_idx = (
154
+ torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, num_beams * num_return_sequences).view(-1).to(device)
155
+ )
156
+ input_ids = input_ids.index_select(0, expanded_return_idx)
157
+
158
+ batch_beam_size, cur_len = input_ids.shape
159
+
160
+ logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, typical_p=None, temperature=temperature, num_beams=num_beams)
161
+
162
+ beam_scorer = BeamSearchScorer(
163
+ batch_size=batch_size * num_return_sequences, num_beams=num_beams,
164
+ device=device,
165
+ length_penalty=length_penalty,
166
+ )
167
+ batch_size = len(beam_scorer._beam_hyps)
168
+ num_beams = beam_scorer.num_beams
169
+
170
+ beam_scores = torch.zeros((batch_size, num_beams)).to(device)
171
+ beam_scores = beam_scores.view((batch_size * num_beams,))
172
+
173
+ is_reach_eos = torch.zeros(batch_beam_size).bool().to(device)
174
+ with torch.no_grad():
175
+ for i in range(max_text_length - 1):
176
+ output_decoder = self.text_decoder(
177
+ input_ids,
178
+ encoder_hidden_states=image_tokens.repeat_interleave(num_beams * num_return_sequences, dim=0)
179
+ )
180
+ decoded_token_logits = output_decoder.logits
181
+ next_token_logits = decoded_token_logits[:, -1, :]
182
+
183
+ next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
184
+ # supposed to be the line below, but ignore temporarily
185
+ # next_token_scores_processed = logits_processor(input_ids, next_token_scores)
186
+ next_token_scores_processed = next_token_scores
187
+ next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
188
+ # supposed to be the line below, but do a simple top_k+top_p temporarily
189
+ next_token_scores = logits_warper(input_ids, next_token_scores)
190
+ # next_token_scores = top_k_top_p_filtering(next_token_scores, top_k=top_k, top_p=top_p, device=device)
191
+
192
+ vocab_size = next_token_scores.shape[-1]
193
+ next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
194
+
195
+ probs = F.softmax(next_token_scores, dim=-1)
196
+ next_tokens = torch.multinomial(probs, num_samples=2 * num_beams)
197
+ next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
198
+
199
+ next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
200
+ next_tokens = torch.gather(next_tokens, -1, _indices)
201
+
202
+ next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
203
+ next_tokens = next_tokens % vocab_size
204
+
205
+ # stateless
206
+ beam_outputs = beam_scorer.process(
207
+ input_ids,
208
+ next_token_scores,
209
+ next_tokens,
210
+ next_indices,
211
+ pad_token_id=tokenizer.pad_token_id,
212
+ eos_token_id=tokenizer.eos_token_id,
213
+ )
214
+
215
+ beam_scores = beam_outputs["next_beam_scores"]
216
+ beam_next_tokens = beam_outputs["next_beam_tokens"]
217
+ beam_idx = beam_outputs["next_beam_indices"]
218
+
219
+ input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
220
+
221
+ is_reach_eos = is_reach_eos | (input_ids[:, -1] == tokenizer.eos_token_id)
222
+ if beam_scorer.is_done or torch.all(is_reach_eos):
223
+ break
224
+
225
+ sequence_outputs = beam_scorer.finalize(
226
+ input_ids,
227
+ beam_scores,
228
+ next_tokens,
229
+ next_indices,
230
+ pad_token_id=tokenizer.pad_token_id,
231
+ eos_token_id=tokenizer.eos_token_id,
232
+ max_length=max_text_length,
233
+ )
234
+
235
+ sequences = sequence_outputs["sequences"]
236
+ sequence_scores = sequence_outputs["sequence_scores"]
237
+ return sequences, sequence_scores
238
+
239
+ def group_beam_search(self, image_tokens, tokenizer, target=None, max_text_length=77, top_k=None, top_p=None,
240
+ temperature=1.0, length_penalty=1.,
241
+ num_beams=6, num_beam_groups=3,
242
+ num_return_sequences=1, teacher_forcing=False, early_stopping=False):
243
+ batch_size = image_tokens.shape[0]
244
+ device = image_tokens.device
245
+ input_ids = torch.ones((batch_size, 1), device=device, dtype=torch.long)
246
+ input_ids = input_ids * tokenizer.bos_token_id
247
+
248
+ expanded_return_idx = (
249
+ torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, num_beams).view(-1).to(device)
250
+ )
251
+ input_ids = input_ids.index_select(0, expanded_return_idx)
252
+
253
+ batch_beam_size, cur_len = input_ids.shape
254
+
255
+ logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, typical_p=None, temperature=temperature, num_beams=num_beams)
256
+
257
+ beam_scorer = BeamSearchScorer(
258
+ batch_size=batch_size, num_beams=num_beams,
259
+ num_beam_groups=num_beam_groups,
260
+ num_beam_hyps_to_keep=num_return_sequences, device=device,
261
+ length_penalty=length_penalty,
262
+ )
263
+ num_sub_beams = num_beams // num_beam_groups
264
+ beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
265
+ beam_scores[:, ::num_sub_beams] = 0
266
+ beam_scores = beam_scores.view((batch_size * num_beams,))
267
+
268
+ is_reach_eos = torch.zeros(batch_beam_size).bool().to(device)
269
+ with torch.no_grad():
270
+
271
+ # predicted tokens in cur_len step
272
+ current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
273
+
274
+ # indices which will form the beams in the next time step
275
+ reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)
276
+
277
+ for i in range(max_text_length - 1):
278
+ output_decoder = self.text_decoder(
279
+ input_ids,
280
+ encoder_hidden_states=image_tokens.repeat_interleave(num_beams, dim=0)
281
+ )
282
+ decoded_token_logits = output_decoder.logits
283
+
284
+ for beam_group_idx in range(num_beam_groups):
285
+ group_start_idx = beam_group_idx * num_sub_beams
286
+ group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
287
+ group_size = group_end_idx - group_start_idx
288
+
289
+ # indices of beams of current group among all sentences in batch
290
+ batch_group_indices = []
291
+
292
+ for batch_idx in range(batch_size):
293
+ batch_group_indices.extend(
294
+ [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
295
+ )
296
+ group_input_ids = input_ids[batch_group_indices]
297
+
298
+ # select outputs of beams of current group only
299
+ next_token_logits = decoded_token_logits[batch_group_indices, -1, :]
300
+
301
+ next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
302
+ vocab_size = next_token_scores.shape[-1]
303
+
304
+ # supposed to be the line below, but ignore temporarily
305
+ # next_token_scores_processed = logits_processor(input_ids, next_token_scores)
306
+ next_token_scores_processed = next_token_scores
307
+ next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
308
+ next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
309
+ next_token_scores = logits_warper(input_ids, next_token_scores)
310
+ # next_token_scores = top_k_top_p_filtering(next_token_scores, top_k=top_k, top_p=top_p, device=device)
311
+
312
+ # reshape for beam search
313
+ next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
314
+
315
+ next_token_scores, next_tokens = torch.topk(
316
+ next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
317
+ )
318
+
319
+ next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
320
+ next_tokens = next_tokens % vocab_size
321
+
322
+ # stateless
323
+ beam_outputs = beam_scorer.process(
324
+ group_input_ids,
325
+ next_token_scores,
326
+ next_tokens,
327
+ next_indices,
328
+ pad_token_id=tokenizer.pad_token_id,
329
+ eos_token_id=tokenizer.eos_token_id,
330
+ beam_indices=None
331
+ )
332
+ beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
333
+ beam_next_tokens = beam_outputs["next_beam_tokens"]
334
+ beam_idx = beam_outputs["next_beam_indices"]
335
+
336
+ input_ids[batch_group_indices] = group_input_ids[beam_idx]
337
+ group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
338
+ current_tokens[batch_group_indices] = group_input_ids[:, -1]
339
+ reordering_indices[batch_group_indices] = (
340
+ num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size)
341
+ )
342
+
343
+ input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
344
+
345
+ is_reach_eos = is_reach_eos | (input_ids[:, -1] == tokenizer.eos_token_id)
346
+ if beam_scorer.is_done or torch.all(is_reach_eos):
347
+ break
348
+
349
+ sequence_outputs = beam_scorer.finalize(
350
+ input_ids,
351
+ beam_scores,
352
+ next_tokens,
353
+ next_indices,
354
+ pad_token_id=tokenizer.pad_token_id,
355
+ eos_token_id=tokenizer.eos_token_id,
356
+ max_length=max_text_length,
357
+ beam_indices=None,
358
+ )
359
+
360
+ sequences = sequence_outputs["sequences"]
361
+ sequence_scores = sequence_outputs["sequence_scores"]
362
+ return sequences, sequence_scores
363
+
364
+ def _get_logits_warper(
365
+ self, top_k=None, top_p=None, typical_p=None,
366
+ temperature=None, num_beams=None, renormalize_logits=None,
367
+ ):
368
+ top_k = top_k if top_k is not None else 0
369
+ top_p = top_p if top_p is not None else 1.0
370
+ typical_p = typical_p if typical_p is not None else 1.
371
+ temperature = temperature if temperature is not None else 1.
372
+ warpers = LogitsProcessorList()
373
+
374
+ if temperature is not None and temperature != 1.0:
375
+ warpers.append(TemperatureLogitsWarper(temperature))
376
+ if top_k is not None and top_k != 0:
377
+ warpers.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
378
+ if top_p is not None and top_p < 1.0:
379
+ warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
380
+ if typical_p is not None and typical_p < 1.0:
381
+ warpers.append(TypicalLogitsWarper(mass=typical_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
382
+ # `LogitNormalization` should always be the last logit processor, when present
383
+ if renormalize_logits is True:
384
+ warpers.append(LogitNormalization())
385
+ return warpers
lavila/models/openai_clip.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Part of the code is from https://github.com/openai/CLIP/blob/main/clip/clip.py
8
+ # Modified by Yue Zhao
9
+ # The original code is under MIT License
10
+
11
+ import hashlib
12
+ import os
13
+ import urllib
14
+ import warnings
15
+ from typing import Union, List
16
+ from pkg_resources import packaging
17
+
18
+ import torch
19
+ from PIL import Image
20
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
21
+ from tqdm import tqdm
22
+
23
+ from .openai_model import build_model
24
+ from .tokenizer import SimpleTokenizer as _Tokenizer
25
+
26
+ try:
27
+ from torchvision.transforms import InterpolationMode
28
+ BICUBIC = InterpolationMode.BICUBIC
29
+ except ImportError:
30
+ BICUBIC = Image.BICUBIC
31
+
32
+
33
+ if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
34
+ warnings.warn("PyTorch version 1.7.1 or higher is recommended")
35
+
36
+
37
+ __all__ = ["available_models", "load", "tokenize"]
38
+ _tokenizer = _Tokenizer()
39
+
40
+ _MODELS = {
41
+ "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
42
+ "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
43
+ "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
44
+ "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
45
+ "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
46
+ "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
47
+ "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
48
+ "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
49
+ "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
50
+ }
51
+
52
+
53
+ def _download(url: str, root: str):
54
+ os.makedirs(root, exist_ok=True)
55
+ filename = os.path.basename(url)
56
+
57
+ expected_sha256 = url.split("/")[-2]
58
+ download_target = os.path.join(root, filename)
59
+
60
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
61
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
62
+
63
+ if os.path.isfile(download_target):
64
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
65
+ return download_target
66
+ else:
67
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
68
+
69
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
70
+ with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
71
+ while True:
72
+ buffer = source.read(8192)
73
+ if not buffer:
74
+ break
75
+
76
+ output.write(buffer)
77
+ loop.update(len(buffer))
78
+
79
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
80
+ raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
81
+
82
+ return download_target
83
+
84
+
85
+ def _convert_image_to_rgb(image):
86
+ return image.convert("RGB")
87
+
88
+
89
+ def _transform(n_px):
90
+ return Compose([
91
+ Resize(n_px, interpolation=BICUBIC),
92
+ CenterCrop(n_px),
93
+ _convert_image_to_rgb,
94
+ ToTensor(),
95
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
96
+ ])
97
+
98
+
99
+ def available_models() -> List[str]:
100
+ """Returns the names of available CLIP models"""
101
+ return list(_MODELS.keys())
102
+
103
+
104
+ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
105
+ """Load a CLIP model
106
+ Parameters
107
+ ----------
108
+ name : str
109
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
110
+ device : Union[str, torch.device]
111
+ The device to put the loaded model
112
+ jit : bool
113
+ Whether to load the optimized JIT model or more hackable non-JIT model (default).
114
+ download_root: str
115
+ path to download the model files; by default, it uses "~/.cache/clip"
116
+ Returns
117
+ -------
118
+ model : torch.nn.Module
119
+ The CLIP model
120
+ preprocess : Callable[[PIL.Image], torch.Tensor]
121
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
122
+ """
123
+ if name in _MODELS:
124
+ model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
125
+ elif os.path.isfile(name):
126
+ model_path = name
127
+ else:
128
+ raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
129
+
130
+ with open(model_path, 'rb') as opened_file:
131
+ try:
132
+ # loading JIT archive
133
+ model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
134
+ state_dict = None
135
+ except RuntimeError:
136
+ # loading saved state dict
137
+ if jit:
138
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
139
+ jit = False
140
+ state_dict = torch.load(opened_file, map_location="cpu")
141
+
142
+ if not jit:
143
+ model = build_model(state_dict or model.state_dict()).to(device)
144
+ if str(device) == "cpu":
145
+ model.float()
146
+ return model, _transform(model.visual.input_resolution)
147
+
148
+ # patch the device names
149
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
150
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
151
+
152
+ def patch_device(module):
153
+ try:
154
+ graphs = [module.graph] if hasattr(module, "graph") else []
155
+ except RuntimeError:
156
+ graphs = []
157
+
158
+ if hasattr(module, "forward1"):
159
+ graphs.append(module.forward1.graph)
160
+
161
+ for graph in graphs:
162
+ for node in graph.findAllNodes("prim::Constant"):
163
+ if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
164
+ node.copyAttributes(device_node)
165
+
166
+ model.apply(patch_device)
167
+ patch_device(model.encode_image)
168
+ patch_device(model.encode_text)
169
+
170
+ # patch dtype to float32 on CPU
171
+ if str(device) == "cpu":
172
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
173
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
174
+ float_node = float_input.node()
175
+
176
+ def patch_float(module):
177
+ try:
178
+ graphs = [module.graph] if hasattr(module, "graph") else []
179
+ except RuntimeError:
180
+ graphs = []
181
+
182
+ if hasattr(module, "forward1"):
183
+ graphs.append(module.forward1.graph)
184
+
185
+ for graph in graphs:
186
+ for node in graph.findAllNodes("aten::to"):
187
+ inputs = list(node.inputs())
188
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
189
+ if inputs[i].node()["value"] == 5:
190
+ inputs[i].node().copyAttributes(float_node)
191
+
192
+ model.apply(patch_float)
193
+ patch_float(model.encode_image)
194
+ patch_float(model.encode_text)
195
+
196
+ model.float()
197
+
198
+ return model, _transform(model.input_resolution.item())
199
+
200
+
201
+ def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
202
+ """
203
+ Returns the tokenized representation of given input string(s)
204
+ Parameters
205
+ ----------
206
+ texts : Union[str, List[str]]
207
+ An input string or a list of input strings to tokenize
208
+ context_length : int
209
+ The context length to use; all CLIP models use 77 as the context length
210
+ truncate: bool
211
+ Whether to truncate the text in case its encoding is longer than the context length
212
+ Returns
213
+ -------
214
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
215
+ We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
216
+ """
217
+ if isinstance(texts, str):
218
+ texts = [texts]
219
+
220
+ sot_token = _tokenizer.encoder["<|startoftext|>"]
221
+ eot_token = _tokenizer.encoder["<|endoftext|>"]
222
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
223
+ if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
224
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
225
+ else:
226
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
227
+
228
+ for i, tokens in enumerate(all_tokens):
229
+ if len(tokens) > context_length:
230
+ if truncate:
231
+ tokens = tokens[:context_length]
232
+ tokens[-1] = eot_token
233
+ else:
234
+ raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
235
+ result[i, :len(tokens)] = torch.tensor(tokens)
236
+
237
+ return result
lavila/models/openai_model.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Part of the code is from https://github.com/openai/CLIP/blob/main/clip/model.py
8
+ # Modified by Yue Zhao
9
+ # The original code is under MIT License
10
+
11
+ from collections import OrderedDict
12
+ from typing import Tuple, Union
13
+
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn.functional as F
17
+ import torch.utils.checkpoint as checkpoint
18
+ from torch import nn
19
+
20
+
21
+ class Bottleneck(nn.Module):
22
+ expansion = 4
23
+
24
+ def __init__(self, inplanes, planes, stride=1):
25
+ super().__init__()
26
+
27
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
28
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
29
+ self.bn1 = nn.BatchNorm2d(planes)
30
+ self.relu1 = nn.ReLU(inplace=True)
31
+
32
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
33
+ self.bn2 = nn.BatchNorm2d(planes)
34
+ self.relu2 = nn.ReLU(inplace=True)
35
+
36
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
37
+
38
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
39
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
40
+ self.relu3 = nn.ReLU(inplace=True)
41
+
42
+ self.downsample = None
43
+ self.stride = stride
44
+
45
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
46
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
47
+ self.downsample = nn.Sequential(OrderedDict([
48
+ ("-1", nn.AvgPool2d(stride)),
49
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
50
+ ("1", nn.BatchNorm2d(planes * self.expansion))
51
+ ]))
52
+
53
+ def forward(self, x: torch.Tensor):
54
+ identity = x
55
+
56
+ out = self.relu1(self.bn1(self.conv1(x)))
57
+ out = self.relu2(self.bn2(self.conv2(out)))
58
+ out = self.avgpool(out)
59
+ out = self.bn3(self.conv3(out))
60
+
61
+ if self.downsample is not None:
62
+ identity = self.downsample(x)
63
+
64
+ out += identity
65
+ out = self.relu3(out)
66
+ return out
67
+
68
+
69
+ class AttentionPool2d(nn.Module):
70
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
71
+ super().__init__()
72
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
73
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
74
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
75
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
76
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
77
+ self.num_heads = num_heads
78
+
79
+ def forward(self, x):
80
+ x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
81
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
82
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
83
+ x, _ = F.multi_head_attention_forward(
84
+ query=x[:1], key=x, value=x,
85
+ embed_dim_to_check=x.shape[-1],
86
+ num_heads=self.num_heads,
87
+ q_proj_weight=self.q_proj.weight,
88
+ k_proj_weight=self.k_proj.weight,
89
+ v_proj_weight=self.v_proj.weight,
90
+ in_proj_weight=None,
91
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
92
+ bias_k=None,
93
+ bias_v=None,
94
+ add_zero_attn=False,
95
+ dropout_p=0,
96
+ out_proj_weight=self.c_proj.weight,
97
+ out_proj_bias=self.c_proj.bias,
98
+ use_separate_proj_weight=True,
99
+ training=self.training,
100
+ need_weights=False
101
+ )
102
+ return x.squeeze(0)
103
+
104
+
105
+ class ModifiedResNet(nn.Module):
106
+ """
107
+ A ResNet class that is similar to torchvision's but contains the following changes:
108
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
109
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
110
+ - The final pooling layer is a QKV attention instead of an average pool
111
+ """
112
+
113
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
114
+ super().__init__()
115
+ self.output_dim = output_dim
116
+ self.input_resolution = input_resolution
117
+
118
+ # the 3-layer stem
119
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
120
+ self.bn1 = nn.BatchNorm2d(width // 2)
121
+ self.relu1 = nn.ReLU(inplace=True)
122
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
123
+ self.bn2 = nn.BatchNorm2d(width // 2)
124
+ self.relu2 = nn.ReLU(inplace=True)
125
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
126
+ self.bn3 = nn.BatchNorm2d(width)
127
+ self.relu3 = nn.ReLU(inplace=True)
128
+ self.avgpool = nn.AvgPool2d(2)
129
+
130
+ # residual layers
131
+ self._inplanes = width # this is a *mutable* variable used during construction
132
+ self.layer1 = self._make_layer(width, layers[0])
133
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
134
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
135
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
136
+
137
+ embed_dim = width * 32 # the ResNet feature dimension
138
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
139
+
140
+ def _make_layer(self, planes, blocks, stride=1):
141
+ layers = [Bottleneck(self._inplanes, planes, stride)]
142
+
143
+ self._inplanes = planes * Bottleneck.expansion
144
+ for _ in range(1, blocks):
145
+ layers.append(Bottleneck(self._inplanes, planes))
146
+
147
+ return nn.Sequential(*layers)
148
+
149
+ def forward(self, x):
150
+ def stem(x):
151
+ x = self.relu1(self.bn1(self.conv1(x)))
152
+ x = self.relu2(self.bn2(self.conv2(x)))
153
+ x = self.relu3(self.bn3(self.conv3(x)))
154
+ x = self.avgpool(x)
155
+ return x
156
+
157
+ x = x.type(self.conv1.weight.dtype)
158
+ x = stem(x)
159
+ x = self.layer1(x)
160
+ x = self.layer2(x)
161
+ x = self.layer3(x)
162
+ x = self.layer4(x)
163
+ x = self.attnpool(x)
164
+
165
+ return x
166
+
167
+
168
+ class LayerNorm(nn.LayerNorm):
169
+ """Subclass torch's LayerNorm to handle fp16."""
170
+
171
+ def forward(self, x: torch.Tensor):
172
+ orig_type = x.dtype
173
+ ret = super().forward(x.type(torch.float32))
174
+ return ret.type(orig_type)
175
+
176
+
177
+ class QuickGELU(nn.Module):
178
+ def forward(self, x: torch.Tensor):
179
+ return x * torch.sigmoid(1.702 * x)
180
+
181
+
182
+ class ResidualAttentionBlock(nn.Module):
183
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
184
+ super().__init__()
185
+
186
+ self.attn = nn.MultiheadAttention(d_model, n_head)
187
+ self.ln_1 = nn.LayerNorm(d_model) # used to be `models.transformer.LayerNorm`
188
+ self.mlp = nn.Sequential(OrderedDict([
189
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
190
+ ("gelu", QuickGELU()),
191
+ ("c_proj", nn.Linear(d_model * 4, d_model))
192
+ ]))
193
+ self.ln_2 = nn.LayerNorm(d_model) # used to be `models.transformer.LayerNorm`
194
+ self.attn_mask = attn_mask
195
+
196
+ def attention(self, x: torch.Tensor):
197
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
198
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
199
+
200
+ def forward_part1(self, x):
201
+ return self.attention(self.ln_1(x))
202
+
203
+ def forward_part2(self, x):
204
+ return self.mlp(self.ln_2(x))
205
+
206
+ def forward(self, x: torch.Tensor, use_checkpoint=False):
207
+ if use_checkpoint:
208
+ x = x + checkpoint.checkpoint(self.forward_part1, x)
209
+ else:
210
+ x = x + self.forward_part1(x)
211
+
212
+ if use_checkpoint:
213
+ x = x + checkpoint.checkpoint(self.forward_part2, x)
214
+ else:
215
+ x = x + self.forward_part2(x)
216
+ return x
217
+
218
+
219
+ class Transformer(nn.Module):
220
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
221
+ super().__init__()
222
+ self.width = width
223
+ self.layers = layers
224
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
225
+
226
+ def forward(self, x: torch.Tensor, use_checkpoint=False):
227
+ if use_checkpoint:
228
+ for i in range(self.layers):
229
+ x = checkpoint.checkpoint(self.resblocks[i], x)
230
+ return x
231
+ else:
232
+ return self.resblocks(x)
233
+
234
+
235
+ class VisionTransformer(nn.Module):
236
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
237
+ super().__init__()
238
+ self.input_resolution = input_resolution
239
+ self.output_dim = output_dim
240
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
241
+
242
+ scale = width ** -0.5
243
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
244
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
245
+ self.ln_pre = LayerNorm(width)
246
+
247
+ self.transformer = Transformer(width, layers, heads)
248
+
249
+ self.ln_post = LayerNorm(width)
250
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
251
+
252
+ def forward(self, x: torch.Tensor, apply_project=True, use_checkpoint=False, cls_at_last=True):
253
+ x = self.conv1(x) # shape = [*, width, grid, grid]
254
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
255
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
256
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
257
+ x = x + self.positional_embedding.to(x.dtype)
258
+ x = self.ln_pre(x)
259
+
260
+ x = x.permute(1, 0, 2) # NLD -> LND
261
+ x = self.transformer(x, use_checkpoint=use_checkpoint)
262
+ x = x.permute(1, 0, 2) # LND -> NLD
263
+
264
+ if cls_at_last:
265
+ x = self.ln_post(x[:, 0, :])
266
+
267
+ if self.proj is not None and apply_project:
268
+ x = x @ self.proj
269
+
270
+ return x
271
+ else:
272
+ return x[:, 1:, :]
273
+
274
+
275
+ class CLIP(nn.Module):
276
+ def __init__(self,
277
+ embed_dim: int,
278
+ # vision
279
+ image_resolution: int,
280
+ vision_layers: Union[Tuple[int, int, int, int], int],
281
+ vision_width: int,
282
+ vision_patch_size: int,
283
+ # text
284
+ context_length: int,
285
+ vocab_size: int,
286
+ transformer_width: int,
287
+ transformer_heads: int,
288
+ transformer_layers: int
289
+ ):
290
+ super().__init__()
291
+
292
+ self.context_length = context_length
293
+
294
+ if isinstance(vision_layers, (tuple, list)):
295
+ vision_heads = vision_width * 32 // 64
296
+ self.visual = ModifiedResNet(
297
+ layers=vision_layers,
298
+ output_dim=embed_dim,
299
+ heads=vision_heads,
300
+ input_resolution=image_resolution,
301
+ width=vision_width
302
+ )
303
+ else:
304
+ vision_heads = vision_width // 64
305
+ self.visual = VisionTransformer(
306
+ input_resolution=image_resolution,
307
+ patch_size=vision_patch_size,
308
+ width=vision_width,
309
+ layers=vision_layers,
310
+ heads=vision_heads,
311
+ output_dim=embed_dim
312
+ )
313
+
314
+ self.transformer = Transformer(
315
+ width=transformer_width,
316
+ layers=transformer_layers,
317
+ heads=transformer_heads,
318
+ attn_mask=self.build_attention_mask()
319
+ )
320
+
321
+ self.vocab_size = vocab_size
322
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
323
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
324
+ self.ln_final = LayerNorm(transformer_width)
325
+
326
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
327
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
328
+
329
+ self.initialize_parameters()
330
+
331
+ def initialize_parameters(self):
332
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
333
+ nn.init.normal_(self.positional_embedding, std=0.01)
334
+
335
+ if isinstance(self.visual, ModifiedResNet):
336
+ if self.visual.attnpool is not None:
337
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
338
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
339
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
340
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
341
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
342
+
343
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
344
+ for name, param in resnet_block.named_parameters():
345
+ if name.endswith("bn3.weight"):
346
+ nn.init.zeros_(param)
347
+
348
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
349
+ attn_std = self.transformer.width ** -0.5
350
+ fc_std = (2 * self.transformer.width) ** -0.5
351
+ for block in self.transformer.resblocks:
352
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
353
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
354
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
355
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
356
+
357
+ if self.text_projection is not None:
358
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
359
+
360
+ def build_attention_mask(self):
361
+ # lazily create causal attention mask, with full attention between the vision tokens
362
+ # pytorch uses additive attention mask; fill with -inf
363
+ mask = torch.empty(self.context_length, self.context_length)
364
+ mask.fill_(float("-inf"))
365
+ mask.triu_(1) # zero out the lower diagonal
366
+ return mask
367
+
368
+ @property
369
+ def dtype(self):
370
+ return self.visual.conv1.weight.dtype
371
+
372
+ def encode_image(self, image, apply_project=True, use_checkpoint=False):
373
+ if image.ndim == 4:
374
+ return self.visual(image.type(self.dtype))
375
+ else:
376
+ image = image.permute(0, 2, 1, 3, 4) # BCTHW -> BTCHW
377
+ bb, tt, _, _, _ = image.shape
378
+ x = self.visual(image.reshape(-1, *image.shape[2:]), apply_project=apply_project, use_checkpoint=use_checkpoint) # ND
379
+ x = x.view(bb, tt, -1)
380
+ image_features = x.mean(1)
381
+ # image_features = x.max(1).values
382
+ return image_features
383
+
384
+ def encode_text(self, text, use_checkpoint=False):
385
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
386
+
387
+ x = x + self.positional_embedding.type(self.dtype)
388
+ x = x.permute(1, 0, 2) # NLD -> LND
389
+ x = self.transformer(x, use_checkpoint=use_checkpoint)
390
+ x = x.permute(1, 0, 2) # LND -> NLD
391
+ x = self.ln_final(x).type(self.dtype)
392
+
393
+ # x.shape = [batch_size, n_ctx, transformer.width]
394
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
395
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
396
+
397
+ return x
398
+
399
+ def forward(self, image, text, use_checkpoint=False, norm_embed=True):
400
+ image_features = self.encode_image(image, use_checkpoint=use_checkpoint)
401
+ text_features = self.encode_text(text, use_checkpoint=use_checkpoint)
402
+
403
+ # normalized features
404
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
405
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
406
+
407
+ # # cosine similarity as logits
408
+ # logit_scale = self.logit_scale.exp()
409
+ # logits_per_image = logit_scale * image_features @ text_features.t()
410
+ # logits_per_text = logits_per_image.t()
411
+
412
+ # # shape = [global_batch_size, global_batch_size]
413
+ # return logits_per_image, logits_per_text
414
+
415
+ return {'image_embed': image_features,
416
+ 'text_embed': text_features,
417
+ 'logit_scale': self.logit_scale.exp()}
418
+
419
+
420
+ def convert_weights(model: nn.Module):
421
+ """Convert applicable model parameters to fp16"""
422
+
423
+ def _convert_weights_to_fp16(l):
424
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
425
+ l.weight.data = l.weight.data.half()
426
+ if l.bias is not None:
427
+ l.bias.data = l.bias.data.half()
428
+
429
+ if isinstance(l, nn.MultiheadAttention):
430
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
431
+ tensor = getattr(l, attr)
432
+ if tensor is not None:
433
+ tensor.data = tensor.data.half()
434
+
435
+ for name in ["text_projection", "proj"]:
436
+ if hasattr(l, name):
437
+ attr = getattr(l, name)
438
+ if attr is not None:
439
+ attr.data = attr.data.half()
440
+
441
+ model.apply(_convert_weights_to_fp16)
442
+
443
+
444
+ def build_model(state_dict: dict):
445
+ vit = "visual.proj" in state_dict
446
+
447
+ if vit:
448
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
449
+ vision_layers = len(
450
+ [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]
451
+ )
452
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
453
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
454
+ image_resolution = vision_patch_size * grid_size
455
+ else:
456
+ counts: list = [
457
+ len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]
458
+ ]
459
+ vision_layers = tuple(counts)
460
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
461
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
462
+ vision_patch_size = None
463
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
464
+ image_resolution = output_width * 32
465
+
466
+ embed_dim = state_dict["text_projection"].shape[1]
467
+ context_length = state_dict["positional_embedding"].shape[0]
468
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
469
+ transformer_width = state_dict["ln_final.weight"].shape[0]
470
+ transformer_heads = transformer_width // 64
471
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
472
+
473
+ model = CLIP(
474
+ embed_dim,
475
+ image_resolution, vision_layers, vision_width, vision_patch_size,
476
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
477
+ )
478
+
479
+ for key in ["input_resolution", "context_length", "vocab_size"]:
480
+ if key in state_dict:
481
+ del state_dict[key]
482
+
483
+ convert_weights(model)
484
+ model.load_state_dict(state_dict)
485
+ return model.eval()
lavila/models/timesformer.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Part of the code is from https://github.com/m-bain/frozen-in-time/blob/main/model/video_transformer.py
8
+ # Modified by Yue Zhao
9
+ # The original code is under MIT License
10
+
11
+ """
12
+ Implementations of Video Transformers in PyTorch
13
+ A PyTorch implementation of space-time transformer as described in
14
+ 'Frozen in Time: A Joint Image and Video Encoder for End-to-End Retrieval' - https://arxiv.org/abs/2104.00650
15
+ A PyTorch implementation of timesformer as described in
16
+ 'Is Space-Time Attention All You Need for Video Understanding?' - https://arxiv.org/abs/2102.05095
17
+ Acknowledgments:
18
+ - This code builds on Ross Wightman's vision_transformer code in pytorch-image-models:
19
+ https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
20
+ - It is also inspired by lucidrains timesformer implementation:
21
+ https://github.com/lucidrains/TimeSformer-pytorch
22
+ Hacked together by Max Bain
23
+ """
24
+
25
+ from collections import OrderedDict
26
+ from functools import partial
27
+
28
+ import torch
29
+ import torch.utils.checkpoint as checkpoint
30
+ from einops import rearrange, repeat
31
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
32
+ from torch import einsum, nn
33
+
34
+
35
+ def attn(q, k, v):
36
+ sim = einsum('b i d, b j d -> b i j', q, k)
37
+ attn = sim.softmax(dim=-1)
38
+ out = einsum('b i j, b j d -> b i d', attn, v)
39
+ return out
40
+
41
+
42
+ class Mlp(nn.Module):
43
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
44
+ super().__init__()
45
+ out_features = out_features or in_features
46
+ hidden_features = hidden_features or in_features
47
+ self.fc1 = nn.Linear(in_features, hidden_features)
48
+ self.act = act_layer()
49
+ self.fc2 = nn.Linear(hidden_features, out_features)
50
+ self.drop = nn.Dropout(drop)
51
+
52
+ def forward(self, x):
53
+ x = self.fc1(x)
54
+ x = self.act(x)
55
+ x = self.drop(x)
56
+ x = self.fc2(x)
57
+ x = self.drop(x)
58
+ return x
59
+
60
+
61
+ class VideoPatchEmbed(nn.Module):
62
+ """ Video to Patch Embedding
63
+ """
64
+
65
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768,
66
+ num_frames=8, ln_pre=False):
67
+ super().__init__()
68
+ img_size = to_2tuple(img_size)
69
+ patch_size = to_2tuple(patch_size)
70
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * num_frames
71
+ self.img_size = img_size
72
+ self.patch_size = patch_size
73
+ self.num_patches = num_patches
74
+ self.num_frames = num_frames
75
+ self.embed_dim = embed_dim
76
+ # ln_pre is inserted to be compatible with CLIP-style model
77
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=not ln_pre)
78
+
79
+ def forward(self, x):
80
+ B, F, C, H, W = x.shape
81
+ assert F <= self.num_frames
82
+ x = x.view(-1, C, H, W)
83
+ x = self.proj(x)
84
+ return x
85
+
86
+
87
+ class VarAttention(nn.Module):
88
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.,
89
+ initialize='random'):
90
+ super().__init__()
91
+ self.num_heads = num_heads
92
+ head_dim = dim // num_heads
93
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
94
+ self.scale = qk_scale or head_dim ** -0.5
95
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
96
+ self.proj = nn.Linear(dim, dim)
97
+ if initialize == 'zeros':
98
+ self.qkv.weight.data.fill_(0)
99
+ self.qkv.bias.data.fill_(0)
100
+ # fill proj weight with 1 here to improve training dynamics. Otherwise temporal attention inputs
101
+ # are multiplied by 0*0, which is hard for the model to move out of.
102
+ self.proj.weight.data.fill_(1)
103
+ self.proj.bias.data.fill_(0)
104
+ self.attn_drop = nn.Dropout(attn_drop)
105
+ self.proj_drop = nn.Dropout(proj_drop)
106
+
107
+ def forward(self, x, einops_from, einops_to, einops_dims):
108
+ h = self.num_heads
109
+ # project x to q, k, v vaalues
110
+ q, k, v = self.qkv(x).chunk(3, dim=-1)
111
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
112
+
113
+ q *= self.scale
114
+
115
+ # splice out CLS token at index 1
116
+ (cls_q, q_), (cls_k, k_), (cls_v, v_) = map(lambda t: (t[:, 0:1], t[:, 1:]), (q, k, v))
117
+
118
+ # let CLS token attend to key / values of all patches across time and space
119
+ cls_out = attn(cls_q, k, v)
120
+ # rearrange across time or space
121
+ q_, k_, v_ = map(lambda t: rearrange(t, f'{einops_from} -> {einops_to}', **einops_dims), (q_, k_, v_))
122
+
123
+ # expand cls token keys and values across time or space and concat
124
+ r = q_.shape[0] // cls_k.shape[0]
125
+ cls_k, cls_v = map(lambda t: repeat(t, 'b () d -> (b r) () d', r=r), (cls_k, cls_v))
126
+
127
+ k_ = torch.cat((cls_k, k_), dim=1)
128
+ v_ = torch.cat((cls_v, v_), dim=1)
129
+
130
+ # attention
131
+ out = attn(q_, k_, v_)
132
+
133
+ # merge back time or space
134
+ out = rearrange(out, f'{einops_to} -> {einops_from}', **einops_dims)
135
+
136
+ # concat back the cls token
137
+ out = torch.cat((cls_out, out), dim=1)
138
+
139
+ # merge back the heads
140
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
141
+ # to out
142
+ x = self.proj(out)
143
+ x = self.proj_drop(x)
144
+ return x
145
+
146
+
147
+ class SpaceTimeBlock(nn.Module):
148
+
149
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
150
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, time_init='zeros',
151
+ attention_style='frozen-in-time', is_tanh_gating=False):
152
+ super().__init__()
153
+ self.norm1 = norm_layer(dim)
154
+ self.attn = VarAttention(
155
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
156
+
157
+ self.timeattn = VarAttention(
158
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
159
+ initialize=time_init)
160
+
161
+ if is_tanh_gating:
162
+ self.alpha_timeattn = nn.Parameter(torch.zeros([]))
163
+
164
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
165
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
166
+ self.norm2 = norm_layer(dim)
167
+ mlp_hidden_dim = int(dim * mlp_ratio)
168
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
169
+ self.norm3 = norm_layer(dim)
170
+
171
+ self.attention_style = attention_style
172
+
173
+ def forward(self, x, einops_from_space, einops_to_space, einops_from_time, einops_to_time,
174
+ time_n, space_f, use_checkpoint=False):
175
+ if use_checkpoint:
176
+ time_output = checkpoint.checkpoint(
177
+ self.timeattn, self.norm3(x), einops_from_time, einops_to_time, {"n": time_n}
178
+ )
179
+ else:
180
+ time_output = self.timeattn(self.norm3(x), einops_from_time, einops_to_time, {"n": time_n})
181
+ if hasattr(self, "alpha_timeattn"):
182
+ time_output = torch.tanh(self.alpha_timeattn) * time_output
183
+ time_residual = x + time_output
184
+ if use_checkpoint:
185
+ space_output = checkpoint.checkpoint(
186
+ self.attn, self.norm1(time_residual), einops_from_space, einops_to_space, {"f": space_f}
187
+ )
188
+ else:
189
+ space_output = self.attn(self.norm1(time_residual), einops_from_space,
190
+ einops_to_space, {"f": space_f})
191
+ if self.attention_style == 'frozen-in-time':
192
+ space_residual = x + self.drop_path(space_output)
193
+ else:
194
+ raise NotImplementedError
195
+
196
+ x = space_residual + self.drop_path(self.mlp(self.norm2(space_residual)))
197
+
198
+ return x
199
+
200
+
201
+ class SpaceTimeTransformer(nn.Module):
202
+ """ Vision Transformer
203
+ A PyTorch impl of : `Space-Time Transformer` from Frozen-in-time - by Max Bain.
204
+ https://arxiv.org/abs/2104.00650
205
+ Based off:
206
+ - ViT implementation from the timm library [https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py]
207
+ lucidrains timesformer implementation [https://github.com/lucidrains/TimeSformer-pytorch].
208
+ Notable differences:
209
+ - allows for variable length input frames (<= num_frames)
210
+ - allows for variable length input resolution (<= (img_size, img_size)) [UNTESTED]
211
+ - different attention block mechanism
212
+ """
213
+
214
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
215
+ num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
216
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None,
217
+ num_frames=8, time_init='rand', attention_style='frozen-in-time', ln_pre=False,
218
+ act_layer=nn.GELU, is_tanh_gating=False):
219
+ """
220
+ Args:
221
+ img_size (int, tuple): input image size
222
+ patch_size (int, tuple): patch size
223
+ in_chans (int): number of input channels
224
+ num_classes (int): number of classes for classification head
225
+ embed_dim (int): embedding dimension
226
+ depth (int): depth of transformer
227
+ num_heads (int): number of attention heads
228
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
229
+ qkv_bias (bool): enable bias for qkv if True
230
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
231
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
232
+ drop_rate (float): dropout rate
233
+ attn_drop_rate (float): attention dropout rate
234
+ drop_path_rate (float): stochastic depth rate
235
+ hybrid_backbone (nn.Module): CNN backbone to use in-place of PatchEmbed module
236
+ norm_layer: (nn.Module): normalization layer
237
+ num_frames: (int) maximum number of frames expected as input
238
+ time_init: (str) how to initialise the time attention layer, 'zeros' allows for the timesformer to start off
239
+ as ViT.
240
+ attention_style: (str) how to attend to space and time.
241
+ """
242
+ super().__init__()
243
+ self.num_classes = num_classes
244
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
245
+ self.num_frames = num_frames
246
+ self.embed_dim = embed_dim
247
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
248
+ print("######USING ATTENTION STYLE: ", attention_style)
249
+ if hybrid_backbone is not None:
250
+ raise NotImplementedError('hybrid backbone not implemented')
251
+ else:
252
+ self.patch_embed = VideoPatchEmbed(
253
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, num_frames=num_frames, ln_pre=ln_pre)
254
+ num_patches = self.patch_embed.num_patches
255
+ self.patches_per_frame = num_patches // num_frames
256
+
257
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
258
+ self.pos_embed = nn.Parameter(
259
+ torch.zeros(1, self.patches_per_frame + 1,
260
+ embed_dim)) # remember to take pos_embed[1:] for tiling over time
261
+ self.temporal_embed = nn.Parameter(torch.zeros(1, num_frames, embed_dim))
262
+
263
+ if ln_pre:
264
+ self.ln_pre = nn.LayerNorm(embed_dim)
265
+ else:
266
+ self.ln_pre = None
267
+
268
+ self.pos_drop = nn.Dropout(p=drop_rate)
269
+
270
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
271
+ self.blocks = nn.ModuleList([
272
+ SpaceTimeBlock(
273
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
274
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, time_init=time_init,
275
+ attention_style=attention_style, act_layer=act_layer, is_tanh_gating=is_tanh_gating)
276
+ for i in range(depth)])
277
+ self.norm = norm_layer(embed_dim)
278
+
279
+ # Representation layer
280
+ if representation_size:
281
+ self.num_features = representation_size
282
+ self.pre_logits = nn.Sequential(OrderedDict([
283
+ ('fc', nn.Linear(embed_dim, representation_size)),
284
+ ('act', nn.Tanh())
285
+ ]))
286
+ else:
287
+ self.pre_logits = nn.Identity()
288
+
289
+ # Classifier head
290
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
291
+
292
+ trunc_normal_(self.pos_embed, std=.02)
293
+ trunc_normal_(self.cls_token, std=.02)
294
+
295
+ # if num_frames > 1, then we perform ViT inflation and initialise time attention to zero so not necessary.
296
+ if num_frames == 1:
297
+ self.apply(self._init_weights)
298
+
299
+ # einops transformations
300
+ self.einops_from_space = 'b (f n) d'
301
+ self.einops_to_space = '(b f) n d'
302
+ self.einops_from_time = 'b (f n) d'
303
+ self.einops_to_time = '(b n) f d'
304
+
305
+ def _init_weights(self, m):
306
+ if isinstance(m, nn.Linear):
307
+ trunc_normal_(m.weight, std=.02)
308
+ if isinstance(m, nn.Linear) and m.bias is not None:
309
+ nn.init.constant_(m.bias, 0)
310
+ elif isinstance(m, nn.LayerNorm):
311
+ nn.init.constant_(m.bias, 0)
312
+ nn.init.constant_(m.weight, 1.0)
313
+
314
+ @torch.jit.ignore
315
+ def no_weight_decay(self):
316
+ return {'pos_embed', 'cls_token'}
317
+
318
+ def get_classifier(self):
319
+ return self.head
320
+
321
+ def reset_classifier(self, num_classes, global_pool=''):
322
+ self.num_classes = num_classes
323
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
324
+
325
+ def freeze_spatial_weights(self):
326
+ freeze_list = []
327
+ for n, p in self.named_parameters():
328
+ if 'temporal_embed' in n or 'timeattn' in n or 'norm3' in n:
329
+ pass
330
+ else:
331
+ p.requires_grad = False
332
+ freeze_list.append(n)
333
+ print("Freeze the pretrained parts in vision model: {}".format(freeze_list))
334
+
335
+ def freeze_temporal_weights(self):
336
+ freeze_list = []
337
+ for n, p in self.named_parameters():
338
+ if 'temporal_embed' in n or 'timeattn' in n or 'norm3' in n:
339
+ p.requires_grad = False
340
+ freeze_list.append(n)
341
+ else:
342
+ pass
343
+ print("Freeze the pretrained parts in vision model: {}".format(freeze_list))
344
+
345
+ def forward_features(self, x, use_checkpoint=False, cls_at_last=True):
346
+ # print(x.shape)
347
+ b, curr_frames, channels, _, _ = x.shape
348
+ x = self.patch_embed(x)
349
+ x = x.flatten(2).transpose(2, 1)
350
+ x = x.reshape(b, -1, self.patch_embed.embed_dim)
351
+
352
+ BF = x.shape[0]
353
+ cls_tokens = self.cls_token.expand(BF, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
354
+ x = torch.cat((cls_tokens, x), dim=1)
355
+ # positional embed needs to be tiled for each frame (this does [1,2,3] --> [1,2,3,1,2,3]...)
356
+ cls_embed = self.pos_embed[:, 0, :].unsqueeze(1)
357
+ tile_pos_embed = self.pos_embed[:, 1:, :].repeat(1, self.num_frames, 1)
358
+ # temporal embed needs to be repeated within each frame (this does [1,2,3] --> [1,1,1,2,2,2,3,3,3]...)
359
+ tile_temporal_embed = self.temporal_embed.repeat_interleave(self.patches_per_frame, 1)
360
+ total_pos_embed = tile_pos_embed + tile_temporal_embed
361
+ total_pos_embed = torch.cat([cls_embed, total_pos_embed], dim=1)
362
+
363
+ curr_patches = x.shape[1]
364
+ x = x + total_pos_embed[:, :curr_patches]
365
+ if self.ln_pre is not None:
366
+ x = self.ln_pre(x)
367
+ x = self.pos_drop(x)
368
+ n = self.patches_per_frame
369
+ f = curr_frames
370
+
371
+ for blk in self.blocks:
372
+ x = blk(x, self.einops_from_space, self.einops_to_space, self.einops_from_time,
373
+ self.einops_to_time,
374
+ time_n=n, space_f=f, use_checkpoint=use_checkpoint)
375
+
376
+ if cls_at_last:
377
+ x = self.norm(x)[:, 0]
378
+ x = self.pre_logits(x)
379
+
380
+ return x
381
+ else:
382
+ return self.norm(x)
383
+
384
+ def forward(self, x, use_checkpoint=False):
385
+ # Note: B C T H W => B T C H W
386
+ # The default input order is different from the one in Frozen-in-Time
387
+ x = x.permute(0, 2, 1, 3, 4).contiguous()
388
+ x = self.forward_features(x, use_checkpoint=use_checkpoint)
389
+ x = self.head(x)
390
+ return x
lavila/models/tokenizer.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Part of the code is from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py
8
+ # Modified by Yue Zhao
9
+ # The original code is under MIT License
10
+
11
+ import gzip
12
+ import html
13
+ import os
14
+ from functools import lru_cache
15
+
16
+ import ftfy
17
+ import regex as re
18
+ import torch
19
+
20
+ from transformers import (BertTokenizer, DistilBertTokenizer, GPT2Tokenizer)
21
+
22
+
23
+ @lru_cache()
24
+ def default_bpe():
25
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
26
+
27
+
28
+ @lru_cache()
29
+ def bytes_to_unicode():
30
+ """
31
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
32
+ The reversible bpe codes work on unicode strings.
33
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
34
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
35
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
36
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
37
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
38
+ """
39
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
40
+ cs = bs[:]
41
+ n = 0
42
+ for b in range(2**8):
43
+ if b not in bs:
44
+ bs.append(b)
45
+ cs.append(2**8+n)
46
+ n += 1
47
+ cs = [chr(n) for n in cs]
48
+ return dict(zip(bs, cs))
49
+
50
+
51
+ def get_pairs(word):
52
+ """Return set of symbol pairs in a word.
53
+ Word is represented as tuple of symbols (symbols being variable-length strings).
54
+ """
55
+ pairs = set()
56
+ prev_char = word[0]
57
+ for char in word[1:]:
58
+ pairs.add((prev_char, char))
59
+ prev_char = char
60
+ return pairs
61
+
62
+
63
+ def basic_clean(text):
64
+ text = ftfy.fix_text(text)
65
+ text = html.unescape(html.unescape(text))
66
+ return text.strip()
67
+
68
+
69
+ def whitespace_clean(text):
70
+ text = re.sub(r'\s+', ' ', text)
71
+ text = text.strip()
72
+ return text
73
+
74
+
75
+ class SimpleTokenizer(object):
76
+ def __init__(self, bpe_path: str = default_bpe()):
77
+ self.byte_encoder = bytes_to_unicode()
78
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
79
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
80
+ merges = merges[1:49152-256-2+1]
81
+ merges = [tuple(merge.split()) for merge in merges]
82
+ vocab = list(bytes_to_unicode().values())
83
+ vocab = vocab + [v+'</w>' for v in vocab]
84
+ for merge in merges:
85
+ vocab.append(''.join(merge))
86
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
87
+ self.encoder = dict(zip(vocab, range(len(vocab))))
88
+ self.decoder = {v: k for k, v in self.encoder.items()}
89
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
90
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
91
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
92
+
93
+ def bpe(self, token):
94
+ if token in self.cache:
95
+ return self.cache[token]
96
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
97
+ pairs = get_pairs(word)
98
+
99
+ if not pairs:
100
+ return token+'</w>'
101
+
102
+ while True:
103
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
104
+ if bigram not in self.bpe_ranks:
105
+ break
106
+ first, second = bigram
107
+ new_word = []
108
+ i = 0
109
+ while i < len(word):
110
+ try:
111
+ j = word.index(first, i)
112
+ new_word.extend(word[i:j])
113
+ i = j
114
+ except:
115
+ new_word.extend(word[i:])
116
+ break
117
+
118
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
119
+ new_word.append(first+second)
120
+ i += 2
121
+ else:
122
+ new_word.append(word[i])
123
+ i += 1
124
+ new_word = tuple(new_word)
125
+ word = new_word
126
+ if len(word) == 1:
127
+ break
128
+ else:
129
+ pairs = get_pairs(word)
130
+ word = ' '.join(word)
131
+ self.cache[token] = word
132
+ return word
133
+
134
+ def encode(self, text):
135
+ bpe_tokens = []
136
+ text = whitespace_clean(basic_clean(text)).lower()
137
+ for token in re.findall(self.pat, text):
138
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
139
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
140
+ return bpe_tokens
141
+
142
+ def decode(self, tokens):
143
+ text = ''.join([self.decoder[token] for token in tokens])
144
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
145
+ return text
146
+
147
+ def __call__(self, texts, context_length=77):
148
+ if isinstance(texts, str):
149
+ texts = [texts]
150
+
151
+ sot_token = self.encoder["<|startoftext|>"]
152
+ eot_token = self.encoder["<|endoftext|>"]
153
+ all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts]
154
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
155
+
156
+ for i, tokens in enumerate(all_tokens):
157
+ tokens = tokens[:context_length]
158
+ result[i, :len(tokens)] = torch.tensor(tokens)
159
+
160
+ if len(result) == 1:
161
+ return result[0]
162
+ return result
163
+
164
+
165
+ class MyBertTokenizer(object):
166
+ def __init__(self, name=''):
167
+ print('=> Initialize MyBertTokenizer ({})'.format(name))
168
+ self.tokenizer = BertTokenizer.from_pretrained(name)
169
+ self.bos_token_id, self.eos_token_id = self.tokenizer('').input_ids
170
+ self.pad_token_id = 0
171
+
172
+ def __call__(self, texts, context_length=77):
173
+ if isinstance(texts, str):
174
+ texts = [texts]
175
+ result = torch.zeros(len(texts), context_length, dtype=torch.long)
176
+ mask = torch.zeros(len(texts), context_length, dtype=torch.float32)
177
+ for i, text in enumerate(texts):
178
+ tokens = self.tokenizer(text)
179
+ input_ids = tokens.input_ids[:context_length]
180
+ attention_mask = tokens.attention_mask[:context_length]
181
+ result[i, :len(input_ids)] = torch.tensor(input_ids)
182
+ mask[i, :len(attention_mask)] = torch.tensor(attention_mask)
183
+
184
+ if len(result) == 1:
185
+ return result[0], mask[0]
186
+ return result, mask
187
+
188
+
189
+ class MyDistilBertTokenizer(object):
190
+ def __init__(self, name=''):
191
+ print('=> Initialize MyDistilBertTokenizer ({})'.format(name))
192
+ self.tokenizer = DistilBertTokenizer.from_pretrained(name)
193
+
194
+ def __call__(self, texts, context_length=77):
195
+ if isinstance(texts, str):
196
+ texts = [texts]
197
+ result = torch.zeros(len(texts), context_length, dtype=torch.long)
198
+ mask = torch.zeros(len(texts), context_length, dtype=torch.float32)
199
+ for i, text in enumerate(texts):
200
+ tokens = self.tokenizer(text)
201
+ input_ids = tokens.input_ids[:context_length]
202
+ attention_mask = tokens.attention_mask[:context_length]
203
+ result[i, :len(input_ids)] = torch.tensor(input_ids)
204
+ mask[i, :len(attention_mask)] = torch.tensor(attention_mask)
205
+
206
+ if len(result) == 1:
207
+ return result[0], mask[0]
208
+ return result, mask
209
+
210
+
211
+ class MyGPT2Tokenizer(object):
212
+ def __init__(self, name='', add_bos=False):
213
+ print('=> Initialize MyGPT2Tokenizer ({})'.format(name))
214
+ self.tokenizer = GPT2Tokenizer.from_pretrained(name)
215
+ self.bos_token_id, self.eos_token_id = self.tokenizer.bos_token_id, self.tokenizer.eos_token_id
216
+ self.pad_token_id = 0
217
+ self.add_bos = add_bos
218
+ # num_added_tokens = self.tokenizer.add_special_tokens({'pad_token': "[PAD]"})
219
+ # print('num_added_tokens={}'.format(len(num_added_tokens)))
220
+
221
+ def __call__(self, texts, context_length=77):
222
+ if isinstance(texts, str):
223
+ texts = [texts]
224
+ result = torch.zeros(len(texts), context_length, dtype=torch.long)
225
+ for i, text in enumerate(texts):
226
+ tokens = self.tokenizer(text)
227
+ if not self.add_bos:
228
+ input_ids = tokens.input_ids[:context_length - 1]
229
+ input_ids = input_ids + [self.tokenizer.eos_token_id] # add [EOS]
230
+ else:
231
+ input_ids = tokens.input_ids[:context_length - 2]
232
+ input_ids = [self.tokenizer.bos_token_id] + input_ids + [self.tokenizer.eos_token_id] # add [EOS]
233
+ # attention_mask = tokens.attention_mask[:context_length]
234
+ # attention_mask = attention_mask + [0.] * pad_length
235
+ result[i, :len(input_ids)] = torch.tensor(input_ids)
236
+
237
+ if len(result) == 1:
238
+ return result[0]
239
+ return result
lavila/models/utils.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from collections import OrderedDict
8
+ import functools
9
+ import torch
10
+ import torch.nn.functional as F
11
+
12
+
13
+ def inflate_positional_embeds(
14
+ current_model_state_dict, new_state_dict,
15
+ num_frames=4,
16
+ load_temporal_fix='bilinear',
17
+ ):
18
+ # allow loading of timesformer with fewer num_frames
19
+ curr_keys = list(current_model_state_dict.keys())
20
+ if 'visual.temporal_embed' in new_state_dict and 'visual.temporal_embed' in curr_keys:
21
+ load_temporal_embed = new_state_dict['visual.temporal_embed']
22
+ load_num_frames = load_temporal_embed.shape[1]
23
+ curr_num_frames = num_frames
24
+ embed_dim = load_temporal_embed.shape[2]
25
+
26
+ if load_num_frames != curr_num_frames:
27
+ if load_num_frames > curr_num_frames:
28
+ print(f'### loaded SpaceTimeTransformer model has MORE frames than current...'
29
+ f'### loading weights, filling in the extras via {load_temporal_fix}')
30
+ new_temporal_embed = load_temporal_embed[:, :curr_num_frames, :]
31
+ else:
32
+ print(f'### loaded SpaceTimeTransformer model has FEWER frames than current...'
33
+ f'### loading weights, filling in the extras via {load_temporal_fix}')
34
+ if load_temporal_fix == 'zeros':
35
+ new_temporal_embed = torch.zeros([load_temporal_embed.shape[0], curr_num_frames, embed_dim])
36
+ new_temporal_embed[:, :load_num_frames] = load_temporal_embed
37
+ elif load_temporal_fix in ['interp', 'bilinear']:
38
+ # interpolate
39
+ # unsqueeze so pytorch thinks its an image
40
+ mode = 'nearest'
41
+ if load_temporal_fix == 'bilinear':
42
+ mode = 'bilinear'
43
+ load_temporal_embed = load_temporal_embed.unsqueeze(0)
44
+ new_temporal_embed = F.interpolate(load_temporal_embed,
45
+ (curr_num_frames, embed_dim), mode=mode).squeeze(0)
46
+ else:
47
+ raise NotImplementedError
48
+ new_state_dict['visual.temporal_embed'] = new_temporal_embed
49
+ # allow loading with smaller spatial patches. assumes custom border crop, to append the
50
+ # border patches to the input sequence
51
+ if 'visual.pos_embed' in new_state_dict and 'visual.pos_embed' in curr_keys:
52
+ load_pos_embed = new_state_dict['visual.pos_embed']
53
+ load_num_patches = load_pos_embed.shape[1]
54
+ curr_pos_embed = current_model_state_dict['visual.pos_embed']
55
+ if load_num_patches != curr_pos_embed.shape[1]:
56
+ raise NotImplementedError(
57
+ 'Loading models with different spatial resolution / patch number not yet implemented, sorry.')
58
+
59
+ return new_state_dict
60
+
61
+
62
+ def rsetattr(obj, attr, val):
63
+ pre, _, post = attr.rpartition('.')
64
+ return setattr(rgetattr(obj, pre) if pre else obj, post, val)
65
+
66
+
67
+ def rgetattr(obj, attr, *args):
68
+ def _getattr(obj, attr):
69
+ return getattr(obj, attr, *args)
70
+ return functools.reduce(_getattr, [obj] + attr.split('.'))
71
+
72
+
73
+ # util functions to convert CLIP-style model keys to TimeSformer-style
74
+ def remap_keys(clip_state_dict, transformer_layers=12):
75
+ remapped_state_dict = OrderedDict()
76
+ key_mapping = {
77
+ "class_embedding": "cls_token",
78
+ "positional_embedding": "pos_embed",
79
+ "conv1.weight": "patch_embed.proj.weight",
80
+ "ln_pre.weight": "ln_pre.weight",
81
+ "ln_pre.bias": "ln_pre.bias",
82
+ "ln_post.weight": "norm.weight",
83
+ "ln_post.bias": "norm.bias",
84
+ }
85
+ for layer in range(transformer_layers):
86
+ key_mapping[f"transformer.resblocks.{layer}.attn.in_proj_weight"] = f"blocks.{layer}.attn.qkv.weight"
87
+ key_mapping[f"transformer.resblocks.{layer}.attn.in_proj_bias"] = f"blocks.{layer}.attn.qkv.bias"
88
+ key_mapping[f"transformer.resblocks.{layer}.attn.out_proj.weight"] = f"blocks.{layer}.attn.proj.weight"
89
+ key_mapping[f"transformer.resblocks.{layer}.attn.out_proj.bias"] = f"blocks.{layer}.attn.proj.bias"
90
+ key_mapping[f"transformer.resblocks.{layer}.ln_1.weight"] = f"blocks.{layer}.norm1.weight"
91
+ key_mapping[f"transformer.resblocks.{layer}.ln_1.bias"] = f"blocks.{layer}.norm1.bias"
92
+ key_mapping[f"transformer.resblocks.{layer}.mlp.c_fc.weight"] = f"blocks.{layer}.mlp.fc1.weight"
93
+ key_mapping[f"transformer.resblocks.{layer}.mlp.c_fc.bias"] = f"blocks.{layer}.mlp.fc1.bias"
94
+ key_mapping[f"transformer.resblocks.{layer}.mlp.c_proj.weight"] = f"blocks.{layer}.mlp.fc2.weight"
95
+ key_mapping[f"transformer.resblocks.{layer}.mlp.c_proj.bias"] = f"blocks.{layer}.mlp.fc2.bias"
96
+ key_mapping[f"transformer.resblocks.{layer}.ln_2.weight"] = f"blocks.{layer}.norm2.weight"
97
+ key_mapping[f"transformer.resblocks.{layer}.ln_2.bias"] = f"blocks.{layer}.norm2.bias"
98
+
99
+ for key in clip_state_dict:
100
+ if key == 'proj':
101
+ continue # due to possible dim mismatch, we load this later
102
+ if key == "class_embedding":
103
+ clip_state_dict[key] = clip_state_dict[key].unsqueeze(0).unsqueeze(0)
104
+ if key == "positional_embedding":
105
+ clip_state_dict[key] = clip_state_dict[key].unsqueeze(0)
106
+ remapped_state_dict[key_mapping[key]] = clip_state_dict[key]
107
+
108
+ return remapped_state_dict
lavila/utils/distributed.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ import shutil
9
+ import torch
10
+ import torch.distributed as dist
11
+
12
+
13
+ def get_model(model):
14
+ if isinstance(model, torch.nn.DataParallel) \
15
+ or isinstance(model, torch.nn.parallel.DistributedDataParallel):
16
+ return model.module
17
+ else:
18
+ return model
19
+
20
+
21
+ def setup_for_distributed(is_master):
22
+ """
23
+ This function disables printing when not in master process
24
+ """
25
+ import builtins as __builtin__
26
+ builtin_print = __builtin__.print
27
+
28
+ def print(*args, **kwargs):
29
+ force = kwargs.pop('force', False)
30
+ if is_master or force:
31
+ builtin_print(*args, **kwargs)
32
+
33
+ __builtin__.print = print
34
+
35
+
36
+ def is_dist_avail_and_initialized():
37
+ if not dist.is_available():
38
+ return False
39
+ if not dist.is_initialized():
40
+ return False
41
+ return True
42
+
43
+
44
+ def get_world_size():
45
+ if not is_dist_avail_and_initialized():
46
+ return 1
47
+ else:
48
+ return dist.get_world_size()
49
+
50
+
51
+ def get_rank():
52
+ if not is_dist_avail_and_initialized():
53
+ return 0
54
+ return dist.get_rank()
55
+
56
+
57
+ def is_main_process():
58
+ return get_rank() == 0
59
+
60
+
61
+ def save_on_master(state, is_best, output_dir, is_epoch=True):
62
+ if is_main_process():
63
+ ckpt_path = f'{output_dir}/checkpoint.pt'
64
+ best_path = f'{output_dir}/checkpoint_best.pt'
65
+ if is_best:
66
+ torch.save(state, best_path)
67
+ if is_epoch:
68
+ if isinstance(state['epoch'], int):
69
+ ckpt2_path = '{}/checkpoint_{:04d}.pt'.format(output_dir, state['epoch'])
70
+ else:
71
+ ckpt2_path = '{}/checkpoint_{:.4f}.pt'.format(output_dir, state['epoch'])
72
+ torch.save(state, ckpt_path)
73
+ shutil.copy(ckpt_path, ckpt2_path)
74
+
75
+
76
+ def init_distributed_mode(args):
77
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
78
+ args.rank = int(os.environ["RANK"])
79
+ args.world_size = int(os.environ['WORLD_SIZE'])
80
+ args.gpu = int(os.environ['LOCAL_RANK'])
81
+ elif 'SLURM_PROCID' in os.environ:
82
+ args.rank = int(os.environ['SLURM_PROCID'])
83
+ args.gpu = args.rank % torch.cuda.device_count()
84
+ else:
85
+ print('Not using distributed mode')
86
+ args.distributed = False
87
+ return
88
+
89
+ args.distributed = True
90
+
91
+ torch.cuda.set_device(args.gpu)
92
+ args.dist_backend = 'nccl'
93
+ print('| distributed init (rank {}): {}'.format(
94
+ args.rank, args.dist_url), flush=True)
95
+ torch.distributed.init_process_group(
96
+ backend=args.dist_backend,
97
+ init_method=args.dist_url,
98
+ world_size=args.world_size,
99
+ rank=args.rank
100
+ )
101
+ torch.distributed.barrier()
102
+ setup_for_distributed(args.rank == 0)
lavila/utils/evaluation.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+
11
+ def accuracy(output, target, topk=(1,)):
12
+ """Computes the accuracy over the k top predictions for the specified values of k"""
13
+ with torch.no_grad():
14
+ maxk = max(topk)
15
+ batch_size = target.size(0)
16
+
17
+ _, pred = output.topk(maxk, 1, True, True)
18
+ pred = pred.t()
19
+ correct = pred.eq(target.reshape(1, -1).expand_as(pred))
20
+
21
+ res = []
22
+ for k in topk:
23
+ correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
24
+ res.append(correct_k.mul_(100.0 / batch_size))
25
+ return res
26
+
27
+
28
+ def get_mean_accuracy(cm):
29
+ list_acc = []
30
+ for i in range(len(cm)):
31
+ acc = 0
32
+ if cm[i, :].sum() > 0:
33
+ acc = cm[i, i] / cm[i, :].sum()
34
+ list_acc.append(acc)
35
+
36
+ return 100 * np.mean(list_acc), 100 * np.trace(cm) / np.sum(cm)
lavila/utils/evaluation_charades.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+
9
+
10
+ def compute_map(submission_array, gt_array):
11
+ """ Returns mAP, weighted mAP, and AP array """
12
+ m_aps = []
13
+ n_classes = submission_array.shape[1]
14
+ for oc_i in range(n_classes):
15
+ sorted_idxs = np.argsort(-submission_array[:, oc_i])
16
+ tp = gt_array[:, oc_i][sorted_idxs] == 1
17
+ fp = np.invert(tp)
18
+ n_pos = tp.sum()
19
+ if n_pos < 0.1:
20
+ m_aps.append(float('nan'))
21
+ continue
22
+ fp.sum()
23
+ f_pcs = np.cumsum(fp)
24
+ t_pcs = np.cumsum(tp)
25
+ prec = t_pcs / (f_pcs+t_pcs).astype(float)
26
+ avg_prec = 0
27
+ for i in range(submission_array.shape[0]):
28
+ if tp[i]:
29
+ avg_prec += prec[i]
30
+ m_aps.append(avg_prec / n_pos.astype(float))
31
+ m_aps = np.array(m_aps)
32
+ m_ap = np.mean(m_aps)
33
+ w_ap = (m_aps * gt_array.sum(axis=0) / gt_array.sum().sum().astype(float))
34
+ return m_ap, w_ap, m_aps
35
+
36
+
37
+ def charades_map(submission_array, gt_array):
38
+ """
39
+ Approximate version of the charades evaluation function
40
+ For precise numbers, use the submission file with the official matlab script
41
+ """
42
+ fix = submission_array.copy()
43
+ empty = np.sum(gt_array, axis=1) == 0
44
+ fix[empty, :] = np.NINF
45
+ return compute_map(fix, gt_array)
46
+
47
+
48
+ def create_submission(video_list, predictions, out_file):
49
+ assert len(video_list) == predictions.shape[0]
50
+ with open(out_file, 'w') as f:
51
+ for i, video_id in enumerate(video_list):
52
+ pred_str = ' '.join(map(lambda x: str(x), predictions[i].tolist()))
53
+ f.write('{} {}\n\n'.format(video_id, pred_str))
lavila/utils/evaluation_egomcq.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+
9
+
10
+ def egomcq_accuracy_metrics(preds, labels, types):
11
+ metrics = {}
12
+ type_list = torch.unique(types)
13
+ group_list = ["Intra-video", "Inter-video"]
14
+ for type_i, group_i in zip(type_list, group_list):
15
+ correct = 0
16
+ total = 0
17
+ for pred, label, type in zip(preds, labels, types):
18
+ if type == type_i:
19
+ pred_ = torch.argmax(pred)
20
+ if pred_.item() == label.item():
21
+ correct += 1
22
+ total += 1
23
+ accuracy = correct/total
24
+ metrics[group_i] = accuracy * 100
25
+ return metrics
lavila/utils/evaluation_ek100cls.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Part of the code is from https://github.com/fpv-iplab/rulstm/blob/master/RULSTM/utils.py
8
+ # Modified by Yue Zhao
9
+
10
+ import numpy as np
11
+
12
+
13
+ def get_marginal_indexes(actions, mode):
14
+ """For each verb/noun retrieve the list of actions containing that verb/name
15
+ Input:
16
+ mode: "verb" or "noun"
17
+ Output:
18
+ a list of numpy array of indexes. If verb/noun 3 is contained in actions 2,8,19,
19
+ then output[3] will be np.array([2,8,19])
20
+ """
21
+ vi = []
22
+ for v in range(actions[mode].max()+1):
23
+ vals = actions[actions[mode] == v].index.values
24
+ if len(vals) > 0:
25
+ vi.append(vals)
26
+ else:
27
+ vi.append(np.array([0]))
28
+ return vi
29
+
30
+
31
+ def marginalize(probs, indexes):
32
+ mprobs = []
33
+ for ilist in indexes:
34
+ mprobs.append(probs[:, ilist].sum(1))
35
+ return np.array(mprobs).T
lavila/utils/evaluation_ek100mir.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Part of the code is from
8
+ # `https://github.com/mwray/Joint-Part-of-Speech-Embeddings/tree/main/src/evaluation/NDCG.py`
9
+ # and
10
+ # `https://github.com/mwray/Joint-Part-of-Speech-Embeddings/tree/main/src/evaluation/mAP.py`
11
+ # Modified by Yue Zhao
12
+
13
+ import numpy as np
14
+
15
+
16
+ def calculate_DCG(similarity_matrix, relevancy_matrix, k_counts):
17
+ """
18
+ Calculates the Discounted Cumulative Gain (DCG) between two modalities for
19
+ the first modality.
20
+ DCG = \sum_{i=1}^k \frac{rel_i}{log_2(i + 1)}
21
+ i.e. the sum of the k relevant retrievals which is calculated as the scaled
22
+ relevancy for the ith item. The scale is designed such that early
23
+ retrievals are more important than later retrievals.
24
+ Params:
25
+ - similarity_matrix: matrix of size n1 x n2 where n1 is the number of
26
+ items in the first modality and n2 is the number of items in the
27
+ second modality. The [ith,jth] element is the predicted similarity
28
+ between the ith item from the first modality and the jth item from
29
+ the second modality.
30
+ - relevancy_matrix: matrix of size n1 x n2 (see similarity_matrix
31
+ above). The [ith, jth] element is the semantic relevancy between the
32
+ ith item from the first modality and the jth item from the second
33
+ modality.
34
+ - k_counts: matrix of size n1 x n2 (see similarity_matrix above) which
35
+ includes information on which items to use to calculate the DCG for
36
+ (see calculate_k_counts for more info on this matrix).
37
+ Returns:
38
+ - The DCG for each item in the first modality, a n1 length vector.
39
+ """
40
+ x_sz, y_sz = similarity_matrix.shape
41
+ ranks = np.argsort(similarity_matrix)[:, ::-1]
42
+ # Create vector of size (n,) where n is the length of the last dimension in
43
+ # similarity matrix
44
+ # This vector is of the form log(i+1)
45
+ logs = np.log2(np.arange(y_sz) + 2)
46
+ # Convert logs into the divisor for the DCG calculation, of size similarity
47
+ # matrix
48
+ divisors = np.repeat(np.expand_dims(logs, axis=0), x_sz, axis=0)
49
+
50
+ # mask out the sorted relevancy matrix to only use the first k relevant
51
+ # retrievals for each item.
52
+ columns = np.repeat(np.expand_dims(np.arange(x_sz), axis=1), y_sz, axis=1)
53
+ numerators = relevancy_matrix[columns, ranks] * k_counts
54
+ # Calculate the final DCG score (note that this isn't expected to sum to 1)
55
+ return np.sum(numerators / divisors, axis=1)
56
+
57
+
58
+ def calculate_k_counts(relevancy_matrix):
59
+ """
60
+ Works out the maximum number of allowed retrievals when working out the
61
+ Discounted Cumulative Gain. For each query the DCG only uses the first k
62
+ items retrieved which constitute the k relevant items for that query
63
+ (otherwise the nDCG scores can be deceptively high for bad rankings).
64
+ Params:
65
+ - relevancy_matrix: matrix of size n1 x n2 where n1 is the number of
66
+ items in the first modality and n2 is the number of items in the
67
+ second modality. The [ith, jth] element is the semantic relevancy
68
+ between the ith item from the first modality and the jth item from
69
+ the second modality.
70
+ Returns:
71
+ - Matrix of size n1 x n2 (see relevancy matrix for more info). This is
72
+ created as a mask such that if the [ith, jth] element is 1 it
73
+ represents a valid item to use for the calculation of DCG for the
74
+ ith item after sorting. For example, if relevancy matrix of:
75
+ [[1, 0.5, 0],
76
+ [0, 0 , 1]]
77
+ is given, then the k_counts matrix will be:
78
+ [[1, 1, 0],
79
+ [1, 0, 0]]
80
+ i.e. the first row has 2 non-zero items, so the first two retrieved
81
+ items should be used in the calculation. In the second row there is
82
+ only 1 relevant item, therefore only the first retrieved item should
83
+ be used for the DCG calculation.
84
+ """
85
+ return (np.sort(relevancy_matrix)[:, ::-1] > 0).astype(int)
86
+
87
+
88
+ def calculate_IDCG(relevancy_matrix, k_counts):
89
+ """
90
+ Calculates the Ideal Discounted Cumulative Gain (IDCG) which is the value
91
+ of the Discounted Cumulative Gain (DCG) for a perfect retrieval, i.e. the
92
+ items in the second modality were retrieved in order of their descending
93
+ relevancy.
94
+ Params:
95
+ - relevancy_matrix: matrix of size n1 x n2 where n1 is the number of
96
+ items in the first modality and n2 is the number of items in the
97
+ second modality. The [ith, jth] element is the semantic relevancy
98
+ between the ith item from the first modality and the jth item from
99
+ the second modality.
100
+ - k_counts: matrix of size n1 x n2 (see similarity_matrix above) which
101
+ includes information on which items to use to calculate the DCG for
102
+ (see calculate_k_counts for more info on this matrix).
103
+ """
104
+ return calculate_DCG(relevancy_matrix, relevancy_matrix, k_counts)
105
+
106
+
107
+ def calculate_nDCG(similarity_matrix, relevancy_matrix, k_counts=None, IDCG=None, reduction='mean'):
108
+ """
109
+ Calculates the normalised Discounted Cumulative Gain (nDCG) between two
110
+ modalities for the first modality using the Discounted Cumulative Gain
111
+ (DCG) and the Ideal Discounted Cumulative Gain (IDCG).
112
+ nDCG = \frac{DCG}{IDCG}
113
+ Params:
114
+ - similarity_matrix: matrix of size n1 x n2 where n1 is the number of
115
+ items in the first modality and n2 is the number of items in the second
116
+ modality. The [ith,jth] element is the predicted similarity between
117
+ the ith item from the first modality and the jth item from the second
118
+ modality.
119
+ - relevancy_matrix: matrix of size n1 x n2 (see similarity_matrix
120
+ above). The [ith, jth] element is the semantic relevancy between the
121
+ ith item from the first modality and the jth item from the second
122
+ modality.
123
+ - k_counts: optional parameter: matrix of size n1 x n2 (see
124
+ similarity_matrix above) which includes information on which items to
125
+ use to calculate the DCG for (see calculate_k_counts for more info on
126
+ this matrix). This will be calculated using calculate_IDCG if not
127
+ present, but should be pre-processed for efficiency.
128
+ - IDCG: Optional parameter which includes the pre-processed Ideal
129
+ Discounted Cumulative Gain (IDCG). This is a vector of size n1 (see
130
+ similarity_matrix above) which contains the IDCG value for each item
131
+ from the first modality. This will be calculated using calculate_IDCG
132
+ if not present, but should be pre-processed for efficiency.
133
+ - reduction: what to use to reduce the different nDCG scores. By
134
+ default this applies np.mean across all different queries.
135
+ Returns:
136
+ - The nDCG values for the first modality.
137
+ """
138
+ if k_counts is None:
139
+ k_counts = calculate_k_counts(relevancy_matrix)
140
+ DCG = calculate_DCG(similarity_matrix, relevancy_matrix, k_counts)
141
+ if IDCG is None:
142
+ IDCG = calculate_IDCG(relevancy_matrix, k_counts)
143
+ if reduction == 'mean':
144
+ return np.mean(DCG / IDCG)
145
+ elif reduction is None:
146
+ return DCG / IDCG
147
+
148
+
149
+ def calculate_mAP(sim_mat, relevancy_matrix):
150
+ """
151
+ Computes the mean average precision according to the following formula of
152
+ average precision:
153
+ \frac{\sum_{k=1}^n p(k) x rel(k)}{num_rel_docs}
154
+ where p(k) is the precision at k, rel(k) is an indicator function
155
+ determining whether the kth returned item is relevant or not and
156
+ num_rel_docs is the number of relevant items to find within the search.
157
+ The mean average precision is the mean of the average precision for each
158
+ query item (i.e row in the matrix)
159
+ This function takes in two parameters:
160
+ - sim_mat: a NxM matrix which represents the similarity between two
161
+ modalities (with modality 1 being of size N and modality 2 of size M).
162
+ - relevancy_matrix: an NxM matrix which represents the relevancy between two
163
+ modalities of items (with modality 1 being of size N and modality 2 of
164
+ size M).
165
+ """
166
+ # Find the order of the items in modality 2 according to modality 1
167
+ ranked_order = (-sim_mat).argsort()
168
+ ranked_sim_mat = sim_mat[np.arange(sim_mat.shape[0])[:, None], ranked_order]
169
+ # re-order the relevancy matrix to accommodate the proposals
170
+ ranked_rel_mat = relevancy_matrix[np.arange(relevancy_matrix.shape[0])[:, None], ranked_order]
171
+
172
+ # find the number of relevant items found at each k
173
+ cumulative_rel_mat = np.cumsum(ranked_rel_mat, axis=1)
174
+ # Mask this ensuring that it is non zero if the kth term is 1 (rel(k) above)
175
+ cumulative_rel_mat[ranked_rel_mat != 1] = 0
176
+ # find the divisor for p(k)
177
+ divisor = np.arange(ranked_rel_mat.shape[1]) + 1
178
+
179
+ # find the number of relevant docs per query item
180
+ number_rel_docs = np.sum(ranked_rel_mat == 1, axis=1)
181
+
182
+ # find the average precision per query, within np.sum finds p(k) * rel(k)
183
+ avg_precision = np.sum(cumulative_rel_mat / divisor, axis=1) / number_rel_docs
184
+ mAP = np.mean(avg_precision)
185
+ return mAP
186
+
187
+
188
+ def get_mAP(similarity_matrix, rel_matrix):
189
+ vis_map = calculate_mAP(similarity_matrix, rel_matrix)
190
+ txt_map = calculate_mAP(similarity_matrix.T, rel_matrix.T)
191
+ return vis_map, txt_map, (vis_map + txt_map) / 2
192
+
193
+
194
+ def get_nDCG(similarity_matrix, rel_matrix):
195
+ vis_k_counts = calculate_k_counts(rel_matrix)
196
+ txt_k_counts = calculate_k_counts(rel_matrix.T)
197
+ vis_IDCG = calculate_IDCG(rel_matrix, vis_k_counts)
198
+ txt_IDCG = calculate_IDCG(rel_matrix.T, txt_k_counts)
199
+ vis_nDCG = calculate_nDCG(similarity_matrix, rel_matrix, k_counts=vis_k_counts, IDCG=vis_IDCG)
200
+ txt_nDCG = calculate_nDCG(similarity_matrix.T, rel_matrix.T, k_counts=txt_k_counts, IDCG=txt_IDCG)
201
+ return vis_nDCG, txt_nDCG, (vis_nDCG + txt_nDCG) / 2
lavila/utils/meter.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+ from lavila.utils import distributed as dist_utils
10
+
11
+
12
+ class AverageMeter(object):
13
+ """Computes and stores the average and current value"""
14
+ def __init__(self, name, fmt=':f'):
15
+ self.name = name
16
+ self.fmt = fmt
17
+ self.reset()
18
+
19
+ def reset(self):
20
+ self.val = 0
21
+ self.avg = 0
22
+ self.sum = 0
23
+ self.count = 0
24
+
25
+ def update(self, val, n=1):
26
+ self.val = val
27
+ self.sum += val * n
28
+ self.count += n
29
+ self.avg = self.sum / self.count
30
+
31
+ def synchronize(self):
32
+ if not dist_utils.is_dist_avail_and_initialized():
33
+ return
34
+ t = torch.tensor([self.sum, self.count], dtype=torch.float64, device='cuda')
35
+ dist.barrier()
36
+ dist.all_reduce(t)
37
+ t = t.tolist()
38
+ self.sum = int(t[0])
39
+ self.count = t[1]
40
+ self.avg = self.sum / self.count
41
+
42
+ def __str__(self):
43
+ fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
44
+ return fmtstr.format(**self.__dict__)
45
+
46
+
47
+ class ProgressMeter(object):
48
+ def __init__(self, num_batches, meters, prefix=""):
49
+ self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
50
+ self.meters = meters
51
+ self.prefix = prefix
52
+
53
+ def display(self, batch):
54
+ entries = [self.prefix + self.batch_fmtstr.format(batch)]
55
+ entries += [str(meter) for meter in self.meters]
56
+ print('\t'.join(entries))
57
+
58
+ def synchronize(self):
59
+ for meter in self.meters:
60
+ meter.synchronize()
61
+
62
+ def _get_batch_fmtstr(self, num_batches):
63
+ num_digits = len(str(num_batches // 1))
64
+ fmt = '{:' + str(num_digits) + 'd}'
65
+ return '[' + fmt + '/' + fmt.format(num_batches) + ']'