Spaces:
Runtime error
Runtime error
Upload . with huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +7 -0
- .gitignore +129 -0
- CODE_OF_CONDUCT.md +5 -0
- CONTRIBUTING.md +39 -0
- LICENSE +22 -0
- app.py +146 -0
- assets/06919917-76bc-4adc-b944-2a722f165513.gif +3 -0
- assets/3c0dffd0-e38e-4643-bc48-d513943dc20b_012_014.mp4 +0 -0
- assets/ab865129-78fa-47d4-8a50-ff8c5533246f.gif +3 -0
- assets/cf7c12db-1a9e-46d3-96d6-38174bbe373c.gif +3 -0
- assets/mixkit-chef-preparing-a-sauce-in-a-blender-43034-medium.gif +3 -0
- assets/mixkit-hands-of-a-baker-kneading-a-dough-42467-medium.gif +3 -0
- assets/mixkit-pastry-chef-cutting-a-loaf-into-slices-43015-medium.gif +3 -0
- assets/narrator.gif +3 -0
- assets/rephraser.gif +0 -0
- datasets/README.md +153 -0
- demo_narrator.py +97 -0
- demo_narrator_3rd_person.py +99 -0
- docs/INSTALL.md +15 -0
- docs/MODEL_ZOO.md +311 -0
- docs/PRETRAIN.md +125 -0
- eval_narrator.py +308 -0
- eval_zeroshot.py +389 -0
- lavila/data/__pycache__/datasets.cpython-38.pyc +0 -0
- lavila/data/__pycache__/video_transforms.cpython-38.pyc +0 -0
- lavila/data/datasets.py +517 -0
- lavila/data/video_transforms.py +186 -0
- lavila/models/__pycache__/distributed_utils.cpython-38.pyc +0 -0
- lavila/models/__pycache__/gpt2_gated.cpython-38.pyc +0 -0
- lavila/models/__pycache__/loss.cpython-38.pyc +0 -0
- lavila/models/__pycache__/models.cpython-38.pyc +0 -0
- lavila/models/bpe_simple_vocab_16e6.txt.gz +3 -0
- lavila/models/coca.py +131 -0
- lavila/models/distributed_utils.py +89 -0
- lavila/models/gpt2_gated.py +1615 -0
- lavila/models/loss.py +367 -0
- lavila/models/models.py +1218 -0
- lavila/models/narrator.py +385 -0
- lavila/models/openai_clip.py +237 -0
- lavila/models/openai_model.py +485 -0
- lavila/models/timesformer.py +390 -0
- lavila/models/tokenizer.py +239 -0
- lavila/models/utils.py +108 -0
- lavila/utils/distributed.py +102 -0
- lavila/utils/evaluation.py +36 -0
- lavila/utils/evaluation_charades.py +53 -0
- lavila/utils/evaluation_egomcq.py +25 -0
- lavila/utils/evaluation_ek100cls.py +35 -0
- lavila/utils/evaluation_ek100mir.py +201 -0
- lavila/utils/meter.py +65 -0
.gitattributes
CHANGED
@@ -32,3 +32,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
assets/mixkit-hands-of-a-baker-kneading-a-dough-42467-medium.gif filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/mixkit-pastry-chef-cutting-a-loaf-into-slices-43015-medium.gif filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/ab865129-78fa-47d4-8a50-ff8c5533246f.gif filter=lfs diff=lfs merge=lfs -text
|
38 |
+
assets/06919917-76bc-4adc-b944-2a722f165513.gif filter=lfs diff=lfs merge=lfs -text
|
39 |
+
assets/cf7c12db-1a9e-46d3-96d6-38174bbe373c.gif filter=lfs diff=lfs merge=lfs -text
|
40 |
+
assets/narrator.gif filter=lfs diff=lfs merge=lfs -text
|
41 |
+
assets/mixkit-chef-preparing-a-sauce-in-a-blender-43034-medium.gif filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
.python-version
|
86 |
+
|
87 |
+
# pipenv
|
88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
91 |
+
# install all needed dependencies.
|
92 |
+
#Pipfile.lock
|
93 |
+
|
94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
95 |
+
__pypackages__/
|
96 |
+
|
97 |
+
# Celery stuff
|
98 |
+
celerybeat-schedule
|
99 |
+
celerybeat.pid
|
100 |
+
|
101 |
+
# SageMath parsed files
|
102 |
+
*.sage.py
|
103 |
+
|
104 |
+
# Environments
|
105 |
+
.env
|
106 |
+
.venv
|
107 |
+
env/
|
108 |
+
venv/
|
109 |
+
ENV/
|
110 |
+
env.bak/
|
111 |
+
venv.bak/
|
112 |
+
|
113 |
+
# Spyder project settings
|
114 |
+
.spyderproject
|
115 |
+
.spyproject
|
116 |
+
|
117 |
+
# Rope project settings
|
118 |
+
.ropeproject
|
119 |
+
|
120 |
+
# mkdocs documentation
|
121 |
+
/site
|
122 |
+
|
123 |
+
# mypy
|
124 |
+
.mypy_cache/
|
125 |
+
.dmypy.json
|
126 |
+
dmypy.json
|
127 |
+
|
128 |
+
# Pyre type checker
|
129 |
+
.pyre/
|
CODE_OF_CONDUCT.md
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Code of Conduct
|
2 |
+
|
3 |
+
Facebook has adopted a Code of Conduct that we expect project participants to adhere to.
|
4 |
+
Please read the [full text](https://code.fb.com/codeofconduct/)
|
5 |
+
so that you can understand what actions will and will not be tolerated.
|
CONTRIBUTING.md
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Contributing to LaViLa
|
2 |
+
We want to make contributing to this project as easy and transparent as
|
3 |
+
possible.
|
4 |
+
|
5 |
+
## Our Development Process
|
6 |
+
Minor changes and improvements will be released on an ongoing basis. Larger changes (e.g., changesets implementing a new paper) will be released on a more periodic basis.
|
7 |
+
|
8 |
+
## Pull Requests
|
9 |
+
We actively welcome your pull requests.
|
10 |
+
|
11 |
+
1. Fork the repo and create your branch from `main`.
|
12 |
+
2. If you've added code that should be tested, add tests.
|
13 |
+
3. If you've changed APIs, update the documentation.
|
14 |
+
4. Ensure the test suite passes.
|
15 |
+
5. Make sure your code lints.
|
16 |
+
6. If you haven't already, complete the Contributor License Agreement ("CLA").
|
17 |
+
|
18 |
+
## Contributor License Agreement ("CLA")
|
19 |
+
In order to accept your pull request, we need you to submit a CLA. You only need
|
20 |
+
to do this once to work on any of Facebook's open source projects.
|
21 |
+
|
22 |
+
Complete your CLA here: <https://code.facebook.com/cla>
|
23 |
+
|
24 |
+
## Issues
|
25 |
+
We use GitHub issues to track public bugs. Please ensure your description is
|
26 |
+
clear and has sufficient instructions to be able to reproduce the issue.
|
27 |
+
|
28 |
+
Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
|
29 |
+
disclosure of security bugs. In those cases, please go through the process
|
30 |
+
outlined on that page and do not file a public issue.
|
31 |
+
|
32 |
+
## Coding Style
|
33 |
+
* 4 spaces for indentation rather than tabs
|
34 |
+
* 80 character line length
|
35 |
+
* PEP8 formatting following [Black](https://black.readthedocs.io/en/stable/)
|
36 |
+
|
37 |
+
## License
|
38 |
+
By contributing to LaViLa, you agree that your contributions will be licensed
|
39 |
+
under the LICENSE file in the root directory of this source tree.
|
LICENSE
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
MIT License
|
3 |
+
|
4 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
5 |
+
|
6 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
7 |
+
of this software and associated documentation files (the "Software"), to deal
|
8 |
+
in the Software without restriction, including without limitation the rights
|
9 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
10 |
+
copies of the Software, and to permit persons to whom the Software is
|
11 |
+
furnished to do so, subject to the following conditions:
|
12 |
+
|
13 |
+
The above copyright notice and this permission notice shall be included in all
|
14 |
+
copies or substantial portions of the Software.
|
15 |
+
|
16 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
17 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
18 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
19 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
20 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
21 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
22 |
+
SOFTWARE.
|
app.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.insert(0, './')
|
3 |
+
|
4 |
+
import decord
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import os
|
8 |
+
|
9 |
+
from lavila.data.video_transforms import Permute
|
10 |
+
from lavila.data.datasets import get_frame_ids, video_loader_by_frames
|
11 |
+
from lavila.models.models import VCLM_OPENAI_TIMESFORMER_BASE_GPT2
|
12 |
+
from lavila.models.tokenizer import MyGPT2Tokenizer
|
13 |
+
from collections import OrderedDict
|
14 |
+
import torch
|
15 |
+
import torchvision.transforms as transforms
|
16 |
+
import torchvision.transforms._transforms_video as transforms_video
|
17 |
+
import gradio as gr
|
18 |
+
|
19 |
+
def get_frame_ids(start_frame, end_frame, num_segments=32, jitter=True):
|
20 |
+
seg_size = float(end_frame - start_frame - 1) / num_segments
|
21 |
+
seq = []
|
22 |
+
for i in range(num_segments):
|
23 |
+
start = int(np.round(seg_size * i) + start_frame)
|
24 |
+
end = int(np.round(seg_size * (i + 1)) + start_frame)
|
25 |
+
end = min(end, end_frame)
|
26 |
+
if jitter:
|
27 |
+
frame_id = np.random.randint(low=start, high=(end + 1))
|
28 |
+
else:
|
29 |
+
frame_id = (start + end) // 2
|
30 |
+
seq.append(frame_id)
|
31 |
+
return seq
|
32 |
+
|
33 |
+
def video_loader_by_frames(root, vid, frame_ids):
|
34 |
+
vr = decord.VideoReader(os.path.join(root, vid))
|
35 |
+
try:
|
36 |
+
frames = vr.get_batch(frame_ids).asnumpy()
|
37 |
+
frames = [torch.tensor(frame, dtype=torch.float32) for frame in frames]
|
38 |
+
except (IndexError, decord.DECORDError) as error:
|
39 |
+
print(error)
|
40 |
+
print("Erroneous video: ", vid)
|
41 |
+
frames = [torch.zeros((240, 320, 3)) for _ in range(len(frame_ids))]
|
42 |
+
return torch.stack(frames, dim=0)
|
43 |
+
|
44 |
+
def iter_clips(video_path, num_segments=4, stride_size=16):
|
45 |
+
# The video is represented by `num_seg=4` frames
|
46 |
+
vr = decord.VideoReader(video_path)
|
47 |
+
frame_sample_size = num_segments * stride_size
|
48 |
+
max_start_frame = len(vr) - frame_sample_size
|
49 |
+
curr_frame = 0
|
50 |
+
fps = vr.get_avg_fps()
|
51 |
+
while curr_frame < max_start_frame:
|
52 |
+
stop_frame = min(frame_sample_size, len(vr))
|
53 |
+
curr_sec, stop_sec = curr_frame / fps, stop_frame / fps
|
54 |
+
frame_ids = get_frame_ids(curr_frame, stop_frame, num_segments=num_segments, jitter=False)
|
55 |
+
frames = video_loader_by_frames('./', video_path, frame_ids)
|
56 |
+
yield curr_sec, stop_sec, frames
|
57 |
+
curr_frame += frame_sample_size
|
58 |
+
|
59 |
+
|
60 |
+
class Pipeline:
|
61 |
+
def __init__(self, path=""):
|
62 |
+
ckpt_path = os.path.join(path, 'vclm_openai_timesformer_base_gpt2_base.pt_ego4d.jobid_319630.ep_0002.md5sum_68a71f.pth')
|
63 |
+
ckpt = torch.load(ckpt_path, map_location='cpu')
|
64 |
+
state_dict = OrderedDict()
|
65 |
+
for k, v in ckpt['state_dict'].items():
|
66 |
+
state_dict[k.replace('module.', '')] = v
|
67 |
+
|
68 |
+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
69 |
+
self.model = VCLM_OPENAI_TIMESFORMER_BASE_GPT2(
|
70 |
+
text_use_cls_token=False,
|
71 |
+
project_embed_dim=256,
|
72 |
+
gated_xattn=True,
|
73 |
+
timesformer_gated_xattn=False,
|
74 |
+
freeze_lm_vclm=False,
|
75 |
+
freeze_visual_vclm=False,
|
76 |
+
freeze_visual_vclm_temporal=False,
|
77 |
+
num_frames=4,
|
78 |
+
drop_path_rate=0.
|
79 |
+
)
|
80 |
+
self.model.load_state_dict(state_dict, strict=True)
|
81 |
+
self.model.to(self.device)
|
82 |
+
self.model.eval()
|
83 |
+
|
84 |
+
self.tokenizer = MyGPT2Tokenizer('gpt2', add_bos=True)
|
85 |
+
|
86 |
+
crop_size = 224
|
87 |
+
self.val_transform = transforms.Compose([
|
88 |
+
Permute([3, 0, 1, 2]),
|
89 |
+
transforms.Resize(crop_size),
|
90 |
+
transforms.CenterCrop(crop_size),
|
91 |
+
transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305])
|
92 |
+
])
|
93 |
+
|
94 |
+
def decode_one(self, generated_ids, tokenizer):
|
95 |
+
# get the index of <EOS>
|
96 |
+
if tokenizer.eos_token_id == tokenizer.bos_token_id:
|
97 |
+
if tokenizer.eos_token_id in generated_ids[1:].tolist():
|
98 |
+
eos_id = generated_ids[1:].tolist().index(tokenizer.eos_token_id) + 1
|
99 |
+
else:
|
100 |
+
eos_id = len(generated_ids.tolist()) - 1
|
101 |
+
elif tokenizer.eos_token_id in generated_ids.tolist():
|
102 |
+
eos_id = generated_ids.tolist().index(tokenizer.eos_token_id)
|
103 |
+
else:
|
104 |
+
eos_id = len(generated_ids.tolist()) - 1
|
105 |
+
generated_text_str = tokenizer.tokenizer.decode(generated_ids[1:eos_id].tolist())
|
106 |
+
return generated_text_str
|
107 |
+
|
108 |
+
def __call__(self, video_path, temperature=0.7, top_p=0.95, max_text_length=77, num_return_sequences=10):
|
109 |
+
text = ""
|
110 |
+
with torch.autocast(self.device):
|
111 |
+
for start, stop, frames in iter_clips(video_path):
|
112 |
+
text += f"{'-'*30} Predictions From: {start:10.3f}-{stop:10.3f} seconds {'-'*30}\n"
|
113 |
+
frames = self.val_transform(frames).unsqueeze(0)
|
114 |
+
if self.device == 'cuda':
|
115 |
+
frames = frames.to(self.device).half()
|
116 |
+
|
117 |
+
with torch.no_grad():
|
118 |
+
image_features = self.model.encode_image(frames)
|
119 |
+
generated_text_ids, ppls = self.model.generate(
|
120 |
+
image_features,
|
121 |
+
self.tokenizer,
|
122 |
+
target=None, # free-form generation
|
123 |
+
max_text_length=max_text_length,
|
124 |
+
top_k=None,
|
125 |
+
top_p=top_p, # nucleus sampling
|
126 |
+
num_return_sequences=num_return_sequences, # number of candidates: 10
|
127 |
+
temperature=temperature,
|
128 |
+
early_stopping=True,
|
129 |
+
)
|
130 |
+
for i in range(num_return_sequences):
|
131 |
+
generated_text_str = self.decode_one(generated_text_ids[i], self.tokenizer)
|
132 |
+
text += '\t{}: {}\n'.format(i, generated_text_str)
|
133 |
+
return text
|
134 |
+
|
135 |
+
interface = gr.Interface(
|
136 |
+
Pipeline(),
|
137 |
+
inputs=[
|
138 |
+
gr.Video(label='video_path'),
|
139 |
+
gr.Slider(0.0, 1.0, 0.7, label='temperature'),
|
140 |
+
gr.Slider(0.0, 1.0, 0.95, label='top_p'),
|
141 |
+
],
|
142 |
+
outputs='text'
|
143 |
+
)
|
144 |
+
|
145 |
+
if __name__ == '__main__':
|
146 |
+
interface.launch(debug=True)
|
assets/06919917-76bc-4adc-b944-2a722f165513.gif
ADDED
Git LFS Details
|
assets/3c0dffd0-e38e-4643-bc48-d513943dc20b_012_014.mp4
ADDED
Binary file (127 kB). View file
|
|
assets/ab865129-78fa-47d4-8a50-ff8c5533246f.gif
ADDED
Git LFS Details
|
assets/cf7c12db-1a9e-46d3-96d6-38174bbe373c.gif
ADDED
Git LFS Details
|
assets/mixkit-chef-preparing-a-sauce-in-a-blender-43034-medium.gif
ADDED
Git LFS Details
|
assets/mixkit-hands-of-a-baker-kneading-a-dough-42467-medium.gif
ADDED
Git LFS Details
|
assets/mixkit-pastry-chef-cutting-a-loaf-into-slices-43015-medium.gif
ADDED
Git LFS Details
|
assets/narrator.gif
ADDED
Git LFS Details
|
assets/rephraser.gif
ADDED
datasets/README.md
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Preparing datasets for LAVILA
|
2 |
+
|
3 |
+
Please download the (selected) datasets from the official websites and place or sim-link them under `$LAVILA_ROOT/datasets/`.
|
4 |
+
|
5 |
+
```bash
|
6 |
+
$LAVILA_ROOT/datasets/
|
7 |
+
CharadesEgo/
|
8 |
+
EGTEA/
|
9 |
+
EK100/
|
10 |
+
Ego4D/
|
11 |
+
```
|
12 |
+
|
13 |
+
## Ego4D
|
14 |
+
1. Download [Ego4D videos](https://ego4d-data.org/docs/start-here/#download-data) (license is required).
|
15 |
+
|
16 |
+
2. Preprocess(TBA)
|
17 |
+
|
18 |
+
3. Download annotations
|
19 |
+
|
20 |
+
a. Download [egomcq.json](https://drive.google.com/file/d/1-5iRYf4BCHmj4MYQYFRMY4bhsWJUN3rW/view) to `$LAVILA_ROOT/datasets/Ego4D` (if you want to evaluate EgoMCQ).
|
21 |
+
|
22 |
+
b. Download [metadata for train split](https://dl.fbaipublicfiles.com/lavila/metadata/ego4d/ego4d_train.pkl) and [val split](https://dl.fbaipublicfiles.com/lavila/metadata/ego4d/ego4d_val.pkl) to `$LAVILA_ROOT/datasets/Ego4D` ((if you want to train LAVILA from scratch).
|
23 |
+
|
24 |
+
The fold should look like this:
|
25 |
+
```bash
|
26 |
+
$LAVILA_ROOT/datasets/
|
27 |
+
Ego4D/
|
28 |
+
ego4d_train.pkl
|
29 |
+
ego4d_val.pkl
|
30 |
+
egomcq.json
|
31 |
+
video_288px/
|
32 |
+
000786a7-3f9d-4fe6-bfb3-045b368f7d44.mp4/
|
33 |
+
0.mp4
|
34 |
+
300.mp4
|
35 |
+
000a3525-6c98-4650-aaab-be7d2c7b9402.mp4/
|
36 |
+
0.mp4
|
37 |
+
...
|
38 |
+
```
|
39 |
+
|
40 |
+
|
41 |
+
## EPIC-Kitchens-100 (EK-100)
|
42 |
+
|
43 |
+
1. Download annotations
|
44 |
+
|
45 |
+
```bash
|
46 |
+
# Assume that you are under `datasets/EK100/`
|
47 |
+
git clone https://github.com/epic-kitchens/epic-kitchens-100-annotations
|
48 |
+
```
|
49 |
+
|
50 |
+
2. Download videos.
|
51 |
+
|
52 |
+
a. For raw videos, please download them from [https://epic-kitchens.github.io/](https://epic-kitchens.github.io/).
|
53 |
+
|
54 |
+
b. (Recommended) The raw videos are huge (~1 TB). As an alternative, please check out a [resized version]().
|
55 |
+
|
56 |
+
3. (For EK-100 MIR)
|
57 |
+
|
58 |
+
a. Generate the relevancy matrix of train/val splits using [the official code](https://github.com/mwray/Joint-Part-of-Speech-Embeddings).
|
59 |
+
|
60 |
+
b. (Recommended) The generated result has some randomness. Therefore, we also provide the [replica of train split](https://dl.fbaipublicfiles.com/lavila/metadata/EK100/caption_relevancy_EPIC_100_retrieval_train.pkl) and [val split](https://dl.fbaipublicfiles.com/lavila/metadata/EK100/caption_relevancy_EPIC_100_retrieval_test.pkl). Please put them to the folder `$LAVILA_ROOT/datasets/EK100/epic-kitchens-100-annotations/retrieval_annotations/relevancy/`.
|
61 |
+
|
62 |
+
|
63 |
+
The folder should look like this:
|
64 |
+
```bash
|
65 |
+
$LAVILA_ROOT/datasets/
|
66 |
+
EK100/
|
67 |
+
epic-kitchens-100-annotations/
|
68 |
+
EPIC_100_train.csv
|
69 |
+
EPIC_100_validation.csv
|
70 |
+
...
|
71 |
+
retrieval_annotations/relevancy/ # this appears if you do 3.
|
72 |
+
caption_relevancy_EPIC_100_retrieval_train.pkl
|
73 |
+
caption_relevancy_EPIC_100_retrieval_test.pkl
|
74 |
+
video_ht256px/
|
75 |
+
P01/
|
76 |
+
P01_01.MP4
|
77 |
+
P01_02.MP4
|
78 |
+
...
|
79 |
+
P01_19.MP4
|
80 |
+
P02/
|
81 |
+
P02_01.MP4
|
82 |
+
P02_02.MP4
|
83 |
+
...
|
84 |
+
P02_15.MP4
|
85 |
+
...
|
86 |
+
```
|
87 |
+
|
88 |
+
## CharadesEgo
|
89 |
+
|
90 |
+
1. Download annotations at [https://prior.allenai.org/projects/charades-ego](https://prior.allenai.org/projects/charades-ego).
|
91 |
+
```bash
|
92 |
+
### Annotations
|
93 |
+
# Assume that you are under `datasets/CharadesEgo/`
|
94 |
+
wget https://ai2-public-datasets.s3-us-west-2.amazonaws.com/charades/CharadesEgo.zip
|
95 |
+
unzip CharadesEgo.zip && rm CharadesEgo.zip
|
96 |
+
```
|
97 |
+
|
98 |
+
2. Download data (~11GB) at [https://prior.allenai.org/projects/charades-ego](https://prior.allenai.org/projects/charades-ego).
|
99 |
+
```bash
|
100 |
+
### Data
|
101 |
+
wget https://ai2-public-datasets.s3-us-west-2.amazonaws.com/charades/CharadesEgo_v1_480.tar
|
102 |
+
tar -xvf CharadesEgo_v1_480.tar # Or specify an external path using `-C` and sim-link it to here
|
103 |
+
rm CharadesEgo_v1_480.tar
|
104 |
+
```
|
105 |
+
|
106 |
+
3. (For fine-tuning CharadesEgo) Download two additional metadata files: [clip-level metadata (train)](https://dl.fbaipublicfiles.com/lavila/metadata/CharadesEgo/metadata_filtered_train.pkl) and [clip-level metadata (val)](https://dl.fbaipublicfiles.com/lavila/metadata/CharadesEgo/metadata_filtered_val.pkl). Put them to the folder `$LAVILA_ROOT/datasets/CharadesEgo/CharadesEgo/`.
|
107 |
+
|
108 |
+
The folder should look like this:
|
109 |
+
```bash
|
110 |
+
$LAVILA_ROOT/datasets/
|
111 |
+
CharadesEgo/
|
112 |
+
CharadesEgo/
|
113 |
+
CharadesEgo_v1_train_only1st.csv
|
114 |
+
CharadesEgo_v1_test_only1st.csv
|
115 |
+
...
|
116 |
+
metadata_filtered_train.pkl # this appears if you do 3.
|
117 |
+
metadata_filtered_val.pkl # this appears if you do 3.
|
118 |
+
CharadesEgo_v1_480/
|
119 |
+
005BU.mp4
|
120 |
+
005BUEGO.mp4
|
121 |
+
...
|
122 |
+
```
|
123 |
+
|
124 |
+
|
125 |
+
## EGTEA
|
126 |
+
|
127 |
+
1. Visit [https://cbs.ic.gatech.edu/fpv/](https://cbs.ic.gatech.edu/fpv/).
|
128 |
+
|
129 |
+
2. Download `TRIMMED_ACTION_CLIPS` (~20GB) and `ACTION_ANNOTATIONS` and untar to the current folder `$LAVILA_ROOT/datasets/EGTEA`.
|
130 |
+
|
131 |
+
```bash
|
132 |
+
unzip action_annotation.zip -d EGTEA/ && rm action_annotation.zip
|
133 |
+
```
|
134 |
+
|
135 |
+
The folder should look like this:
|
136 |
+
```bash
|
137 |
+
$LAVILA_ROOT/datasets/
|
138 |
+
EGTEA/
|
139 |
+
train_split1.txt
|
140 |
+
test_split1.txt
|
141 |
+
cropped_clips/
|
142 |
+
OP01-R01-PastaSalad/
|
143 |
+
OP01-R01-PastaSalad-1002316-1004005-F024051-F024101.mp4
|
144 |
+
OP01-R01-PastaSalad-1004110-1021110-F024057-F024548.mp4
|
145 |
+
OP01-R01-PastaSalad-1022590-1024050-F024539-F024581.mp4
|
146 |
+
...
|
147 |
+
OP01-R02-TurkeySandwich/
|
148 |
+
OP01-R02-TurkeySandwich-102320-105110-F002449-F002529.mp4
|
149 |
+
OP01-R02-TurkeySandwich-105440-106460-F002528-F002558.mp4
|
150 |
+
OP01-R02-TurkeySandwich-107332-133184-F002513-F003259.mp4
|
151 |
+
...
|
152 |
+
...
|
153 |
+
```
|
demo_narrator.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
|
8 |
+
import argparse
|
9 |
+
import os
|
10 |
+
import urllib.request
|
11 |
+
from collections import OrderedDict
|
12 |
+
|
13 |
+
import decord
|
14 |
+
import torch
|
15 |
+
import torchvision.transforms as transforms
|
16 |
+
import torchvision.transforms._transforms_video as transforms_video
|
17 |
+
|
18 |
+
from lavila.data.video_transforms import Permute
|
19 |
+
from lavila.data.datasets import get_frame_ids, video_loader_by_frames
|
20 |
+
from lavila.models.models import VCLM_OPENAI_TIMESFORMER_LARGE_336PX_GPT2_XL
|
21 |
+
from lavila.models.tokenizer import MyGPT2Tokenizer
|
22 |
+
from eval_narrator import decode_one
|
23 |
+
|
24 |
+
|
25 |
+
def main(args):
|
26 |
+
|
27 |
+
vr = decord.VideoReader(args.video_path)
|
28 |
+
num_seg = 4
|
29 |
+
frame_ids = get_frame_ids(0, len(vr), num_segments=num_seg, jitter=False)
|
30 |
+
frames = video_loader_by_frames('./', args.video_path, frame_ids)
|
31 |
+
|
32 |
+
ckpt_name = 'vclm_openai_timesformer_large_336px_gpt2_xl.pt_ego4d.jobid_246897.ep_0003.md5sum_443263.pth'
|
33 |
+
ckpt_path = os.path.join('modelzoo/', ckpt_name)
|
34 |
+
os.makedirs('modelzoo/', exist_ok=True)
|
35 |
+
if not os.path.exists(ckpt_path):
|
36 |
+
print('downloading model to {}'.format(ckpt_path))
|
37 |
+
urllib.request.urlretrieve('https://dl.fbaipublicfiles.com/lavila/checkpoints/narrator/{}'.format(ckpt_name), ckpt_path)
|
38 |
+
ckpt = torch.load(ckpt_path, map_location='cpu')
|
39 |
+
state_dict = OrderedDict()
|
40 |
+
for k, v in ckpt['state_dict'].items():
|
41 |
+
state_dict[k.replace('module.', '')] = v
|
42 |
+
|
43 |
+
# instantiate the model, and load the pre-trained weights
|
44 |
+
model = VCLM_OPENAI_TIMESFORMER_LARGE_336PX_GPT2_XL(
|
45 |
+
text_use_cls_token=False,
|
46 |
+
project_embed_dim=256,
|
47 |
+
gated_xattn=True,
|
48 |
+
timesformer_gated_xattn=False,
|
49 |
+
freeze_lm_vclm=False, # we use model.eval() anyway
|
50 |
+
freeze_visual_vclm=False, # we use model.eval() anyway
|
51 |
+
num_frames=4,
|
52 |
+
drop_path_rate=0.
|
53 |
+
)
|
54 |
+
model.load_state_dict(state_dict, strict=True)
|
55 |
+
if args.cuda:
|
56 |
+
model.cuda()
|
57 |
+
model.eval()
|
58 |
+
|
59 |
+
# transforms on input frames
|
60 |
+
crop_size = 336
|
61 |
+
val_transform = transforms.Compose([
|
62 |
+
Permute([3, 0, 1, 2]),
|
63 |
+
transforms.Resize(crop_size),
|
64 |
+
transforms.CenterCrop(crop_size),
|
65 |
+
transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305])
|
66 |
+
])
|
67 |
+
frames = val_transform(frames)
|
68 |
+
frames = frames.unsqueeze(0) # fake a batch dimension
|
69 |
+
|
70 |
+
tokenizer = MyGPT2Tokenizer('gpt2-xl', add_bos=True)
|
71 |
+
with torch.no_grad():
|
72 |
+
if args.cuda:
|
73 |
+
frames = frames.cuda(non_blocking=True)
|
74 |
+
image_features = model.encode_image(frames)
|
75 |
+
generated_text_ids, ppls = model.generate(
|
76 |
+
image_features,
|
77 |
+
tokenizer,
|
78 |
+
target=None, # free-form generation
|
79 |
+
max_text_length=77,
|
80 |
+
top_k=None,
|
81 |
+
top_p=0.95, # nucleus sampling
|
82 |
+
num_return_sequences=10, # number of candidates: 10
|
83 |
+
temperature=0.7,
|
84 |
+
early_stopping=True,
|
85 |
+
)
|
86 |
+
|
87 |
+
for i in range(10):
|
88 |
+
generated_text_str = decode_one(generated_text_ids[i], tokenizer)
|
89 |
+
print('{}: {}'.format(i, generated_text_str))
|
90 |
+
|
91 |
+
|
92 |
+
if __name__ == '__main__':
|
93 |
+
parser = argparse.ArgumentParser('lavila narrator demo')
|
94 |
+
parser.add_argument('--cuda', action='store_true', help='use cuda')
|
95 |
+
parser.add_argument('--video-path', default='assets/3c0dffd0-e38e-4643-bc48-d513943dc20b_012_014.mp4', type=str, help='video path')
|
96 |
+
args = parser.parse_args()
|
97 |
+
main(args)
|
demo_narrator_3rd_person.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
|
8 |
+
import argparse
|
9 |
+
import os
|
10 |
+
import urllib.request
|
11 |
+
from collections import OrderedDict
|
12 |
+
|
13 |
+
import decord
|
14 |
+
import torch
|
15 |
+
import torchvision.transforms as transforms
|
16 |
+
import torchvision.transforms._transforms_video as transforms_video
|
17 |
+
|
18 |
+
from lavila.data.video_transforms import Permute
|
19 |
+
from lavila.data.datasets import get_frame_ids, video_loader_by_frames
|
20 |
+
from lavila.models.models import VCLM_OPENAI_TIMESFORMER_LARGE_GPT2_XL
|
21 |
+
from lavila.models.tokenizer import MyGPT2Tokenizer
|
22 |
+
from eval_narrator import decode_one
|
23 |
+
|
24 |
+
|
25 |
+
def main(args):
|
26 |
+
|
27 |
+
vr = decord.VideoReader(args.video_path)
|
28 |
+
num_seg = 4
|
29 |
+
frame_ids = get_frame_ids(0, len(vr), num_segments=num_seg, jitter=False)
|
30 |
+
frames = video_loader_by_frames('./', args.video_path, frame_ids)
|
31 |
+
|
32 |
+
ckpt_name = 'vclm_openai_timesformer_large_gpt2_xl.pt_htm.jobid_341080.ep_0001.pth'
|
33 |
+
ckpt_path = os.path.join('modelzoo/', ckpt_name)
|
34 |
+
os.makedirs('modelzoo/', exist_ok=True)
|
35 |
+
if not os.path.exists(ckpt_path):
|
36 |
+
print('downloading model to {}'.format(ckpt_path))
|
37 |
+
urllib.request.urlretrieve('https://dl.fbaipublicfiles.com/lavila/checkpoints/narrator/htm_aa/{}'.format(ckpt_name), ckpt_path)
|
38 |
+
ckpt = torch.load(ckpt_path, map_location='cpu')
|
39 |
+
state_dict = OrderedDict()
|
40 |
+
for k, v in ckpt['state_dict'].items():
|
41 |
+
state_dict[k.replace('module.', '')] = v
|
42 |
+
|
43 |
+
# instantiate the model, and load the pre-trained weights
|
44 |
+
model = VCLM_OPENAI_TIMESFORMER_LARGE_GPT2_XL(
|
45 |
+
text_use_cls_token=False,
|
46 |
+
project_embed_dim=256,
|
47 |
+
gated_xattn=True,
|
48 |
+
timesformer_gated_xattn=False,
|
49 |
+
freeze_lm_vclm=False, # we use model.eval() anyway
|
50 |
+
freeze_visual_vclm=False, # we use model.eval() anyway
|
51 |
+
freeze_visual_vclm_temporal=False,
|
52 |
+
num_frames=4,
|
53 |
+
drop_path_rate=0.
|
54 |
+
)
|
55 |
+
model.load_state_dict(state_dict, strict=True)
|
56 |
+
if args.cuda:
|
57 |
+
model.cuda()
|
58 |
+
model.eval()
|
59 |
+
|
60 |
+
# transforms on input frames
|
61 |
+
crop_size = 224
|
62 |
+
val_transform = transforms.Compose([
|
63 |
+
Permute([3, 0, 1, 2]),
|
64 |
+
transforms.Resize(crop_size),
|
65 |
+
transforms.CenterCrop(crop_size),
|
66 |
+
transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305])
|
67 |
+
])
|
68 |
+
frames = val_transform(frames)
|
69 |
+
frames = frames.unsqueeze(0) # fake a batch dimension
|
70 |
+
|
71 |
+
tokenizer = MyGPT2Tokenizer('gpt2-xl', add_bos=True)
|
72 |
+
with torch.no_grad():
|
73 |
+
if args.cuda:
|
74 |
+
frames = frames.cuda(non_blocking=True)
|
75 |
+
image_features = model.encode_image(frames)
|
76 |
+
generated_text_ids, ppls = model.generate(
|
77 |
+
image_features,
|
78 |
+
tokenizer,
|
79 |
+
target=None, # free-form generation
|
80 |
+
max_text_length=77,
|
81 |
+
top_k=None,
|
82 |
+
top_p=0.95, # nucleus sampling
|
83 |
+
num_return_sequences=10, # number of candidates: 10
|
84 |
+
temperature=0.7,
|
85 |
+
early_stopping=True,
|
86 |
+
)
|
87 |
+
|
88 |
+
for i in range(10):
|
89 |
+
generated_text_str = decode_one(generated_text_ids[i], tokenizer)
|
90 |
+
print('{}: {}'.format(i, generated_text_str))
|
91 |
+
|
92 |
+
|
93 |
+
if __name__ == '__main__':
|
94 |
+
parser = argparse.ArgumentParser('lavila narrator demo')
|
95 |
+
parser.add_argument('--cuda', action='store_true', help='use cuda')
|
96 |
+
parser.add_argument('--video-path', type=str,
|
97 |
+
default='assets/mixkit-pastry-chef-cutting-a-loaf-into-slices-43015-medium.mp4')
|
98 |
+
args = parser.parse_args()
|
99 |
+
main(args)
|
docs/INSTALL.md
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Installation
|
2 |
+
|
3 |
+
## Requirements
|
4 |
+
|
5 |
+
|
6 |
+
## Example conda environment setup
|
7 |
+
|
8 |
+
```bash
|
9 |
+
conda create --name lavila python=3.8 -y
|
10 |
+
conda activate lavila
|
11 |
+
pip install -r requirements.txt
|
12 |
+
```
|
13 |
+
|
14 |
+
## datasets
|
15 |
+
If you want to train/evaluate on the datasets, please see [datasets/README.md](../datasets/README.md) to see how we prepare datasets for this project.
|
docs/MODEL_ZOO.md
ADDED
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# LAVILA Model Zoo
|
2 |
+
|
3 |
+
## Multi-node Training
|
4 |
+
We use multi-node training on a SLURM cluster with [submitit](https://github.com/facebookincubator/submitit) for producing the results and models in the paper.
|
5 |
+
Please install `submitit` in your conda environment:
|
6 |
+
```bash
|
7 |
+
pip install submitit
|
8 |
+
```
|
9 |
+
|
10 |
+
|
11 |
+
## Pre-training
|
12 |
+
|
13 |
+
Please refer to [PRETRAIN.md](./PRETRAIN.md).
|
14 |
+
|
15 |
+
|
16 |
+
## Narrator
|
17 |
+
|
18 |
+
| Visual Encoder | Text Decoder | METEOR | ROUGE-L | CIDEr | Pre-trained<br>Vis. Encoder (md5) | checkpoint (md5) |
|
19 |
+
| :------------: | :----------: | :----: | :-----: | :---: | :-------------------------------: | :--------: |
|
20 |
+
| TSF-B | GPT-2 | 0.282 | 0.517 | 0.833 | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/clip_openai_timesformer_base.baseline.ep_0003.pth) (dbcc4d) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/narrator/vclm_openai_timesformer_base_gpt2_base.pt_ego4d.jobid_319630.ep_0002.md5sum_68a71f.pth) (68a71f) |
|
21 |
+
| TSF-L@HR | GPT-2 XL | 0.298 | 0.539 | 0.977 | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/clip_openai_timesformer_large_336px_distilbert_base.baseline.ep_0003.pth) (5c69b8) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/narrator/vclm_openai_timesformer_large_336px_gpt2_xl.pt_ego4d.jobid_246897.ep_0003.md5sum_443263.pth) (443263) |
|
22 |
+
|
23 |
+
|
24 |
+
<details><summary>Ego4D val split</summary>
|
25 |
+
<p>
|
26 |
+
|
27 |
+
```bash
|
28 |
+
torchrun --nproc_per_node=1 \
|
29 |
+
eval_narrator.py \
|
30 |
+
--caption-top-p 0.95 --caption-temperature 0.7 \
|
31 |
+
--eval-freq 10000 \
|
32 |
+
--resume $CHECKPOINT
|
33 |
+
```
|
34 |
+
|
35 |
+
</p></details>
|
36 |
+
|
37 |
+
## Zero-shot
|
38 |
+
|
39 |
+
<div class="table-wrapper" markdown="block">
|
40 |
+
|
41 |
+
| | Backbone | EK-100 MIR<br>avg. mAP | EK-100 MIR<br>avg. nDCG | Charades-Ego<br>mAP^ | EGTEA<br> mean acc. | EgoMCQ<br>intra-video acc. | checkpoint |
|
42 |
+
| :----------: | :------: | :--------------------: | :---------------------: | :------------------: | :-----------------: | :------------------------: | :----------: |
|
43 |
+
| Prev. SOTA^^ | TSF-B | 22.1/23.3 | 22.1/27.9 | 25.2 | 17.6 | 57.2 | [Epoch 1](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/egovlp_epo1_converted_f16.md5sum_7a3d3b.pth), [best epoch](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/egovlp_converted_f16.md5sum_c33363.pth) |
|
44 |
+
| LAVILA | TSF-B | 29.7/30.9 | 31.5/32.0 | 26.8 | 28.9 | 59.9 | [Epoch 1](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/clip_openai_timesformer_base.narrator_rephraser.ep_0001.md5sum_02dbb9.pth)^, [Epoch 5](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/clip_openai_timesformer_base.narrator_rephraser.ep_0005.md5sum_d73a9c.pth) |
|
45 |
+
| LAVILA | TSF-L | 35.0/36.1 | 34.2/34.6 | 28.9 | 34.1 | 63.1 | [Epoch 1](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/clip_openai_timesformer_large.narrator_rephraser.ep_0001.md5sum_9a25de.pth)^, [Epoch 3](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/clip_openai_timesformer_large.narrator_rephraser.ep_0003.md5sum_c89337.pth) |
|
46 |
+
|
47 |
+
</div>
|
48 |
+
|
49 |
+
^ Note that the pre-trained checkpoint to evaluate CharadesEgo is different from that to evalute other datasets.
|
50 |
+
Specifically, we use the checkpoint at epoch 1 to zero-shot evaluate CharadesEgo and the checkpoint that achieves best average mAP on EK-100 MIR to evaluate other datasets, as is done in [EgoVLP](https://arxiv.org/pdf/2206.01670.pdf).
|
51 |
+
Our guess is that since CharadesEgo videos (captured by head-mounted mobile cameras) are visually different from Ego4D/EPIC-Kitchens videos (captured by professional action cameras, eg GoPro), pre-training on Ego4D videos for longer will lead to some potential domain discrepancy.
|
52 |
+
|
53 |
+
^^ We use the checkpoints released by [EgoVLP](https://github.com/showlab/EgoVLP) and convert them to be compatible with this codebase. Also note that our reproduced numbers are better than the reported numbers, especially on EK-100 MIR since we evaluate on raw videos directly (for more details, check out Appendix F & Table 10 in our paper).
|
54 |
+
|
55 |
+
<details><summary>1. EK-100 MIR</summary>
|
56 |
+
<p>
|
57 |
+
|
58 |
+
```bash
|
59 |
+
python eval_zeroshot.py --dataset ek100_mir --root datasets/EK100/video_ht256px/ --clip-length 4 --resume $PATH
|
60 |
+
```
|
61 |
+
By increasing the number of frames per clip, eg `--clip-length 16`, you are expected to see a better performance.
|
62 |
+
|
63 |
+
</p></details>
|
64 |
+
|
65 |
+
<details><summary>2. EK-100 CLS</summary>
|
66 |
+
<p>
|
67 |
+
|
68 |
+
```bash
|
69 |
+
python eval_zeroshot.py --dataset ek100_cls --metadata-val datasets/EK100/epic-kitchens-100-annotations/EPIC_100_validation.csv --resume $PATH
|
70 |
+
```
|
71 |
+
|
72 |
+
</p></details>
|
73 |
+
|
74 |
+
<details><summary>3. Charades-Ego</summary>
|
75 |
+
<p>
|
76 |
+
|
77 |
+
```bash
|
78 |
+
python eval_zeroshot.py --dataset charades_ego --metadata-val datasets/CharadesEgo/CharadesEgo/CharadesEgo_v1_test_only1st.csv --root datasets/CharadesEgo/CharadesEgo_v1_480/ --clip-length 16 --sparse-sample --resume $PATH
|
79 |
+
```
|
80 |
+
|
81 |
+
</p></details>
|
82 |
+
|
83 |
+
<details><summary>4. EGTEA</summary>
|
84 |
+
<p>
|
85 |
+
|
86 |
+
```bash
|
87 |
+
python eval_zeroshot.py --dataset egtea --metadata-val datasets/EGTEA/test_split1.txt --root datasets/EGTEA/cropped_clips/ --clip-length 16 --clip-stride 2 --num-crops 3 --num-clips 10 --resume $PATH
|
88 |
+
```
|
89 |
+
|
90 |
+
</p></details>
|
91 |
+
|
92 |
+
<details><summary>5. EgoMCQ</summary>
|
93 |
+
<p>
|
94 |
+
|
95 |
+
```bash
|
96 |
+
python eval_zeroshot.py --dataset ego4d_mcq --metadata-val datasets/Ego4D/egomcq.json --root datasets/Ego4D/video_5min_chunks_288px/ --clip-length 4 --resume $PATH --use-half -j 4
|
97 |
+
```
|
98 |
+
|
99 |
+
</p></details>
|
100 |
+
|
101 |
+
## Fine-tuned
|
102 |
+
|
103 |
+
### EK-100 MIR
|
104 |
+
|
105 |
+
<div class="table-wrapper" markdown="block">
|
106 |
+
|
107 |
+
| | Backbone | avg mAP | avg nDCG | Pretrain (md5) | Fine-tuned checkpoint | training log |
|
108 |
+
| :----: | :-------:| :-----: | :------: | :----------: | :-------------------: | :----------: |
|
109 |
+
| LAVILA | TSF-B | 50.5 | 65.0 | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/clip_openai_timesformer_base.narrator_rephraser.ep_0005.md5sum_d73a9c.pth) (d73a9c) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ek100_mir/clip_openai_timesformer_base.ft_ek100_mir.ep_0085.md5sum_c67d95.pth) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ek100_mir/clip_openai_timesformer_base.ft_ek100_mir.jobid_57361.log) |
|
110 |
+
| LAVILA | TSF-L | 50.9 | 66.5 | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/clip_openai_timesformer_large.narrator_rephraser.ep_0003.md5sum_c89337.pth) (c89337) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ek100_mir/clip_openai_timesformer_large.ft_ek100_mir.ep_0095.md5sum_bd508b.pth) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ek100_mir/clip_openai_timesformer_large.ft_ek100_mir.jobid_56606.log) |
|
111 |
+
|
112 |
+
</div>
|
113 |
+
|
114 |
+
|
115 |
+
<details><summary>Training and evaluating scripts</summary>
|
116 |
+
<p>
|
117 |
+
|
118 |
+
### Multi-node training (Slurm)
|
119 |
+
```bash
|
120 |
+
# TimeSformer-Base
|
121 |
+
python run_with_submitit_finetune_retrieval.py \
|
122 |
+
--pretrain-model $PATH \
|
123 |
+
--use-checkpoint --nodes 4
|
124 |
+
|
125 |
+
# TimeSformer-Large
|
126 |
+
python run_with_submitit_finetune_retrieval.py \
|
127 |
+
--pretrain-model $PATH \
|
128 |
+
--batch-size 4 \
|
129 |
+
--use-checkpoint --nodes 4
|
130 |
+
```
|
131 |
+
|
132 |
+
### Single-machine training
|
133 |
+
```bash
|
134 |
+
torchrun --nproc_per_node=8 \
|
135 |
+
main_finetune_retrieval.py \
|
136 |
+
--output-dir $OUT_DIR \
|
137 |
+
--pretrain-model $PATH \
|
138 |
+
--use-checkpoint
|
139 |
+
```
|
140 |
+
|
141 |
+
Note that you might see a slight drop of performance when training on a single node compared to multiple nodes (everything else being the same) because of a smaller total batch size.
|
142 |
+
|
143 |
+
### Evaluation
|
144 |
+
|
145 |
+
Evaluation is done every `--eval-freq 5` epochs by default during fine-tuning.
|
146 |
+
If you want to evaluate any checkpoint after fine-tuning, please switch to `--evaluate` mode and specify the path to the checkpoint by `--resume $FINETUNED_CHECKPOINT`.
|
147 |
+
```bash
|
148 |
+
torchrun --nproc_per_node=1 \
|
149 |
+
main_finetune_retrieval.py \
|
150 |
+
--output-dir $OUT_DIR \
|
151 |
+
--pretrain-model $PATH \
|
152 |
+
--use-checkpoint \
|
153 |
+
--evaluate \
|
154 |
+
--resume $FINETUNED_CHECKPOINT
|
155 |
+
```
|
156 |
+
|
157 |
+
|
158 |
+
</p></details>
|
159 |
+
|
160 |
+
### CharadesEgo
|
161 |
+
|
162 |
+
<div class="table-wrapper" markdown="block">
|
163 |
+
|
164 |
+
| | Backbone | video mAP |Pretrain^ (md5) | Fine-tuned checkpoint | training log |
|
165 |
+
| :----: | :-------:| :------: | :-------: | :-------------------: | :----------: |
|
166 |
+
| LAVILA | TSF-B | 33.7 | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/clip_openai_timesformer_base.narrator_rephraser.ep_0001.md5sum_02dbb9.pth) (02dbb9) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/charades_ego/clip_openai_timesformer_base.ft_charades_ego.ep_0005.md5sum_39bf4b.pth) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/charades_ego/clip_openai_timesformer_base.ft_charades_ego.jobid_65760.log) |
|
167 |
+
| LAVILA | TSF-L | 36.1 | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/clip_openai_timesformer_large.narrator_rephraser.ep_0001.md5sum_9a25de.pth) (9a25de) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/charades_ego/clip_openai_timesformer_large.ft_charades_ego.ep_0003.md5sum_9448b2.pth) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/charades_ego/clip_openai_timesformer_large.ft_charades_ego.jobid_65760.log) |
|
168 |
+
|
169 |
+
</div>
|
170 |
+
|
171 |
+
^ Note that the pre-trained checkpoint for fine-tuning CharadesEgo is different from that for fine-tuning EK-100 or EGTEA. Same reason stated above.
|
172 |
+
|
173 |
+
<details><summary>Training and evaluating scripts</summary>
|
174 |
+
<p>
|
175 |
+
|
176 |
+
### Multi-node training (Slurm)
|
177 |
+
|
178 |
+
```bash
|
179 |
+
# TimeSformer-Base
|
180 |
+
python run_with_submitit_finetune_retrieval.py \
|
181 |
+
--dataset charades_ego \
|
182 |
+
--metadata datasets/CharadesEgo/CharadesEgo/metadata_filtered_train.pkl \
|
183 |
+
--metadata-val datasets/CharadesEgo/CharadesEgo/CharadesEgo_v1_test_only1st.csv \
|
184 |
+
--root datasets/CharadesEgo/CharadesEgo_v1_480/ \
|
185 |
+
--epochs 10 \
|
186 |
+
--save-freq 1 --eval-freq 1 \
|
187 |
+
--sparse-sample \
|
188 |
+
--pretrain-model $PATH \
|
189 |
+
--use-checkpoint --nodes 4
|
190 |
+
|
191 |
+
# TimeSformer-Large
|
192 |
+
python run_with_submitit_finetune_retrieval.py \
|
193 |
+
--dataset charades_ego \
|
194 |
+
--metadata datasets/CharadesEgo/CharadesEgo/metadata_filtered_train.pkl \
|
195 |
+
--metadata-val datasets/CharadesEgo/CharadesEgo/CharadesEgo_v1_test_only1st.csv \
|
196 |
+
--root datasets/CharadesEgo/CharadesEgo_v1_480/ \
|
197 |
+
--epochs 10 \
|
198 |
+
--save-freq 1 --eval-freq 1 \
|
199 |
+
--sparse-sample \
|
200 |
+
--pretrain-model $PATH \
|
201 |
+
--batch-size 4 \
|
202 |
+
--use-checkpoint --nodes 4
|
203 |
+
```
|
204 |
+
|
205 |
+
### Evaluation
|
206 |
+
```bash
|
207 |
+
torchrun --nproc_per_node=1 \
|
208 |
+
main_finetune_retrieval.py \
|
209 |
+
--dataset charades_ego \
|
210 |
+
--metadata datasets/CharadesEgo/CharadesEgo/metadata_filtered_train.pkl \
|
211 |
+
--metadata-val datasets/CharadesEgo/CharadesEgo/CharadesEgo_v1_test_only1st.csv \
|
212 |
+
--root datasets/CharadesEgo/CharadesEgo_v1_480/ \
|
213 |
+
--output-dir $OUT_DIR \
|
214 |
+
--sparse-sample \
|
215 |
+
--pretrain-model $PATH \
|
216 |
+
--evaluate \
|
217 |
+
--resume $FINETUNED_CHECKPOINT
|
218 |
+
```
|
219 |
+
|
220 |
+
</p></details>
|
221 |
+
|
222 |
+
### EK-100 CLS
|
223 |
+
|
224 |
+
<div class="table-wrapper" markdown="block">
|
225 |
+
|
226 |
+
| | Backbone | V+N+A multi-head | Verb top-1 | Noun top-1 | Action top-1 | Pretrain (md5) | Fine-tuned checkpoint | training log |
|
227 |
+
| :----: | :-------:| :--------------: | :--------: | :--------: | :---------: | :------------: | :-------------------: | :----------: |
|
228 |
+
| LAVILA | TSF-B | no | 67.7 | 56.7 | 46.2 | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/clip_openai_timesformer_base.narrator_rephraser.ep_0005.md5sum_d73a9c.pth) (d73a9c) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ek100_cls/clip_openai_timesformer_base.ft_ek100_cls.single_head.ep_0100.md5sum_e8aa0c.pth) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ek100_cls/clip_openai_timesformer_base.ft_ek100_cls.single_head.jobid_73363.log) |
|
229 |
+
| LAVILA | TSF-B | yes | 69.0 | 58.4 | 46.9 | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/clip_openai_timesformer_base.narrator_rephraser.ep_0005.md5sum_d73a9c.pth) (d73a9c) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ek100_cls/clip_openai_timesformer_base.ft_ek100_cls.ep_0100.md5sum_4e3575.pth) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ek100_cls/clip_openai_timesformer_base.ft_ek100_cls.jobid_73361.log) |
|
230 |
+
| LAVILA | TSF-L | yes | 72.0 | 62.9 | 51.0 | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/clip_openai_timesformer_large.narrator_rephraser.ep_0003.md5sum_c89337.pth) (c89337) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ek100_cls/clip_openai_timesformer_large.ft_ek100_cls.ep_0090.md5sum_4a2509.pth) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ek100_cls/clip_openai_timesformer_large.ft_ek100_cls.jobid_74016.log) |
|
231 |
+
</div>
|
232 |
+
|
233 |
+
<details><summary>Training and evaluating scripts</summary>
|
234 |
+
<p>
|
235 |
+
|
236 |
+
### Multi-node training (Slurm)
|
237 |
+
|
238 |
+
```bash
|
239 |
+
# TimeSformer-Base
|
240 |
+
python run_with_submitit_finetune_classification.py \
|
241 |
+
--pretrain-model $PATH \
|
242 |
+
--use-vn-classifier --num-classes 97 300 3806 \
|
243 |
+
--use-sgd --wd 4e-5 --lr-multiplier-on-backbone 0.1 \
|
244 |
+
--use-checkpoint --node 1
|
245 |
+
|
246 |
+
# TimeSformer-Large
|
247 |
+
python run_with_submitit_finetune_classification.py \
|
248 |
+
--pretrain-model $PATH \
|
249 |
+
--use-vn-classifier --num-classes 97 300 3806 \
|
250 |
+
--use-sgd --wd 4e-5 --lr-multiplier-on-backbone 0.1 \
|
251 |
+
--use-checkpoint --node 4
|
252 |
+
```
|
253 |
+
|
254 |
+
</p></details>
|
255 |
+
|
256 |
+
### EGTEA
|
257 |
+
|
258 |
+
<div class="table-wrapper" markdown="block">
|
259 |
+
|
260 |
+
| | Backbone | mean Acc. | Pretrain (md5) | Fine-tuned checkpoint | training log |
|
261 |
+
| :----: | :-------:| :-------: | :------: | :-------------------: | :----------: |
|
262 |
+
| LAVILA | TSF-B | 70.12 | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/clip_openai_timesformer_base.narrator_rephraser.ep_0005.md5sum_d73a9c.pth) (d73a9c) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/egtea/clip_openai_timesformer_base.ft_egtea.ep_0090.md5sum_3b1faf.pth) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/egtea/clip_openai_timesformer_base.ft_egtea.jobid_73358.log) |
|
263 |
+
| LAVILA | TSF-L | 76.00 | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/ego4d/clip_openai_timesformer_large.narrator_rephraser.ep_0003.md5sum_c89337.pth) (c89337) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/egtea/clip_openai_timesformer_large.ft_egtea.ep_0095.md5sum_a5ba17.pth) | [download](https://dl.fbaipublicfiles.com/lavila/checkpoints/dual_encoders/egtea/clip_openai_timesformer_large.ft_egtea.jobid_74026.log) |
|
264 |
+
|
265 |
+
</div>
|
266 |
+
|
267 |
+
<details><summary>Training and evaluating scripts</summary>
|
268 |
+
<p>
|
269 |
+
|
270 |
+
```bash
|
271 |
+
# TimeSformer-Base
|
272 |
+
python run_with_submitit_finetune_classification.py \
|
273 |
+
--dataset egtea \
|
274 |
+
--metadata-train datasets/EGTEA/train_split1.txt \
|
275 |
+
--metadata-val datasets/EGTEA/test_split1.txt \
|
276 |
+
--root datasets/EGTEA/cropped_clips/ \
|
277 |
+
--pretrain-model $PATH \
|
278 |
+
--num-classes 106 \
|
279 |
+
--use-sgd --wd 4e-5 \
|
280 |
+
--use-checkpoint --node 1
|
281 |
+
|
282 |
+
# TimeSformer-Large
|
283 |
+
python run_with_submitit_finetune_classification.py \
|
284 |
+
--dataset egtea \
|
285 |
+
--metadata-train datasets/EGTEA/train_split1.txt \
|
286 |
+
--metadata-val datasets/EGTEA/test_split1.txt \
|
287 |
+
--root datasets/EGTEA/cropped_clips/ \
|
288 |
+
--pretrain-model $PATH \
|
289 |
+
--num-classes 106 \
|
290 |
+
--use-sgd --wd 4e-5 \
|
291 |
+
--batch-size 4 \
|
292 |
+
--use-checkpoint --node 4
|
293 |
+
```
|
294 |
+
### Evaluation
|
295 |
+
```bash
|
296 |
+
torchrun --nproc_per_node=1 \
|
297 |
+
main_finetune_classification.py \
|
298 |
+
--dataset egtea \
|
299 |
+
--metadata-train datasets/EGTEA/train_split1.txt \
|
300 |
+
--metadata-val datasets/EGTEA/test_split1.txt \
|
301 |
+
--root datasets/EGTEA/cropped_clips/ \
|
302 |
+
--output-dir $OUT_DIR \
|
303 |
+
--pretrain-model $PATH \
|
304 |
+
--num-classes 106 \
|
305 |
+
--use-sgd --wd 4e-5 \
|
306 |
+
--evaluate \
|
307 |
+
--resume $FINETUNED_CHECKPOINT \
|
308 |
+
--num-crops 3 --num-clips 10 \
|
309 |
+
--use-half
|
310 |
+
```
|
311 |
+
</p></details>
|
docs/PRETRAIN.md
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# LAVILA Pretraining
|
2 |
+
|
3 |
+
In this doc, we provide a step-by-step guide (with commands) to train LaViLa.
|
4 |
+
Note that we recommend running the following job with four 8x V100 (32GB) nodes (or eight nodes for the larger backbone) using [submitit](https://github.com/facebookincubator/submitit).
|
5 |
+
See how to install submitit at [here](./MODEL_ZOO.md#multi-node-training).
|
6 |
+
|
7 |
+
|
8 |
+
## Pre-training Dual-Encoder Baseline
|
9 |
+
|
10 |
+
We first pre-train a dual-encoder baseline with human annotations on Ego4d clips.
|
11 |
+
The goal is (1) to establish a comparable baseline for LAVILA, and (2) provide a video encoder for narrator (see below).
|
12 |
+
We use a default batch size of 32 per gpu so that the total batch size for InfoNCE loss is `32*8*4=1024`.
|
13 |
+
|
14 |
+
<details><summary> Train a baseline dual-encoder (with TSF-B) </summary>
|
15 |
+
|
16 |
+
```bash
|
17 |
+
python run_with_submitit_pretrain.py --model CLIP_OPENAI_TIMESFORMER_BASE \
|
18 |
+
--norm-embed --freeze-temperature \
|
19 |
+
--fix-lr --contrastive-use-vissl \
|
20 |
+
--nodes 4 --use_volta32
|
21 |
+
```
|
22 |
+
</details>
|
23 |
+
|
24 |
+
To fit a High-Resolution TimeSformer-Large with a sufficient batch size, we use [DistilBERT](https://huggingface.co/docs/transformers/model_doc/distilbert), a memory-efficient text encoder, instead of the original text encoder in the CLIP. Additionally we apply [gradient checkpointing](https://pytorch.org/docs/stable/checkpoint.html) and [Zero Redundancy Optimizer (ZeRO)](https://arxiv.org/abs/1910.02054).
|
25 |
+
|
26 |
+
<details><summary> Train a baseline dual-encoder (with TSF-L@HR) </summary>
|
27 |
+
|
28 |
+
```bash
|
29 |
+
python run_with_submitit_pretrain.py --model CLIP_OPENAI_TIMESFORMER_LARGE_336PX_DISTILBERT_BASE \
|
30 |
+
--batch-size 8 \
|
31 |
+
--use-checkpoint --use-zero \
|
32 |
+
--norm-embed --freeze-temperature \
|
33 |
+
--fix-lr --contrastive-use-vissl \
|
34 |
+
--nodes 8 --use_volta32
|
35 |
+
```
|
36 |
+
</details>
|
37 |
+
|
38 |
+
## Training and Evaluating Narrator
|
39 |
+
|
40 |
+
The narrator is a *visually conditioned* large language model (VCLM), which comprises a pre-trained video encoder (obtained above), a text decoder (GPT-2 family), and a few gated cross-attention modules that attends visual information while captioning. Both the video encoder and the text decoder are kept frozen while the cross-attention modules are learnable.
|
41 |
+
|
42 |
+
Note that we turn off Pytorch's automatic mixed-precision (AMP) during training the narrator. We observe training is instable if AMP is on.
|
43 |
+
|
44 |
+
Also note that `$PATH` can be found in the `Vis. Encoder` column of [MODEL_ZOO.md#Narrator](./MODEL_ZOO.md#narrator). If you are using your own checkpoint (e.g. pre-trained in the previous step), please make sure that the following keys in the checkpoint have been dropped: `epoch`, `optimizer`, and `scaler`.
|
45 |
+
|
46 |
+
<details><summary> Train a baseline narrator (TSF-B as visual encoder and GPT-2 base as textual decoder) </summary>
|
47 |
+
|
48 |
+
```bash
|
49 |
+
python run_with_submitit_pretrain.py \
|
50 |
+
--model VCLM_OPENAI_TIMESFORMER_BASE_GPT2 \
|
51 |
+
--gated-xattn --freeze-lm-vclm --freeze-visual-vclm --freeze-visual-vclm-temporal \
|
52 |
+
--fix-lr --batch-size 8 --clip-grad-value 1.0 --eval-freq 1 --disable-amp \
|
53 |
+
--nodes 4 --use_volta32 --resume $PATH # Eg. $PATH can be "modelzoo/clip_openai_timesformer_base.baseline.ep_0003.pth"
|
54 |
+
```
|
55 |
+
|
56 |
+
</details>
|
57 |
+
|
58 |
+
<details><summary> Train a strong narrator (TSF-L@HR as visual encoder and GPT-2 XL as textual decoder) </summary>
|
59 |
+
|
60 |
+
```bash
|
61 |
+
python run_with_submitit_pretrain.py \
|
62 |
+
--model VCLM_OPENAI_TIMESFORMER_LARGE_336PX_GPT2_XL \
|
63 |
+
--gated-xattn --freeze-lm-vclm --freeze-visual-vclm --freeze-visual-vclm-temporal --use-checkpoint \
|
64 |
+
--fix-lr --batch-size 8 --clip-grad-value 1.0 --eval-freq 1 --disable-amp \
|
65 |
+
--nodes 4 --use_volta32 --resume $PATH # Eg. $PATH can be "modelzoo/clip_openai_timesformer_large_336px_distilbert_base.baseline.ep_0003.pth"
|
66 |
+
```
|
67 |
+
</details>
|
68 |
+
|
69 |
+
<details><summary> Evaluate the narrator on Ego4D val split </summary>
|
70 |
+
|
71 |
+
```bash
|
72 |
+
torchrun --nproc_per_node=1 eval_narrator.py \
|
73 |
+
--caption-top-p 0.95 --caption-temperature 0.7 \
|
74 |
+
--eval-freq 10000 \ # evaluate on the val split of Ego4D (1/10000-subset for fast evaluation)
|
75 |
+
--resume $VCLM_CHECKPOINT
|
76 |
+
```
|
77 |
+
This will output some common NLG metrics, such as BLEU-x, METEOR, ROUGE_L, and CIDEr (using the human narrations as ground-truth).
|
78 |
+
</details>
|
79 |
+
|
80 |
+
## Narrating video clips using LAVILA-Narrator
|
81 |
+
|
82 |
+
|
83 |
+
<details><summary> Infer the narrator </summary>
|
84 |
+
|
85 |
+
```bash
|
86 |
+
python run_with_submitit_infer_narrator.py \
|
87 |
+
--metadata datasets/Ego4D/ego4d_train.pkl \
|
88 |
+
--batch-size 64 \
|
89 |
+
--resume $PATH --use-half \
|
90 |
+
--nodes 4 --use_volta32
|
91 |
+
```
|
92 |
+
</details>
|
93 |
+
|
94 |
+
It will generate a pickle file (`$output_dir/total.pkl`) which is a list of quintuples - `(video_uid: str, start_time: float, end_time: float, narration_list: List[str], NLL_list: List[float])`.
|
95 |
+
|
96 |
+
For narrator-generated narrations on Ego4D ground-truth clips, we also provide a [replica](https://dl.fbaipublicfiles.com/lavila/metadata/ego4d/ego4d_train.narrator_63690737.return_10.pkl). Note that the narrator used here is our best performing one.
|
97 |
+
|
98 |
+
## Rephrasing human narrations using LAVILA-Rephraser
|
99 |
+
|
100 |
+
Rephraser is a standard LLM that can paraphrase narrations in existing clips.
|
101 |
+
Specifically, we use an off-the-shelf T5-based paraphraser which is publicly available at [Hugging Face's model hub](https://huggingface.co/ramsrigouthamg/t5-large-paraphraser-diverse-high-quality).
|
102 |
+
For more details, please refer to the [model card](https://huggingface.co/ramsrigouthamg/t5-large-paraphraser-diverse-high-quality).
|
103 |
+
|
104 |
+
For rephrased human narrations on Ego4D ground-truth clips, we provide a [replica](https://dl.fbaipublicfiles.com/lavila/metadata/ego4d/ego4d_train.rephraser.no_punkt_top3.pkl).
|
105 |
+
|
106 |
+
|
107 |
+
## Pre-training LAVILA Dual-Encoder
|
108 |
+
Now we are ready to pre-train our LAVILA's dual-encoder by combining human annotations (augmented by Rephraser) and the Narrator-generated narrations.
|
109 |
+
|
110 |
+
<details><summary> Training a LaViLa dual-encoder </summary>
|
111 |
+
|
112 |
+
```bash
|
113 |
+
python run_with_submitit_pretrain.py --model CLIP_OPENAI_TIMESFORMER_BASE \
|
114 |
+
--metadata datasets/Ego4D/ego4d_train.rephraser.no_punkt_top3.pkl \
|
115 |
+
--metadata-aux datasets/Ego4D/ego4d_train.narrator_63690737.return_10.pkl \
|
116 |
+
--norm-embed --freeze-temperature \
|
117 |
+
--freeze-pseudo-temperature \
|
118 |
+
--fix-lr --contrastive-use-vissl \
|
119 |
+
--nodes 4 --use_volta32
|
120 |
+
```
|
121 |
+
</details>
|
122 |
+
|
123 |
+
## Down-stream Evaluation
|
124 |
+
With the pre-trained dual-encoder at hand, we now can do zero-shot or fine-tuning evalution evaluations on down-stream benchmarks.
|
125 |
+
Please refer to [MODEL_ZOO.md](./MODEL_ZOO.md#zero-shot) for more details.
|
eval_narrator.py
ADDED
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import os.path as osp
|
9 |
+
import time
|
10 |
+
from collections import OrderedDict
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
# https://github.com/numpy/numpy/issues/21079
|
14 |
+
try:
|
15 |
+
import numpy.distutils
|
16 |
+
numpy.distutils.__config__.blas_opt_info = np.distutils.__config__.blas_ilp64_opt_info
|
17 |
+
except Exception:
|
18 |
+
pass
|
19 |
+
from nlgeval import NLGEval
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torchvision.transforms as transforms
|
23 |
+
import torchvision.transforms._transforms_video as transforms_video
|
24 |
+
|
25 |
+
from lavila.data import datasets
|
26 |
+
from lavila.data.video_transforms import Permute, SpatialCrop, TemporalCrop
|
27 |
+
from lavila.models import models
|
28 |
+
from lavila.models.utils import inflate_positional_embeds
|
29 |
+
from lavila.utils import distributed as dist_utils
|
30 |
+
from lavila.utils.preprocess import generate_tokenizer
|
31 |
+
|
32 |
+
|
33 |
+
def decode_one(generated_ids, tokenizer):
|
34 |
+
# get the index of <EOS>
|
35 |
+
if tokenizer.eos_token_id == tokenizer.bos_token_id:
|
36 |
+
if tokenizer.eos_token_id in generated_ids[1:].tolist():
|
37 |
+
eos_id = generated_ids[1:].tolist().index(tokenizer.eos_token_id) + 1
|
38 |
+
else:
|
39 |
+
eos_id = len(generated_ids.tolist()) - 1
|
40 |
+
elif tokenizer.eos_token_id in generated_ids.tolist():
|
41 |
+
eos_id = generated_ids.tolist().index(tokenizer.eos_token_id)
|
42 |
+
else:
|
43 |
+
eos_id = len(generated_ids.tolist()) - 1
|
44 |
+
generated_text_str = tokenizer.tokenizer.decode(generated_ids[1:eos_id].tolist())
|
45 |
+
return generated_text_str
|
46 |
+
|
47 |
+
|
48 |
+
def get_args_parser():
|
49 |
+
parser = argparse.ArgumentParser(description='LAVILA 0-shot evaluations', add_help=False)
|
50 |
+
parser.add_argument('--dataset', default='ego4d', type=str,
|
51 |
+
choices=['ego4d'])
|
52 |
+
parser.add_argument('--root',
|
53 |
+
default='datasets/Ego4D/video_5min_chunks_288px/',
|
54 |
+
type=str, help='path to dataset root')
|
55 |
+
parser.add_argument('--metadata-val',
|
56 |
+
default='datasets/Ego4D/ego4d_val.pkl',
|
57 |
+
type=str, help='path to metadata file (val set)')
|
58 |
+
parser.add_argument('--output-dir', default='./', type=str, help='output dir')
|
59 |
+
parser.add_argument('--num-crops', default=1, type=int, help='number of crops in transforms')
|
60 |
+
parser.add_argument('--num-clips', default=1, type=int, help='number of clips (for untrimmed videos, eg. Charades)')
|
61 |
+
parser.add_argument('--clip-length', default=4, type=int, help='clip length')
|
62 |
+
parser.add_argument('--clip-stride', default=16, type=int, help='clip stride')
|
63 |
+
parser.add_argument('--sparse-sample', action='store_true', help='switch to sparse sampling')
|
64 |
+
parser.add_argument('--batch-size', default=16, type=int, help='batch_size')
|
65 |
+
# captioning options
|
66 |
+
parser.add_argument('--caption-sample', default='multinomial_sample',
|
67 |
+
choices=['multinomial_sample', 'beam_sample', 'group_beam_search'])
|
68 |
+
parser.add_argument('--caption-top-k', default=None, type=int, help='top-k sampling (predecessor of nucleus sampling)')
|
69 |
+
parser.add_argument('--caption-top-p', default=0.95, type=float, help='top-p sampling sampling (aka nucleus sampling)')
|
70 |
+
parser.add_argument('--caption-num-beams', default=3, type=int)
|
71 |
+
parser.add_argument('--caption-num-beam-groups', default=1, type=int)
|
72 |
+
parser.add_argument('--caption-temperature', default=0.7, type=float)
|
73 |
+
parser.add_argument('--caption-length-penalty', default=1.0, type=float)
|
74 |
+
parser.add_argument('--caption-num-return-sequences', default=1, type=int)
|
75 |
+
parser.add_argument('--caption-max-len', default=77, type=int)
|
76 |
+
parser.add_argument('--caption-disable-visual', action='store_true')
|
77 |
+
parser.add_argument('--caption-early-stop', action='store_true', help='early stopping to save computation')
|
78 |
+
parser.add_argument('--caption-output-filename', default='caption.txt', type=str)
|
79 |
+
# others
|
80 |
+
parser.add_argument('--eval-freq', default=1000, type=int,
|
81 |
+
help='percentage (1/eval_freq) of val data to evaluate (for fast prototyping)')
|
82 |
+
parser.add_argument('--print-freq', default=10, type=int)
|
83 |
+
parser.add_argument('-j', '--workers', default=10, type=int, metavar='N',
|
84 |
+
help='number of data loading workers per process')
|
85 |
+
parser.add_argument('--resume', default='', type=str, help='path to latest checkpoint')
|
86 |
+
parser.add_argument('--use-half', action='store_true')
|
87 |
+
return parser
|
88 |
+
|
89 |
+
|
90 |
+
def main(args):
|
91 |
+
if args.resume:
|
92 |
+
ckpt_path = args.resume
|
93 |
+
elif osp.isfile(osp.join(args.output_dir, 'checkpoint_best.pt')):
|
94 |
+
ckpt_path = osp.join(args.output_dir, 'checkpoint_best.pt')
|
95 |
+
else:
|
96 |
+
raise Exception('no checkpoint found')
|
97 |
+
|
98 |
+
ckpt = torch.load(ckpt_path, map_location='cpu')
|
99 |
+
|
100 |
+
# create model
|
101 |
+
state_dict = OrderedDict()
|
102 |
+
for k, v in ckpt['state_dict'].items():
|
103 |
+
state_dict[k.replace('module.', '')] = v
|
104 |
+
|
105 |
+
old_args = ckpt['args']
|
106 |
+
print('=> creating model: {}'.format(old_args.model))
|
107 |
+
model = getattr(models, old_args.model)(
|
108 |
+
text_use_cls_token=old_args.use_cls_token,
|
109 |
+
project_embed_dim=old_args.project_embed_dim,
|
110 |
+
gated_xattn=False if 'gated_xattn' not in old_args else old_args.gated_xattn,
|
111 |
+
timesformer_gated_xattn=False if 'timesformer_gated_xattn' not in old_args else old_args.timesformer_gated_xattn,
|
112 |
+
timesformer_freeze_space=False if 'timesformer_freeze_space' not in old_args else old_args.timesformer_freeze_space,
|
113 |
+
freeze_lm_vclm=False if 'freeze_lm_vclm' not in old_args else old_args.freeze_lm_vclm,
|
114 |
+
freeze_visual_vclm=False if 'freeze_visual_vclm' not in old_args else old_args.freeze_visual_vclm,
|
115 |
+
num_frames=args.clip_length,
|
116 |
+
drop_path_rate=0,
|
117 |
+
)
|
118 |
+
model.cuda()
|
119 |
+
if 'TIMESFORMER' in old_args.model or 'EGOVLP' in old_args.model:
|
120 |
+
# inflate weight
|
121 |
+
print('=> inflating PE in models due to different frame numbers')
|
122 |
+
state_dict = inflate_positional_embeds(
|
123 |
+
model.state_dict(), state_dict,
|
124 |
+
num_frames=args.clip_length,
|
125 |
+
load_temporal_fix='bilinear',
|
126 |
+
)
|
127 |
+
model.load_state_dict(state_dict, strict=True)
|
128 |
+
print("=> loaded resume checkpoint '{}' (epoch {}, best_metric = {})".format(args.resume, ckpt['epoch'], ckpt['best_acc1']))
|
129 |
+
|
130 |
+
torch.backends.cudnn.benchmark = True
|
131 |
+
|
132 |
+
tokenizer = generate_tokenizer(old_args.model)
|
133 |
+
crop_size = 224 if '336PX' not in old_args.model else 336
|
134 |
+
if args.num_crops == 1 and args.num_clips == 1:
|
135 |
+
val_transform = transforms.Compose([
|
136 |
+
Permute([3, 0, 1, 2]), # T H W C -> C T H W
|
137 |
+
transforms.Resize(crop_size),
|
138 |
+
transforms.CenterCrop(crop_size),
|
139 |
+
(transforms_video.NormalizeVideo(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]) if ('OPENAI' not in old_args.model) else
|
140 |
+
transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305])),
|
141 |
+
])
|
142 |
+
else:
|
143 |
+
val_transform = transforms.Compose([
|
144 |
+
Permute([3, 0, 1, 2]), # T H W C -> C T H W
|
145 |
+
transforms.Resize(crop_size),
|
146 |
+
(transforms_video.NormalizeVideo(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]) if ('OPENAI' not in old_args.model) else
|
147 |
+
transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305])),
|
148 |
+
TemporalCrop(frames_per_clip=args.clip_length, stride=args.clip_length),
|
149 |
+
SpatialCrop(crop_size=crop_size, num_crops=args.num_crops),
|
150 |
+
])
|
151 |
+
|
152 |
+
val_dataset = datasets.VideoCaptionDatasetCLIP(
|
153 |
+
args.dataset,
|
154 |
+
args.root,
|
155 |
+
args.metadata_val,
|
156 |
+
transform=val_transform,
|
157 |
+
is_training=False,
|
158 |
+
tokenizer=tokenizer,
|
159 |
+
clip_length=args.clip_length,
|
160 |
+
clip_stride=args.clip_stride,
|
161 |
+
sparse_sample=False,
|
162 |
+
subsample_stride=args.eval_freq,
|
163 |
+
)
|
164 |
+
|
165 |
+
val_loader = torch.utils.data.DataLoader(
|
166 |
+
val_dataset, batch_size=args.batch_size, shuffle=False,
|
167 |
+
num_workers=args.workers, pin_memory=True, drop_last=False)
|
168 |
+
|
169 |
+
validate_caption(val_loader, model, tokenizer, args.caption_output_filename, use_half=args.use_half)
|
170 |
+
|
171 |
+
|
172 |
+
def validate_caption(val_loader, model, tokenizer, output_filename='caption.txt', use_half=False):
|
173 |
+
model.eval()
|
174 |
+
if args.use_half:
|
175 |
+
model = model.half()
|
176 |
+
nlgeval = NLGEval()
|
177 |
+
f = open(output_filename, 'w')
|
178 |
+
ppls_all = []
|
179 |
+
ppls_with_teacher_all = []
|
180 |
+
reference = []
|
181 |
+
hypothesis = []
|
182 |
+
end_time = time.time()
|
183 |
+
id_offset = 0
|
184 |
+
print('=> start forwarding')
|
185 |
+
with torch.no_grad():
|
186 |
+
for i, inputs in enumerate(val_loader):
|
187 |
+
if i % args.print_freq == 0:
|
188 |
+
print('finish batch {}/{} in {} sec'.format(i, len(val_loader), time.time() - end_time))
|
189 |
+
end_time = time.time()
|
190 |
+
images = inputs[0].cuda(non_blocking=True)
|
191 |
+
if use_half:
|
192 |
+
images = images.half()
|
193 |
+
target = inputs[1].cuda(non_blocking=True)
|
194 |
+
|
195 |
+
# encode images
|
196 |
+
image_features = dist_utils.get_model(model).encode_image(images)
|
197 |
+
|
198 |
+
# teacher forcing (to get standard ppl metric)
|
199 |
+
generated_text_ids_with_teacher, ppls_with_teacher = dist_utils.get_model(model).generate(
|
200 |
+
image_features,
|
201 |
+
tokenizer,
|
202 |
+
target=target,
|
203 |
+
max_text_length=args.caption_max_len,
|
204 |
+
top_k=args.caption_top_k,
|
205 |
+
top_p=args.caption_top_p,
|
206 |
+
teacher_forcing=True,
|
207 |
+
early_stopping=args.caption_early_stop,
|
208 |
+
)
|
209 |
+
|
210 |
+
if args.caption_sample == 'multinomial_sample':
|
211 |
+
assert args.caption_num_beam_groups == 1
|
212 |
+
generated_text_ids, ppls = dist_utils.get_model(model).generate(
|
213 |
+
image_features,
|
214 |
+
tokenizer,
|
215 |
+
target=target.repeat_interleave(args.caption_num_return_sequences, dim=0),
|
216 |
+
max_text_length=args.caption_max_len,
|
217 |
+
top_k=args.caption_top_k,
|
218 |
+
top_p=args.caption_top_p,
|
219 |
+
num_return_sequences=args.caption_num_return_sequences,
|
220 |
+
temperature=args.caption_temperature,
|
221 |
+
early_stopping=args.caption_early_stop,
|
222 |
+
)
|
223 |
+
elif args.caption_sample == 'beam_sample':
|
224 |
+
assert args.caption_num_beam_groups == 1
|
225 |
+
generated_text_ids, ppls = dist_utils.get_model(model).beam_sample(
|
226 |
+
image_features,
|
227 |
+
tokenizer,
|
228 |
+
target=target,
|
229 |
+
max_text_length=args.caption_max_len,
|
230 |
+
top_k=args.caption_top_k,
|
231 |
+
top_p=args.caption_top_p,
|
232 |
+
temperature=args.caption_temperature,
|
233 |
+
length_penalty=args.caption_length_penalty,
|
234 |
+
num_beams=args.caption_num_beams,
|
235 |
+
num_return_sequences=args.caption_num_return_sequences,
|
236 |
+
early_stopping=args.caption_early_stop,
|
237 |
+
)
|
238 |
+
elif args.caption_sample == 'group_beam_search':
|
239 |
+
assert args.caption_num_beam_groups > 1 and args.caption_num_beams % args.caption_num_beam_groups == 0
|
240 |
+
generated_text_ids, ppls = dist_utils.get_model(model).group_beam_search(
|
241 |
+
image_features,
|
242 |
+
tokenizer,
|
243 |
+
target=target if not args.caption_no_gt else None,
|
244 |
+
max_text_length=args.caption_max_len,
|
245 |
+
top_k=args.caption_top_k,
|
246 |
+
top_p=args.caption_top_p,
|
247 |
+
temperature=args.caption_temperature,
|
248 |
+
length_penalty=args.caption_length_penalty,
|
249 |
+
num_beams=args.caption_num_beams,
|
250 |
+
num_beam_groups=args.caption_num_beam_groups,
|
251 |
+
num_return_sequences=args.caption_num_return_sequences,
|
252 |
+
early_stopping=args.caption_early_stop,
|
253 |
+
)
|
254 |
+
else:
|
255 |
+
raise NotImplementedError
|
256 |
+
ppls_all.append(ppls.reshape(-1, args.caption_num_return_sequences).mean(1))
|
257 |
+
ppls_with_teacher_all.append(ppls_with_teacher)
|
258 |
+
|
259 |
+
for j in range(generated_text_ids.shape[0] // args.caption_num_return_sequences):
|
260 |
+
for k in range(args.caption_num_return_sequences):
|
261 |
+
jj = j * args.caption_num_return_sequences + k
|
262 |
+
|
263 |
+
generated_text_str = decode_one(generated_text_ids[jj], tokenizer)
|
264 |
+
gt_text = decode_one(target[j], tokenizer)
|
265 |
+
generated_text_str_with_teacher = decode_one(generated_text_ids_with_teacher[j], tokenizer)
|
266 |
+
|
267 |
+
from transformers import BertTokenizer
|
268 |
+
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
269 |
+
gt_text = bert_tokenizer.decode(bert_tokenizer(gt_text)['input_ids'][1:-1])
|
270 |
+
generated_text_str = bert_tokenizer.decode(bert_tokenizer(generated_text_str)['input_ids'][1:-1])
|
271 |
+
generated_text_str_with_teacher = bert_tokenizer.decode(bert_tokenizer(generated_text_str_with_teacher)['input_ids'][1:-1])
|
272 |
+
reference.append(gt_text)
|
273 |
+
hypothesis.append(generated_text_str)
|
274 |
+
s1 = '[{:6d}] Groundtruth | | {}'.format(id_offset + j, gt_text)
|
275 |
+
s2 = '[{:6d}] Generated | PPL : {:9.3f} | {}'.format(id_offset + j, ppls[jj], generated_text_str)
|
276 |
+
s3 = '[{:6d}] Generated (w/. teacher) | PPL : {:9.3f} | {}'.format(id_offset + j, ppls_with_teacher[j], generated_text_str_with_teacher)
|
277 |
+
for s in [s1, s2, s3]:
|
278 |
+
# if i % args.print_freq == 0:
|
279 |
+
# print(s)
|
280 |
+
f.write('{} \n'.format(s))
|
281 |
+
id_offset += generated_text_ids.shape[0] // args.caption_num_return_sequences
|
282 |
+
|
283 |
+
ppls_with_teacher_all = torch.cat(ppls_with_teacher_all, dim=0)
|
284 |
+
ppls_all = torch.cat(ppls_all, dim=0)
|
285 |
+
|
286 |
+
print('PPL (w/. teacher) = {:9.3f}'.format(ppls_with_teacher_all.mean().item()))
|
287 |
+
print('PPL (w/o. teacher) = {:9.3f}'.format(ppls_all.mean().item()))
|
288 |
+
f.write('PPL (w/. teacher) = {:9.3f} \n'.format(ppls_with_teacher_all.mean().item()))
|
289 |
+
f.write('PPL (w/o. teacher) = {:9.3f} \n'.format(ppls_all.mean().item()))
|
290 |
+
|
291 |
+
print('Avg length for reference: {:9.3f}'.format(sum(map(lambda sentence: len(sentence.split(' ')), reference)) / len(reference)))
|
292 |
+
print('Avg length for hypothesis: {:9.3f}'.format(sum(map(lambda sentence: len(sentence.split(' ')), hypothesis)) / len(hypothesis)))
|
293 |
+
f.write('Avg length for reference: {:9.3f} \n'.format(sum(map(lambda sentence: len(sentence.split(' ')), reference)) / len(reference)))
|
294 |
+
f.write('Avg length for hypothesis: {:9.3f} \n'.format(sum(map(lambda sentence: len(sentence.split(' ')), hypothesis)) / len(hypothesis)))
|
295 |
+
|
296 |
+
print('=> Calling NLGEval')
|
297 |
+
f.write('=> Calling NLGEval\n')
|
298 |
+
metrics_dict = nlgeval.compute_metrics([reference], hypothesis)
|
299 |
+
for k in metrics_dict:
|
300 |
+
print('{:16s} = {:9.3f}'.format(k, metrics_dict[k]))
|
301 |
+
f.write('{:16s} = {:9.3f} \n'.format(k, metrics_dict[k]))
|
302 |
+
f.close()
|
303 |
+
|
304 |
+
|
305 |
+
if __name__ == '__main__':
|
306 |
+
parser = argparse.ArgumentParser('lavila 0-shot evaluations', parents=[get_args_parser()])
|
307 |
+
args = parser.parse_args()
|
308 |
+
main(args)
|
eval_zeroshot.py
ADDED
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import numpy as np
|
9 |
+
import os.path as osp
|
10 |
+
import time
|
11 |
+
from collections import OrderedDict
|
12 |
+
|
13 |
+
import pandas as pd
|
14 |
+
import torch
|
15 |
+
import torchvision.transforms as transforms
|
16 |
+
import torchvision.transforms._transforms_video as transforms_video
|
17 |
+
from sklearn.metrics import confusion_matrix
|
18 |
+
|
19 |
+
from lavila.data import datasets
|
20 |
+
from lavila.data.video_transforms import Permute, SpatialCrop, TemporalCrop
|
21 |
+
from lavila.models import models
|
22 |
+
from lavila.models.utils import inflate_positional_embeds
|
23 |
+
from lavila.utils import distributed as dist_utils
|
24 |
+
from lavila.utils.evaluation import accuracy, get_mean_accuracy
|
25 |
+
from lavila.utils.evaluation_egomcq import egomcq_accuracy_metrics
|
26 |
+
from lavila.utils.evaluation_ek100mir import (calculate_k_counts, calculate_IDCG, calculate_mAP, calculate_nDCG)
|
27 |
+
from lavila.utils.evaluation_charades import charades_map
|
28 |
+
from lavila.utils.preprocess import generate_label_map, generate_tokenizer
|
29 |
+
|
30 |
+
|
31 |
+
def get_args_parser():
|
32 |
+
parser = argparse.ArgumentParser(description='LAVILA 0-shot evaluations', add_help=False)
|
33 |
+
parser.add_argument('--dataset', default='ek100_mir', type=str,
|
34 |
+
choices=['ek100_cls', 'ek100_mir', 'charades_ego', 'egtea', 'ego4d_mcq'])
|
35 |
+
parser.add_argument('--root',
|
36 |
+
default='datasets/EK100/video_ht256px/',
|
37 |
+
type=str, help='path to dataset root')
|
38 |
+
parser.add_argument('--metadata-val',
|
39 |
+
default='datasets/EK100/epic-kitchens-100-annotations/retrieval_annotations/EPIC_100_retrieval_test.csv',
|
40 |
+
type=str, help='path to metadata file (val set)')
|
41 |
+
parser.add_argument('--relevancy-path',
|
42 |
+
default='datasets/EK100/epic-kitchens-100-annotations/retrieval_annotations/relevancy/caption_relevancy_EPIC_100_retrieval_test.pkl',
|
43 |
+
type=str, help='path to relevancy matrix (val set)')
|
44 |
+
parser.add_argument('--output-dir', default='./', type=str, help='output dir')
|
45 |
+
parser.add_argument('--num-crops', default=1, type=int, help='number of crops in transforms')
|
46 |
+
parser.add_argument('--num-clips', default=1, type=int, help='number of clips (for untrimmed videos, eg. Charades)')
|
47 |
+
parser.add_argument('--clip-length', default=4, type=int, help='clip length')
|
48 |
+
parser.add_argument('--clip-stride', default=16, type=int, help='clip stride')
|
49 |
+
parser.add_argument('--sparse-sample', action='store_true', help='switch to sparse sampling')
|
50 |
+
parser.add_argument('--batch-size', default=16, type=int, help='batch_size')
|
51 |
+
parser.add_argument('--cls-use-template', action='store_true', help='use prompt in 0-shot classification')
|
52 |
+
parser.add_argument('--print-freq', default=100, type=int)
|
53 |
+
parser.add_argument('-j', '--workers', default=10, type=int, metavar='N',
|
54 |
+
help='number of data loading workers per process')
|
55 |
+
parser.add_argument('--resume', default='', type=str, help='path to latest checkpoint')
|
56 |
+
parser.add_argument('--use-half', action='store_true')
|
57 |
+
return parser
|
58 |
+
|
59 |
+
|
60 |
+
def main(args):
|
61 |
+
if args.resume:
|
62 |
+
ckpt_path = args.resume
|
63 |
+
elif osp.isfile(osp.join(args.output_dir, 'checkpoint_best.pt')):
|
64 |
+
ckpt_path = osp.join(args.output_dir, 'checkpoint_best.pt')
|
65 |
+
else:
|
66 |
+
raise Exception('no checkpoint found')
|
67 |
+
|
68 |
+
ckpt = torch.load(ckpt_path, map_location='cpu')
|
69 |
+
|
70 |
+
# create model
|
71 |
+
state_dict = OrderedDict()
|
72 |
+
for k, v in ckpt['state_dict'].items():
|
73 |
+
state_dict[k.replace('module.', '')] = v
|
74 |
+
|
75 |
+
old_args = ckpt['args']
|
76 |
+
print('=> creating model: {}'.format(old_args.model))
|
77 |
+
model = getattr(models, old_args.model)(
|
78 |
+
text_use_cls_token=old_args.use_cls_token,
|
79 |
+
project_embed_dim=old_args.project_embed_dim,
|
80 |
+
gated_xattn=False if 'gated_xattn' not in old_args else old_args.gated_xattn,
|
81 |
+
timesformer_gated_xattn=False if 'timesformer_gated_xattn' not in old_args else old_args.timesformer_gated_xattn,
|
82 |
+
timesformer_freeze_space=False if 'timesformer_freeze_space' not in old_args else old_args.timesformer_freeze_space,
|
83 |
+
freeze_lm_vclm=False if 'freeze_lm_vclm' not in old_args else old_args.freeze_lm_vclm,
|
84 |
+
freeze_visual_vclm=False if 'freeze_visual_vclm' not in old_args else old_args.freeze_visual_vclm,
|
85 |
+
num_frames=args.clip_length,
|
86 |
+
drop_path_rate=0,
|
87 |
+
)
|
88 |
+
model.cuda()
|
89 |
+
if 'TIMESFORMER' in old_args.model or 'EGOVLP' in old_args.model:
|
90 |
+
# inflate weight
|
91 |
+
print('=> inflating PE in models due to different frame numbers')
|
92 |
+
state_dict = inflate_positional_embeds(
|
93 |
+
model.state_dict(), state_dict,
|
94 |
+
num_frames=args.clip_length,
|
95 |
+
load_temporal_fix='bilinear',
|
96 |
+
)
|
97 |
+
model.load_state_dict(state_dict, strict=True)
|
98 |
+
print("=> loaded resume checkpoint '{}' (epoch {}, best_metric = {})".format(args.resume, ckpt['epoch'], ckpt['best_acc1']))
|
99 |
+
|
100 |
+
torch.backends.cudnn.benchmark = True
|
101 |
+
|
102 |
+
if args.dataset in ['ek100_cls', 'charades_ego', 'egtea']:
|
103 |
+
labels, mapping_vn2act = generate_label_map(args.dataset)
|
104 |
+
else:
|
105 |
+
mapping_vn2act = None
|
106 |
+
tokenizer = generate_tokenizer(old_args.model)
|
107 |
+
crop_size = 224 if '336PX' not in old_args.model else 336
|
108 |
+
if args.num_crops == 1 and args.num_clips == 1:
|
109 |
+
val_transform = transforms.Compose([
|
110 |
+
Permute([3, 0, 1, 2]), # T H W C -> C T H W
|
111 |
+
transforms.Resize(crop_size),
|
112 |
+
transforms.CenterCrop(crop_size),
|
113 |
+
(transforms_video.NormalizeVideo(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]) if ('OPENAI' not in old_args.model) else
|
114 |
+
transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305])),
|
115 |
+
])
|
116 |
+
else:
|
117 |
+
val_transform = transforms.Compose([
|
118 |
+
Permute([3, 0, 1, 2]), # T H W C -> C T H W
|
119 |
+
transforms.Resize(crop_size),
|
120 |
+
(transforms_video.NormalizeVideo(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]) if ('OPENAI' not in old_args.model) else
|
121 |
+
transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305])),
|
122 |
+
TemporalCrop(frames_per_clip=args.clip_length, stride=args.clip_length),
|
123 |
+
SpatialCrop(crop_size=crop_size, num_crops=args.num_crops),
|
124 |
+
])
|
125 |
+
|
126 |
+
val_dataset = datasets.get_downstream_dataset(
|
127 |
+
val_transform, tokenizer, args, subset='val', label_mapping=mapping_vn2act,
|
128 |
+
)
|
129 |
+
|
130 |
+
val_loader = torch.utils.data.DataLoader(
|
131 |
+
val_dataset, batch_size=args.batch_size, shuffle=False,
|
132 |
+
num_workers=args.workers, pin_memory=True, drop_last=False)
|
133 |
+
|
134 |
+
if args.cls_use_template:
|
135 |
+
templates = ['#C C {}', '#C {}']
|
136 |
+
else:
|
137 |
+
templates = ['{}']
|
138 |
+
|
139 |
+
if args.dataset in ['ek100_cls', 'charades_ego', 'egtea']:
|
140 |
+
preds, targets = validate_zeroshot(val_loader, templates, labels, model, tokenizer)
|
141 |
+
if args.dataset == 'ek100_cls':
|
142 |
+
if args.use_half:
|
143 |
+
preds = preds.float()
|
144 |
+
top1, top5 = accuracy(preds, targets, topk=(1, 5))
|
145 |
+
print('top1 = {:.3f}'.format(top1.item()))
|
146 |
+
print('top5 = {:.3f}'.format(top5.item()))
|
147 |
+
elif args.dataset == 'charades_ego':
|
148 |
+
preds, targets = preds.numpy(), targets.numpy()
|
149 |
+
m_ap, _, _ = charades_map(preds, targets)
|
150 |
+
print('mAP = {:.3f}'.format(m_ap))
|
151 |
+
elif args.dataset == 'egtea':
|
152 |
+
preds, targets = preds.numpy(), targets.numpy()
|
153 |
+
print(preds.shape, targets.shape)
|
154 |
+
cm = confusion_matrix(targets, preds.argmax(axis=1))
|
155 |
+
mean_class_acc, acc = get_mean_accuracy(cm)
|
156 |
+
print('Mean Acc. = {:.3f}, Top-1 Acc. = {:.3f}'.format(mean_class_acc, acc))
|
157 |
+
|
158 |
+
if args.dataset == 'ek100_mir':
|
159 |
+
val_dataset = datasets.VideoCaptionDatasetCLIP(
|
160 |
+
'ek100_mir',
|
161 |
+
args.root,
|
162 |
+
args.metadata_val,
|
163 |
+
transform=val_transform, is_training=False,
|
164 |
+
tokenizer=tokenizer,
|
165 |
+
clip_length=args.clip_length,
|
166 |
+
clip_stride=args.clip_stride,
|
167 |
+
sparse_sample=False
|
168 |
+
)
|
169 |
+
val_loader = torch.utils.data.DataLoader(
|
170 |
+
val_dataset, batch_size=args.batch_size, shuffle=False,
|
171 |
+
num_workers=args.workers, pin_memory=True, drop_last=False
|
172 |
+
)
|
173 |
+
similarity_matrix = get_similarity_matrix(val_loader, model, print_freq=args.print_freq, use_half=args.use_half)
|
174 |
+
similarity_matrix = (similarity_matrix + 1) / 2
|
175 |
+
video_id = pd.read_csv(args.metadata_val).values[:, 0]
|
176 |
+
text_id = pd.read_csv(args.metadata_val.replace("test.csv", "test_sentence.csv")).values[:, 0]
|
177 |
+
indexes = [video_id.tolist().index(elem) for elem in text_id]
|
178 |
+
similarity_matrix = similarity_matrix[:, indexes]
|
179 |
+
print(similarity_matrix.shape)
|
180 |
+
rel_matrix = pd.read_pickle(args.relevancy_path)
|
181 |
+
vis_map = calculate_mAP(similarity_matrix, rel_matrix)
|
182 |
+
txt_map = calculate_mAP(similarity_matrix.T, rel_matrix.T)
|
183 |
+
print('mAP: V->T: {:.3f} T->V: {:.3f} AVG: {:.3f}'.format(vis_map, txt_map, (vis_map + txt_map) / 2))
|
184 |
+
vis_k_counts = calculate_k_counts(rel_matrix)
|
185 |
+
txt_k_counts = calculate_k_counts(rel_matrix.T)
|
186 |
+
vis_IDCG = calculate_IDCG(rel_matrix, vis_k_counts)
|
187 |
+
txt_IDCG = calculate_IDCG(rel_matrix.T, txt_k_counts)
|
188 |
+
vis_nDCG = calculate_nDCG(similarity_matrix, rel_matrix, k_counts=vis_k_counts, IDCG=vis_IDCG)
|
189 |
+
txt_nDCG = calculate_nDCG(similarity_matrix.T, rel_matrix.T, k_counts=txt_k_counts, IDCG=txt_IDCG)
|
190 |
+
print('nDCG: V->T: {:.3f} T->V: {:.3f} AVG: {:.3f}'.format(vis_nDCG, txt_nDCG, (vis_nDCG + txt_nDCG) / 2))
|
191 |
+
|
192 |
+
if args.dataset == 'ego4d_mcq':
|
193 |
+
val_dataset = datasets.VideoCaptionDatasetMCQ(
|
194 |
+
args.dataset,
|
195 |
+
args.root,
|
196 |
+
args.metadata_val,
|
197 |
+
transform=val_transform, is_training=False,
|
198 |
+
tokenizer=tokenizer,
|
199 |
+
clip_length=args.clip_length,
|
200 |
+
clip_stride=args.clip_stride,
|
201 |
+
sparse_sample=False,
|
202 |
+
)
|
203 |
+
val_loader = torch.utils.data.DataLoader(
|
204 |
+
val_dataset, batch_size=args.batch_size, shuffle=False,
|
205 |
+
num_workers=args.workers, pin_memory=True, drop_last=False
|
206 |
+
)
|
207 |
+
validate_mcq(val_loader, model, use_half=args.use_half)
|
208 |
+
|
209 |
+
|
210 |
+
def validate_zeroshot(val_loader, templates, labels, model, tokenizer):
|
211 |
+
model.eval()
|
212 |
+
if args.use_half:
|
213 |
+
model = model.half()
|
214 |
+
all_outputs = []
|
215 |
+
all_targets = []
|
216 |
+
all_vis_features = []
|
217 |
+
print('=> encoding captions')
|
218 |
+
with torch.no_grad():
|
219 |
+
text_features = []
|
220 |
+
for label in labels:
|
221 |
+
if isinstance(label, list):
|
222 |
+
texts = [tmpl.format(lbl) for tmpl in templates for lbl in label]
|
223 |
+
else:
|
224 |
+
texts = [tmpl.format(label) for tmpl in templates]
|
225 |
+
texts = tokenizer(texts)
|
226 |
+
if isinstance(texts, tuple):
|
227 |
+
# Bert-style tokenizer will output both ids and mask
|
228 |
+
texts, masks = texts
|
229 |
+
texts = texts.cuda(non_blocking=True)
|
230 |
+
masks = masks.cuda(non_blocking=True)
|
231 |
+
else:
|
232 |
+
texts = texts.cuda(non_blocking=True)
|
233 |
+
masks = None
|
234 |
+
texts = texts.view(-1, 77).contiguous()
|
235 |
+
masks = masks.view(-1, 77).contiguous() if masks is not None else None
|
236 |
+
if masks is not None:
|
237 |
+
class_embeddings = dist_utils.get_model(model).encode_text(texts, attention_mask=masks)
|
238 |
+
else:
|
239 |
+
class_embeddings = dist_utils.get_model(model).encode_text(texts)
|
240 |
+
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
|
241 |
+
|
242 |
+
class_embeddings = class_embeddings.mean(dim=0)
|
243 |
+
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
|
244 |
+
|
245 |
+
text_features.append(class_embeddings)
|
246 |
+
text_features = torch.stack(text_features, dim=0)
|
247 |
+
|
248 |
+
print('=> start forwarding')
|
249 |
+
end_time = time.time()
|
250 |
+
for i, (images, target) in enumerate(val_loader):
|
251 |
+
if i % args.print_freq == 0:
|
252 |
+
print('finish batch {}/{} in {} sec'.format(i, len(val_loader), time.time() - end_time))
|
253 |
+
end_time = time.time()
|
254 |
+
if isinstance(images, torch.Tensor):
|
255 |
+
images = images.cuda(non_blocking=True)
|
256 |
+
if args.use_half:
|
257 |
+
images = images.half()
|
258 |
+
target = target.cuda(non_blocking=True)
|
259 |
+
|
260 |
+
# encode images
|
261 |
+
image_features = dist_utils.get_model(model).encode_image(images)
|
262 |
+
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
263 |
+
all_vis_features.append(image_features)
|
264 |
+
# cosine similarity as logits
|
265 |
+
logits_per_image = image_features @ text_features.t()
|
266 |
+
# logits_per_image = torch.softmax(logits_per_image, dim=1)
|
267 |
+
else:
|
268 |
+
target = target.cuda(non_blocking=True)
|
269 |
+
images_list = images
|
270 |
+
logits_all_clips = []
|
271 |
+
for images in images_list:
|
272 |
+
images = images.cuda(non_blocking=True)
|
273 |
+
if args.use_half:
|
274 |
+
images = images.half()
|
275 |
+
image_features = dist_utils.get_model(model).encode_image(images)
|
276 |
+
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
277 |
+
logits_per_image = image_features @ text_features.t()
|
278 |
+
logits_all_clips.append(logits_per_image)
|
279 |
+
|
280 |
+
logits_all_clips = torch.stack(logits_all_clips, dim=0)
|
281 |
+
logits_per_image = logits_all_clips.max(0).values
|
282 |
+
# logits_per_image = logits_all_clips.mean(0)
|
283 |
+
logits_per_image = torch.softmax(logits_per_image, dim=1)
|
284 |
+
|
285 |
+
all_outputs.append(logits_per_image.cpu())
|
286 |
+
all_targets.append(target.cpu())
|
287 |
+
|
288 |
+
return torch.cat(all_outputs), torch.cat(all_targets)
|
289 |
+
|
290 |
+
|
291 |
+
def get_similarity_matrix(val_loader, model, print_freq=100, use_half=False):
|
292 |
+
model.eval()
|
293 |
+
if use_half:
|
294 |
+
model = model.half()
|
295 |
+
all_text_embed = []
|
296 |
+
all_video_embed = []
|
297 |
+
with torch.no_grad():
|
298 |
+
print('=> encoding visual and textual')
|
299 |
+
for i, inputs in enumerate(val_loader):
|
300 |
+
if i % print_freq == 0:
|
301 |
+
print('finish batch {}/{}'.format(i, len(val_loader)))
|
302 |
+
frames = inputs[0].cuda(non_blocking=True)
|
303 |
+
if use_half:
|
304 |
+
frames = frames.half()
|
305 |
+
texts = inputs[1].cuda(non_blocking=True)
|
306 |
+
if len(inputs) == 4:
|
307 |
+
masks = inputs[2].cuda(non_blocking=True)
|
308 |
+
else:
|
309 |
+
masks = None
|
310 |
+
|
311 |
+
# encode images
|
312 |
+
image_features = dist_utils.get_model(model).encode_image(frames)
|
313 |
+
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
314 |
+
all_video_embed.append(image_features.cpu().numpy())
|
315 |
+
|
316 |
+
if texts.ndim == 3:
|
317 |
+
is_multiple_narrations = True
|
318 |
+
texts = texts.view(-1, texts.shape[-1])
|
319 |
+
else:
|
320 |
+
is_multiple_narrations = False
|
321 |
+
if masks is not None:
|
322 |
+
text_features = dist_utils.get_model(model).encode_text(texts, attention_mask=masks)
|
323 |
+
else:
|
324 |
+
text_features = dist_utils.get_model(model).encode_text(texts)
|
325 |
+
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
326 |
+
all_text_embed.append(text_features.cpu().numpy())
|
327 |
+
|
328 |
+
all_text_embed = np.vstack(all_text_embed)
|
329 |
+
all_video_embed = np.vstack(all_video_embed)
|
330 |
+
similarity_matrix = np.matmul(all_video_embed, all_text_embed.T)
|
331 |
+
if is_multiple_narrations:
|
332 |
+
similarity_matrix = similarity_matrix.reshape(all_video_embed.shape[0], all_video_embed.shape[0], -1)
|
333 |
+
|
334 |
+
return similarity_matrix
|
335 |
+
|
336 |
+
|
337 |
+
def validate_mcq(val_loader, model, use_half=False):
|
338 |
+
model.eval()
|
339 |
+
if use_half:
|
340 |
+
model.half()
|
341 |
+
with torch.no_grad():
|
342 |
+
print('=> start forwarding')
|
343 |
+
all_preds = []
|
344 |
+
all_gts = []
|
345 |
+
all_types = []
|
346 |
+
end_time = time.time()
|
347 |
+
for i, inputs in enumerate(val_loader):
|
348 |
+
if i % args.print_freq == 0:
|
349 |
+
print('finish batch {}/{} in {} sec'.format(i, len(val_loader), time.time() - end_time))
|
350 |
+
end_time = time.time()
|
351 |
+
texts_query = inputs[0].cuda(non_blocking=True)
|
352 |
+
frames_options = inputs[1].cuda(non_blocking=True)
|
353 |
+
if use_half:
|
354 |
+
frames_options = frames_options.half()
|
355 |
+
answer = inputs[3]
|
356 |
+
q_type = inputs[4]
|
357 |
+
if len(inputs) == 7:
|
358 |
+
masks_query = inputs[5].cuda(non_blocking=True)
|
359 |
+
else:
|
360 |
+
masks_query = None
|
361 |
+
|
362 |
+
batch_size = frames_options.shape[0]
|
363 |
+
|
364 |
+
frames_options = frames_options.view(-1, *frames_options.shape[2:])
|
365 |
+
image_features = dist_utils.get_model(model).encode_image(frames_options)
|
366 |
+
image_features = image_features.view(batch_size, -1, *image_features.shape[1:])
|
367 |
+
|
368 |
+
if masks_query is not None:
|
369 |
+
query_features = dist_utils.get_model(model).encode_text(texts_query, attention_mask=masks_query)
|
370 |
+
else:
|
371 |
+
query_features = dist_utils.get_model(model).encode_text(texts_query)
|
372 |
+
|
373 |
+
all_gts.append(answer)
|
374 |
+
all_types.append(q_type)
|
375 |
+
for j in range(batch_size):
|
376 |
+
similarity_matrix = torch.matmul(query_features[j], image_features[j].T)
|
377 |
+
similarity_matrix = similarity_matrix.cpu().detach()
|
378 |
+
all_preds.append(similarity_matrix)
|
379 |
+
all_preds = torch.stack(all_preds)
|
380 |
+
all_gts = torch.cat(all_gts)
|
381 |
+
all_types = torch.cat(all_types)
|
382 |
+
metrics = egomcq_accuracy_metrics(all_preds, all_gts, all_types)
|
383 |
+
print(metrics)
|
384 |
+
|
385 |
+
|
386 |
+
if __name__ == '__main__':
|
387 |
+
parser = argparse.ArgumentParser('lavila 0-shot evaluations', parents=[get_args_parser()])
|
388 |
+
args = parser.parse_args()
|
389 |
+
main(args)
|
lavila/data/__pycache__/datasets.cpython-38.pyc
ADDED
Binary file (14.4 kB). View file
|
|
lavila/data/__pycache__/video_transforms.cpython-38.pyc
ADDED
Binary file (6.25 kB). View file
|
|
lavila/data/datasets.py
ADDED
@@ -0,0 +1,517 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import csv
|
8 |
+
import glob
|
9 |
+
import json
|
10 |
+
import numpy as np
|
11 |
+
import os.path as osp
|
12 |
+
import pickle
|
13 |
+
import random
|
14 |
+
|
15 |
+
import decord
|
16 |
+
import pandas as pd
|
17 |
+
import torch
|
18 |
+
|
19 |
+
|
20 |
+
def datetime2sec(str):
|
21 |
+
hh, mm, ss = str.split(':')
|
22 |
+
return int(hh) * 3600 + int(mm) * 60 + float(ss)
|
23 |
+
|
24 |
+
|
25 |
+
def video_loader(root, vid, second, end_second=None, chunk_len=300, fps=30, clip_length=32, jitter=False):
|
26 |
+
if chunk_len == -1:
|
27 |
+
vr = decord.VideoReader(osp.join(root, '{}.mp4'.format(vid)))
|
28 |
+
second_offset = second
|
29 |
+
if end_second is not None:
|
30 |
+
end_second = min(end_second, len(vr) / vr.get_avg_fps())
|
31 |
+
else:
|
32 |
+
end_second = len(vr) / vr.get_avg_fps()
|
33 |
+
else:
|
34 |
+
chunk_start = int(second) // chunk_len * chunk_len
|
35 |
+
second_offset = second - chunk_start
|
36 |
+
vr = decord.VideoReader(osp.join(root, '{}.mp4'.format(vid), '{}.mp4'.format(chunk_start)))
|
37 |
+
if fps == -1:
|
38 |
+
fps = vr.get_avg_fps()
|
39 |
+
|
40 |
+
# calculate frame_ids
|
41 |
+
frame_offset = int(np.round(second_offset * fps))
|
42 |
+
total_duration = max(int((end_second - second) * fps), clip_length)
|
43 |
+
if chunk_len == -1:
|
44 |
+
if end_second <= second:
|
45 |
+
raise ValueError("end_second should be greater than second")
|
46 |
+
else:
|
47 |
+
frame_ids = get_frame_ids(frame_offset, min(frame_offset + total_duration, len(vr)), num_segments=clip_length, jitter=jitter)
|
48 |
+
else:
|
49 |
+
frame_ids = get_frame_ids(frame_offset, frame_offset + total_duration, num_segments=clip_length, jitter=jitter)
|
50 |
+
|
51 |
+
# load frames
|
52 |
+
if max(frame_ids) < len(vr):
|
53 |
+
try:
|
54 |
+
frames = vr.get_batch(frame_ids).asnumpy()
|
55 |
+
except decord.DECORDError as error:
|
56 |
+
print(error)
|
57 |
+
frames = vr.get_batch([0] * len(frame_ids)).asnumpy()
|
58 |
+
else:
|
59 |
+
# find the remaining frames in the next chunk
|
60 |
+
try:
|
61 |
+
frame_ids_part1 = list(filter(lambda frame_id: frame_id < len(vr), frame_ids))
|
62 |
+
frames_part1 = vr.get_batch(frame_ids_part1).asnumpy()
|
63 |
+
vr2 = decord.VideoReader(osp.join(root, '{}.mp4'.format(vid), '{}.mp4'.format(chunk_start + chunk_len)))
|
64 |
+
frame_ids_part2 = list(filter(lambda frame_id: frame_id >= len(vr), frame_ids))
|
65 |
+
frame_ids_part2 = [min(frame_id % len(vr), len(vr2) - 1) for frame_id in frame_ids_part2]
|
66 |
+
frames_part2 = vr2.get_batch(frame_ids_part2).asnumpy()
|
67 |
+
frames = np.concatenate([frames_part1, frames_part2], axis=0)
|
68 |
+
# the next chunk does not exist; the current chunk is the last one
|
69 |
+
except (RuntimeError, decord.DECORDError) as error:
|
70 |
+
print(error)
|
71 |
+
frame_ids = get_frame_ids(min(frame_offset, len(vr) - 1), len(vr), num_segments=clip_length, jitter=jitter)
|
72 |
+
frames = vr.get_batch(frame_ids).asnumpy()
|
73 |
+
|
74 |
+
frames = [torch.tensor(frame, dtype=torch.float32) for frame in frames]
|
75 |
+
return torch.stack(frames, dim=0)
|
76 |
+
|
77 |
+
|
78 |
+
def get_frame_ids(start_frame, end_frame, num_segments=32, jitter=True):
|
79 |
+
seg_size = float(end_frame - start_frame - 1) / num_segments
|
80 |
+
seq = []
|
81 |
+
for i in range(num_segments):
|
82 |
+
start = int(np.round(seg_size * i) + start_frame)
|
83 |
+
end = int(np.round(seg_size * (i + 1)) + start_frame)
|
84 |
+
end = min(end, end_frame)
|
85 |
+
if jitter:
|
86 |
+
frame_id = np.random.randint(low=start, high=(end + 1))
|
87 |
+
else:
|
88 |
+
frame_id = (start + end) // 2
|
89 |
+
seq.append(frame_id)
|
90 |
+
return seq
|
91 |
+
|
92 |
+
|
93 |
+
def video_loader_by_frames(root, vid, frame_ids):
|
94 |
+
vr = decord.VideoReader(osp.join(root, vid))
|
95 |
+
try:
|
96 |
+
frames = vr.get_batch(frame_ids).asnumpy()
|
97 |
+
frames = [torch.tensor(frame, dtype=torch.float32) for frame in frames]
|
98 |
+
except (IndexError, decord.DECORDError) as error:
|
99 |
+
print(error)
|
100 |
+
print("Erroneous video: ", vid)
|
101 |
+
frames = [torch.zeros((240, 320, 3)) for _ in range(len(frame_ids))]
|
102 |
+
return torch.stack(frames, dim=0)
|
103 |
+
|
104 |
+
|
105 |
+
class VideoCaptionDatasetBase(torch.utils.data.Dataset):
|
106 |
+
def __init__(self, dataset, root, metadata, is_trimmed=True):
|
107 |
+
self.dataset = dataset
|
108 |
+
self.root = root
|
109 |
+
self.is_trimmed = is_trimmed
|
110 |
+
|
111 |
+
if self.dataset == 'ego4d':
|
112 |
+
with open(metadata, 'rb') as f:
|
113 |
+
self.samples = pickle.load(f)
|
114 |
+
elif self.dataset == 'ego4d_mcq':
|
115 |
+
with open(metadata, 'r') as f:
|
116 |
+
self.samples = json.load(f)
|
117 |
+
elif self.dataset in ['ek100_cls', 'ek100_mir']:
|
118 |
+
video_list = glob.glob(osp.join(self.root, '*/*.MP4'))
|
119 |
+
fps_dict = {video: decord.VideoReader(video).get_avg_fps() for video in video_list}
|
120 |
+
self.samples = []
|
121 |
+
with open(metadata) as f:
|
122 |
+
csv_reader = csv.reader(f)
|
123 |
+
_ = next(csv_reader) # skip the header
|
124 |
+
for row in csv_reader:
|
125 |
+
pid, vid = row[1:3]
|
126 |
+
# start_frame, end_frame = int(row[6]), int(row[7])
|
127 |
+
# Deprecated: some videos might have fps mismatch issue
|
128 |
+
start_timestamp, end_timestamp = datetime2sec(row[4]), datetime2sec(row[5])
|
129 |
+
narration = row[8]
|
130 |
+
verb, noun = int(row[10]), int(row[12])
|
131 |
+
vid_path = '{}/{}.MP4'.format(pid, vid)
|
132 |
+
fps = fps_dict[osp.join(self.root, vid_path)]
|
133 |
+
start_frame = int(np.round(fps * start_timestamp))
|
134 |
+
end_frame = int(np.ceil(fps * end_timestamp))
|
135 |
+
self.samples.append((vid_path, start_frame, end_frame, narration, verb, noun))
|
136 |
+
if self.dataset == 'ek100_mir':
|
137 |
+
self.metadata_sentence = pd.read_csv(metadata[:metadata.index('.csv')] + '_sentence.csv')
|
138 |
+
if 'train' in metadata:
|
139 |
+
self.relevancy_mat = pickle.load(open(osp.join(osp.dirname(metadata), 'relevancy', 'caption_relevancy_EPIC_100_retrieval_train.pkl'), 'rb'))
|
140 |
+
elif 'test' in metadata:
|
141 |
+
self.relevancy_mat = pickle.load(open(osp.join(osp.dirname(metadata), 'relevancy', 'caption_relevancy_EPIC_100_retrieval_test.pkl'), 'rb'))
|
142 |
+
else:
|
143 |
+
raise ValueError('{} should contain either "train" or "test"!'.format(metadata))
|
144 |
+
self.relevancy = .1
|
145 |
+
elif self.dataset == 'egtea':
|
146 |
+
video_list = glob.glob(osp.join(self.root, '*/*'))
|
147 |
+
len_dict = {video: len(decord.VideoReader(video)) for video in video_list}
|
148 |
+
|
149 |
+
vn_list, labels = [], []
|
150 |
+
for row in open(osp.join(osp.dirname(metadata), 'action_idx.txt')):
|
151 |
+
row = row.strip()
|
152 |
+
vn = int(row.split(' ')[-1])
|
153 |
+
vn_list.append(vn)
|
154 |
+
narration = ' '.join(row.split(' ')[:-1])
|
155 |
+
labels.append(narration.replace('_', ' ').lower())
|
156 |
+
# labels.append(narration)
|
157 |
+
mapping_act2narration = {vn: narration for vn, narration in zip(vn_list, labels)}
|
158 |
+
|
159 |
+
self.samples = []
|
160 |
+
with open(metadata) as f:
|
161 |
+
for row in f:
|
162 |
+
clip_id, action_idx = row.strip().split(' ')[:2]
|
163 |
+
video_id = '-'.join(clip_id.split('-')[:3])
|
164 |
+
vid_relpath = osp.join(video_id, '{}.mp4'.format(clip_id))
|
165 |
+
vid_fullpath = osp.join(self.root, video_id, '{}.mp4'.format(clip_id))
|
166 |
+
self.samples.append((vid_relpath, 0, len_dict[vid_fullpath], mapping_act2narration[int(action_idx)]))
|
167 |
+
elif self.dataset == 'charades_ego':
|
168 |
+
video_list = glob.glob(osp.join(self.root, '*.mp4'))
|
169 |
+
fps_dict = {video: decord.VideoReader(video).get_avg_fps() for video in video_list}
|
170 |
+
self.samples = []
|
171 |
+
with open(metadata) as f:
|
172 |
+
csv_reader = csv.reader(f)
|
173 |
+
_ = next(csv_reader) # skip the header
|
174 |
+
for row in csv_reader:
|
175 |
+
video_id = row[0]
|
176 |
+
if self.is_trimmed:
|
177 |
+
for action_tuple in row[9].split(';'):
|
178 |
+
if not action_tuple:
|
179 |
+
continue
|
180 |
+
action, start_timestamp, end_timestamp = action_tuple.split(' ')
|
181 |
+
start_timestamp, end_timestamp = float(start_timestamp), float(end_timestamp)
|
182 |
+
vid_path = '{}.mp4'.format(video_id)
|
183 |
+
fps = fps_dict[osp.join(self.root, vid_path)]
|
184 |
+
start_frame = int(np.round(fps * start_timestamp))
|
185 |
+
end_frame = int(np.ceil(fps * end_timestamp))
|
186 |
+
self.samples.append((vid_path, start_frame, end_frame, action))
|
187 |
+
else:
|
188 |
+
if not row[9]:
|
189 |
+
action_list = []
|
190 |
+
else:
|
191 |
+
action_list = [action_tuple.split(' ')[0] for action_tuple in row[9].split(';')]
|
192 |
+
vid_path = '{}.mp4'.format(video_id)
|
193 |
+
fps = fps_dict[osp.join(self.root, vid_path)]
|
194 |
+
duration = fps * float(row[10])
|
195 |
+
self.samples.append((vid_path, 0, duration, action_list))
|
196 |
+
elif self.dataset == 'charades_ego_trimmed':
|
197 |
+
with open(metadata, 'rb') as f:
|
198 |
+
self.samples = pickle.load(f)
|
199 |
+
else:
|
200 |
+
raise NotImplementedError
|
201 |
+
|
202 |
+
def get_raw_item(self, i, is_training=True, num_clips=1, clip_length=32, clip_stride=2, sparse_sample=False,
|
203 |
+
narration_selection='random'):
|
204 |
+
if self.dataset == 'ego4d':
|
205 |
+
if len(self.samples[i]) == 4:
|
206 |
+
vid, start_second, end_second, narration = self.samples[i]
|
207 |
+
frames = video_loader(self.root, vid, start_second,
|
208 |
+
end_second=end_second,
|
209 |
+
clip_length=clip_length,
|
210 |
+
jitter=is_training)
|
211 |
+
if isinstance(narration, list):
|
212 |
+
if narration_selection == 'random':
|
213 |
+
narration = random.choice(narration)
|
214 |
+
elif narration_selection == 'concat':
|
215 |
+
narration = '. '.join(narration)
|
216 |
+
elif narration_selection == 'list':
|
217 |
+
narration = narration
|
218 |
+
else:
|
219 |
+
raise ValueError
|
220 |
+
return frames, narration
|
221 |
+
elif len(self.samples[i]) == 5:
|
222 |
+
# TODO: need better filtering strategy based on nll
|
223 |
+
vid, start_second, end_second, narration, _ = self.samples[i]
|
224 |
+
frames = video_loader(self.root, vid, start_second,
|
225 |
+
end_second=end_second,
|
226 |
+
clip_length=clip_length,
|
227 |
+
jitter=is_training)
|
228 |
+
if isinstance(narration, list):
|
229 |
+
if narration_selection == 'random':
|
230 |
+
narration = random.choice(narration)
|
231 |
+
elif narration_selection == 'concat':
|
232 |
+
narration = '. '.join(narration)
|
233 |
+
elif narration_selection == 'list':
|
234 |
+
narration = narration
|
235 |
+
else:
|
236 |
+
raise ValueError
|
237 |
+
return frames, narration
|
238 |
+
elif self.dataset == 'ego4d_mcq':
|
239 |
+
itemMCQ = self.samples[str(i)]
|
240 |
+
answerIndex = itemMCQ['answer']
|
241 |
+
textQuery = itemMCQ['query']['clip_text']
|
242 |
+
sampleOptions = itemMCQ['choices']
|
243 |
+
frames_options = []
|
244 |
+
narration_options = []
|
245 |
+
for option_id in range(len(sampleOptions)):
|
246 |
+
option = sampleOptions[str(option_id)]
|
247 |
+
frames = video_loader(self.root, option['video_uid'],
|
248 |
+
float(option['clip_start']), end_second=float(option['clip_end']),
|
249 |
+
clip_length=clip_length,
|
250 |
+
jitter=is_training)
|
251 |
+
frames_options.append(frames)
|
252 |
+
narration_options.append(option['clip_text'])
|
253 |
+
return textQuery, frames_options, narration_options, answerIndex, itemMCQ['types']
|
254 |
+
elif self.dataset == 'ek100_mir':
|
255 |
+
vid_path, start_frame, end_frame, narration, verb, noun = self.samples[i]
|
256 |
+
# from third_party.EgoVLP.base.base_dataset import sample_frames_start_end
|
257 |
+
# frame_ids = sample_frames_start_end(clip_length, start_frame, end_frame, sample='uniform', fix_start=None)
|
258 |
+
frame_ids = get_frame_ids(start_frame, end_frame, num_segments=clip_length, jitter=is_training)
|
259 |
+
frames = video_loader_by_frames(self.root, vid_path, frame_ids)
|
260 |
+
if is_training:
|
261 |
+
positive_list = np.where(self.relevancy_mat[i] > self.relevancy)[0].tolist()
|
262 |
+
if positive_list != []:
|
263 |
+
pos = random.sample(positive_list, min(len(positive_list), 1))[0]
|
264 |
+
if pos < len(self.metadata_sentence) and pos < self.relevancy_mat.shape[1]:
|
265 |
+
return frames, (self.metadata_sentence.iloc[pos][1], self.relevancy_mat[i][pos])
|
266 |
+
else:
|
267 |
+
return frames, (narration, 1)
|
268 |
+
elif self.dataset == 'ek100_cls':
|
269 |
+
vid_path, start_frame, end_frame, narration, verb, noun = self.samples[i]
|
270 |
+
frame_ids = get_frame_ids(start_frame, end_frame, num_segments=clip_length, jitter=is_training)
|
271 |
+
frames = video_loader_by_frames(self.root, vid_path, frame_ids)
|
272 |
+
return frames, '{}:{}'.format(verb, noun)
|
273 |
+
elif self.dataset == 'egtea':
|
274 |
+
vid_path, start_frame, end_frame, sentence = self.samples[i]
|
275 |
+
if is_training:
|
276 |
+
assert num_clips == 1
|
277 |
+
if end_frame < clip_length * clip_stride:
|
278 |
+
frames = video_loader_by_frames(self.root, vid_path, list(np.arange(0, end_frame)))
|
279 |
+
zeros = torch.zeros((clip_length * clip_stride - end_frame, *frames.shape[1:]))
|
280 |
+
frames = torch.cat((frames, zeros), dim=0)
|
281 |
+
frames = frames[::clip_stride]
|
282 |
+
else:
|
283 |
+
start_id = np.random.randint(0, end_frame - clip_length * clip_stride + 1)
|
284 |
+
frame_ids = np.arange(start_id, start_id + clip_length * clip_stride, clip_stride)
|
285 |
+
frames = video_loader_by_frames(self.root, vid_path, frame_ids)
|
286 |
+
else:
|
287 |
+
if end_frame < clip_length * clip_stride:
|
288 |
+
frames = video_loader_by_frames(self.root, vid_path, list(np.arange(0, end_frame)))
|
289 |
+
zeros = torch.zeros((clip_length * clip_stride - end_frame, *frames.shape[1:]))
|
290 |
+
frames = torch.cat((frames, zeros), dim=0)
|
291 |
+
frames = frames[::clip_stride]
|
292 |
+
frames = frames.repeat(num_clips, 1, 1, 1)
|
293 |
+
else:
|
294 |
+
frame_ids = []
|
295 |
+
for start_id in np.linspace(0, end_frame - clip_length * clip_stride, num_clips, dtype=int):
|
296 |
+
frame_ids.extend(np.arange(start_id, start_id + clip_length * clip_stride, clip_stride))
|
297 |
+
frames = video_loader_by_frames(self.root, vid_path, frame_ids)
|
298 |
+
return frames, sentence
|
299 |
+
elif self.dataset == 'charades_ego':
|
300 |
+
vid_path, start_frame, end_frame, action_list = self.samples[i]
|
301 |
+
if sparse_sample:
|
302 |
+
frame_ids = get_frame_ids(start_frame, end_frame, num_segments=num_clips * clip_length, jitter=is_training)
|
303 |
+
frames = video_loader_by_frames(self.root, vid_path, frame_ids)
|
304 |
+
else:
|
305 |
+
if end_frame < clip_length * clip_stride:
|
306 |
+
frames = video_loader_by_frames(self.root, vid_path, list(np.arange(0, end_frame)))
|
307 |
+
zeros = torch.zeros((clip_length * clip_stride - end_frame, *frames.shape[1:]))
|
308 |
+
frames = torch.cat((frames, zeros), dim=0)
|
309 |
+
frames = frames[::clip_stride]
|
310 |
+
frames = frames.repeat(num_clips, 1, 1, 1)
|
311 |
+
else:
|
312 |
+
frame_ids = []
|
313 |
+
for start_id in np.linspace(0, end_frame - clip_length * clip_stride, num_clips, dtype=int):
|
314 |
+
frame_ids.extend(np.arange(start_id, start_id + clip_length * clip_stride, clip_stride))
|
315 |
+
print('frame_ids:', frame_ids)
|
316 |
+
frames = video_loader_by_frames(self.root, vid_path, frame_ids)
|
317 |
+
return frames, action_list
|
318 |
+
elif self.dataset == 'charades_ego_trimmed':
|
319 |
+
vid, start_second, end_second, narration = self.samples[i]
|
320 |
+
frames = video_loader(self.root, vid, start_second,
|
321 |
+
end_second=end_second,
|
322 |
+
chunk_len=-1, # no chunk for CharadesEgo
|
323 |
+
fps=-1, # could be variable fps
|
324 |
+
clip_length=clip_length,
|
325 |
+
jitter=is_training)
|
326 |
+
return frames, narration
|
327 |
+
else:
|
328 |
+
raise NotImplementedError
|
329 |
+
|
330 |
+
def __getitem__(self, i):
|
331 |
+
raise NotImplementedError
|
332 |
+
|
333 |
+
def __len__(self):
|
334 |
+
return len(self.samples)
|
335 |
+
|
336 |
+
|
337 |
+
class VideoCaptionDatasetCLIP(VideoCaptionDatasetBase):
|
338 |
+
def __init__(self, dataset, root, metadata, transform=None,
|
339 |
+
is_training=True, tokenizer=None,
|
340 |
+
clip_length=32, clip_stride=2, sparse_sample=False,
|
341 |
+
narration_selection='random',
|
342 |
+
num_hard_negatives=0,
|
343 |
+
subsample_stride=None):
|
344 |
+
super().__init__(dataset, root, metadata)
|
345 |
+
|
346 |
+
self.full_samples = self.samples.copy()
|
347 |
+
if isinstance(subsample_stride, int):
|
348 |
+
self.samples = self.samples[::subsample_stride]
|
349 |
+
self.transform = transform
|
350 |
+
self.is_training = is_training
|
351 |
+
self.tokenizer = tokenizer
|
352 |
+
self.clip_length = clip_length
|
353 |
+
self.clip_stride = clip_stride
|
354 |
+
self.sparse_sample = sparse_sample
|
355 |
+
self.narration_selection = narration_selection
|
356 |
+
self.num_hard_negatives = num_hard_negatives
|
357 |
+
if num_hard_negatives > 0:
|
358 |
+
assert self.dataset == 'htm_aa'
|
359 |
+
|
360 |
+
def __getitem__(self, i):
|
361 |
+
frames, caption = self.get_raw_item(
|
362 |
+
i, is_training=self.is_training,
|
363 |
+
clip_length=self.clip_length,
|
364 |
+
clip_stride=self.clip_stride,
|
365 |
+
sparse_sample=self.sparse_sample,
|
366 |
+
narration_selection=self.narration_selection,
|
367 |
+
)
|
368 |
+
|
369 |
+
# ek100_mir will also output relevancy value
|
370 |
+
if isinstance(caption, tuple):
|
371 |
+
caption, relevancy = caption
|
372 |
+
else:
|
373 |
+
relevancy = 0.
|
374 |
+
|
375 |
+
# apply transformation
|
376 |
+
if self.transform is not None:
|
377 |
+
frames = self.transform(frames)
|
378 |
+
|
379 |
+
# tokenize caption
|
380 |
+
if self.tokenizer is not None:
|
381 |
+
caption = self.tokenizer(caption)
|
382 |
+
|
383 |
+
if isinstance(caption, tuple):
|
384 |
+
caption, mask = caption
|
385 |
+
return frames, caption, mask, relevancy
|
386 |
+
else:
|
387 |
+
return frames, caption, relevancy
|
388 |
+
|
389 |
+
|
390 |
+
class VideoCaptionDatasetMCQ(VideoCaptionDatasetBase):
|
391 |
+
def __init__(self, dataset, root, metadata, transform=None,
|
392 |
+
is_training=True, tokenizer=None,
|
393 |
+
clip_length=32, clip_stride=2, sparse_sample=False,
|
394 |
+
narration_selection='random'):
|
395 |
+
super().__init__(dataset, root, metadata)
|
396 |
+
|
397 |
+
self.full_samples = self.samples.copy()
|
398 |
+
self.transform = transform
|
399 |
+
self.is_training = is_training
|
400 |
+
self.tokenizer = tokenizer
|
401 |
+
self.clip_length = clip_length
|
402 |
+
self.clip_stride = clip_stride
|
403 |
+
self.sparse_sample = sparse_sample
|
404 |
+
self.narration_selection = narration_selection
|
405 |
+
|
406 |
+
def __getitem__(self, i):
|
407 |
+
|
408 |
+
textQuery, frames_options, narration_options, answerIndex, q_type = self.get_raw_item(
|
409 |
+
i, is_training=self.is_training,
|
410 |
+
clip_length=self.clip_length,
|
411 |
+
clip_stride=self.clip_stride,
|
412 |
+
sparse_sample=self.sparse_sample,
|
413 |
+
narration_selection=self.narration_selection,
|
414 |
+
)
|
415 |
+
|
416 |
+
# apply transformation
|
417 |
+
if self.transform is not None:
|
418 |
+
frames_options = [self.transform(frames) for frames in frames_options]
|
419 |
+
|
420 |
+
# tokenize caption
|
421 |
+
if self.tokenizer is not None:
|
422 |
+
textQuery = self.tokenizer(textQuery)
|
423 |
+
narration_options = self.tokenizer(narration_options)
|
424 |
+
if isinstance(textQuery, tuple):
|
425 |
+
textQuery, mask_query = textQuery
|
426 |
+
narration_options, mask_options = narration_options
|
427 |
+
return (
|
428 |
+
textQuery, torch.stack(frames_options, dim=0),
|
429 |
+
narration_options, answerIndex, q_type,
|
430 |
+
mask_query, mask_options
|
431 |
+
)
|
432 |
+
else:
|
433 |
+
return textQuery, torch.stack(frames_options, dim=0), narration_options, answerIndex, q_type
|
434 |
+
|
435 |
+
|
436 |
+
class VideoClassyDataset(VideoCaptionDatasetBase):
|
437 |
+
def __init__(
|
438 |
+
self, dataset, root, metadata, transform=None,
|
439 |
+
is_training=True, label_mapping=None,
|
440 |
+
num_clips=1,
|
441 |
+
clip_length=32, clip_stride=2,
|
442 |
+
sparse_sample=False,
|
443 |
+
is_trimmed=True,
|
444 |
+
):
|
445 |
+
super().__init__(dataset, root, metadata, is_trimmed=is_trimmed)
|
446 |
+
|
447 |
+
self.transform = transform
|
448 |
+
self.is_training = is_training
|
449 |
+
self.label_mapping = label_mapping
|
450 |
+
self.num_clips = num_clips
|
451 |
+
self.clip_length = clip_length
|
452 |
+
self.clip_stride = clip_stride
|
453 |
+
self.sparse_sample = sparse_sample
|
454 |
+
|
455 |
+
def __getitem__(self, i):
|
456 |
+
frames, label = self.get_raw_item(
|
457 |
+
i, is_training=self.is_training,
|
458 |
+
num_clips=self.num_clips,
|
459 |
+
clip_length=self.clip_length,
|
460 |
+
clip_stride=self.clip_stride,
|
461 |
+
sparse_sample=self.sparse_sample,
|
462 |
+
)
|
463 |
+
|
464 |
+
# apply transformation
|
465 |
+
if self.transform is not None:
|
466 |
+
frames = self.transform(frames)
|
467 |
+
|
468 |
+
if self.label_mapping is not None:
|
469 |
+
if isinstance(label, list):
|
470 |
+
# multi-label case
|
471 |
+
res_array = np.zeros(len(self.label_mapping))
|
472 |
+
for lbl in label:
|
473 |
+
res_array[self.label_mapping[lbl]] = 1.
|
474 |
+
label = res_array
|
475 |
+
else:
|
476 |
+
label = self.label_mapping[label]
|
477 |
+
|
478 |
+
return frames, label
|
479 |
+
|
480 |
+
|
481 |
+
def get_dataset(train_transform, tokenizer, args, is_training=True):
|
482 |
+
if 'narration_selection' not in args:
|
483 |
+
args.narration_selection = 'random'
|
484 |
+
if args.model.startswith('CLIP') or args.model.startswith('VCLM'):
|
485 |
+
return VideoCaptionDatasetCLIP(
|
486 |
+
args.dataset, args.root, args.metadata, train_transform,
|
487 |
+
is_training=is_training,
|
488 |
+
tokenizer=tokenizer,
|
489 |
+
clip_length=args.clip_length, clip_stride=args.clip_stride,
|
490 |
+
sparse_sample=args.sparse_sample,
|
491 |
+
narration_selection=args.narration_selection,
|
492 |
+
num_hard_negatives=args.num_hard_neg if 'num_hard_neg' in args else 0,
|
493 |
+
)
|
494 |
+
else:
|
495 |
+
raise NotImplementedError
|
496 |
+
|
497 |
+
|
498 |
+
def get_downstream_dataset(transform, tokenizer, args, subset='train', label_mapping=None):
|
499 |
+
if subset == 'train':
|
500 |
+
return VideoClassyDataset(
|
501 |
+
args.dataset, args.root, args.metadata_train, transform,
|
502 |
+
is_training=True, label_mapping=label_mapping,
|
503 |
+
num_clips=args.num_clips,
|
504 |
+
clip_length=args.clip_length, clip_stride=args.clip_stride,
|
505 |
+
sparse_sample=args.sparse_sample,
|
506 |
+
)
|
507 |
+
elif subset == 'val':
|
508 |
+
return VideoClassyDataset(
|
509 |
+
args.dataset, args.root, args.metadata_val, transform,
|
510 |
+
is_training=False, label_mapping=label_mapping,
|
511 |
+
num_clips=args.num_clips,
|
512 |
+
clip_length=args.clip_length, clip_stride=args.clip_stride,
|
513 |
+
sparse_sample=args.sparse_sample,
|
514 |
+
is_trimmed=not args.dataset == 'charades_ego'
|
515 |
+
)
|
516 |
+
else:
|
517 |
+
assert ValueError("subset should be either 'train' or 'val'")
|
lavila/data/video_transforms.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import math
|
8 |
+
from typing import Sequence
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torchvision import transforms
|
12 |
+
|
13 |
+
|
14 |
+
class Permute(nn.Module):
|
15 |
+
"""
|
16 |
+
Permutation as an op
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, ordering):
|
20 |
+
super().__init__()
|
21 |
+
self.ordering = ordering
|
22 |
+
|
23 |
+
def forward(self, frames):
|
24 |
+
"""
|
25 |
+
Args:
|
26 |
+
frames in some ordering, by default (C, T, H, W)
|
27 |
+
Returns:
|
28 |
+
frames in the ordering that was specified
|
29 |
+
"""
|
30 |
+
return frames.permute(self.ordering)
|
31 |
+
|
32 |
+
|
33 |
+
class TemporalCrop(nn.Module):
|
34 |
+
"""
|
35 |
+
Convert the video into smaller clips temporally.
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(
|
39 |
+
self, frames_per_clip: int = 8, stride: int = 8, frame_stride: int = 1
|
40 |
+
):
|
41 |
+
super().__init__()
|
42 |
+
self.frames = frames_per_clip
|
43 |
+
self.stride = stride
|
44 |
+
self.frame_stride = frame_stride
|
45 |
+
|
46 |
+
def forward(self, video):
|
47 |
+
assert video.ndim == 4, "Must be (C, T, H, W)"
|
48 |
+
res = []
|
49 |
+
for start in range(
|
50 |
+
0, video.size(1) - (self.frames * self.frame_stride) + 1, self.stride
|
51 |
+
):
|
52 |
+
end = start + (self.frames) * self.frame_stride
|
53 |
+
res.append(video[:, start: end: self.frame_stride, ...])
|
54 |
+
return res
|
55 |
+
|
56 |
+
|
57 |
+
def crop_boxes(boxes, x_offset, y_offset):
|
58 |
+
"""
|
59 |
+
Peform crop on the bounding boxes given the offsets.
|
60 |
+
Args:
|
61 |
+
boxes (ndarray or None): bounding boxes to peform crop. The dimension
|
62 |
+
is `num boxes` x 4.
|
63 |
+
x_offset (int): cropping offset in the x axis.
|
64 |
+
y_offset (int): cropping offset in the y axis.
|
65 |
+
Returns:
|
66 |
+
cropped_boxes (ndarray or None): the cropped boxes with dimension of
|
67 |
+
`num boxes` x 4.
|
68 |
+
"""
|
69 |
+
cropped_boxes = boxes.copy()
|
70 |
+
cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset
|
71 |
+
cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset
|
72 |
+
|
73 |
+
return cropped_boxes
|
74 |
+
|
75 |
+
|
76 |
+
def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None):
|
77 |
+
"""
|
78 |
+
Perform uniform spatial sampling on the images and corresponding boxes.
|
79 |
+
Args:
|
80 |
+
images (tensor): images to perform uniform crop. The dimension is
|
81 |
+
`num frames` x `channel` x `height` x `width`.
|
82 |
+
size (int): size of height and weight to crop the images.
|
83 |
+
spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
|
84 |
+
is larger than height. Or 0, 1, or 2 for top, center, and bottom
|
85 |
+
crop if height is larger than width.
|
86 |
+
boxes (ndarray or None): optional. Corresponding boxes to images.
|
87 |
+
Dimension is `num boxes` x 4.
|
88 |
+
scale_size (int): optinal. If not None, resize the images to scale_size before
|
89 |
+
performing any crop.
|
90 |
+
Returns:
|
91 |
+
cropped (tensor): images with dimension of
|
92 |
+
`num frames` x `channel` x `size` x `size`.
|
93 |
+
cropped_boxes (ndarray or None): the cropped boxes with dimension of
|
94 |
+
`num boxes` x 4.
|
95 |
+
"""
|
96 |
+
assert spatial_idx in [0, 1, 2]
|
97 |
+
ndim = len(images.shape)
|
98 |
+
if ndim == 3:
|
99 |
+
images = images.unsqueeze(0)
|
100 |
+
height = images.shape[2]
|
101 |
+
width = images.shape[3]
|
102 |
+
|
103 |
+
if scale_size is not None:
|
104 |
+
if width <= height:
|
105 |
+
width, height = scale_size, int(height / width * scale_size)
|
106 |
+
else:
|
107 |
+
width, height = int(width / height * scale_size), scale_size
|
108 |
+
images = torch.nn.functional.interpolate(
|
109 |
+
images,
|
110 |
+
size=(height, width),
|
111 |
+
mode="bilinear",
|
112 |
+
align_corners=False,
|
113 |
+
)
|
114 |
+
|
115 |
+
y_offset = int(math.ceil((height - size) / 2))
|
116 |
+
x_offset = int(math.ceil((width - size) / 2))
|
117 |
+
|
118 |
+
if height > width:
|
119 |
+
if spatial_idx == 0:
|
120 |
+
y_offset = 0
|
121 |
+
elif spatial_idx == 2:
|
122 |
+
y_offset = height - size
|
123 |
+
else:
|
124 |
+
if spatial_idx == 0:
|
125 |
+
x_offset = 0
|
126 |
+
elif spatial_idx == 2:
|
127 |
+
x_offset = width - size
|
128 |
+
cropped = images[:, :, y_offset: y_offset + size, x_offset: x_offset + size]
|
129 |
+
cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
|
130 |
+
if ndim == 3:
|
131 |
+
cropped = cropped.squeeze(0)
|
132 |
+
return cropped, cropped_boxes
|
133 |
+
|
134 |
+
|
135 |
+
class SpatialCrop(nn.Module):
|
136 |
+
"""
|
137 |
+
Convert the video into 3 smaller clips spatially. Must be used after the
|
138 |
+
temporal crops to get spatial crops, and should be used with
|
139 |
+
-2 in the spatial crop at the slowfast augmentation stage (so full
|
140 |
+
frames are passed in here). Will return a larger list with the
|
141 |
+
3x spatial crops as well. It's useful for 3x4 testing (eg in SwinT)
|
142 |
+
or 3x10 testing in SlowFast etc.
|
143 |
+
"""
|
144 |
+
|
145 |
+
def __init__(self, crop_size: int = 224, num_crops: int = 3):
|
146 |
+
super().__init__()
|
147 |
+
self.crop_size = crop_size
|
148 |
+
if num_crops == 6:
|
149 |
+
self.crops_to_ext = [0, 1, 2]
|
150 |
+
# I guess Swin uses 5 crops without flipping, but that doesn't
|
151 |
+
# make sense given they first resize to 224 and take 224 crops.
|
152 |
+
# (pg 6 of https://arxiv.org/pdf/2106.13230.pdf)
|
153 |
+
# So I'm assuming we can use flipped crops and that will add sth..
|
154 |
+
self.flipped_crops_to_ext = [0, 1, 2]
|
155 |
+
elif num_crops == 3:
|
156 |
+
self.crops_to_ext = [0, 1, 2]
|
157 |
+
self.flipped_crops_to_ext = []
|
158 |
+
elif num_crops == 1:
|
159 |
+
self.crops_to_ext = [1]
|
160 |
+
self.flipped_crops_to_ext = []
|
161 |
+
else:
|
162 |
+
raise NotImplementedError(
|
163 |
+
"Nothing else supported yet, "
|
164 |
+
"slowfast only takes 0, 1, 2 as arguments"
|
165 |
+
)
|
166 |
+
|
167 |
+
def forward(self, videos: Sequence[torch.Tensor]):
|
168 |
+
"""
|
169 |
+
Args:
|
170 |
+
videos: A list of C, T, H, W videos.
|
171 |
+
Returns:
|
172 |
+
videos: A list with 3x the number of elements. Each video converted
|
173 |
+
to C, T, H', W' by spatial cropping.
|
174 |
+
"""
|
175 |
+
assert isinstance(videos, list), "Must be a list of videos after temporal crops"
|
176 |
+
assert all([video.ndim == 4 for video in videos]), "Must be (C,T,H,W)"
|
177 |
+
res = []
|
178 |
+
for video in videos:
|
179 |
+
for spatial_idx in self.crops_to_ext:
|
180 |
+
res.append(uniform_crop(video, self.crop_size, spatial_idx)[0])
|
181 |
+
if not self.flipped_crops_to_ext:
|
182 |
+
continue
|
183 |
+
flipped_video = transforms.functional.hflip(video)
|
184 |
+
for spatial_idx in self.flipped_crops_to_ext:
|
185 |
+
res.append(uniform_crop(flipped_video, self.crop_size, spatial_idx)[0])
|
186 |
+
return res
|
lavila/models/__pycache__/distributed_utils.cpython-38.pyc
ADDED
Binary file (2.95 kB). View file
|
|
lavila/models/__pycache__/gpt2_gated.cpython-38.pyc
ADDED
Binary file (46.9 kB). View file
|
|
lavila/models/__pycache__/loss.cpython-38.pyc
ADDED
Binary file (8.63 kB). View file
|
|
lavila/models/__pycache__/models.cpython-38.pyc
ADDED
Binary file (22.6 kB). View file
|
|
lavila/models/bpe_simple_vocab_16e6.txt.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
3 |
+
size 1356917
|
lavila/models/coca.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Part of the code is from https://github.com/lucidrains/CoCa-pytorch/blob/main/coca_pytorch/coca_pytorch.py
|
8 |
+
# Modified by Yue Zhao
|
9 |
+
# The original code is under MIT License
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from torch import einsum
|
15 |
+
from einops import rearrange
|
16 |
+
|
17 |
+
|
18 |
+
def exists(val):
|
19 |
+
return val is not None
|
20 |
+
|
21 |
+
|
22 |
+
def default(val, d):
|
23 |
+
return val if exists(val) else d
|
24 |
+
|
25 |
+
|
26 |
+
# normalization
|
27 |
+
# they use layernorm without bias, something that pytorch does not offer
|
28 |
+
class LayerNorm(nn.Module):
|
29 |
+
def __init__(self, dim):
|
30 |
+
super().__init__()
|
31 |
+
self.gamma = nn.Parameter(torch.ones(dim))
|
32 |
+
self.register_buffer("beta", torch.zeros(dim))
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
|
36 |
+
|
37 |
+
|
38 |
+
class Residual(nn.Module):
|
39 |
+
def __init__(self, fn):
|
40 |
+
super().__init__()
|
41 |
+
self.fn = fn
|
42 |
+
|
43 |
+
def forward(self, x, *args, **kwargs):
|
44 |
+
return self.fn(x, *args, **kwargs) + x
|
45 |
+
|
46 |
+
|
47 |
+
# classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU for gating the feedforward
|
48 |
+
# https://arxiv.org/abs/2002.05202
|
49 |
+
class SwiGLU(nn.Module):
|
50 |
+
def forward(self, x):
|
51 |
+
x, gate = x.chunk(2, dim=-1)
|
52 |
+
return F.silu(gate) * x
|
53 |
+
|
54 |
+
|
55 |
+
class CrossAttention(nn.Module):
|
56 |
+
def __init__(
|
57 |
+
self,
|
58 |
+
dim,
|
59 |
+
*,
|
60 |
+
context_dim=None,
|
61 |
+
dim_head=64,
|
62 |
+
heads=8,
|
63 |
+
parallel_ff=False,
|
64 |
+
ff_mult=4,
|
65 |
+
norm_context=False
|
66 |
+
):
|
67 |
+
super().__init__()
|
68 |
+
self.heads = heads
|
69 |
+
self.scale = dim_head ** -0.5
|
70 |
+
inner_dim = heads * dim_head
|
71 |
+
context_dim = default(context_dim, dim)
|
72 |
+
|
73 |
+
self.norm = LayerNorm(dim)
|
74 |
+
self.context_norm = LayerNorm(context_dim) if norm_context else nn.Identity()
|
75 |
+
|
76 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
77 |
+
self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False)
|
78 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
79 |
+
|
80 |
+
# whether to have parallel feedforward
|
81 |
+
|
82 |
+
ff_inner_dim = ff_mult * dim
|
83 |
+
|
84 |
+
self.ff = nn.Sequential(
|
85 |
+
nn.Linear(dim, ff_inner_dim * 2, bias=False),
|
86 |
+
SwiGLU(),
|
87 |
+
nn.Linear(ff_inner_dim, dim, bias=False)
|
88 |
+
) if parallel_ff else None
|
89 |
+
|
90 |
+
def forward(self, x, context):
|
91 |
+
"""
|
92 |
+
einstein notation
|
93 |
+
b - batch
|
94 |
+
h - heads
|
95 |
+
n, i, j - sequence length (base sequence length, source, target)
|
96 |
+
d - feature dimension
|
97 |
+
"""
|
98 |
+
|
99 |
+
# pre-layernorm, for queries and context
|
100 |
+
x = self.norm(x)
|
101 |
+
context = self.context_norm(context)
|
102 |
+
|
103 |
+
# get queries
|
104 |
+
q = self.to_q(x)
|
105 |
+
q = rearrange(q, 'b n (h d) -> b h n d', h=self.heads)
|
106 |
+
|
107 |
+
# scale
|
108 |
+
q = q * self.scale
|
109 |
+
|
110 |
+
# get key / values
|
111 |
+
k, v = self.to_kv(context).chunk(2, dim=-1)
|
112 |
+
|
113 |
+
# query / key similarity
|
114 |
+
sim = einsum('b h i d, b j d -> b h i j', q, k)
|
115 |
+
|
116 |
+
# attention
|
117 |
+
sim = sim - sim.amax(dim=-1, keepdim=True)
|
118 |
+
attn = sim.softmax(dim=-1)
|
119 |
+
|
120 |
+
# aggregate
|
121 |
+
out = einsum('b h i j, b j d -> b h i d', attn, v)
|
122 |
+
|
123 |
+
# merge and combine heads
|
124 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
125 |
+
out = self.to_out(out)
|
126 |
+
|
127 |
+
# add parallel feedforward (for multimodal layers)
|
128 |
+
if exists(self.ff):
|
129 |
+
out = out + self.ff(x)
|
130 |
+
|
131 |
+
return out
|
lavila/models/distributed_utils.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# Part of the code is from
|
7 |
+
# `https://github.com/facebookresearch/vissl/blob/main/vissl/utils/distributed_utils.py` and
|
8 |
+
# `https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/generic/distributed_util.py`
|
9 |
+
# Modified by Yue Zhao
|
10 |
+
# The original code is under MIT License
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.distributed as dist
|
14 |
+
from typing import Tuple
|
15 |
+
|
16 |
+
|
17 |
+
def convert_to_distributed_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, str]:
|
18 |
+
"""
|
19 |
+
For some backends, such as NCCL, communication only works if the
|
20 |
+
tensor is on the GPU. This helper function converts to the correct
|
21 |
+
device and returns the tensor + original device.
|
22 |
+
"""
|
23 |
+
orig_device = "cpu" if not tensor.is_cuda else "gpu"
|
24 |
+
if (
|
25 |
+
torch.distributed.is_available()
|
26 |
+
and torch.distributed.get_backend() == torch.distributed.Backend.NCCL
|
27 |
+
and not tensor.is_cuda
|
28 |
+
):
|
29 |
+
tensor = tensor.cuda()
|
30 |
+
return (tensor, orig_device)
|
31 |
+
|
32 |
+
|
33 |
+
def convert_to_normal_tensor(tensor: torch.Tensor, orig_device: str) -> torch.Tensor:
|
34 |
+
"""
|
35 |
+
For some backends, such as NCCL, communication only works if the
|
36 |
+
tensor is on the GPU. This converts the tensor back to original device.
|
37 |
+
"""
|
38 |
+
if tensor.is_cuda and orig_device == "cpu":
|
39 |
+
tensor = tensor.cpu()
|
40 |
+
return tensor
|
41 |
+
|
42 |
+
|
43 |
+
def is_distributed_training_run() -> bool:
|
44 |
+
return (
|
45 |
+
torch.distributed.is_available()
|
46 |
+
and torch.distributed.is_initialized()
|
47 |
+
and (torch.distributed.get_world_size() > 1)
|
48 |
+
)
|
49 |
+
|
50 |
+
|
51 |
+
class GatherLayer(torch.autograd.Function):
|
52 |
+
"""
|
53 |
+
Gather tensors from all workers with support for backward propagation:
|
54 |
+
This implementation does not cut the gradients as torch.distributed.all_gather does.
|
55 |
+
"""
|
56 |
+
|
57 |
+
@staticmethod
|
58 |
+
def forward(ctx, x):
|
59 |
+
output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
|
60 |
+
dist.all_gather(output, x)
|
61 |
+
return tuple(output)
|
62 |
+
|
63 |
+
@staticmethod
|
64 |
+
def backward(ctx, *grads):
|
65 |
+
all_gradients = torch.stack(grads)
|
66 |
+
dist.all_reduce(all_gradients)
|
67 |
+
return all_gradients[dist.get_rank()]
|
68 |
+
|
69 |
+
|
70 |
+
def gather_from_all(tensor: torch.Tensor) -> torch.Tensor:
|
71 |
+
"""
|
72 |
+
Similar to classy_vision.generic.distributed_util.gather_from_all
|
73 |
+
except that it does not cut the gradients
|
74 |
+
"""
|
75 |
+
if tensor.ndim == 0:
|
76 |
+
# 0 dim tensors cannot be gathered. so unsqueeze
|
77 |
+
tensor = tensor.unsqueeze(0)
|
78 |
+
|
79 |
+
if is_distributed_training_run():
|
80 |
+
tensor, orig_device = convert_to_distributed_tensor(tensor)
|
81 |
+
gathered_tensors = GatherLayer.apply(tensor)
|
82 |
+
gathered_tensors = [
|
83 |
+
convert_to_normal_tensor(_tensor, orig_device)
|
84 |
+
for _tensor in gathered_tensors
|
85 |
+
]
|
86 |
+
else:
|
87 |
+
gathered_tensors = [tensor]
|
88 |
+
gathered_tensor = torch.cat(gathered_tensors, 0)
|
89 |
+
return gathered_tensor
|
lavila/models/gpt2_gated.py
ADDED
@@ -0,0 +1,1615 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Part of the code is from https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
|
8 |
+
# Modified by Yue Zhao
|
9 |
+
#
|
10 |
+
#
|
11 |
+
# coding=utf-8
|
12 |
+
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
|
13 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
14 |
+
#
|
15 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
16 |
+
# you may not use this file except in compliance with the License.
|
17 |
+
# You may obtain a copy of the License at
|
18 |
+
#
|
19 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
20 |
+
#
|
21 |
+
# Unless required by applicable law or agreed to in writing, software
|
22 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
23 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
24 |
+
# See the License for the specific language governing permissions and
|
25 |
+
# limitations under the License.
|
26 |
+
"""PyTorch OpenAI GPT-2 model."""
|
27 |
+
|
28 |
+
import copy
|
29 |
+
import math
|
30 |
+
import os
|
31 |
+
from dataclasses import dataclass
|
32 |
+
from typing import Optional, Tuple, Union
|
33 |
+
|
34 |
+
import torch
|
35 |
+
import torch.utils.checkpoint
|
36 |
+
from packaging import version
|
37 |
+
from torch import nn
|
38 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
39 |
+
|
40 |
+
|
41 |
+
if version.parse(torch.__version__) >= version.parse("1.6"):
|
42 |
+
is_amp_available = True
|
43 |
+
from torch.cuda.amp import autocast
|
44 |
+
else:
|
45 |
+
is_amp_available = False
|
46 |
+
|
47 |
+
from transformers.activations import ACT2FN
|
48 |
+
from transformers.modeling_outputs import (
|
49 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
50 |
+
CausalLMOutputWithCrossAttentions,
|
51 |
+
SequenceClassifierOutputWithPast,
|
52 |
+
TokenClassifierOutput,
|
53 |
+
)
|
54 |
+
from transformers.modeling_utils import PreTrainedModel, SequenceSummary
|
55 |
+
from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
|
56 |
+
from transformers.utils import (
|
57 |
+
ModelOutput,
|
58 |
+
add_code_sample_docstrings,
|
59 |
+
add_start_docstrings,
|
60 |
+
add_start_docstrings_to_model_forward,
|
61 |
+
logging,
|
62 |
+
replace_return_docstrings,
|
63 |
+
)
|
64 |
+
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
|
65 |
+
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
66 |
+
|
67 |
+
|
68 |
+
logger = logging.get_logger(__name__)
|
69 |
+
|
70 |
+
_CHECKPOINT_FOR_DOC = "gpt2"
|
71 |
+
_CONFIG_FOR_DOC = "GPT2Config"
|
72 |
+
_TOKENIZER_FOR_DOC = "GPT2Tokenizer"
|
73 |
+
|
74 |
+
GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
75 |
+
"gpt2",
|
76 |
+
"gpt2-medium",
|
77 |
+
"gpt2-large",
|
78 |
+
"gpt2-xl",
|
79 |
+
"distilgpt2",
|
80 |
+
# See all GPT-2 models at https://huggingface.co/models?filter=gpt2
|
81 |
+
]
|
82 |
+
|
83 |
+
|
84 |
+
def augment_gpt2_config(config, cross_attn_freq=1, gated_xattn=True):
|
85 |
+
new_config = copy.deepcopy(config)
|
86 |
+
new_config.add_cross_attention = True
|
87 |
+
new_config.add_cross_attention_freq = cross_attn_freq
|
88 |
+
new_config.is_tanh_gating = gated_xattn
|
89 |
+
return new_config
|
90 |
+
|
91 |
+
|
92 |
+
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
|
93 |
+
"""Load tf checkpoints in a pytorch model"""
|
94 |
+
try:
|
95 |
+
import re
|
96 |
+
|
97 |
+
import tensorflow as tf
|
98 |
+
except ImportError:
|
99 |
+
logger.error(
|
100 |
+
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
101 |
+
"https://www.tensorflow.org/install/ for installation instructions."
|
102 |
+
)
|
103 |
+
raise
|
104 |
+
tf_path = os.path.abspath(gpt2_checkpoint_path)
|
105 |
+
logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
|
106 |
+
# Load weights from TF model
|
107 |
+
init_vars = tf.train.list_variables(tf_path)
|
108 |
+
names = []
|
109 |
+
arrays = []
|
110 |
+
for name, shape in init_vars:
|
111 |
+
logger.info(f"Loading TF weight {name} with shape {shape}")
|
112 |
+
array = tf.train.load_variable(tf_path, name)
|
113 |
+
names.append(name)
|
114 |
+
arrays.append(array.squeeze())
|
115 |
+
|
116 |
+
for name, array in zip(names, arrays):
|
117 |
+
name = name[6:] # skip "model/"
|
118 |
+
name = name.split("/")
|
119 |
+
pointer = model
|
120 |
+
for m_name in name:
|
121 |
+
if re.fullmatch(r"[A-Za-z]+\d+", m_name):
|
122 |
+
scope_names = re.split(r"(\d+)", m_name)
|
123 |
+
else:
|
124 |
+
scope_names = [m_name]
|
125 |
+
if scope_names[0] == "w" or scope_names[0] == "g":
|
126 |
+
pointer = getattr(pointer, "weight")
|
127 |
+
elif scope_names[0] == "b":
|
128 |
+
pointer = getattr(pointer, "bias")
|
129 |
+
elif scope_names[0] == "wpe" or scope_names[0] == "wte":
|
130 |
+
pointer = getattr(pointer, scope_names[0])
|
131 |
+
pointer = getattr(pointer, "weight")
|
132 |
+
else:
|
133 |
+
pointer = getattr(pointer, scope_names[0])
|
134 |
+
if len(scope_names) >= 2:
|
135 |
+
num = int(scope_names[1])
|
136 |
+
pointer = pointer[num]
|
137 |
+
try:
|
138 |
+
assert (
|
139 |
+
pointer.shape == array.shape
|
140 |
+
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
|
141 |
+
except AssertionError as e:
|
142 |
+
e.args += (pointer.shape, array.shape)
|
143 |
+
raise
|
144 |
+
logger.info(f"Initialize PyTorch weight {name}")
|
145 |
+
pointer.data = torch.from_numpy(array)
|
146 |
+
return model
|
147 |
+
|
148 |
+
|
149 |
+
class GPT2Attention(nn.Module):
|
150 |
+
def __init__(self, config, is_cross_attention=False, layer_idx=None):
|
151 |
+
super().__init__()
|
152 |
+
|
153 |
+
max_positions = config.max_position_embeddings
|
154 |
+
self.register_buffer(
|
155 |
+
"bias",
|
156 |
+
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
|
157 |
+
1, 1, max_positions, max_positions
|
158 |
+
),
|
159 |
+
)
|
160 |
+
self.register_buffer("masked_bias", torch.tensor(-1e4))
|
161 |
+
|
162 |
+
self.embed_dim = config.hidden_size
|
163 |
+
self.num_heads = config.num_attention_heads
|
164 |
+
self.head_dim = self.embed_dim // self.num_heads
|
165 |
+
self.split_size = self.embed_dim
|
166 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
167 |
+
raise ValueError(
|
168 |
+
f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
|
169 |
+
)
|
170 |
+
|
171 |
+
self.scale_attn_weights = config.scale_attn_weights
|
172 |
+
self.is_cross_attention = is_cross_attention
|
173 |
+
|
174 |
+
# Layer-wise attention scaling, reordering, and upcasting
|
175 |
+
self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
|
176 |
+
self.layer_idx = layer_idx
|
177 |
+
self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
|
178 |
+
|
179 |
+
if self.is_cross_attention:
|
180 |
+
self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
|
181 |
+
self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
|
182 |
+
else:
|
183 |
+
self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
|
184 |
+
self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
|
185 |
+
|
186 |
+
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
187 |
+
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
188 |
+
|
189 |
+
self.pruned_heads = set()
|
190 |
+
|
191 |
+
def prune_heads(self, heads):
|
192 |
+
if len(heads) == 0:
|
193 |
+
return
|
194 |
+
heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads)
|
195 |
+
index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
|
196 |
+
|
197 |
+
# Prune conv1d layers
|
198 |
+
self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
|
199 |
+
self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
|
200 |
+
|
201 |
+
# Update hyper params
|
202 |
+
self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads))
|
203 |
+
self.num_heads = self.num_heads - len(heads)
|
204 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
205 |
+
|
206 |
+
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
207 |
+
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
208 |
+
|
209 |
+
if self.scale_attn_weights:
|
210 |
+
attn_weights = attn_weights / (value.size(-1) ** 0.5)
|
211 |
+
|
212 |
+
# Layer-wise attention scaling
|
213 |
+
if self.scale_attn_by_inverse_layer_idx:
|
214 |
+
attn_weights = attn_weights / float(self.layer_idx + 1)
|
215 |
+
|
216 |
+
if not self.is_cross_attention:
|
217 |
+
# if only "normal" attention layer implements causal mask
|
218 |
+
query_length, key_length = query.size(-2), key.size(-2)
|
219 |
+
causal_mask = self.bias[:, :, key_length - query_length: key_length, :key_length].bool()
|
220 |
+
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
|
221 |
+
|
222 |
+
if attention_mask is not None:
|
223 |
+
# Apply the attention mask
|
224 |
+
attn_weights = attn_weights + attention_mask
|
225 |
+
|
226 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
227 |
+
|
228 |
+
# Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
|
229 |
+
attn_weights = attn_weights.type(value.dtype)
|
230 |
+
attn_weights = self.attn_dropout(attn_weights)
|
231 |
+
|
232 |
+
# Mask heads if we want to
|
233 |
+
if head_mask is not None:
|
234 |
+
attn_weights = attn_weights * head_mask
|
235 |
+
|
236 |
+
attn_output = torch.matmul(attn_weights, value)
|
237 |
+
|
238 |
+
return attn_output, attn_weights
|
239 |
+
|
240 |
+
def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
|
241 |
+
# Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
|
242 |
+
bsz, num_heads, q_seq_len, dk = query.size()
|
243 |
+
_, _, k_seq_len, _ = key.size()
|
244 |
+
|
245 |
+
# Preallocate attn_weights for `baddbmm`
|
246 |
+
attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)
|
247 |
+
|
248 |
+
# Compute Scale Factor
|
249 |
+
scale_factor = 1.0
|
250 |
+
if self.scale_attn_weights:
|
251 |
+
scale_factor /= float(value.size(-1)) ** 0.5
|
252 |
+
|
253 |
+
if self.scale_attn_by_inverse_layer_idx:
|
254 |
+
scale_factor /= float(self.layer_idx + 1)
|
255 |
+
|
256 |
+
# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
|
257 |
+
if is_amp_available:
|
258 |
+
with autocast(enabled=False):
|
259 |
+
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
|
260 |
+
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
|
261 |
+
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
|
262 |
+
else:
|
263 |
+
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
|
264 |
+
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
|
265 |
+
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
|
266 |
+
|
267 |
+
if not self.is_cross_attention:
|
268 |
+
# if only "normal" attention layer implements causal mask
|
269 |
+
query_length, key_length = query.size(-2), key.size(-2)
|
270 |
+
causal_mask = self.bias[:, :, key_length - query_length: key_length, :key_length].bool()
|
271 |
+
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
|
272 |
+
|
273 |
+
if attention_mask is not None:
|
274 |
+
# Apply the attention mask
|
275 |
+
attn_weights = attn_weights + attention_mask
|
276 |
+
|
277 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
278 |
+
|
279 |
+
# Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
|
280 |
+
if attn_weights.dtype != torch.float32:
|
281 |
+
raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32")
|
282 |
+
attn_weights = attn_weights.type(value.dtype)
|
283 |
+
attn_weights = self.attn_dropout(attn_weights)
|
284 |
+
|
285 |
+
# Mask heads if we want to
|
286 |
+
if head_mask is not None:
|
287 |
+
attn_weights = attn_weights * head_mask
|
288 |
+
|
289 |
+
attn_output = torch.matmul(attn_weights, value)
|
290 |
+
|
291 |
+
return attn_output, attn_weights
|
292 |
+
|
293 |
+
def _split_heads(self, tensor, num_heads, attn_head_size):
|
294 |
+
"""
|
295 |
+
Splits hidden_size dim into attn_head_size and num_heads
|
296 |
+
"""
|
297 |
+
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
|
298 |
+
tensor = tensor.view(new_shape)
|
299 |
+
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
|
300 |
+
|
301 |
+
def _merge_heads(self, tensor, num_heads, attn_head_size):
|
302 |
+
"""
|
303 |
+
Merges attn_head_size dim and num_attn_heads dim into hidden_size
|
304 |
+
"""
|
305 |
+
tensor = tensor.permute(0, 2, 1, 3).contiguous()
|
306 |
+
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
|
307 |
+
return tensor.view(new_shape)
|
308 |
+
|
309 |
+
def forward(
|
310 |
+
self,
|
311 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
312 |
+
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
313 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
314 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
315 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
316 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
317 |
+
use_cache: Optional[bool] = False,
|
318 |
+
output_attentions: Optional[bool] = False,
|
319 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
|
320 |
+
if encoder_hidden_states is not None:
|
321 |
+
if not hasattr(self, "q_attn"):
|
322 |
+
raise ValueError(
|
323 |
+
"If class is used as cross attention, the weights `q_attn` have to be defined. "
|
324 |
+
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
|
325 |
+
)
|
326 |
+
|
327 |
+
query = self.q_attn(hidden_states)
|
328 |
+
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
|
329 |
+
attention_mask = encoder_attention_mask
|
330 |
+
else:
|
331 |
+
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
332 |
+
|
333 |
+
query = self._split_heads(query, self.num_heads, self.head_dim)
|
334 |
+
key = self._split_heads(key, self.num_heads, self.head_dim)
|
335 |
+
value = self._split_heads(value, self.num_heads, self.head_dim)
|
336 |
+
|
337 |
+
if layer_past is not None:
|
338 |
+
past_key, past_value = layer_past
|
339 |
+
key = torch.cat((past_key, key), dim=-2)
|
340 |
+
value = torch.cat((past_value, value), dim=-2)
|
341 |
+
|
342 |
+
if use_cache is True:
|
343 |
+
present = (key, value)
|
344 |
+
else:
|
345 |
+
present = None
|
346 |
+
|
347 |
+
if self.reorder_and_upcast_attn:
|
348 |
+
attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
|
349 |
+
else:
|
350 |
+
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
351 |
+
|
352 |
+
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
353 |
+
attn_output = self.c_proj(attn_output)
|
354 |
+
attn_output = self.resid_dropout(attn_output)
|
355 |
+
|
356 |
+
outputs = (attn_output, present)
|
357 |
+
if output_attentions:
|
358 |
+
outputs += (attn_weights,)
|
359 |
+
|
360 |
+
return outputs # a, present, (attentions)
|
361 |
+
|
362 |
+
|
363 |
+
class SqReLU(nn.Module):
|
364 |
+
"""
|
365 |
+
See So: Primer: Searching for Efficient Transformers for Language Modeling (So., https://arxiv.org/abs/2109.08668).
|
366 |
+
"""
|
367 |
+
|
368 |
+
def __init__(self):
|
369 |
+
super().__init__()
|
370 |
+
self.act = self._sqrelu_python
|
371 |
+
|
372 |
+
def _sqrelu_python(self, input: torch.Tensor) -> torch.Tensor:
|
373 |
+
return torch.pow(nn.functional.relu(input), 2)
|
374 |
+
|
375 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
376 |
+
return self.act(input)
|
377 |
+
|
378 |
+
|
379 |
+
class GPT2MLP(nn.Module):
|
380 |
+
def __init__(self, intermediate_size, config, squared_relu=False):
|
381 |
+
super().__init__()
|
382 |
+
embed_dim = config.hidden_size
|
383 |
+
self.c_fc = Conv1D(intermediate_size, embed_dim)
|
384 |
+
self.c_proj = Conv1D(embed_dim, intermediate_size)
|
385 |
+
if squared_relu:
|
386 |
+
self.act = SqReLU()
|
387 |
+
else:
|
388 |
+
self.act = ACT2FN[config.activation_function]
|
389 |
+
self.dropout = nn.Dropout(config.resid_pdrop)
|
390 |
+
|
391 |
+
def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
|
392 |
+
hidden_states = self.c_fc(hidden_states)
|
393 |
+
hidden_states = self.act(hidden_states)
|
394 |
+
hidden_states = self.c_proj(hidden_states)
|
395 |
+
hidden_states = self.dropout(hidden_states)
|
396 |
+
return hidden_states
|
397 |
+
|
398 |
+
|
399 |
+
class GPT2Block(nn.Module):
|
400 |
+
def __init__(self, config, layer_idx=None):
|
401 |
+
super().__init__()
|
402 |
+
hidden_size = config.hidden_size
|
403 |
+
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
|
404 |
+
|
405 |
+
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
406 |
+
self.attn = GPT2Attention(config, layer_idx=layer_idx)
|
407 |
+
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
408 |
+
|
409 |
+
self.add_cross_attention_freq = config.add_cross_attention_freq
|
410 |
+
if config.add_cross_attention and layer_idx % config.add_cross_attention_freq == 0:
|
411 |
+
self.crossattention = GPT2Attention(config, is_cross_attention=True, layer_idx=layer_idx)
|
412 |
+
self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
413 |
+
self.mlp_crossattention = GPT2MLP(inner_dim, config, squared_relu=True)
|
414 |
+
self.ln_2_crossattention = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
415 |
+
if config.is_tanh_gating:
|
416 |
+
self.alpha_cattn = nn.Parameter(torch.zeros([]))
|
417 |
+
self.alpha_dense = nn.Parameter(torch.zeros([]))
|
418 |
+
|
419 |
+
self.mlp = GPT2MLP(inner_dim, config)
|
420 |
+
|
421 |
+
def forward(
|
422 |
+
self,
|
423 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
424 |
+
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
425 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
426 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
427 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
428 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
429 |
+
use_cache: Optional[bool] = False,
|
430 |
+
output_attentions: Optional[bool] = False,
|
431 |
+
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
|
432 |
+
if encoder_hidden_states is not None and self.attn.layer_idx % self.add_cross_attention_freq == 0:
|
433 |
+
# add one self-attention block for cross-attention
|
434 |
+
if not hasattr(self, "crossattention"):
|
435 |
+
raise ValueError(
|
436 |
+
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
|
437 |
+
"cross-attention layers by setting `config.add_cross_attention=True`"
|
438 |
+
)
|
439 |
+
residual = hidden_states
|
440 |
+
hidden_states = self.ln_cross_attn(hidden_states)
|
441 |
+
cross_attn_outputs = self.crossattention(
|
442 |
+
hidden_states,
|
443 |
+
attention_mask=attention_mask,
|
444 |
+
head_mask=head_mask,
|
445 |
+
encoder_hidden_states=encoder_hidden_states,
|
446 |
+
encoder_attention_mask=encoder_attention_mask,
|
447 |
+
output_attentions=output_attentions,
|
448 |
+
)
|
449 |
+
attn_output = cross_attn_outputs[0]
|
450 |
+
if hasattr(self, "alpha_cattn"):
|
451 |
+
attn_output = torch.tanh(self.alpha_cattn) * attn_output
|
452 |
+
# residual connection
|
453 |
+
hidden_states = residual + attn_output
|
454 |
+
|
455 |
+
residual = hidden_states
|
456 |
+
hidden_states = self.ln_2_crossattention(hidden_states)
|
457 |
+
feed_forward_hidden_states = self.mlp_crossattention(hidden_states)
|
458 |
+
if hasattr(self, "alpha_dense"):
|
459 |
+
feed_forward_hidden_states = torch.tanh(self.alpha_dense) * feed_forward_hidden_states
|
460 |
+
# residual connection
|
461 |
+
hidden_states = residual + feed_forward_hidden_states
|
462 |
+
|
463 |
+
# Self-Attention
|
464 |
+
residual = hidden_states
|
465 |
+
hidden_states = self.ln_1(hidden_states)
|
466 |
+
attn_outputs = self.attn(
|
467 |
+
hidden_states,
|
468 |
+
layer_past=layer_past,
|
469 |
+
attention_mask=attention_mask,
|
470 |
+
head_mask=head_mask,
|
471 |
+
use_cache=use_cache,
|
472 |
+
output_attentions=output_attentions,
|
473 |
+
)
|
474 |
+
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
|
475 |
+
outputs = attn_outputs[1:]
|
476 |
+
# residual connection
|
477 |
+
hidden_states = attn_output + residual
|
478 |
+
|
479 |
+
# add cross attentions (follow the original order, not to mess things up)
|
480 |
+
if encoder_hidden_states is not None and self.attn.layer_idx % self.add_cross_attention_freq == 0:
|
481 |
+
outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
|
482 |
+
|
483 |
+
# FFN
|
484 |
+
residual = hidden_states
|
485 |
+
hidden_states = self.ln_2(hidden_states)
|
486 |
+
feed_forward_hidden_states = self.mlp(hidden_states)
|
487 |
+
# residual connection
|
488 |
+
hidden_states = residual + feed_forward_hidden_states
|
489 |
+
|
490 |
+
if use_cache:
|
491 |
+
outputs = (hidden_states,) + outputs
|
492 |
+
else:
|
493 |
+
outputs = (hidden_states,) + outputs[1:]
|
494 |
+
|
495 |
+
return outputs # hidden_states, present, (attentions, cross_attentions)
|
496 |
+
|
497 |
+
|
498 |
+
class GPT2PreTrainedModel(PreTrainedModel):
|
499 |
+
"""
|
500 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
501 |
+
models.
|
502 |
+
"""
|
503 |
+
|
504 |
+
config_class = GPT2Config
|
505 |
+
load_tf_weights = load_tf_weights_in_gpt2
|
506 |
+
base_model_prefix = "transformer"
|
507 |
+
is_parallelizable = True
|
508 |
+
supports_gradient_checkpointing = True
|
509 |
+
|
510 |
+
def __init__(self, *inputs, **kwargs):
|
511 |
+
super().__init__(*inputs, **kwargs)
|
512 |
+
|
513 |
+
def _init_weights(self, module):
|
514 |
+
"""Initialize the weights."""
|
515 |
+
if isinstance(module, (nn.Linear, Conv1D)):
|
516 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
517 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
518 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
519 |
+
if module.bias is not None:
|
520 |
+
module.bias.data.zero_()
|
521 |
+
elif isinstance(module, nn.Embedding):
|
522 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
523 |
+
if module.padding_idx is not None:
|
524 |
+
module.weight.data[module.padding_idx].zero_()
|
525 |
+
elif isinstance(module, nn.LayerNorm):
|
526 |
+
module.bias.data.zero_()
|
527 |
+
module.weight.data.fill_(1.0)
|
528 |
+
|
529 |
+
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
530 |
+
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
531 |
+
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
532 |
+
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
533 |
+
#
|
534 |
+
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
535 |
+
for name, p in module.named_parameters():
|
536 |
+
if "c_proj" in name and "weight" in name:
|
537 |
+
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
538 |
+
p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
|
539 |
+
|
540 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
541 |
+
if isinstance(module, GPT2Model):
|
542 |
+
module.gradient_checkpointing = value
|
543 |
+
|
544 |
+
|
545 |
+
@dataclass
|
546 |
+
class GPT2DoubleHeadsModelOutput(ModelOutput):
|
547 |
+
"""
|
548 |
+
Base class for outputs of models predicting if two sentences are consecutive or not.
|
549 |
+
|
550 |
+
Args:
|
551 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
552 |
+
Language modeling loss.
|
553 |
+
mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided):
|
554 |
+
Multiple choice classification loss.
|
555 |
+
logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`):
|
556 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
557 |
+
mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`):
|
558 |
+
Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
|
559 |
+
past_key_values (`Tuple[Tuple[torch.Tensor]]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
560 |
+
Tuple of length `config.n_layers`, containing tuples of tensors of shape `(batch_size, num_heads,
|
561 |
+
sequence_length, embed_size_per_head)`).
|
562 |
+
|
563 |
+
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
|
564 |
+
`past_key_values` input) to speed up sequential decoding.
|
565 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
566 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
567 |
+
shape `(batch_size, sequence_length, hidden_size)`.
|
568 |
+
|
569 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
570 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
571 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
572 |
+
sequence_length)`.
|
573 |
+
|
574 |
+
GPT2Attentions weights after the attention softmax, used to compute the weighted average in the
|
575 |
+
self-attention heads.
|
576 |
+
"""
|
577 |
+
|
578 |
+
loss: Optional[torch.FloatTensor] = None
|
579 |
+
mc_loss: Optional[torch.FloatTensor] = None
|
580 |
+
logits: torch.FloatTensor = None
|
581 |
+
mc_logits: torch.FloatTensor = None
|
582 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
583 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
584 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
585 |
+
|
586 |
+
|
587 |
+
GPT2_START_DOCSTRING = r"""
|
588 |
+
|
589 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
590 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
591 |
+
etc.)
|
592 |
+
|
593 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
594 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
595 |
+
and behavior.
|
596 |
+
|
597 |
+
Parameters:
|
598 |
+
config ([`GPT2Config`]): Model configuration class with all the parameters of the model.
|
599 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
600 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
601 |
+
"""
|
602 |
+
|
603 |
+
GPT2_INPUTS_DOCSTRING = r"""
|
604 |
+
Args:
|
605 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|
606 |
+
`input_ids_length` = `sequence_length` if `past_key_values` is `None` else
|
607 |
+
`past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
|
608 |
+
sequence tokens in the vocabulary.
|
609 |
+
|
610 |
+
If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
|
611 |
+
`input_ids`.
|
612 |
+
|
613 |
+
Indices can be obtained using [`GPT2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
614 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
615 |
+
|
616 |
+
[What are input IDs?](../glossary#input-ids)
|
617 |
+
past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
|
618 |
+
Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
|
619 |
+
`past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
|
620 |
+
their past given to this model should not be passed as `input_ids` as they have already been computed.
|
621 |
+
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
622 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
623 |
+
|
624 |
+
- 1 for tokens that are **not masked**,
|
625 |
+
- 0 for tokens that are **masked**.
|
626 |
+
|
627 |
+
If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for
|
628 |
+
`past_key_values`. In other words, the `attention_mask` always has to have the length:
|
629 |
+
`len(past_key_values) + len(input_ids)`
|
630 |
+
|
631 |
+
[What are attention masks?](../glossary#attention-mask)
|
632 |
+
token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
|
633 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
634 |
+
1]`:
|
635 |
+
|
636 |
+
- 0 corresponds to a *sentence A* token,
|
637 |
+
- 1 corresponds to a *sentence B* token.
|
638 |
+
|
639 |
+
[What are token type IDs?](../glossary#token-type-ids)
|
640 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
641 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
642 |
+
config.max_position_embeddings - 1]`.
|
643 |
+
|
644 |
+
[What are position IDs?](../glossary#position-ids)
|
645 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
646 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
647 |
+
|
648 |
+
- 1 indicates the head is **not masked**,
|
649 |
+
- 0 indicates the head is **masked**.
|
650 |
+
|
651 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
652 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
653 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
654 |
+
model's internal embedding lookup matrix.
|
655 |
+
|
656 |
+
If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
|
657 |
+
`past_key_values`).
|
658 |
+
use_cache (`bool`, *optional*):
|
659 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
660 |
+
`past_key_values`).
|
661 |
+
output_attentions (`bool`, *optional*):
|
662 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
663 |
+
tensors for more detail.
|
664 |
+
output_hidden_states (`bool`, *optional*):
|
665 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
666 |
+
more detail.
|
667 |
+
return_dict (`bool`, *optional*):
|
668 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
669 |
+
"""
|
670 |
+
PARALLELIZE_DOCSTRING = r"""
|
671 |
+
This is an experimental feature and is a subject to change at a moment's notice.
|
672 |
+
|
673 |
+
Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
|
674 |
+
it will evenly distribute blocks across all devices.
|
675 |
+
|
676 |
+
Args:
|
677 |
+
device_map (`Dict[int, list]`, optional, defaults to None):
|
678 |
+
A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
|
679 |
+
automatically mapped to the first device (for esoteric reasons). That means that the first device should
|
680 |
+
have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the
|
681 |
+
following number of attention modules:
|
682 |
+
|
683 |
+
- gpt2: 12
|
684 |
+
- gpt2-medium: 24
|
685 |
+
- gpt2-large: 36
|
686 |
+
- gpt2-xl: 48
|
687 |
+
|
688 |
+
Example:
|
689 |
+
|
690 |
+
```python
|
691 |
+
# Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules:
|
692 |
+
model = GPT2LMHeadModel.from_pretrained("gpt2-xl")
|
693 |
+
device_map = {
|
694 |
+
0: [0, 1, 2, 3, 4, 5, 6, 7, 8],
|
695 |
+
1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
|
696 |
+
2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34],
|
697 |
+
3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
|
698 |
+
}
|
699 |
+
model.parallelize(device_map)
|
700 |
+
```
|
701 |
+
"""
|
702 |
+
DEPARALLELIZE_DOCSTRING = r"""
|
703 |
+
Moves the model to cpu from a model parallel state.
|
704 |
+
|
705 |
+
Example:
|
706 |
+
|
707 |
+
```python
|
708 |
+
# On a 4 GPU machine with gpt2-large:
|
709 |
+
model = GPT2LMHeadModel.from_pretrained("gpt2-large")
|
710 |
+
device_map = {
|
711 |
+
0: [0, 1, 2, 3, 4, 5, 6, 7],
|
712 |
+
1: [8, 9, 10, 11, 12, 13, 14, 15],
|
713 |
+
2: [16, 17, 18, 19, 20, 21, 22, 23],
|
714 |
+
3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35],
|
715 |
+
}
|
716 |
+
model.parallelize(device_map) # Splits the model across several devices
|
717 |
+
model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
|
718 |
+
```
|
719 |
+
"""
|
720 |
+
|
721 |
+
|
722 |
+
@add_start_docstrings(
|
723 |
+
"The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
|
724 |
+
GPT2_START_DOCSTRING,
|
725 |
+
)
|
726 |
+
class GPT2Model(GPT2PreTrainedModel):
|
727 |
+
_keys_to_ignore_on_load_missing = ["attn.masked_bias"]
|
728 |
+
|
729 |
+
def __init__(self, config):
|
730 |
+
super().__init__(config)
|
731 |
+
|
732 |
+
self.embed_dim = config.hidden_size
|
733 |
+
|
734 |
+
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
|
735 |
+
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
736 |
+
|
737 |
+
self.drop = nn.Dropout(config.embd_pdrop)
|
738 |
+
self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)])
|
739 |
+
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
740 |
+
|
741 |
+
# Model parallel
|
742 |
+
self.model_parallel = False
|
743 |
+
self.device_map = None
|
744 |
+
self.gradient_checkpointing = False
|
745 |
+
|
746 |
+
# Initialize weights and apply final processing
|
747 |
+
self.post_init()
|
748 |
+
|
749 |
+
@add_start_docstrings(PARALLELIZE_DOCSTRING)
|
750 |
+
def parallelize(self, device_map=None):
|
751 |
+
# Check validity of device_map
|
752 |
+
self.device_map = (
|
753 |
+
get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map
|
754 |
+
)
|
755 |
+
assert_device_map(self.device_map, len(self.h))
|
756 |
+
self.model_parallel = True
|
757 |
+
self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
|
758 |
+
self.last_device = "cuda:" + str(max(self.device_map.keys()))
|
759 |
+
self.wte = self.wte.to(self.first_device)
|
760 |
+
self.wpe = self.wpe.to(self.first_device)
|
761 |
+
# Load onto devices
|
762 |
+
for k, v in self.device_map.items():
|
763 |
+
for block in v:
|
764 |
+
cuda_device = "cuda:" + str(k)
|
765 |
+
self.h[block] = self.h[block].to(cuda_device)
|
766 |
+
# ln_f to last
|
767 |
+
self.ln_f = self.ln_f.to(self.last_device)
|
768 |
+
|
769 |
+
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
|
770 |
+
def deparallelize(self):
|
771 |
+
self.model_parallel = False
|
772 |
+
self.device_map = None
|
773 |
+
self.first_device = "cpu"
|
774 |
+
self.last_device = "cpu"
|
775 |
+
self.wte = self.wte.to("cpu")
|
776 |
+
self.wpe = self.wpe.to("cpu")
|
777 |
+
for index in range(len(self.h)):
|
778 |
+
self.h[index] = self.h[index].to("cpu")
|
779 |
+
self.ln_f = self.ln_f.to("cpu")
|
780 |
+
torch.cuda.empty_cache()
|
781 |
+
|
782 |
+
def get_input_embeddings(self):
|
783 |
+
return self.wte
|
784 |
+
|
785 |
+
def set_input_embeddings(self, new_embeddings):
|
786 |
+
self.wte = new_embeddings
|
787 |
+
|
788 |
+
def _prune_heads(self, heads_to_prune):
|
789 |
+
"""
|
790 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
791 |
+
"""
|
792 |
+
for layer, heads in heads_to_prune.items():
|
793 |
+
self.h[layer].attn.prune_heads(heads)
|
794 |
+
|
795 |
+
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
|
796 |
+
@add_code_sample_docstrings(
|
797 |
+
processor_class=_TOKENIZER_FOR_DOC,
|
798 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
799 |
+
output_type=BaseModelOutputWithPastAndCrossAttentions,
|
800 |
+
config_class=_CONFIG_FOR_DOC,
|
801 |
+
)
|
802 |
+
def forward(
|
803 |
+
self,
|
804 |
+
input_ids: Optional[torch.LongTensor] = None,
|
805 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
806 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
807 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
808 |
+
position_ids: Optional[torch.LongTensor] = None,
|
809 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
810 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
811 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
812 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
813 |
+
use_cache: Optional[bool] = None,
|
814 |
+
output_attentions: Optional[bool] = None,
|
815 |
+
output_hidden_states: Optional[bool] = None,
|
816 |
+
return_dict: Optional[bool] = None,
|
817 |
+
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
818 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
819 |
+
output_hidden_states = (
|
820 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
821 |
+
)
|
822 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
823 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
824 |
+
|
825 |
+
if input_ids is not None and inputs_embeds is not None:
|
826 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
827 |
+
elif input_ids is not None:
|
828 |
+
input_shape = input_ids.size()
|
829 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
830 |
+
batch_size = input_ids.shape[0]
|
831 |
+
elif inputs_embeds is not None:
|
832 |
+
input_shape = inputs_embeds.size()[:-1]
|
833 |
+
batch_size = inputs_embeds.shape[0]
|
834 |
+
else:
|
835 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
836 |
+
|
837 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
838 |
+
|
839 |
+
if token_type_ids is not None:
|
840 |
+
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
841 |
+
if position_ids is not None:
|
842 |
+
position_ids = position_ids.view(-1, input_shape[-1])
|
843 |
+
|
844 |
+
if past_key_values is None:
|
845 |
+
past_length = 0
|
846 |
+
past_key_values = tuple([None] * len(self.h))
|
847 |
+
else:
|
848 |
+
past_length = past_key_values[0][0].size(-2)
|
849 |
+
if position_ids is None:
|
850 |
+
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
851 |
+
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
852 |
+
|
853 |
+
# GPT2Attention mask.
|
854 |
+
if attention_mask is not None:
|
855 |
+
if batch_size <= 0:
|
856 |
+
raise ValueError("batch_size has to be defined and > 0")
|
857 |
+
attention_mask = attention_mask.view(batch_size, -1)
|
858 |
+
# We create a 3D attention mask from a 2D tensor mask.
|
859 |
+
# Sizes are [batch_size, 1, 1, to_seq_length]
|
860 |
+
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
861 |
+
# this attention mask is more simple than the triangular masking of causal attention
|
862 |
+
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
863 |
+
attention_mask = attention_mask[:, None, None, :]
|
864 |
+
|
865 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
866 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
867 |
+
# positions we want to attend and -10000.0 for masked positions.
|
868 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
869 |
+
# effectively the same as removing these entirely.
|
870 |
+
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
871 |
+
attention_mask = (1.0 - attention_mask) * -10000.0
|
872 |
+
|
873 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
874 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
875 |
+
if self.config.add_cross_attention and encoder_hidden_states is not None:
|
876 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
877 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
878 |
+
if encoder_attention_mask is None:
|
879 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
880 |
+
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
881 |
+
else:
|
882 |
+
encoder_attention_mask = None
|
883 |
+
|
884 |
+
# Prepare head mask if needed
|
885 |
+
# 1.0 in head_mask indicate we keep the head
|
886 |
+
# attention_probs has shape bsz x n_heads x N x N
|
887 |
+
# head_mask has shape n_layer x batch x n_heads x N x N
|
888 |
+
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
889 |
+
|
890 |
+
if inputs_embeds is None:
|
891 |
+
inputs_embeds = self.wte(input_ids)
|
892 |
+
position_embeds = self.wpe(position_ids)
|
893 |
+
hidden_states = inputs_embeds + position_embeds
|
894 |
+
|
895 |
+
if token_type_ids is not None:
|
896 |
+
token_type_embeds = self.wte(token_type_ids)
|
897 |
+
hidden_states = hidden_states + token_type_embeds
|
898 |
+
|
899 |
+
hidden_states = self.drop(hidden_states)
|
900 |
+
|
901 |
+
output_shape = input_shape + (hidden_states.size(-1),)
|
902 |
+
|
903 |
+
presents = () if use_cache else None
|
904 |
+
all_self_attentions = () if output_attentions else None
|
905 |
+
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
906 |
+
all_hidden_states = () if output_hidden_states else None
|
907 |
+
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
908 |
+
|
909 |
+
# Model parallel
|
910 |
+
if self.model_parallel:
|
911 |
+
torch.cuda.set_device(hidden_states.device)
|
912 |
+
# Ensure layer_past is on same device as hidden_states (might not be correct)
|
913 |
+
if layer_past is not None:
|
914 |
+
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
|
915 |
+
# Ensure that attention_mask is always on the same device as hidden_states
|
916 |
+
if attention_mask is not None:
|
917 |
+
attention_mask = attention_mask.to(hidden_states.device)
|
918 |
+
if isinstance(head_mask, torch.Tensor):
|
919 |
+
head_mask = head_mask.to(hidden_states.device)
|
920 |
+
if output_hidden_states:
|
921 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
922 |
+
|
923 |
+
if self.gradient_checkpointing and self.training:
|
924 |
+
|
925 |
+
if use_cache:
|
926 |
+
logger.warning(
|
927 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
928 |
+
)
|
929 |
+
use_cache = False
|
930 |
+
|
931 |
+
def create_custom_forward(module):
|
932 |
+
def custom_forward(*inputs):
|
933 |
+
# None for past_key_value
|
934 |
+
return module(*inputs, use_cache, output_attentions)
|
935 |
+
|
936 |
+
return custom_forward
|
937 |
+
|
938 |
+
outputs = torch.utils.checkpoint.checkpoint(
|
939 |
+
create_custom_forward(block),
|
940 |
+
hidden_states,
|
941 |
+
None,
|
942 |
+
attention_mask,
|
943 |
+
head_mask[i],
|
944 |
+
encoder_hidden_states,
|
945 |
+
encoder_attention_mask,
|
946 |
+
)
|
947 |
+
else:
|
948 |
+
outputs = block(
|
949 |
+
hidden_states,
|
950 |
+
layer_past=layer_past,
|
951 |
+
attention_mask=attention_mask,
|
952 |
+
head_mask=head_mask[i],
|
953 |
+
encoder_hidden_states=encoder_hidden_states,
|
954 |
+
encoder_attention_mask=encoder_attention_mask,
|
955 |
+
use_cache=use_cache,
|
956 |
+
output_attentions=output_attentions,
|
957 |
+
)
|
958 |
+
|
959 |
+
hidden_states = outputs[0]
|
960 |
+
if use_cache is True:
|
961 |
+
presents = presents + (outputs[1],)
|
962 |
+
|
963 |
+
if output_attentions:
|
964 |
+
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
965 |
+
if self.config.add_cross_attention:
|
966 |
+
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
967 |
+
|
968 |
+
# Model Parallel: If it's the last layer for that device, put things on the next device
|
969 |
+
if self.model_parallel:
|
970 |
+
for k, v in self.device_map.items():
|
971 |
+
if i == v[-1] and "cuda:" + str(k) != self.last_device:
|
972 |
+
hidden_states = hidden_states.to("cuda:" + str(k + 1))
|
973 |
+
|
974 |
+
hidden_states = self.ln_f(hidden_states)
|
975 |
+
|
976 |
+
hidden_states = hidden_states.view(output_shape)
|
977 |
+
# Add last hidden state
|
978 |
+
if output_hidden_states:
|
979 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
980 |
+
|
981 |
+
if not return_dict:
|
982 |
+
return tuple(
|
983 |
+
v
|
984 |
+
for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
|
985 |
+
if v is not None
|
986 |
+
)
|
987 |
+
|
988 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
989 |
+
last_hidden_state=hidden_states,
|
990 |
+
past_key_values=presents,
|
991 |
+
hidden_states=all_hidden_states,
|
992 |
+
attentions=all_self_attentions,
|
993 |
+
cross_attentions=all_cross_attentions,
|
994 |
+
)
|
995 |
+
|
996 |
+
|
997 |
+
@add_start_docstrings(
|
998 |
+
"""
|
999 |
+
The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
1000 |
+
embeddings).
|
1001 |
+
""",
|
1002 |
+
GPT2_START_DOCSTRING,
|
1003 |
+
)
|
1004 |
+
class GPT2LMHeadModel(GPT2PreTrainedModel):
|
1005 |
+
_keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"]
|
1006 |
+
|
1007 |
+
def __init__(self, config):
|
1008 |
+
super().__init__(config)
|
1009 |
+
self.transformer = GPT2Model(config)
|
1010 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
1011 |
+
|
1012 |
+
# Model parallel
|
1013 |
+
self.model_parallel = False
|
1014 |
+
self.device_map = None
|
1015 |
+
|
1016 |
+
# Initialize weights and apply final processing
|
1017 |
+
self.post_init()
|
1018 |
+
|
1019 |
+
def freeze_lm_weights(self):
|
1020 |
+
freeze_list, unfreeze_list = [], []
|
1021 |
+
for n, p in self.named_parameters():
|
1022 |
+
if 'crossattention' in n or 'cross_attn' in n or 'alpha_cattn' in n or 'alpha_dense' in n:
|
1023 |
+
p.requires_grad = True
|
1024 |
+
unfreeze_list.append(n)
|
1025 |
+
else:
|
1026 |
+
p.requires_grad = False
|
1027 |
+
freeze_list.append(n)
|
1028 |
+
print("Freeze the pretrained parts in LM: {}".format(freeze_list))
|
1029 |
+
print(" Learn the rest parts in LM: {}".format(unfreeze_list))
|
1030 |
+
|
1031 |
+
@add_start_docstrings(PARALLELIZE_DOCSTRING)
|
1032 |
+
def parallelize(self, device_map=None):
|
1033 |
+
self.device_map = (
|
1034 |
+
get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
|
1035 |
+
if device_map is None
|
1036 |
+
else device_map
|
1037 |
+
)
|
1038 |
+
assert_device_map(self.device_map, len(self.transformer.h))
|
1039 |
+
self.transformer.parallelize(self.device_map)
|
1040 |
+
self.lm_head = self.lm_head.to(self.transformer.first_device)
|
1041 |
+
self.model_parallel = True
|
1042 |
+
|
1043 |
+
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
|
1044 |
+
def deparallelize(self):
|
1045 |
+
self.transformer.deparallelize()
|
1046 |
+
self.transformer = self.transformer.to("cpu")
|
1047 |
+
self.lm_head = self.lm_head.to("cpu")
|
1048 |
+
self.model_parallel = False
|
1049 |
+
torch.cuda.empty_cache()
|
1050 |
+
|
1051 |
+
def get_output_embeddings(self):
|
1052 |
+
return self.lm_head
|
1053 |
+
|
1054 |
+
def set_output_embeddings(self, new_embeddings):
|
1055 |
+
self.lm_head = new_embeddings
|
1056 |
+
|
1057 |
+
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
1058 |
+
token_type_ids = kwargs.get("token_type_ids", None)
|
1059 |
+
# only last token for inputs_ids if past is defined in kwargs
|
1060 |
+
if past:
|
1061 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
1062 |
+
if token_type_ids is not None:
|
1063 |
+
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
|
1064 |
+
|
1065 |
+
attention_mask = kwargs.get("attention_mask", None)
|
1066 |
+
position_ids = kwargs.get("position_ids", None)
|
1067 |
+
|
1068 |
+
if attention_mask is not None and position_ids is None:
|
1069 |
+
# create position_ids on the fly for batch generation
|
1070 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
1071 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
1072 |
+
if past:
|
1073 |
+
position_ids = position_ids[:, -1].unsqueeze(-1)
|
1074 |
+
else:
|
1075 |
+
position_ids = None
|
1076 |
+
return {
|
1077 |
+
"input_ids": input_ids,
|
1078 |
+
"past_key_values": past,
|
1079 |
+
"use_cache": kwargs.get("use_cache"),
|
1080 |
+
"position_ids": position_ids,
|
1081 |
+
"attention_mask": attention_mask,
|
1082 |
+
"token_type_ids": token_type_ids,
|
1083 |
+
}
|
1084 |
+
|
1085 |
+
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
|
1086 |
+
@add_code_sample_docstrings(
|
1087 |
+
processor_class=_TOKENIZER_FOR_DOC,
|
1088 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
1089 |
+
output_type=CausalLMOutputWithCrossAttentions,
|
1090 |
+
config_class=_CONFIG_FOR_DOC,
|
1091 |
+
)
|
1092 |
+
def forward(
|
1093 |
+
self,
|
1094 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1095 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
1096 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1097 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
1098 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1099 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
1100 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1101 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
1102 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
1103 |
+
labels: Optional[torch.LongTensor] = None,
|
1104 |
+
use_cache: Optional[bool] = None,
|
1105 |
+
output_attentions: Optional[bool] = None,
|
1106 |
+
output_hidden_states: Optional[bool] = None,
|
1107 |
+
return_dict: Optional[bool] = None,
|
1108 |
+
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
1109 |
+
r"""
|
1110 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1111 |
+
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
1112 |
+
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
1113 |
+
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
1114 |
+
"""
|
1115 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1116 |
+
|
1117 |
+
transformer_outputs = self.transformer(
|
1118 |
+
input_ids,
|
1119 |
+
past_key_values=past_key_values,
|
1120 |
+
attention_mask=attention_mask,
|
1121 |
+
token_type_ids=token_type_ids,
|
1122 |
+
position_ids=position_ids,
|
1123 |
+
head_mask=head_mask,
|
1124 |
+
inputs_embeds=inputs_embeds,
|
1125 |
+
encoder_hidden_states=encoder_hidden_states,
|
1126 |
+
encoder_attention_mask=encoder_attention_mask,
|
1127 |
+
use_cache=use_cache,
|
1128 |
+
output_attentions=output_attentions,
|
1129 |
+
output_hidden_states=output_hidden_states,
|
1130 |
+
return_dict=return_dict,
|
1131 |
+
)
|
1132 |
+
hidden_states = transformer_outputs[0]
|
1133 |
+
|
1134 |
+
# Set device for model parallelism
|
1135 |
+
if self.model_parallel:
|
1136 |
+
torch.cuda.set_device(self.transformer.first_device)
|
1137 |
+
hidden_states = hidden_states.to(self.lm_head.weight.device)
|
1138 |
+
|
1139 |
+
lm_logits = self.lm_head(hidden_states)
|
1140 |
+
|
1141 |
+
loss = None
|
1142 |
+
if labels is not None:
|
1143 |
+
# Shift so that tokens < n predict n
|
1144 |
+
shift_logits = lm_logits[..., :-1, :].contiguous()
|
1145 |
+
shift_labels = labels[..., 1:].contiguous()
|
1146 |
+
# Flatten the tokens
|
1147 |
+
loss_fct = CrossEntropyLoss()
|
1148 |
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
1149 |
+
|
1150 |
+
if not return_dict:
|
1151 |
+
output = (lm_logits,) + transformer_outputs[1:]
|
1152 |
+
return ((loss,) + output) if loss is not None else output
|
1153 |
+
|
1154 |
+
return CausalLMOutputWithCrossAttentions(
|
1155 |
+
loss=loss,
|
1156 |
+
logits=lm_logits,
|
1157 |
+
past_key_values=transformer_outputs.past_key_values,
|
1158 |
+
hidden_states=transformer_outputs.hidden_states,
|
1159 |
+
attentions=transformer_outputs.attentions,
|
1160 |
+
cross_attentions=transformer_outputs.cross_attentions,
|
1161 |
+
)
|
1162 |
+
|
1163 |
+
@staticmethod
|
1164 |
+
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
|
1165 |
+
"""
|
1166 |
+
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
|
1167 |
+
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
|
1168 |
+
beam_idx at every generation step.
|
1169 |
+
"""
|
1170 |
+
return tuple(
|
1171 |
+
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
|
1172 |
+
for layer_past in past
|
1173 |
+
)
|
1174 |
+
|
1175 |
+
|
1176 |
+
@add_start_docstrings(
|
1177 |
+
"""
|
1178 |
+
The GPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for
|
1179 |
+
RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the
|
1180 |
+
input embeddings, the classification head takes as input the input of a specified classification token index in the
|
1181 |
+
input sequence).
|
1182 |
+
""",
|
1183 |
+
GPT2_START_DOCSTRING,
|
1184 |
+
)
|
1185 |
+
class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
1186 |
+
_keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"]
|
1187 |
+
|
1188 |
+
def __init__(self, config):
|
1189 |
+
super().__init__(config)
|
1190 |
+
config.num_labels = 1
|
1191 |
+
self.transformer = GPT2Model(config)
|
1192 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
1193 |
+
self.multiple_choice_head = SequenceSummary(config)
|
1194 |
+
|
1195 |
+
# Model parallel
|
1196 |
+
self.model_parallel = False
|
1197 |
+
self.device_map = None
|
1198 |
+
|
1199 |
+
# Initialize weights and apply final processing
|
1200 |
+
self.post_init()
|
1201 |
+
|
1202 |
+
@add_start_docstrings(PARALLELIZE_DOCSTRING)
|
1203 |
+
def parallelize(self, device_map=None):
|
1204 |
+
self.device_map = (
|
1205 |
+
get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
|
1206 |
+
if device_map is None
|
1207 |
+
else device_map
|
1208 |
+
)
|
1209 |
+
assert_device_map(self.device_map, len(self.transformer.h))
|
1210 |
+
self.transformer.parallelize(self.device_map)
|
1211 |
+
self.lm_head = self.lm_head.to(self.transformer.first_device)
|
1212 |
+
self.multiple_choice_head = self.multiple_choice_head.to(self.transformer.first_device)
|
1213 |
+
self.model_parallel = True
|
1214 |
+
|
1215 |
+
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
|
1216 |
+
def deparallelize(self):
|
1217 |
+
self.transformer.deparallelize()
|
1218 |
+
self.transformer = self.transformer.to("cpu")
|
1219 |
+
self.lm_head = self.lm_head.to("cpu")
|
1220 |
+
self.multiple_choice_head = self.multiple_choice_head.to("cpu")
|
1221 |
+
self.model_parallel = False
|
1222 |
+
torch.cuda.empty_cache()
|
1223 |
+
|
1224 |
+
def get_output_embeddings(self):
|
1225 |
+
return self.lm_head
|
1226 |
+
|
1227 |
+
def set_output_embeddings(self, new_embeddings):
|
1228 |
+
self.lm_head = new_embeddings
|
1229 |
+
|
1230 |
+
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
1231 |
+
token_type_ids = kwargs.get("token_type_ids", None)
|
1232 |
+
# only last token for inputs_ids if past is defined in kwargs
|
1233 |
+
if past:
|
1234 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
1235 |
+
if token_type_ids is not None:
|
1236 |
+
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
|
1237 |
+
|
1238 |
+
attention_mask = kwargs.get("attention_mask", None)
|
1239 |
+
position_ids = kwargs.get("position_ids", None)
|
1240 |
+
|
1241 |
+
if attention_mask is not None and position_ids is None:
|
1242 |
+
# create position_ids on the fly for batch generation
|
1243 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
1244 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
1245 |
+
if past:
|
1246 |
+
position_ids = position_ids[:, -1].unsqueeze(-1)
|
1247 |
+
else:
|
1248 |
+
position_ids = None
|
1249 |
+
|
1250 |
+
return {
|
1251 |
+
"input_ids": input_ids,
|
1252 |
+
"past_key_values": past,
|
1253 |
+
"use_cache": kwargs.get("use_cache"),
|
1254 |
+
"position_ids": position_ids,
|
1255 |
+
"attention_mask": attention_mask,
|
1256 |
+
"token_type_ids": token_type_ids,
|
1257 |
+
}
|
1258 |
+
|
1259 |
+
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
|
1260 |
+
@replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
|
1261 |
+
def forward(
|
1262 |
+
self,
|
1263 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1264 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
1265 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1266 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
1267 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1268 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
1269 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1270 |
+
mc_token_ids: Optional[torch.LongTensor] = None,
|
1271 |
+
labels: Optional[torch.LongTensor] = None,
|
1272 |
+
mc_labels: Optional[torch.LongTensor] = None,
|
1273 |
+
use_cache: Optional[bool] = None,
|
1274 |
+
output_attentions: Optional[bool] = None,
|
1275 |
+
output_hidden_states: Optional[bool] = None,
|
1276 |
+
return_dict: Optional[bool] = None,
|
1277 |
+
**kwargs,
|
1278 |
+
) -> Union[Tuple, GPT2DoubleHeadsModelOutput]:
|
1279 |
+
r"""
|
1280 |
+
mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input):
|
1281 |
+
Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) -
|
1282 |
+
1[`.
|
1283 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1284 |
+
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
1285 |
+
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size - 1]` All labels set to
|
1286 |
+
`-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]`
|
1287 |
+
mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*):
|
1288 |
+
Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
|
1289 |
+
where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above)
|
1290 |
+
|
1291 |
+
Return:
|
1292 |
+
|
1293 |
+
Example:
|
1294 |
+
|
1295 |
+
```python
|
1296 |
+
>>> import torch
|
1297 |
+
>>> from transformers import GPT2Tokenizer, GPT2DoubleHeadsModel
|
1298 |
+
|
1299 |
+
>>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
1300 |
+
>>> model = GPT2DoubleHeadsModel.from_pretrained("gpt2")
|
1301 |
+
|
1302 |
+
>>> # Add a [CLS] to the vocabulary (we should train it also!)
|
1303 |
+
>>> num_added_tokens = tokenizer.add_special_tokens({"cls_token": "[CLS]"})
|
1304 |
+
>>> # Update the model embeddings with the new vocabulary size
|
1305 |
+
>>> embedding_layer = model.resize_token_embeddings(len(tokenizer))
|
1306 |
+
|
1307 |
+
>>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
|
1308 |
+
>>> encoded_choices = [tokenizer.encode(s) for s in choices]
|
1309 |
+
>>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices]
|
1310 |
+
|
1311 |
+
>>> input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2
|
1312 |
+
>>> mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1
|
1313 |
+
|
1314 |
+
>>> outputs = model(input_ids, mc_token_ids=mc_token_ids)
|
1315 |
+
>>> lm_logits = outputs.logits
|
1316 |
+
>>> mc_logits = outputs.mc_logits
|
1317 |
+
```"""
|
1318 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1319 |
+
|
1320 |
+
transformer_outputs = self.transformer(
|
1321 |
+
input_ids,
|
1322 |
+
past_key_values=past_key_values,
|
1323 |
+
attention_mask=attention_mask,
|
1324 |
+
token_type_ids=token_type_ids,
|
1325 |
+
position_ids=position_ids,
|
1326 |
+
head_mask=head_mask,
|
1327 |
+
inputs_embeds=inputs_embeds,
|
1328 |
+
use_cache=use_cache,
|
1329 |
+
output_attentions=output_attentions,
|
1330 |
+
output_hidden_states=output_hidden_states,
|
1331 |
+
return_dict=return_dict,
|
1332 |
+
)
|
1333 |
+
|
1334 |
+
hidden_states = transformer_outputs[0]
|
1335 |
+
|
1336 |
+
# Set device for model parallelism
|
1337 |
+
if self.model_parallel:
|
1338 |
+
torch.cuda.set_device(self.transformer.first_device)
|
1339 |
+
hidden_states = hidden_states.to(self.lm_head.weight.device)
|
1340 |
+
|
1341 |
+
lm_logits = self.lm_head(hidden_states)
|
1342 |
+
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
|
1343 |
+
|
1344 |
+
mc_loss = None
|
1345 |
+
if mc_labels is not None:
|
1346 |
+
loss_fct = CrossEntropyLoss()
|
1347 |
+
mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
|
1348 |
+
lm_loss = None
|
1349 |
+
if labels is not None:
|
1350 |
+
shift_logits = lm_logits[..., :-1, :].contiguous()
|
1351 |
+
shift_labels = labels[..., 1:].contiguous()
|
1352 |
+
loss_fct = CrossEntropyLoss()
|
1353 |
+
lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
1354 |
+
|
1355 |
+
if not return_dict:
|
1356 |
+
output = (lm_logits, mc_logits) + transformer_outputs[1:]
|
1357 |
+
if mc_loss is not None:
|
1358 |
+
output = (mc_loss,) + output
|
1359 |
+
return ((lm_loss,) + output) if lm_loss is not None else output
|
1360 |
+
|
1361 |
+
return GPT2DoubleHeadsModelOutput(
|
1362 |
+
loss=lm_loss,
|
1363 |
+
mc_loss=mc_loss,
|
1364 |
+
logits=lm_logits,
|
1365 |
+
mc_logits=mc_logits,
|
1366 |
+
past_key_values=transformer_outputs.past_key_values,
|
1367 |
+
hidden_states=transformer_outputs.hidden_states,
|
1368 |
+
attentions=transformer_outputs.attentions,
|
1369 |
+
)
|
1370 |
+
|
1371 |
+
@staticmethod
|
1372 |
+
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
|
1373 |
+
"""
|
1374 |
+
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
|
1375 |
+
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
|
1376 |
+
beam_idx at every generation step.
|
1377 |
+
"""
|
1378 |
+
return tuple(
|
1379 |
+
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
|
1380 |
+
for layer_past in past
|
1381 |
+
)
|
1382 |
+
|
1383 |
+
|
1384 |
+
@add_start_docstrings(
|
1385 |
+
"""
|
1386 |
+
The GPT2 Model transformer with a sequence classification head on top (linear layer).
|
1387 |
+
|
1388 |
+
[`GPT2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
1389 |
+
(e.g. GPT-1) do.
|
1390 |
+
|
1391 |
+
Since it does classification on the last token, it requires to know the position of the last token. If a
|
1392 |
+
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
1393 |
+
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
1394 |
+
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
1395 |
+
each row of the batch).
|
1396 |
+
""",
|
1397 |
+
GPT2_START_DOCSTRING,
|
1398 |
+
)
|
1399 |
+
class GPT2ForSequenceClassification(GPT2PreTrainedModel):
|
1400 |
+
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
|
1401 |
+
|
1402 |
+
def __init__(self, config):
|
1403 |
+
super().__init__(config)
|
1404 |
+
self.num_labels = config.num_labels
|
1405 |
+
self.transformer = GPT2Model(config)
|
1406 |
+
self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
|
1407 |
+
|
1408 |
+
# Model parallel
|
1409 |
+
self.model_parallel = False
|
1410 |
+
self.device_map = None
|
1411 |
+
|
1412 |
+
# Initialize weights and apply final processing
|
1413 |
+
self.post_init()
|
1414 |
+
|
1415 |
+
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
|
1416 |
+
@add_code_sample_docstrings(
|
1417 |
+
processor_class=_TOKENIZER_FOR_DOC,
|
1418 |
+
checkpoint="microsoft/DialogRPT-updown",
|
1419 |
+
output_type=SequenceClassifierOutputWithPast,
|
1420 |
+
config_class=_CONFIG_FOR_DOC,
|
1421 |
+
expected_output="'LABEL_0'",
|
1422 |
+
expected_loss=5.28,
|
1423 |
+
)
|
1424 |
+
def forward(
|
1425 |
+
self,
|
1426 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1427 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
1428 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1429 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
1430 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1431 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
1432 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1433 |
+
labels: Optional[torch.LongTensor] = None,
|
1434 |
+
use_cache: Optional[bool] = None,
|
1435 |
+
output_attentions: Optional[bool] = None,
|
1436 |
+
output_hidden_states: Optional[bool] = None,
|
1437 |
+
return_dict: Optional[bool] = None,
|
1438 |
+
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
1439 |
+
r"""
|
1440 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1441 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
1442 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
1443 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
1444 |
+
"""
|
1445 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1446 |
+
|
1447 |
+
transformer_outputs = self.transformer(
|
1448 |
+
input_ids,
|
1449 |
+
past_key_values=past_key_values,
|
1450 |
+
attention_mask=attention_mask,
|
1451 |
+
token_type_ids=token_type_ids,
|
1452 |
+
position_ids=position_ids,
|
1453 |
+
head_mask=head_mask,
|
1454 |
+
inputs_embeds=inputs_embeds,
|
1455 |
+
use_cache=use_cache,
|
1456 |
+
output_attentions=output_attentions,
|
1457 |
+
output_hidden_states=output_hidden_states,
|
1458 |
+
return_dict=return_dict,
|
1459 |
+
)
|
1460 |
+
hidden_states = transformer_outputs[0]
|
1461 |
+
logits = self.score(hidden_states)
|
1462 |
+
|
1463 |
+
if input_ids is not None:
|
1464 |
+
batch_size, sequence_length = input_ids.shape[:2]
|
1465 |
+
else:
|
1466 |
+
batch_size, sequence_length = inputs_embeds.shape[:2]
|
1467 |
+
|
1468 |
+
assert (
|
1469 |
+
self.config.pad_token_id is not None or batch_size == 1
|
1470 |
+
), "Cannot handle batch sizes > 1 if no padding token is defined."
|
1471 |
+
if self.config.pad_token_id is None:
|
1472 |
+
sequence_lengths = -1
|
1473 |
+
else:
|
1474 |
+
if input_ids is not None:
|
1475 |
+
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
|
1476 |
+
else:
|
1477 |
+
sequence_lengths = -1
|
1478 |
+
logger.warning(
|
1479 |
+
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
1480 |
+
f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
1481 |
+
)
|
1482 |
+
|
1483 |
+
pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths]
|
1484 |
+
|
1485 |
+
loss = None
|
1486 |
+
if labels is not None:
|
1487 |
+
if self.config.problem_type is None:
|
1488 |
+
if self.num_labels == 1:
|
1489 |
+
self.config.problem_type = "regression"
|
1490 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
1491 |
+
self.config.problem_type = "single_label_classification"
|
1492 |
+
else:
|
1493 |
+
self.config.problem_type = "multi_label_classification"
|
1494 |
+
|
1495 |
+
if self.config.problem_type == "regression":
|
1496 |
+
loss_fct = MSELoss()
|
1497 |
+
if self.num_labels == 1:
|
1498 |
+
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
1499 |
+
else:
|
1500 |
+
loss = loss_fct(pooled_logits, labels)
|
1501 |
+
elif self.config.problem_type == "single_label_classification":
|
1502 |
+
loss_fct = CrossEntropyLoss()
|
1503 |
+
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
1504 |
+
elif self.config.problem_type == "multi_label_classification":
|
1505 |
+
loss_fct = BCEWithLogitsLoss()
|
1506 |
+
loss = loss_fct(pooled_logits, labels)
|
1507 |
+
if not return_dict:
|
1508 |
+
output = (pooled_logits,) + transformer_outputs[1:]
|
1509 |
+
return ((loss,) + output) if loss is not None else output
|
1510 |
+
|
1511 |
+
return SequenceClassifierOutputWithPast(
|
1512 |
+
loss=loss,
|
1513 |
+
logits=pooled_logits,
|
1514 |
+
past_key_values=transformer_outputs.past_key_values,
|
1515 |
+
hidden_states=transformer_outputs.hidden_states,
|
1516 |
+
attentions=transformer_outputs.attentions,
|
1517 |
+
)
|
1518 |
+
|
1519 |
+
|
1520 |
+
@add_start_docstrings(
|
1521 |
+
"""
|
1522 |
+
GPT2 Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
|
1523 |
+
Named-Entity-Recognition (NER) tasks.
|
1524 |
+
""",
|
1525 |
+
GPT2_START_DOCSTRING,
|
1526 |
+
)
|
1527 |
+
class GPT2ForTokenClassification(GPT2PreTrainedModel):
|
1528 |
+
def __init__(self, config):
|
1529 |
+
super().__init__(config)
|
1530 |
+
self.num_labels = config.num_labels
|
1531 |
+
|
1532 |
+
self.transformer = GPT2Model(config)
|
1533 |
+
if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
|
1534 |
+
classifier_dropout = config.classifier_dropout
|
1535 |
+
elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
|
1536 |
+
classifier_dropout = config.hidden_dropout
|
1537 |
+
else:
|
1538 |
+
classifier_dropout = 0.1
|
1539 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
1540 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
1541 |
+
|
1542 |
+
# Model parallel
|
1543 |
+
self.model_parallel = False
|
1544 |
+
self.device_map = None
|
1545 |
+
|
1546 |
+
# Initialize weights and apply final processing
|
1547 |
+
self.post_init()
|
1548 |
+
|
1549 |
+
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
|
1550 |
+
# fmt: off
|
1551 |
+
@add_code_sample_docstrings(
|
1552 |
+
processor_class=_TOKENIZER_FOR_DOC,
|
1553 |
+
checkpoint="brad1141/gpt2-finetuned-comp2",
|
1554 |
+
output_type=TokenClassifierOutput,
|
1555 |
+
config_class=_CONFIG_FOR_DOC,
|
1556 |
+
expected_loss=0.25,
|
1557 |
+
expected_output=["Lead", "Lead", "Lead", "Position", "Lead", "Lead", "Lead", "Lead", "Lead", "Lead", "Lead", "Lead"],
|
1558 |
+
)
|
1559 |
+
# fmt: on
|
1560 |
+
def forward(
|
1561 |
+
self,
|
1562 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1563 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
1564 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1565 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
1566 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1567 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
1568 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1569 |
+
labels: Optional[torch.LongTensor] = None,
|
1570 |
+
use_cache: Optional[bool] = None,
|
1571 |
+
output_attentions: Optional[bool] = None,
|
1572 |
+
output_hidden_states: Optional[bool] = None,
|
1573 |
+
return_dict: Optional[bool] = None,
|
1574 |
+
) -> Union[Tuple, TokenClassifierOutput]:
|
1575 |
+
r"""
|
1576 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1577 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
1578 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
1579 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
1580 |
+
"""
|
1581 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1582 |
+
|
1583 |
+
transformer_outputs = self.transformer(
|
1584 |
+
input_ids,
|
1585 |
+
past_key_values=past_key_values,
|
1586 |
+
attention_mask=attention_mask,
|
1587 |
+
token_type_ids=token_type_ids,
|
1588 |
+
position_ids=position_ids,
|
1589 |
+
head_mask=head_mask,
|
1590 |
+
inputs_embeds=inputs_embeds,
|
1591 |
+
use_cache=use_cache,
|
1592 |
+
output_attentions=output_attentions,
|
1593 |
+
output_hidden_states=output_hidden_states,
|
1594 |
+
return_dict=return_dict,
|
1595 |
+
)
|
1596 |
+
|
1597 |
+
hidden_states = transformer_outputs[0]
|
1598 |
+
hidden_states = self.dropout(hidden_states)
|
1599 |
+
logits = self.classifier(hidden_states)
|
1600 |
+
|
1601 |
+
loss = None
|
1602 |
+
if labels is not None:
|
1603 |
+
loss_fct = CrossEntropyLoss()
|
1604 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
1605 |
+
|
1606 |
+
if not return_dict:
|
1607 |
+
output = (logits,) + transformer_outputs[2:]
|
1608 |
+
return ((loss,) + output) if loss is not None else output
|
1609 |
+
|
1610 |
+
return TokenClassifierOutput(
|
1611 |
+
loss=loss,
|
1612 |
+
logits=logits,
|
1613 |
+
hidden_states=transformer_outputs.hidden_states,
|
1614 |
+
attentions=transformer_outputs.attentions,
|
1615 |
+
)
|
lavila/models/loss.py
ADDED
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.distributed as dist
|
11 |
+
import torch.distributed.nn
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
|
15 |
+
from .distributed_utils import gather_from_all
|
16 |
+
|
17 |
+
|
18 |
+
def gather_features(
|
19 |
+
image_features,
|
20 |
+
text_features,
|
21 |
+
local_loss=False,
|
22 |
+
gather_with_grad=False,
|
23 |
+
rank=0,
|
24 |
+
world_size=1,
|
25 |
+
):
|
26 |
+
# Adapted from: https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/loss.py
|
27 |
+
# We gather tensors from all gpus
|
28 |
+
if gather_with_grad:
|
29 |
+
all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
|
30 |
+
all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
|
31 |
+
else:
|
32 |
+
gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
|
33 |
+
gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
|
34 |
+
dist.all_gather(gathered_image_features, image_features)
|
35 |
+
dist.all_gather(gathered_text_features, text_features)
|
36 |
+
if not local_loss:
|
37 |
+
# ensure grads for local rank when all_* features don't have a gradient
|
38 |
+
gathered_image_features[rank] = image_features
|
39 |
+
gathered_text_features[rank] = text_features
|
40 |
+
all_image_features = torch.cat(gathered_image_features, dim=0)
|
41 |
+
all_text_features = torch.cat(gathered_text_features, dim=0)
|
42 |
+
|
43 |
+
return all_image_features, all_text_features
|
44 |
+
|
45 |
+
|
46 |
+
class CLIPLoss(nn.Module):
|
47 |
+
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
use_vissl=False,
|
51 |
+
local_loss=False,
|
52 |
+
gather_with_grad=False,
|
53 |
+
cache_labels=False,
|
54 |
+
rank=0,
|
55 |
+
world_size=1,
|
56 |
+
):
|
57 |
+
super().__init__()
|
58 |
+
self.use_vissl = use_vissl
|
59 |
+
self.local_loss = local_loss
|
60 |
+
self.gather_with_grad = gather_with_grad
|
61 |
+
self.cache_labels = cache_labels
|
62 |
+
self.rank = rank
|
63 |
+
self.world_size = world_size
|
64 |
+
|
65 |
+
# cache state
|
66 |
+
self.prev_num_logits = 0
|
67 |
+
self.labels = {}
|
68 |
+
|
69 |
+
def forward(self, outputs):
|
70 |
+
image_features = outputs['image_embed']
|
71 |
+
text_features = outputs['text_embed']
|
72 |
+
logit_scale = outputs['logit_scale']
|
73 |
+
device = image_features.device
|
74 |
+
if self.world_size > 1:
|
75 |
+
if self.use_vissl:
|
76 |
+
all_image_features = gather_from_all(image_features)
|
77 |
+
all_text_features = gather_from_all(text_features)
|
78 |
+
logits_per_image = logit_scale * all_image_features @ all_text_features.T
|
79 |
+
logits_per_text = logits_per_image.T
|
80 |
+
else:
|
81 |
+
all_image_features, all_text_features = gather_features(
|
82 |
+
image_features, text_features,
|
83 |
+
self.local_loss, self.gather_with_grad, self.rank, self.world_size)
|
84 |
+
|
85 |
+
if self.local_loss:
|
86 |
+
logits_per_image = logit_scale * image_features @ all_text_features.T
|
87 |
+
logits_per_text = logit_scale * text_features @ all_image_features.T
|
88 |
+
else:
|
89 |
+
logits_per_image = logit_scale * all_image_features @ all_text_features.T
|
90 |
+
logits_per_text = logits_per_image.T
|
91 |
+
else:
|
92 |
+
logits_per_image = logit_scale * image_features @ text_features.T
|
93 |
+
logits_per_text = logit_scale * text_features @ image_features.T
|
94 |
+
|
95 |
+
# calculated ground-truth and cache if enabled
|
96 |
+
num_logits = logits_per_image.shape[0]
|
97 |
+
if self.prev_num_logits != num_logits or device not in self.labels:
|
98 |
+
labels = torch.arange(num_logits, device=device, dtype=torch.long)
|
99 |
+
if self.world_size > 1 and self.local_loss:
|
100 |
+
labels = labels + num_logits * self.rank
|
101 |
+
if self.cache_labels:
|
102 |
+
self.labels[device] = labels
|
103 |
+
self.prev_num_logits = num_logits
|
104 |
+
else:
|
105 |
+
labels = self.labels[device]
|
106 |
+
|
107 |
+
loss = (
|
108 |
+
F.cross_entropy(logits_per_image, labels) +
|
109 |
+
F.cross_entropy(logits_per_text, labels)
|
110 |
+
) / 2
|
111 |
+
|
112 |
+
# compute accuracy
|
113 |
+
with torch.no_grad():
|
114 |
+
pred = torch.argmax(logits_per_image, dim=-1)
|
115 |
+
correct = pred.eq(labels).sum()
|
116 |
+
acc = 100 * correct / logits_per_image.size(0)
|
117 |
+
|
118 |
+
return {'loss': loss, 'clip_loss': loss, 'clip_acc': acc}
|
119 |
+
|
120 |
+
|
121 |
+
class SSLCLIPLoss(nn.Module):
|
122 |
+
|
123 |
+
def __init__(
|
124 |
+
self,
|
125 |
+
use_vissl=False,
|
126 |
+
local_loss=False,
|
127 |
+
gather_with_grad=False,
|
128 |
+
cache_labels=False,
|
129 |
+
rank=0,
|
130 |
+
world_size=1,
|
131 |
+
scale_init=0.08,
|
132 |
+
freeze_scale=False,
|
133 |
+
):
|
134 |
+
super().__init__()
|
135 |
+
self.use_vissl = use_vissl
|
136 |
+
self.local_loss = local_loss
|
137 |
+
self.gather_with_grad = gather_with_grad
|
138 |
+
self.cache_labels = cache_labels
|
139 |
+
self.rank = rank
|
140 |
+
self.world_size = world_size
|
141 |
+
self.logit_scale_pseudo = nn.Parameter(torch.ones([]) * np.log(1 / scale_init))
|
142 |
+
if freeze_scale:
|
143 |
+
self.logit_scale_pseudo.requires_grad = False
|
144 |
+
|
145 |
+
# cache state
|
146 |
+
self.prev_num_logits = 0
|
147 |
+
self.labels = {}
|
148 |
+
|
149 |
+
def forward(self, outputs, gt_indicators):
|
150 |
+
image_features = outputs['image_embed']
|
151 |
+
text_features = outputs['text_embed']
|
152 |
+
logit_scale = outputs['logit_scale']
|
153 |
+
logit_scale_pseudo = self.logit_scale_pseudo.exp()
|
154 |
+
device = image_features.device
|
155 |
+
if self.world_size > 1:
|
156 |
+
if self.use_vissl:
|
157 |
+
all_image_features = gather_from_all(image_features)
|
158 |
+
all_text_features = gather_from_all(text_features)
|
159 |
+
all_gt_indicators = gather_from_all(gt_indicators)
|
160 |
+
num = all_gt_indicators.shape[0]
|
161 |
+
mask = all_gt_indicators.repeat(num, 1) + all_gt_indicators.repeat(num, 1).T
|
162 |
+
logit_scale_mat = torch.ones((num, num), device=device)
|
163 |
+
logit_scale_mat[mask == 0] = logit_scale_pseudo
|
164 |
+
logit_scale_mat[mask == 1] = torch.sqrt(logit_scale_pseudo * logit_scale)
|
165 |
+
logit_scale_mat[mask == 2] = logit_scale
|
166 |
+
logits_per_image = logit_scale_mat * (all_image_features @ all_text_features.T)
|
167 |
+
logits_per_text = logits_per_image.T
|
168 |
+
else:
|
169 |
+
raise NotImplementedError
|
170 |
+
else:
|
171 |
+
all_gt_indicators = gt_indicators
|
172 |
+
num = gt_indicators.shape[0]
|
173 |
+
mask = gt_indicators.repeat(num, 1) + gt_indicators.repeat(num, 1).T
|
174 |
+
logit_scale_mat = torch.ones((num, num), device=device)
|
175 |
+
logit_scale_mat[mask == 0] = logit_scale_pseudo
|
176 |
+
logit_scale_mat[mask == 1] = torch.sqrt(logit_scale_pseudo * logit_scale)
|
177 |
+
logit_scale_mat[mask == 2] = logit_scale
|
178 |
+
logits_per_image = logit_scale_mat * (image_features @ text_features.T)
|
179 |
+
logits_per_text = logit_scale_mat * (text_features @ image_features.T)
|
180 |
+
|
181 |
+
# calculated ground-truth and cache if enabled
|
182 |
+
num_logits = logits_per_image.shape[0]
|
183 |
+
if self.prev_num_logits != num_logits or device not in self.labels:
|
184 |
+
labels = torch.arange(num_logits, device=device, dtype=torch.long)
|
185 |
+
if self.world_size > 1 and self.local_loss:
|
186 |
+
labels = labels + num_logits * self.rank
|
187 |
+
if self.cache_labels:
|
188 |
+
self.labels[device] = labels
|
189 |
+
self.prev_num_logits = num_logits
|
190 |
+
else:
|
191 |
+
labels = self.labels[device]
|
192 |
+
|
193 |
+
loss = (
|
194 |
+
F.cross_entropy(logits_per_image, labels) +
|
195 |
+
F.cross_entropy(logits_per_text, labels)
|
196 |
+
) / 2
|
197 |
+
|
198 |
+
# compute accuracy
|
199 |
+
with torch.no_grad():
|
200 |
+
pred = torch.argmax(logits_per_image, dim=-1)
|
201 |
+
correct = pred.eq(labels).sum()
|
202 |
+
acc = 100 * correct / logits_per_image.size(0)
|
203 |
+
pred_gt = pred[all_gt_indicators == 1]
|
204 |
+
labels_gt = labels[all_gt_indicators == 1]
|
205 |
+
pred_pseudo = pred[all_gt_indicators == 0]
|
206 |
+
labels_pseudo = labels[all_gt_indicators == 0]
|
207 |
+
num_gt = pred_gt.shape[0]
|
208 |
+
num_pseudo = pred_pseudo.shape[0]
|
209 |
+
correct_gt = pred_gt.eq(labels_gt).sum()
|
210 |
+
correct_pseudo = pred_pseudo.eq(labels_pseudo).sum()
|
211 |
+
acc_gt = 100 * correct_gt / num_gt
|
212 |
+
acc_pseudo = 100 * correct_pseudo / num_pseudo
|
213 |
+
|
214 |
+
return {
|
215 |
+
'loss': loss, 'clip_loss': loss, 'num_gt': torch.tensor([num_gt]), 'num_pseudo': torch.tensor([num_pseudo]),
|
216 |
+
'clip_acc': acc, 'clip_acc_gt': acc_gt, 'clip_acc_pseudo': acc_pseudo
|
217 |
+
}
|
218 |
+
|
219 |
+
|
220 |
+
class CaptionLoss(nn.Module):
|
221 |
+
def __init__(self, pad_id=0, tokenizer=None):
|
222 |
+
super().__init__()
|
223 |
+
self.pad_id = pad_id
|
224 |
+
self.tokenizer = tokenizer
|
225 |
+
self.pad_id = tokenizer.pad_token_id
|
226 |
+
|
227 |
+
def forward(self, outputs):
|
228 |
+
logits = outputs['text_tokens_logits']
|
229 |
+
labels = outputs['labels']
|
230 |
+
# loss = F.cross_entropy(logits, labels, ignore_index=self.pad_id)
|
231 |
+
loss = F.cross_entropy(logits, labels, ignore_index=self.pad_id, reduction='none')
|
232 |
+
|
233 |
+
# compute accuracy
|
234 |
+
with torch.no_grad():
|
235 |
+
correct = 0.
|
236 |
+
total = 0.
|
237 |
+
ppls = []
|
238 |
+
for i in range(logits.size(0)):
|
239 |
+
pred = torch.argmax(logits[i], dim=0)
|
240 |
+
nopad = labels[i].ne(self.pad_id)
|
241 |
+
correct += (pred.eq(labels[i]) & nopad).sum()
|
242 |
+
total += nopad.sum()
|
243 |
+
ppl = torch.exp(loss[i].sum() / nopad.sum())
|
244 |
+
ppls.append(ppl)
|
245 |
+
# TODO: for debug only
|
246 |
+
# sep_pos = labels[i].tolist().index(self.tokenizer.tokenizer.sep_token_id)
|
247 |
+
# if self.tokenizer is not None:
|
248 |
+
# print('{} {} {}'.format(
|
249 |
+
# i, self.tokenizer.tokenizer.convert_ids_to_tokens(pred[:sep_pos]),
|
250 |
+
# self.tokenizer.tokenizer.convert_ids_to_tokens(labels[i, :sep_pos]),
|
251 |
+
# ))
|
252 |
+
acc = 100 * correct / (total + 1e-8)
|
253 |
+
return {'loss': loss.mean(), 'caption_loss': loss.mean(), 'caption_acc': acc, 'ppl': torch.tensor(ppls).mean()}
|
254 |
+
|
255 |
+
|
256 |
+
def sim_matrix(a, b, eps=1e-8):
|
257 |
+
"""
|
258 |
+
added eps for numerical stability
|
259 |
+
"""
|
260 |
+
a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
|
261 |
+
a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
|
262 |
+
b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
|
263 |
+
sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
|
264 |
+
return sim_mt
|
265 |
+
|
266 |
+
|
267 |
+
class MaxMarginRankingLoss(nn.Module):
|
268 |
+
|
269 |
+
def __init__(self, margin=0.2, fix_norm=True):
|
270 |
+
super().__init__()
|
271 |
+
self.fix_norm = fix_norm
|
272 |
+
self.loss = nn.MarginRankingLoss(margin)
|
273 |
+
self.margin = margin
|
274 |
+
|
275 |
+
def forward(self, outputs, weight=None):
|
276 |
+
image_features = outputs['image_embed']
|
277 |
+
text_features = outputs['text_embed']
|
278 |
+
|
279 |
+
all_image_features = gather_from_all(image_features)
|
280 |
+
all_text_features = gather_from_all(text_features)
|
281 |
+
x = sim_matrix(all_text_features, all_image_features)
|
282 |
+
|
283 |
+
n = x.size()[0]
|
284 |
+
|
285 |
+
x1 = torch.diag(x)
|
286 |
+
x1 = x1.unsqueeze(1)
|
287 |
+
x1 = x1.expand(n, n)
|
288 |
+
x1 = x1.contiguous().view(-1, 1)
|
289 |
+
x1 = torch.cat((x1, x1), 0)
|
290 |
+
|
291 |
+
x2 = x.view(-1, 1)
|
292 |
+
x3 = x.transpose(0, 1).contiguous().view(-1, 1)
|
293 |
+
|
294 |
+
x2 = torch.cat((x2, x3), 0)
|
295 |
+
max_margin = F.relu(self.margin - (x1 - x2))
|
296 |
+
|
297 |
+
if self.fix_norm:
|
298 |
+
# remove the elements from the diagonal
|
299 |
+
keep = torch.ones(x.shape) - torch.eye(x.shape[0]) # 128 x 128
|
300 |
+
keep1 = keep.view(-1, 1)
|
301 |
+
keep2 = keep.transpose(0, 1).contiguous().view(-1, 1)
|
302 |
+
keep_idx = torch.nonzero(torch.cat((keep1, keep2), 0).flatten()).flatten()
|
303 |
+
if x1.is_cuda:
|
304 |
+
keep_idx = keep_idx.cuda()
|
305 |
+
x1_ = torch.index_select(x1, dim=0, index=keep_idx)
|
306 |
+
x2_ = torch.index_select(x2, dim=0, index=keep_idx)
|
307 |
+
max_margin = F.relu(self.margin - (x1_ - x2_))
|
308 |
+
|
309 |
+
return {
|
310 |
+
'loss': max_margin.mean(),
|
311 |
+
'max_margin_loss': max_margin.mean()
|
312 |
+
}
|
313 |
+
|
314 |
+
|
315 |
+
class AdaptiveMaxMarginRankingLoss(nn.Module):
|
316 |
+
|
317 |
+
def __init__(self, margin=0.4, fix_norm=True):
|
318 |
+
super().__init__()
|
319 |
+
self.fix_norm = fix_norm
|
320 |
+
self.loss = nn.MarginRankingLoss(margin)
|
321 |
+
self.margin = margin
|
322 |
+
|
323 |
+
def forward(self, outputs, weight=None):
|
324 |
+
image_features = outputs['image_embed']
|
325 |
+
text_features = outputs['text_embed']
|
326 |
+
|
327 |
+
all_image_features = gather_from_all(image_features)
|
328 |
+
all_text_features = gather_from_all(text_features)
|
329 |
+
all_weights = gather_from_all(weight)
|
330 |
+
x = sim_matrix(all_text_features, all_image_features)
|
331 |
+
|
332 |
+
n = x.size()[0]
|
333 |
+
|
334 |
+
x1 = torch.diag(x)
|
335 |
+
x1 = x1.unsqueeze(1)
|
336 |
+
x1 = x1.expand(n, n)
|
337 |
+
x1 = x1.contiguous().view(-1, 1)
|
338 |
+
x1 = torch.cat((x1, x1), 0)
|
339 |
+
|
340 |
+
w1 = all_weights.unsqueeze(1)
|
341 |
+
w1 = w1.expand(n, n)
|
342 |
+
w1 = w1.contiguous().view(-1, 1)
|
343 |
+
w1 = torch.cat((w1, w1), 0)
|
344 |
+
|
345 |
+
x2 = x.view(-1, 1)
|
346 |
+
x3 = x.transpose(0, 1).contiguous().view(-1, 1)
|
347 |
+
|
348 |
+
x2 = torch.cat((x2, x3), 0)
|
349 |
+
max_margin = F.relu(w1 * self.margin - (x1 - x2))
|
350 |
+
|
351 |
+
if self.fix_norm:
|
352 |
+
# remove the elements from the diagonal
|
353 |
+
keep = torch.ones(x.shape) - torch.eye(x.shape[0]) # 128 x 128
|
354 |
+
keep1 = keep.view(-1, 1)
|
355 |
+
keep2 = keep.transpose(0, 1).contiguous().view(-1, 1)
|
356 |
+
keep_idx = torch.nonzero(torch.cat((keep1, keep2), 0).flatten()).flatten()
|
357 |
+
if x1.is_cuda:
|
358 |
+
keep_idx = keep_idx.cuda()
|
359 |
+
x1_ = torch.index_select(x1, dim=0, index=keep_idx)
|
360 |
+
w1_ = torch.index_select(w1, dim=0, index=keep_idx)
|
361 |
+
x2_ = torch.index_select(x2, dim=0, index=keep_idx)
|
362 |
+
max_margin = F.relu(w1_ * self.margin - (x1_ - x2_))
|
363 |
+
|
364 |
+
return {
|
365 |
+
'loss': max_margin.mean(),
|
366 |
+
'max_margin_loss': max_margin.mean()
|
367 |
+
}
|
lavila/models/models.py
ADDED
@@ -0,0 +1,1218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import timm
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from transformers import DistilBertModel, GPT2LMHeadModel
|
13 |
+
|
14 |
+
import lavila.models.loss as loss
|
15 |
+
from lavila.models.gpt2_gated import GPT2LMHeadModel as GatedGPT2LMHeadModel
|
16 |
+
from lavila.models.gpt2_gated import augment_gpt2_config
|
17 |
+
from lavila.models.narrator import VCLM_HF
|
18 |
+
from lavila.models.openai_clip import load as load_openai_clip
|
19 |
+
from lavila.models.openai_model import QuickGELU, Transformer
|
20 |
+
from lavila.models.timesformer import SpaceTimeTransformer
|
21 |
+
from lavila.models.utils import remap_keys, rsetattr
|
22 |
+
|
23 |
+
|
24 |
+
class VideoClassifier(nn.Module):
|
25 |
+
def __init__(self,
|
26 |
+
vision_model: nn.Module,
|
27 |
+
dropout: float,
|
28 |
+
num_classes: int,
|
29 |
+
**kwargs,
|
30 |
+
):
|
31 |
+
super().__init__()
|
32 |
+
self.visual = vision_model
|
33 |
+
self.dropout = nn.Dropout(dropout)
|
34 |
+
self.fc_cls = nn.Linear(vision_model.num_features, num_classes, bias=True)
|
35 |
+
|
36 |
+
self.fc_cls.weight.data.normal_(mean=0.0, std=0.01)
|
37 |
+
self.fc_cls.bias.data.zero_()
|
38 |
+
|
39 |
+
def forward(self, image, use_checkpoint=False):
|
40 |
+
image_embed = self.visual(image, use_checkpoint=use_checkpoint)
|
41 |
+
if isinstance(image_embed, list):
|
42 |
+
assert len(image_embed) == 1
|
43 |
+
image_embed = image_embed[0]
|
44 |
+
logit = self.fc_cls(self.dropout(image_embed))
|
45 |
+
return logit
|
46 |
+
|
47 |
+
|
48 |
+
class VideoClassifierMultiHead(nn.Module):
|
49 |
+
def __init__(self,
|
50 |
+
vision_model: nn.Module,
|
51 |
+
dropout: float,
|
52 |
+
num_classes_list: list,
|
53 |
+
**kwargs,
|
54 |
+
):
|
55 |
+
super().__init__()
|
56 |
+
self.visual = vision_model
|
57 |
+
self.dropout = nn.Dropout(dropout)
|
58 |
+
self.fc_cls = nn.ModuleList(
|
59 |
+
[nn.Linear(vision_model.num_features, num_classes, bias=True) for num_classes in num_classes_list]
|
60 |
+
)
|
61 |
+
|
62 |
+
for m in self.fc_cls:
|
63 |
+
m.weight.data.normal_(mean=0.0, std=0.01)
|
64 |
+
m.bias.data.zero_()
|
65 |
+
|
66 |
+
def forward(self, image, use_checkpoint=False):
|
67 |
+
image_embed = self.visual(image, use_checkpoint=use_checkpoint)
|
68 |
+
if isinstance(image_embed, list):
|
69 |
+
assert len(image_embed) == 1
|
70 |
+
image_embed = image_embed[0]
|
71 |
+
logit_list = [m(self.dropout(image_embed)) for m in self.fc_cls]
|
72 |
+
return logit_list
|
73 |
+
|
74 |
+
|
75 |
+
class CLIP(nn.Module):
|
76 |
+
def __init__(self,
|
77 |
+
embed_dim: int,
|
78 |
+
# vision
|
79 |
+
vision_width: int,
|
80 |
+
vision_model: nn.Module,
|
81 |
+
# text
|
82 |
+
context_length: int,
|
83 |
+
vocab_size: int,
|
84 |
+
transformer_width: int,
|
85 |
+
transformer_heads: int,
|
86 |
+
transformer_layers: int,
|
87 |
+
tempearture_init=0.07,
|
88 |
+
**kwargs,
|
89 |
+
):
|
90 |
+
super().__init__()
|
91 |
+
|
92 |
+
self.context_length = context_length
|
93 |
+
self.vision_width = vision_width
|
94 |
+
|
95 |
+
self.visual = vision_model
|
96 |
+
self.transformer = Transformer(
|
97 |
+
width=transformer_width,
|
98 |
+
layers=transformer_layers,
|
99 |
+
heads=transformer_heads,
|
100 |
+
attn_mask=self.build_attention_mask(),
|
101 |
+
)
|
102 |
+
|
103 |
+
self.vocab_size = vocab_size
|
104 |
+
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
105 |
+
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
106 |
+
self.ln_final = nn.LayerNorm(transformer_width) # used to be `models.transformer.LayerNorm``
|
107 |
+
|
108 |
+
self.image_projection = nn.Parameter(torch.empty(vision_width, embed_dim))
|
109 |
+
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
110 |
+
print("=> initialize initial temperature with {}".format(tempearture_init))
|
111 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / tempearture_init))
|
112 |
+
|
113 |
+
self.initialize_parameters()
|
114 |
+
|
115 |
+
def initialize_parameters(self):
|
116 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
117 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
118 |
+
|
119 |
+
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
120 |
+
attn_std = self.transformer.width ** -0.5
|
121 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
122 |
+
for block in self.transformer.resblocks:
|
123 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
124 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
125 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
126 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
127 |
+
|
128 |
+
nn.init.normal_(self.image_projection, std=self.vision_width ** -0.5)
|
129 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
130 |
+
|
131 |
+
def build_attention_mask(self):
|
132 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
133 |
+
# pytorch uses additive attention mask; fill with -inf
|
134 |
+
mask = torch.empty(self.context_length, self.context_length)
|
135 |
+
mask.fill_(float("-inf"))
|
136 |
+
mask.triu_(1) # zero out the lower diagonal
|
137 |
+
return mask
|
138 |
+
|
139 |
+
def encode_image(self, image, use_checkpoint=False, apply_project=True):
|
140 |
+
x = self.visual(image, use_checkpoint=use_checkpoint)
|
141 |
+
if isinstance(x, list):
|
142 |
+
assert len(x) == 1
|
143 |
+
x = x[0]
|
144 |
+
if not apply_project:
|
145 |
+
return x
|
146 |
+
x = x @ self.image_projection
|
147 |
+
|
148 |
+
return x
|
149 |
+
|
150 |
+
def encode_text(self, text, use_checkpoint=False):
|
151 |
+
x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
|
152 |
+
x = x + self.positional_embedding
|
153 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
154 |
+
x = self.transformer(x, use_checkpoint=use_checkpoint)
|
155 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
156 |
+
x = self.ln_final(x)
|
157 |
+
|
158 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
159 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
160 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
161 |
+
|
162 |
+
return x
|
163 |
+
|
164 |
+
def forward(self, image, text, use_checkpoint=False, norm_embed=False):
|
165 |
+
image_embed = self.encode_image(image, use_checkpoint=use_checkpoint)
|
166 |
+
text_embed = self.encode_text(text, use_checkpoint=use_checkpoint)
|
167 |
+
|
168 |
+
if norm_embed:
|
169 |
+
image_embed = F.normalize(image_embed, dim=-1)
|
170 |
+
text_embed = F.normalize(text_embed, dim=-1)
|
171 |
+
return {'image_embed': image_embed,
|
172 |
+
'text_embed': text_embed,
|
173 |
+
'logit_scale': self.logit_scale.exp()}
|
174 |
+
|
175 |
+
|
176 |
+
class CLIP_HF(nn.Module):
|
177 |
+
def __init__(self,
|
178 |
+
embed_dim: int,
|
179 |
+
# vision
|
180 |
+
vision_width: int,
|
181 |
+
vision_model: nn.Module,
|
182 |
+
# text
|
183 |
+
text_width: int,
|
184 |
+
text_model: nn.Module,
|
185 |
+
text_use_cls_token: bool,
|
186 |
+
text_is_regressive: bool,
|
187 |
+
tempearture_init=0.07,
|
188 |
+
**kwargs,
|
189 |
+
):
|
190 |
+
super().__init__()
|
191 |
+
|
192 |
+
self.vision_width = vision_width
|
193 |
+
self.visual = vision_model
|
194 |
+
self.text_width = text_width
|
195 |
+
self.textual = text_model
|
196 |
+
self.text_use_cls_token = text_use_cls_token
|
197 |
+
self.text_is_regressive = text_is_regressive
|
198 |
+
|
199 |
+
if 'projection' not in kwargs:
|
200 |
+
self.projection = 'default'
|
201 |
+
else:
|
202 |
+
self.projection = kwargs['projection']
|
203 |
+
if self.projection == 'default':
|
204 |
+
self.image_projection = nn.Parameter(torch.empty(vision_width, embed_dim))
|
205 |
+
self.text_projection = nn.Parameter(torch.empty(text_width, embed_dim))
|
206 |
+
elif self.projection == 'frozen_in_time':
|
207 |
+
self.image_projection = nn.Sequential(
|
208 |
+
nn.Linear(vision_width, embed_dim)
|
209 |
+
)
|
210 |
+
self.text_projection = nn.Sequential(
|
211 |
+
nn.ReLU(),
|
212 |
+
nn.Linear(text_width, embed_dim)
|
213 |
+
)
|
214 |
+
print("=> initialize initial temperature with {}".format(tempearture_init))
|
215 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / tempearture_init))
|
216 |
+
|
217 |
+
self.initialize_parameters()
|
218 |
+
|
219 |
+
def initialize_parameters(self):
|
220 |
+
if self.projection == 'default':
|
221 |
+
nn.init.normal_(self.image_projection, std=self.vision_width ** -0.5)
|
222 |
+
nn.init.normal_(self.text_projection, std=self.text_width ** -0.5)
|
223 |
+
else:
|
224 |
+
nn.init.normal_(self.image_projection[0].weight, std=self.vision_width ** -0.5)
|
225 |
+
nn.init.normal_(self.text_projection[1].weight, std=self.text_width ** -0.5)
|
226 |
+
|
227 |
+
def build_attention_mask(self):
|
228 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
229 |
+
# pytorch uses additive attention mask; fill with -inf
|
230 |
+
mask = torch.empty(self.context_length, self.context_length)
|
231 |
+
mask.fill_(float("-inf"))
|
232 |
+
mask.triu_(1) # zero out the lower diagonal
|
233 |
+
return mask
|
234 |
+
|
235 |
+
def encode_image(self, image, use_checkpoint=False, apply_project=True):
|
236 |
+
x = self.visual(image, use_checkpoint=use_checkpoint)
|
237 |
+
if isinstance(x, list):
|
238 |
+
assert len(x) == 1
|
239 |
+
x = x[0]
|
240 |
+
if not apply_project:
|
241 |
+
return x
|
242 |
+
if self.projection == 'default':
|
243 |
+
x = x @ self.image_projection
|
244 |
+
else:
|
245 |
+
x = self.image_projection(x)
|
246 |
+
|
247 |
+
return x
|
248 |
+
|
249 |
+
def encode_text(self, text, attention_mask=None, use_checkpoint=False):
|
250 |
+
if use_checkpoint:
|
251 |
+
if isinstance(self.textual, DistilBertModel):
|
252 |
+
pass
|
253 |
+
# print("DistilBertModel does not support gradient checkpointing. Skipping even if use_checkpoint=True")
|
254 |
+
else:
|
255 |
+
self.textual.gradient_checkpointing_enable()
|
256 |
+
else:
|
257 |
+
self.textual.gradient_checkpointing_disable()
|
258 |
+
# text, attention_mask = text.squeeze(1), attention_mask.squeeze(1)
|
259 |
+
# ^ uncomment this only when doing local debugging (distributed=False)
|
260 |
+
x = self.textual(text, attention_mask=attention_mask)
|
261 |
+
|
262 |
+
if self.text_is_regressive:
|
263 |
+
# gpt-style
|
264 |
+
x = x.last_hidden_state
|
265 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
266 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)]
|
267 |
+
else:
|
268 |
+
# bert-style
|
269 |
+
if self.text_use_cls_token:
|
270 |
+
x = x.last_hidden_state
|
271 |
+
x = x[torch.arange(x.shape[0]), 0, :]
|
272 |
+
else:
|
273 |
+
x = x.pooler_output
|
274 |
+
if self.projection == 'default':
|
275 |
+
x = x @ self.text_projection
|
276 |
+
else:
|
277 |
+
x = self.text_projection(x)
|
278 |
+
|
279 |
+
return x
|
280 |
+
|
281 |
+
def forward(self, image, text, mask=None, use_checkpoint=False, norm_embed=False):
|
282 |
+
image_embed = self.encode_image(image, use_checkpoint=use_checkpoint)
|
283 |
+
text_embed = self.encode_text(text, attention_mask=mask, use_checkpoint=use_checkpoint)
|
284 |
+
|
285 |
+
if norm_embed:
|
286 |
+
image_embed = F.normalize(image_embed, dim=-1)
|
287 |
+
text_embed = F.normalize(text_embed, dim=-1)
|
288 |
+
return {'image_embed': image_embed,
|
289 |
+
'text_embed': text_embed,
|
290 |
+
'logit_scale': self.logit_scale.exp()}
|
291 |
+
|
292 |
+
|
293 |
+
def get_loss(model, args, tokenizer=None):
|
294 |
+
if model.startswith('CLIP'):
|
295 |
+
return loss.CLIPLoss(
|
296 |
+
use_vissl=args.contrastive_use_vissl,
|
297 |
+
cache_labels=True,
|
298 |
+
rank=args.rank,
|
299 |
+
world_size=args.world_size,
|
300 |
+
)
|
301 |
+
elif model.startswith('VCLM'):
|
302 |
+
return loss.CaptionLoss(tokenizer=tokenizer)
|
303 |
+
else:
|
304 |
+
raise NotImplementedError
|
305 |
+
|
306 |
+
|
307 |
+
def get_metric_names(model):
|
308 |
+
if model.startswith('CLIP'):
|
309 |
+
return ['loss', 'clip_loss', 'clip_acc']
|
310 |
+
elif model.startswith('VCLM'):
|
311 |
+
return ['loss', 'caption_loss', 'caption_acc', 'ppl']
|
312 |
+
else:
|
313 |
+
raise NotImplementedError
|
314 |
+
|
315 |
+
|
316 |
+
def CLIP_OPENAI_TIMESFORMER_BASE(
|
317 |
+
num_frames=4, timesformer_gated_xattn=False, drop_path_rate=0, timesformer_freeze_space=False,
|
318 |
+
temperature_init=0.07, project_embed_dim=256, **kwargs,
|
319 |
+
):
|
320 |
+
vision_model = SpaceTimeTransformer(
|
321 |
+
num_frames=num_frames,
|
322 |
+
time_init='zeros',
|
323 |
+
attention_style='frozen-in-time',
|
324 |
+
ln_pre=True,
|
325 |
+
act_layer=QuickGELU,
|
326 |
+
is_tanh_gating=timesformer_gated_xattn,
|
327 |
+
drop_path_rate=drop_path_rate,
|
328 |
+
)
|
329 |
+
clip_model, _ = load_openai_clip('ViT-B/16', 'cpu')
|
330 |
+
print("=> Loading CLIP (ViT-B/16) weights")
|
331 |
+
remapped_state_dict = remap_keys(clip_model.visual.state_dict(), transformer_layers=12)
|
332 |
+
res = vision_model.load_state_dict(remapped_state_dict, strict=False)
|
333 |
+
print(res)
|
334 |
+
if timesformer_freeze_space:
|
335 |
+
print("=> Freeze the space part in TimeSformer")
|
336 |
+
freeze_list, unfreeze_list = [], []
|
337 |
+
for n, p in vision_model.named_parameters():
|
338 |
+
if n not in remapped_state_dict or n == 'cls_token':
|
339 |
+
p.requires_grad = True
|
340 |
+
unfreeze_list.append(n)
|
341 |
+
else:
|
342 |
+
p.requires_grad = False
|
343 |
+
freeze_list.append(n)
|
344 |
+
print("Freeze the pretrained parts in TimeSformer: {}".format(freeze_list))
|
345 |
+
print(" Learn the rest parts in TimeSformer: {}".format(unfreeze_list))
|
346 |
+
|
347 |
+
vision_model.head = nn.Identity()
|
348 |
+
vision_model.pre_logits = nn.Identity()
|
349 |
+
vision_model.fc = nn.Identity()
|
350 |
+
model = CLIP(
|
351 |
+
embed_dim=project_embed_dim,
|
352 |
+
vision_width=768,
|
353 |
+
vision_model=vision_model,
|
354 |
+
context_length=77,
|
355 |
+
vocab_size=49408,
|
356 |
+
transformer_width=512,
|
357 |
+
transformer_heads=8,
|
358 |
+
transformer_layers=12,
|
359 |
+
tempearture_init=temperature_init,
|
360 |
+
**kwargs
|
361 |
+
)
|
362 |
+
model.transformer.load_state_dict(clip_model.transformer.state_dict())
|
363 |
+
model.token_embedding.load_state_dict(clip_model.token_embedding.state_dict())
|
364 |
+
model.positional_embedding.data.copy_(clip_model.positional_embedding.data)
|
365 |
+
model.ln_final.load_state_dict(clip_model.ln_final.state_dict())
|
366 |
+
if project_embed_dim == clip_model.text_projection.shape[1]:
|
367 |
+
print("=> Loading CLIP's text_projection, image_projection and logit_scale directly")
|
368 |
+
model.image_projection.data.copy_(clip_model.visual.proj.data)
|
369 |
+
model.text_projection.data.copy_(clip_model.text_projection.data)
|
370 |
+
model.logit_scale.data.copy_(clip_model.logit_scale.data)
|
371 |
+
return model
|
372 |
+
|
373 |
+
|
374 |
+
def CLIP_OPENAI_TIMESFORMER_LARGE(
|
375 |
+
num_frames=4, timesformer_gated_xattn=False, drop_path_rate=0, timesformer_freeze_space=False,
|
376 |
+
temperature_init=0.07, project_embed_dim=256, **kwargs,
|
377 |
+
):
|
378 |
+
vision_model = SpaceTimeTransformer(
|
379 |
+
img_size=224, patch_size=14,
|
380 |
+
embed_dim=1024, depth=24, num_heads=16,
|
381 |
+
num_frames=num_frames,
|
382 |
+
time_init='zeros',
|
383 |
+
attention_style='frozen-in-time',
|
384 |
+
ln_pre=True,
|
385 |
+
act_layer=QuickGELU,
|
386 |
+
is_tanh_gating=timesformer_gated_xattn,
|
387 |
+
drop_path_rate=drop_path_rate,
|
388 |
+
)
|
389 |
+
clip_model, _ = load_openai_clip('ViT-L/14', 'cpu')
|
390 |
+
print("=> Loading CLIP (ViT-L/14) weights")
|
391 |
+
remapped_state_dict = remap_keys(clip_model.visual.state_dict(), transformer_layers=24)
|
392 |
+
res = vision_model.load_state_dict(remapped_state_dict, strict=False)
|
393 |
+
print(res)
|
394 |
+
if timesformer_freeze_space:
|
395 |
+
print("=> Freeze the space part in TimeSformer")
|
396 |
+
freeze_list, unfreeze_list = [], []
|
397 |
+
for n, p in vision_model.named_parameters():
|
398 |
+
if n not in remapped_state_dict or n == 'cls_token':
|
399 |
+
p.requires_grad = True
|
400 |
+
unfreeze_list.append(n)
|
401 |
+
else:
|
402 |
+
p.requires_grad = False
|
403 |
+
freeze_list.append(n)
|
404 |
+
print("Freeze the pretrained parts in TimeSformer: {}".format(freeze_list))
|
405 |
+
print(" Learn the rest parts in TimeSformer: {}".format(unfreeze_list))
|
406 |
+
|
407 |
+
vision_model.head = nn.Identity()
|
408 |
+
vision_model.pre_logits = nn.Identity()
|
409 |
+
vision_model.fc = nn.Identity()
|
410 |
+
model = CLIP(
|
411 |
+
embed_dim=project_embed_dim,
|
412 |
+
vision_width=1024,
|
413 |
+
vision_model=vision_model,
|
414 |
+
context_length=77,
|
415 |
+
vocab_size=49408,
|
416 |
+
transformer_width=768,
|
417 |
+
transformer_heads=12,
|
418 |
+
transformer_layers=12,
|
419 |
+
tempearture_init=temperature_init,
|
420 |
+
**kwargs
|
421 |
+
)
|
422 |
+
model.transformer.load_state_dict(clip_model.transformer.state_dict())
|
423 |
+
model.token_embedding.load_state_dict(clip_model.token_embedding.state_dict())
|
424 |
+
model.positional_embedding.data.copy_(clip_model.positional_embedding.data)
|
425 |
+
model.ln_final.load_state_dict(clip_model.ln_final.state_dict())
|
426 |
+
if project_embed_dim == clip_model.text_projection.shape[1]:
|
427 |
+
print("=> Loading CLIP's text_projection, image_projection and logit_scale directly")
|
428 |
+
model.image_projection.data.copy_(clip_model.visual.proj.data)
|
429 |
+
model.text_projection.data.copy_(clip_model.text_projection.data)
|
430 |
+
model.logit_scale.data.copy_(clip_model.logit_scale.data)
|
431 |
+
return model
|
432 |
+
|
433 |
+
|
434 |
+
def CLIP_OPENAI_TIMESFORMER_LARGE_336PX(
|
435 |
+
num_frames=4, timesformer_gated_xattn=False, drop_path_rate=0, timesformer_freeze_space=False,
|
436 |
+
temperature_init=0.07, project_embed_dim=256, **kwargs,
|
437 |
+
):
|
438 |
+
vision_model = SpaceTimeTransformer(
|
439 |
+
img_size=336, patch_size=14,
|
440 |
+
embed_dim=1024, depth=24, num_heads=16,
|
441 |
+
num_frames=num_frames,
|
442 |
+
time_init='zeros',
|
443 |
+
attention_style='frozen-in-time',
|
444 |
+
ln_pre=True,
|
445 |
+
act_layer=QuickGELU,
|
446 |
+
is_tanh_gating=timesformer_gated_xattn,
|
447 |
+
drop_path_rate=drop_path_rate,
|
448 |
+
)
|
449 |
+
clip_model, _ = load_openai_clip('ViT-L/14@336px', 'cpu')
|
450 |
+
print("=> Loading CLIP (ViT-L/14@336px) weights")
|
451 |
+
remapped_state_dict = remap_keys(clip_model.visual.state_dict(), transformer_layers=24)
|
452 |
+
res = vision_model.load_state_dict(remapped_state_dict, strict=False)
|
453 |
+
print(res)
|
454 |
+
if timesformer_freeze_space:
|
455 |
+
print("=> Freeze the space part in TimeSformer")
|
456 |
+
freeze_list, unfreeze_list = [], []
|
457 |
+
for n, p in vision_model.named_parameters():
|
458 |
+
if n not in remapped_state_dict or n == 'cls_token':
|
459 |
+
p.requires_grad = True
|
460 |
+
unfreeze_list.append(n)
|
461 |
+
else:
|
462 |
+
p.requires_grad = False
|
463 |
+
freeze_list.append(n)
|
464 |
+
print("Freeze the pretrained parts in TimeSformer: {}".format(freeze_list))
|
465 |
+
print(" Learn the rest parts in TimeSformer: {}".format(unfreeze_list))
|
466 |
+
|
467 |
+
vision_model.head = nn.Identity()
|
468 |
+
vision_model.pre_logits = nn.Identity()
|
469 |
+
vision_model.fc = nn.Identity()
|
470 |
+
model = CLIP(
|
471 |
+
embed_dim=project_embed_dim,
|
472 |
+
vision_width=1024,
|
473 |
+
vision_model=vision_model,
|
474 |
+
context_length=77,
|
475 |
+
vocab_size=49408,
|
476 |
+
transformer_width=768,
|
477 |
+
transformer_heads=12,
|
478 |
+
transformer_layers=12,
|
479 |
+
tempearture_init=temperature_init,
|
480 |
+
**kwargs
|
481 |
+
)
|
482 |
+
model.transformer.load_state_dict(clip_model.transformer.state_dict())
|
483 |
+
model.token_embedding.load_state_dict(clip_model.token_embedding.state_dict())
|
484 |
+
model.positional_embedding.data.copy_(clip_model.positional_embedding.data)
|
485 |
+
model.ln_final.load_state_dict(clip_model.ln_final.state_dict())
|
486 |
+
if project_embed_dim == clip_model.text_projection.shape[1]:
|
487 |
+
print("=> Loading CLIP's text_projection, image_projection and logit_scale directly")
|
488 |
+
model.image_projection.data.copy_(clip_model.visual.proj.data)
|
489 |
+
model.text_projection.data.copy_(clip_model.text_projection.data)
|
490 |
+
model.logit_scale.data.copy_(clip_model.logit_scale.data)
|
491 |
+
return model
|
492 |
+
|
493 |
+
|
494 |
+
def CLIP_OPENAI_TIMESFORMER_BASE_DISTILBERT_BASE(
|
495 |
+
num_frames=4, timesformer_gated_xattn=False, drop_path_rate=0, timesformer_freeze_space=False,
|
496 |
+
temperature_init=0.07, project_embed_dim=256, **kwargs,
|
497 |
+
):
|
498 |
+
vision_model = SpaceTimeTransformer(
|
499 |
+
num_frames=num_frames,
|
500 |
+
time_init='zeros',
|
501 |
+
attention_style='frozen-in-time',
|
502 |
+
ln_pre=True,
|
503 |
+
act_layer=QuickGELU,
|
504 |
+
is_tanh_gating=timesformer_gated_xattn,
|
505 |
+
drop_path_rate=drop_path_rate,
|
506 |
+
)
|
507 |
+
clip_model, _ = load_openai_clip('ViT-B/16', 'cpu')
|
508 |
+
print("=> Loading CLIP (ViT-B/16) weights")
|
509 |
+
remapped_state_dict = remap_keys(clip_model.visual.state_dict(), transformer_layers=12)
|
510 |
+
res = vision_model.load_state_dict(remapped_state_dict, strict=False)
|
511 |
+
print(res)
|
512 |
+
if timesformer_freeze_space:
|
513 |
+
print("=> Freeze the space part in TimeSformer")
|
514 |
+
freeze_list, unfreeze_list = [], []
|
515 |
+
for n, p in vision_model.named_parameters():
|
516 |
+
if n not in remapped_state_dict or n == 'cls_token':
|
517 |
+
p.requires_grad = True
|
518 |
+
unfreeze_list.append(n)
|
519 |
+
else:
|
520 |
+
p.requires_grad = False
|
521 |
+
freeze_list.append(n)
|
522 |
+
print("Freeze the pretrained parts in TimeSformer: {}".format(freeze_list))
|
523 |
+
print(" Learn the rest parts in TimeSformer: {}".format(unfreeze_list))
|
524 |
+
|
525 |
+
vision_model.head = nn.Identity()
|
526 |
+
vision_model.pre_logits = nn.Identity()
|
527 |
+
vision_model.fc = nn.Identity()
|
528 |
+
|
529 |
+
text_model = DistilBertModel.from_pretrained(
|
530 |
+
'distilbert-base-uncased',
|
531 |
+
)
|
532 |
+
kwargs.pop('text_use_cls_token') # ignore args.use_cls_token since DistilBert does not have pooler on top
|
533 |
+
model = CLIP_HF(
|
534 |
+
embed_dim=project_embed_dim,
|
535 |
+
vision_width=vision_model.embed_dim,
|
536 |
+
vision_model=vision_model,
|
537 |
+
text_width=768,
|
538 |
+
text_model=text_model,
|
539 |
+
text_use_cls_token=True, # DistilBert does not have pooler on top
|
540 |
+
text_is_regressive=False,
|
541 |
+
tempearture_init=temperature_init,
|
542 |
+
**kwargs,
|
543 |
+
)
|
544 |
+
|
545 |
+
return model
|
546 |
+
|
547 |
+
|
548 |
+
def CLIP_OPENAI_TIMESFORMER_LARGE_DISTILBERT_BASE(
|
549 |
+
num_frames=4, timesformer_gated_xattn=False, drop_path_rate=0, timesformer_freeze_space=False,
|
550 |
+
temperature_init=0.07, project_embed_dim=256, **kwargs,
|
551 |
+
):
|
552 |
+
vision_model = SpaceTimeTransformer(
|
553 |
+
img_size=224, patch_size=14,
|
554 |
+
embed_dim=1024, depth=24, num_heads=16,
|
555 |
+
num_frames=num_frames,
|
556 |
+
time_init='zeros',
|
557 |
+
attention_style='frozen-in-time',
|
558 |
+
ln_pre=True,
|
559 |
+
act_layer=QuickGELU,
|
560 |
+
is_tanh_gating=timesformer_gated_xattn,
|
561 |
+
drop_path_rate=drop_path_rate,
|
562 |
+
)
|
563 |
+
clip_model, _ = load_openai_clip('ViT-L/14', 'cpu')
|
564 |
+
print("=> Loading CLIP (ViT-L/14) weights")
|
565 |
+
remapped_state_dict = remap_keys(clip_model.visual.state_dict(), transformer_layers=24)
|
566 |
+
res = vision_model.load_state_dict(remapped_state_dict, strict=False)
|
567 |
+
print(res)
|
568 |
+
if timesformer_freeze_space:
|
569 |
+
print("=> Freeze the space part in TimeSformer")
|
570 |
+
freeze_list, unfreeze_list = [], []
|
571 |
+
for n, p in vision_model.named_parameters():
|
572 |
+
if n not in remapped_state_dict or n == 'cls_token':
|
573 |
+
p.requires_grad = True
|
574 |
+
unfreeze_list.append(n)
|
575 |
+
else:
|
576 |
+
p.requires_grad = False
|
577 |
+
freeze_list.append(n)
|
578 |
+
print("Freeze the pretrained parts in TimeSformer: {}".format(freeze_list))
|
579 |
+
print(" Learn the rest parts in TimeSformer: {}".format(unfreeze_list))
|
580 |
+
|
581 |
+
vision_model.head = nn.Identity()
|
582 |
+
vision_model.pre_logits = nn.Identity()
|
583 |
+
vision_model.fc = nn.Identity()
|
584 |
+
|
585 |
+
text_model = DistilBertModel.from_pretrained(
|
586 |
+
'distilbert-base-uncased',
|
587 |
+
)
|
588 |
+
kwargs.pop('text_use_cls_token') # ignore args.use_cls_token since DistilBert does not have pooler on top
|
589 |
+
model = CLIP_HF(
|
590 |
+
embed_dim=project_embed_dim,
|
591 |
+
vision_width=vision_model.embed_dim,
|
592 |
+
vision_model=vision_model,
|
593 |
+
text_width=768,
|
594 |
+
text_model=text_model,
|
595 |
+
text_use_cls_token=True, # DistilBert does not have pooler on top
|
596 |
+
text_is_regressive=False,
|
597 |
+
tempearture_init=temperature_init,
|
598 |
+
**kwargs,
|
599 |
+
)
|
600 |
+
|
601 |
+
return model
|
602 |
+
|
603 |
+
|
604 |
+
def CLIP_OPENAI_TIMESFORMER_LARGE_336PX_DISTILBERT_BASE(
|
605 |
+
num_frames=4, timesformer_gated_xattn=False, drop_path_rate=0, timesformer_freeze_space=False,
|
606 |
+
temperature_init=0.07, project_embed_dim=256, **kwargs,
|
607 |
+
):
|
608 |
+
vision_model = SpaceTimeTransformer(
|
609 |
+
img_size=336, patch_size=14,
|
610 |
+
embed_dim=1024, depth=24, num_heads=16,
|
611 |
+
num_frames=num_frames,
|
612 |
+
time_init='zeros',
|
613 |
+
attention_style='frozen-in-time',
|
614 |
+
ln_pre=True,
|
615 |
+
act_layer=QuickGELU,
|
616 |
+
is_tanh_gating=timesformer_gated_xattn,
|
617 |
+
drop_path_rate=drop_path_rate,
|
618 |
+
)
|
619 |
+
clip_model, _ = load_openai_clip('ViT-L/14@336px', 'cpu')
|
620 |
+
print("=> Loading CLIP (ViT-L/14@336px) weights")
|
621 |
+
remapped_state_dict = remap_keys(clip_model.visual.state_dict(), transformer_layers=24)
|
622 |
+
res = vision_model.load_state_dict(remapped_state_dict, strict=False)
|
623 |
+
print(res)
|
624 |
+
if timesformer_freeze_space:
|
625 |
+
print("=> Freeze the space part in TimeSformer")
|
626 |
+
freeze_list, unfreeze_list = [], []
|
627 |
+
for n, p in vision_model.named_parameters():
|
628 |
+
if n not in remapped_state_dict or n == 'cls_token':
|
629 |
+
p.requires_grad = True
|
630 |
+
unfreeze_list.append(n)
|
631 |
+
else:
|
632 |
+
p.requires_grad = False
|
633 |
+
freeze_list.append(n)
|
634 |
+
print("Freeze the pretrained parts in TimeSformer: {}".format(freeze_list))
|
635 |
+
print(" Learn the rest parts in TimeSformer: {}".format(unfreeze_list))
|
636 |
+
|
637 |
+
vision_model.head = nn.Identity()
|
638 |
+
vision_model.pre_logits = nn.Identity()
|
639 |
+
vision_model.fc = nn.Identity()
|
640 |
+
|
641 |
+
text_model = DistilBertModel.from_pretrained(
|
642 |
+
'distilbert-base-uncased',
|
643 |
+
)
|
644 |
+
kwargs.pop('text_use_cls_token') # ignore args.use_cls_token since DistilBert does not have pooler on top
|
645 |
+
model = CLIP_HF(
|
646 |
+
embed_dim=project_embed_dim,
|
647 |
+
vision_width=vision_model.embed_dim,
|
648 |
+
vision_model=vision_model,
|
649 |
+
text_width=768,
|
650 |
+
text_model=text_model,
|
651 |
+
text_use_cls_token=True, # DistilBert does not have pooler on top
|
652 |
+
text_is_regressive=False,
|
653 |
+
tempearture_init=temperature_init,
|
654 |
+
**kwargs,
|
655 |
+
)
|
656 |
+
|
657 |
+
return model
|
658 |
+
|
659 |
+
|
660 |
+
def CLIP_HF_EGOVLP_DISTILBERT_BASE(num_frames=4, project_embed_dim=256, **kwargs):
|
661 |
+
vision_model = SpaceTimeTransformer(
|
662 |
+
num_frames=num_frames,
|
663 |
+
time_init='zeros',
|
664 |
+
attention_style='frozen-in-time',
|
665 |
+
)
|
666 |
+
vit_model = timm.models.vision_transformer.vit_base_patch16_224(pretrained=True)
|
667 |
+
vision_model.load_state_dict(vit_model.state_dict(), strict=False)
|
668 |
+
vision_model.head = nn.Identity()
|
669 |
+
vision_model.pre_logits = nn.Identity()
|
670 |
+
vision_model.fc = nn.Identity()
|
671 |
+
|
672 |
+
text_model = DistilBertModel.from_pretrained(
|
673 |
+
'distilbert-base-uncased',
|
674 |
+
)
|
675 |
+
kwargs.pop('text_use_cls_token') # ignore args.use_cls_token since DistilBert does not have pooler on top
|
676 |
+
kwargs.update({'projection': 'frozen_in_time'})
|
677 |
+
model = CLIP_HF(
|
678 |
+
embed_dim=project_embed_dim,
|
679 |
+
vision_width=vision_model.embed_dim,
|
680 |
+
vision_model=vision_model,
|
681 |
+
text_width=768,
|
682 |
+
text_model=text_model,
|
683 |
+
text_use_cls_token=True, # DistilBert does not have pooler on top
|
684 |
+
text_is_regressive=False,
|
685 |
+
**kwargs,
|
686 |
+
)
|
687 |
+
|
688 |
+
return model
|
689 |
+
|
690 |
+
|
691 |
+
def CLIP_HF_TIMESFORMER_DISTILBERT_BASE(num_frames=4, drop_path_rate=0, temperature_init=0.07, project_embed_dim=256, **kwargs):
|
692 |
+
vision_model = SpaceTimeTransformer(
|
693 |
+
num_frames=num_frames,
|
694 |
+
time_init='zeros',
|
695 |
+
attention_style='frozen-in-time',
|
696 |
+
drop_path_rate=drop_path_rate,
|
697 |
+
)
|
698 |
+
vit_model = timm.models.vision_transformer.vit_base_patch16_224(pretrained=True)
|
699 |
+
vision_model.load_state_dict(vit_model.state_dict(), strict=False)
|
700 |
+
vision_model.head = nn.Identity()
|
701 |
+
vision_model.pre_logits = nn.Identity()
|
702 |
+
vision_model.fc = nn.Identity()
|
703 |
+
|
704 |
+
text_model = DistilBertModel.from_pretrained(
|
705 |
+
'distilbert-base-uncased',
|
706 |
+
)
|
707 |
+
kwargs.pop('text_use_cls_token') # ignore args.use_cls_token since DistilBert does not have pooler on top
|
708 |
+
model = CLIP_HF(
|
709 |
+
embed_dim=project_embed_dim,
|
710 |
+
vision_width=vision_model.embed_dim,
|
711 |
+
vision_model=vision_model,
|
712 |
+
text_width=768,
|
713 |
+
text_model=text_model,
|
714 |
+
text_use_cls_token=True, # DistilBert does not have pooler on top
|
715 |
+
text_is_regressive=False,
|
716 |
+
tempearture_init=temperature_init,
|
717 |
+
**kwargs,
|
718 |
+
)
|
719 |
+
|
720 |
+
return model
|
721 |
+
|
722 |
+
|
723 |
+
def VCLM_OPENAI_VITB16_GPT2_LARGE(gated_xattn=False, freeze_lm_vclm=False,
|
724 |
+
freeze_visual_vclm=False, freeze_visual_vclm_temporal=False, **kwargs):
|
725 |
+
clip_model, _ = load_openai_clip('ViT-B/16', 'cpu')
|
726 |
+
vision_model = clip_model.visual
|
727 |
+
kwargs.pop('text_use_cls_token')
|
728 |
+
|
729 |
+
gpt2 = GPT2LMHeadModel.from_pretrained(
|
730 |
+
"gpt2-large",
|
731 |
+
use_cache=False,
|
732 |
+
)
|
733 |
+
new_config = augment_gpt2_config(gpt2.config, cross_attn_freq=2, gated_xattn=gated_xattn)
|
734 |
+
text_decoder = GatedGPT2LMHeadModel(new_config)
|
735 |
+
for n, p in gpt2.named_parameters():
|
736 |
+
rsetattr(text_decoder, n + '.data', p.data)
|
737 |
+
|
738 |
+
if freeze_lm_vclm:
|
739 |
+
print('Freeze the LM part of TextDecoder of VCLM')
|
740 |
+
text_decoder.freeze_lm_weights()
|
741 |
+
|
742 |
+
if freeze_visual_vclm:
|
743 |
+
print('Freeze the spatial part of VideoEncoder of VCLM')
|
744 |
+
vision_model.freeze_spatial_weights()
|
745 |
+
|
746 |
+
if freeze_visual_vclm_temporal:
|
747 |
+
print('Freeze the temporal part of VideoEncoder of VCLM')
|
748 |
+
vision_model.freeze_temporal_weights()
|
749 |
+
|
750 |
+
model = VCLM_HF(
|
751 |
+
vision_width=768,
|
752 |
+
vision_model=vision_model,
|
753 |
+
text_width=1280,
|
754 |
+
text_decoder=text_decoder,
|
755 |
+
num_img_queries=256,
|
756 |
+
dim_head=64,
|
757 |
+
heads=20,
|
758 |
+
**kwargs,
|
759 |
+
)
|
760 |
+
|
761 |
+
return model
|
762 |
+
|
763 |
+
|
764 |
+
def VCLM_OPENAI_VITB16_GPT2_XL(gated_xattn=False, freeze_lm_vclm=False,
|
765 |
+
freeze_visual_vclm=False, freeze_visual_vclm_temporal=False, **kwargs):
|
766 |
+
clip_model, _ = load_openai_clip('ViT-B/16', 'cpu')
|
767 |
+
vision_model = clip_model.visual
|
768 |
+
kwargs.pop('text_use_cls_token')
|
769 |
+
|
770 |
+
gpt2 = GPT2LMHeadModel.from_pretrained(
|
771 |
+
"gpt2-xl",
|
772 |
+
use_cache=False,
|
773 |
+
)
|
774 |
+
new_config = augment_gpt2_config(gpt2.config, cross_attn_freq=2, gated_xattn=gated_xattn)
|
775 |
+
text_decoder = GatedGPT2LMHeadModel(new_config)
|
776 |
+
for n, p in gpt2.named_parameters():
|
777 |
+
rsetattr(text_decoder, n + '.data', p.data)
|
778 |
+
|
779 |
+
if freeze_lm_vclm:
|
780 |
+
print('Freeze the LM part of TextDecoder of VCLM')
|
781 |
+
text_decoder.freeze_lm_weights()
|
782 |
+
|
783 |
+
if freeze_visual_vclm:
|
784 |
+
print('Freeze the spatial part of VideoEncoder of VCLM')
|
785 |
+
vision_model.freeze_spatial_weights()
|
786 |
+
|
787 |
+
if freeze_visual_vclm_temporal:
|
788 |
+
print('Freeze the temporal part of VideoEncoder of VCLM')
|
789 |
+
vision_model.freeze_temporal_weights()
|
790 |
+
|
791 |
+
model = VCLM_HF(
|
792 |
+
vision_width=768,
|
793 |
+
vision_model=vision_model,
|
794 |
+
text_width=1600,
|
795 |
+
text_decoder=text_decoder,
|
796 |
+
num_img_queries=256,
|
797 |
+
dim_head=64,
|
798 |
+
heads=25,
|
799 |
+
**kwargs,
|
800 |
+
)
|
801 |
+
|
802 |
+
return model
|
803 |
+
|
804 |
+
|
805 |
+
def VCLM_OPENAI_VITL14_GPT2_XL(gated_xattn=False, freeze_lm_vclm=False,
|
806 |
+
freeze_visual_vclm=False, freeze_visual_vclm_temporal=False, **kwargs):
|
807 |
+
clip_model, _ = load_openai_clip('ViT-L/14', 'cpu')
|
808 |
+
vision_model = clip_model.visual
|
809 |
+
kwargs.pop('text_use_cls_token')
|
810 |
+
|
811 |
+
gpt2 = GPT2LMHeadModel.from_pretrained(
|
812 |
+
"gpt2-xl",
|
813 |
+
use_cache=False,
|
814 |
+
)
|
815 |
+
new_config = augment_gpt2_config(gpt2.config, cross_attn_freq=2, gated_xattn=gated_xattn)
|
816 |
+
text_decoder = GatedGPT2LMHeadModel(new_config)
|
817 |
+
for n, p in gpt2.named_parameters():
|
818 |
+
rsetattr(text_decoder, n + '.data', p.data)
|
819 |
+
|
820 |
+
if freeze_lm_vclm:
|
821 |
+
print('Freeze the LM part of TextDecoder of VCLM')
|
822 |
+
text_decoder.freeze_lm_weights()
|
823 |
+
|
824 |
+
if freeze_visual_vclm:
|
825 |
+
print('Freeze the spatial part of VideoEncoder of VCLM')
|
826 |
+
vision_model.freeze_spatial_weights()
|
827 |
+
|
828 |
+
if freeze_visual_vclm_temporal:
|
829 |
+
print('Freeze the temporal part of VideoEncoder of VCLM')
|
830 |
+
vision_model.freeze_temporal_weights()
|
831 |
+
|
832 |
+
model = VCLM_HF(
|
833 |
+
vision_width=1024,
|
834 |
+
vision_model=vision_model,
|
835 |
+
text_width=1600,
|
836 |
+
text_decoder=text_decoder,
|
837 |
+
num_img_queries=256,
|
838 |
+
dim_head=64,
|
839 |
+
heads=25,
|
840 |
+
**kwargs,
|
841 |
+
)
|
842 |
+
|
843 |
+
return model
|
844 |
+
|
845 |
+
|
846 |
+
def VCLM_OPENAI_VITL14_336PX_GPT2_XL(gated_xattn=False, freeze_lm_vclm=False,
|
847 |
+
freeze_visual_vclm=False, freeze_visual_vclm_temporal=False, **kwargs):
|
848 |
+
clip_model, _ = load_openai_clip('ViT-L/14@336px', 'cpu')
|
849 |
+
vision_model = clip_model.visual
|
850 |
+
kwargs.pop('text_use_cls_token')
|
851 |
+
|
852 |
+
gpt2 = GPT2LMHeadModel.from_pretrained(
|
853 |
+
"gpt2-xl",
|
854 |
+
use_cache=False,
|
855 |
+
)
|
856 |
+
new_config = augment_gpt2_config(gpt2.config, cross_attn_freq=2, gated_xattn=gated_xattn)
|
857 |
+
text_decoder = GatedGPT2LMHeadModel(new_config)
|
858 |
+
for n, p in gpt2.named_parameters():
|
859 |
+
rsetattr(text_decoder, n + '.data', p.data)
|
860 |
+
|
861 |
+
if freeze_lm_vclm:
|
862 |
+
print('Freeze the LM part of TextDecoder of VCLM')
|
863 |
+
text_decoder.freeze_lm_weights()
|
864 |
+
|
865 |
+
if freeze_visual_vclm:
|
866 |
+
print('Freeze the spatial part of VideoEncoder of VCLM')
|
867 |
+
vision_model.freeze_spatial_weights()
|
868 |
+
|
869 |
+
if freeze_visual_vclm_temporal:
|
870 |
+
print('Freeze the temporal part of VideoEncoder of VCLM')
|
871 |
+
vision_model.freeze_temporal_weights()
|
872 |
+
|
873 |
+
model = VCLM_HF(
|
874 |
+
vision_width=1024,
|
875 |
+
vision_model=vision_model,
|
876 |
+
text_width=1600,
|
877 |
+
text_decoder=text_decoder,
|
878 |
+
num_img_queries=256,
|
879 |
+
dim_head=64,
|
880 |
+
heads=25,
|
881 |
+
**kwargs,
|
882 |
+
)
|
883 |
+
|
884 |
+
return model
|
885 |
+
|
886 |
+
|
887 |
+
def VCLM_OPENAI_TIMESFORMER_BASE_GPT2(
|
888 |
+
gated_xattn=False,
|
889 |
+
random_init_gpt2=False,
|
890 |
+
freeze_lm_vclm=False,
|
891 |
+
freeze_visual_vclm=False,
|
892 |
+
freeze_visual_vclm_temporal=False,
|
893 |
+
num_frames=4,
|
894 |
+
timesformer_gated_xattn=False,
|
895 |
+
**kwargs,
|
896 |
+
):
|
897 |
+
vision_model = SpaceTimeTransformer(
|
898 |
+
num_frames=num_frames,
|
899 |
+
time_init='zeros',
|
900 |
+
attention_style='frozen-in-time',
|
901 |
+
ln_pre=True,
|
902 |
+
act_layer=QuickGELU,
|
903 |
+
is_tanh_gating=timesformer_gated_xattn,
|
904 |
+
)
|
905 |
+
clip_model, _ = load_openai_clip('ViT-B/16', 'cpu')
|
906 |
+
print("=> Loading CLIP (ViT-B/16) weights")
|
907 |
+
remapped_state_dict = remap_keys(clip_model.visual.state_dict(), transformer_layers=12)
|
908 |
+
res = vision_model.load_state_dict(remapped_state_dict, strict=False)
|
909 |
+
print(res)
|
910 |
+
vision_model.head = nn.Identity()
|
911 |
+
vision_model.pre_logits = nn.Identity()
|
912 |
+
vision_model.fc = nn.Identity()
|
913 |
+
|
914 |
+
gpt2 = GPT2LMHeadModel.from_pretrained(
|
915 |
+
"gpt2",
|
916 |
+
use_cache=False,
|
917 |
+
)
|
918 |
+
new_config = augment_gpt2_config(gpt2.config, cross_attn_freq=1, gated_xattn=gated_xattn)
|
919 |
+
text_decoder = GatedGPT2LMHeadModel(new_config)
|
920 |
+
if not random_init_gpt2:
|
921 |
+
print('Loading LM from pretrained weights..')
|
922 |
+
for n, p in gpt2.named_parameters():
|
923 |
+
rsetattr(text_decoder, n + '.data', p.data)
|
924 |
+
|
925 |
+
if freeze_lm_vclm:
|
926 |
+
print('Freeze the LM part of TextDecoder of VCLM')
|
927 |
+
text_decoder.freeze_lm_weights()
|
928 |
+
|
929 |
+
if freeze_visual_vclm:
|
930 |
+
print('Freeze the spatial part of VideoEncoder of VCLM')
|
931 |
+
vision_model.freeze_spatial_weights()
|
932 |
+
|
933 |
+
if freeze_visual_vclm_temporal:
|
934 |
+
print('Freeze the temporal part of VideoEncoder of VCLM')
|
935 |
+
vision_model.freeze_temporal_weights()
|
936 |
+
|
937 |
+
model = VCLM_HF(
|
938 |
+
vision_width=768,
|
939 |
+
vision_model=vision_model,
|
940 |
+
text_width=768,
|
941 |
+
text_decoder=text_decoder,
|
942 |
+
num_img_queries=256,
|
943 |
+
dim_head=64,
|
944 |
+
heads=12,
|
945 |
+
**kwargs,
|
946 |
+
)
|
947 |
+
|
948 |
+
return model
|
949 |
+
|
950 |
+
|
951 |
+
def VCLM_OPENAI_TIMESFORMER_BASE_GPT2_XL(
|
952 |
+
gated_xattn=False,
|
953 |
+
freeze_lm_vclm=False,
|
954 |
+
freeze_visual_vclm=False,
|
955 |
+
freeze_visual_vclm_temporal=False,
|
956 |
+
num_frames=4,
|
957 |
+
timesformer_gated_xattn=False,
|
958 |
+
**kwargs,
|
959 |
+
):
|
960 |
+
vision_model = SpaceTimeTransformer(
|
961 |
+
num_frames=num_frames,
|
962 |
+
time_init='zeros',
|
963 |
+
attention_style='frozen-in-time',
|
964 |
+
ln_pre=True,
|
965 |
+
act_layer=QuickGELU,
|
966 |
+
is_tanh_gating=timesformer_gated_xattn,
|
967 |
+
)
|
968 |
+
clip_model, _ = load_openai_clip('ViT-B/16', 'cpu')
|
969 |
+
print("=> Loading CLIP (ViT-B/16) weights")
|
970 |
+
remapped_state_dict = remap_keys(clip_model.visual.state_dict(), transformer_layers=12)
|
971 |
+
res = vision_model.load_state_dict(remapped_state_dict, strict=False)
|
972 |
+
print(res)
|
973 |
+
vision_model.head = nn.Identity()
|
974 |
+
vision_model.pre_logits = nn.Identity()
|
975 |
+
vision_model.fc = nn.Identity()
|
976 |
+
|
977 |
+
gpt2 = GPT2LMHeadModel.from_pretrained(
|
978 |
+
"gpt2-xl",
|
979 |
+
use_cache=False,
|
980 |
+
)
|
981 |
+
new_config = augment_gpt2_config(gpt2.config, cross_attn_freq=2, gated_xattn=gated_xattn)
|
982 |
+
text_decoder = GatedGPT2LMHeadModel(new_config)
|
983 |
+
for n, p in gpt2.named_parameters():
|
984 |
+
rsetattr(text_decoder, n + '.data', p.data)
|
985 |
+
|
986 |
+
if freeze_lm_vclm:
|
987 |
+
print('Freeze the LM part of TextDecoder of VCLM')
|
988 |
+
text_decoder.freeze_lm_weights()
|
989 |
+
|
990 |
+
if freeze_visual_vclm:
|
991 |
+
print('Freeze the spatial part of VideoEncoder of VCLM')
|
992 |
+
vision_model.freeze_spatial_weights()
|
993 |
+
|
994 |
+
if freeze_visual_vclm_temporal:
|
995 |
+
print('Freeze the temporal part of VideoEncoder of VCLM')
|
996 |
+
vision_model.freeze_temporal_weights()
|
997 |
+
|
998 |
+
model = VCLM_HF(
|
999 |
+
vision_width=768,
|
1000 |
+
vision_model=vision_model,
|
1001 |
+
text_width=1600,
|
1002 |
+
text_decoder=text_decoder,
|
1003 |
+
num_img_queries=256,
|
1004 |
+
dim_head=64,
|
1005 |
+
heads=25,
|
1006 |
+
**kwargs,
|
1007 |
+
)
|
1008 |
+
|
1009 |
+
return model
|
1010 |
+
|
1011 |
+
|
1012 |
+
def VCLM_OPENAI_TIMESFORMER_LARGE_GPT2_XL(
|
1013 |
+
gated_xattn=False,
|
1014 |
+
freeze_lm_vclm=False,
|
1015 |
+
freeze_visual_vclm=False,
|
1016 |
+
freeze_visual_vclm_temporal=False,
|
1017 |
+
num_frames=4,
|
1018 |
+
timesformer_gated_xattn=False,
|
1019 |
+
**kwargs,
|
1020 |
+
):
|
1021 |
+
vision_model = SpaceTimeTransformer(
|
1022 |
+
img_size=224, patch_size=14,
|
1023 |
+
embed_dim=1024, depth=24, num_heads=16,
|
1024 |
+
num_frames=num_frames,
|
1025 |
+
time_init='zeros',
|
1026 |
+
attention_style='frozen-in-time',
|
1027 |
+
ln_pre=True,
|
1028 |
+
act_layer=QuickGELU,
|
1029 |
+
is_tanh_gating=timesformer_gated_xattn,
|
1030 |
+
)
|
1031 |
+
clip_model, _ = load_openai_clip('ViT-L/14', 'cpu')
|
1032 |
+
print("=> Loading CLIP (ViT-L/14x) weights")
|
1033 |
+
remapped_state_dict = remap_keys(clip_model.visual.state_dict(), transformer_layers=24)
|
1034 |
+
res = vision_model.load_state_dict(remapped_state_dict, strict=False)
|
1035 |
+
print(res)
|
1036 |
+
vision_model.head = nn.Identity()
|
1037 |
+
vision_model.pre_logits = nn.Identity()
|
1038 |
+
vision_model.fc = nn.Identity()
|
1039 |
+
|
1040 |
+
gpt2 = GPT2LMHeadModel.from_pretrained(
|
1041 |
+
"gpt2-xl",
|
1042 |
+
use_cache=False,
|
1043 |
+
)
|
1044 |
+
new_config = augment_gpt2_config(gpt2.config, cross_attn_freq=2, gated_xattn=gated_xattn)
|
1045 |
+
text_decoder = GatedGPT2LMHeadModel(new_config)
|
1046 |
+
for n, p in gpt2.named_parameters():
|
1047 |
+
rsetattr(text_decoder, n + '.data', p.data)
|
1048 |
+
|
1049 |
+
if freeze_lm_vclm:
|
1050 |
+
print('Freeze the LM part of TextDecoder of VCLM')
|
1051 |
+
text_decoder.freeze_lm_weights()
|
1052 |
+
|
1053 |
+
if freeze_visual_vclm:
|
1054 |
+
print('Freeze the spatial part of VideoEncoder of VCLM')
|
1055 |
+
vision_model.freeze_spatial_weights()
|
1056 |
+
|
1057 |
+
if freeze_visual_vclm_temporal:
|
1058 |
+
print('Freeze the temporal part of VideoEncoder of VCLM')
|
1059 |
+
vision_model.freeze_temporal_weights()
|
1060 |
+
|
1061 |
+
model = VCLM_HF(
|
1062 |
+
vision_width=1024,
|
1063 |
+
vision_model=vision_model,
|
1064 |
+
text_width=1600,
|
1065 |
+
text_decoder=text_decoder,
|
1066 |
+
num_img_queries=256,
|
1067 |
+
dim_head=64,
|
1068 |
+
heads=25,
|
1069 |
+
**kwargs,
|
1070 |
+
)
|
1071 |
+
|
1072 |
+
return model
|
1073 |
+
|
1074 |
+
|
1075 |
+
def VCLM_OPENAI_TIMESFORMER_LARGE_GPT2(
|
1076 |
+
gated_xattn=False,
|
1077 |
+
freeze_lm_vclm=False,
|
1078 |
+
freeze_visual_vclm=False,
|
1079 |
+
freeze_visual_vclm_temporal=False,
|
1080 |
+
num_frames=4,
|
1081 |
+
timesformer_gated_xattn=False,
|
1082 |
+
**kwargs
|
1083 |
+
):
|
1084 |
+
vision_model = SpaceTimeTransformer(
|
1085 |
+
img_size=224, patch_size=14,
|
1086 |
+
embed_dim=1024, depth=24, num_heads=16,
|
1087 |
+
num_frames=num_frames,
|
1088 |
+
time_init='zeros',
|
1089 |
+
attention_style='frozen-in-time',
|
1090 |
+
ln_pre=True,
|
1091 |
+
act_layer=QuickGELU,
|
1092 |
+
is_tanh_gating=timesformer_gated_xattn,
|
1093 |
+
)
|
1094 |
+
clip_model, _ = load_openai_clip('ViT-L/14', 'cpu')
|
1095 |
+
print("=> Loading CLIP (ViT-L/14x) weights")
|
1096 |
+
remapped_state_dict = remap_keys(clip_model.visual.state_dict(), transformer_layers=24)
|
1097 |
+
res = vision_model.load_state_dict(remapped_state_dict, strict=False)
|
1098 |
+
print(res)
|
1099 |
+
vision_model.head = nn.Identity()
|
1100 |
+
vision_model.pre_logits = nn.Identity()
|
1101 |
+
vision_model.fc = nn.Identity()
|
1102 |
+
|
1103 |
+
gpt2 = GPT2LMHeadModel.from_pretrained(
|
1104 |
+
"gpt2",
|
1105 |
+
use_cache=False,
|
1106 |
+
)
|
1107 |
+
new_config = augment_gpt2_config(gpt2.config, cross_attn_freq=1, gated_xattn=gated_xattn)
|
1108 |
+
text_decoder = GatedGPT2LMHeadModel(new_config)
|
1109 |
+
for n, p in gpt2.named_parameters():
|
1110 |
+
rsetattr(text_decoder, n + '.data', p.data)
|
1111 |
+
|
1112 |
+
if freeze_lm_vclm:
|
1113 |
+
print('Freeze the LM part of TextDecoder of VCLM')
|
1114 |
+
text_decoder.freeze_lm_weights()
|
1115 |
+
|
1116 |
+
if freeze_visual_vclm:
|
1117 |
+
print('Freeze the spatial part of VideoEncoder of VCLM')
|
1118 |
+
vision_model.freeze_spatial_weights()
|
1119 |
+
|
1120 |
+
if freeze_visual_vclm_temporal:
|
1121 |
+
print('Freeze the temporal part of VideoEncoder of VCLM')
|
1122 |
+
vision_model.freeze_temporal_weights()
|
1123 |
+
|
1124 |
+
model = VCLM_HF(
|
1125 |
+
vision_width=1024,
|
1126 |
+
vision_model=vision_model,
|
1127 |
+
text_width=768,
|
1128 |
+
text_decoder=text_decoder,
|
1129 |
+
num_img_queries=256,
|
1130 |
+
dim_head=64,
|
1131 |
+
heads=12,
|
1132 |
+
**kwargs,
|
1133 |
+
)
|
1134 |
+
|
1135 |
+
return model
|
1136 |
+
|
1137 |
+
|
1138 |
+
def VCLM_OPENAI_TIMESFORMER_LARGE_336PX_GPT2_XL(
|
1139 |
+
gated_xattn=False,
|
1140 |
+
freeze_lm_vclm=False,
|
1141 |
+
freeze_visual_vclm=False,
|
1142 |
+
freeze_visual_vclm_temporal=False,
|
1143 |
+
num_frames=4,
|
1144 |
+
timesformer_gated_xattn=False,
|
1145 |
+
**kwargs,
|
1146 |
+
):
|
1147 |
+
vision_model = SpaceTimeTransformer(
|
1148 |
+
img_size=336, patch_size=14,
|
1149 |
+
embed_dim=1024, depth=24, num_heads=16,
|
1150 |
+
num_frames=num_frames,
|
1151 |
+
time_init='zeros',
|
1152 |
+
attention_style='frozen-in-time',
|
1153 |
+
ln_pre=True,
|
1154 |
+
act_layer=QuickGELU,
|
1155 |
+
is_tanh_gating=timesformer_gated_xattn,
|
1156 |
+
)
|
1157 |
+
clip_model, _ = load_openai_clip('ViT-L/14@336px', 'cpu')
|
1158 |
+
print("=> Loading CLIP (ViT-L/14@336px) weights")
|
1159 |
+
remapped_state_dict = remap_keys(clip_model.visual.state_dict(), transformer_layers=24)
|
1160 |
+
res = vision_model.load_state_dict(remapped_state_dict, strict=False)
|
1161 |
+
print(res)
|
1162 |
+
vision_model.head = nn.Identity()
|
1163 |
+
vision_model.pre_logits = nn.Identity()
|
1164 |
+
vision_model.fc = nn.Identity()
|
1165 |
+
|
1166 |
+
gpt2 = GPT2LMHeadModel.from_pretrained(
|
1167 |
+
"gpt2-xl",
|
1168 |
+
use_cache=False,
|
1169 |
+
)
|
1170 |
+
new_config = augment_gpt2_config(gpt2.config, cross_attn_freq=3, gated_xattn=gated_xattn)
|
1171 |
+
text_decoder = GatedGPT2LMHeadModel(new_config)
|
1172 |
+
for n, p in gpt2.named_parameters():
|
1173 |
+
rsetattr(text_decoder, n + '.data', p.data)
|
1174 |
+
|
1175 |
+
if freeze_lm_vclm:
|
1176 |
+
print('Freeze the LM part of TextDecoder of VCLM')
|
1177 |
+
text_decoder.freeze_lm_weights()
|
1178 |
+
|
1179 |
+
if freeze_visual_vclm:
|
1180 |
+
print('Freeze the spatial part of VideoEncoder of VCLM')
|
1181 |
+
vision_model.freeze_spatial_weights()
|
1182 |
+
|
1183 |
+
if freeze_visual_vclm_temporal:
|
1184 |
+
print('Freeze the temporal part of VideoEncoder of VCLM')
|
1185 |
+
vision_model.freeze_temporal_weights()
|
1186 |
+
|
1187 |
+
model = VCLM_HF(
|
1188 |
+
vision_width=1024,
|
1189 |
+
vision_model=vision_model,
|
1190 |
+
text_width=1600,
|
1191 |
+
text_decoder=text_decoder,
|
1192 |
+
num_img_queries=256,
|
1193 |
+
dim_head=64,
|
1194 |
+
heads=25,
|
1195 |
+
**kwargs,
|
1196 |
+
)
|
1197 |
+
|
1198 |
+
return model
|
1199 |
+
|
1200 |
+
|
1201 |
+
def CLIP_OPENAI_VITB32(**kwargs):
|
1202 |
+
model, _ = load_openai_clip('ViT-B/32', 'cpu')
|
1203 |
+
return model
|
1204 |
+
|
1205 |
+
|
1206 |
+
def CLIP_OPENAI_VITB16(**kwargs):
|
1207 |
+
model, _ = load_openai_clip('ViT-B/16', 'cpu')
|
1208 |
+
return model
|
1209 |
+
|
1210 |
+
|
1211 |
+
def CLIP_OPENAI_VITL14(**kwargs):
|
1212 |
+
model, _ = load_openai_clip('ViT-L/14', 'cpu')
|
1213 |
+
return model
|
1214 |
+
|
1215 |
+
|
1216 |
+
def CLIP_OPENAI_VITL14_336PX(**kwargs):
|
1217 |
+
model, _ = load_openai_clip('ViT-L/14@336px', 'cpu')
|
1218 |
+
return model
|
lavila/models/narrator.py
ADDED
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Part of the code is from https://github.com/huggingface/transformers/blob/main/src/transformers/generation_utils.py
|
8 |
+
# Modified by Yue Zhao
|
9 |
+
# The original code is under Apache 2.0 License
|
10 |
+
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from torch import nn
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from einops import rearrange, repeat
|
16 |
+
from transformers import BeamSearchScorer
|
17 |
+
from transformers.generation_logits_process import (
|
18 |
+
LogitsProcessorList, TopKLogitsWarper, TopPLogitsWarper,
|
19 |
+
TemperatureLogitsWarper, TypicalLogitsWarper, LogitNormalization
|
20 |
+
)
|
21 |
+
|
22 |
+
from lavila.models.coca import CrossAttention, LayerNorm
|
23 |
+
from lavila.models.openai_model import VisionTransformer
|
24 |
+
from lavila.models.timesformer import SpaceTimeTransformer
|
25 |
+
|
26 |
+
|
27 |
+
class VCLM_HF(nn.Module):
|
28 |
+
def __init__(self,
|
29 |
+
# vision
|
30 |
+
vision_width: int,
|
31 |
+
vision_model: nn.Module,
|
32 |
+
# text
|
33 |
+
text_width: int,
|
34 |
+
text_decoder: nn.Module,
|
35 |
+
num_img_queries=256,
|
36 |
+
dim_head=64,
|
37 |
+
heads=8,
|
38 |
+
**kwargs,
|
39 |
+
):
|
40 |
+
super().__init__()
|
41 |
+
self.vision_width = vision_width
|
42 |
+
self.visual = vision_model
|
43 |
+
self.text_width = text_width
|
44 |
+
self.text_decoder = text_decoder
|
45 |
+
|
46 |
+
self.img_queries = nn.Parameter(torch.empty(num_img_queries, text_width))
|
47 |
+
self.img_attn_pool = CrossAttention(
|
48 |
+
dim=text_width, context_dim=vision_width,
|
49 |
+
dim_head=dim_head, heads=heads,
|
50 |
+
norm_context=True
|
51 |
+
)
|
52 |
+
self.img_attn_pool_norm = LayerNorm(text_width)
|
53 |
+
|
54 |
+
self.initialize_parameters()
|
55 |
+
|
56 |
+
def initialize_parameters(self):
|
57 |
+
nn.init.normal_(self.img_queries, std=self.text_width ** -0.5)
|
58 |
+
|
59 |
+
def encode_image(self, image, use_checkpoint=False):
|
60 |
+
if isinstance(self.visual, VisionTransformer):
|
61 |
+
# openai_model.VisionTransformer accepts (N, C, H, W) instead of (N, C, T, H, W)
|
62 |
+
image = image.permute(0, 2, 1, 3, 4) # BCTHW -> BTCHW
|
63 |
+
bb, tt, _, _, _ = image.shape
|
64 |
+
x = self.visual(image.reshape(-1, *image.shape[2:]), use_checkpoint=use_checkpoint, cls_at_last=False) # NLD
|
65 |
+
x = x.view(bb, tt, *x.shape[1:])
|
66 |
+
x = x.permute(0, 3, 1, 2)
|
67 |
+
elif isinstance(self.visual, SpaceTimeTransformer):
|
68 |
+
image = image.permute(0, 2, 1, 3, 4).contiguous() # BCTHW -> BTCHW
|
69 |
+
bb, tt, _, _, _ = image.shape
|
70 |
+
x = self.visual.forward_features(image, use_checkpoint=use_checkpoint, cls_at_last=False) # NLD
|
71 |
+
x = x.permute(0, 2, 1)
|
72 |
+
else:
|
73 |
+
x = self.visual(image, use_checkpoint=use_checkpoint, mean_at_last=False)
|
74 |
+
if isinstance(x, list):
|
75 |
+
assert len(x) == 1
|
76 |
+
x = x[0]
|
77 |
+
|
78 |
+
x = x.flatten(start_dim=2) # BDTHW -> BD(THW)
|
79 |
+
x = x.permute(0, 2, 1) # BDN -> BND
|
80 |
+
img_queries = repeat(self.img_queries, 'n d -> b n d', b=x.shape[0])
|
81 |
+
img_queries = self.img_attn_pool(img_queries, x)
|
82 |
+
img_queries = self.img_attn_pool_norm(img_queries)
|
83 |
+
return img_queries
|
84 |
+
|
85 |
+
def forward(self, image, text, mask=None, use_checkpoint=False, norm_embed=False):
|
86 |
+
if use_checkpoint:
|
87 |
+
self.text_decoder.gradient_checkpointing_enable()
|
88 |
+
else:
|
89 |
+
self.text_decoder.gradient_checkpointing_disable()
|
90 |
+
|
91 |
+
text, labels = text[:, :-1], text[:, 1:]
|
92 |
+
# mask = mask[:, :-1]
|
93 |
+
image_tokens = self.encode_image(image, use_checkpoint=use_checkpoint)
|
94 |
+
|
95 |
+
output_decoder = self.text_decoder(text.contiguous(), encoder_hidden_states=image_tokens)
|
96 |
+
text_tokens_logits = output_decoder.logits
|
97 |
+
text_tokens_logits = rearrange(text_tokens_logits, 'b n c -> b c n')
|
98 |
+
|
99 |
+
return {'text_tokens_logits': text_tokens_logits,
|
100 |
+
'labels': labels}
|
101 |
+
|
102 |
+
def generate(self, image_tokens, tokenizer, target=None, max_text_length=77, top_k=None, top_p=None,
|
103 |
+
num_return_sequences=1, temperature=1.0, teacher_forcing=False, early_stopping=False):
|
104 |
+
image_tokens = image_tokens.repeat_interleave(num_return_sequences, dim=0)
|
105 |
+
device = image_tokens.device
|
106 |
+
generated_text_ids = torch.LongTensor([[tokenizer.bos_token_id]] * image_tokens.shape[0]).to(device)
|
107 |
+
condition_text_ids = generated_text_ids.clone()
|
108 |
+
|
109 |
+
logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, typical_p=None, temperature=temperature, num_beams=1)
|
110 |
+
|
111 |
+
nlls, num_tokens = torch.zeros(image_tokens.shape[0]).to(device), torch.zeros(image_tokens.shape[0]).to(device)
|
112 |
+
is_reach_eos = torch.zeros(image_tokens.shape[0]).bool().to(device)
|
113 |
+
with torch.no_grad():
|
114 |
+
for i in range(max_text_length - 1):
|
115 |
+
output_decoder = self.text_decoder(condition_text_ids, encoder_hidden_states=image_tokens)
|
116 |
+
decoded_token_logits = output_decoder.logits
|
117 |
+
next_token_logits = decoded_token_logits[:, -1, :]
|
118 |
+
if target is not None:
|
119 |
+
nll = F.cross_entropy(next_token_logits, target[:, i+1], ignore_index=tokenizer.pad_token_id, reduction='none')
|
120 |
+
nlls += nll
|
121 |
+
num_tokens += target[:, i+1].ne(tokenizer.pad_token_id)
|
122 |
+
else:
|
123 |
+
nll = torch.special.entr(F.softmax(next_token_logits, dim=1)).sum(dim=1)
|
124 |
+
nlls += nll * (~is_reach_eos)
|
125 |
+
num_tokens += (~is_reach_eos)
|
126 |
+
# filtered_p = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p, device=device)
|
127 |
+
next_token_logits = logits_warper(generated_text_ids, next_token_logits)
|
128 |
+
filtered_p = F.softmax(next_token_logits, dim=-1)
|
129 |
+
next_token = torch.multinomial(filtered_p, num_samples=1)
|
130 |
+
is_reach_eos = is_reach_eos | (next_token[:, 0] == tokenizer.eos_token_id)
|
131 |
+
if early_stopping and torch.all(is_reach_eos):
|
132 |
+
break
|
133 |
+
|
134 |
+
if teacher_forcing:
|
135 |
+
condition_text_ids = target[:, :i+2]
|
136 |
+
else:
|
137 |
+
condition_text_ids = torch.cat((generated_text_ids, next_token), dim=1)
|
138 |
+
|
139 |
+
generated_text_ids = torch.cat((generated_text_ids, next_token), dim=1)
|
140 |
+
if target is not None:
|
141 |
+
return generated_text_ids, torch.exp(nlls / num_tokens)
|
142 |
+
else:
|
143 |
+
return generated_text_ids, torch.exp(nlls / num_tokens)
|
144 |
+
|
145 |
+
def beam_sample(self, image_tokens, tokenizer, target=None, max_text_length=77, top_k=None, top_p=None,
|
146 |
+
temperature=1.0, length_penalty=1.,
|
147 |
+
num_beams=3, num_return_sequences=1, teacher_forcing=False, early_stopping=False):
|
148 |
+
batch_size = image_tokens.shape[0]
|
149 |
+
device = image_tokens.device
|
150 |
+
input_ids = torch.ones((batch_size, 1), device=device, dtype=torch.long)
|
151 |
+
input_ids = input_ids * tokenizer.bos_token_id
|
152 |
+
|
153 |
+
expanded_return_idx = (
|
154 |
+
torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, num_beams * num_return_sequences).view(-1).to(device)
|
155 |
+
)
|
156 |
+
input_ids = input_ids.index_select(0, expanded_return_idx)
|
157 |
+
|
158 |
+
batch_beam_size, cur_len = input_ids.shape
|
159 |
+
|
160 |
+
logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, typical_p=None, temperature=temperature, num_beams=num_beams)
|
161 |
+
|
162 |
+
beam_scorer = BeamSearchScorer(
|
163 |
+
batch_size=batch_size * num_return_sequences, num_beams=num_beams,
|
164 |
+
device=device,
|
165 |
+
length_penalty=length_penalty,
|
166 |
+
)
|
167 |
+
batch_size = len(beam_scorer._beam_hyps)
|
168 |
+
num_beams = beam_scorer.num_beams
|
169 |
+
|
170 |
+
beam_scores = torch.zeros((batch_size, num_beams)).to(device)
|
171 |
+
beam_scores = beam_scores.view((batch_size * num_beams,))
|
172 |
+
|
173 |
+
is_reach_eos = torch.zeros(batch_beam_size).bool().to(device)
|
174 |
+
with torch.no_grad():
|
175 |
+
for i in range(max_text_length - 1):
|
176 |
+
output_decoder = self.text_decoder(
|
177 |
+
input_ids,
|
178 |
+
encoder_hidden_states=image_tokens.repeat_interleave(num_beams * num_return_sequences, dim=0)
|
179 |
+
)
|
180 |
+
decoded_token_logits = output_decoder.logits
|
181 |
+
next_token_logits = decoded_token_logits[:, -1, :]
|
182 |
+
|
183 |
+
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
184 |
+
# supposed to be the line below, but ignore temporarily
|
185 |
+
# next_token_scores_processed = logits_processor(input_ids, next_token_scores)
|
186 |
+
next_token_scores_processed = next_token_scores
|
187 |
+
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
|
188 |
+
# supposed to be the line below, but do a simple top_k+top_p temporarily
|
189 |
+
next_token_scores = logits_warper(input_ids, next_token_scores)
|
190 |
+
# next_token_scores = top_k_top_p_filtering(next_token_scores, top_k=top_k, top_p=top_p, device=device)
|
191 |
+
|
192 |
+
vocab_size = next_token_scores.shape[-1]
|
193 |
+
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
|
194 |
+
|
195 |
+
probs = F.softmax(next_token_scores, dim=-1)
|
196 |
+
next_tokens = torch.multinomial(probs, num_samples=2 * num_beams)
|
197 |
+
next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
|
198 |
+
|
199 |
+
next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
|
200 |
+
next_tokens = torch.gather(next_tokens, -1, _indices)
|
201 |
+
|
202 |
+
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
|
203 |
+
next_tokens = next_tokens % vocab_size
|
204 |
+
|
205 |
+
# stateless
|
206 |
+
beam_outputs = beam_scorer.process(
|
207 |
+
input_ids,
|
208 |
+
next_token_scores,
|
209 |
+
next_tokens,
|
210 |
+
next_indices,
|
211 |
+
pad_token_id=tokenizer.pad_token_id,
|
212 |
+
eos_token_id=tokenizer.eos_token_id,
|
213 |
+
)
|
214 |
+
|
215 |
+
beam_scores = beam_outputs["next_beam_scores"]
|
216 |
+
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
217 |
+
beam_idx = beam_outputs["next_beam_indices"]
|
218 |
+
|
219 |
+
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
|
220 |
+
|
221 |
+
is_reach_eos = is_reach_eos | (input_ids[:, -1] == tokenizer.eos_token_id)
|
222 |
+
if beam_scorer.is_done or torch.all(is_reach_eos):
|
223 |
+
break
|
224 |
+
|
225 |
+
sequence_outputs = beam_scorer.finalize(
|
226 |
+
input_ids,
|
227 |
+
beam_scores,
|
228 |
+
next_tokens,
|
229 |
+
next_indices,
|
230 |
+
pad_token_id=tokenizer.pad_token_id,
|
231 |
+
eos_token_id=tokenizer.eos_token_id,
|
232 |
+
max_length=max_text_length,
|
233 |
+
)
|
234 |
+
|
235 |
+
sequences = sequence_outputs["sequences"]
|
236 |
+
sequence_scores = sequence_outputs["sequence_scores"]
|
237 |
+
return sequences, sequence_scores
|
238 |
+
|
239 |
+
def group_beam_search(self, image_tokens, tokenizer, target=None, max_text_length=77, top_k=None, top_p=None,
|
240 |
+
temperature=1.0, length_penalty=1.,
|
241 |
+
num_beams=6, num_beam_groups=3,
|
242 |
+
num_return_sequences=1, teacher_forcing=False, early_stopping=False):
|
243 |
+
batch_size = image_tokens.shape[0]
|
244 |
+
device = image_tokens.device
|
245 |
+
input_ids = torch.ones((batch_size, 1), device=device, dtype=torch.long)
|
246 |
+
input_ids = input_ids * tokenizer.bos_token_id
|
247 |
+
|
248 |
+
expanded_return_idx = (
|
249 |
+
torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, num_beams).view(-1).to(device)
|
250 |
+
)
|
251 |
+
input_ids = input_ids.index_select(0, expanded_return_idx)
|
252 |
+
|
253 |
+
batch_beam_size, cur_len = input_ids.shape
|
254 |
+
|
255 |
+
logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, typical_p=None, temperature=temperature, num_beams=num_beams)
|
256 |
+
|
257 |
+
beam_scorer = BeamSearchScorer(
|
258 |
+
batch_size=batch_size, num_beams=num_beams,
|
259 |
+
num_beam_groups=num_beam_groups,
|
260 |
+
num_beam_hyps_to_keep=num_return_sequences, device=device,
|
261 |
+
length_penalty=length_penalty,
|
262 |
+
)
|
263 |
+
num_sub_beams = num_beams // num_beam_groups
|
264 |
+
beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
|
265 |
+
beam_scores[:, ::num_sub_beams] = 0
|
266 |
+
beam_scores = beam_scores.view((batch_size * num_beams,))
|
267 |
+
|
268 |
+
is_reach_eos = torch.zeros(batch_beam_size).bool().to(device)
|
269 |
+
with torch.no_grad():
|
270 |
+
|
271 |
+
# predicted tokens in cur_len step
|
272 |
+
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
|
273 |
+
|
274 |
+
# indices which will form the beams in the next time step
|
275 |
+
reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)
|
276 |
+
|
277 |
+
for i in range(max_text_length - 1):
|
278 |
+
output_decoder = self.text_decoder(
|
279 |
+
input_ids,
|
280 |
+
encoder_hidden_states=image_tokens.repeat_interleave(num_beams, dim=0)
|
281 |
+
)
|
282 |
+
decoded_token_logits = output_decoder.logits
|
283 |
+
|
284 |
+
for beam_group_idx in range(num_beam_groups):
|
285 |
+
group_start_idx = beam_group_idx * num_sub_beams
|
286 |
+
group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
|
287 |
+
group_size = group_end_idx - group_start_idx
|
288 |
+
|
289 |
+
# indices of beams of current group among all sentences in batch
|
290 |
+
batch_group_indices = []
|
291 |
+
|
292 |
+
for batch_idx in range(batch_size):
|
293 |
+
batch_group_indices.extend(
|
294 |
+
[batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
|
295 |
+
)
|
296 |
+
group_input_ids = input_ids[batch_group_indices]
|
297 |
+
|
298 |
+
# select outputs of beams of current group only
|
299 |
+
next_token_logits = decoded_token_logits[batch_group_indices, -1, :]
|
300 |
+
|
301 |
+
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
302 |
+
vocab_size = next_token_scores.shape[-1]
|
303 |
+
|
304 |
+
# supposed to be the line below, but ignore temporarily
|
305 |
+
# next_token_scores_processed = logits_processor(input_ids, next_token_scores)
|
306 |
+
next_token_scores_processed = next_token_scores
|
307 |
+
next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
|
308 |
+
next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
|
309 |
+
next_token_scores = logits_warper(input_ids, next_token_scores)
|
310 |
+
# next_token_scores = top_k_top_p_filtering(next_token_scores, top_k=top_k, top_p=top_p, device=device)
|
311 |
+
|
312 |
+
# reshape for beam search
|
313 |
+
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
|
314 |
+
|
315 |
+
next_token_scores, next_tokens = torch.topk(
|
316 |
+
next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
|
317 |
+
)
|
318 |
+
|
319 |
+
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
|
320 |
+
next_tokens = next_tokens % vocab_size
|
321 |
+
|
322 |
+
# stateless
|
323 |
+
beam_outputs = beam_scorer.process(
|
324 |
+
group_input_ids,
|
325 |
+
next_token_scores,
|
326 |
+
next_tokens,
|
327 |
+
next_indices,
|
328 |
+
pad_token_id=tokenizer.pad_token_id,
|
329 |
+
eos_token_id=tokenizer.eos_token_id,
|
330 |
+
beam_indices=None
|
331 |
+
)
|
332 |
+
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
|
333 |
+
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
334 |
+
beam_idx = beam_outputs["next_beam_indices"]
|
335 |
+
|
336 |
+
input_ids[batch_group_indices] = group_input_ids[beam_idx]
|
337 |
+
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
|
338 |
+
current_tokens[batch_group_indices] = group_input_ids[:, -1]
|
339 |
+
reordering_indices[batch_group_indices] = (
|
340 |
+
num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size)
|
341 |
+
)
|
342 |
+
|
343 |
+
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
|
344 |
+
|
345 |
+
is_reach_eos = is_reach_eos | (input_ids[:, -1] == tokenizer.eos_token_id)
|
346 |
+
if beam_scorer.is_done or torch.all(is_reach_eos):
|
347 |
+
break
|
348 |
+
|
349 |
+
sequence_outputs = beam_scorer.finalize(
|
350 |
+
input_ids,
|
351 |
+
beam_scores,
|
352 |
+
next_tokens,
|
353 |
+
next_indices,
|
354 |
+
pad_token_id=tokenizer.pad_token_id,
|
355 |
+
eos_token_id=tokenizer.eos_token_id,
|
356 |
+
max_length=max_text_length,
|
357 |
+
beam_indices=None,
|
358 |
+
)
|
359 |
+
|
360 |
+
sequences = sequence_outputs["sequences"]
|
361 |
+
sequence_scores = sequence_outputs["sequence_scores"]
|
362 |
+
return sequences, sequence_scores
|
363 |
+
|
364 |
+
def _get_logits_warper(
|
365 |
+
self, top_k=None, top_p=None, typical_p=None,
|
366 |
+
temperature=None, num_beams=None, renormalize_logits=None,
|
367 |
+
):
|
368 |
+
top_k = top_k if top_k is not None else 0
|
369 |
+
top_p = top_p if top_p is not None else 1.0
|
370 |
+
typical_p = typical_p if typical_p is not None else 1.
|
371 |
+
temperature = temperature if temperature is not None else 1.
|
372 |
+
warpers = LogitsProcessorList()
|
373 |
+
|
374 |
+
if temperature is not None and temperature != 1.0:
|
375 |
+
warpers.append(TemperatureLogitsWarper(temperature))
|
376 |
+
if top_k is not None and top_k != 0:
|
377 |
+
warpers.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
|
378 |
+
if top_p is not None and top_p < 1.0:
|
379 |
+
warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
|
380 |
+
if typical_p is not None and typical_p < 1.0:
|
381 |
+
warpers.append(TypicalLogitsWarper(mass=typical_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
|
382 |
+
# `LogitNormalization` should always be the last logit processor, when present
|
383 |
+
if renormalize_logits is True:
|
384 |
+
warpers.append(LogitNormalization())
|
385 |
+
return warpers
|
lavila/models/openai_clip.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Part of the code is from https://github.com/openai/CLIP/blob/main/clip/clip.py
|
8 |
+
# Modified by Yue Zhao
|
9 |
+
# The original code is under MIT License
|
10 |
+
|
11 |
+
import hashlib
|
12 |
+
import os
|
13 |
+
import urllib
|
14 |
+
import warnings
|
15 |
+
from typing import Union, List
|
16 |
+
from pkg_resources import packaging
|
17 |
+
|
18 |
+
import torch
|
19 |
+
from PIL import Image
|
20 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
21 |
+
from tqdm import tqdm
|
22 |
+
|
23 |
+
from .openai_model import build_model
|
24 |
+
from .tokenizer import SimpleTokenizer as _Tokenizer
|
25 |
+
|
26 |
+
try:
|
27 |
+
from torchvision.transforms import InterpolationMode
|
28 |
+
BICUBIC = InterpolationMode.BICUBIC
|
29 |
+
except ImportError:
|
30 |
+
BICUBIC = Image.BICUBIC
|
31 |
+
|
32 |
+
|
33 |
+
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
|
34 |
+
warnings.warn("PyTorch version 1.7.1 or higher is recommended")
|
35 |
+
|
36 |
+
|
37 |
+
__all__ = ["available_models", "load", "tokenize"]
|
38 |
+
_tokenizer = _Tokenizer()
|
39 |
+
|
40 |
+
_MODELS = {
|
41 |
+
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
|
42 |
+
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
|
43 |
+
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
|
44 |
+
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
|
45 |
+
"RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
|
46 |
+
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
|
47 |
+
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
|
48 |
+
"ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
|
49 |
+
"ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
|
50 |
+
}
|
51 |
+
|
52 |
+
|
53 |
+
def _download(url: str, root: str):
|
54 |
+
os.makedirs(root, exist_ok=True)
|
55 |
+
filename = os.path.basename(url)
|
56 |
+
|
57 |
+
expected_sha256 = url.split("/")[-2]
|
58 |
+
download_target = os.path.join(root, filename)
|
59 |
+
|
60 |
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
61 |
+
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
62 |
+
|
63 |
+
if os.path.isfile(download_target):
|
64 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
|
65 |
+
return download_target
|
66 |
+
else:
|
67 |
+
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
68 |
+
|
69 |
+
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
70 |
+
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
|
71 |
+
while True:
|
72 |
+
buffer = source.read(8192)
|
73 |
+
if not buffer:
|
74 |
+
break
|
75 |
+
|
76 |
+
output.write(buffer)
|
77 |
+
loop.update(len(buffer))
|
78 |
+
|
79 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
|
80 |
+
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
|
81 |
+
|
82 |
+
return download_target
|
83 |
+
|
84 |
+
|
85 |
+
def _convert_image_to_rgb(image):
|
86 |
+
return image.convert("RGB")
|
87 |
+
|
88 |
+
|
89 |
+
def _transform(n_px):
|
90 |
+
return Compose([
|
91 |
+
Resize(n_px, interpolation=BICUBIC),
|
92 |
+
CenterCrop(n_px),
|
93 |
+
_convert_image_to_rgb,
|
94 |
+
ToTensor(),
|
95 |
+
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
96 |
+
])
|
97 |
+
|
98 |
+
|
99 |
+
def available_models() -> List[str]:
|
100 |
+
"""Returns the names of available CLIP models"""
|
101 |
+
return list(_MODELS.keys())
|
102 |
+
|
103 |
+
|
104 |
+
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
|
105 |
+
"""Load a CLIP model
|
106 |
+
Parameters
|
107 |
+
----------
|
108 |
+
name : str
|
109 |
+
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
110 |
+
device : Union[str, torch.device]
|
111 |
+
The device to put the loaded model
|
112 |
+
jit : bool
|
113 |
+
Whether to load the optimized JIT model or more hackable non-JIT model (default).
|
114 |
+
download_root: str
|
115 |
+
path to download the model files; by default, it uses "~/.cache/clip"
|
116 |
+
Returns
|
117 |
+
-------
|
118 |
+
model : torch.nn.Module
|
119 |
+
The CLIP model
|
120 |
+
preprocess : Callable[[PIL.Image], torch.Tensor]
|
121 |
+
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
122 |
+
"""
|
123 |
+
if name in _MODELS:
|
124 |
+
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
|
125 |
+
elif os.path.isfile(name):
|
126 |
+
model_path = name
|
127 |
+
else:
|
128 |
+
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
129 |
+
|
130 |
+
with open(model_path, 'rb') as opened_file:
|
131 |
+
try:
|
132 |
+
# loading JIT archive
|
133 |
+
model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
|
134 |
+
state_dict = None
|
135 |
+
except RuntimeError:
|
136 |
+
# loading saved state dict
|
137 |
+
if jit:
|
138 |
+
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
139 |
+
jit = False
|
140 |
+
state_dict = torch.load(opened_file, map_location="cpu")
|
141 |
+
|
142 |
+
if not jit:
|
143 |
+
model = build_model(state_dict or model.state_dict()).to(device)
|
144 |
+
if str(device) == "cpu":
|
145 |
+
model.float()
|
146 |
+
return model, _transform(model.visual.input_resolution)
|
147 |
+
|
148 |
+
# patch the device names
|
149 |
+
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
150 |
+
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
151 |
+
|
152 |
+
def patch_device(module):
|
153 |
+
try:
|
154 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
155 |
+
except RuntimeError:
|
156 |
+
graphs = []
|
157 |
+
|
158 |
+
if hasattr(module, "forward1"):
|
159 |
+
graphs.append(module.forward1.graph)
|
160 |
+
|
161 |
+
for graph in graphs:
|
162 |
+
for node in graph.findAllNodes("prim::Constant"):
|
163 |
+
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
|
164 |
+
node.copyAttributes(device_node)
|
165 |
+
|
166 |
+
model.apply(patch_device)
|
167 |
+
patch_device(model.encode_image)
|
168 |
+
patch_device(model.encode_text)
|
169 |
+
|
170 |
+
# patch dtype to float32 on CPU
|
171 |
+
if str(device) == "cpu":
|
172 |
+
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
173 |
+
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
174 |
+
float_node = float_input.node()
|
175 |
+
|
176 |
+
def patch_float(module):
|
177 |
+
try:
|
178 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
179 |
+
except RuntimeError:
|
180 |
+
graphs = []
|
181 |
+
|
182 |
+
if hasattr(module, "forward1"):
|
183 |
+
graphs.append(module.forward1.graph)
|
184 |
+
|
185 |
+
for graph in graphs:
|
186 |
+
for node in graph.findAllNodes("aten::to"):
|
187 |
+
inputs = list(node.inputs())
|
188 |
+
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
|
189 |
+
if inputs[i].node()["value"] == 5:
|
190 |
+
inputs[i].node().copyAttributes(float_node)
|
191 |
+
|
192 |
+
model.apply(patch_float)
|
193 |
+
patch_float(model.encode_image)
|
194 |
+
patch_float(model.encode_text)
|
195 |
+
|
196 |
+
model.float()
|
197 |
+
|
198 |
+
return model, _transform(model.input_resolution.item())
|
199 |
+
|
200 |
+
|
201 |
+
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
|
202 |
+
"""
|
203 |
+
Returns the tokenized representation of given input string(s)
|
204 |
+
Parameters
|
205 |
+
----------
|
206 |
+
texts : Union[str, List[str]]
|
207 |
+
An input string or a list of input strings to tokenize
|
208 |
+
context_length : int
|
209 |
+
The context length to use; all CLIP models use 77 as the context length
|
210 |
+
truncate: bool
|
211 |
+
Whether to truncate the text in case its encoding is longer than the context length
|
212 |
+
Returns
|
213 |
+
-------
|
214 |
+
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
|
215 |
+
We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
|
216 |
+
"""
|
217 |
+
if isinstance(texts, str):
|
218 |
+
texts = [texts]
|
219 |
+
|
220 |
+
sot_token = _tokenizer.encoder["<|startoftext|>"]
|
221 |
+
eot_token = _tokenizer.encoder["<|endoftext|>"]
|
222 |
+
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
|
223 |
+
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
|
224 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
225 |
+
else:
|
226 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
|
227 |
+
|
228 |
+
for i, tokens in enumerate(all_tokens):
|
229 |
+
if len(tokens) > context_length:
|
230 |
+
if truncate:
|
231 |
+
tokens = tokens[:context_length]
|
232 |
+
tokens[-1] = eot_token
|
233 |
+
else:
|
234 |
+
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
|
235 |
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
236 |
+
|
237 |
+
return result
|
lavila/models/openai_model.py
ADDED
@@ -0,0 +1,485 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Part of the code is from https://github.com/openai/CLIP/blob/main/clip/model.py
|
8 |
+
# Modified by Yue Zhao
|
9 |
+
# The original code is under MIT License
|
10 |
+
|
11 |
+
from collections import OrderedDict
|
12 |
+
from typing import Tuple, Union
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
import torch
|
16 |
+
import torch.nn.functional as F
|
17 |
+
import torch.utils.checkpoint as checkpoint
|
18 |
+
from torch import nn
|
19 |
+
|
20 |
+
|
21 |
+
class Bottleneck(nn.Module):
|
22 |
+
expansion = 4
|
23 |
+
|
24 |
+
def __init__(self, inplanes, planes, stride=1):
|
25 |
+
super().__init__()
|
26 |
+
|
27 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
28 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
29 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
30 |
+
self.relu1 = nn.ReLU(inplace=True)
|
31 |
+
|
32 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
33 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
34 |
+
self.relu2 = nn.ReLU(inplace=True)
|
35 |
+
|
36 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
37 |
+
|
38 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
39 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
40 |
+
self.relu3 = nn.ReLU(inplace=True)
|
41 |
+
|
42 |
+
self.downsample = None
|
43 |
+
self.stride = stride
|
44 |
+
|
45 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
46 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
47 |
+
self.downsample = nn.Sequential(OrderedDict([
|
48 |
+
("-1", nn.AvgPool2d(stride)),
|
49 |
+
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
50 |
+
("1", nn.BatchNorm2d(planes * self.expansion))
|
51 |
+
]))
|
52 |
+
|
53 |
+
def forward(self, x: torch.Tensor):
|
54 |
+
identity = x
|
55 |
+
|
56 |
+
out = self.relu1(self.bn1(self.conv1(x)))
|
57 |
+
out = self.relu2(self.bn2(self.conv2(out)))
|
58 |
+
out = self.avgpool(out)
|
59 |
+
out = self.bn3(self.conv3(out))
|
60 |
+
|
61 |
+
if self.downsample is not None:
|
62 |
+
identity = self.downsample(x)
|
63 |
+
|
64 |
+
out += identity
|
65 |
+
out = self.relu3(out)
|
66 |
+
return out
|
67 |
+
|
68 |
+
|
69 |
+
class AttentionPool2d(nn.Module):
|
70 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
71 |
+
super().__init__()
|
72 |
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
73 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
74 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
75 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
76 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
77 |
+
self.num_heads = num_heads
|
78 |
+
|
79 |
+
def forward(self, x):
|
80 |
+
x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
|
81 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
82 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
83 |
+
x, _ = F.multi_head_attention_forward(
|
84 |
+
query=x[:1], key=x, value=x,
|
85 |
+
embed_dim_to_check=x.shape[-1],
|
86 |
+
num_heads=self.num_heads,
|
87 |
+
q_proj_weight=self.q_proj.weight,
|
88 |
+
k_proj_weight=self.k_proj.weight,
|
89 |
+
v_proj_weight=self.v_proj.weight,
|
90 |
+
in_proj_weight=None,
|
91 |
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
92 |
+
bias_k=None,
|
93 |
+
bias_v=None,
|
94 |
+
add_zero_attn=False,
|
95 |
+
dropout_p=0,
|
96 |
+
out_proj_weight=self.c_proj.weight,
|
97 |
+
out_proj_bias=self.c_proj.bias,
|
98 |
+
use_separate_proj_weight=True,
|
99 |
+
training=self.training,
|
100 |
+
need_weights=False
|
101 |
+
)
|
102 |
+
return x.squeeze(0)
|
103 |
+
|
104 |
+
|
105 |
+
class ModifiedResNet(nn.Module):
|
106 |
+
"""
|
107 |
+
A ResNet class that is similar to torchvision's but contains the following changes:
|
108 |
+
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
109 |
+
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
110 |
+
- The final pooling layer is a QKV attention instead of an average pool
|
111 |
+
"""
|
112 |
+
|
113 |
+
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
114 |
+
super().__init__()
|
115 |
+
self.output_dim = output_dim
|
116 |
+
self.input_resolution = input_resolution
|
117 |
+
|
118 |
+
# the 3-layer stem
|
119 |
+
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
120 |
+
self.bn1 = nn.BatchNorm2d(width // 2)
|
121 |
+
self.relu1 = nn.ReLU(inplace=True)
|
122 |
+
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
123 |
+
self.bn2 = nn.BatchNorm2d(width // 2)
|
124 |
+
self.relu2 = nn.ReLU(inplace=True)
|
125 |
+
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
126 |
+
self.bn3 = nn.BatchNorm2d(width)
|
127 |
+
self.relu3 = nn.ReLU(inplace=True)
|
128 |
+
self.avgpool = nn.AvgPool2d(2)
|
129 |
+
|
130 |
+
# residual layers
|
131 |
+
self._inplanes = width # this is a *mutable* variable used during construction
|
132 |
+
self.layer1 = self._make_layer(width, layers[0])
|
133 |
+
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
134 |
+
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
135 |
+
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
136 |
+
|
137 |
+
embed_dim = width * 32 # the ResNet feature dimension
|
138 |
+
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
|
139 |
+
|
140 |
+
def _make_layer(self, planes, blocks, stride=1):
|
141 |
+
layers = [Bottleneck(self._inplanes, planes, stride)]
|
142 |
+
|
143 |
+
self._inplanes = planes * Bottleneck.expansion
|
144 |
+
for _ in range(1, blocks):
|
145 |
+
layers.append(Bottleneck(self._inplanes, planes))
|
146 |
+
|
147 |
+
return nn.Sequential(*layers)
|
148 |
+
|
149 |
+
def forward(self, x):
|
150 |
+
def stem(x):
|
151 |
+
x = self.relu1(self.bn1(self.conv1(x)))
|
152 |
+
x = self.relu2(self.bn2(self.conv2(x)))
|
153 |
+
x = self.relu3(self.bn3(self.conv3(x)))
|
154 |
+
x = self.avgpool(x)
|
155 |
+
return x
|
156 |
+
|
157 |
+
x = x.type(self.conv1.weight.dtype)
|
158 |
+
x = stem(x)
|
159 |
+
x = self.layer1(x)
|
160 |
+
x = self.layer2(x)
|
161 |
+
x = self.layer3(x)
|
162 |
+
x = self.layer4(x)
|
163 |
+
x = self.attnpool(x)
|
164 |
+
|
165 |
+
return x
|
166 |
+
|
167 |
+
|
168 |
+
class LayerNorm(nn.LayerNorm):
|
169 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
170 |
+
|
171 |
+
def forward(self, x: torch.Tensor):
|
172 |
+
orig_type = x.dtype
|
173 |
+
ret = super().forward(x.type(torch.float32))
|
174 |
+
return ret.type(orig_type)
|
175 |
+
|
176 |
+
|
177 |
+
class QuickGELU(nn.Module):
|
178 |
+
def forward(self, x: torch.Tensor):
|
179 |
+
return x * torch.sigmoid(1.702 * x)
|
180 |
+
|
181 |
+
|
182 |
+
class ResidualAttentionBlock(nn.Module):
|
183 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
184 |
+
super().__init__()
|
185 |
+
|
186 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
187 |
+
self.ln_1 = nn.LayerNorm(d_model) # used to be `models.transformer.LayerNorm`
|
188 |
+
self.mlp = nn.Sequential(OrderedDict([
|
189 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
190 |
+
("gelu", QuickGELU()),
|
191 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
192 |
+
]))
|
193 |
+
self.ln_2 = nn.LayerNorm(d_model) # used to be `models.transformer.LayerNorm`
|
194 |
+
self.attn_mask = attn_mask
|
195 |
+
|
196 |
+
def attention(self, x: torch.Tensor):
|
197 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
198 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
199 |
+
|
200 |
+
def forward_part1(self, x):
|
201 |
+
return self.attention(self.ln_1(x))
|
202 |
+
|
203 |
+
def forward_part2(self, x):
|
204 |
+
return self.mlp(self.ln_2(x))
|
205 |
+
|
206 |
+
def forward(self, x: torch.Tensor, use_checkpoint=False):
|
207 |
+
if use_checkpoint:
|
208 |
+
x = x + checkpoint.checkpoint(self.forward_part1, x)
|
209 |
+
else:
|
210 |
+
x = x + self.forward_part1(x)
|
211 |
+
|
212 |
+
if use_checkpoint:
|
213 |
+
x = x + checkpoint.checkpoint(self.forward_part2, x)
|
214 |
+
else:
|
215 |
+
x = x + self.forward_part2(x)
|
216 |
+
return x
|
217 |
+
|
218 |
+
|
219 |
+
class Transformer(nn.Module):
|
220 |
+
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
|
221 |
+
super().__init__()
|
222 |
+
self.width = width
|
223 |
+
self.layers = layers
|
224 |
+
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
|
225 |
+
|
226 |
+
def forward(self, x: torch.Tensor, use_checkpoint=False):
|
227 |
+
if use_checkpoint:
|
228 |
+
for i in range(self.layers):
|
229 |
+
x = checkpoint.checkpoint(self.resblocks[i], x)
|
230 |
+
return x
|
231 |
+
else:
|
232 |
+
return self.resblocks(x)
|
233 |
+
|
234 |
+
|
235 |
+
class VisionTransformer(nn.Module):
|
236 |
+
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
|
237 |
+
super().__init__()
|
238 |
+
self.input_resolution = input_resolution
|
239 |
+
self.output_dim = output_dim
|
240 |
+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
241 |
+
|
242 |
+
scale = width ** -0.5
|
243 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
244 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
245 |
+
self.ln_pre = LayerNorm(width)
|
246 |
+
|
247 |
+
self.transformer = Transformer(width, layers, heads)
|
248 |
+
|
249 |
+
self.ln_post = LayerNorm(width)
|
250 |
+
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
251 |
+
|
252 |
+
def forward(self, x: torch.Tensor, apply_project=True, use_checkpoint=False, cls_at_last=True):
|
253 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
254 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
255 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
256 |
+
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
257 |
+
x = x + self.positional_embedding.to(x.dtype)
|
258 |
+
x = self.ln_pre(x)
|
259 |
+
|
260 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
261 |
+
x = self.transformer(x, use_checkpoint=use_checkpoint)
|
262 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
263 |
+
|
264 |
+
if cls_at_last:
|
265 |
+
x = self.ln_post(x[:, 0, :])
|
266 |
+
|
267 |
+
if self.proj is not None and apply_project:
|
268 |
+
x = x @ self.proj
|
269 |
+
|
270 |
+
return x
|
271 |
+
else:
|
272 |
+
return x[:, 1:, :]
|
273 |
+
|
274 |
+
|
275 |
+
class CLIP(nn.Module):
|
276 |
+
def __init__(self,
|
277 |
+
embed_dim: int,
|
278 |
+
# vision
|
279 |
+
image_resolution: int,
|
280 |
+
vision_layers: Union[Tuple[int, int, int, int], int],
|
281 |
+
vision_width: int,
|
282 |
+
vision_patch_size: int,
|
283 |
+
# text
|
284 |
+
context_length: int,
|
285 |
+
vocab_size: int,
|
286 |
+
transformer_width: int,
|
287 |
+
transformer_heads: int,
|
288 |
+
transformer_layers: int
|
289 |
+
):
|
290 |
+
super().__init__()
|
291 |
+
|
292 |
+
self.context_length = context_length
|
293 |
+
|
294 |
+
if isinstance(vision_layers, (tuple, list)):
|
295 |
+
vision_heads = vision_width * 32 // 64
|
296 |
+
self.visual = ModifiedResNet(
|
297 |
+
layers=vision_layers,
|
298 |
+
output_dim=embed_dim,
|
299 |
+
heads=vision_heads,
|
300 |
+
input_resolution=image_resolution,
|
301 |
+
width=vision_width
|
302 |
+
)
|
303 |
+
else:
|
304 |
+
vision_heads = vision_width // 64
|
305 |
+
self.visual = VisionTransformer(
|
306 |
+
input_resolution=image_resolution,
|
307 |
+
patch_size=vision_patch_size,
|
308 |
+
width=vision_width,
|
309 |
+
layers=vision_layers,
|
310 |
+
heads=vision_heads,
|
311 |
+
output_dim=embed_dim
|
312 |
+
)
|
313 |
+
|
314 |
+
self.transformer = Transformer(
|
315 |
+
width=transformer_width,
|
316 |
+
layers=transformer_layers,
|
317 |
+
heads=transformer_heads,
|
318 |
+
attn_mask=self.build_attention_mask()
|
319 |
+
)
|
320 |
+
|
321 |
+
self.vocab_size = vocab_size
|
322 |
+
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
323 |
+
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
324 |
+
self.ln_final = LayerNorm(transformer_width)
|
325 |
+
|
326 |
+
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
327 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
328 |
+
|
329 |
+
self.initialize_parameters()
|
330 |
+
|
331 |
+
def initialize_parameters(self):
|
332 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
333 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
334 |
+
|
335 |
+
if isinstance(self.visual, ModifiedResNet):
|
336 |
+
if self.visual.attnpool is not None:
|
337 |
+
std = self.visual.attnpool.c_proj.in_features ** -0.5
|
338 |
+
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
339 |
+
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
340 |
+
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
341 |
+
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
342 |
+
|
343 |
+
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
|
344 |
+
for name, param in resnet_block.named_parameters():
|
345 |
+
if name.endswith("bn3.weight"):
|
346 |
+
nn.init.zeros_(param)
|
347 |
+
|
348 |
+
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
349 |
+
attn_std = self.transformer.width ** -0.5
|
350 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
351 |
+
for block in self.transformer.resblocks:
|
352 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
353 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
354 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
355 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
356 |
+
|
357 |
+
if self.text_projection is not None:
|
358 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
359 |
+
|
360 |
+
def build_attention_mask(self):
|
361 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
362 |
+
# pytorch uses additive attention mask; fill with -inf
|
363 |
+
mask = torch.empty(self.context_length, self.context_length)
|
364 |
+
mask.fill_(float("-inf"))
|
365 |
+
mask.triu_(1) # zero out the lower diagonal
|
366 |
+
return mask
|
367 |
+
|
368 |
+
@property
|
369 |
+
def dtype(self):
|
370 |
+
return self.visual.conv1.weight.dtype
|
371 |
+
|
372 |
+
def encode_image(self, image, apply_project=True, use_checkpoint=False):
|
373 |
+
if image.ndim == 4:
|
374 |
+
return self.visual(image.type(self.dtype))
|
375 |
+
else:
|
376 |
+
image = image.permute(0, 2, 1, 3, 4) # BCTHW -> BTCHW
|
377 |
+
bb, tt, _, _, _ = image.shape
|
378 |
+
x = self.visual(image.reshape(-1, *image.shape[2:]), apply_project=apply_project, use_checkpoint=use_checkpoint) # ND
|
379 |
+
x = x.view(bb, tt, -1)
|
380 |
+
image_features = x.mean(1)
|
381 |
+
# image_features = x.max(1).values
|
382 |
+
return image_features
|
383 |
+
|
384 |
+
def encode_text(self, text, use_checkpoint=False):
|
385 |
+
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
386 |
+
|
387 |
+
x = x + self.positional_embedding.type(self.dtype)
|
388 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
389 |
+
x = self.transformer(x, use_checkpoint=use_checkpoint)
|
390 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
391 |
+
x = self.ln_final(x).type(self.dtype)
|
392 |
+
|
393 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
394 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
395 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
396 |
+
|
397 |
+
return x
|
398 |
+
|
399 |
+
def forward(self, image, text, use_checkpoint=False, norm_embed=True):
|
400 |
+
image_features = self.encode_image(image, use_checkpoint=use_checkpoint)
|
401 |
+
text_features = self.encode_text(text, use_checkpoint=use_checkpoint)
|
402 |
+
|
403 |
+
# normalized features
|
404 |
+
image_features = image_features / image_features.norm(dim=1, keepdim=True)
|
405 |
+
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
406 |
+
|
407 |
+
# # cosine similarity as logits
|
408 |
+
# logit_scale = self.logit_scale.exp()
|
409 |
+
# logits_per_image = logit_scale * image_features @ text_features.t()
|
410 |
+
# logits_per_text = logits_per_image.t()
|
411 |
+
|
412 |
+
# # shape = [global_batch_size, global_batch_size]
|
413 |
+
# return logits_per_image, logits_per_text
|
414 |
+
|
415 |
+
return {'image_embed': image_features,
|
416 |
+
'text_embed': text_features,
|
417 |
+
'logit_scale': self.logit_scale.exp()}
|
418 |
+
|
419 |
+
|
420 |
+
def convert_weights(model: nn.Module):
|
421 |
+
"""Convert applicable model parameters to fp16"""
|
422 |
+
|
423 |
+
def _convert_weights_to_fp16(l):
|
424 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
425 |
+
l.weight.data = l.weight.data.half()
|
426 |
+
if l.bias is not None:
|
427 |
+
l.bias.data = l.bias.data.half()
|
428 |
+
|
429 |
+
if isinstance(l, nn.MultiheadAttention):
|
430 |
+
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
431 |
+
tensor = getattr(l, attr)
|
432 |
+
if tensor is not None:
|
433 |
+
tensor.data = tensor.data.half()
|
434 |
+
|
435 |
+
for name in ["text_projection", "proj"]:
|
436 |
+
if hasattr(l, name):
|
437 |
+
attr = getattr(l, name)
|
438 |
+
if attr is not None:
|
439 |
+
attr.data = attr.data.half()
|
440 |
+
|
441 |
+
model.apply(_convert_weights_to_fp16)
|
442 |
+
|
443 |
+
|
444 |
+
def build_model(state_dict: dict):
|
445 |
+
vit = "visual.proj" in state_dict
|
446 |
+
|
447 |
+
if vit:
|
448 |
+
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
449 |
+
vision_layers = len(
|
450 |
+
[k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]
|
451 |
+
)
|
452 |
+
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
453 |
+
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
454 |
+
image_resolution = vision_patch_size * grid_size
|
455 |
+
else:
|
456 |
+
counts: list = [
|
457 |
+
len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]
|
458 |
+
]
|
459 |
+
vision_layers = tuple(counts)
|
460 |
+
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
461 |
+
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
462 |
+
vision_patch_size = None
|
463 |
+
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
464 |
+
image_resolution = output_width * 32
|
465 |
+
|
466 |
+
embed_dim = state_dict["text_projection"].shape[1]
|
467 |
+
context_length = state_dict["positional_embedding"].shape[0]
|
468 |
+
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
469 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
470 |
+
transformer_heads = transformer_width // 64
|
471 |
+
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
|
472 |
+
|
473 |
+
model = CLIP(
|
474 |
+
embed_dim,
|
475 |
+
image_resolution, vision_layers, vision_width, vision_patch_size,
|
476 |
+
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
|
477 |
+
)
|
478 |
+
|
479 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
480 |
+
if key in state_dict:
|
481 |
+
del state_dict[key]
|
482 |
+
|
483 |
+
convert_weights(model)
|
484 |
+
model.load_state_dict(state_dict)
|
485 |
+
return model.eval()
|
lavila/models/timesformer.py
ADDED
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Part of the code is from https://github.com/m-bain/frozen-in-time/blob/main/model/video_transformer.py
|
8 |
+
# Modified by Yue Zhao
|
9 |
+
# The original code is under MIT License
|
10 |
+
|
11 |
+
"""
|
12 |
+
Implementations of Video Transformers in PyTorch
|
13 |
+
A PyTorch implementation of space-time transformer as described in
|
14 |
+
'Frozen in Time: A Joint Image and Video Encoder for End-to-End Retrieval' - https://arxiv.org/abs/2104.00650
|
15 |
+
A PyTorch implementation of timesformer as described in
|
16 |
+
'Is Space-Time Attention All You Need for Video Understanding?' - https://arxiv.org/abs/2102.05095
|
17 |
+
Acknowledgments:
|
18 |
+
- This code builds on Ross Wightman's vision_transformer code in pytorch-image-models:
|
19 |
+
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
20 |
+
- It is also inspired by lucidrains timesformer implementation:
|
21 |
+
https://github.com/lucidrains/TimeSformer-pytorch
|
22 |
+
Hacked together by Max Bain
|
23 |
+
"""
|
24 |
+
|
25 |
+
from collections import OrderedDict
|
26 |
+
from functools import partial
|
27 |
+
|
28 |
+
import torch
|
29 |
+
import torch.utils.checkpoint as checkpoint
|
30 |
+
from einops import rearrange, repeat
|
31 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
32 |
+
from torch import einsum, nn
|
33 |
+
|
34 |
+
|
35 |
+
def attn(q, k, v):
|
36 |
+
sim = einsum('b i d, b j d -> b i j', q, k)
|
37 |
+
attn = sim.softmax(dim=-1)
|
38 |
+
out = einsum('b i j, b j d -> b i d', attn, v)
|
39 |
+
return out
|
40 |
+
|
41 |
+
|
42 |
+
class Mlp(nn.Module):
|
43 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
44 |
+
super().__init__()
|
45 |
+
out_features = out_features or in_features
|
46 |
+
hidden_features = hidden_features or in_features
|
47 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
48 |
+
self.act = act_layer()
|
49 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
50 |
+
self.drop = nn.Dropout(drop)
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
x = self.fc1(x)
|
54 |
+
x = self.act(x)
|
55 |
+
x = self.drop(x)
|
56 |
+
x = self.fc2(x)
|
57 |
+
x = self.drop(x)
|
58 |
+
return x
|
59 |
+
|
60 |
+
|
61 |
+
class VideoPatchEmbed(nn.Module):
|
62 |
+
""" Video to Patch Embedding
|
63 |
+
"""
|
64 |
+
|
65 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768,
|
66 |
+
num_frames=8, ln_pre=False):
|
67 |
+
super().__init__()
|
68 |
+
img_size = to_2tuple(img_size)
|
69 |
+
patch_size = to_2tuple(patch_size)
|
70 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * num_frames
|
71 |
+
self.img_size = img_size
|
72 |
+
self.patch_size = patch_size
|
73 |
+
self.num_patches = num_patches
|
74 |
+
self.num_frames = num_frames
|
75 |
+
self.embed_dim = embed_dim
|
76 |
+
# ln_pre is inserted to be compatible with CLIP-style model
|
77 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=not ln_pre)
|
78 |
+
|
79 |
+
def forward(self, x):
|
80 |
+
B, F, C, H, W = x.shape
|
81 |
+
assert F <= self.num_frames
|
82 |
+
x = x.view(-1, C, H, W)
|
83 |
+
x = self.proj(x)
|
84 |
+
return x
|
85 |
+
|
86 |
+
|
87 |
+
class VarAttention(nn.Module):
|
88 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.,
|
89 |
+
initialize='random'):
|
90 |
+
super().__init__()
|
91 |
+
self.num_heads = num_heads
|
92 |
+
head_dim = dim // num_heads
|
93 |
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
94 |
+
self.scale = qk_scale or head_dim ** -0.5
|
95 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
96 |
+
self.proj = nn.Linear(dim, dim)
|
97 |
+
if initialize == 'zeros':
|
98 |
+
self.qkv.weight.data.fill_(0)
|
99 |
+
self.qkv.bias.data.fill_(0)
|
100 |
+
# fill proj weight with 1 here to improve training dynamics. Otherwise temporal attention inputs
|
101 |
+
# are multiplied by 0*0, which is hard for the model to move out of.
|
102 |
+
self.proj.weight.data.fill_(1)
|
103 |
+
self.proj.bias.data.fill_(0)
|
104 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
105 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
106 |
+
|
107 |
+
def forward(self, x, einops_from, einops_to, einops_dims):
|
108 |
+
h = self.num_heads
|
109 |
+
# project x to q, k, v vaalues
|
110 |
+
q, k, v = self.qkv(x).chunk(3, dim=-1)
|
111 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
112 |
+
|
113 |
+
q *= self.scale
|
114 |
+
|
115 |
+
# splice out CLS token at index 1
|
116 |
+
(cls_q, q_), (cls_k, k_), (cls_v, v_) = map(lambda t: (t[:, 0:1], t[:, 1:]), (q, k, v))
|
117 |
+
|
118 |
+
# let CLS token attend to key / values of all patches across time and space
|
119 |
+
cls_out = attn(cls_q, k, v)
|
120 |
+
# rearrange across time or space
|
121 |
+
q_, k_, v_ = map(lambda t: rearrange(t, f'{einops_from} -> {einops_to}', **einops_dims), (q_, k_, v_))
|
122 |
+
|
123 |
+
# expand cls token keys and values across time or space and concat
|
124 |
+
r = q_.shape[0] // cls_k.shape[0]
|
125 |
+
cls_k, cls_v = map(lambda t: repeat(t, 'b () d -> (b r) () d', r=r), (cls_k, cls_v))
|
126 |
+
|
127 |
+
k_ = torch.cat((cls_k, k_), dim=1)
|
128 |
+
v_ = torch.cat((cls_v, v_), dim=1)
|
129 |
+
|
130 |
+
# attention
|
131 |
+
out = attn(q_, k_, v_)
|
132 |
+
|
133 |
+
# merge back time or space
|
134 |
+
out = rearrange(out, f'{einops_to} -> {einops_from}', **einops_dims)
|
135 |
+
|
136 |
+
# concat back the cls token
|
137 |
+
out = torch.cat((cls_out, out), dim=1)
|
138 |
+
|
139 |
+
# merge back the heads
|
140 |
+
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
141 |
+
# to out
|
142 |
+
x = self.proj(out)
|
143 |
+
x = self.proj_drop(x)
|
144 |
+
return x
|
145 |
+
|
146 |
+
|
147 |
+
class SpaceTimeBlock(nn.Module):
|
148 |
+
|
149 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
150 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, time_init='zeros',
|
151 |
+
attention_style='frozen-in-time', is_tanh_gating=False):
|
152 |
+
super().__init__()
|
153 |
+
self.norm1 = norm_layer(dim)
|
154 |
+
self.attn = VarAttention(
|
155 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
156 |
+
|
157 |
+
self.timeattn = VarAttention(
|
158 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
|
159 |
+
initialize=time_init)
|
160 |
+
|
161 |
+
if is_tanh_gating:
|
162 |
+
self.alpha_timeattn = nn.Parameter(torch.zeros([]))
|
163 |
+
|
164 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
165 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
166 |
+
self.norm2 = norm_layer(dim)
|
167 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
168 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
169 |
+
self.norm3 = norm_layer(dim)
|
170 |
+
|
171 |
+
self.attention_style = attention_style
|
172 |
+
|
173 |
+
def forward(self, x, einops_from_space, einops_to_space, einops_from_time, einops_to_time,
|
174 |
+
time_n, space_f, use_checkpoint=False):
|
175 |
+
if use_checkpoint:
|
176 |
+
time_output = checkpoint.checkpoint(
|
177 |
+
self.timeattn, self.norm3(x), einops_from_time, einops_to_time, {"n": time_n}
|
178 |
+
)
|
179 |
+
else:
|
180 |
+
time_output = self.timeattn(self.norm3(x), einops_from_time, einops_to_time, {"n": time_n})
|
181 |
+
if hasattr(self, "alpha_timeattn"):
|
182 |
+
time_output = torch.tanh(self.alpha_timeattn) * time_output
|
183 |
+
time_residual = x + time_output
|
184 |
+
if use_checkpoint:
|
185 |
+
space_output = checkpoint.checkpoint(
|
186 |
+
self.attn, self.norm1(time_residual), einops_from_space, einops_to_space, {"f": space_f}
|
187 |
+
)
|
188 |
+
else:
|
189 |
+
space_output = self.attn(self.norm1(time_residual), einops_from_space,
|
190 |
+
einops_to_space, {"f": space_f})
|
191 |
+
if self.attention_style == 'frozen-in-time':
|
192 |
+
space_residual = x + self.drop_path(space_output)
|
193 |
+
else:
|
194 |
+
raise NotImplementedError
|
195 |
+
|
196 |
+
x = space_residual + self.drop_path(self.mlp(self.norm2(space_residual)))
|
197 |
+
|
198 |
+
return x
|
199 |
+
|
200 |
+
|
201 |
+
class SpaceTimeTransformer(nn.Module):
|
202 |
+
""" Vision Transformer
|
203 |
+
A PyTorch impl of : `Space-Time Transformer` from Frozen-in-time - by Max Bain.
|
204 |
+
https://arxiv.org/abs/2104.00650
|
205 |
+
Based off:
|
206 |
+
- ViT implementation from the timm library [https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py]
|
207 |
+
lucidrains timesformer implementation [https://github.com/lucidrains/TimeSformer-pytorch].
|
208 |
+
Notable differences:
|
209 |
+
- allows for variable length input frames (<= num_frames)
|
210 |
+
- allows for variable length input resolution (<= (img_size, img_size)) [UNTESTED]
|
211 |
+
- different attention block mechanism
|
212 |
+
"""
|
213 |
+
|
214 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
215 |
+
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
|
216 |
+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None,
|
217 |
+
num_frames=8, time_init='rand', attention_style='frozen-in-time', ln_pre=False,
|
218 |
+
act_layer=nn.GELU, is_tanh_gating=False):
|
219 |
+
"""
|
220 |
+
Args:
|
221 |
+
img_size (int, tuple): input image size
|
222 |
+
patch_size (int, tuple): patch size
|
223 |
+
in_chans (int): number of input channels
|
224 |
+
num_classes (int): number of classes for classification head
|
225 |
+
embed_dim (int): embedding dimension
|
226 |
+
depth (int): depth of transformer
|
227 |
+
num_heads (int): number of attention heads
|
228 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
229 |
+
qkv_bias (bool): enable bias for qkv if True
|
230 |
+
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
|
231 |
+
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
|
232 |
+
drop_rate (float): dropout rate
|
233 |
+
attn_drop_rate (float): attention dropout rate
|
234 |
+
drop_path_rate (float): stochastic depth rate
|
235 |
+
hybrid_backbone (nn.Module): CNN backbone to use in-place of PatchEmbed module
|
236 |
+
norm_layer: (nn.Module): normalization layer
|
237 |
+
num_frames: (int) maximum number of frames expected as input
|
238 |
+
time_init: (str) how to initialise the time attention layer, 'zeros' allows for the timesformer to start off
|
239 |
+
as ViT.
|
240 |
+
attention_style: (str) how to attend to space and time.
|
241 |
+
"""
|
242 |
+
super().__init__()
|
243 |
+
self.num_classes = num_classes
|
244 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
245 |
+
self.num_frames = num_frames
|
246 |
+
self.embed_dim = embed_dim
|
247 |
+
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
248 |
+
print("######USING ATTENTION STYLE: ", attention_style)
|
249 |
+
if hybrid_backbone is not None:
|
250 |
+
raise NotImplementedError('hybrid backbone not implemented')
|
251 |
+
else:
|
252 |
+
self.patch_embed = VideoPatchEmbed(
|
253 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, num_frames=num_frames, ln_pre=ln_pre)
|
254 |
+
num_patches = self.patch_embed.num_patches
|
255 |
+
self.patches_per_frame = num_patches // num_frames
|
256 |
+
|
257 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
258 |
+
self.pos_embed = nn.Parameter(
|
259 |
+
torch.zeros(1, self.patches_per_frame + 1,
|
260 |
+
embed_dim)) # remember to take pos_embed[1:] for tiling over time
|
261 |
+
self.temporal_embed = nn.Parameter(torch.zeros(1, num_frames, embed_dim))
|
262 |
+
|
263 |
+
if ln_pre:
|
264 |
+
self.ln_pre = nn.LayerNorm(embed_dim)
|
265 |
+
else:
|
266 |
+
self.ln_pre = None
|
267 |
+
|
268 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
269 |
+
|
270 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
271 |
+
self.blocks = nn.ModuleList([
|
272 |
+
SpaceTimeBlock(
|
273 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
274 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, time_init=time_init,
|
275 |
+
attention_style=attention_style, act_layer=act_layer, is_tanh_gating=is_tanh_gating)
|
276 |
+
for i in range(depth)])
|
277 |
+
self.norm = norm_layer(embed_dim)
|
278 |
+
|
279 |
+
# Representation layer
|
280 |
+
if representation_size:
|
281 |
+
self.num_features = representation_size
|
282 |
+
self.pre_logits = nn.Sequential(OrderedDict([
|
283 |
+
('fc', nn.Linear(embed_dim, representation_size)),
|
284 |
+
('act', nn.Tanh())
|
285 |
+
]))
|
286 |
+
else:
|
287 |
+
self.pre_logits = nn.Identity()
|
288 |
+
|
289 |
+
# Classifier head
|
290 |
+
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
291 |
+
|
292 |
+
trunc_normal_(self.pos_embed, std=.02)
|
293 |
+
trunc_normal_(self.cls_token, std=.02)
|
294 |
+
|
295 |
+
# if num_frames > 1, then we perform ViT inflation and initialise time attention to zero so not necessary.
|
296 |
+
if num_frames == 1:
|
297 |
+
self.apply(self._init_weights)
|
298 |
+
|
299 |
+
# einops transformations
|
300 |
+
self.einops_from_space = 'b (f n) d'
|
301 |
+
self.einops_to_space = '(b f) n d'
|
302 |
+
self.einops_from_time = 'b (f n) d'
|
303 |
+
self.einops_to_time = '(b n) f d'
|
304 |
+
|
305 |
+
def _init_weights(self, m):
|
306 |
+
if isinstance(m, nn.Linear):
|
307 |
+
trunc_normal_(m.weight, std=.02)
|
308 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
309 |
+
nn.init.constant_(m.bias, 0)
|
310 |
+
elif isinstance(m, nn.LayerNorm):
|
311 |
+
nn.init.constant_(m.bias, 0)
|
312 |
+
nn.init.constant_(m.weight, 1.0)
|
313 |
+
|
314 |
+
@torch.jit.ignore
|
315 |
+
def no_weight_decay(self):
|
316 |
+
return {'pos_embed', 'cls_token'}
|
317 |
+
|
318 |
+
def get_classifier(self):
|
319 |
+
return self.head
|
320 |
+
|
321 |
+
def reset_classifier(self, num_classes, global_pool=''):
|
322 |
+
self.num_classes = num_classes
|
323 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
324 |
+
|
325 |
+
def freeze_spatial_weights(self):
|
326 |
+
freeze_list = []
|
327 |
+
for n, p in self.named_parameters():
|
328 |
+
if 'temporal_embed' in n or 'timeattn' in n or 'norm3' in n:
|
329 |
+
pass
|
330 |
+
else:
|
331 |
+
p.requires_grad = False
|
332 |
+
freeze_list.append(n)
|
333 |
+
print("Freeze the pretrained parts in vision model: {}".format(freeze_list))
|
334 |
+
|
335 |
+
def freeze_temporal_weights(self):
|
336 |
+
freeze_list = []
|
337 |
+
for n, p in self.named_parameters():
|
338 |
+
if 'temporal_embed' in n or 'timeattn' in n or 'norm3' in n:
|
339 |
+
p.requires_grad = False
|
340 |
+
freeze_list.append(n)
|
341 |
+
else:
|
342 |
+
pass
|
343 |
+
print("Freeze the pretrained parts in vision model: {}".format(freeze_list))
|
344 |
+
|
345 |
+
def forward_features(self, x, use_checkpoint=False, cls_at_last=True):
|
346 |
+
# print(x.shape)
|
347 |
+
b, curr_frames, channels, _, _ = x.shape
|
348 |
+
x = self.patch_embed(x)
|
349 |
+
x = x.flatten(2).transpose(2, 1)
|
350 |
+
x = x.reshape(b, -1, self.patch_embed.embed_dim)
|
351 |
+
|
352 |
+
BF = x.shape[0]
|
353 |
+
cls_tokens = self.cls_token.expand(BF, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
354 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
355 |
+
# positional embed needs to be tiled for each frame (this does [1,2,3] --> [1,2,3,1,2,3]...)
|
356 |
+
cls_embed = self.pos_embed[:, 0, :].unsqueeze(1)
|
357 |
+
tile_pos_embed = self.pos_embed[:, 1:, :].repeat(1, self.num_frames, 1)
|
358 |
+
# temporal embed needs to be repeated within each frame (this does [1,2,3] --> [1,1,1,2,2,2,3,3,3]...)
|
359 |
+
tile_temporal_embed = self.temporal_embed.repeat_interleave(self.patches_per_frame, 1)
|
360 |
+
total_pos_embed = tile_pos_embed + tile_temporal_embed
|
361 |
+
total_pos_embed = torch.cat([cls_embed, total_pos_embed], dim=1)
|
362 |
+
|
363 |
+
curr_patches = x.shape[1]
|
364 |
+
x = x + total_pos_embed[:, :curr_patches]
|
365 |
+
if self.ln_pre is not None:
|
366 |
+
x = self.ln_pre(x)
|
367 |
+
x = self.pos_drop(x)
|
368 |
+
n = self.patches_per_frame
|
369 |
+
f = curr_frames
|
370 |
+
|
371 |
+
for blk in self.blocks:
|
372 |
+
x = blk(x, self.einops_from_space, self.einops_to_space, self.einops_from_time,
|
373 |
+
self.einops_to_time,
|
374 |
+
time_n=n, space_f=f, use_checkpoint=use_checkpoint)
|
375 |
+
|
376 |
+
if cls_at_last:
|
377 |
+
x = self.norm(x)[:, 0]
|
378 |
+
x = self.pre_logits(x)
|
379 |
+
|
380 |
+
return x
|
381 |
+
else:
|
382 |
+
return self.norm(x)
|
383 |
+
|
384 |
+
def forward(self, x, use_checkpoint=False):
|
385 |
+
# Note: B C T H W => B T C H W
|
386 |
+
# The default input order is different from the one in Frozen-in-Time
|
387 |
+
x = x.permute(0, 2, 1, 3, 4).contiguous()
|
388 |
+
x = self.forward_features(x, use_checkpoint=use_checkpoint)
|
389 |
+
x = self.head(x)
|
390 |
+
return x
|
lavila/models/tokenizer.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Part of the code is from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py
|
8 |
+
# Modified by Yue Zhao
|
9 |
+
# The original code is under MIT License
|
10 |
+
|
11 |
+
import gzip
|
12 |
+
import html
|
13 |
+
import os
|
14 |
+
from functools import lru_cache
|
15 |
+
|
16 |
+
import ftfy
|
17 |
+
import regex as re
|
18 |
+
import torch
|
19 |
+
|
20 |
+
from transformers import (BertTokenizer, DistilBertTokenizer, GPT2Tokenizer)
|
21 |
+
|
22 |
+
|
23 |
+
@lru_cache()
|
24 |
+
def default_bpe():
|
25 |
+
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
|
26 |
+
|
27 |
+
|
28 |
+
@lru_cache()
|
29 |
+
def bytes_to_unicode():
|
30 |
+
"""
|
31 |
+
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
32 |
+
The reversible bpe codes work on unicode strings.
|
33 |
+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
34 |
+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
35 |
+
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
36 |
+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
37 |
+
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
38 |
+
"""
|
39 |
+
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
40 |
+
cs = bs[:]
|
41 |
+
n = 0
|
42 |
+
for b in range(2**8):
|
43 |
+
if b not in bs:
|
44 |
+
bs.append(b)
|
45 |
+
cs.append(2**8+n)
|
46 |
+
n += 1
|
47 |
+
cs = [chr(n) for n in cs]
|
48 |
+
return dict(zip(bs, cs))
|
49 |
+
|
50 |
+
|
51 |
+
def get_pairs(word):
|
52 |
+
"""Return set of symbol pairs in a word.
|
53 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
54 |
+
"""
|
55 |
+
pairs = set()
|
56 |
+
prev_char = word[0]
|
57 |
+
for char in word[1:]:
|
58 |
+
pairs.add((prev_char, char))
|
59 |
+
prev_char = char
|
60 |
+
return pairs
|
61 |
+
|
62 |
+
|
63 |
+
def basic_clean(text):
|
64 |
+
text = ftfy.fix_text(text)
|
65 |
+
text = html.unescape(html.unescape(text))
|
66 |
+
return text.strip()
|
67 |
+
|
68 |
+
|
69 |
+
def whitespace_clean(text):
|
70 |
+
text = re.sub(r'\s+', ' ', text)
|
71 |
+
text = text.strip()
|
72 |
+
return text
|
73 |
+
|
74 |
+
|
75 |
+
class SimpleTokenizer(object):
|
76 |
+
def __init__(self, bpe_path: str = default_bpe()):
|
77 |
+
self.byte_encoder = bytes_to_unicode()
|
78 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
79 |
+
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
80 |
+
merges = merges[1:49152-256-2+1]
|
81 |
+
merges = [tuple(merge.split()) for merge in merges]
|
82 |
+
vocab = list(bytes_to_unicode().values())
|
83 |
+
vocab = vocab + [v+'</w>' for v in vocab]
|
84 |
+
for merge in merges:
|
85 |
+
vocab.append(''.join(merge))
|
86 |
+
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
87 |
+
self.encoder = dict(zip(vocab, range(len(vocab))))
|
88 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
89 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
90 |
+
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
91 |
+
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
92 |
+
|
93 |
+
def bpe(self, token):
|
94 |
+
if token in self.cache:
|
95 |
+
return self.cache[token]
|
96 |
+
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
97 |
+
pairs = get_pairs(word)
|
98 |
+
|
99 |
+
if not pairs:
|
100 |
+
return token+'</w>'
|
101 |
+
|
102 |
+
while True:
|
103 |
+
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
104 |
+
if bigram not in self.bpe_ranks:
|
105 |
+
break
|
106 |
+
first, second = bigram
|
107 |
+
new_word = []
|
108 |
+
i = 0
|
109 |
+
while i < len(word):
|
110 |
+
try:
|
111 |
+
j = word.index(first, i)
|
112 |
+
new_word.extend(word[i:j])
|
113 |
+
i = j
|
114 |
+
except:
|
115 |
+
new_word.extend(word[i:])
|
116 |
+
break
|
117 |
+
|
118 |
+
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
119 |
+
new_word.append(first+second)
|
120 |
+
i += 2
|
121 |
+
else:
|
122 |
+
new_word.append(word[i])
|
123 |
+
i += 1
|
124 |
+
new_word = tuple(new_word)
|
125 |
+
word = new_word
|
126 |
+
if len(word) == 1:
|
127 |
+
break
|
128 |
+
else:
|
129 |
+
pairs = get_pairs(word)
|
130 |
+
word = ' '.join(word)
|
131 |
+
self.cache[token] = word
|
132 |
+
return word
|
133 |
+
|
134 |
+
def encode(self, text):
|
135 |
+
bpe_tokens = []
|
136 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
137 |
+
for token in re.findall(self.pat, text):
|
138 |
+
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
139 |
+
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
140 |
+
return bpe_tokens
|
141 |
+
|
142 |
+
def decode(self, tokens):
|
143 |
+
text = ''.join([self.decoder[token] for token in tokens])
|
144 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
145 |
+
return text
|
146 |
+
|
147 |
+
def __call__(self, texts, context_length=77):
|
148 |
+
if isinstance(texts, str):
|
149 |
+
texts = [texts]
|
150 |
+
|
151 |
+
sot_token = self.encoder["<|startoftext|>"]
|
152 |
+
eot_token = self.encoder["<|endoftext|>"]
|
153 |
+
all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts]
|
154 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
155 |
+
|
156 |
+
for i, tokens in enumerate(all_tokens):
|
157 |
+
tokens = tokens[:context_length]
|
158 |
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
159 |
+
|
160 |
+
if len(result) == 1:
|
161 |
+
return result[0]
|
162 |
+
return result
|
163 |
+
|
164 |
+
|
165 |
+
class MyBertTokenizer(object):
|
166 |
+
def __init__(self, name=''):
|
167 |
+
print('=> Initialize MyBertTokenizer ({})'.format(name))
|
168 |
+
self.tokenizer = BertTokenizer.from_pretrained(name)
|
169 |
+
self.bos_token_id, self.eos_token_id = self.tokenizer('').input_ids
|
170 |
+
self.pad_token_id = 0
|
171 |
+
|
172 |
+
def __call__(self, texts, context_length=77):
|
173 |
+
if isinstance(texts, str):
|
174 |
+
texts = [texts]
|
175 |
+
result = torch.zeros(len(texts), context_length, dtype=torch.long)
|
176 |
+
mask = torch.zeros(len(texts), context_length, dtype=torch.float32)
|
177 |
+
for i, text in enumerate(texts):
|
178 |
+
tokens = self.tokenizer(text)
|
179 |
+
input_ids = tokens.input_ids[:context_length]
|
180 |
+
attention_mask = tokens.attention_mask[:context_length]
|
181 |
+
result[i, :len(input_ids)] = torch.tensor(input_ids)
|
182 |
+
mask[i, :len(attention_mask)] = torch.tensor(attention_mask)
|
183 |
+
|
184 |
+
if len(result) == 1:
|
185 |
+
return result[0], mask[0]
|
186 |
+
return result, mask
|
187 |
+
|
188 |
+
|
189 |
+
class MyDistilBertTokenizer(object):
|
190 |
+
def __init__(self, name=''):
|
191 |
+
print('=> Initialize MyDistilBertTokenizer ({})'.format(name))
|
192 |
+
self.tokenizer = DistilBertTokenizer.from_pretrained(name)
|
193 |
+
|
194 |
+
def __call__(self, texts, context_length=77):
|
195 |
+
if isinstance(texts, str):
|
196 |
+
texts = [texts]
|
197 |
+
result = torch.zeros(len(texts), context_length, dtype=torch.long)
|
198 |
+
mask = torch.zeros(len(texts), context_length, dtype=torch.float32)
|
199 |
+
for i, text in enumerate(texts):
|
200 |
+
tokens = self.tokenizer(text)
|
201 |
+
input_ids = tokens.input_ids[:context_length]
|
202 |
+
attention_mask = tokens.attention_mask[:context_length]
|
203 |
+
result[i, :len(input_ids)] = torch.tensor(input_ids)
|
204 |
+
mask[i, :len(attention_mask)] = torch.tensor(attention_mask)
|
205 |
+
|
206 |
+
if len(result) == 1:
|
207 |
+
return result[0], mask[0]
|
208 |
+
return result, mask
|
209 |
+
|
210 |
+
|
211 |
+
class MyGPT2Tokenizer(object):
|
212 |
+
def __init__(self, name='', add_bos=False):
|
213 |
+
print('=> Initialize MyGPT2Tokenizer ({})'.format(name))
|
214 |
+
self.tokenizer = GPT2Tokenizer.from_pretrained(name)
|
215 |
+
self.bos_token_id, self.eos_token_id = self.tokenizer.bos_token_id, self.tokenizer.eos_token_id
|
216 |
+
self.pad_token_id = 0
|
217 |
+
self.add_bos = add_bos
|
218 |
+
# num_added_tokens = self.tokenizer.add_special_tokens({'pad_token': "[PAD]"})
|
219 |
+
# print('num_added_tokens={}'.format(len(num_added_tokens)))
|
220 |
+
|
221 |
+
def __call__(self, texts, context_length=77):
|
222 |
+
if isinstance(texts, str):
|
223 |
+
texts = [texts]
|
224 |
+
result = torch.zeros(len(texts), context_length, dtype=torch.long)
|
225 |
+
for i, text in enumerate(texts):
|
226 |
+
tokens = self.tokenizer(text)
|
227 |
+
if not self.add_bos:
|
228 |
+
input_ids = tokens.input_ids[:context_length - 1]
|
229 |
+
input_ids = input_ids + [self.tokenizer.eos_token_id] # add [EOS]
|
230 |
+
else:
|
231 |
+
input_ids = tokens.input_ids[:context_length - 2]
|
232 |
+
input_ids = [self.tokenizer.bos_token_id] + input_ids + [self.tokenizer.eos_token_id] # add [EOS]
|
233 |
+
# attention_mask = tokens.attention_mask[:context_length]
|
234 |
+
# attention_mask = attention_mask + [0.] * pad_length
|
235 |
+
result[i, :len(input_ids)] = torch.tensor(input_ids)
|
236 |
+
|
237 |
+
if len(result) == 1:
|
238 |
+
return result[0]
|
239 |
+
return result
|
lavila/models/utils.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from collections import OrderedDict
|
8 |
+
import functools
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
|
13 |
+
def inflate_positional_embeds(
|
14 |
+
current_model_state_dict, new_state_dict,
|
15 |
+
num_frames=4,
|
16 |
+
load_temporal_fix='bilinear',
|
17 |
+
):
|
18 |
+
# allow loading of timesformer with fewer num_frames
|
19 |
+
curr_keys = list(current_model_state_dict.keys())
|
20 |
+
if 'visual.temporal_embed' in new_state_dict and 'visual.temporal_embed' in curr_keys:
|
21 |
+
load_temporal_embed = new_state_dict['visual.temporal_embed']
|
22 |
+
load_num_frames = load_temporal_embed.shape[1]
|
23 |
+
curr_num_frames = num_frames
|
24 |
+
embed_dim = load_temporal_embed.shape[2]
|
25 |
+
|
26 |
+
if load_num_frames != curr_num_frames:
|
27 |
+
if load_num_frames > curr_num_frames:
|
28 |
+
print(f'### loaded SpaceTimeTransformer model has MORE frames than current...'
|
29 |
+
f'### loading weights, filling in the extras via {load_temporal_fix}')
|
30 |
+
new_temporal_embed = load_temporal_embed[:, :curr_num_frames, :]
|
31 |
+
else:
|
32 |
+
print(f'### loaded SpaceTimeTransformer model has FEWER frames than current...'
|
33 |
+
f'### loading weights, filling in the extras via {load_temporal_fix}')
|
34 |
+
if load_temporal_fix == 'zeros':
|
35 |
+
new_temporal_embed = torch.zeros([load_temporal_embed.shape[0], curr_num_frames, embed_dim])
|
36 |
+
new_temporal_embed[:, :load_num_frames] = load_temporal_embed
|
37 |
+
elif load_temporal_fix in ['interp', 'bilinear']:
|
38 |
+
# interpolate
|
39 |
+
# unsqueeze so pytorch thinks its an image
|
40 |
+
mode = 'nearest'
|
41 |
+
if load_temporal_fix == 'bilinear':
|
42 |
+
mode = 'bilinear'
|
43 |
+
load_temporal_embed = load_temporal_embed.unsqueeze(0)
|
44 |
+
new_temporal_embed = F.interpolate(load_temporal_embed,
|
45 |
+
(curr_num_frames, embed_dim), mode=mode).squeeze(0)
|
46 |
+
else:
|
47 |
+
raise NotImplementedError
|
48 |
+
new_state_dict['visual.temporal_embed'] = new_temporal_embed
|
49 |
+
# allow loading with smaller spatial patches. assumes custom border crop, to append the
|
50 |
+
# border patches to the input sequence
|
51 |
+
if 'visual.pos_embed' in new_state_dict and 'visual.pos_embed' in curr_keys:
|
52 |
+
load_pos_embed = new_state_dict['visual.pos_embed']
|
53 |
+
load_num_patches = load_pos_embed.shape[1]
|
54 |
+
curr_pos_embed = current_model_state_dict['visual.pos_embed']
|
55 |
+
if load_num_patches != curr_pos_embed.shape[1]:
|
56 |
+
raise NotImplementedError(
|
57 |
+
'Loading models with different spatial resolution / patch number not yet implemented, sorry.')
|
58 |
+
|
59 |
+
return new_state_dict
|
60 |
+
|
61 |
+
|
62 |
+
def rsetattr(obj, attr, val):
|
63 |
+
pre, _, post = attr.rpartition('.')
|
64 |
+
return setattr(rgetattr(obj, pre) if pre else obj, post, val)
|
65 |
+
|
66 |
+
|
67 |
+
def rgetattr(obj, attr, *args):
|
68 |
+
def _getattr(obj, attr):
|
69 |
+
return getattr(obj, attr, *args)
|
70 |
+
return functools.reduce(_getattr, [obj] + attr.split('.'))
|
71 |
+
|
72 |
+
|
73 |
+
# util functions to convert CLIP-style model keys to TimeSformer-style
|
74 |
+
def remap_keys(clip_state_dict, transformer_layers=12):
|
75 |
+
remapped_state_dict = OrderedDict()
|
76 |
+
key_mapping = {
|
77 |
+
"class_embedding": "cls_token",
|
78 |
+
"positional_embedding": "pos_embed",
|
79 |
+
"conv1.weight": "patch_embed.proj.weight",
|
80 |
+
"ln_pre.weight": "ln_pre.weight",
|
81 |
+
"ln_pre.bias": "ln_pre.bias",
|
82 |
+
"ln_post.weight": "norm.weight",
|
83 |
+
"ln_post.bias": "norm.bias",
|
84 |
+
}
|
85 |
+
for layer in range(transformer_layers):
|
86 |
+
key_mapping[f"transformer.resblocks.{layer}.attn.in_proj_weight"] = f"blocks.{layer}.attn.qkv.weight"
|
87 |
+
key_mapping[f"transformer.resblocks.{layer}.attn.in_proj_bias"] = f"blocks.{layer}.attn.qkv.bias"
|
88 |
+
key_mapping[f"transformer.resblocks.{layer}.attn.out_proj.weight"] = f"blocks.{layer}.attn.proj.weight"
|
89 |
+
key_mapping[f"transformer.resblocks.{layer}.attn.out_proj.bias"] = f"blocks.{layer}.attn.proj.bias"
|
90 |
+
key_mapping[f"transformer.resblocks.{layer}.ln_1.weight"] = f"blocks.{layer}.norm1.weight"
|
91 |
+
key_mapping[f"transformer.resblocks.{layer}.ln_1.bias"] = f"blocks.{layer}.norm1.bias"
|
92 |
+
key_mapping[f"transformer.resblocks.{layer}.mlp.c_fc.weight"] = f"blocks.{layer}.mlp.fc1.weight"
|
93 |
+
key_mapping[f"transformer.resblocks.{layer}.mlp.c_fc.bias"] = f"blocks.{layer}.mlp.fc1.bias"
|
94 |
+
key_mapping[f"transformer.resblocks.{layer}.mlp.c_proj.weight"] = f"blocks.{layer}.mlp.fc2.weight"
|
95 |
+
key_mapping[f"transformer.resblocks.{layer}.mlp.c_proj.bias"] = f"blocks.{layer}.mlp.fc2.bias"
|
96 |
+
key_mapping[f"transformer.resblocks.{layer}.ln_2.weight"] = f"blocks.{layer}.norm2.weight"
|
97 |
+
key_mapping[f"transformer.resblocks.{layer}.ln_2.bias"] = f"blocks.{layer}.norm2.bias"
|
98 |
+
|
99 |
+
for key in clip_state_dict:
|
100 |
+
if key == 'proj':
|
101 |
+
continue # due to possible dim mismatch, we load this later
|
102 |
+
if key == "class_embedding":
|
103 |
+
clip_state_dict[key] = clip_state_dict[key].unsqueeze(0).unsqueeze(0)
|
104 |
+
if key == "positional_embedding":
|
105 |
+
clip_state_dict[key] = clip_state_dict[key].unsqueeze(0)
|
106 |
+
remapped_state_dict[key_mapping[key]] = clip_state_dict[key]
|
107 |
+
|
108 |
+
return remapped_state_dict
|
lavila/utils/distributed.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import os
|
8 |
+
import shutil
|
9 |
+
import torch
|
10 |
+
import torch.distributed as dist
|
11 |
+
|
12 |
+
|
13 |
+
def get_model(model):
|
14 |
+
if isinstance(model, torch.nn.DataParallel) \
|
15 |
+
or isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
16 |
+
return model.module
|
17 |
+
else:
|
18 |
+
return model
|
19 |
+
|
20 |
+
|
21 |
+
def setup_for_distributed(is_master):
|
22 |
+
"""
|
23 |
+
This function disables printing when not in master process
|
24 |
+
"""
|
25 |
+
import builtins as __builtin__
|
26 |
+
builtin_print = __builtin__.print
|
27 |
+
|
28 |
+
def print(*args, **kwargs):
|
29 |
+
force = kwargs.pop('force', False)
|
30 |
+
if is_master or force:
|
31 |
+
builtin_print(*args, **kwargs)
|
32 |
+
|
33 |
+
__builtin__.print = print
|
34 |
+
|
35 |
+
|
36 |
+
def is_dist_avail_and_initialized():
|
37 |
+
if not dist.is_available():
|
38 |
+
return False
|
39 |
+
if not dist.is_initialized():
|
40 |
+
return False
|
41 |
+
return True
|
42 |
+
|
43 |
+
|
44 |
+
def get_world_size():
|
45 |
+
if not is_dist_avail_and_initialized():
|
46 |
+
return 1
|
47 |
+
else:
|
48 |
+
return dist.get_world_size()
|
49 |
+
|
50 |
+
|
51 |
+
def get_rank():
|
52 |
+
if not is_dist_avail_and_initialized():
|
53 |
+
return 0
|
54 |
+
return dist.get_rank()
|
55 |
+
|
56 |
+
|
57 |
+
def is_main_process():
|
58 |
+
return get_rank() == 0
|
59 |
+
|
60 |
+
|
61 |
+
def save_on_master(state, is_best, output_dir, is_epoch=True):
|
62 |
+
if is_main_process():
|
63 |
+
ckpt_path = f'{output_dir}/checkpoint.pt'
|
64 |
+
best_path = f'{output_dir}/checkpoint_best.pt'
|
65 |
+
if is_best:
|
66 |
+
torch.save(state, best_path)
|
67 |
+
if is_epoch:
|
68 |
+
if isinstance(state['epoch'], int):
|
69 |
+
ckpt2_path = '{}/checkpoint_{:04d}.pt'.format(output_dir, state['epoch'])
|
70 |
+
else:
|
71 |
+
ckpt2_path = '{}/checkpoint_{:.4f}.pt'.format(output_dir, state['epoch'])
|
72 |
+
torch.save(state, ckpt_path)
|
73 |
+
shutil.copy(ckpt_path, ckpt2_path)
|
74 |
+
|
75 |
+
|
76 |
+
def init_distributed_mode(args):
|
77 |
+
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
78 |
+
args.rank = int(os.environ["RANK"])
|
79 |
+
args.world_size = int(os.environ['WORLD_SIZE'])
|
80 |
+
args.gpu = int(os.environ['LOCAL_RANK'])
|
81 |
+
elif 'SLURM_PROCID' in os.environ:
|
82 |
+
args.rank = int(os.environ['SLURM_PROCID'])
|
83 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
84 |
+
else:
|
85 |
+
print('Not using distributed mode')
|
86 |
+
args.distributed = False
|
87 |
+
return
|
88 |
+
|
89 |
+
args.distributed = True
|
90 |
+
|
91 |
+
torch.cuda.set_device(args.gpu)
|
92 |
+
args.dist_backend = 'nccl'
|
93 |
+
print('| distributed init (rank {}): {}'.format(
|
94 |
+
args.rank, args.dist_url), flush=True)
|
95 |
+
torch.distributed.init_process_group(
|
96 |
+
backend=args.dist_backend,
|
97 |
+
init_method=args.dist_url,
|
98 |
+
world_size=args.world_size,
|
99 |
+
rank=args.rank
|
100 |
+
)
|
101 |
+
torch.distributed.barrier()
|
102 |
+
setup_for_distributed(args.rank == 0)
|
lavila/utils/evaluation.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
def accuracy(output, target, topk=(1,)):
|
12 |
+
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
13 |
+
with torch.no_grad():
|
14 |
+
maxk = max(topk)
|
15 |
+
batch_size = target.size(0)
|
16 |
+
|
17 |
+
_, pred = output.topk(maxk, 1, True, True)
|
18 |
+
pred = pred.t()
|
19 |
+
correct = pred.eq(target.reshape(1, -1).expand_as(pred))
|
20 |
+
|
21 |
+
res = []
|
22 |
+
for k in topk:
|
23 |
+
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
|
24 |
+
res.append(correct_k.mul_(100.0 / batch_size))
|
25 |
+
return res
|
26 |
+
|
27 |
+
|
28 |
+
def get_mean_accuracy(cm):
|
29 |
+
list_acc = []
|
30 |
+
for i in range(len(cm)):
|
31 |
+
acc = 0
|
32 |
+
if cm[i, :].sum() > 0:
|
33 |
+
acc = cm[i, i] / cm[i, :].sum()
|
34 |
+
list_acc.append(acc)
|
35 |
+
|
36 |
+
return 100 * np.mean(list_acc), 100 * np.trace(cm) / np.sum(cm)
|
lavila/utils/evaluation_charades.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
|
10 |
+
def compute_map(submission_array, gt_array):
|
11 |
+
""" Returns mAP, weighted mAP, and AP array """
|
12 |
+
m_aps = []
|
13 |
+
n_classes = submission_array.shape[1]
|
14 |
+
for oc_i in range(n_classes):
|
15 |
+
sorted_idxs = np.argsort(-submission_array[:, oc_i])
|
16 |
+
tp = gt_array[:, oc_i][sorted_idxs] == 1
|
17 |
+
fp = np.invert(tp)
|
18 |
+
n_pos = tp.sum()
|
19 |
+
if n_pos < 0.1:
|
20 |
+
m_aps.append(float('nan'))
|
21 |
+
continue
|
22 |
+
fp.sum()
|
23 |
+
f_pcs = np.cumsum(fp)
|
24 |
+
t_pcs = np.cumsum(tp)
|
25 |
+
prec = t_pcs / (f_pcs+t_pcs).astype(float)
|
26 |
+
avg_prec = 0
|
27 |
+
for i in range(submission_array.shape[0]):
|
28 |
+
if tp[i]:
|
29 |
+
avg_prec += prec[i]
|
30 |
+
m_aps.append(avg_prec / n_pos.astype(float))
|
31 |
+
m_aps = np.array(m_aps)
|
32 |
+
m_ap = np.mean(m_aps)
|
33 |
+
w_ap = (m_aps * gt_array.sum(axis=0) / gt_array.sum().sum().astype(float))
|
34 |
+
return m_ap, w_ap, m_aps
|
35 |
+
|
36 |
+
|
37 |
+
def charades_map(submission_array, gt_array):
|
38 |
+
"""
|
39 |
+
Approximate version of the charades evaluation function
|
40 |
+
For precise numbers, use the submission file with the official matlab script
|
41 |
+
"""
|
42 |
+
fix = submission_array.copy()
|
43 |
+
empty = np.sum(gt_array, axis=1) == 0
|
44 |
+
fix[empty, :] = np.NINF
|
45 |
+
return compute_map(fix, gt_array)
|
46 |
+
|
47 |
+
|
48 |
+
def create_submission(video_list, predictions, out_file):
|
49 |
+
assert len(video_list) == predictions.shape[0]
|
50 |
+
with open(out_file, 'w') as f:
|
51 |
+
for i, video_id in enumerate(video_list):
|
52 |
+
pred_str = ' '.join(map(lambda x: str(x), predictions[i].tolist()))
|
53 |
+
f.write('{} {}\n\n'.format(video_id, pred_str))
|
lavila/utils/evaluation_egomcq.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
|
10 |
+
def egomcq_accuracy_metrics(preds, labels, types):
|
11 |
+
metrics = {}
|
12 |
+
type_list = torch.unique(types)
|
13 |
+
group_list = ["Intra-video", "Inter-video"]
|
14 |
+
for type_i, group_i in zip(type_list, group_list):
|
15 |
+
correct = 0
|
16 |
+
total = 0
|
17 |
+
for pred, label, type in zip(preds, labels, types):
|
18 |
+
if type == type_i:
|
19 |
+
pred_ = torch.argmax(pred)
|
20 |
+
if pred_.item() == label.item():
|
21 |
+
correct += 1
|
22 |
+
total += 1
|
23 |
+
accuracy = correct/total
|
24 |
+
metrics[group_i] = accuracy * 100
|
25 |
+
return metrics
|
lavila/utils/evaluation_ek100cls.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Part of the code is from https://github.com/fpv-iplab/rulstm/blob/master/RULSTM/utils.py
|
8 |
+
# Modified by Yue Zhao
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
|
13 |
+
def get_marginal_indexes(actions, mode):
|
14 |
+
"""For each verb/noun retrieve the list of actions containing that verb/name
|
15 |
+
Input:
|
16 |
+
mode: "verb" or "noun"
|
17 |
+
Output:
|
18 |
+
a list of numpy array of indexes. If verb/noun 3 is contained in actions 2,8,19,
|
19 |
+
then output[3] will be np.array([2,8,19])
|
20 |
+
"""
|
21 |
+
vi = []
|
22 |
+
for v in range(actions[mode].max()+1):
|
23 |
+
vals = actions[actions[mode] == v].index.values
|
24 |
+
if len(vals) > 0:
|
25 |
+
vi.append(vals)
|
26 |
+
else:
|
27 |
+
vi.append(np.array([0]))
|
28 |
+
return vi
|
29 |
+
|
30 |
+
|
31 |
+
def marginalize(probs, indexes):
|
32 |
+
mprobs = []
|
33 |
+
for ilist in indexes:
|
34 |
+
mprobs.append(probs[:, ilist].sum(1))
|
35 |
+
return np.array(mprobs).T
|
lavila/utils/evaluation_ek100mir.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Part of the code is from
|
8 |
+
# `https://github.com/mwray/Joint-Part-of-Speech-Embeddings/tree/main/src/evaluation/NDCG.py`
|
9 |
+
# and
|
10 |
+
# `https://github.com/mwray/Joint-Part-of-Speech-Embeddings/tree/main/src/evaluation/mAP.py`
|
11 |
+
# Modified by Yue Zhao
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
|
15 |
+
|
16 |
+
def calculate_DCG(similarity_matrix, relevancy_matrix, k_counts):
|
17 |
+
"""
|
18 |
+
Calculates the Discounted Cumulative Gain (DCG) between two modalities for
|
19 |
+
the first modality.
|
20 |
+
DCG = \sum_{i=1}^k \frac{rel_i}{log_2(i + 1)}
|
21 |
+
i.e. the sum of the k relevant retrievals which is calculated as the scaled
|
22 |
+
relevancy for the ith item. The scale is designed such that early
|
23 |
+
retrievals are more important than later retrievals.
|
24 |
+
Params:
|
25 |
+
- similarity_matrix: matrix of size n1 x n2 where n1 is the number of
|
26 |
+
items in the first modality and n2 is the number of items in the
|
27 |
+
second modality. The [ith,jth] element is the predicted similarity
|
28 |
+
between the ith item from the first modality and the jth item from
|
29 |
+
the second modality.
|
30 |
+
- relevancy_matrix: matrix of size n1 x n2 (see similarity_matrix
|
31 |
+
above). The [ith, jth] element is the semantic relevancy between the
|
32 |
+
ith item from the first modality and the jth item from the second
|
33 |
+
modality.
|
34 |
+
- k_counts: matrix of size n1 x n2 (see similarity_matrix above) which
|
35 |
+
includes information on which items to use to calculate the DCG for
|
36 |
+
(see calculate_k_counts for more info on this matrix).
|
37 |
+
Returns:
|
38 |
+
- The DCG for each item in the first modality, a n1 length vector.
|
39 |
+
"""
|
40 |
+
x_sz, y_sz = similarity_matrix.shape
|
41 |
+
ranks = np.argsort(similarity_matrix)[:, ::-1]
|
42 |
+
# Create vector of size (n,) where n is the length of the last dimension in
|
43 |
+
# similarity matrix
|
44 |
+
# This vector is of the form log(i+1)
|
45 |
+
logs = np.log2(np.arange(y_sz) + 2)
|
46 |
+
# Convert logs into the divisor for the DCG calculation, of size similarity
|
47 |
+
# matrix
|
48 |
+
divisors = np.repeat(np.expand_dims(logs, axis=0), x_sz, axis=0)
|
49 |
+
|
50 |
+
# mask out the sorted relevancy matrix to only use the first k relevant
|
51 |
+
# retrievals for each item.
|
52 |
+
columns = np.repeat(np.expand_dims(np.arange(x_sz), axis=1), y_sz, axis=1)
|
53 |
+
numerators = relevancy_matrix[columns, ranks] * k_counts
|
54 |
+
# Calculate the final DCG score (note that this isn't expected to sum to 1)
|
55 |
+
return np.sum(numerators / divisors, axis=1)
|
56 |
+
|
57 |
+
|
58 |
+
def calculate_k_counts(relevancy_matrix):
|
59 |
+
"""
|
60 |
+
Works out the maximum number of allowed retrievals when working out the
|
61 |
+
Discounted Cumulative Gain. For each query the DCG only uses the first k
|
62 |
+
items retrieved which constitute the k relevant items for that query
|
63 |
+
(otherwise the nDCG scores can be deceptively high for bad rankings).
|
64 |
+
Params:
|
65 |
+
- relevancy_matrix: matrix of size n1 x n2 where n1 is the number of
|
66 |
+
items in the first modality and n2 is the number of items in the
|
67 |
+
second modality. The [ith, jth] element is the semantic relevancy
|
68 |
+
between the ith item from the first modality and the jth item from
|
69 |
+
the second modality.
|
70 |
+
Returns:
|
71 |
+
- Matrix of size n1 x n2 (see relevancy matrix for more info). This is
|
72 |
+
created as a mask such that if the [ith, jth] element is 1 it
|
73 |
+
represents a valid item to use for the calculation of DCG for the
|
74 |
+
ith item after sorting. For example, if relevancy matrix of:
|
75 |
+
[[1, 0.5, 0],
|
76 |
+
[0, 0 , 1]]
|
77 |
+
is given, then the k_counts matrix will be:
|
78 |
+
[[1, 1, 0],
|
79 |
+
[1, 0, 0]]
|
80 |
+
i.e. the first row has 2 non-zero items, so the first two retrieved
|
81 |
+
items should be used in the calculation. In the second row there is
|
82 |
+
only 1 relevant item, therefore only the first retrieved item should
|
83 |
+
be used for the DCG calculation.
|
84 |
+
"""
|
85 |
+
return (np.sort(relevancy_matrix)[:, ::-1] > 0).astype(int)
|
86 |
+
|
87 |
+
|
88 |
+
def calculate_IDCG(relevancy_matrix, k_counts):
|
89 |
+
"""
|
90 |
+
Calculates the Ideal Discounted Cumulative Gain (IDCG) which is the value
|
91 |
+
of the Discounted Cumulative Gain (DCG) for a perfect retrieval, i.e. the
|
92 |
+
items in the second modality were retrieved in order of their descending
|
93 |
+
relevancy.
|
94 |
+
Params:
|
95 |
+
- relevancy_matrix: matrix of size n1 x n2 where n1 is the number of
|
96 |
+
items in the first modality and n2 is the number of items in the
|
97 |
+
second modality. The [ith, jth] element is the semantic relevancy
|
98 |
+
between the ith item from the first modality and the jth item from
|
99 |
+
the second modality.
|
100 |
+
- k_counts: matrix of size n1 x n2 (see similarity_matrix above) which
|
101 |
+
includes information on which items to use to calculate the DCG for
|
102 |
+
(see calculate_k_counts for more info on this matrix).
|
103 |
+
"""
|
104 |
+
return calculate_DCG(relevancy_matrix, relevancy_matrix, k_counts)
|
105 |
+
|
106 |
+
|
107 |
+
def calculate_nDCG(similarity_matrix, relevancy_matrix, k_counts=None, IDCG=None, reduction='mean'):
|
108 |
+
"""
|
109 |
+
Calculates the normalised Discounted Cumulative Gain (nDCG) between two
|
110 |
+
modalities for the first modality using the Discounted Cumulative Gain
|
111 |
+
(DCG) and the Ideal Discounted Cumulative Gain (IDCG).
|
112 |
+
nDCG = \frac{DCG}{IDCG}
|
113 |
+
Params:
|
114 |
+
- similarity_matrix: matrix of size n1 x n2 where n1 is the number of
|
115 |
+
items in the first modality and n2 is the number of items in the second
|
116 |
+
modality. The [ith,jth] element is the predicted similarity between
|
117 |
+
the ith item from the first modality and the jth item from the second
|
118 |
+
modality.
|
119 |
+
- relevancy_matrix: matrix of size n1 x n2 (see similarity_matrix
|
120 |
+
above). The [ith, jth] element is the semantic relevancy between the
|
121 |
+
ith item from the first modality and the jth item from the second
|
122 |
+
modality.
|
123 |
+
- k_counts: optional parameter: matrix of size n1 x n2 (see
|
124 |
+
similarity_matrix above) which includes information on which items to
|
125 |
+
use to calculate the DCG for (see calculate_k_counts for more info on
|
126 |
+
this matrix). This will be calculated using calculate_IDCG if not
|
127 |
+
present, but should be pre-processed for efficiency.
|
128 |
+
- IDCG: Optional parameter which includes the pre-processed Ideal
|
129 |
+
Discounted Cumulative Gain (IDCG). This is a vector of size n1 (see
|
130 |
+
similarity_matrix above) which contains the IDCG value for each item
|
131 |
+
from the first modality. This will be calculated using calculate_IDCG
|
132 |
+
if not present, but should be pre-processed for efficiency.
|
133 |
+
- reduction: what to use to reduce the different nDCG scores. By
|
134 |
+
default this applies np.mean across all different queries.
|
135 |
+
Returns:
|
136 |
+
- The nDCG values for the first modality.
|
137 |
+
"""
|
138 |
+
if k_counts is None:
|
139 |
+
k_counts = calculate_k_counts(relevancy_matrix)
|
140 |
+
DCG = calculate_DCG(similarity_matrix, relevancy_matrix, k_counts)
|
141 |
+
if IDCG is None:
|
142 |
+
IDCG = calculate_IDCG(relevancy_matrix, k_counts)
|
143 |
+
if reduction == 'mean':
|
144 |
+
return np.mean(DCG / IDCG)
|
145 |
+
elif reduction is None:
|
146 |
+
return DCG / IDCG
|
147 |
+
|
148 |
+
|
149 |
+
def calculate_mAP(sim_mat, relevancy_matrix):
|
150 |
+
"""
|
151 |
+
Computes the mean average precision according to the following formula of
|
152 |
+
average precision:
|
153 |
+
\frac{\sum_{k=1}^n p(k) x rel(k)}{num_rel_docs}
|
154 |
+
where p(k) is the precision at k, rel(k) is an indicator function
|
155 |
+
determining whether the kth returned item is relevant or not and
|
156 |
+
num_rel_docs is the number of relevant items to find within the search.
|
157 |
+
The mean average precision is the mean of the average precision for each
|
158 |
+
query item (i.e row in the matrix)
|
159 |
+
This function takes in two parameters:
|
160 |
+
- sim_mat: a NxM matrix which represents the similarity between two
|
161 |
+
modalities (with modality 1 being of size N and modality 2 of size M).
|
162 |
+
- relevancy_matrix: an NxM matrix which represents the relevancy between two
|
163 |
+
modalities of items (with modality 1 being of size N and modality 2 of
|
164 |
+
size M).
|
165 |
+
"""
|
166 |
+
# Find the order of the items in modality 2 according to modality 1
|
167 |
+
ranked_order = (-sim_mat).argsort()
|
168 |
+
ranked_sim_mat = sim_mat[np.arange(sim_mat.shape[0])[:, None], ranked_order]
|
169 |
+
# re-order the relevancy matrix to accommodate the proposals
|
170 |
+
ranked_rel_mat = relevancy_matrix[np.arange(relevancy_matrix.shape[0])[:, None], ranked_order]
|
171 |
+
|
172 |
+
# find the number of relevant items found at each k
|
173 |
+
cumulative_rel_mat = np.cumsum(ranked_rel_mat, axis=1)
|
174 |
+
# Mask this ensuring that it is non zero if the kth term is 1 (rel(k) above)
|
175 |
+
cumulative_rel_mat[ranked_rel_mat != 1] = 0
|
176 |
+
# find the divisor for p(k)
|
177 |
+
divisor = np.arange(ranked_rel_mat.shape[1]) + 1
|
178 |
+
|
179 |
+
# find the number of relevant docs per query item
|
180 |
+
number_rel_docs = np.sum(ranked_rel_mat == 1, axis=1)
|
181 |
+
|
182 |
+
# find the average precision per query, within np.sum finds p(k) * rel(k)
|
183 |
+
avg_precision = np.sum(cumulative_rel_mat / divisor, axis=1) / number_rel_docs
|
184 |
+
mAP = np.mean(avg_precision)
|
185 |
+
return mAP
|
186 |
+
|
187 |
+
|
188 |
+
def get_mAP(similarity_matrix, rel_matrix):
|
189 |
+
vis_map = calculate_mAP(similarity_matrix, rel_matrix)
|
190 |
+
txt_map = calculate_mAP(similarity_matrix.T, rel_matrix.T)
|
191 |
+
return vis_map, txt_map, (vis_map + txt_map) / 2
|
192 |
+
|
193 |
+
|
194 |
+
def get_nDCG(similarity_matrix, rel_matrix):
|
195 |
+
vis_k_counts = calculate_k_counts(rel_matrix)
|
196 |
+
txt_k_counts = calculate_k_counts(rel_matrix.T)
|
197 |
+
vis_IDCG = calculate_IDCG(rel_matrix, vis_k_counts)
|
198 |
+
txt_IDCG = calculate_IDCG(rel_matrix.T, txt_k_counts)
|
199 |
+
vis_nDCG = calculate_nDCG(similarity_matrix, rel_matrix, k_counts=vis_k_counts, IDCG=vis_IDCG)
|
200 |
+
txt_nDCG = calculate_nDCG(similarity_matrix.T, rel_matrix.T, k_counts=txt_k_counts, IDCG=txt_IDCG)
|
201 |
+
return vis_nDCG, txt_nDCG, (vis_nDCG + txt_nDCG) / 2
|
lavila/utils/meter.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.distributed as dist
|
9 |
+
from lavila.utils import distributed as dist_utils
|
10 |
+
|
11 |
+
|
12 |
+
class AverageMeter(object):
|
13 |
+
"""Computes and stores the average and current value"""
|
14 |
+
def __init__(self, name, fmt=':f'):
|
15 |
+
self.name = name
|
16 |
+
self.fmt = fmt
|
17 |
+
self.reset()
|
18 |
+
|
19 |
+
def reset(self):
|
20 |
+
self.val = 0
|
21 |
+
self.avg = 0
|
22 |
+
self.sum = 0
|
23 |
+
self.count = 0
|
24 |
+
|
25 |
+
def update(self, val, n=1):
|
26 |
+
self.val = val
|
27 |
+
self.sum += val * n
|
28 |
+
self.count += n
|
29 |
+
self.avg = self.sum / self.count
|
30 |
+
|
31 |
+
def synchronize(self):
|
32 |
+
if not dist_utils.is_dist_avail_and_initialized():
|
33 |
+
return
|
34 |
+
t = torch.tensor([self.sum, self.count], dtype=torch.float64, device='cuda')
|
35 |
+
dist.barrier()
|
36 |
+
dist.all_reduce(t)
|
37 |
+
t = t.tolist()
|
38 |
+
self.sum = int(t[0])
|
39 |
+
self.count = t[1]
|
40 |
+
self.avg = self.sum / self.count
|
41 |
+
|
42 |
+
def __str__(self):
|
43 |
+
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
|
44 |
+
return fmtstr.format(**self.__dict__)
|
45 |
+
|
46 |
+
|
47 |
+
class ProgressMeter(object):
|
48 |
+
def __init__(self, num_batches, meters, prefix=""):
|
49 |
+
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
|
50 |
+
self.meters = meters
|
51 |
+
self.prefix = prefix
|
52 |
+
|
53 |
+
def display(self, batch):
|
54 |
+
entries = [self.prefix + self.batch_fmtstr.format(batch)]
|
55 |
+
entries += [str(meter) for meter in self.meters]
|
56 |
+
print('\t'.join(entries))
|
57 |
+
|
58 |
+
def synchronize(self):
|
59 |
+
for meter in self.meters:
|
60 |
+
meter.synchronize()
|
61 |
+
|
62 |
+
def _get_batch_fmtstr(self, num_batches):
|
63 |
+
num_digits = len(str(num_batches // 1))
|
64 |
+
fmt = '{:' + str(num_digits) + 'd}'
|
65 |
+
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
|