d22cs051
commited on
Commit
•
8273cb9
0
Parent(s):
retriying pushing the code
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +35 -0
- .gitignore +129 -0
- Dockerfile +32 -0
- README.md +11 -0
- app.py +71 -0
- config.py +149 -0
- fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.circleci/config.yml +159 -0
- fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.github/ISSUE_TEMPLATE.md +3 -0
- fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.github/ISSUE_TEMPLATE/bug_report.md +43 -0
- fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.github/ISSUE_TEMPLATE/documentation.md +15 -0
- fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.github/ISSUE_TEMPLATE/feature_request.md +24 -0
- fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.github/ISSUE_TEMPLATE/how-to-question.md +33 -0
- fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.github/PULL_REQUEST_TEMPLATE.md +16 -0
- fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.github/stale.yml +30 -0
- fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.github/workflows/build.yml +60 -0
- fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.github/workflows/build_wheels.yml +41 -0
- fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.gitignore +136 -0
- fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.gitmodules +4 -0
- fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.isort.cfg +2 -0
- fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.pre-commit-config.yaml +40 -0
- fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/CODE_OF_CONDUCT.md +77 -0
- fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/CONTRIBUTING.md +82 -0
- fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/LICENSE +21 -0
- fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/README.md +236 -0
- fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/.gitignore +2 -0
- fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/.gitignore +139 -0
- fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/CONFIG.md +41 -0
- fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/DATASET.md +34 -0
- fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/README.md +166 -0
- fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/endtask.md +41 -0
- fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/locallaunch.py +148 -0
- fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/__init__.py +12 -0
- fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/datasets/__init__.py +10 -0
- fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/datasets/fairseqmmdataset.py +57 -0
- fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/datasets/mmdataset.py +111 -0
- fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/evaluators/__init__.py +13 -0
- fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/evaluators/evaluator.py +54 -0
- fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/evaluators/metric.py +313 -0
- fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/evaluators/predictor.py +595 -0
- fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/losses/__init__.py +16 -0
- fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/losses/fairseqmmloss.py +63 -0
- fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/losses/loss.py +87 -0
- fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/losses/nce.py +156 -0
- fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/models/__init__.py +17 -0
- fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/models/fairseqmmmodel.py +51 -0
- fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/models/mmfusion.py +926 -0
- fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/models/mmfusionnlg.py +999 -0
- fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/models/transformermodel.py +734 -0
- fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/modules/__init__.py +10 -0
- fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/modules/mm.py +145 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
.python-version
|
86 |
+
|
87 |
+
# pipenv
|
88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
91 |
+
# install all needed dependencies.
|
92 |
+
#Pipfile.lock
|
93 |
+
|
94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
95 |
+
__pypackages__/
|
96 |
+
|
97 |
+
# Celery stuff
|
98 |
+
celerybeat-schedule
|
99 |
+
celerybeat.pid
|
100 |
+
|
101 |
+
# SageMath parsed files
|
102 |
+
*.sage.py
|
103 |
+
|
104 |
+
# Environments
|
105 |
+
.env
|
106 |
+
.venv
|
107 |
+
env/
|
108 |
+
venv/
|
109 |
+
ENV/
|
110 |
+
env.bak/
|
111 |
+
venv.bak/
|
112 |
+
|
113 |
+
# Spyder project settings
|
114 |
+
.spyderproject
|
115 |
+
.spyproject
|
116 |
+
|
117 |
+
# Rope project settings
|
118 |
+
.ropeproject
|
119 |
+
|
120 |
+
# mkdocs documentation
|
121 |
+
/site
|
122 |
+
|
123 |
+
# mypy
|
124 |
+
.mypy_cache/
|
125 |
+
.dmypy.json
|
126 |
+
dmypy.json
|
127 |
+
|
128 |
+
# Pyre type checker
|
129 |
+
.pyre/
|
Dockerfile
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.8.1-slim-buster
|
2 |
+
|
3 |
+
|
4 |
+
WORKDIR /code
|
5 |
+
|
6 |
+
COPY . /code
|
7 |
+
|
8 |
+
# RUN useradd -m -u 1000 user
|
9 |
+
|
10 |
+
RUN apt-get update
|
11 |
+
RUN apt-get install build-essential -y
|
12 |
+
# RUN pip install --no-cache-dir -r requirements.txt
|
13 |
+
RUN pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
|
14 |
+
# WORKDIR fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1
|
15 |
+
RUN pip install -e fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.
|
16 |
+
RUN pip install -r requirements.txt --no-cache-dir
|
17 |
+
RUN pip install gradio --no-cache-dir
|
18 |
+
RUN pip install protobuf==3.20.* --no-cache-dir
|
19 |
+
|
20 |
+
# Switch to the "user" user
|
21 |
+
# USER user
|
22 |
+
|
23 |
+
# Set home to the user's home directory
|
24 |
+
# ENV HOME=/home/user \
|
25 |
+
# PATH=/home/user/.local/bin:$PATH
|
26 |
+
|
27 |
+
# Set the working directory to the user's home directory
|
28 |
+
# WORKDIR $HOME/code
|
29 |
+
|
30 |
+
# COPY --chown=user . $HOME/code
|
31 |
+
# RUN ls -la $HOME/code
|
32 |
+
CMD ["python3", "app.py"]
|
README.md
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Audio Deepfake Detection
|
3 |
+
emoji: 🐢
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: purple
|
6 |
+
sdk: docker
|
7 |
+
pinned: false
|
8 |
+
license: mit
|
9 |
+
---
|
10 |
+
|
11 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from model import Model
|
4 |
+
from config import Config
|
5 |
+
|
6 |
+
import warnings
|
7 |
+
# warnings.filterwarnings('ignore')
|
8 |
+
|
9 |
+
# making config object
|
10 |
+
config = Config()
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
def infrence(audio_file1):
|
15 |
+
print(f"[LOG] Audio file: {audio_file1}")
|
16 |
+
|
17 |
+
class DFSeparationApp:
|
18 |
+
def __init__(self, model_path,device="cpu"):
|
19 |
+
self.device = device
|
20 |
+
self.model = self.load_model(model_path)
|
21 |
+
self.model.to(self.device)
|
22 |
+
|
23 |
+
|
24 |
+
def load_model(self, model_path):
|
25 |
+
checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
|
26 |
+
fine_tuned_model = Model(
|
27 |
+
args=config,
|
28 |
+
device=self.device
|
29 |
+
)
|
30 |
+
fine_tuned_model.load_state_dict(checkpoint["model"])
|
31 |
+
print("[LOG] Model loaded successfully.")
|
32 |
+
return fine_tuned_model
|
33 |
+
|
34 |
+
def predict(self, audio_file):
|
35 |
+
# Load the audio file
|
36 |
+
audio_tensor = torch.tensor(audio_file[1]).to(self.device)
|
37 |
+
with torch.no_grad():
|
38 |
+
# Make prediction
|
39 |
+
output = self.model(audio_tensor)
|
40 |
+
preds = output.argmax(dim=-1)
|
41 |
+
probs = output.softmax(dim=-1)
|
42 |
+
print(f"[LOG] Prediction: {preds.item()}")
|
43 |
+
print(f"[LOG] Probability: {probs.max().item()}")
|
44 |
+
return preds.item(), probs.max().item()
|
45 |
+
|
46 |
+
def run(self):
|
47 |
+
print(f"[LOG] Running the app...")
|
48 |
+
# gradio interface
|
49 |
+
audio_input1 = gr.Audio(label="Upload or record audio")
|
50 |
+
prediction = gr.Label(label="Prediction:")
|
51 |
+
prob = gr.Label(label="Probability:")
|
52 |
+
gr.Interface(
|
53 |
+
fn=self.predict,
|
54 |
+
inputs=[audio_input1],
|
55 |
+
outputs=[prediction, prob],
|
56 |
+
title="DF Separation",
|
57 |
+
description="This app classify the audio samples into Real and Fake.",
|
58 |
+
examples=[
|
59 |
+
["samples/Fake/download (5).wav","1"],
|
60 |
+
["samples/Fake/fake1_1.wav","1"],
|
61 |
+
["samples/Real/Central Avenue 1.wav","0"],
|
62 |
+
["samples/Real/hindi.mp3","0"],
|
63 |
+
]
|
64 |
+
).launch(quiet=False,server_name="0.0.0.0")
|
65 |
+
|
66 |
+
if __name__ == "__main__":
|
67 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
68 |
+
print(f"[LOG] Device: {device}")
|
69 |
+
model_path = "models/for_trained_model.ckpt" # Replace with your model path
|
70 |
+
app = DFSeparationApp(model_path, device=device)
|
71 |
+
app.run()
|
config.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class Config:
|
2 |
+
def __init__(self):
|
3 |
+
self.custom_data_dir = 'data/Dataset_Speech_Assignment'
|
4 |
+
self.for2sec_data_dir = 'data/for-2seconds'
|
5 |
+
self.batch_size = 32
|
6 |
+
self.num_workers = 4
|
7 |
+
self.num_epochs = 50
|
8 |
+
self.lr = 1e-3
|
9 |
+
self.model_checkpoint_path = 'models/Best_LA_model_for_DF.pth'
|
10 |
+
|
11 |
+
############################################################################
|
12 |
+
"""
|
13 |
+
parser.add_argument('--algo', type=int, default=3,
|
14 |
+
help='Rawboost algos discriptions. 0: No augmentation 1: LnL_convolutive_noise, 2: ISD_additive_noise, 3: SSI_additive_noise, 4: series algo (1+2+3), \
|
15 |
+
5: series algo (1+2), 6: series algo (1+3), 7: series algo(2+3), 8: parallel algo(1,2) .default=0]')
|
16 |
+
|
17 |
+
# LnL_convolutive_noise parameters
|
18 |
+
parser.add_argument('--nBands', type=int, default=5,
|
19 |
+
help='number of notch filters.The higher the number of bands, the more aggresive the distortions is.[default=5]')
|
20 |
+
parser.add_argument('--minF', type=int, default=20,
|
21 |
+
help='minimum centre frequency [Hz] of notch filter.[default=20] ')
|
22 |
+
parser.add_argument('--maxF', type=int, default=8000,
|
23 |
+
help='maximum centre frequency [Hz] (<sr/2) of notch filter.[default=8000]')
|
24 |
+
parser.add_argument('--minBW', type=int, default=100,
|
25 |
+
help='minimum width [Hz] of filter.[default=100] ')
|
26 |
+
parser.add_argument('--maxBW', type=int, default=1000,
|
27 |
+
help='maximum width [Hz] of filter.[default=1000] ')
|
28 |
+
parser.add_argument('--minCoeff', type=int, default=10,
|
29 |
+
help='minimum filter coefficients. More the filter coefficients more ideal the filter slope.[default=10]')
|
30 |
+
parser.add_argument('--maxCoeff', type=int, default=100,
|
31 |
+
help='maximum filter coefficients. More the filter coefficients more ideal the filter slope.[default=100]')
|
32 |
+
parser.add_argument('--minG', type=int, default=0,
|
33 |
+
help='minimum gain factor of linear component.[default=0]')
|
34 |
+
parser.add_argument('--maxG', type=int, default=0,
|
35 |
+
help='maximum gain factor of linear component.[default=0]')
|
36 |
+
parser.add_argument('--minBiasLinNonLin', type=int, default=5,
|
37 |
+
help=' minimum gain difference between linear and non-linear components.[default=5]')
|
38 |
+
parser.add_argument('--maxBiasLinNonLin', type=int, default=20,
|
39 |
+
help=' maximum gain difference between linear and non-linear components.[default=20]')
|
40 |
+
parser.add_argument('--N_f', type=int, default=5,
|
41 |
+
help='order of the (non-)linearity where N_f=1 refers only to linear components.[default=5]')
|
42 |
+
|
43 |
+
# ISD_additive_noise parameters
|
44 |
+
parser.add_argument('--P', type=int, default=10,
|
45 |
+
help='Maximum number of uniformly distributed samples in [%].[defaul=10]')
|
46 |
+
parser.add_argument('--g_sd', type=int, default=2,
|
47 |
+
help='gain parameters > 0. [default=2]')
|
48 |
+
|
49 |
+
# SSI_additive_noise parameters
|
50 |
+
parser.add_argument('--SNRmin', type=int, default=10,
|
51 |
+
help='Minimum SNR value for coloured additive noise.[defaul=10]')
|
52 |
+
parser.add_argument('--SNRmax', type=int, default=40,
|
53 |
+
help='Maximum SNR value for coloured additive noise.[defaul=40]')
|
54 |
+
"""
|
55 |
+
############################################################################
|
56 |
+
# conversion from agrparse to class object
|
57 |
+
self.algo = 3
|
58 |
+
self.nBands = 5
|
59 |
+
self.minF = 20
|
60 |
+
self.maxF = 8000
|
61 |
+
self.minBW = 100
|
62 |
+
self.maxBW = 1000
|
63 |
+
self.minCoeff = 10
|
64 |
+
self.maxCoeff = 100
|
65 |
+
self.minG = 0
|
66 |
+
self.maxG = 0
|
67 |
+
self.minBiasLinNonLin = 5
|
68 |
+
self.maxBiasLinNonLin = 20
|
69 |
+
self.N_f = 5
|
70 |
+
self.P = 10
|
71 |
+
self.g_sd = 2
|
72 |
+
self.SNRmin = 10
|
73 |
+
self.SNRmax = 40
|
74 |
+
|
75 |
+
|
76 |
+
#############################################################################
|
77 |
+
"""
|
78 |
+
parser.add_argument('--database_path', type=str, default='/your/path/to/data/ASVspoof_database/DF/', help='Change this to user\'s full directory address of LA database (ASVspoof2019- for training & development (used as validation), ASVspoof2021 DF for evaluation scores). We assume that all three ASVspoof 2019 LA train, LA dev and ASVspoof2021 DF eval data folders are in the same database_path directory.')
|
79 |
+
'''
|
80 |
+
% database_path/
|
81 |
+
% |- DF
|
82 |
+
% |- ASVspoof2021_DF_eval/flac
|
83 |
+
% |- ASVspoof2019_LA_train/flac
|
84 |
+
% |- ASVspoof2019_LA_dev/flac
|
85 |
+
'''
|
86 |
+
|
87 |
+
parser.add_argument('--protocols_path', type=str, default='database/', help='Change with path to user\'s DF database protocols directory address')
|
88 |
+
'''
|
89 |
+
% protocols_path/
|
90 |
+
% |- ASVspoof_LA_cm_protocols
|
91 |
+
% |- ASVspoof2021.LA.cm.eval.trl.txt
|
92 |
+
% |- ASVspoof2019.LA.cm.dev.trl.txt
|
93 |
+
% |- ASVspoof2019.LA.cm.train.trn.txt
|
94 |
+
|
95 |
+
% |- ASVspoof_DF_cm_protocols
|
96 |
+
% |- ASVspoof2021.DF.cm.eval.trl.txt
|
97 |
+
|
98 |
+
'''
|
99 |
+
|
100 |
+
# Hyperparameters
|
101 |
+
parser.add_argument('--batch_size', type=int, default=14)
|
102 |
+
parser.add_argument('--num_epochs', type=int, default=100)
|
103 |
+
parser.add_argument('--lr', type=float, default=0.000001)
|
104 |
+
parser.add_argument('--weight_decay', type=float, default=0.0001)
|
105 |
+
parser.add_argument('--loss', type=str, default='weighted_CCE')
|
106 |
+
# model
|
107 |
+
parser.add_argument('--seed', type=int, default=1234,
|
108 |
+
help='random seed (default: 1234)')
|
109 |
+
|
110 |
+
parser.add_argument('--model_path', type=str,
|
111 |
+
default=None, help='Model checkpoint')
|
112 |
+
parser.add_argument('--comment', type=str, default=None,
|
113 |
+
help='Comment to describe the saved model')
|
114 |
+
# Auxiliary arguments
|
115 |
+
parser.add_argument('--track', type=str, default='DF',choices=['LA', 'PA','DF'], help='LA/PA/DF')
|
116 |
+
parser.add_argument('--eval_output', type=str, default=None,
|
117 |
+
help='Path to save the evaluation result')
|
118 |
+
parser.add_argument('--eval', action='store_true', default=False,
|
119 |
+
help='eval mode')
|
120 |
+
parser.add_argument('--is_eval', action='store_true', default=False,help='eval database')
|
121 |
+
parser.add_argument('--eval_part', type=int, default=0)
|
122 |
+
# backend options
|
123 |
+
parser.add_argument('--cudnn-deterministic-toggle', action='store_false', \
|
124 |
+
default=True,
|
125 |
+
help='use cudnn-deterministic? (default true)')
|
126 |
+
|
127 |
+
parser.add_argument('--cudnn-benchmark-toggle', action='store_true', \
|
128 |
+
default=False,
|
129 |
+
help='use cudnn-benchmark? (default false)')
|
130 |
+
"""
|
131 |
+
|
132 |
+
self.weight_decay = 0.0001
|
133 |
+
self.loss = 'weighted_CCE'
|
134 |
+
self.seed = 1234
|
135 |
+
self.model_path = "models/LA_model.pth"
|
136 |
+
self.comment = None
|
137 |
+
self.track = 'DF'
|
138 |
+
self.eval_output = None
|
139 |
+
self.eval = False
|
140 |
+
self.is_eval = False
|
141 |
+
self.eval_part = 0
|
142 |
+
self.cudnn_deterministic_toggle = False
|
143 |
+
self.cudnn_benchmark_toggle = False
|
144 |
+
|
145 |
+
self.wandb_config = {
|
146 |
+
'project': 'Speech Assignment 3',
|
147 |
+
'run_name': 'LA_model',
|
148 |
+
}
|
149 |
+
|
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.circleci/config.yml
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Use 2.1 for orbs
|
2 |
+
version: 2.1
|
3 |
+
|
4 |
+
# -------------------------------------------------------------------------------------
|
5 |
+
# Environments to run the jobs in
|
6 |
+
# -------------------------------------------------------------------------------------
|
7 |
+
gpu: &gpu
|
8 |
+
environment:
|
9 |
+
CUDA_VERSION: "11.1"
|
10 |
+
machine:
|
11 |
+
image: ubuntu-1604-cuda-11.1:202012-01
|
12 |
+
resource_class: gpu.nvidia.medium.multi
|
13 |
+
|
14 |
+
|
15 |
+
# -------------------------------------------------------------------------------------
|
16 |
+
# Re-usable commands
|
17 |
+
# -------------------------------------------------------------------------------------
|
18 |
+
cache_key: &cache_key cache-key-{{ .Environment.CIRCLE_JOB }}-{{ checksum ".circleci/config.yml" }}-{{ checksum "setup.py"}}
|
19 |
+
|
20 |
+
install_dep_common: &install_dep_common
|
21 |
+
- run:
|
22 |
+
name: Install Common Dependencies
|
23 |
+
command: |
|
24 |
+
source activate fairseq
|
25 |
+
pip install --upgrade setuptools
|
26 |
+
pip install bitarray boto3 deepspeed editdistance fastBPE iopath ipdb ipython pyarrow pytest sacremoses sentencepiece subword-nmt hydra-core==1.0.7 omegaconf==2.0.6
|
27 |
+
pip install --progress-bar off pytest
|
28 |
+
pip install --progress-bar off fairscale
|
29 |
+
pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda111 -U
|
30 |
+
python -c 'import torch; print("Torch version:", torch.__version__)'
|
31 |
+
python -m torch.utils.collect_env
|
32 |
+
|
33 |
+
install_dep_fused_ops: &install_dep_fused_ops
|
34 |
+
- run:
|
35 |
+
name: Install Megatron/Apex Dependencies
|
36 |
+
working_directory: ~/
|
37 |
+
command: |
|
38 |
+
source activate fairseq
|
39 |
+
git clone https://github.com/NVIDIA/apex
|
40 |
+
cd apex
|
41 |
+
git checkout e2083df5eb96643c61613b9df48dd4eea6b07690
|
42 |
+
pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--deprecated_fused_adam" --global-option="--xentropy" --global-option="--fast_multihead_attn" ./
|
43 |
+
cd ~/
|
44 |
+
git clone --depth=1 --branch v2.4 https://github.com/NVIDIA/Megatron-LM.git
|
45 |
+
cd Megatron-LM
|
46 |
+
pip install -e .
|
47 |
+
|
48 |
+
|
49 |
+
install_dep_pt19: &install_dep_pt19
|
50 |
+
- run:
|
51 |
+
name: Install Pytorch Dependencies
|
52 |
+
command: |
|
53 |
+
source activate fairseq
|
54 |
+
pip install --upgrade setuptools
|
55 |
+
pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html
|
56 |
+
python -c 'import torch; print("Torch version:", torch.__version__)'
|
57 |
+
|
58 |
+
install_dep_pt18: &install_dep_pt18
|
59 |
+
- run:
|
60 |
+
name: Install Pytorch Dependencies
|
61 |
+
command: |
|
62 |
+
source activate fairseq
|
63 |
+
pip install --upgrade setuptools
|
64 |
+
pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
|
65 |
+
python -c 'import torch; print("Torch version:", torch.__version__)'
|
66 |
+
|
67 |
+
install_repo: &install_repo
|
68 |
+
- run:
|
69 |
+
name: Install Repository
|
70 |
+
command: |
|
71 |
+
source activate fairseq
|
72 |
+
pip install .
|
73 |
+
python setup.py build_ext --inplace
|
74 |
+
|
75 |
+
run_unittests: &run_unittests
|
76 |
+
- run:
|
77 |
+
name: Run Unit Tests
|
78 |
+
command: |
|
79 |
+
source activate fairseq
|
80 |
+
pytest tests/gpu/test_binaries_gpu.py
|
81 |
+
|
82 |
+
check_nvidia_driver: &check_nvidia_driver
|
83 |
+
- run:
|
84 |
+
name: Check NVIDIA Driver
|
85 |
+
working_directory: ~/
|
86 |
+
command: |
|
87 |
+
pyenv versions
|
88 |
+
nvidia-smi
|
89 |
+
|
90 |
+
create_conda_env: &create_conda_env
|
91 |
+
- run:
|
92 |
+
name: Install and Create Conda Environment
|
93 |
+
command: |
|
94 |
+
curl -o ~/miniconda.sh -O https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
|
95 |
+
chmod +x ~/miniconda.sh
|
96 |
+
~/miniconda.sh -b -p $HOME/miniconda
|
97 |
+
rm ~/miniconda.sh
|
98 |
+
echo 'export PATH=$HOME/miniconda/bin:$PATH' >> $BASH_ENV
|
99 |
+
source $BASH_ENV
|
100 |
+
if [ ! -d ~/miniconda/envs/fairseq ]
|
101 |
+
then
|
102 |
+
conda create -y -n fairseq python=3.8
|
103 |
+
fi
|
104 |
+
source activate fairseq
|
105 |
+
python --version
|
106 |
+
pip install --upgrade pip
|
107 |
+
# -------------------------------------------------------------------------------------
|
108 |
+
# Jobs to run
|
109 |
+
# -------------------------------------------------------------------------------------
|
110 |
+
|
111 |
+
jobs:
|
112 |
+
gpu_tests_pt19:
|
113 |
+
<<: *gpu
|
114 |
+
|
115 |
+
working_directory: ~/fairseq-py
|
116 |
+
|
117 |
+
steps:
|
118 |
+
- checkout
|
119 |
+
- <<: *check_nvidia_driver
|
120 |
+
- <<: *create_conda_env
|
121 |
+
- restore_cache:
|
122 |
+
key: *cache_key
|
123 |
+
- <<: *install_dep_pt19
|
124 |
+
- <<: *install_dep_common
|
125 |
+
- <<: *install_dep_fused_ops
|
126 |
+
- save_cache:
|
127 |
+
paths:
|
128 |
+
- ~/miniconda/
|
129 |
+
key: *cache_key
|
130 |
+
- <<: *install_repo
|
131 |
+
- <<: *run_unittests
|
132 |
+
|
133 |
+
gpu_tests_pt18:
|
134 |
+
<<: *gpu
|
135 |
+
|
136 |
+
working_directory: ~/fairseq-py
|
137 |
+
|
138 |
+
steps:
|
139 |
+
- checkout
|
140 |
+
- <<: *check_nvidia_driver
|
141 |
+
- <<: *create_conda_env
|
142 |
+
- restore_cache:
|
143 |
+
key: *cache_key
|
144 |
+
- <<: *install_dep_pt18
|
145 |
+
- <<: *install_dep_common
|
146 |
+
- <<: *install_dep_fused_ops
|
147 |
+
- save_cache:
|
148 |
+
paths:
|
149 |
+
- ~/miniconda/
|
150 |
+
key: *cache_key
|
151 |
+
- <<: *install_repo
|
152 |
+
- <<: *run_unittests
|
153 |
+
|
154 |
+
workflows:
|
155 |
+
version: 2
|
156 |
+
build:
|
157 |
+
jobs:
|
158 |
+
- gpu_tests_pt18
|
159 |
+
- gpu_tests_pt19
|
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.github/ISSUE_TEMPLATE.md
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
## 👉 [Please follow one of these issue templates](https://github.com/pytorch/fairseq/issues/new/choose) 👈
|
2 |
+
|
3 |
+
Note: to keep the backlog clean and actionable, issues may be immediately closed if they do not follow one of the above issue templates.
|
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.github/ISSUE_TEMPLATE/bug_report.md
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: 🐛 Bug Report
|
3 |
+
about: Submit a bug report to help us improve
|
4 |
+
labels: 'bug, needs triage'
|
5 |
+
---
|
6 |
+
|
7 |
+
## 🐛 Bug
|
8 |
+
|
9 |
+
<!-- A clear and concise description of what the bug is. -->
|
10 |
+
|
11 |
+
### To Reproduce
|
12 |
+
|
13 |
+
Steps to reproduce the behavior (**always include the command you ran**):
|
14 |
+
|
15 |
+
1. Run cmd '....'
|
16 |
+
2. See error
|
17 |
+
|
18 |
+
<!-- If you have a code sample, error messages, stack traces, please provide it here as well -->
|
19 |
+
|
20 |
+
|
21 |
+
#### Code sample
|
22 |
+
<!-- Ideally attach a minimal code sample to reproduce the decried issue.
|
23 |
+
Minimal means having the shortest code but still preserving the bug. -->
|
24 |
+
|
25 |
+
### Expected behavior
|
26 |
+
|
27 |
+
<!-- A clear and concise description of what you expected to happen. -->
|
28 |
+
|
29 |
+
### Environment
|
30 |
+
|
31 |
+
- fairseq Version (e.g., 1.0 or main):
|
32 |
+
- PyTorch Version (e.g., 1.0)
|
33 |
+
- OS (e.g., Linux):
|
34 |
+
- How you installed fairseq (`pip`, source):
|
35 |
+
- Build command you used (if compiling from source):
|
36 |
+
- Python version:
|
37 |
+
- CUDA/cuDNN version:
|
38 |
+
- GPU models and configuration:
|
39 |
+
- Any other relevant information:
|
40 |
+
|
41 |
+
### Additional context
|
42 |
+
|
43 |
+
<!-- Add any other context about the problem here. -->
|
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.github/ISSUE_TEMPLATE/documentation.md
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: 📚 Documentation/Typos
|
3 |
+
about: Report an issue related to documentation or a typo
|
4 |
+
labels: 'documentation, needs triage'
|
5 |
+
---
|
6 |
+
|
7 |
+
## 📚 Documentation
|
8 |
+
|
9 |
+
For typos and doc fixes, please go ahead and:
|
10 |
+
|
11 |
+
1. Create an issue.
|
12 |
+
2. Fix the typo.
|
13 |
+
3. Submit a PR.
|
14 |
+
|
15 |
+
Thanks!
|
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.github/ISSUE_TEMPLATE/feature_request.md
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: 🚀 Feature Request
|
3 |
+
about: Submit a proposal/request for a new feature
|
4 |
+
labels: 'enhancement, help wanted, needs triage'
|
5 |
+
---
|
6 |
+
|
7 |
+
## 🚀 Feature Request
|
8 |
+
<!-- A clear and concise description of the feature proposal -->
|
9 |
+
|
10 |
+
### Motivation
|
11 |
+
|
12 |
+
<!-- Please outline the motivation for the proposal. Is your feature request related to a problem? e.g., I'm always frustrated when [...]. If this is related to another GitHub issue, please link here too -->
|
13 |
+
|
14 |
+
### Pitch
|
15 |
+
|
16 |
+
<!-- A clear and concise description of what you want to happen. -->
|
17 |
+
|
18 |
+
### Alternatives
|
19 |
+
|
20 |
+
<!-- A clear and concise description of any alternative solutions or features you've considered, if any. -->
|
21 |
+
|
22 |
+
### Additional context
|
23 |
+
|
24 |
+
<!-- Add any other context or screenshots about the feature request here. -->
|
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.github/ISSUE_TEMPLATE/how-to-question.md
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: ❓ Questions/Help
|
3 |
+
about: If you have questions, please first search existing issues and docs
|
4 |
+
labels: 'question, needs triage'
|
5 |
+
---
|
6 |
+
|
7 |
+
## ❓ Questions and Help
|
8 |
+
|
9 |
+
### Before asking:
|
10 |
+
1. search the issues.
|
11 |
+
2. search the docs.
|
12 |
+
|
13 |
+
<!-- If you still can't find what you need: -->
|
14 |
+
|
15 |
+
#### What is your question?
|
16 |
+
|
17 |
+
#### Code
|
18 |
+
|
19 |
+
<!-- Please paste a code snippet if your question requires it! -->
|
20 |
+
|
21 |
+
#### What have you tried?
|
22 |
+
|
23 |
+
#### What's your environment?
|
24 |
+
|
25 |
+
- fairseq Version (e.g., 1.0 or main):
|
26 |
+
- PyTorch Version (e.g., 1.0)
|
27 |
+
- OS (e.g., Linux):
|
28 |
+
- How you installed fairseq (`pip`, source):
|
29 |
+
- Build command you used (if compiling from source):
|
30 |
+
- Python version:
|
31 |
+
- CUDA/cuDNN version:
|
32 |
+
- GPU models and configuration:
|
33 |
+
- Any other relevant information:
|
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.github/PULL_REQUEST_TEMPLATE.md
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Before submitting
|
2 |
+
|
3 |
+
- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
|
4 |
+
- [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/main/CONTRIBUTING.md)?
|
5 |
+
- [ ] Did you make sure to update the docs?
|
6 |
+
- [ ] Did you write any new necessary tests?
|
7 |
+
|
8 |
+
## What does this PR do?
|
9 |
+
Fixes # (issue).
|
10 |
+
|
11 |
+
## PR review
|
12 |
+
Anyone in the community is free to review the PR once the tests have passed.
|
13 |
+
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
|
14 |
+
|
15 |
+
## Did you have fun?
|
16 |
+
Make sure you had fun coding 🙃
|
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.github/stale.yml
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Configuration for probot-stale - https://github.com/probot/stale
|
2 |
+
# Mostly copied from github.com/facebook/react/blob/master/.github/stale.yml
|
3 |
+
# Number of days of inactivity before an issue becomes stale
|
4 |
+
daysUntilStale: 90
|
5 |
+
# Number of days of inactivity before a stale issue is closed
|
6 |
+
daysUntilClose: 7
|
7 |
+
# Issues with these labels will never be considered stale
|
8 |
+
exemptLabels:
|
9 |
+
- bug
|
10 |
+
# Label to use when marking an issue as stale
|
11 |
+
staleLabel: stale
|
12 |
+
issues:
|
13 |
+
# Comment to post when marking an issue as stale.
|
14 |
+
markComment: >
|
15 |
+
This issue has been automatically marked as stale.
|
16 |
+
**If this issue is still affecting you, please leave any comment** (for example, "bump"), and we'll keep it open.
|
17 |
+
We are sorry that we haven't been able to prioritize it yet. If you have any new additional information, please include it with your comment!
|
18 |
+
# Comment to post when closing a stale issue.
|
19 |
+
closeComment: >
|
20 |
+
Closing this issue after a prolonged period of inactivity. If this issue is still present in the latest release, please create a new issue with up-to-date information. Thank you!
|
21 |
+
pulls:
|
22 |
+
# Comment to post when marking a pull request as stale.
|
23 |
+
markComment: >
|
24 |
+
This pull request has been automatically marked as stale.
|
25 |
+
**If this pull request is still relevant, please leave any comment** (for example, "bump"), and we'll keep it open.
|
26 |
+
We are sorry that we haven't been able to prioritize reviewing it yet. Your contribution is very much appreciated.
|
27 |
+
# Comment to post when closing a stale pull request.
|
28 |
+
closeComment: >
|
29 |
+
Closing this pull request after a prolonged period of inactivity. If this issue is still present in the latest release, please ask for this pull request to be reopened. Thank you!
|
30 |
+
|
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.github/workflows/build.yml
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: build
|
2 |
+
|
3 |
+
on:
|
4 |
+
# Trigger the workflow on push to main or any pull request
|
5 |
+
push:
|
6 |
+
branches:
|
7 |
+
- main
|
8 |
+
pull_request:
|
9 |
+
|
10 |
+
jobs:
|
11 |
+
build:
|
12 |
+
|
13 |
+
strategy:
|
14 |
+
max-parallel: 4
|
15 |
+
matrix:
|
16 |
+
platform: [ubuntu-latest, macos-latest]
|
17 |
+
python-version: [3.8, 3.9]
|
18 |
+
|
19 |
+
runs-on: ${{ matrix.platform }}
|
20 |
+
|
21 |
+
steps:
|
22 |
+
- uses: actions/checkout@v2
|
23 |
+
|
24 |
+
- name: Set up Python ${{ matrix.python-version }}
|
25 |
+
uses: actions/setup-python@v2
|
26 |
+
with:
|
27 |
+
python-version: ${{ matrix.python-version }}
|
28 |
+
|
29 |
+
- name: Conditionally install pytorch
|
30 |
+
if: matrix.platform == 'windows-latest'
|
31 |
+
run: pip3 install torch -f https://download.pytorch.org/whl/torch_stable.html
|
32 |
+
|
33 |
+
- name: Install locally
|
34 |
+
run: |
|
35 |
+
python -m pip install --upgrade pip
|
36 |
+
git submodule update --init --recursive
|
37 |
+
python setup.py build_ext --inplace
|
38 |
+
python -m pip install --editable .
|
39 |
+
|
40 |
+
- name: Install optional test requirements
|
41 |
+
run: |
|
42 |
+
python -m pip install iopath transformers pyarrow
|
43 |
+
python -m pip install git+https://github.com/facebookresearch/fairscale.git@main
|
44 |
+
|
45 |
+
- name: Lint with flake8
|
46 |
+
run: |
|
47 |
+
pip install flake8
|
48 |
+
# stop the build if there are Python syntax errors or undefined names
|
49 |
+
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --extend-exclude fairseq/model_parallel/megatron
|
50 |
+
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
|
51 |
+
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --extend-exclude fairseq/model_parallel/megatron
|
52 |
+
|
53 |
+
- name: Run tests
|
54 |
+
run: |
|
55 |
+
python setup.py test
|
56 |
+
|
57 |
+
- name: Lint with black
|
58 |
+
run: |
|
59 |
+
pip install black
|
60 |
+
black --check . --extend-exclude 'examples|fairseq\/model_parallel\/megatron'
|
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.github/workflows/build_wheels.yml
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: build_wheels
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches:
|
6 |
+
- v[0-9]+.[0-9]+.[x0-9]+
|
7 |
+
tags:
|
8 |
+
- v*
|
9 |
+
|
10 |
+
jobs:
|
11 |
+
build_wheels:
|
12 |
+
name: Build wheels on ${{ matrix.os }}
|
13 |
+
runs-on: ${{ matrix.os }}
|
14 |
+
strategy:
|
15 |
+
matrix:
|
16 |
+
os: [ubuntu-latest, macos-latest]
|
17 |
+
|
18 |
+
steps:
|
19 |
+
- uses: actions/checkout@v2
|
20 |
+
|
21 |
+
- name: Install Python
|
22 |
+
uses: actions/setup-python@v2
|
23 |
+
with:
|
24 |
+
python-version: '3.7'
|
25 |
+
|
26 |
+
- name: Install cibuildwheel
|
27 |
+
run: |
|
28 |
+
python -m pip install cibuildwheel
|
29 |
+
|
30 |
+
- name: Build wheels for CPython
|
31 |
+
run: |
|
32 |
+
python -m cibuildwheel --output-dir dist
|
33 |
+
env:
|
34 |
+
CIBW_BUILD: "cp36-*64 cp37-*64 cp38-*64"
|
35 |
+
CIBW_MANYLINUX_X86_64_IMAGE: manylinux1
|
36 |
+
CIBW_BEFORE_BUILD: git submodule update --init --recursive && pip install .
|
37 |
+
|
38 |
+
- uses: actions/upload-artifact@v2
|
39 |
+
with:
|
40 |
+
name: wheels
|
41 |
+
path: ./dist/*.whl
|
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.gitignore
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# JetBrains PyCharm IDE
|
2 |
+
.idea/
|
3 |
+
|
4 |
+
# Byte-compiled / optimized / DLL files
|
5 |
+
__pycache__/
|
6 |
+
*.py[cod]
|
7 |
+
*$py.class
|
8 |
+
|
9 |
+
# C extensions
|
10 |
+
*.so
|
11 |
+
|
12 |
+
# macOS dir files
|
13 |
+
.DS_Store
|
14 |
+
|
15 |
+
# Distribution / packaging
|
16 |
+
.Python
|
17 |
+
env/
|
18 |
+
build/
|
19 |
+
develop-eggs/
|
20 |
+
dist/
|
21 |
+
downloads/
|
22 |
+
eggs/
|
23 |
+
.eggs/
|
24 |
+
lib/
|
25 |
+
lib64/
|
26 |
+
parts/
|
27 |
+
sdist/
|
28 |
+
var/
|
29 |
+
wheels/
|
30 |
+
*.egg-info/
|
31 |
+
.installed.cfg
|
32 |
+
*.egg
|
33 |
+
|
34 |
+
# Checkpoints
|
35 |
+
checkpoints
|
36 |
+
|
37 |
+
# PyInstaller
|
38 |
+
# Usually these files are written by a python script from a template
|
39 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
40 |
+
*.manifest
|
41 |
+
*.spec
|
42 |
+
|
43 |
+
# Installer logs
|
44 |
+
pip-log.txt
|
45 |
+
pip-delete-this-directory.txt
|
46 |
+
|
47 |
+
# Unit test / coverage reports
|
48 |
+
htmlcov/
|
49 |
+
.tox/
|
50 |
+
.coverage
|
51 |
+
.coverage.*
|
52 |
+
.cache
|
53 |
+
nosetests.xml
|
54 |
+
coverage.xml
|
55 |
+
*.cover
|
56 |
+
.hypothesis/
|
57 |
+
|
58 |
+
# Translations
|
59 |
+
*.mo
|
60 |
+
*.pot
|
61 |
+
|
62 |
+
# Django stuff:
|
63 |
+
*.log
|
64 |
+
local_settings.py
|
65 |
+
|
66 |
+
# Flask stuff:
|
67 |
+
instance/
|
68 |
+
.webassets-cache
|
69 |
+
|
70 |
+
# Scrapy stuff:
|
71 |
+
.scrapy
|
72 |
+
|
73 |
+
# Sphinx documentation
|
74 |
+
docs/_build/
|
75 |
+
|
76 |
+
# PyBuilder
|
77 |
+
target/
|
78 |
+
|
79 |
+
# Jupyter Notebook
|
80 |
+
.ipynb_checkpoints
|
81 |
+
|
82 |
+
# pyenv
|
83 |
+
.python-version
|
84 |
+
|
85 |
+
# celery beat schedule file
|
86 |
+
celerybeat-schedule
|
87 |
+
|
88 |
+
# SageMath parsed files
|
89 |
+
*.sage.py
|
90 |
+
|
91 |
+
# dotenv
|
92 |
+
.env
|
93 |
+
|
94 |
+
# virtualenv
|
95 |
+
.venv
|
96 |
+
venv/
|
97 |
+
ENV/
|
98 |
+
|
99 |
+
# Spyder project settings
|
100 |
+
.spyderproject
|
101 |
+
.spyproject
|
102 |
+
|
103 |
+
# Rope project settings
|
104 |
+
.ropeproject
|
105 |
+
|
106 |
+
# mkdocs documentation
|
107 |
+
/site
|
108 |
+
|
109 |
+
# mypy
|
110 |
+
.mypy_cache/
|
111 |
+
|
112 |
+
# Generated files
|
113 |
+
/fairseq/temporal_convolution_tbc
|
114 |
+
/fairseq/modules/*_layer/*_forward.cu
|
115 |
+
/fairseq/modules/*_layer/*_backward.cu
|
116 |
+
/fairseq/version.py
|
117 |
+
|
118 |
+
# data
|
119 |
+
data-bin/
|
120 |
+
|
121 |
+
# reranking
|
122 |
+
/examples/reranking/rerank_data
|
123 |
+
|
124 |
+
# Cython-generated C++ source files
|
125 |
+
/fairseq/data/data_utils_fast.cpp
|
126 |
+
/fairseq/data/token_block_utils_fast.cpp
|
127 |
+
|
128 |
+
# VSCODE
|
129 |
+
.vscode/ftp-sync.json
|
130 |
+
.vscode/settings.json
|
131 |
+
|
132 |
+
# Experimental Folder
|
133 |
+
experimental/*
|
134 |
+
|
135 |
+
# Weights and Biases logs
|
136 |
+
wandb/
|
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.gitmodules
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[submodule "fairseq/model_parallel/megatron"]
|
2 |
+
path = fairseq/model_parallel/megatron
|
3 |
+
url = https://github.com/ngoyal2707/Megatron-LM
|
4 |
+
branch = fairseq
|
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.isort.cfg
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
[settings]
|
2 |
+
known_third_party = _cffi_backend,agg_results,aml,bitarray,boto3,botocore,dump_hubert_feature,dynamicconv_cuda,editdistance,faiss,fasttext,feature_utils,ffmpeg,g2p_en,h5py,hydra,hypothesis,indicnlp,inflect,iopath,joblib,kaldi_io,kenlm,libfb,librosa,lightconv_cuda,matplotlib,misc,mmpt,mmpt_cli,model,nltk,npy_append_array,numpy,omegaconf,pandas,pathbuilder,preprocessing,progressbar,pythainlp,random_sequence_shuffler,regex,sacrebleu,sacremoses,scipy,sentencepiece,setuptools,six,sklearn,soundfile,sweep,sweep_wmt_en2de_transformer_big_common,tabulate,torch,torchaudio,tqdm,unidecode,utils,videoreader,wav2vec_cluster_faiss,wget,yaml
|
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.pre-commit-config.yaml
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
exclude: 'build|stubs'
|
2 |
+
|
3 |
+
default_language_version:
|
4 |
+
python: python3
|
5 |
+
|
6 |
+
repos:
|
7 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
8 |
+
rev: v4.0.1
|
9 |
+
hooks:
|
10 |
+
- id: trailing-whitespace
|
11 |
+
- id: check-ast
|
12 |
+
- id: check-merge-conflict
|
13 |
+
- id: no-commit-to-branch
|
14 |
+
args: ['--branch=master']
|
15 |
+
- id: check-added-large-files
|
16 |
+
args: ['--maxkb=500']
|
17 |
+
- id: end-of-file-fixer
|
18 |
+
|
19 |
+
- repo: https://github.com/ambv/black
|
20 |
+
rev: 21.12b0
|
21 |
+
hooks:
|
22 |
+
- id: black
|
23 |
+
language_version: python3.8
|
24 |
+
|
25 |
+
- repo: https://gitlab.com/pycqa/flake8
|
26 |
+
rev: 3.9.2
|
27 |
+
hooks:
|
28 |
+
- id: flake8
|
29 |
+
args: [
|
30 |
+
# only error for syntax errors and undefined names
|
31 |
+
"--select=E9,F63,F7,F82",
|
32 |
+
]
|
33 |
+
|
34 |
+
- repo: https://github.com/pycqa/isort
|
35 |
+
rev: 5.10.1
|
36 |
+
hooks:
|
37 |
+
- id: isort
|
38 |
+
exclude: README.md
|
39 |
+
additional_dependencies: [toml]
|
40 |
+
args: ["--profile", "black"]
|
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/CODE_OF_CONDUCT.md
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Code of Conduct
|
2 |
+
|
3 |
+
## Our Pledge
|
4 |
+
|
5 |
+
In the interest of fostering an open and welcoming environment, we as
|
6 |
+
contributors and maintainers pledge to make participation in our project and
|
7 |
+
our community a harassment-free experience for everyone, regardless of age, body
|
8 |
+
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
9 |
+
level of experience, education, socio-economic status, nationality, personal
|
10 |
+
appearance, race, religion, or sexual identity and orientation.
|
11 |
+
|
12 |
+
## Our Standards
|
13 |
+
|
14 |
+
Examples of behavior that contributes to creating a positive environment
|
15 |
+
include:
|
16 |
+
|
17 |
+
* Using welcoming and inclusive language
|
18 |
+
* Being respectful of differing viewpoints and experiences
|
19 |
+
* Gracefully accepting constructive criticism
|
20 |
+
* Focusing on what is best for the community
|
21 |
+
* Showing empathy towards other community members
|
22 |
+
|
23 |
+
Examples of unacceptable behavior by participants include:
|
24 |
+
|
25 |
+
* The use of sexualized language or imagery and unwelcome sexual attention or
|
26 |
+
advances
|
27 |
+
* Trolling, insulting/derogatory comments, and personal or political attacks
|
28 |
+
* Public or private harassment
|
29 |
+
* Publishing others' private information, such as a physical or electronic
|
30 |
+
address, without explicit permission
|
31 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
32 |
+
professional setting
|
33 |
+
|
34 |
+
## Our Responsibilities
|
35 |
+
|
36 |
+
Project maintainers are responsible for clarifying the standards of acceptable
|
37 |
+
behavior and are expected to take appropriate and fair corrective action in
|
38 |
+
response to any instances of unacceptable behavior.
|
39 |
+
|
40 |
+
Project maintainers have the right and responsibility to remove, edit, or
|
41 |
+
reject comments, commits, code, wiki edits, issues, and other contributions
|
42 |
+
that are not aligned to this Code of Conduct, or to ban temporarily or
|
43 |
+
permanently any contributor for other behaviors that they deem inappropriate,
|
44 |
+
threatening, offensive, or harmful.
|
45 |
+
|
46 |
+
## Scope
|
47 |
+
|
48 |
+
This Code of Conduct applies within all project spaces, and it also applies when
|
49 |
+
an individual is representing the project or its community in public spaces.
|
50 |
+
Examples of representing a project or community include using an official
|
51 |
+
project e-mail address, posting via an official social media account, or acting
|
52 |
+
as an appointed representative at an online or offline event. Representation of
|
53 |
+
a project may be further defined and clarified by project maintainers.
|
54 |
+
|
55 |
+
## Enforcement
|
56 |
+
|
57 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
58 |
+
reported by contacting the project team at <conduct@pytorch.org>. All
|
59 |
+
complaints will be reviewed and investigated and will result in a response that
|
60 |
+
is deemed necessary and appropriate to the circumstances. The project team is
|
61 |
+
obligated to maintain confidentiality with regard to the reporter of an incident.
|
62 |
+
Further details of specific enforcement policies may be posted separately.
|
63 |
+
|
64 |
+
Project maintainers who do not follow or enforce the Code of Conduct in good
|
65 |
+
faith may face temporary or permanent repercussions as determined by other
|
66 |
+
members of the project's leadership.
|
67 |
+
|
68 |
+
## Attribution
|
69 |
+
|
70 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
71 |
+
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
72 |
+
|
73 |
+
[homepage]: https://www.contributor-covenant.org
|
74 |
+
|
75 |
+
For answers to common questions about this code of conduct, see
|
76 |
+
https://www.contributor-covenant.org/faq
|
77 |
+
|
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/CONTRIBUTING.md
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Contributing to Facebook AI Research Sequence-to-Sequence Toolkit (fairseq)
|
2 |
+
We want to make contributing to this project as easy and transparent as
|
3 |
+
possible.
|
4 |
+
|
5 |
+
## Pull Requests
|
6 |
+
We actively welcome your pull requests.
|
7 |
+
|
8 |
+
1. Fork the repo and create your branch from `main`.
|
9 |
+
2. If you've added code that should be tested, add tests.
|
10 |
+
3. If you've changed APIs, update the documentation.
|
11 |
+
4. Ensure the test suite passes.
|
12 |
+
5. Make sure your code lints.
|
13 |
+
6. If you haven't already, complete the Contributor License Agreement ("CLA").
|
14 |
+
|
15 |
+
## Contributor License Agreement ("CLA")
|
16 |
+
In order to accept your pull request, we need you to submit a CLA. You only need
|
17 |
+
to do this once to work on any of Facebook's open source projects.
|
18 |
+
|
19 |
+
Complete your CLA here: <https://code.facebook.com/cla>
|
20 |
+
|
21 |
+
## Issues
|
22 |
+
We use GitHub issues to track public bugs. Please ensure your description is
|
23 |
+
clear and has sufficient instructions to be able to reproduce the issue.
|
24 |
+
|
25 |
+
## License
|
26 |
+
By contributing to Facebook AI Research Sequence-to-Sequence Toolkit (fairseq),
|
27 |
+
you agree that your contributions will be licensed under the LICENSE file in
|
28 |
+
the root directory of this source tree.
|
29 |
+
|
30 |
+
## Pre-commit hooks
|
31 |
+
In order to ensure your code lints, there are pre-commit hooks configured in the repository which you can install.
|
32 |
+
After installation, they will automatically run each time you commit.
|
33 |
+
An abbreviated guide is given below; for more information, refer to [the offical pre-commit documentation](https://pre-commit.com/).
|
34 |
+
|
35 |
+
### Installation
|
36 |
+
```
|
37 |
+
pip install pre-commit
|
38 |
+
pre-commit install
|
39 |
+
```
|
40 |
+
|
41 |
+
### Usage
|
42 |
+
Just commit your changes:
|
43 |
+
```
|
44 |
+
git commit -m "My informative commit message"
|
45 |
+
```
|
46 |
+
|
47 |
+
If there was a failure, you will get feedback
|
48 |
+
```
|
49 |
+
[INFO] Initializing environment for https://github.com/PyCQA/flake8.
|
50 |
+
[INFO] Installing environment for https://github.com/pre-commit/pre-commit-hooks.
|
51 |
+
[INFO] Once installed this environment will be reused.
|
52 |
+
[INFO] This may take a few minutes...
|
53 |
+
[INFO] Installing environment for https://github.com/PyCQA/flake8.
|
54 |
+
[INFO] Once installed this environment will be reused.
|
55 |
+
[INFO] This may take a few minutes...
|
56 |
+
Trim Trailing Whitespace.................................................Failed
|
57 |
+
- hook id: trailing-whitespace
|
58 |
+
- exit code: 1
|
59 |
+
- files were modified by this hook
|
60 |
+
Fixing examples/nllb/modeling/wmt15_benchmark/eval_langs2.sh
|
61 |
+
Fix End of Files.........................................................Failed
|
62 |
+
- hook id: end-of-file-fixer
|
63 |
+
- exit code: 1
|
64 |
+
- files were modified by this hook
|
65 |
+
Fixing examples/few_shot/scripts/schedule_jobs_few_shot.py
|
66 |
+
flake8...................................................................Passed
|
67 |
+
```
|
68 |
+
|
69 |
+
Certain hooks modify your files to comply.
|
70 |
+
To include these modifications, you will need to add them (i.e. `git add ...`) and commit again.
|
71 |
+
|
72 |
+
If all is well, you should see something like:
|
73 |
+
```
|
74 |
+
Trim Trailing Whitespace.................................................Passed
|
75 |
+
Fix End of Files.........................................................Passed
|
76 |
+
flake8...................................................................Passed
|
77 |
+
[gshard-fix-ci 8698644e1] Fix lint, add pre-commit hooks
|
78 |
+
10 files changed, 148 insertions(+), 110 deletions(-)
|
79 |
+
create mode 100644 .flake8
|
80 |
+
create mode 100644 .pre-commit-config.yaml
|
81 |
+
rename examples/nllb/modeling/wmt15_benchmark/{eval_langs2.py => eval_langs2.sh} (99%)
|
82 |
+
```
|
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) Facebook, Inc. and its affiliates.
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/README.md
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<p align="center">
|
2 |
+
<img src="docs/fairseq_logo.png" width="150">
|
3 |
+
<br />
|
4 |
+
<br />
|
5 |
+
<a href="https://github.com/pytorch/fairseq/blob/main/LICENSE"><img alt="MIT License" src="https://img.shields.io/badge/license-MIT-blue.svg" /></a>
|
6 |
+
<a href="https://github.com/pytorch/fairseq/releases"><img alt="Latest Release" src="https://img.shields.io/github/release/pytorch/fairseq.svg" /></a>
|
7 |
+
<a href="https://github.com/pytorch/fairseq/actions?query=workflow:build"><img alt="Build Status" src="https://github.com/pytorch/fairseq/workflows/build/badge.svg" /></a>
|
8 |
+
<a href="https://fairseq.readthedocs.io/en/latest/?badge=latest"><img alt="Documentation Status" src="https://readthedocs.org/projects/fairseq/badge/?version=latest" /></a>
|
9 |
+
</p>
|
10 |
+
|
11 |
+
--------------------------------------------------------------------------------
|
12 |
+
|
13 |
+
Fairseq(-py) is a sequence modeling toolkit that allows researchers and
|
14 |
+
developers to train custom models for translation, summarization, language
|
15 |
+
modeling and other text generation tasks.
|
16 |
+
|
17 |
+
We provide reference implementations of various sequence modeling papers:
|
18 |
+
|
19 |
+
<details><summary>List of implemented papers</summary><p>
|
20 |
+
|
21 |
+
* **Convolutional Neural Networks (CNN)**
|
22 |
+
+ [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/conv_lm/README.md)
|
23 |
+
+ [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md)
|
24 |
+
+ [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
|
25 |
+
+ [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md)
|
26 |
+
+ [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md)
|
27 |
+
* **LightConv and DynamicConv models**
|
28 |
+
+ [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md)
|
29 |
+
* **Long Short-Term Memory (LSTM) networks**
|
30 |
+
+ Effective Approaches to Attention-based Neural Machine Translation (Luong et al., 2015)
|
31 |
+
* **Transformer (self-attention) networks**
|
32 |
+
+ Attention Is All You Need (Vaswani et al., 2017)
|
33 |
+
+ [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md)
|
34 |
+
+ [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md)
|
35 |
+
+ [Adaptive Input Representations for Neural Language Modeling (Baevski and Auli, 2018)](examples/language_model/README.adaptive_inputs.md)
|
36 |
+
+ [Lexically constrained decoding with dynamic beam allocation (Post & Vilar, 2018)](examples/constrained_decoding/README.md)
|
37 |
+
+ [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context (Dai et al., 2019)](examples/truncated_bptt/README.md)
|
38 |
+
+ [Adaptive Attention Span in Transformers (Sukhbaatar et al., 2019)](examples/adaptive_span/README.md)
|
39 |
+
+ [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
|
40 |
+
+ [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
|
41 |
+
+ [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
|
42 |
+
+ [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md )
|
43 |
+
+ [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md)
|
44 |
+
+ [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md)
|
45 |
+
+ [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md)
|
46 |
+
+ [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md)
|
47 |
+
+ [Generating Medical Reports from Patient-Doctor Conversations Using Sequence-to-Sequence Models (Enarvi et al., 2020)](examples/pointer_generator/README.md)
|
48 |
+
+ [Linformer: Self-Attention with Linear Complexity (Wang et al., 2020)](examples/linformer/README.md)
|
49 |
+
+ [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md)
|
50 |
+
+ [Deep Transformers with Latent Depth (Li et al., 2020)](examples/latent_depth/README.md)
|
51 |
+
+ [Unsupervised Cross-lingual Representation Learning for Speech Recognition (Conneau et al., 2020)](https://arxiv.org/abs/2006.13979)
|
52 |
+
+ [Self-training and Pre-training are Complementary for Speech Recognition (Xu et al., 2020)](https://arxiv.org/abs/2010.11430)
|
53 |
+
+ [Robust wav2vec 2.0: Analyzing Domain Shift in Self-Supervised Pre-Training (Hsu, et al., 2021)](https://arxiv.org/abs/2104.01027)
|
54 |
+
+ [Unsupervised Speech Recognition (Baevski, et al., 2021)](https://arxiv.org/abs/2105.11084)
|
55 |
+
+ [Simple and Effective Zero-shot Cross-lingual Phoneme Recognition (Xu et al., 2021)](https://arxiv.org/abs/2109.11680)
|
56 |
+
+ [VideoCLIP: Contrastive Pre-training for Zero-shot Video-Text Understanding (Xu et. al., 2021)](https://arxiv.org/pdf/2109.14084.pdf)
|
57 |
+
+ [VLM: Task-agnostic Video-Language Model Pre-training for Video Understanding (Xu et. al., 2021)](https://aclanthology.org/2021.findings-acl.370.pdf)
|
58 |
+
+ [NormFormer: Improved Transformer Pretraining with Extra Normalization (Shleifer et. al, 2021)](examples/normformer/README.md)
|
59 |
+
* **Non-autoregressive Transformers**
|
60 |
+
+ Non-Autoregressive Neural Machine Translation (Gu et al., 2017)
|
61 |
+
+ Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018)
|
62 |
+
+ Insertion Transformer: Flexible Sequence Generation via Insertion Operations (Stern et al. 2019)
|
63 |
+
+ Mask-Predict: Parallel Decoding of Conditional Masked Language Models (Ghazvininejad et al., 2019)
|
64 |
+
+ [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md)
|
65 |
+
* **Finetuning**
|
66 |
+
+ [Better Fine-Tuning by Reducing Representational Collapse (Aghajanyan et al. 2020)](examples/rxf/README.md)
|
67 |
+
|
68 |
+
</p></details>
|
69 |
+
|
70 |
+
### What's New:
|
71 |
+
* October 2021 [Released VideoCLIP and VLM models](examples/MMPT/README.md)
|
72 |
+
* October 2021 [Released multilingual finetuned XLSR-53 model](examples/wav2vec/README.md)
|
73 |
+
* September 2021 [`master` branch renamed to `main`](https://github.com/github/renaming).
|
74 |
+
* July 2021 [Released DrNMT code](examples/discriminative_reranking_nmt/README.md)
|
75 |
+
* July 2021 [Released Robust wav2vec 2.0 model](examples/wav2vec/README.md)
|
76 |
+
* June 2021 [Released XLMR-XL and XLMR-XXL models](examples/xlmr/README.md)
|
77 |
+
* May 2021 [Released Unsupervised Speech Recognition code](examples/wav2vec/unsupervised/README.md)
|
78 |
+
* March 2021 [Added full parameter and optimizer state sharding + CPU offloading](examples/fully_sharded_data_parallel/README.md)
|
79 |
+
* February 2021 [Added LASER training code](examples/laser/README.md)
|
80 |
+
* December 2020: [Added Adaptive Attention Span code](examples/adaptive_span/README.md)
|
81 |
+
* December 2020: [GottBERT model and code released](examples/gottbert/README.md)
|
82 |
+
* November 2020: Adopted the [Hydra](https://github.com/facebookresearch/hydra) configuration framework
|
83 |
+
* [see documentation explaining how to use it for new and existing projects](docs/hydra_integration.md)
|
84 |
+
* November 2020: [fairseq 0.10.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.10.0)
|
85 |
+
* October 2020: [Added R3F/R4F (Better Fine-Tuning) code](examples/rxf/README.md)
|
86 |
+
* October 2020: [Deep Transformer with Latent Depth code released](examples/latent_depth/README.md)
|
87 |
+
* October 2020: [Added CRISS models and code](examples/criss/README.md)
|
88 |
+
|
89 |
+
<details><summary>Previous updates</summary><p>
|
90 |
+
|
91 |
+
* September 2020: [Added Linformer code](examples/linformer/README.md)
|
92 |
+
* September 2020: [Added pointer-generator networks](examples/pointer_generator/README.md)
|
93 |
+
* August 2020: [Added lexically constrained decoding](examples/constrained_decoding/README.md)
|
94 |
+
* August 2020: [wav2vec2 models and code released](examples/wav2vec/README.md)
|
95 |
+
* July 2020: [Unsupervised Quality Estimation code released](examples/unsupervised_quality_estimation/README.md)
|
96 |
+
* May 2020: [Follow fairseq on Twitter](https://twitter.com/fairseq)
|
97 |
+
* April 2020: [Monotonic Multihead Attention code released](examples/simultaneous_translation/README.md)
|
98 |
+
* April 2020: [Quant-Noise code released](examples/quant_noise/README.md)
|
99 |
+
* April 2020: [Initial model parallel support and 11B parameters unidirectional LM released](examples/megatron_11b/README.md)
|
100 |
+
* March 2020: [Byte-level BPE code released](examples/byte_level_bpe/README.md)
|
101 |
+
* February 2020: [mBART model and code released](examples/mbart/README.md)
|
102 |
+
* February 2020: [Added tutorial for back-translation](https://github.com/pytorch/fairseq/tree/main/examples/backtranslation#training-your-own-model-wmt18-english-german)
|
103 |
+
* December 2019: [fairseq 0.9.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.9.0)
|
104 |
+
* November 2019: [VizSeq released (a visual analysis toolkit for evaluating fairseq models)](https://facebookresearch.github.io/vizseq/docs/getting_started/fairseq_example)
|
105 |
+
* November 2019: [CamemBERT model and code released](examples/camembert/README.md)
|
106 |
+
* November 2019: [BART model and code released](examples/bart/README.md)
|
107 |
+
* November 2019: [XLM-R models and code released](examples/xlmr/README.md)
|
108 |
+
* September 2019: [Nonautoregressive translation code released](examples/nonautoregressive_translation/README.md)
|
109 |
+
* August 2019: [WMT'19 models released](examples/wmt19/README.md)
|
110 |
+
* July 2019: fairseq relicensed under MIT license
|
111 |
+
* July 2019: [RoBERTa models and code released](examples/roberta/README.md)
|
112 |
+
* June 2019: [wav2vec models and code released](examples/wav2vec/README.md)
|
113 |
+
|
114 |
+
</p></details>
|
115 |
+
|
116 |
+
### Features:
|
117 |
+
|
118 |
+
* multi-GPU training on one machine or across multiple machines (data and model parallel)
|
119 |
+
* fast generation on both CPU and GPU with multiple search algorithms implemented:
|
120 |
+
+ beam search
|
121 |
+
+ Diverse Beam Search ([Vijayakumar et al., 2016](https://arxiv.org/abs/1610.02424))
|
122 |
+
+ sampling (unconstrained, top-k and top-p/nucleus)
|
123 |
+
+ [lexically constrained decoding](examples/constrained_decoding/README.md) (Post & Vilar, 2018)
|
124 |
+
* [gradient accumulation](https://fairseq.readthedocs.io/en/latest/getting_started.html#large-mini-batch-training-with-delayed-updates) enables training with large mini-batches even on a single GPU
|
125 |
+
* [mixed precision training](https://fairseq.readthedocs.io/en/latest/getting_started.html#training-with-half-precision-floating-point-fp16) (trains faster with less GPU memory on [NVIDIA tensor cores](https://developer.nvidia.com/tensor-cores))
|
126 |
+
* [extensible](https://fairseq.readthedocs.io/en/latest/overview.html): easily register new models, criterions, tasks, optimizers and learning rate schedulers
|
127 |
+
* [flexible configuration](docs/hydra_integration.md) based on [Hydra](https://github.com/facebookresearch/hydra) allowing a combination of code, command-line and file based configuration
|
128 |
+
* [full parameter and optimizer state sharding](examples/fully_sharded_data_parallel/README.md)
|
129 |
+
* [offloading parameters to CPU](examples/fully_sharded_data_parallel/README.md)
|
130 |
+
|
131 |
+
We also provide [pre-trained models for translation and language modeling](#pre-trained-models-and-examples)
|
132 |
+
with a convenient `torch.hub` interface:
|
133 |
+
|
134 |
+
``` python
|
135 |
+
en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model')
|
136 |
+
en2de.translate('Hello world', beam=5)
|
137 |
+
# 'Hallo Welt'
|
138 |
+
```
|
139 |
+
|
140 |
+
See the PyTorch Hub tutorials for [translation](https://pytorch.org/hub/pytorch_fairseq_translation/)
|
141 |
+
and [RoBERTa](https://pytorch.org/hub/pytorch_fairseq_roberta/) for more examples.
|
142 |
+
|
143 |
+
# Requirements and Installation
|
144 |
+
|
145 |
+
* [PyTorch](http://pytorch.org/) version >= 1.5.0
|
146 |
+
* Python version >= 3.6
|
147 |
+
* For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl)
|
148 |
+
* **To install fairseq** and develop locally:
|
149 |
+
|
150 |
+
``` bash
|
151 |
+
git clone https://github.com/pytorch/fairseq
|
152 |
+
cd fairseq
|
153 |
+
pip install --editable ./
|
154 |
+
|
155 |
+
# on MacOS:
|
156 |
+
# CFLAGS="-stdlib=libc++" pip install --editable ./
|
157 |
+
|
158 |
+
# to install the latest stable release (0.10.x)
|
159 |
+
# pip install fairseq
|
160 |
+
```
|
161 |
+
|
162 |
+
* **For faster training** install NVIDIA's [apex](https://github.com/NVIDIA/apex) library:
|
163 |
+
|
164 |
+
``` bash
|
165 |
+
git clone https://github.com/NVIDIA/apex
|
166 |
+
cd apex
|
167 |
+
pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" \
|
168 |
+
--global-option="--deprecated_fused_adam" --global-option="--xentropy" \
|
169 |
+
--global-option="--fast_multihead_attn" ./
|
170 |
+
```
|
171 |
+
|
172 |
+
* **For large datasets** install [PyArrow](https://arrow.apache.org/docs/python/install.html#using-pip): `pip install pyarrow`
|
173 |
+
* If you use Docker make sure to increase the shared memory size either with `--ipc=host` or `--shm-size`
|
174 |
+
as command line options to `nvidia-docker run` .
|
175 |
+
|
176 |
+
# Getting Started
|
177 |
+
|
178 |
+
The [full documentation](https://fairseq.readthedocs.io/) contains instructions
|
179 |
+
for getting started, training new models and extending fairseq with new model
|
180 |
+
types and tasks.
|
181 |
+
|
182 |
+
# Pre-trained models and examples
|
183 |
+
|
184 |
+
We provide pre-trained models and pre-processed, binarized test sets for several tasks listed below,
|
185 |
+
as well as example training and evaluation commands.
|
186 |
+
|
187 |
+
* [Translation](examples/translation/README.md): convolutional and transformer models are available
|
188 |
+
* [Language Modeling](examples/language_model/README.md): convolutional and transformer models are available
|
189 |
+
|
190 |
+
We also have more detailed READMEs to reproduce results from specific papers:
|
191 |
+
|
192 |
+
* [XLS-R: Self-supervised Cross-lingual Speech Representation Learning at Scale (Babu et al., 2021)](examples/wav2vec/xlsr/README.md)
|
193 |
+
* [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md)
|
194 |
+
* [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md)
|
195 |
+
* [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md)
|
196 |
+
* [Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)](examples/quant_noise/README.md)
|
197 |
+
* [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md)
|
198 |
+
* [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md)
|
199 |
+
* [Reducing Transformer Depth on Demand with Structured Dropout (Fan et al., 2019)](examples/layerdrop/README.md)
|
200 |
+
* [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md)
|
201 |
+
* [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md)
|
202 |
+
* [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
|
203 |
+
* [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
|
204 |
+
* [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md)
|
205 |
+
* [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
|
206 |
+
* [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md)
|
207 |
+
* [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md)
|
208 |
+
* [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
|
209 |
+
* [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md)
|
210 |
+
* [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md)
|
211 |
+
* [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md)
|
212 |
+
* [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/README.conv.md)
|
213 |
+
|
214 |
+
# Join the fairseq community
|
215 |
+
|
216 |
+
* Twitter: https://twitter.com/fairseq
|
217 |
+
* Facebook page: https://www.facebook.com/groups/fairseq.users
|
218 |
+
* Google group: https://groups.google.com/forum/#!forum/fairseq-users
|
219 |
+
|
220 |
+
# License
|
221 |
+
|
222 |
+
fairseq(-py) is MIT-licensed.
|
223 |
+
The license applies to the pre-trained models as well.
|
224 |
+
|
225 |
+
# Citation
|
226 |
+
|
227 |
+
Please cite as:
|
228 |
+
|
229 |
+
``` bibtex
|
230 |
+
@inproceedings{ott2019fairseq,
|
231 |
+
title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling},
|
232 |
+
author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli},
|
233 |
+
booktitle = {Proceedings of NAACL-HLT 2019: Demonstrations},
|
234 |
+
year = {2019},
|
235 |
+
}
|
236 |
+
```
|
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
!*/*.sh
|
2 |
+
!*/*.md
|
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/.gitignore
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
.python-version
|
86 |
+
|
87 |
+
# pipenv
|
88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
91 |
+
# install all needed dependencies.
|
92 |
+
#Pipfile.lock
|
93 |
+
|
94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
95 |
+
__pypackages__/
|
96 |
+
|
97 |
+
# Celery stuff
|
98 |
+
celerybeat-schedule
|
99 |
+
celerybeat.pid
|
100 |
+
|
101 |
+
# SageMath parsed files
|
102 |
+
*.sage.py
|
103 |
+
|
104 |
+
# Environments
|
105 |
+
.env
|
106 |
+
.venv
|
107 |
+
env/
|
108 |
+
venv/
|
109 |
+
ENV/
|
110 |
+
env.bak/
|
111 |
+
venv.bak/
|
112 |
+
|
113 |
+
# Spyder project settings
|
114 |
+
.spyderproject
|
115 |
+
.spyproject
|
116 |
+
|
117 |
+
# Rope project settings
|
118 |
+
.ropeproject
|
119 |
+
|
120 |
+
# mkdocs documentation
|
121 |
+
/site
|
122 |
+
|
123 |
+
# mypy
|
124 |
+
.mypy_cache/
|
125 |
+
.dmypy.json
|
126 |
+
dmypy.json
|
127 |
+
|
128 |
+
# Pyre type checker
|
129 |
+
.pyre/
|
130 |
+
runs
|
131 |
+
data
|
132 |
+
pretrained_models
|
133 |
+
projects/mmfusion_*
|
134 |
+
log_test
|
135 |
+
third-party
|
136 |
+
python_log
|
137 |
+
slurm_snapshot_code
|
138 |
+
lightning_logs
|
139 |
+
demos
|
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/CONFIG.md
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### Config Files Explained
|
2 |
+
|
3 |
+
Taking `projects/mfmmlm.yaml` for example, which run pretraining using masked frame model (MFM) and masked language model (MLM) on a single BERT:
|
4 |
+
|
5 |
+
```yaml
|
6 |
+
project_dir: mfmmlm # specify the project dir for this baseline.
|
7 |
+
run_task:
|
8 |
+
- how2.yaml # run pretraining on how2 when launching `projects/taskmfmmlm.yaml`
|
9 |
+
- [vtt.yaml, vttcap.yaml, vttqa.yaml, youcook.yaml, youcookcap.yaml, crosstask.yaml, coin.yaml] # run fine-tuning tasks.
|
10 |
+
base_dir: task # a global template folder to specify each training task.
|
11 |
+
task_group:
|
12 |
+
pretrain: # section for pretraining. Most baselines differs in this section.
|
13 |
+
task_list:
|
14 |
+
- how2.yaml # reconfig `projects/task/how2.yaml`
|
15 |
+
dataset:
|
16 |
+
aligner: MFMMLMAligner # overwrite the aligner for MFMMLM training task.
|
17 |
+
model:
|
18 |
+
model_cls: MMFusionMFMMLM # overwrite the model, which constructs negative examples for MFM on-the-fly.
|
19 |
+
loss:
|
20 |
+
loss_cls: MFMMLM # overwrite the loss as MFMMLM, which combines MFM and MLM together.
|
21 |
+
fairseq: # all fairseq args can be expecified under this name.
|
22 |
+
dataset:
|
23 |
+
batch_size: 128
|
24 |
+
finetune: # section for fine-tuning tasks, we don't need to change anything here mostly since we want to see how pretraining can contribute to finetuning.
|
25 |
+
task_list: # specify the list of downstream tasks, e.g., copy `projects/task/vtt.yaml` to `projects/mfmmlm`.
|
26 |
+
- vtt.yaml
|
27 |
+
- vttqa.yaml
|
28 |
+
- youcook.yaml
|
29 |
+
- youcookcap.yaml
|
30 |
+
- crosstask.yaml
|
31 |
+
- coin.yaml
|
32 |
+
test: # section for testing.
|
33 |
+
task_list:
|
34 |
+
- test_vtt.yaml
|
35 |
+
- test_vttqa.yaml
|
36 |
+
- test_youcook.yaml
|
37 |
+
- test_youcookcap.yaml
|
38 |
+
- test_crosstask.yaml
|
39 |
+
- test_crosstask_zs.yaml
|
40 |
+
- test_coin.yaml
|
41 |
+
```
|
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/DATASET.md
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Dataset
|
2 |
+
|
3 |
+
We understand video data are challenging to download and process. For videos, we provide our preprocessing scripts under `scripts/video_feature_extractor` (deeply adapted from `https://github.com/antoine77340/video_feature_extractor`); for text, we pre-tokenizing scripts under `scripts/text_token_extractor`.
|
4 |
+
|
5 |
+
### S3D Feature Extraction
|
6 |
+
We use pre-trained [S3D](https://github.com/antoine77340/S3D_HowTo100M) for video feature extraction. Please place the models as `pretrained_models/s3d_dict.npy` and `pretrained_models/s3d_howto100m.pth`.
|
7 |
+
|
8 |
+
We implement a `PathBuilder` to automatically track video ids, source video paths to their feature locations (you may need `conda install -c anaconda pandas`). Decoding may need `pip install ffmpeg-python`.
|
9 |
+
|
10 |
+
### Howto100M
|
11 |
+
[Howto100M](https://www.di.ens.fr/willow/research/howto100m/) is a large-scale video pre-training datasets. You may download videos by yourself and run preprocessing of our scripts.
|
12 |
+
|
13 |
+
Several key differences of our preprocessing from existing papers: (1) we use `raw_caption.json` instead of `caption.json` to have pure self-supervision on text (`caption.json` has manual removal of stop words); (2) we remove partially duplicated texts that are originally designed for real-time readability (see `mmpt/processors/dedupprocessor.py`); (3) then we shard video/text features using `SharedTensor` in `mmpt/utils/shardedtensor.py` for fast loading during training (faster than `h5py`).
|
14 |
+
|
15 |
+
#### Steps
|
16 |
+
##### video
|
17 |
+
To extract video features: edit and run `bash scripts/video_feature_extractor/how2/s3d.sh`. (consider to run this on multiple machines; by default, we store features in fp16 to save space and also for faster training).
|
18 |
+
|
19 |
+
Split available video ids as `data/how2/how2_s3d_train.lst` and `data/how2/how2_s3d_val.lst`.
|
20 |
+
|
21 |
+
Lastly, pack video features into `ShardedTensor` using `python scripts/video_feature_extractor/shard_feature.py`.
|
22 |
+
|
23 |
+
##### text
|
24 |
+
Clean captions using `python -m mmpt.processors.dedupprocessor`.
|
25 |
+
|
26 |
+
Tokenize dedupped captions `data/how2/raw_caption_dedup.pkl` into sharded numpy arrays:
|
27 |
+
```
|
28 |
+
python scripts/text_token_extractor/pretokenization.py scripts/text_token_extractor/configs/bert-base-uncased.yaml
|
29 |
+
```
|
30 |
+
|
31 |
+
### Youcook, MSRVTT etc.
|
32 |
+
We use the version of Youcook and MSRVTT come with Howto100M and MILNCE. Please download the data to `data/youcook` and `data/msrvtt` accordingly, you can also check `projects/task/youcook.yaml` and `projects/task/vtt.yaml` etc. in details.
|
33 |
+
We extract features for Youcook, MSRVTT similar to the first step of Howto100M but we read text from meta data directly and perform on-the-fly tokenization.
|
34 |
+
|
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/README.md
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# VideoCLIP and VLM
|
2 |
+
|
3 |
+
You just find this toolkit for multimodal video understanding! It contains implementation of two recent multi-modal video understanding papers [VideoCLIP](https://arxiv.org/pdf/2109.14084.pdf) (EMNLP, 2021) and [VLM](https://aclanthology.org/2021.findings-acl.370.pdf) (ACL Findings, 2021), along with high-performance toolkits that are typically lacking in existing codebase. The toolkit is desigend to contain generic performance-tuned components that can be potentially adapted to other frameworks (we initially use fairseq).
|
4 |
+
|
5 |
+
VideoCLIP is a contrastive learning model for zero-shot transfer to retrieval/classification/sequence labeling style tasks.
|
6 |
+
|
7 |
+
<img src="videoclip.png" width="350" class="center">
|
8 |
+
|
9 |
+
VLM is a masked language model style pre-training using only one encoder with masked modality model (MMM) for retrieval/generation/sequence labeling style tasks.
|
10 |
+
|
11 |
+
<img src="vlm.png" width="350" class="center">
|
12 |
+
|
13 |
+
### News
|
14 |
+
[Oct. 2021] Initial release of implementation for the following papers:
|
15 |
+
[VideoCLIP: Contrastive Pre-training for Zero-shot Video-Text Understanding](https://arxiv.org/pdf/2109.14084.pdf) (Xu et. al., EMNLP 2021)
|
16 |
+
[VLM: Task-agnostic Video-Language Model Pre-training for Video Understanding](https://aclanthology.org/2021.findings-acl.370.pdf) (Xu et. al., ACL Findings 2021)
|
17 |
+
|
18 |
+
|
19 |
+
### Installation
|
20 |
+
We aim to minimize the dependency of this repo on other packages.
|
21 |
+
We use fairseq as the main trainer (no models/datasets dependency on fairseq. We will support other trainer in future):
|
22 |
+
```
|
23 |
+
git clone https://github.com/pytorch/fairseq
|
24 |
+
cd fairseq
|
25 |
+
pip install -e . # also optionally follow fairseq README for apex installation for fp16 training.
|
26 |
+
export MKL_THREADING_LAYER=GNU # fairseq may need this for numpy.
|
27 |
+
```
|
28 |
+
|
29 |
+
Then install this toolkit:
|
30 |
+
```
|
31 |
+
cd examples/MMPT # MMPT can be in any folder, not necessarily under fairseq/examples.
|
32 |
+
pip install -e .
|
33 |
+
```
|
34 |
+
|
35 |
+
The code is developed under Python=3.8.8, Pytorch=1.8, cuda=11.0 with fairseq=1.0.0a0+af0389f and tested under Python=3.8.8 pytorch=1.9 cuda=11.0 fairseq=1.0.0a0+8e7bc73 during code release.
|
36 |
+
Most models require `transformers==3.4` for API compatibility `pip install transformers==3.4`.
|
37 |
+
In addition, some downstream tasks may need `conda install pandas`.
|
38 |
+
|
39 |
+
|
40 |
+
### Usage
|
41 |
+
#### Download Checkpoints
|
42 |
+
We use pre-trained [S3D](https://github.com/antoine77340/S3D_HowTo100M) for video feature extraction. Please place the models as `pretrained_models/s3d_dict.npy` and `pretrained_models/s3d_howto100m.pth`.
|
43 |
+
|
44 |
+
Download VideoCLIP checkpoint `https://dl.fbaipublicfiles.com/MMPT/retri/videoclip/checkpoint_best.pt` to `runs/retri/videoclip` or VLM checkpoint `https://dl.fbaipublicfiles.com/MMPT/mtm/vlm/checkpoint_best.pt` to `runs/mtm/vlm`.
|
45 |
+
|
46 |
+
#### Demo of Inference
|
47 |
+
run `python locallaunch.py projects/retri/videoclip.yaml --dryrun` to get all `.yaml`s for VideoCLIP.
|
48 |
+
|
49 |
+
```python
|
50 |
+
import torch
|
51 |
+
|
52 |
+
from mmpt.models import MMPTModel
|
53 |
+
|
54 |
+
|
55 |
+
model, tokenizer, aligner = MMPTModel.from_pretrained(
|
56 |
+
"projects/retri/videoclip/how2.yaml")
|
57 |
+
|
58 |
+
model.eval()
|
59 |
+
|
60 |
+
|
61 |
+
# B, T, FPS, H, W, C (VideoCLIP is trained on 30 fps of s3d)
|
62 |
+
video_frames = torch.randn(1, 2, 30, 224, 224, 3)
|
63 |
+
caps, cmasks = aligner._build_text_seq(
|
64 |
+
tokenizer("some text", add_special_tokens=False)["input_ids"]
|
65 |
+
)
|
66 |
+
|
67 |
+
caps, cmasks = caps[None, :], cmasks[None, :] # bsz=1
|
68 |
+
|
69 |
+
with torch.no_grad():
|
70 |
+
output = model(video_frames, caps, cmasks, return_score=True)
|
71 |
+
print(output["score"]) # dot-product
|
72 |
+
```
|
73 |
+
|
74 |
+
#### Data Preparation
|
75 |
+
See [dataset](DATASET.md) for each dataset.
|
76 |
+
|
77 |
+
#### Global Config for Training Pipeline
|
78 |
+
We organize a global config file for a training/testing pipeline under projects (see a detailed [explanation](CONFIG.md)). For example, VideoCLIP in `projects/retri/videoclip.yaml` and VLM is in `projects/mtm/vlm.yaml`.
|
79 |
+
|
80 |
+
We wrap all cmds into `locallaunch.py` and `mmpt_cli/localjob.py`. You can check concrete cmds by `--dryrun` and then drop it for actual run.
|
81 |
+
|
82 |
+
First, run `python locallaunch.py projects/retri/videoclip.yaml --dryrun` will generate configs for all configs of pre-training, zero-shot evaluation, fine-tuning and testing, for VideoCLIP under `projects/retri/videoclip`.
|
83 |
+
|
84 |
+
Then each (either training or evaluation) process will be configed by a concrete config file (we save all complex arguments into the concrete config file for reproducibility, including fairseq args). For example, run zero-shot evaluation on youcook,
|
85 |
+
```
|
86 |
+
python locallaunch.py projects/retri/videoclip/test_youcook_zs.yaml --jobtype local_predict # zero-shot evaluation.
|
87 |
+
python locallaunch.py projects/retri/videoclip/youcook_videoclip.yaml --jobtype local_single --dryrun # fine-tuning: use --dryrun to check cmds and drop it to make an actual run; local_small will run on two gpus (as in paper).
|
88 |
+
python locallaunch.py projects/retri/videoclip/test_youcook_videoclip.yaml --jobtype local_predict # testing on fine-tuned model.
|
89 |
+
```
|
90 |
+
|
91 |
+
Pretraining can be run as:
|
92 |
+
```
|
93 |
+
python locallaunch.py projects/retri/videoclip/how2.yaml --jobtype local_single --dryrun # check then drop dryrun; paper is ran on local_big as 8 gpus.
|
94 |
+
```
|
95 |
+
You may need to change `--jobtype`, check/extend `LocalJob` in `mmpt_cli/localjob.py` for multi-gpu/multi-node pre-training.
|
96 |
+
|
97 |
+
The detailed instructions of pretraining and fine-tuning can be found at [pretraining instruction](pretraining.md) and [finetuning instruction](endtask.md).
|
98 |
+
|
99 |
+
|
100 |
+
### Development
|
101 |
+
Several components of this toolkit can be re-used for future research (and also our ongoing research).
|
102 |
+
|
103 |
+
#### Framework Wrapper
|
104 |
+
We currently only support fairseq, but most components can be easily fit into other frameworks like huggingface. This repo is a `--user-dir` of fairseq with fairseq wrapper. For example, `mmpt/tasks` includes a `FairseqMMTTask`, which manages `mmpt/datasets` with `FairseqDataset`, `mmpt/models` with `FairseqModel`, `mmpt/losses` with `FairseqCriterion`.
|
105 |
+
|
106 |
+
#### Processors
|
107 |
+
**Multi**modal research introduces the complexity on modality alignment from different input sources to losses. Inspired by [MMF](https://github.com/facebookresearch/mmf), this toolkit leverages `mmpt/processors` to handle various needs of data preprocessing and loading, **alleviating** the needs of multiple `torch.data.utils.Dataset` (that can be tricky for ablation study).
|
108 |
+
Processors can also be decoupled from `torch.data.utils.Dataset` for offline preprocessing instead of on-the-fly data preprocessing.
|
109 |
+
|
110 |
+
We decouple a `mmpt.MMDataset` as 3 types of processors: `MetaProcessor`, `VideoProcessor`, `TextProcessor` and `Aligner`. They can be configed in `dataset` field of a config file (e.g., see `projects/task/how2.yaml`).
|
111 |
+
`MetaProcessor` is used to load the meta data about a dataset, aka, all video_ids of how2 dataset.
|
112 |
+
`VideoProcessor` is used to load the video features about a dataset. For example, S3D features for each second of a video.
|
113 |
+
`TextProcessor` is used to load the text (feature). For example, BERT pre-tokenized text clips for how2 dataset (with `start`s, `end`s of timestamps and `cap` for `token_ids`).
|
114 |
+
`Aligner` is the core class for different baselines that prepares the training data. For example, sampling a clip, masking tokens for MLM, etc.
|
115 |
+
|
116 |
+
#### Performance-tuned Components
|
117 |
+
To speed up pre-training, this toolkit uses sharded features stored in mmaped numpy, backed by `ShardedTensor` in `mmpt/utils/shardedtensor.py` (adopted from MARGE paper). This reduces the loads of IO for multi-GPU training without loading all features for a video into the memory each time and `ShardedTensor` ensure features are stored in continuous disk space for near random access. This is used for both How2 video features and texts in `mmpt/processors/how2processor.py`.
|
118 |
+
|
119 |
+
|
120 |
+
### Citation
|
121 |
+
If this codebase is useful for your work, please cite the following papers:
|
122 |
+
|
123 |
+
```BibTeX
|
124 |
+
@inproceedings{xu-etal-2021-videoclip,
|
125 |
+
title = "{VideoCLIP}: Contrastive Pre-training for\\Zero-shot Video-Text Understanding",
|
126 |
+
author = "Xu, Hu and
|
127 |
+
Ghosh, Gargi and
|
128 |
+
Huang, Po-Yao and
|
129 |
+
Okhonko, Dmytro and
|
130 |
+
Aghajanyan, Armen and
|
131 |
+
Metze, Florian and
|
132 |
+
Zettlemoyer, Luke and
|
133 |
+
Feichtenhofer, Christoph",
|
134 |
+
booktitle = "Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing (EMNLP)",
|
135 |
+
month = nov,
|
136 |
+
year = "2021",
|
137 |
+
address = "Online",
|
138 |
+
publisher = "Association for Computational Linguistics",
|
139 |
+
}
|
140 |
+
|
141 |
+
@inproceedings{xu-etal-2021-vlm,
|
142 |
+
title = "{VLM}: Task-agnostic Video-Language Model Pre-training for Video Understanding",
|
143 |
+
author = "Xu, Hu and
|
144 |
+
Ghosh, Gargi and
|
145 |
+
Huang, Po-Yao and
|
146 |
+
Arora, Prahal and
|
147 |
+
Aminzadeh, Masoumeh and
|
148 |
+
Feichtenhofer, Christoph and
|
149 |
+
Metze, Florian and
|
150 |
+
Zettlemoyer, Luke",
|
151 |
+
booktitle = "Findings of the Association for Computational Linguistics: ACL-IJCNLP 2021",
|
152 |
+
month = aug,
|
153 |
+
year = "2021",
|
154 |
+
address = "Online",
|
155 |
+
publisher = "Association for Computational Linguistics",
|
156 |
+
url = "https://aclanthology.org/2021.findings-acl.370",
|
157 |
+
doi = "10.18653/v1/2021.findings-acl.370",
|
158 |
+
pages = "4227--4239",
|
159 |
+
}
|
160 |
+
```
|
161 |
+
|
162 |
+
### Bug Reports
|
163 |
+
This repo is in its initial stage, welcome bug reports to huxu@fb.com
|
164 |
+
|
165 |
+
### Copyright
|
166 |
+
The majority of Multimodal Pre-training (MMPT) is licensed under CC-BY-NC, however portions of the project are available under separate license terms: Evaluation Codes/Models: Howto100M and HuggingFace Transformers are licensed under the Apache2.0 license; COIN and NLG-eval are licensed under the MIT license; CrossTask is licensed under the BSD-3; DiDeMo is licensed under the BSD-2 license.
|
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/endtask.md
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Zero-shot Transfer and Finetuning
|
2 |
+
|
3 |
+
(If you are new to the ideas of `mmpt.processors`, see [README](README.md) first.)
|
4 |
+
All finetuning datasets (specifically `processors`) are defined in `mmpt.processors.dsprocessor`.
|
5 |
+
Given the complexity of different types of finetuning tasks, each task may have their own meta/video/text/aligner processors and `mmpt/evaluators/{Predictor,Metric}`.
|
6 |
+
|
7 |
+
### Tasks
|
8 |
+
|
9 |
+
Currently, we support 5 end datasets: `MSRVTT`, `Youcook`, `COIN`, `Crosstask` and `DiDeMo` with the following tasks:
|
10 |
+
text-video retrieval: `MSRVTT`, `Youcook`, `DiDeMo`;
|
11 |
+
video captioning: `Youcook`;
|
12 |
+
Video Question and Answering: `MSRVTT-QA`.
|
13 |
+
|
14 |
+
To add your own dataset, you can specify the corresponding processors and config them in the `dataset` field of a config file, such as `projects/task/vtt.yaml`.
|
15 |
+
|
16 |
+
### Zero-shot Transfer (no Training)
|
17 |
+
Zero-shot transfer will run the pre-trained model (e.g., VideoCLIP) directly on testing data. Configs with pattern: `projects/task/*_zs_*.yaml` are dedicated for zero-shot transfer.
|
18 |
+
|
19 |
+
### Fine-tuning
|
20 |
+
|
21 |
+
The training of a downstream task is similar to pretraining, execept you may need to specify the `restore_file` in `fairseq.checkpoint` and reset optimizers, see `projects/task/ft.yaml` that is included by `projects/task/vtt.yaml`.
|
22 |
+
|
23 |
+
We typically do finetuning on 2 gpus (`local_small`).
|
24 |
+
|
25 |
+
### Testing
|
26 |
+
For each finetuning dataset, you may need to specify a testing config, similar to `projects/task/test_vtt.yaml`.
|
27 |
+
|
28 |
+
We define `mmpt.evaluators.Predictor` for different types of prediction. For example, `MSRVTT` and `Youcook` are video-retrieval tasks and expecting to use `RetrievalPredictor`. You may need to define your new type of predictors and specify that in `predictor` field of a testing config.
|
29 |
+
|
30 |
+
Each task may also have their own metric for evaluation. This can be created in `mmpt.evaluators.Metric` and specified in the `metric` field of a testing config.
|
31 |
+
|
32 |
+
Launching a testing is as simple as training by specifying the path of a testing config:
|
33 |
+
```python locallaunch.py projects/mfmmlm/test_vtt.yaml```
|
34 |
+
Testing will be launched locally by default since prediction is computationally less expensive.
|
35 |
+
|
36 |
+
### Third-party Libraries
|
37 |
+
We list the following finetuning tasks that require third-party libraries.
|
38 |
+
|
39 |
+
Youcook captioning: `https://github.com/Maluuba/nlg-eval`
|
40 |
+
|
41 |
+
CrossTask: `https://github.com/DmZhukov/CrossTask`'s `dp` under `third-party/CrossTask` (`python setup.py build_ext --inplace`)
|
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/locallaunch.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
import argparse
|
6 |
+
import os
|
7 |
+
|
8 |
+
from omegaconf import OmegaConf
|
9 |
+
|
10 |
+
from mmpt.utils import recursive_config, overwrite_dir
|
11 |
+
from mmpt_cli.localjob import LocalJob
|
12 |
+
|
13 |
+
|
14 |
+
class JobLauncher(object):
|
15 |
+
JOB_CONFIG = {
|
16 |
+
"local": LocalJob,
|
17 |
+
}
|
18 |
+
|
19 |
+
def __init__(self, yaml_file):
|
20 |
+
self.yaml_file = yaml_file
|
21 |
+
job_key = "local"
|
22 |
+
|
23 |
+
if yaml_file.endswith(".yaml"):
|
24 |
+
config = recursive_config(yaml_file)
|
25 |
+
if config.task_type is not None:
|
26 |
+
job_key = config.task_type.split("_")[0]
|
27 |
+
else:
|
28 |
+
raise ValueError("unknown extension of job file:", yaml_file)
|
29 |
+
self.job_key = job_key
|
30 |
+
|
31 |
+
def __call__(self, job_type=None, dryrun=False):
|
32 |
+
if job_type is not None:
|
33 |
+
self.job_key = job_type.split("_")[0]
|
34 |
+
print("[JobLauncher] job_key", self.job_key)
|
35 |
+
job = JobLauncher.JOB_CONFIG[self.job_key](
|
36 |
+
self.yaml_file, job_type=job_type, dryrun=dryrun)
|
37 |
+
return job.submit()
|
38 |
+
|
39 |
+
|
40 |
+
class Pipeline(object):
|
41 |
+
"""a job that loads yaml config."""
|
42 |
+
|
43 |
+
def __init__(self, fn):
|
44 |
+
"""
|
45 |
+
load a yaml config of a job and save generated configs as yaml for each task.
|
46 |
+
return: a list of files to run as specified by `run_task`.
|
47 |
+
"""
|
48 |
+
if fn.endswith(".py"):
|
49 |
+
# a python command.
|
50 |
+
self.backend = "python"
|
51 |
+
self.run_yamls = [fn]
|
52 |
+
return
|
53 |
+
|
54 |
+
job_config = recursive_config(fn)
|
55 |
+
if job_config.base_dir is None: # single file job config.
|
56 |
+
self.run_yamls = [fn]
|
57 |
+
return
|
58 |
+
|
59 |
+
self.project_dir = os.path.join("projects", job_config.project_dir)
|
60 |
+
self.run_dir = os.path.join("runs", job_config.project_dir)
|
61 |
+
|
62 |
+
if job_config.run_task is not None:
|
63 |
+
run_yamls = []
|
64 |
+
for stage in job_config.run_task:
|
65 |
+
# each stage can have multiple tasks running in parallel.
|
66 |
+
if OmegaConf.is_list(stage):
|
67 |
+
stage_yamls = []
|
68 |
+
for task_file in stage:
|
69 |
+
stage_yamls.append(
|
70 |
+
os.path.join(self.project_dir, task_file))
|
71 |
+
run_yamls.append(stage_yamls)
|
72 |
+
else:
|
73 |
+
run_yamls.append(os.path.join(self.project_dir, stage))
|
74 |
+
self.run_yamls = run_yamls
|
75 |
+
configs_to_save = self._overwrite_task(job_config)
|
76 |
+
self._save_configs(configs_to_save)
|
77 |
+
|
78 |
+
def __getitem__(self, idx):
|
79 |
+
yaml_files = self.run_yamls[idx]
|
80 |
+
if isinstance(yaml_files, list):
|
81 |
+
return [JobLauncher(yaml_file) for yaml_file in yaml_files]
|
82 |
+
return [JobLauncher(yaml_files)]
|
83 |
+
|
84 |
+
def __len__(self):
|
85 |
+
return len(self.run_yamls)
|
86 |
+
|
87 |
+
def _save_configs(self, configs_to_save: dict):
|
88 |
+
# save
|
89 |
+
os.makedirs(self.project_dir, exist_ok=True)
|
90 |
+
for config_file in configs_to_save:
|
91 |
+
config = configs_to_save[config_file]
|
92 |
+
print("saving", config_file)
|
93 |
+
OmegaConf.save(config=config, f=config_file)
|
94 |
+
|
95 |
+
def _overwrite_task(self, job_config):
|
96 |
+
configs_to_save = {}
|
97 |
+
self.base_project_dir = os.path.join("projects", job_config.base_dir)
|
98 |
+
self.base_run_dir = os.path.join("runs", job_config.base_dir)
|
99 |
+
|
100 |
+
for config_sets in job_config.task_group:
|
101 |
+
overwrite_config = job_config.task_group[config_sets]
|
102 |
+
if (
|
103 |
+
overwrite_config.task_list is None
|
104 |
+
or len(overwrite_config.task_list) == 0
|
105 |
+
):
|
106 |
+
print(
|
107 |
+
"[warning]",
|
108 |
+
job_config.task_group,
|
109 |
+
"has no task_list specified.")
|
110 |
+
# we don't want this added to a final config.
|
111 |
+
task_list = overwrite_config.pop("task_list", None)
|
112 |
+
for config_file in task_list:
|
113 |
+
config_file_path = os.path.join(
|
114 |
+
self.base_project_dir, config_file)
|
115 |
+
config = recursive_config(config_file_path)
|
116 |
+
# overwrite it.
|
117 |
+
if overwrite_config:
|
118 |
+
config = OmegaConf.merge(config, overwrite_config)
|
119 |
+
overwrite_dir(config, self.run_dir, basedir=self.base_run_dir)
|
120 |
+
save_file_path = os.path.join(self.project_dir, config_file)
|
121 |
+
configs_to_save[save_file_path] = config
|
122 |
+
return configs_to_save
|
123 |
+
|
124 |
+
|
125 |
+
def main(args):
|
126 |
+
job_type = args.jobtype if args.jobtype else None
|
127 |
+
# parse multiple pipelines.
|
128 |
+
pipelines = [Pipeline(fn) for fn in args.yamls.split(",")]
|
129 |
+
|
130 |
+
for pipe_id, pipeline in enumerate(pipelines):
|
131 |
+
if not hasattr(pipeline, "project_dir"):
|
132 |
+
for job in pipeline[0]:
|
133 |
+
job(job_type=job_type, dryrun=args.dryrun)
|
134 |
+
|
135 |
+
|
136 |
+
if __name__ == "__main__":
|
137 |
+
parser = argparse.ArgumentParser()
|
138 |
+
parser.add_argument("yamls", type=str)
|
139 |
+
parser.add_argument(
|
140 |
+
"--dryrun",
|
141 |
+
action="store_true",
|
142 |
+
help="run config and prepare to submit without launch the job.",
|
143 |
+
)
|
144 |
+
parser.add_argument(
|
145 |
+
"--jobtype", type=str, default="",
|
146 |
+
help="force to run jobs as specified.")
|
147 |
+
args = parser.parse_args()
|
148 |
+
main(args)
|
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
try:
|
6 |
+
# fairseq user dir
|
7 |
+
from .datasets import FairseqMMDataset
|
8 |
+
from .losses import FairseqCriterion
|
9 |
+
from .models import FairseqMMModel
|
10 |
+
from .tasks import FairseqMMTask
|
11 |
+
except ImportError:
|
12 |
+
pass
|
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/datasets/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
from .mmdataset import *
|
6 |
+
|
7 |
+
try:
|
8 |
+
from .fairseqmmdataset import *
|
9 |
+
except ImportError:
|
10 |
+
pass
|
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/datasets/fairseqmmdataset.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
"""
|
6 |
+
TODO (huxu): fairseq wrapper class for all dataset you defined: mostly MMDataset.
|
7 |
+
"""
|
8 |
+
|
9 |
+
from collections import OrderedDict
|
10 |
+
|
11 |
+
from torch.utils.data import Dataset
|
12 |
+
from torch.utils.data.dataloader import default_collate
|
13 |
+
from fairseq.data import FairseqDataset, data_utils
|
14 |
+
|
15 |
+
|
16 |
+
class FairseqMMDataset(FairseqDataset):
|
17 |
+
"""
|
18 |
+
A wrapper class for MMDataset for fairseq.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, mmdataset):
|
22 |
+
if not isinstance(mmdataset, Dataset):
|
23 |
+
raise TypeError("mmdataset must be of type `torch.utils.data.dataset`.")
|
24 |
+
self.mmdataset = mmdataset
|
25 |
+
|
26 |
+
def set_epoch(self, epoch, **unused):
|
27 |
+
super().set_epoch(epoch)
|
28 |
+
self.epoch = epoch
|
29 |
+
|
30 |
+
def __getitem__(self, idx):
|
31 |
+
with data_utils.numpy_seed(43211, self.epoch, idx):
|
32 |
+
return self.mmdataset[idx]
|
33 |
+
|
34 |
+
def __len__(self):
|
35 |
+
return len(self.mmdataset)
|
36 |
+
|
37 |
+
def collater(self, samples):
|
38 |
+
if hasattr(self.mmdataset, "collator"):
|
39 |
+
return self.mmdataset.collator(samples)
|
40 |
+
if len(samples) == 0:
|
41 |
+
return {}
|
42 |
+
if isinstance(samples[0], dict):
|
43 |
+
batch = OrderedDict()
|
44 |
+
for key in samples[0]:
|
45 |
+
if samples[0][key] is not None:
|
46 |
+
batch[key] = default_collate([sample[key] for sample in samples])
|
47 |
+
return batch
|
48 |
+
else:
|
49 |
+
return default_collate(samples)
|
50 |
+
|
51 |
+
def size(self, index):
|
52 |
+
"""dummy implementation: we don't use --max-tokens"""
|
53 |
+
return 1
|
54 |
+
|
55 |
+
def num_tokens(self, index):
|
56 |
+
"""dummy implementation: we don't use --max-tokens"""
|
57 |
+
return 1
|
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/datasets/mmdataset.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from collections import OrderedDict
|
9 |
+
|
10 |
+
from torch.utils.data import Dataset
|
11 |
+
from torch.utils.data.dataloader import default_collate
|
12 |
+
|
13 |
+
from ..utils import set_seed
|
14 |
+
|
15 |
+
|
16 |
+
class MMDataset(Dataset):
|
17 |
+
"""
|
18 |
+
A generic multi-modal dataset.
|
19 |
+
Args:
|
20 |
+
`meta_processor`: a meta processor,
|
21 |
+
handling loading meta data and return video_id and text_id.
|
22 |
+
`video_processor`: a video processor,
|
23 |
+
handling e.g., decoding, loading .np files.
|
24 |
+
`text_processor`: a text processor,
|
25 |
+
handling e.g., tokenization.
|
26 |
+
`aligner`: combine the video and text feature
|
27 |
+
as one training example.
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
meta_processor,
|
33 |
+
video_processor,
|
34 |
+
text_processor,
|
35 |
+
align_processor,
|
36 |
+
):
|
37 |
+
self.split = meta_processor.split
|
38 |
+
self.meta_processor = meta_processor
|
39 |
+
self.video_processor = video_processor
|
40 |
+
self.text_processor = text_processor
|
41 |
+
self.align_processor = align_processor
|
42 |
+
|
43 |
+
def __len__(self):
|
44 |
+
return len(self.meta_processor)
|
45 |
+
|
46 |
+
def __getitem__(self, idx):
|
47 |
+
if self.split == "test":
|
48 |
+
set_seed(idx)
|
49 |
+
video_id, text_id = self.meta_processor[idx]
|
50 |
+
video_feature = self.video_processor(video_id)
|
51 |
+
text_feature = self.text_processor(text_id)
|
52 |
+
output = self.align_processor(video_id, video_feature, text_feature)
|
53 |
+
# TODO (huxu): the following is for debug purpose.
|
54 |
+
output.update({"idx": idx})
|
55 |
+
return output
|
56 |
+
|
57 |
+
def collater(self, samples):
|
58 |
+
"""This collator is deprecated.
|
59 |
+
set self.collator = MMDataset.collater.
|
60 |
+
see collator in FairseqMMDataset.
|
61 |
+
"""
|
62 |
+
|
63 |
+
if len(samples) == 0:
|
64 |
+
return {}
|
65 |
+
if isinstance(samples[0], dict):
|
66 |
+
batch = OrderedDict()
|
67 |
+
for key in samples[0]:
|
68 |
+
if samples[0][key] is not None:
|
69 |
+
batch[key] = default_collate(
|
70 |
+
[sample[key] for sample in samples])
|
71 |
+
# if torch.is_tensor(batch[key]):
|
72 |
+
# print(key, batch[key].size())
|
73 |
+
# else:
|
74 |
+
# print(key, len(batch[key]))
|
75 |
+
return batch
|
76 |
+
else:
|
77 |
+
return default_collate(samples)
|
78 |
+
|
79 |
+
def print_example(self, output):
|
80 |
+
print("[one example]", output["video_id"])
|
81 |
+
if (
|
82 |
+
hasattr(self.align_processor, "subsampling")
|
83 |
+
and self.align_processor.subsampling is not None
|
84 |
+
and self.align_processor.subsampling > 1
|
85 |
+
):
|
86 |
+
for key in output:
|
87 |
+
if torch.is_tensor(output[key]):
|
88 |
+
output[key] = output[key][0]
|
89 |
+
|
90 |
+
# search tokenizer to translate ids back.
|
91 |
+
tokenizer = None
|
92 |
+
if hasattr(self.text_processor, "tokenizer"):
|
93 |
+
tokenizer = self.text_processor.tokenizer
|
94 |
+
elif hasattr(self.align_processor, "tokenizer"):
|
95 |
+
tokenizer = self.align_processor.tokenizer
|
96 |
+
if tokenizer is not None:
|
97 |
+
caps = output["caps"].tolist()
|
98 |
+
if isinstance(caps[0], list):
|
99 |
+
caps = caps[0]
|
100 |
+
print("caps", tokenizer.decode(caps))
|
101 |
+
print("caps", tokenizer.convert_ids_to_tokens(caps))
|
102 |
+
|
103 |
+
for key, value in output.items():
|
104 |
+
if torch.is_tensor(value):
|
105 |
+
if len(value.size()) >= 3: # attention_mask.
|
106 |
+
print(key, value.size())
|
107 |
+
print(key, "first", value[0, :, :])
|
108 |
+
print(key, "last", value[-1, :, :])
|
109 |
+
else:
|
110 |
+
print(key, value)
|
111 |
+
print("[end of one example]")
|
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/evaluators/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
from .metric import *
|
6 |
+
from .evaluator import *
|
7 |
+
|
8 |
+
|
9 |
+
# experimental.
|
10 |
+
try:
|
11 |
+
from .expmetric import *
|
12 |
+
except ImportError:
|
13 |
+
pass
|
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/evaluators/evaluator.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
import os
|
6 |
+
import glob
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from . import metric as metric_path
|
10 |
+
from . import predictor as predictor_path
|
11 |
+
|
12 |
+
|
13 |
+
class Evaluator(object):
|
14 |
+
"""
|
15 |
+
perform evaluation on a single (downstream) task.
|
16 |
+
make this both offline and online.
|
17 |
+
TODO(huxu) saving evaluation results.
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, config, eval_dataloader=None):
|
21 |
+
if config.metric is None:
|
22 |
+
raise ValueError("config.metric is", config.metric)
|
23 |
+
metric_cls = getattr(metric_path, config.metric)
|
24 |
+
self.metric = metric_cls(config)
|
25 |
+
if config.predictor is None:
|
26 |
+
raise ValueError("config.predictor is", config.predictor)
|
27 |
+
predictor_cls = getattr(predictor_path, config.predictor)
|
28 |
+
self.predictor = predictor_cls(config)
|
29 |
+
self.eval_dataloader = eval_dataloader
|
30 |
+
|
31 |
+
def __call__(self):
|
32 |
+
try:
|
33 |
+
print(self.predictor.pred_dir)
|
34 |
+
for pred_file in glob.glob(
|
35 |
+
self.predictor.pred_dir + "/*_merged.npy"):
|
36 |
+
outputs = np.load(pred_file)
|
37 |
+
results = self.metric.compute_metrics(outputs)
|
38 |
+
self.metric.print_computed_metrics(results)
|
39 |
+
|
40 |
+
outputs = np.load(os.path.join(
|
41 |
+
self.predictor.pred_dir, "merged.npy"))
|
42 |
+
results = self.metric.compute_metrics(outputs)
|
43 |
+
return {"results": results, "metric": self.metric}
|
44 |
+
except FileNotFoundError:
|
45 |
+
print("\n[missing]", self.predictor.pred_dir)
|
46 |
+
return {}
|
47 |
+
|
48 |
+
def evaluate(self, model, eval_dataloader=None, output_file="merged"):
|
49 |
+
if eval_dataloader is None:
|
50 |
+
eval_dataloader = self.eval_dataloader
|
51 |
+
outputs = self.predictor.predict_loop(
|
52 |
+
model, eval_dataloader, output_file)
|
53 |
+
results = self.metric.compute_metrics(**outputs)
|
54 |
+
return results
|
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/evaluators/metric.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import json
|
8 |
+
|
9 |
+
|
10 |
+
class Metric(object):
|
11 |
+
def __init__(self, config, metric_names):
|
12 |
+
self.metric_names = metric_names
|
13 |
+
|
14 |
+
def best_metric(self, metric):
|
15 |
+
return metric[self.metric_names[0]]
|
16 |
+
|
17 |
+
def save_metrics(self, fn, metrics):
|
18 |
+
with open(fn, "w") as fw:
|
19 |
+
json.dump(fw, metrics)
|
20 |
+
|
21 |
+
def print_computed_metrics(self, metrics):
|
22 |
+
raise NotImplementedError
|
23 |
+
|
24 |
+
|
25 |
+
class RetrievalMetric(Metric):
|
26 |
+
"""
|
27 |
+
this is modified from `howto100m/metrics.py`.
|
28 |
+
History of changes:
|
29 |
+
refactor as a class.
|
30 |
+
add metric_key in __init__
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(self, config, metric_names=["R1", "R5", "R10", "MR"]):
|
34 |
+
super().__init__(config, metric_names)
|
35 |
+
self.error = False # TODO(huxu): add to config to print error.
|
36 |
+
|
37 |
+
def compute_metrics(self, outputs, texts, **kwargs):
|
38 |
+
x = outputs
|
39 |
+
sx = np.sort(-x, axis=1)
|
40 |
+
d = np.diag(-x)
|
41 |
+
d = d[:, np.newaxis]
|
42 |
+
ind = sx - d
|
43 |
+
ind = np.where(ind == 0)
|
44 |
+
ind = ind[1]
|
45 |
+
metrics = {}
|
46 |
+
metrics["R1"] = float(np.sum(ind == 0)) / len(ind)
|
47 |
+
metrics["R5"] = float(np.sum(ind < 5)) / len(ind)
|
48 |
+
metrics["R10"] = float(np.sum(ind < 10)) / len(ind)
|
49 |
+
metrics["MR"] = np.median(ind) + 1
|
50 |
+
|
51 |
+
max_idx = np.argmax(outputs, axis=1)
|
52 |
+
if self.error:
|
53 |
+
# print top-20 errors.
|
54 |
+
error = []
|
55 |
+
for ex_idx in range(20):
|
56 |
+
error.append((texts[ex_idx], texts[max_idx[ex_idx]]))
|
57 |
+
metrics["error"] = error
|
58 |
+
return metrics
|
59 |
+
|
60 |
+
def print_computed_metrics(self, metrics):
|
61 |
+
r1 = metrics["R1"]
|
62 |
+
r5 = metrics["R5"]
|
63 |
+
r10 = metrics["R10"]
|
64 |
+
mr = metrics["MR"]
|
65 |
+
print(
|
66 |
+
"R@1: {:.4f} - R@5: {:.4f} - R@10: {:.4f} - Median R: {}".format(
|
67 |
+
r1, r5, r10, mr
|
68 |
+
)
|
69 |
+
)
|
70 |
+
if "error" in metrics:
|
71 |
+
print(metrics["error"])
|
72 |
+
|
73 |
+
|
74 |
+
class DiDeMoMetric(Metric):
|
75 |
+
"""
|
76 |
+
History of changes:
|
77 |
+
python 2.x to python 3.x.
|
78 |
+
merge utils.py into eval to save one file.
|
79 |
+
reference: https://github.com/LisaAnne/LocalizingMoments/blob/master/utils/eval.py
|
80 |
+
Code to evaluate your results on the DiDeMo dataset.
|
81 |
+
"""
|
82 |
+
def __init__(self, config, metric_names=["rank1", "rank5", "miou"]):
|
83 |
+
super().__init__(config, metric_names)
|
84 |
+
|
85 |
+
def compute_metrics(self, outputs, targets, **kwargs):
|
86 |
+
assert len(outputs) == len(targets)
|
87 |
+
rank1, rank5, miou = self._eval_predictions(outputs, targets)
|
88 |
+
metrics = {
|
89 |
+
"rank1": rank1,
|
90 |
+
"rank5": rank5,
|
91 |
+
"miou": miou
|
92 |
+
}
|
93 |
+
return metrics
|
94 |
+
|
95 |
+
def print_computed_metrics(self, metrics):
|
96 |
+
rank1 = metrics["rank1"]
|
97 |
+
rank5 = metrics["rank5"]
|
98 |
+
miou = metrics["miou"]
|
99 |
+
# print("Average rank@1: %f" % rank1)
|
100 |
+
# print("Average rank@5: %f" % rank5)
|
101 |
+
# print("Average iou: %f" % miou)
|
102 |
+
|
103 |
+
print(
|
104 |
+
"Average rank@1: {:.4f} Average rank@5: {:.4f} Average iou: {:.4f}".format(
|
105 |
+
rank1, rank5, miou
|
106 |
+
)
|
107 |
+
)
|
108 |
+
|
109 |
+
def _iou(self, pred, gt):
|
110 |
+
intersection = max(0, min(pred[1], gt[1]) + 1 - max(pred[0], gt[0]))
|
111 |
+
union = max(pred[1], gt[1]) + 1 - min(pred[0], gt[0])
|
112 |
+
return float(intersection)/union
|
113 |
+
|
114 |
+
def _rank(self, pred, gt):
|
115 |
+
return pred.index(tuple(gt)) + 1
|
116 |
+
|
117 |
+
def _eval_predictions(self, segments, data):
|
118 |
+
'''
|
119 |
+
Inputs:
|
120 |
+
segments: For each item in the ground truth data, rank possible video segments given the description and video.
|
121 |
+
In DiDeMo, there are 21 posible moments extracted for each video so the list of video segments will be of length 21.
|
122 |
+
The first video segment should be the video segment that best corresponds to the text query.
|
123 |
+
There are 4180 sentence in the validation data, so when evaluating a model on the val dataset,
|
124 |
+
segments should be a list of lenght 4180, and each item in segments should be a list of length 21.
|
125 |
+
data: ground truth data
|
126 |
+
'''
|
127 |
+
average_ranks = []
|
128 |
+
average_iou = []
|
129 |
+
for s, d in zip(segments, data):
|
130 |
+
pred = s[0]
|
131 |
+
ious = [self._iou(pred, t) for t in d['times']]
|
132 |
+
average_iou.append(np.mean(np.sort(ious)[-3:]))
|
133 |
+
ranks = [self._rank(s, t) for t in d['times'] if tuple(t) in s] # if t in s] is added for s, e not in prediction.
|
134 |
+
average_ranks.append(np.mean(np.sort(ranks)[:3]))
|
135 |
+
rank1 = np.sum(np.array(average_ranks) <= 1)/float(len(average_ranks))
|
136 |
+
rank5 = np.sum(np.array(average_ranks) <= 5)/float(len(average_ranks))
|
137 |
+
miou = np.mean(average_iou)
|
138 |
+
|
139 |
+
# print("Average rank@1: %f" % rank1)
|
140 |
+
# print("Average rank@5: %f" % rank5)
|
141 |
+
# print("Average iou: %f" % miou)
|
142 |
+
return rank1, rank5, miou
|
143 |
+
|
144 |
+
|
145 |
+
class NLGMetric(Metric):
|
146 |
+
def __init__(
|
147 |
+
self,
|
148 |
+
config,
|
149 |
+
metric_names=[
|
150 |
+
"Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4",
|
151 |
+
"METEOR", "ROUGE_L", "CIDEr"
|
152 |
+
]
|
153 |
+
):
|
154 |
+
super().__init__(config, metric_names)
|
155 |
+
# please install NLGEval from `https://github.com/Maluuba/nlg-eval`
|
156 |
+
from nlgeval import NLGEval
|
157 |
+
self.nlg = NLGEval()
|
158 |
+
|
159 |
+
def compute_metrics(self, outputs, targets, **kwargs):
|
160 |
+
return self.nlg.compute_metrics(
|
161 |
+
hyp_list=outputs, ref_list=targets)
|
162 |
+
|
163 |
+
def print_computed_metrics(self, metrics):
|
164 |
+
Bleu_1 = metrics["Bleu_1"]
|
165 |
+
Bleu_2 = metrics["Bleu_2"]
|
166 |
+
Bleu_3 = metrics["Bleu_3"]
|
167 |
+
Bleu_4 = metrics["Bleu_4"]
|
168 |
+
METEOR = metrics["METEOR"]
|
169 |
+
ROUGE_L = metrics["ROUGE_L"]
|
170 |
+
CIDEr = metrics["CIDEr"]
|
171 |
+
|
172 |
+
print(
|
173 |
+
"Bleu_1: {:.4f} - Bleu_2: {:.4f} - Bleu_3: {:.4f} - Bleu_4: {:.4f} - METEOR: {:.4f} - ROUGE_L: {:.4f} - CIDEr: {:.4f}".format(
|
174 |
+
Bleu_1, Bleu_2, Bleu_3, Bleu_4, METEOR, ROUGE_L, CIDEr
|
175 |
+
)
|
176 |
+
)
|
177 |
+
|
178 |
+
|
179 |
+
class QAMetric(Metric):
|
180 |
+
def __init__(
|
181 |
+
self,
|
182 |
+
config,
|
183 |
+
metric_names=["acc"]
|
184 |
+
):
|
185 |
+
super().__init__(config, metric_names)
|
186 |
+
|
187 |
+
def compute_metrics(self, outputs, targets, **kwargs):
|
188 |
+
from sklearn.metrics import accuracy_score
|
189 |
+
return {"acc": accuracy_score(targets, outputs)}
|
190 |
+
|
191 |
+
def print_computed_metrics(self, metrics):
|
192 |
+
print("acc: {:.4f}".format(metrics["acc"]))
|
193 |
+
|
194 |
+
|
195 |
+
class COINActionSegmentationMetric(Metric):
|
196 |
+
"""
|
197 |
+
COIN dataset listed 3 repos for Action Segmentation.
|
198 |
+
Action Sets, NeuralNetwork-Viterbi, TCFPN-ISBA.
|
199 |
+
The first and second are the same.
|
200 |
+
https://github.com/alexanderrichard/action-sets/blob/master/eval.py
|
201 |
+
|
202 |
+
Future reference for the third:
|
203 |
+
`https://github.com/Zephyr-D/TCFPN-ISBA/blob/master/utils/metrics.py`
|
204 |
+
"""
|
205 |
+
def __init__(self, config, metric_name=["frame_acc"]):
|
206 |
+
super().__init__(config, metric_name)
|
207 |
+
|
208 |
+
def compute_metrics(self, outputs, targets):
|
209 |
+
n_frames = 0
|
210 |
+
n_errors = 0
|
211 |
+
n_errors = sum(outputs != targets)
|
212 |
+
n_frames = len(targets)
|
213 |
+
return {"frame_acc": 1.0 - float(n_errors) / n_frames}
|
214 |
+
|
215 |
+
def print_computed_metrics(self, metrics):
|
216 |
+
fa = metrics["frame_acc"]
|
217 |
+
print("frame accuracy:", fa)
|
218 |
+
|
219 |
+
|
220 |
+
class CrossTaskMetric(Metric):
|
221 |
+
def __init__(self, config, metric_names=["recall"]):
|
222 |
+
super().__init__(config, metric_names)
|
223 |
+
|
224 |
+
def compute_metrics(self, outputs, targets, **kwargs):
|
225 |
+
"""refactored from line 166:
|
226 |
+
https://github.com/DmZhukov/CrossTask/blob/master/train.py"""
|
227 |
+
|
228 |
+
recalls = self._get_recalls(Y_true=targets, Y_pred=outputs)
|
229 |
+
results = {}
|
230 |
+
for task, rec in recalls.items():
|
231 |
+
results[str(task)] = rec
|
232 |
+
|
233 |
+
avg_recall = np.mean(list(recalls.values()))
|
234 |
+
results["recall"] = avg_recall
|
235 |
+
return results
|
236 |
+
|
237 |
+
def print_computed_metrics(self, metrics):
|
238 |
+
print('Recall: {0:0.3f}'.format(metrics["recall"]))
|
239 |
+
for task in metrics:
|
240 |
+
if task != "recall":
|
241 |
+
print('Task {0}. Recall = {1:0.3f}'.format(
|
242 |
+
task, metrics[task]))
|
243 |
+
|
244 |
+
def _get_recalls(self, Y_true, Y_pred):
|
245 |
+
"""refactored from
|
246 |
+
https://github.com/DmZhukov/CrossTask/blob/master/train.py"""
|
247 |
+
|
248 |
+
step_match = {task: 0 for task in Y_true.keys()}
|
249 |
+
step_total = {task: 0 for task in Y_true.keys()}
|
250 |
+
for task, ys_true in Y_true.items():
|
251 |
+
ys_pred = Y_pred[task]
|
252 |
+
for vid in set(ys_pred.keys()).intersection(set(ys_true.keys())):
|
253 |
+
y_true = ys_true[vid]
|
254 |
+
y_pred = ys_pred[vid]
|
255 |
+
step_total[task] += (y_true.sum(axis=0) > 0).sum()
|
256 |
+
step_match[task] += (y_true*y_pred).sum()
|
257 |
+
recalls = {
|
258 |
+
task: step_match[task] / n for task, n in step_total.items()}
|
259 |
+
return recalls
|
260 |
+
|
261 |
+
|
262 |
+
class ActionRecognitionMetric(Metric):
|
263 |
+
def __init__(
|
264 |
+
self,
|
265 |
+
config,
|
266 |
+
metric_names=["acc", "acc_splits", "r1_splits", "r5_splits", "r10_splits"]
|
267 |
+
):
|
268 |
+
super().__init__(config, metric_names)
|
269 |
+
|
270 |
+
def compute_metrics(self, outputs, targets, splits, **kwargs):
|
271 |
+
all_video_embd = outputs
|
272 |
+
labels = targets
|
273 |
+
split1, split2, split3 = splits
|
274 |
+
accs = []
|
275 |
+
r1s = []
|
276 |
+
r5s = []
|
277 |
+
r10s = []
|
278 |
+
for split in range(3):
|
279 |
+
if split == 0:
|
280 |
+
s = split1
|
281 |
+
elif split == 1:
|
282 |
+
s = split2
|
283 |
+
else:
|
284 |
+
s = split3
|
285 |
+
|
286 |
+
X_pred = all_video_embd[np.where(s == 2)[0]]
|
287 |
+
label_test = labels[np.where(s == 2)[0]]
|
288 |
+
logits = X_pred
|
289 |
+
X_pred = np.argmax(X_pred, axis=1)
|
290 |
+
acc = np.sum(X_pred == label_test) / float(len(X_pred))
|
291 |
+
accs.append(acc)
|
292 |
+
# compute recall.
|
293 |
+
sorted_pred = (-logits).argsort(axis=-1)
|
294 |
+
label_test_sp = label_test.reshape(-1, 1)
|
295 |
+
|
296 |
+
r1 = np.mean((sorted_pred[:, :1] == label_test_sp).sum(axis=1), axis=0)
|
297 |
+
r5 = np.mean((sorted_pred[:, :5] == label_test_sp).sum(axis=1), axis=0)
|
298 |
+
r10 = np.mean((sorted_pred[:, :10] == label_test_sp).sum(axis=1), axis=0)
|
299 |
+
r1s.append(r1)
|
300 |
+
r5s.append(r5)
|
301 |
+
r10s.append(r10)
|
302 |
+
|
303 |
+
return {"acc": accs[0], "acc_splits": accs, "r1_splits": r1s, "r5_splits": r5s, "r10_splits": r10s}
|
304 |
+
|
305 |
+
def print_computed_metrics(self, metrics):
|
306 |
+
for split, acc in enumerate(metrics["acc_splits"]):
|
307 |
+
print("Top 1 accuracy on split {}: {}; r1 {}; r5 {}; r10 {}".format(
|
308 |
+
split + 1, acc,
|
309 |
+
metrics["r1_splits"][split],
|
310 |
+
metrics["r5_splits"][split],
|
311 |
+
metrics["r10_splits"][split],
|
312 |
+
)
|
313 |
+
)
|
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/evaluators/predictor.py
ADDED
@@ -0,0 +1,595 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
import json
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import pickle
|
11 |
+
import math
|
12 |
+
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
|
16 |
+
class Predictor(object):
|
17 |
+
"""this base class is used to save predictions to disk
|
18 |
+
(and being called by a evaluator later).
|
19 |
+
Predictor has minimum support of single gpu prediction.
|
20 |
+
"""
|
21 |
+
def __init__(self, config):
|
22 |
+
self.pred_dir = None # on-the-fly eval does not save the results.
|
23 |
+
if hasattr(config, "eval") and config.eval is not None:
|
24 |
+
self.pred_dir = config.eval.save_path
|
25 |
+
os.makedirs(self.pred_dir, exist_ok=True)
|
26 |
+
|
27 |
+
def __call__(self, outputs):
|
28 |
+
"""extract the prediction and save it."""
|
29 |
+
raise NotImplementedError
|
30 |
+
|
31 |
+
def predict_loop(self, model, eval_dataloader, output_file=None):
|
32 |
+
"""on-the-fly prediction on a single gpu."""
|
33 |
+
self.full_scores = []
|
34 |
+
model.eval()
|
35 |
+
model = model.to(0)
|
36 |
+
with torch.no_grad():
|
37 |
+
for data in eval_dataloader:
|
38 |
+
data = self.to_ctx(data)
|
39 |
+
outputs = model(**data)
|
40 |
+
outputs.update(data)
|
41 |
+
self(outputs)
|
42 |
+
return self.finalize(output_file)
|
43 |
+
|
44 |
+
def finalize(self, output_file):
|
45 |
+
pass
|
46 |
+
|
47 |
+
def to_ctx(self, data, ctx=0, dtype=None):
|
48 |
+
if isinstance(data, dict):
|
49 |
+
for key in data:
|
50 |
+
if torch.is_tensor(data[key]):
|
51 |
+
if dtype is not None and data[key].dtype == torch.float32:
|
52 |
+
data[key] = data[key].to(dtype)
|
53 |
+
data[key] = data[key].to(ctx)
|
54 |
+
return data
|
55 |
+
else:
|
56 |
+
raise ValueError("non-dict type of batch is not supported yet.")
|
57 |
+
|
58 |
+
|
59 |
+
class NLGPredictor(Predictor):
|
60 |
+
"""Predicting Text from MMFusion models."""
|
61 |
+
"""TODO: make a context."""
|
62 |
+
def __init__(self, config):
|
63 |
+
super().__init__(config)
|
64 |
+
from transformers import AutoTokenizer
|
65 |
+
|
66 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
67 |
+
config.dataset.bert_name,
|
68 |
+
bos_token="[CLS]", eos_token="[SEP]")
|
69 |
+
self.bos_token_id = self.tokenizer.bos_token_id
|
70 |
+
self.eos_token_id = self.tokenizer.eos_token_id
|
71 |
+
|
72 |
+
def predict_loop(self, model, eval_dataloader, output_file=None):
|
73 |
+
"""TODO: refactor base classes."""
|
74 |
+
ctx = 0
|
75 |
+
outputs = {"outputs": [], "targets": [[]]}
|
76 |
+
model.eval()
|
77 |
+
model = model.to(ctx)
|
78 |
+
with torch.no_grad():
|
79 |
+
for data in tqdm(eval_dataloader):
|
80 |
+
data = self.to_ctx(data, ctx)
|
81 |
+
self(data, model, outputs)
|
82 |
+
return self.finalize(outputs, output_file)
|
83 |
+
|
84 |
+
def __call__(self, data, model, outputs):
|
85 |
+
data.update({
|
86 |
+
"bos_token_id": self.bos_token_id,
|
87 |
+
"eos_token_id": self.eos_token_id
|
88 |
+
})
|
89 |
+
|
90 |
+
output = model.generate(**data)
|
91 |
+
assert len(output) == len(data["ref"])
|
92 |
+
for idx, _output in enumerate(output):
|
93 |
+
generated_text = self.tokenizer.decode(
|
94 |
+
_output, skip_special_tokens=True)
|
95 |
+
if generated_text == "":
|
96 |
+
generated_text = "none"
|
97 |
+
outputs["outputs"].append(generated_text)
|
98 |
+
outputs["targets"][0].append(data["ref"][idx])
|
99 |
+
if random.random() < 0.001:
|
100 |
+
print("_output", _output)
|
101 |
+
print("generated_text", generated_text)
|
102 |
+
print("ref", data["ref"][idx])
|
103 |
+
|
104 |
+
def finalize(self, outputs, output_file=None):
|
105 |
+
if output_file is not None:
|
106 |
+
with open(os.path.join(
|
107 |
+
self.pred_dir, output_file + ".json"), "w") as fw:
|
108 |
+
json.dump(outputs, fw, indent=4)
|
109 |
+
return outputs
|
110 |
+
|
111 |
+
|
112 |
+
class RetrievalPredictor(Predictor):
|
113 |
+
"""generated `pooled_video` and `pooled_text`."""
|
114 |
+
def __init__(self, config):
|
115 |
+
super().__init__(config)
|
116 |
+
from transformers import AutoTokenizer
|
117 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
118 |
+
config.dataset.bert_name)
|
119 |
+
|
120 |
+
def predict_loop(
|
121 |
+
self,
|
122 |
+
model,
|
123 |
+
eval_dataloader,
|
124 |
+
output_file="retrieval.npy"
|
125 |
+
):
|
126 |
+
"""on-the-fly prediction on a single gpu."""
|
127 |
+
full_scores = []
|
128 |
+
texts = []
|
129 |
+
model.eval()
|
130 |
+
model = model.cuda()
|
131 |
+
with torch.no_grad():
|
132 |
+
for data in eval_dataloader:
|
133 |
+
# convert to dict.
|
134 |
+
if not isinstance(data, dict):
|
135 |
+
data = {
|
136 |
+
"caps": data[0],
|
137 |
+
"cmasks": data[1],
|
138 |
+
"vfeats": data[2],
|
139 |
+
"vmasks": data[3],
|
140 |
+
"video_id": data[4]
|
141 |
+
}
|
142 |
+
data = self.to_ctx(data)
|
143 |
+
outputs = model(**data)
|
144 |
+
outputs.update(data)
|
145 |
+
self(outputs, full_scores)
|
146 |
+
for _cap in data["caps"]:
|
147 |
+
texts.append(
|
148 |
+
self.tokenizer.decode(_cap, skip_special_tokens=True)
|
149 |
+
)
|
150 |
+
|
151 |
+
return self.finalize(full_scores, texts, output_file)
|
152 |
+
|
153 |
+
def __call__(self, sample, full_scores):
|
154 |
+
scores = self._get_pooled_outputs(sample)
|
155 |
+
self._append_scores(scores, full_scores)
|
156 |
+
|
157 |
+
def finalize(self, full_scores, texts, output_file=None):
|
158 |
+
outputs = self._aggregate_scores(full_scores)
|
159 |
+
if output_file is not None:
|
160 |
+
np.save(os.path.join(self.pred_dir, output_file + ".npy"), outputs)
|
161 |
+
return {"outputs": outputs, "texts": texts}
|
162 |
+
|
163 |
+
def _get_pooled_outputs(self, outputs):
|
164 |
+
if "pooled_video" in outputs:
|
165 |
+
return outputs["pooled_video"], outputs["pooled_text"]
|
166 |
+
else:
|
167 |
+
raise ValueError("unknown format of outputs.")
|
168 |
+
|
169 |
+
def _append_scores(self, scores, full_scores):
|
170 |
+
assert len(scores) == 2
|
171 |
+
if len(full_scores) == 0:
|
172 |
+
full_scores.append([])
|
173 |
+
full_scores.append([])
|
174 |
+
full_scores[0].append(scores[0].cpu().detach().numpy())
|
175 |
+
full_scores[1].append(scores[1].cpu().detach().numpy())
|
176 |
+
|
177 |
+
def _aggregate_scores(self, scores):
|
178 |
+
assert len(scores) == 2
|
179 |
+
video_hidden = np.concatenate(scores[0], axis=0)
|
180 |
+
text_hidden = np.concatenate(scores[1], axis=0)
|
181 |
+
# clear up.
|
182 |
+
self.full_scores = []
|
183 |
+
return np.matmul(text_hidden, video_hidden.T)
|
184 |
+
|
185 |
+
|
186 |
+
class QAPredictor(Predictor):
|
187 |
+
"""generated `pooled_video` and `pooled_text`."""
|
188 |
+
def __init__(self, config):
|
189 |
+
super().__init__(config)
|
190 |
+
"""predictor maintains scores and aggregate them."""
|
191 |
+
|
192 |
+
def predict_loop(self, model, eval_dataloader, output_file="qa.npy"):
|
193 |
+
"""on-the-fly prediction on a single gpu."""
|
194 |
+
self.full_scores = []
|
195 |
+
model.eval()
|
196 |
+
model = model.cuda()
|
197 |
+
with torch.no_grad():
|
198 |
+
for data in eval_dataloader:
|
199 |
+
# reshape ans and dup video 5 times.
|
200 |
+
v_len = data["vfeats"].size(1)
|
201 |
+
hidden_size = data["vfeats"].size(2)
|
202 |
+
data["vfeats"] = data["vfeats"].unsqueeze(1).repeat(1, 5, 1, 1).view(-1, v_len, hidden_size)
|
203 |
+
data["vmasks"] = data["vmasks"].unsqueeze(1).repeat(1, 5, 1).view(-1, v_len)
|
204 |
+
|
205 |
+
t_len = data["caps"].size(-1)
|
206 |
+
data["caps"] = data["caps"].view(-1, t_len)
|
207 |
+
data["cmasks"] = data["cmasks"].view(-1, t_len)
|
208 |
+
|
209 |
+
data = self.to_ctx(data)
|
210 |
+
outputs = model(**data)
|
211 |
+
outputs.update(data)
|
212 |
+
self(outputs)
|
213 |
+
return self.finalize(output_file)
|
214 |
+
|
215 |
+
def __call__(self, sample):
|
216 |
+
hidden_size = sample["pooled_video"].size(-1)
|
217 |
+
pooled_video = sample["pooled_video"].view(-1, 5, hidden_size)
|
218 |
+
pooled_text = sample["pooled_text"].view(-1, 5, hidden_size)
|
219 |
+
scores = torch.bmm(pooled_video, pooled_text.transpose(2, 1))
|
220 |
+
scores = scores.argmax(-1)
|
221 |
+
self._append_scores(scores[:, 0], sample["answers"], self.full_scores)
|
222 |
+
|
223 |
+
def finalize(self, output_file=None):
|
224 |
+
outputs, targets = self._aggregate_scores(self.full_scores)
|
225 |
+
if output_file is not None:
|
226 |
+
np.save(os.path.join(self.pred_dir, output_file + ".npy"), outputs)
|
227 |
+
return {"outputs": outputs, "targets": targets}
|
228 |
+
|
229 |
+
def _append_scores(self, scores, answers, full_scores):
|
230 |
+
if len(full_scores) == 0:
|
231 |
+
full_scores.append([])
|
232 |
+
full_scores.append([])
|
233 |
+
full_scores[0].append(scores.cpu().detach().numpy())
|
234 |
+
full_scores[1].append(answers.cpu().detach().numpy())
|
235 |
+
|
236 |
+
def _aggregate_scores(self, scores):
|
237 |
+
assert len(scores) == 2
|
238 |
+
outputs = np.concatenate(scores[0], axis=0)
|
239 |
+
targets = np.concatenate(scores[1], axis=0)
|
240 |
+
# clear up.
|
241 |
+
self.full_scores = []
|
242 |
+
return outputs, targets
|
243 |
+
|
244 |
+
|
245 |
+
class CrossTaskPredictor(Predictor):
|
246 |
+
"""
|
247 |
+
CrossTaskPredictor needs to compute the average of logits
|
248 |
+
for overlapped sliding-window.
|
249 |
+
"""
|
250 |
+
def __init__(self, config):
|
251 |
+
super().__init__(config)
|
252 |
+
self.lsm = torch.nn.LogSoftmax(dim=1)
|
253 |
+
self.max_video_len = config.dataset.max_video_len
|
254 |
+
self.sliding_window = config.dataset.sliding_window
|
255 |
+
self.sliding_window_size = config.dataset.sliding_window_size
|
256 |
+
self.annotation_path = config.dataset.annotation_path
|
257 |
+
|
258 |
+
def predict_loop(self, model, eval_dataloader, output_file="result.pkl"):
|
259 |
+
"""refactored from line 144:
|
260 |
+
https://github.com/DmZhukov/CrossTask/blob/master/train.py
|
261 |
+
"""
|
262 |
+
ctx = 0
|
263 |
+
model.eval()
|
264 |
+
model = model.to(ctx)
|
265 |
+
# this is not a loss but just compute neg_log_prob.
|
266 |
+
Y_pred = {}
|
267 |
+
Y_true = {}
|
268 |
+
with torch.no_grad():
|
269 |
+
for batch in eval_dataloader:
|
270 |
+
self(batch, model, Y_pred, Y_true)
|
271 |
+
return self.finalize(Y_pred, Y_true, output_file)
|
272 |
+
|
273 |
+
def __call__(self, sample, model, Y_pred, Y_true):
|
274 |
+
# please install dp from `https://github.com/DmZhukov/CrossTask`
|
275 |
+
from dp import dp
|
276 |
+
vid, task = sample['video_id'][0], sample['task'][0]
|
277 |
+
sample = self.to_ctx(sample)
|
278 |
+
# compute the average logits over sliding windows.
|
279 |
+
output = model(**sample)
|
280 |
+
batch_logits = output["logits"].cpu()
|
281 |
+
|
282 |
+
video_len = sample["video_len"][0]
|
283 |
+
|
284 |
+
# the following version is slow.
|
285 |
+
logits = torch.zeros((video_len, batch_logits.size(1)))
|
286 |
+
logits_counts = torch.zeros((video_len, 1), dtype=torch.long)
|
287 |
+
# use the same loop as aligner to recover.
|
288 |
+
batch_logit_idx = 0
|
289 |
+
for window_start in range(0, video_len, self.sliding_window):
|
290 |
+
video_end = min(video_len - window_start, self.sliding_window_size)
|
291 |
+
logits[window_start: window_start + video_end] += batch_logits[
|
292 |
+
batch_logit_idx: batch_logit_idx + video_end]
|
293 |
+
batch_logit_idx += video_end
|
294 |
+
logits_counts[window_start: window_start + video_end] += torch.ones((video_end, 1), dtype=torch.long)
|
295 |
+
|
296 |
+
if (video_len - window_start) <= self.sliding_window_size:
|
297 |
+
break
|
298 |
+
|
299 |
+
logits /= logits_counts
|
300 |
+
assert logits.size() == (video_len, batch_logits.size(1)), "{}, {}".format(logits.size(), video_len)
|
301 |
+
|
302 |
+
O = self.lsm(logits)
|
303 |
+
y = np.zeros(O.size(), dtype=np.float32)
|
304 |
+
dp(y, -O.detach().cpu().numpy())
|
305 |
+
if task not in Y_pred:
|
306 |
+
Y_pred[task] = {}
|
307 |
+
Y_pred[task][vid] = y
|
308 |
+
annot_path = os.path.join(
|
309 |
+
self.annotation_path, task+'_'+vid+'.csv')
|
310 |
+
if os.path.exists(annot_path):
|
311 |
+
if task not in Y_true:
|
312 |
+
Y_true[task] = {}
|
313 |
+
Y_true[task][vid] = self._read_assignment(
|
314 |
+
*y.shape, annot_path)
|
315 |
+
|
316 |
+
def finalize(self, Y_pred, Y_true, output_file=None):
|
317 |
+
if output_file is not None:
|
318 |
+
with open(
|
319 |
+
os.path.join(self.pred_dir, output_file + ".pkl"),
|
320 |
+
"wb") as fw:
|
321 |
+
pickle.dump(
|
322 |
+
{"Y_pred": Y_pred, "Y_true": Y_true}, fw,
|
323 |
+
protocol=pickle.HIGHEST_PROTOCOL)
|
324 |
+
return {"outputs": Y_pred, "targets": Y_true}
|
325 |
+
|
326 |
+
def _read_assignment(self, T, K, path):
|
327 |
+
"""
|
328 |
+
refactored from https://github.com/DmZhukov/CrossTask/blob/master/data.py
|
329 |
+
Howto interpret contraints on loss that is going to be minimized:
|
330 |
+
lambd is a big number;
|
331 |
+
self.lambd * C is a big number for all valid position (csv stores invalids)
|
332 |
+
|
333 |
+
def forward(self, O, Y, C):
|
334 |
+
return (Y*(self.lambd * C - self.lsm(O))).mean(dim=0).sum()
|
335 |
+
|
336 |
+
This will load the csv file and fill-in the step col from start to end rows.
|
337 |
+
"""
|
338 |
+
|
339 |
+
Y = np.zeros([T, K], dtype=np.uint8)
|
340 |
+
with open(path, 'r') as f:
|
341 |
+
for line in f:
|
342 |
+
step, start, end = line.strip().split(',')
|
343 |
+
start = int(math.floor(float(start)))
|
344 |
+
end = int(math.ceil(float(end)))
|
345 |
+
step = int(step) - 1
|
346 |
+
Y[start:end, step] = 1
|
347 |
+
return Y
|
348 |
+
|
349 |
+
|
350 |
+
class COINPredictor(Predictor):
|
351 |
+
"""
|
352 |
+
COINPredictor is similar to CrossTask on sliding windows.
|
353 |
+
"""
|
354 |
+
def __init__(self, config):
|
355 |
+
super().__init__(config)
|
356 |
+
self.max_video_len = config.dataset.max_video_len
|
357 |
+
self.sliding_window = config.dataset.sliding_window
|
358 |
+
self.sliding_window_size = config.dataset.sliding_window_size
|
359 |
+
|
360 |
+
def predict_loop(self, model, eval_dataloader, output_file="result.pkl"):
|
361 |
+
"""refactored from line 144:
|
362 |
+
https://github.com/DmZhukov/CrossTask/blob/master/train.py
|
363 |
+
"""
|
364 |
+
ctx = 0
|
365 |
+
model.eval()
|
366 |
+
model = model.to(ctx)
|
367 |
+
# this is not a loss but just compute neg_log_prob.
|
368 |
+
Y_pred = []
|
369 |
+
Y_true = []
|
370 |
+
with torch.no_grad():
|
371 |
+
for batch in eval_dataloader:
|
372 |
+
self(batch, model, Y_pred, Y_true)
|
373 |
+
return self.finalize(Y_pred, Y_true, output_file)
|
374 |
+
|
375 |
+
def __call__(self, sample, model, Y_pred, Y_true):
|
376 |
+
sample = self.to_ctx(sample)
|
377 |
+
# compute the average logits over sliding windows.
|
378 |
+
output = model(**sample)
|
379 |
+
logits = self._merge_windows(sample, output)
|
380 |
+
Y_pred.append(logits.argmax(dim=1))
|
381 |
+
Y_true.append(sample["video_targets"].squeeze(0).cpu())
|
382 |
+
|
383 |
+
def _merge_windows(self, sample, output):
|
384 |
+
targets = sample["targets"].reshape(-1).cpu()
|
385 |
+
valid_mask = targets != -100
|
386 |
+
targets = targets[valid_mask]
|
387 |
+
batch_logits = output["logits"].cpu()
|
388 |
+
batch_logits = batch_logits.reshape(-1, batch_logits.size(-1))
|
389 |
+
batch_logits = batch_logits[valid_mask]
|
390 |
+
|
391 |
+
video_len = sample["video_len"][0]
|
392 |
+
|
393 |
+
# the following version is slow.
|
394 |
+
logits = torch.zeros((video_len, batch_logits.size(1)))
|
395 |
+
logits_counts = torch.zeros((video_len, 1), dtype=torch.long)
|
396 |
+
# use the same loop as aligner to recover.
|
397 |
+
batch_logit_idx = 0
|
398 |
+
for window_start in range(0, video_len, self.sliding_window):
|
399 |
+
video_end = min(video_len - window_start, self.sliding_window_size)
|
400 |
+
logits[window_start: window_start + video_end] += batch_logits[
|
401 |
+
batch_logit_idx: batch_logit_idx + video_end]
|
402 |
+
batch_logit_idx += video_end
|
403 |
+
logits_counts[window_start: window_start + video_end] += torch.ones((video_end, 1), dtype=torch.long)
|
404 |
+
if (video_len - window_start) <= self.sliding_window_size:
|
405 |
+
break
|
406 |
+
logits /= logits_counts
|
407 |
+
assert logits.size() == (video_len, batch_logits.size(1)), "{}, {}".format(logits.size(), video_len)
|
408 |
+
return logits
|
409 |
+
|
410 |
+
def finalize(self, Y_pred, Y_true, output_file=None):
|
411 |
+
Y_pred = torch.cat(Y_pred, dim=0).numpy()
|
412 |
+
Y_true = torch.cat(Y_true, dim=0).numpy()
|
413 |
+
assert len(Y_pred) == len(Y_true)
|
414 |
+
|
415 |
+
error_mask = Y_pred != Y_true
|
416 |
+
print("sample error", Y_pred[error_mask][:10], Y_true[error_mask][:10])
|
417 |
+
print("sample error", Y_pred[error_mask][10:20], Y_true[error_mask][10:20])
|
418 |
+
|
419 |
+
if output_file is not None:
|
420 |
+
with open(
|
421 |
+
os.path.join(self.pred_dir, output_file + ".pkl"),
|
422 |
+
"wb") as fw:
|
423 |
+
pickle.dump(
|
424 |
+
{"Y_pred": Y_pred, "Y_true": Y_true}, fw,
|
425 |
+
protocol=pickle.HIGHEST_PROTOCOL)
|
426 |
+
return {"outputs": Y_pred, "targets": Y_true}
|
427 |
+
|
428 |
+
|
429 |
+
class COINZSPredictor(COINPredictor):
|
430 |
+
"""
|
431 |
+
COINZSPredictor for COIN zero-shot prediction.
|
432 |
+
"""
|
433 |
+
|
434 |
+
def __init__(self, config):
|
435 |
+
super().__init__(config)
|
436 |
+
self.dataset_config = config.dataset
|
437 |
+
|
438 |
+
def predict_loop(self, model, eval_dataloader, output_file="result.pkl"):
|
439 |
+
"""refactored from line 144:
|
440 |
+
https://github.com/DmZhukov/CrossTask/blob/master/train.py
|
441 |
+
"""
|
442 |
+
ctx = 0
|
443 |
+
model.eval()
|
444 |
+
model = model.to(ctx)
|
445 |
+
|
446 |
+
with torch.no_grad():
|
447 |
+
outputs = eval_dataloader.dataset.meta_processor.meta_text_labels(
|
448 |
+
self.dataset_config)
|
449 |
+
outputs = self.to_ctx(outputs, ctx)
|
450 |
+
label_hidden_states = model.forward_text(**outputs).cpu()
|
451 |
+
label_sim = label_hidden_states @ label_hidden_states.t()
|
452 |
+
num_labels = label_sim.size(0)
|
453 |
+
eye_mask = ~torch.eye(num_labels, dtype=torch.bool)
|
454 |
+
label_sim = label_sim.masked_select(eye_mask).view(num_labels, num_labels - 1)
|
455 |
+
lbd = label_sim.max()
|
456 |
+
|
457 |
+
# this is not a loss but just compute neg_log_prob.
|
458 |
+
Y_pred = []
|
459 |
+
Y_true = []
|
460 |
+
with torch.no_grad():
|
461 |
+
for batch in eval_dataloader:
|
462 |
+
self(batch, label_hidden_states, model, lbd, Y_pred, Y_true)
|
463 |
+
return self.finalize(Y_pred, Y_true, output_file)
|
464 |
+
|
465 |
+
def reshape_subsample(self, sample):
|
466 |
+
for key in sample:
|
467 |
+
if torch.is_tensor(sample[key]):
|
468 |
+
sample[key] = self.flat_subsample(sample[key])
|
469 |
+
return sample
|
470 |
+
|
471 |
+
def flat_subsample(self, tensor):
|
472 |
+
if len(tensor.size()) > 1 and tensor.size(0) == 1:
|
473 |
+
tensor = tensor.squeeze(0)
|
474 |
+
return tensor
|
475 |
+
|
476 |
+
def __call__(self, sample, label_hidden_states, model, lbd, Y_pred, Y_true):
|
477 |
+
sample = self.reshape_subsample(sample)
|
478 |
+
sample = self.to_ctx(sample)
|
479 |
+
# compute the average logits over sliding windows.
|
480 |
+
sample["output_hidden_states"] = True
|
481 |
+
video_outputs = model.forward_video(**sample).cpu()
|
482 |
+
output = {"logits": video_outputs[:, 1:sample["vmasks"].size(1)+1] @ label_hidden_states.t()}
|
483 |
+
logits = self._merge_windows(sample, output)
|
484 |
+
# logic of zero-shot for sequence labeling.
|
485 |
+
logits_argmax = logits.argmax(dim=1) + 1 # 0 is "O" label.
|
486 |
+
logits_max = logits.max(dim=1)[0]
|
487 |
+
|
488 |
+
pred = torch.zeros_like(logits_argmax)
|
489 |
+
label_select = logits_max > lbd # 73 or 74
|
490 |
+
pred[label_select] = logits_argmax[label_select]
|
491 |
+
|
492 |
+
Y_pred.append(pred)
|
493 |
+
Y_true.append(sample["video_targets"].squeeze(0).cpu())
|
494 |
+
|
495 |
+
def finalize(self, Y_pred, Y_true, output_file=None):
|
496 |
+
Y_pred = torch.cat(Y_pred, dim=0).numpy()
|
497 |
+
Y_true = torch.cat(Y_true, dim=0).numpy()
|
498 |
+
assert len(Y_pred) == len(Y_true)
|
499 |
+
|
500 |
+
error_mask = Y_pred != Y_true
|
501 |
+
print("sample error", Y_pred[error_mask][:10], Y_true[error_mask][:10])
|
502 |
+
print("sample error", Y_pred[error_mask][10:20], Y_true[error_mask][10:20])
|
503 |
+
|
504 |
+
if output_file is not None:
|
505 |
+
with open(
|
506 |
+
os.path.join(self.pred_dir, output_file + ".pkl"),
|
507 |
+
"wb") as fw:
|
508 |
+
pickle.dump(
|
509 |
+
{"Y_pred": Y_pred, "Y_true": Y_true}, fw,
|
510 |
+
protocol=pickle.HIGHEST_PROTOCOL)
|
511 |
+
return {"outputs": Y_pred, "targets": Y_true}
|
512 |
+
|
513 |
+
|
514 |
+
class DiDeMoPredictor(Predictor):
|
515 |
+
"""reference: https://github.com/LisaAnne/LocalizingMoments/blob/master/utils/eval.py
|
516 |
+
https://github.com/LisaAnne/LocalizingMoments/blob/master/utils/data_processing.py
|
517 |
+
"""
|
518 |
+
def __init__(self, config):
|
519 |
+
super().__init__(config)
|
520 |
+
# load targets.
|
521 |
+
with open(config.dataset.test_path) as data_file:
|
522 |
+
self.test_data = json.load(data_file)
|
523 |
+
|
524 |
+
def predict_loop(self, model, eval_dataloader, output_file="didemo.npy"):
|
525 |
+
"""
|
526 |
+
TODO: two solutions here.
|
527 |
+
"""
|
528 |
+
import itertools
|
529 |
+
# 21 chunks.
|
530 |
+
self.possible_segments = [(0,0), (1,1), (2,2), (3,3), (4,4), (5,5)]
|
531 |
+
for i in itertools.combinations(range(6), 2):
|
532 |
+
self.possible_segments.append(i)
|
533 |
+
# pick segments from a video.
|
534 |
+
|
535 |
+
"""on-the-fly prediction on a single gpu."""
|
536 |
+
self.full_scores = []
|
537 |
+
model.eval()
|
538 |
+
model = model.cuda()
|
539 |
+
with torch.no_grad():
|
540 |
+
for data in eval_dataloader:
|
541 |
+
# TODO special forwarding logic here.
|
542 |
+
data = self.to_ctx(data)
|
543 |
+
data["output_hidden_states"] = True
|
544 |
+
hidden_video = model.forward_video(**data)
|
545 |
+
data["output_hidden_states"] = False
|
546 |
+
pooled_text = model.forward_text(**data)
|
547 |
+
outputs = {
|
548 |
+
"hidden_video": hidden_video,
|
549 |
+
"pooled_text": pooled_text
|
550 |
+
}
|
551 |
+
outputs.update(data)
|
552 |
+
self(outputs)
|
553 |
+
return self.finalize(output_file)
|
554 |
+
|
555 |
+
def __call__(self, sample):
|
556 |
+
# TODO: make an index select from self.possible_segments.
|
557 |
+
hidden_video = sample["hidden_video"]
|
558 |
+
pooled_text = sample["pooled_text"]
|
559 |
+
vmasks = sample["vmasks"]
|
560 |
+
# probably maintain valid results here.
|
561 |
+
|
562 |
+
hidden_video = hidden_video[:, 1:-1, :]
|
563 |
+
# probably maintain valid results here.
|
564 |
+
pooled_video = []
|
565 |
+
for s, e in self.possible_segments:
|
566 |
+
pooled_video.append(
|
567 |
+
torch.mean(
|
568 |
+
hidden_video[:, int(s*5):int((e+1)*5), :],
|
569 |
+
dim=1, keepdim=True)
|
570 |
+
)
|
571 |
+
pooled_video = torch.cat(pooled_video, dim=1)
|
572 |
+
scores = torch.bmm(
|
573 |
+
pooled_video, pooled_text.unsqueeze(-1)).squeeze(-1).cpu()
|
574 |
+
|
575 |
+
ranks = scores.argsort(dim=-1, descending=True)
|
576 |
+
|
577 |
+
for batch_idx, rank in enumerate(ranks):
|
578 |
+
rank_of_moment = []
|
579 |
+
for m_idx, moment in enumerate(rank):
|
580 |
+
s, e = self.possible_segments[moment.item()]
|
581 |
+
if torch.any(
|
582 |
+
vmasks[batch_idx, int(s*5):int((e+1)*5)]
|
583 |
+
):
|
584 |
+
rank_of_moment.append((s, e))
|
585 |
+
self.full_scores.append(rank_of_moment)
|
586 |
+
|
587 |
+
def finalize(self, output_file=None):
|
588 |
+
outputs = self._aggregate_scores(self.full_scores)
|
589 |
+
if output_file is not None:
|
590 |
+
np.save(os.path.join(self.pred_dir, output_file + ".npy"), outputs)
|
591 |
+
return {"outputs": outputs, "targets": self.test_data}
|
592 |
+
|
593 |
+
def _aggregate_scores(self, scores):
|
594 |
+
self.full_scores = []
|
595 |
+
return scores
|
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/losses/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
from .loss import *
|
6 |
+
from .nce import *
|
7 |
+
|
8 |
+
try:
|
9 |
+
from .fairseqmmloss import *
|
10 |
+
except ImportError:
|
11 |
+
pass
|
12 |
+
|
13 |
+
try:
|
14 |
+
from .expnce import *
|
15 |
+
except ImportError:
|
16 |
+
pass
|
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/losses/fairseqmmloss.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
"""
|
7 |
+
TODO (huxu): a general fairseq criterion for all your pre-defined losses.
|
8 |
+
"""
|
9 |
+
|
10 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
11 |
+
from fairseq import metrics
|
12 |
+
|
13 |
+
|
14 |
+
@register_criterion("mmloss")
|
15 |
+
class MMCriterion(FairseqCriterion):
|
16 |
+
def __init__(self, task):
|
17 |
+
super().__init__(task)
|
18 |
+
# TODO (huxu): wrap forward call of loss_fn and eval_fn into task.
|
19 |
+
self.mmtask = task.mmtask
|
20 |
+
|
21 |
+
def forward(self, model, sample):
|
22 |
+
"""Compute the loss for the given sample.
|
23 |
+
Returns a tuple with three elements:
|
24 |
+
1) the loss
|
25 |
+
2) the sample size, which is used as the denominator for the gradient
|
26 |
+
3) logging outputs to display while training
|
27 |
+
"""
|
28 |
+
outputs = self.mmtask(model, sample)
|
29 |
+
|
30 |
+
loss, loss_scalar, max_len, batch_size, sample_size = (
|
31 |
+
outputs["loss"],
|
32 |
+
outputs["loss_scalar"],
|
33 |
+
outputs["max_len"],
|
34 |
+
outputs["batch_size"],
|
35 |
+
outputs["sample_size"],
|
36 |
+
)
|
37 |
+
|
38 |
+
logging_output = {
|
39 |
+
"loss": loss_scalar,
|
40 |
+
"ntokens": max_len * batch_size, # dummy report.
|
41 |
+
"nsentences": batch_size, # dummy report.
|
42 |
+
"sample_size": sample_size,
|
43 |
+
}
|
44 |
+
|
45 |
+
return loss, 1, logging_output
|
46 |
+
|
47 |
+
@staticmethod
|
48 |
+
def reduce_metrics(logging_outputs) -> None:
|
49 |
+
"""Aggregate logging outputs from data parallel training."""
|
50 |
+
"""since we use NCE, our actual batch_size is 1 per GPU.
|
51 |
+
Then we take the mean of each worker."""
|
52 |
+
loss_sum = sum(log.get("loss", 0.0) for log in logging_outputs)
|
53 |
+
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
54 |
+
metrics.log_scalar("loss", loss_sum / sample_size, round=3)
|
55 |
+
|
56 |
+
@staticmethod
|
57 |
+
def logging_outputs_can_be_summed() -> bool:
|
58 |
+
"""
|
59 |
+
Whether the logging outputs returned by `forward` can be summed
|
60 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
61 |
+
to True will improves distributed training speed.
|
62 |
+
"""
|
63 |
+
return True
|
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/losses/loss.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. All Rights Reserved
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
|
8 |
+
class Loss(object):
|
9 |
+
def __call__(self, *args, **kwargs):
|
10 |
+
raise NotImplementedError
|
11 |
+
|
12 |
+
|
13 |
+
# Dummy Loss for testing.
|
14 |
+
class DummyLoss(Loss):
|
15 |
+
def __init__(self):
|
16 |
+
self.loss = nn.CrossEntropyLoss()
|
17 |
+
|
18 |
+
def __call__(self, logits, targets, **kwargs):
|
19 |
+
return self.loss(logits, targets)
|
20 |
+
|
21 |
+
|
22 |
+
class DummyK400Loss(Loss):
|
23 |
+
"""dummy k400 loss for MViT."""
|
24 |
+
def __init__(self):
|
25 |
+
self.loss = nn.CrossEntropyLoss()
|
26 |
+
|
27 |
+
def __call__(self, logits, targets, **kwargs):
|
28 |
+
return self.loss(
|
29 |
+
logits, torch.randint(0, 400, (logits.size(0),), device=logits.device))
|
30 |
+
|
31 |
+
|
32 |
+
class CrossEntropy(Loss):
|
33 |
+
def __init__(self):
|
34 |
+
self.loss = nn.CrossEntropyLoss()
|
35 |
+
|
36 |
+
def __call__(self, logits, targets, **kwargs):
|
37 |
+
return self.loss(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
|
38 |
+
|
39 |
+
|
40 |
+
class ArgmaxCrossEntropy(Loss):
|
41 |
+
def __init__(self):
|
42 |
+
self.loss = nn.CrossEntropyLoss()
|
43 |
+
|
44 |
+
def __call__(self, logits, targets, **kwargs):
|
45 |
+
return self.loss(logits, targets.argmax(dim=1))
|
46 |
+
|
47 |
+
|
48 |
+
class BCE(Loss):
|
49 |
+
def __init__(self):
|
50 |
+
self.loss = nn.BCEWithLogitsLoss()
|
51 |
+
|
52 |
+
def __call__(self, logits, targets, **kwargs):
|
53 |
+
targets = targets.squeeze(0)
|
54 |
+
return self.loss(logits, targets)
|
55 |
+
|
56 |
+
|
57 |
+
class NLGLoss(Loss):
|
58 |
+
def __init__(self):
|
59 |
+
self.loss = nn.CrossEntropyLoss()
|
60 |
+
|
61 |
+
def __call__(self, logits, text_label, **kwargs):
|
62 |
+
targets = text_label[text_label != -100]
|
63 |
+
return self.loss(logits, targets)
|
64 |
+
|
65 |
+
|
66 |
+
class MSE(Loss):
|
67 |
+
def __init__(self):
|
68 |
+
self.loss = nn.MSELoss()
|
69 |
+
|
70 |
+
def __call__(self, logits, targets, **kwargs):
|
71 |
+
return self.loss(logits, targets)
|
72 |
+
|
73 |
+
|
74 |
+
class L1(Loss):
|
75 |
+
def __init__(self):
|
76 |
+
self.loss = nn.L1Loss()
|
77 |
+
|
78 |
+
def __call__(self, logits, targets, **kwargs):
|
79 |
+
return self.loss(logits, targets)
|
80 |
+
|
81 |
+
|
82 |
+
class SmoothL1(Loss):
|
83 |
+
def __init__(self):
|
84 |
+
self.loss = nn.SmoothL1Loss()
|
85 |
+
|
86 |
+
def __call__(self, logits, targets, **kwargs):
|
87 |
+
return self.loss(logits, targets)
|
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/losses/nce.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
"""
|
7 |
+
softmax-based NCE loss, used by this project.
|
8 |
+
"""
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from torch import nn
|
13 |
+
|
14 |
+
from .loss import Loss
|
15 |
+
|
16 |
+
|
17 |
+
class NCE(Loss):
|
18 |
+
def __init__(self):
|
19 |
+
# TODO (huxu): define temperature.
|
20 |
+
self.loss = nn.CrossEntropyLoss()
|
21 |
+
|
22 |
+
def __call__(self, align_scores, **kargs):
|
23 |
+
# note: we reuse the same shape as cls head in BERT (batch_size, 2)
|
24 |
+
# but NCE only needs one logits.
|
25 |
+
# (so we drop all weights in the second neg logits.)
|
26 |
+
align_scores = align_scores[:, :1]
|
27 |
+
# duplicate negative examples
|
28 |
+
batch_size = align_scores.size(0) // 2
|
29 |
+
pos_scores = align_scores[:batch_size]
|
30 |
+
neg_scores = align_scores[batch_size:].view(1, batch_size).repeat(
|
31 |
+
batch_size, 1)
|
32 |
+
scores = torch.cat([pos_scores, neg_scores], dim=1)
|
33 |
+
return self.loss(
|
34 |
+
scores,
|
35 |
+
torch.zeros(
|
36 |
+
(batch_size,),
|
37 |
+
dtype=torch.long,
|
38 |
+
device=align_scores.device),
|
39 |
+
)
|
40 |
+
|
41 |
+
|
42 |
+
class T2VContraLoss(Loss):
|
43 |
+
"""NCE for MM joint space, on softmax text2video matrix.
|
44 |
+
"""
|
45 |
+
def __init__(self):
|
46 |
+
# TODO (huxu): define temperature.
|
47 |
+
self.loss = nn.CrossEntropyLoss()
|
48 |
+
|
49 |
+
def __call__(self, pooled_video, pooled_text, **kargs):
|
50 |
+
batch_size = pooled_video.size(0)
|
51 |
+
logits = torch.mm(pooled_text, pooled_video.transpose(1, 0))
|
52 |
+
targets = torch.arange(
|
53 |
+
batch_size,
|
54 |
+
dtype=torch.long,
|
55 |
+
device=pooled_video.device)
|
56 |
+
return self.loss(logits, targets)
|
57 |
+
|
58 |
+
|
59 |
+
class V2TContraLoss(Loss):
|
60 |
+
"""NCE for MM joint space, with softmax on video2text matrix."""
|
61 |
+
|
62 |
+
def __init__(self):
|
63 |
+
# TODO (huxu): define temperature.
|
64 |
+
self.loss = nn.CrossEntropyLoss()
|
65 |
+
|
66 |
+
def __call__(self, pooled_video, pooled_text, **kargs):
|
67 |
+
batch_size = pooled_video.size(0)
|
68 |
+
logits = torch.mm(pooled_video, pooled_text.transpose(1, 0))
|
69 |
+
targets = torch.arange(
|
70 |
+
batch_size,
|
71 |
+
dtype=torch.long,
|
72 |
+
device=pooled_video.device)
|
73 |
+
return self.loss(logits, targets)
|
74 |
+
|
75 |
+
|
76 |
+
class MMContraLoss(Loss):
|
77 |
+
def __init__(self):
|
78 |
+
self.loss = nn.CrossEntropyLoss()
|
79 |
+
|
80 |
+
def __call__(self, pooled_video, pooled_text, **kwargs):
|
81 |
+
logits_per_video = pooled_video @ pooled_text.t()
|
82 |
+
logits_per_text = pooled_text @ pooled_video.t()
|
83 |
+
|
84 |
+
targets = torch.arange(
|
85 |
+
pooled_video.size(0),
|
86 |
+
dtype=torch.long,
|
87 |
+
device=pooled_video.device)
|
88 |
+
loss_video = self.loss(logits_per_video, targets)
|
89 |
+
loss_text = self.loss(logits_per_text, targets)
|
90 |
+
return loss_video + loss_text
|
91 |
+
|
92 |
+
|
93 |
+
class MTM(Loss):
|
94 |
+
"""Combination of MFM and MLM."""
|
95 |
+
|
96 |
+
def __init__(self):
|
97 |
+
self.loss = nn.CrossEntropyLoss()
|
98 |
+
|
99 |
+
def __call__(
|
100 |
+
self,
|
101 |
+
video_logits,
|
102 |
+
text_logits,
|
103 |
+
video_label,
|
104 |
+
text_label,
|
105 |
+
**kwargs
|
106 |
+
):
|
107 |
+
text_logits = torch.cat([
|
108 |
+
text_logits,
|
109 |
+
torch.zeros(
|
110 |
+
(text_logits.size(0), 1), device=text_logits.device)
|
111 |
+
], dim=1)
|
112 |
+
vt_logits = torch.cat([video_logits, text_logits], dim=0)
|
113 |
+
# loss for video.
|
114 |
+
video_label = torch.zeros(
|
115 |
+
(video_logits.size(0),),
|
116 |
+
dtype=torch.long,
|
117 |
+
device=video_logits.device
|
118 |
+
)
|
119 |
+
|
120 |
+
# loss for text.
|
121 |
+
text_label = text_label.reshape(-1)
|
122 |
+
labels_mask = text_label != -100
|
123 |
+
selected_text_label = text_label[labels_mask]
|
124 |
+
|
125 |
+
vt_label = torch.cat([video_label, selected_text_label], dim=0)
|
126 |
+
return self.loss(vt_logits, vt_label)
|
127 |
+
|
128 |
+
|
129 |
+
class MFMMLM(Loss):
|
130 |
+
"""Combination of MFM and MLM."""
|
131 |
+
|
132 |
+
def __init__(self):
|
133 |
+
self.loss = nn.CrossEntropyLoss()
|
134 |
+
|
135 |
+
def __call__(
|
136 |
+
self,
|
137 |
+
video_logits,
|
138 |
+
text_logits,
|
139 |
+
video_label,
|
140 |
+
text_label,
|
141 |
+
**kwargs
|
142 |
+
):
|
143 |
+
# loss for video.
|
144 |
+
video_label = torch.zeros(
|
145 |
+
(video_logits.size(0),),
|
146 |
+
dtype=torch.long,
|
147 |
+
device=video_logits.device
|
148 |
+
)
|
149 |
+
masked_frame_loss = self.loss(video_logits, video_label)
|
150 |
+
|
151 |
+
# loss for text.
|
152 |
+
text_label = text_label.reshape(-1)
|
153 |
+
labels_mask = text_label != -100
|
154 |
+
selected_text_label = text_label[labels_mask]
|
155 |
+
masked_lm_loss = self.loss(text_logits, selected_text_label)
|
156 |
+
return masked_frame_loss + masked_lm_loss
|
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/models/__init__.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
from .mmfusion import *
|
6 |
+
from .transformermodel import *
|
7 |
+
from .mmfusionnlg import *
|
8 |
+
|
9 |
+
try:
|
10 |
+
from .fairseqmmmodel import *
|
11 |
+
except ImportError:
|
12 |
+
pass
|
13 |
+
|
14 |
+
try:
|
15 |
+
from .expmmfusion import *
|
16 |
+
except ImportError:
|
17 |
+
pass
|
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/models/fairseqmmmodel.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from fairseq.models import (
|
7 |
+
BaseFairseqModel,
|
8 |
+
register_model,
|
9 |
+
register_model_architecture
|
10 |
+
)
|
11 |
+
|
12 |
+
|
13 |
+
@register_model("mmmodel")
|
14 |
+
class FairseqMMModel(BaseFairseqModel):
|
15 |
+
"""a fairseq wrapper of model built by `task`."""
|
16 |
+
|
17 |
+
@classmethod
|
18 |
+
def build_model(cls, args, task):
|
19 |
+
return FairseqMMModel(task.mmtask.model)
|
20 |
+
|
21 |
+
def __init__(self, mmmodel):
|
22 |
+
super().__init__()
|
23 |
+
self.mmmodel = mmmodel
|
24 |
+
|
25 |
+
def forward(self, *args, **kwargs):
|
26 |
+
return self.mmmodel(*args, **kwargs)
|
27 |
+
|
28 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
29 |
+
|
30 |
+
super().upgrade_state_dict_named(state_dict, name)
|
31 |
+
|
32 |
+
keys_to_delete = []
|
33 |
+
|
34 |
+
for key in state_dict:
|
35 |
+
if key not in self.state_dict():
|
36 |
+
keys_to_delete.append(key)
|
37 |
+
for key in keys_to_delete:
|
38 |
+
print("[INFO]", key, "not used anymore.")
|
39 |
+
del state_dict[key]
|
40 |
+
|
41 |
+
# copy any newly defined parameters.
|
42 |
+
for key in self.state_dict():
|
43 |
+
if key not in state_dict:
|
44 |
+
print("[INFO] adding", key)
|
45 |
+
state_dict[key] = self.state_dict()[key]
|
46 |
+
|
47 |
+
|
48 |
+
# a dummy arch, we config the model.
|
49 |
+
@register_model_architecture("mmmodel", "mmarch")
|
50 |
+
def mmarch(args):
|
51 |
+
pass
|
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/models/mmfusion.py
ADDED
@@ -0,0 +1,926 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
# Copyright (c) Facebook, Inc. All Rights Reserved
|
17 |
+
|
18 |
+
|
19 |
+
import torch
|
20 |
+
|
21 |
+
from torch import nn
|
22 |
+
|
23 |
+
try:
|
24 |
+
from transformers import AutoConfig, AutoTokenizer
|
25 |
+
except ImportError:
|
26 |
+
pass
|
27 |
+
|
28 |
+
from . import transformermodel
|
29 |
+
|
30 |
+
|
31 |
+
class MMPTModel(nn.Module):
|
32 |
+
"""An e2e wrapper of inference model.
|
33 |
+
"""
|
34 |
+
@classmethod
|
35 |
+
def from_pretrained(cls, config, checkpoint="checkpoint_best.pt"):
|
36 |
+
import os
|
37 |
+
from ..utils import recursive_config
|
38 |
+
from ..tasks import Task
|
39 |
+
config = recursive_config(config)
|
40 |
+
mmtask = Task.config_task(config)
|
41 |
+
checkpoint_path = os.path.join(config.eval.save_path, checkpoint)
|
42 |
+
mmtask.build_model(checkpoint=checkpoint_path)
|
43 |
+
# TODO(huxu): make the video encoder configurable.
|
44 |
+
from ..processors.models.s3dg import S3D
|
45 |
+
video_encoder = S3D('pretrained_models/s3d_dict.npy', 512)
|
46 |
+
video_encoder.load_state_dict(
|
47 |
+
torch.load('pretrained_models/s3d_howto100m.pth'))
|
48 |
+
from transformers import AutoTokenizer
|
49 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
50 |
+
config.dataset.bert_name, use_fast=config.dataset.use_fast
|
51 |
+
)
|
52 |
+
from ..processors import Aligner
|
53 |
+
aligner = Aligner(config.dataset)
|
54 |
+
return (
|
55 |
+
MMPTModel(config, mmtask.model, video_encoder),
|
56 |
+
tokenizer,
|
57 |
+
aligner
|
58 |
+
)
|
59 |
+
|
60 |
+
def __init__(self, config, model, video_encoder, **kwargs):
|
61 |
+
super().__init__()
|
62 |
+
self.max_video_len = config.dataset.max_video_len
|
63 |
+
self.video_encoder = video_encoder
|
64 |
+
self.model = model
|
65 |
+
|
66 |
+
def forward(self, video_frames, caps, cmasks, return_score=False):
|
67 |
+
bsz = video_frames.size(0)
|
68 |
+
assert bsz == 1, "only bsz=1 is supported now."
|
69 |
+
seq_len = video_frames.size(1)
|
70 |
+
video_frames = video_frames.view(-1, *video_frames.size()[2:])
|
71 |
+
vfeats = self.video_encoder(video_frames.permute(0, 4, 1, 2, 3))
|
72 |
+
vfeats = vfeats['video_embedding']
|
73 |
+
vfeats = vfeats.view(bsz, seq_len, vfeats.size(-1))
|
74 |
+
padding = torch.zeros(
|
75 |
+
bsz, self.max_video_len - seq_len, vfeats.size(-1))
|
76 |
+
vfeats = torch.cat([vfeats, padding], dim=1)
|
77 |
+
vmasks = torch.cat([
|
78 |
+
torch.ones((bsz, seq_len), dtype=torch.bool),
|
79 |
+
torch.zeros((bsz, self.max_video_len - seq_len), dtype=torch.bool)
|
80 |
+
],
|
81 |
+
dim=1
|
82 |
+
)
|
83 |
+
output = self.model(caps, cmasks, vfeats, vmasks)
|
84 |
+
if return_score:
|
85 |
+
output = {"score": torch.bmm(
|
86 |
+
output["pooled_video"][:, None, :],
|
87 |
+
output["pooled_text"][:, :, None]
|
88 |
+
).squeeze(-1).squeeze(-1)}
|
89 |
+
return output
|
90 |
+
|
91 |
+
|
92 |
+
class MMFusion(nn.Module):
|
93 |
+
"""a MMPT wrapper class for MMBert style models.
|
94 |
+
TODO: move isolated mask to a subclass.
|
95 |
+
"""
|
96 |
+
def __init__(self, config, **kwargs):
|
97 |
+
super().__init__()
|
98 |
+
transformer_config = AutoConfig.from_pretrained(
|
99 |
+
config.dataset.bert_name)
|
100 |
+
self.hidden_size = transformer_config.hidden_size
|
101 |
+
self.is_train = False
|
102 |
+
if config.dataset.train_path is not None:
|
103 |
+
self.is_train = True
|
104 |
+
# 0 means no iso; 1-12 means iso up to that layer.
|
105 |
+
self.num_hidden_layers = transformer_config.num_hidden_layers
|
106 |
+
self.last_iso_layer = 0
|
107 |
+
if config.dataset.num_iso_layer is not None:
|
108 |
+
self.last_iso_layer = config.dataset.num_iso_layer - 1 + 1
|
109 |
+
|
110 |
+
if config.model.mm_encoder_cls is not None:
|
111 |
+
mm_encoder_cls = getattr(transformermodel, config.model.mm_encoder_cls)
|
112 |
+
model_config = AutoConfig.from_pretrained(config.dataset.bert_name)
|
113 |
+
model_config.max_video_len = config.dataset.max_video_len
|
114 |
+
# TODO: a general way to add parameter for a model.
|
115 |
+
model_config.use_seg_emb = config.model.use_seg_emb
|
116 |
+
self.mm_encoder = mm_encoder_cls.from_pretrained(
|
117 |
+
config.dataset.bert_name, config=model_config)
|
118 |
+
elif config.model.video_encoder_cls is not None\
|
119 |
+
and config.model.text_encoder_cls is not None:
|
120 |
+
video_encoder_cls = getattr(transformermodel, config.model.video_encoder_cls)
|
121 |
+
model_config = AutoConfig.from_pretrained(config.dataset.bert_name)
|
122 |
+
model_config.max_video_len = config.dataset.max_video_len
|
123 |
+
# TODO: make each model a set of config class.
|
124 |
+
if hasattr(model_config, "num_layers"):
|
125 |
+
model_config.num_layers = config.model.num_hidden_video_layers
|
126 |
+
else:
|
127 |
+
model_config.num_hidden_layers = config.model.num_hidden_video_layers
|
128 |
+
self.video_encoder = video_encoder_cls.from_pretrained(
|
129 |
+
config.dataset.bert_name, config=model_config)
|
130 |
+
# exact same NLP model from Huggingface.
|
131 |
+
text_encoder_cls = getattr(transformermodel, config.model.text_encoder_cls)
|
132 |
+
self.text_encoder = text_encoder_cls.from_pretrained(
|
133 |
+
config.dataset.bert_name)
|
134 |
+
else:
|
135 |
+
raise ValueError("the encoder must be either MM or two backbones.")
|
136 |
+
|
137 |
+
def forward(
|
138 |
+
self,
|
139 |
+
caps,
|
140 |
+
cmasks,
|
141 |
+
vfeats,
|
142 |
+
vmasks,
|
143 |
+
**kwargs
|
144 |
+
):
|
145 |
+
raise NotImplementedError(
|
146 |
+
"Please derive MMFusion module."
|
147 |
+
)
|
148 |
+
|
149 |
+
def _mm_on_the_fly(
|
150 |
+
self,
|
151 |
+
cmasks,
|
152 |
+
vmasks,
|
153 |
+
attention_mask
|
154 |
+
):
|
155 |
+
"""helper function for mask, seg_ids and token_type_ids."""
|
156 |
+
if attention_mask is None:
|
157 |
+
attention_mask = self._mm_attention_mask(cmasks, vmasks)
|
158 |
+
|
159 |
+
"""
|
160 |
+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
161 |
+
| first sequence | second sequence |
|
162 |
+
"""
|
163 |
+
token_type_ids = torch.cat(
|
164 |
+
[
|
165 |
+
torch.zeros(
|
166 |
+
(vmasks.size(0), vmasks.size(1) + 2),
|
167 |
+
dtype=torch.long,
|
168 |
+
device=vmasks.device,
|
169 |
+
),
|
170 |
+
torch.ones(
|
171 |
+
(cmasks.size(0), cmasks.size(1) - 2),
|
172 |
+
dtype=torch.long,
|
173 |
+
device=cmasks.device,
|
174 |
+
),
|
175 |
+
],
|
176 |
+
dim=1,
|
177 |
+
)
|
178 |
+
return attention_mask, token_type_ids
|
179 |
+
|
180 |
+
def _mm_attention_mask(self, cmasks, vmasks):
|
181 |
+
assert cmasks.size(0) == vmasks.size(0), "{}, {}, {}, {}".format(
|
182 |
+
str(cmasks.size()),
|
183 |
+
str(vmasks.size()),
|
184 |
+
str(cmasks.size(0)),
|
185 |
+
str(vmasks.size(0)),
|
186 |
+
)
|
187 |
+
|
188 |
+
mm_mask = torch.cat([cmasks[:, :1], vmasks, cmasks[:, 1:]], dim=1)
|
189 |
+
if self.last_iso_layer == 0:
|
190 |
+
# hard attention mask.
|
191 |
+
return mm_mask
|
192 |
+
else:
|
193 |
+
# a gpu iso mask; 0 : num_iso_layer is isolated;
|
194 |
+
# num_iso_layer: are MM-fused.
|
195 |
+
# make an iso layer
|
196 |
+
batch_size = cmasks.size(0)
|
197 |
+
iso_mask = self._make_iso_mask(batch_size, cmasks, vmasks)
|
198 |
+
mm_mask = mm_mask[:, None, :].repeat(1, mm_mask.size(-1), 1)
|
199 |
+
iso_mm_masks = []
|
200 |
+
# hard attention mask.
|
201 |
+
iso_mask = iso_mask[:, None, :, :].repeat(
|
202 |
+
1, self.last_iso_layer, 1, 1)
|
203 |
+
iso_mm_masks.append(iso_mask)
|
204 |
+
if self.last_iso_layer < self.num_hidden_layers:
|
205 |
+
mm_mask = mm_mask[:, None, :, :].repeat(
|
206 |
+
1, self.num_hidden_layers - self.last_iso_layer, 1, 1
|
207 |
+
)
|
208 |
+
iso_mm_masks.append(mm_mask)
|
209 |
+
iso_mm_masks = torch.cat(iso_mm_masks, dim=1)
|
210 |
+
return iso_mm_masks
|
211 |
+
|
212 |
+
def _make_iso_mask(self, batch_size, cmasks, vmasks):
|
213 |
+
cls_self_mask = torch.cat(
|
214 |
+
[
|
215 |
+
torch.ones(
|
216 |
+
(batch_size, 1), dtype=torch.bool, device=cmasks.device),
|
217 |
+
torch.zeros(
|
218 |
+
(batch_size, cmasks.size(1) + vmasks.size(1) - 1),
|
219 |
+
dtype=torch.bool, device=cmasks.device)
|
220 |
+
], dim=1)
|
221 |
+
|
222 |
+
iso_video_mask = torch.cat(
|
223 |
+
[
|
224 |
+
# [CLS] is not used.
|
225 |
+
torch.zeros(
|
226 |
+
(batch_size, 1), dtype=torch.bool, device=cmasks.device
|
227 |
+
),
|
228 |
+
vmasks,
|
229 |
+
# assume to be 1.
|
230 |
+
cmasks[:, 1:2],
|
231 |
+
# 2 means [CLS] + [SEP]
|
232 |
+
torch.zeros(
|
233 |
+
(batch_size, cmasks.size(1) - 2),
|
234 |
+
dtype=torch.bool,
|
235 |
+
device=cmasks.device,
|
236 |
+
),
|
237 |
+
],
|
238 |
+
dim=1,
|
239 |
+
)
|
240 |
+
iso_text_mask = torch.cat(
|
241 |
+
[
|
242 |
+
torch.zeros(
|
243 |
+
(batch_size, 2 + vmasks.size(1)),
|
244 |
+
dtype=torch.bool,
|
245 |
+
device=cmasks.device,
|
246 |
+
), # [CLS] is not used.
|
247 |
+
cmasks[:, 2:], # assume to be 1.
|
248 |
+
],
|
249 |
+
dim=1,
|
250 |
+
)
|
251 |
+
cls_self_mask = cls_self_mask[:, None, :]
|
252 |
+
iso_video_mask = iso_video_mask[:, None, :].repeat(
|
253 |
+
1, vmasks.size(1) + 1, 1)
|
254 |
+
iso_text_mask = iso_text_mask[:, None, :].repeat(
|
255 |
+
1, cmasks.size(1) - 2, 1)
|
256 |
+
return torch.cat([cls_self_mask, iso_video_mask, iso_text_mask], dim=1)
|
257 |
+
|
258 |
+
def _pooling_vt_layer(
|
259 |
+
self,
|
260 |
+
layered_sequence_output,
|
261 |
+
cmasks,
|
262 |
+
vmasks
|
263 |
+
):
|
264 |
+
layer_idx = self.last_iso_layer \
|
265 |
+
if self.last_iso_layer > 0 else self.num_hidden_layers
|
266 |
+
hidden_state = layered_sequence_output[layer_idx]
|
267 |
+
# also output pooled_video and pooled_text.
|
268 |
+
batch_size = cmasks.size(0)
|
269 |
+
# pool the modality.
|
270 |
+
text_offset = vmasks.size(1) + 2 # [CLS] + [SEP]
|
271 |
+
# video tokens + [SEP]
|
272 |
+
video_outputs = hidden_state[:, 1:text_offset]
|
273 |
+
video_attention_mask = torch.cat(
|
274 |
+
[
|
275 |
+
vmasks,
|
276 |
+
torch.ones(
|
277 |
+
(batch_size, 1), dtype=torch.bool, device=vmasks.device),
|
278 |
+
],
|
279 |
+
dim=1,
|
280 |
+
)
|
281 |
+
assert video_outputs.size(1) == video_attention_mask.size(1)
|
282 |
+
pooled_video = torch.sum(
|
283 |
+
video_outputs * video_attention_mask.unsqueeze(-1), dim=1
|
284 |
+
) / video_attention_mask.sum(1, keepdim=True)
|
285 |
+
# pooled_video = torch.mean(video_outputs[0], dim=1)
|
286 |
+
|
287 |
+
# text tokens + [SEP]
|
288 |
+
text_attention_mask = cmasks[:, 2:]
|
289 |
+
text_outputs = hidden_state[:, text_offset:]
|
290 |
+
assert text_outputs.size(1) == text_attention_mask.size(1)
|
291 |
+
pooled_text = torch.sum(
|
292 |
+
text_outputs * text_attention_mask.unsqueeze(-1), dim=1
|
293 |
+
) / text_attention_mask.sum(1, keepdim=True)
|
294 |
+
return pooled_video, pooled_text
|
295 |
+
|
296 |
+
|
297 |
+
class MMFusionMFMMLM(MMFusion):
|
298 |
+
"""forward function for MFM and MLM."""
|
299 |
+
def forward(
|
300 |
+
self,
|
301 |
+
caps,
|
302 |
+
cmasks,
|
303 |
+
vfeats,
|
304 |
+
vmasks,
|
305 |
+
attention_mask=None,
|
306 |
+
video_label=None,
|
307 |
+
text_label=None,
|
308 |
+
**kwargs
|
309 |
+
):
|
310 |
+
output_hidden_states = False if self.is_train else True
|
311 |
+
|
312 |
+
target_vfeats, non_masked_frame_mask = None, None
|
313 |
+
if video_label is not None:
|
314 |
+
target_vfeats = vfeats.masked_select(
|
315 |
+
video_label.unsqueeze(-1)).view(
|
316 |
+
-1, vfeats.size(-1)
|
317 |
+
)
|
318 |
+
# mask video token.
|
319 |
+
vfeats[video_label] = 0.0
|
320 |
+
non_masked_frame_mask = vmasks.clone()
|
321 |
+
non_masked_frame_mask[video_label] = False
|
322 |
+
|
323 |
+
attention_mask, token_type_ids = self._mm_on_the_fly(
|
324 |
+
cmasks, vmasks, attention_mask)
|
325 |
+
|
326 |
+
outputs = self.mm_encoder(
|
327 |
+
input_ids=caps,
|
328 |
+
input_video_embeds=vfeats,
|
329 |
+
attention_mask=attention_mask,
|
330 |
+
token_type_ids=token_type_ids,
|
331 |
+
masked_frame_labels=video_label,
|
332 |
+
target_video_hidden_states=target_vfeats,
|
333 |
+
non_masked_frame_mask=non_masked_frame_mask,
|
334 |
+
masked_lm_labels=text_label,
|
335 |
+
output_hidden_states=output_hidden_states,
|
336 |
+
)
|
337 |
+
|
338 |
+
video_logits, text_logits = outputs[0], outputs[1]
|
339 |
+
|
340 |
+
if self.is_train: # return earlier for training.
|
341 |
+
return {
|
342 |
+
"video_logits": video_logits,
|
343 |
+
"text_logits": text_logits,
|
344 |
+
}
|
345 |
+
|
346 |
+
pooled_video, pooled_text = self._pooling_vt_layer(
|
347 |
+
outputs[2], cmasks, vmasks)
|
348 |
+
return {"pooled_video": pooled_video, "pooled_text": pooled_text}
|
349 |
+
|
350 |
+
|
351 |
+
class MMFusionMTM(MMFusionMFMMLM):
|
352 |
+
def __init__(self, config, **kwargs):
|
353 |
+
super().__init__(config)
|
354 |
+
"""
|
355 |
+
For reproducibility:
|
356 |
+
self.mm_encoder will be initialized then discarded.
|
357 |
+
"""
|
358 |
+
from .transformermodel import MMBertForMTM
|
359 |
+
model_config = AutoConfig.from_pretrained(config.dataset.bert_name)
|
360 |
+
model_config.max_video_len = config.dataset.max_video_len
|
361 |
+
model_config.use_seg_emb = config.model.use_seg_emb
|
362 |
+
self.mm_encoder = MMBertForMTM.from_pretrained(
|
363 |
+
config.dataset.bert_name, config=model_config)
|
364 |
+
|
365 |
+
|
366 |
+
class MMFusionShare(MMFusion):
|
367 |
+
"""A retrival wrapper using mm_encoder as both video/text backbone.
|
368 |
+
TODO: move formally.
|
369 |
+
"""
|
370 |
+
def forward(
|
371 |
+
self,
|
372 |
+
caps,
|
373 |
+
cmasks,
|
374 |
+
vfeats,
|
375 |
+
vmasks,
|
376 |
+
attention_mask=None,
|
377 |
+
video_label=None,
|
378 |
+
text_label=None,
|
379 |
+
output_hidden_states=False,
|
380 |
+
**kwargs
|
381 |
+
):
|
382 |
+
pooled_video = self.forward_video(
|
383 |
+
vfeats,
|
384 |
+
vmasks,
|
385 |
+
caps,
|
386 |
+
cmasks,
|
387 |
+
output_hidden_states
|
388 |
+
)
|
389 |
+
|
390 |
+
pooled_text = self.forward_text(
|
391 |
+
caps,
|
392 |
+
cmasks,
|
393 |
+
output_hidden_states
|
394 |
+
)
|
395 |
+
|
396 |
+
return {"pooled_video": pooled_video, "pooled_text": pooled_text}
|
397 |
+
|
398 |
+
def forward_video(
|
399 |
+
self,
|
400 |
+
vfeats,
|
401 |
+
vmasks,
|
402 |
+
caps,
|
403 |
+
cmasks,
|
404 |
+
output_hidden_states=False,
|
405 |
+
**kwargs
|
406 |
+
):
|
407 |
+
input_ids = caps[:, :2]
|
408 |
+
|
409 |
+
attention_mask = torch.cat([
|
410 |
+
cmasks[:, :1],
|
411 |
+
vmasks,
|
412 |
+
cmasks[:, 1:2]
|
413 |
+
], dim=1)
|
414 |
+
|
415 |
+
token_type_ids = torch.zeros(
|
416 |
+
(vmasks.size(0), vmasks.size(1) + 2),
|
417 |
+
dtype=torch.long,
|
418 |
+
device=vmasks.device)
|
419 |
+
|
420 |
+
outputs = self.mm_encoder(
|
421 |
+
input_ids=input_ids,
|
422 |
+
input_video_embeds=vfeats,
|
423 |
+
attention_mask=attention_mask,
|
424 |
+
token_type_ids=token_type_ids,
|
425 |
+
output_hidden_states=True
|
426 |
+
)
|
427 |
+
video_outputs = outputs[0]
|
428 |
+
|
429 |
+
if output_hidden_states:
|
430 |
+
return video_outputs
|
431 |
+
|
432 |
+
batch_size = cmasks.size(0)
|
433 |
+
|
434 |
+
video_attention_mask = torch.cat(
|
435 |
+
[
|
436 |
+
torch.zeros(
|
437 |
+
(batch_size, 1), dtype=torch.bool, device=vmasks.device),
|
438 |
+
vmasks,
|
439 |
+
torch.ones(
|
440 |
+
(batch_size, 1), dtype=torch.bool, device=vmasks.device),
|
441 |
+
],
|
442 |
+
dim=1,
|
443 |
+
)
|
444 |
+
assert video_outputs.size(1) == video_attention_mask.size(1)
|
445 |
+
|
446 |
+
video_attention_mask = video_attention_mask.type(video_outputs.dtype) \
|
447 |
+
/ video_attention_mask.sum(1, keepdim=True)
|
448 |
+
|
449 |
+
pooled_video = torch.bmm(
|
450 |
+
video_outputs.transpose(2, 1),
|
451 |
+
video_attention_mask.unsqueeze(2)
|
452 |
+
).squeeze(-1)
|
453 |
+
return pooled_video # video_outputs
|
454 |
+
|
455 |
+
def forward_text(
|
456 |
+
self,
|
457 |
+
caps,
|
458 |
+
cmasks,
|
459 |
+
output_hidden_states=False,
|
460 |
+
**kwargs
|
461 |
+
):
|
462 |
+
input_ids = torch.cat([
|
463 |
+
caps[:, :1], caps[:, 2:],
|
464 |
+
], dim=1)
|
465 |
+
|
466 |
+
attention_mask = torch.cat([
|
467 |
+
cmasks[:, :1],
|
468 |
+
cmasks[:, 2:]
|
469 |
+
], dim=1)
|
470 |
+
|
471 |
+
token_type_ids = torch.cat([
|
472 |
+
torch.zeros(
|
473 |
+
(cmasks.size(0), 1),
|
474 |
+
dtype=torch.long,
|
475 |
+
device=cmasks.device),
|
476 |
+
torch.ones(
|
477 |
+
(cmasks.size(0), cmasks.size(1) - 2),
|
478 |
+
dtype=torch.long,
|
479 |
+
device=cmasks.device)
|
480 |
+
], dim=1)
|
481 |
+
|
482 |
+
outputs = self.mm_encoder(
|
483 |
+
input_ids=input_ids,
|
484 |
+
input_video_embeds=None,
|
485 |
+
attention_mask=attention_mask,
|
486 |
+
token_type_ids=token_type_ids,
|
487 |
+
output_hidden_states=True
|
488 |
+
)
|
489 |
+
text_outputs = outputs[0]
|
490 |
+
|
491 |
+
if output_hidden_states:
|
492 |
+
return text_outputs
|
493 |
+
|
494 |
+
batch_size = caps.size(0)
|
495 |
+
# text tokens + [SEP]
|
496 |
+
text_attention_mask = torch.cat([
|
497 |
+
torch.zeros(
|
498 |
+
(batch_size, 1), dtype=torch.bool, device=cmasks.device),
|
499 |
+
cmasks[:, 2:]
|
500 |
+
], dim=1)
|
501 |
+
|
502 |
+
assert text_outputs.size(1) == text_attention_mask.size(1)
|
503 |
+
|
504 |
+
text_attention_mask = text_attention_mask.type(text_outputs.dtype) \
|
505 |
+
/ text_attention_mask.sum(1, keepdim=True)
|
506 |
+
|
507 |
+
pooled_text = torch.bmm(
|
508 |
+
text_outputs.transpose(2, 1),
|
509 |
+
text_attention_mask.unsqueeze(2)
|
510 |
+
).squeeze(-1)
|
511 |
+
return pooled_text # text_outputs
|
512 |
+
|
513 |
+
|
514 |
+
class MMFusionSeparate(MMFusionShare):
|
515 |
+
def forward_video(
|
516 |
+
self,
|
517 |
+
vfeats,
|
518 |
+
vmasks,
|
519 |
+
caps,
|
520 |
+
cmasks,
|
521 |
+
output_hidden_states=False,
|
522 |
+
**kwargs
|
523 |
+
):
|
524 |
+
input_ids = caps[:, :2]
|
525 |
+
|
526 |
+
attention_mask = torch.cat([
|
527 |
+
cmasks[:, :1],
|
528 |
+
vmasks,
|
529 |
+
cmasks[:, 1:2]
|
530 |
+
], dim=1)
|
531 |
+
|
532 |
+
token_type_ids = torch.zeros(
|
533 |
+
(vmasks.size(0), vmasks.size(1) + 2),
|
534 |
+
dtype=torch.long,
|
535 |
+
device=vmasks.device)
|
536 |
+
|
537 |
+
outputs = self.video_encoder(
|
538 |
+
input_ids=input_ids,
|
539 |
+
input_video_embeds=vfeats,
|
540 |
+
attention_mask=attention_mask,
|
541 |
+
token_type_ids=token_type_ids,
|
542 |
+
output_hidden_states=True
|
543 |
+
)
|
544 |
+
video_outputs = outputs[0]
|
545 |
+
|
546 |
+
if output_hidden_states:
|
547 |
+
return video_outputs
|
548 |
+
|
549 |
+
batch_size = cmasks.size(0)
|
550 |
+
|
551 |
+
video_attention_mask = torch.cat(
|
552 |
+
[
|
553 |
+
torch.zeros(
|
554 |
+
(batch_size, 1), dtype=torch.bool, device=vmasks.device),
|
555 |
+
vmasks,
|
556 |
+
torch.ones(
|
557 |
+
(batch_size, 1), dtype=torch.bool, device=vmasks.device),
|
558 |
+
],
|
559 |
+
dim=1,
|
560 |
+
)
|
561 |
+
assert video_outputs.size(1) == video_attention_mask.size(1)
|
562 |
+
|
563 |
+
video_attention_mask = video_attention_mask.type(video_outputs.dtype) \
|
564 |
+
/ video_attention_mask.sum(1, keepdim=True)
|
565 |
+
|
566 |
+
pooled_video = torch.bmm(
|
567 |
+
video_outputs.transpose(2, 1),
|
568 |
+
video_attention_mask.unsqueeze(2)
|
569 |
+
).squeeze(-1)
|
570 |
+
return pooled_video # video_outputs
|
571 |
+
|
572 |
+
def forward_text(
|
573 |
+
self,
|
574 |
+
caps,
|
575 |
+
cmasks,
|
576 |
+
output_hidden_states=False,
|
577 |
+
**kwargs
|
578 |
+
):
|
579 |
+
input_ids = torch.cat([
|
580 |
+
caps[:, :1], caps[:, 2:],
|
581 |
+
], dim=1)
|
582 |
+
|
583 |
+
attention_mask = torch.cat([
|
584 |
+
cmasks[:, :1],
|
585 |
+
cmasks[:, 2:]
|
586 |
+
], dim=1)
|
587 |
+
# different from sharing, we use all-0 type.
|
588 |
+
token_type_ids = torch.zeros(
|
589 |
+
(cmasks.size(0), cmasks.size(1) - 1),
|
590 |
+
dtype=torch.long,
|
591 |
+
device=cmasks.device)
|
592 |
+
|
593 |
+
outputs = self.text_encoder(
|
594 |
+
input_ids=input_ids,
|
595 |
+
attention_mask=attention_mask,
|
596 |
+
token_type_ids=token_type_ids,
|
597 |
+
output_hidden_states=True
|
598 |
+
)
|
599 |
+
text_outputs = outputs[0]
|
600 |
+
|
601 |
+
if output_hidden_states:
|
602 |
+
return text_outputs
|
603 |
+
|
604 |
+
batch_size = caps.size(0)
|
605 |
+
# text tokens + [SEP]
|
606 |
+
text_attention_mask = torch.cat([
|
607 |
+
torch.zeros(
|
608 |
+
(batch_size, 1), dtype=torch.bool, device=cmasks.device),
|
609 |
+
cmasks[:, 2:]
|
610 |
+
], dim=1)
|
611 |
+
|
612 |
+
assert text_outputs.size(1) == text_attention_mask.size(1)
|
613 |
+
|
614 |
+
text_attention_mask = text_attention_mask.type(text_outputs.dtype) \
|
615 |
+
/ text_attention_mask.sum(1, keepdim=True)
|
616 |
+
|
617 |
+
pooled_text = torch.bmm(
|
618 |
+
text_outputs.transpose(2, 1),
|
619 |
+
text_attention_mask.unsqueeze(2)
|
620 |
+
).squeeze(-1)
|
621 |
+
return pooled_text # text_outputs
|
622 |
+
|
623 |
+
|
624 |
+
class MMFusionJoint(MMFusion):
|
625 |
+
"""fine-tuning wrapper for retrival task."""
|
626 |
+
|
627 |
+
def forward(
|
628 |
+
self,
|
629 |
+
caps,
|
630 |
+
cmasks,
|
631 |
+
vfeats,
|
632 |
+
vmasks,
|
633 |
+
attention_mask=None,
|
634 |
+
video_label=None,
|
635 |
+
text_label=None,
|
636 |
+
**kwargs
|
637 |
+
):
|
638 |
+
# TODO (huxu): other ways to do negative examples; move the following
|
639 |
+
# into your criterion forward.
|
640 |
+
output_hidden_states = True
|
641 |
+
|
642 |
+
attention_mask, token_type_ids = self._mm_on_the_fly(
|
643 |
+
cmasks, vmasks, attention_mask)
|
644 |
+
|
645 |
+
separate_forward_split = (
|
646 |
+
None if self.is_train else vmasks.size(1) + 2
|
647 |
+
) # [CLS] + [SEP]
|
648 |
+
|
649 |
+
outputs = self.mm_encoder(
|
650 |
+
input_ids=caps,
|
651 |
+
input_video_embeds=vfeats,
|
652 |
+
attention_mask=attention_mask,
|
653 |
+
token_type_ids=token_type_ids,
|
654 |
+
output_hidden_states=output_hidden_states,
|
655 |
+
separate_forward_split=separate_forward_split,
|
656 |
+
)
|
657 |
+
|
658 |
+
pooled_video, pooled_text = self._pooling_vt_layer(
|
659 |
+
outputs[2], cmasks, vmasks)
|
660 |
+
return {"pooled_video": pooled_video, "pooled_text": pooled_text}
|
661 |
+
|
662 |
+
|
663 |
+
class MMFusionActionSegmentation(MMFusion):
|
664 |
+
"""Fine-tuning wrapper for action segmentation.
|
665 |
+
TODO: rename this for VLM.
|
666 |
+
"""
|
667 |
+
def forward(
|
668 |
+
self,
|
669 |
+
caps,
|
670 |
+
cmasks,
|
671 |
+
vfeats,
|
672 |
+
vmasks,
|
673 |
+
attention_mask=None,
|
674 |
+
**kwargs
|
675 |
+
):
|
676 |
+
# ActionLocalization assume of batch_size=1, squeeze it.
|
677 |
+
caps = caps.view(-1, caps.size(-1))
|
678 |
+
cmasks = cmasks.view(-1, cmasks.size(-1))
|
679 |
+
vfeats = vfeats.view(-1, vfeats.size(2), vfeats.size(3))
|
680 |
+
vmasks = vmasks.view(-1, vmasks.size(-1))
|
681 |
+
|
682 |
+
# this may not cover all shapes of attention_mask.
|
683 |
+
attention_mask = attention_mask.view(
|
684 |
+
-1, attention_mask.size(2), attention_mask.size(3)) \
|
685 |
+
if attention_mask is not None else None
|
686 |
+
|
687 |
+
# TODO (huxu): other ways to do negative examples; move the following
|
688 |
+
# into your criterion forward.
|
689 |
+
output_hidden_states = True
|
690 |
+
|
691 |
+
# video forwarding, text is dummy; never use attention_mask.
|
692 |
+
attention_mask, token_type_ids = self._mm_on_the_fly(
|
693 |
+
cmasks, vmasks, attention_mask)
|
694 |
+
|
695 |
+
logits = self.mm_encoder(
|
696 |
+
input_ids=caps,
|
697 |
+
input_video_embeds=vfeats,
|
698 |
+
attention_mask=attention_mask,
|
699 |
+
token_type_ids=token_type_ids,
|
700 |
+
output_hidden_states=output_hidden_states,
|
701 |
+
)
|
702 |
+
return {"logits": logits[0][:, 1:vmasks.size(1)+1]}
|
703 |
+
|
704 |
+
|
705 |
+
class MMFusionActionLocalization(MMFusion):
|
706 |
+
"""fine-tuning model for retrival task."""
|
707 |
+
|
708 |
+
def __init__(self, config, **kwargs):
|
709 |
+
super().__init__(config)
|
710 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
711 |
+
config.dataset.bert_name)
|
712 |
+
self.cls_token_id = tokenizer.cls_token_id
|
713 |
+
self.sep_token_id = tokenizer.sep_token_id
|
714 |
+
self.pad_token_id = tokenizer.pad_token_id
|
715 |
+
|
716 |
+
def forward(
|
717 |
+
self,
|
718 |
+
caps,
|
719 |
+
cmasks,
|
720 |
+
vfeats,
|
721 |
+
vmasks,
|
722 |
+
attention_mask=None,
|
723 |
+
**kwargs
|
724 |
+
):
|
725 |
+
# ActionLocalization assume of batch_size=1, squeeze it.
|
726 |
+
caps = caps.squeeze(0)
|
727 |
+
cmasks = cmasks.squeeze(0)
|
728 |
+
vfeats = vfeats.squeeze(0)
|
729 |
+
vmasks = vmasks.squeeze(0)
|
730 |
+
attention_mask = attention_mask.squeeze(0) if attention_mask is not None else None
|
731 |
+
|
732 |
+
# TODO (huxu): other ways to do negative examples; move the following
|
733 |
+
# into your criterion forward.
|
734 |
+
output_hidden_states = True
|
735 |
+
|
736 |
+
# a len1 dummy video token.
|
737 |
+
dummy_vfeats = torch.zeros(
|
738 |
+
(caps.size(0), 1, vfeats.size(-1)), device=vfeats.device, dtype=vfeats.dtype)
|
739 |
+
dummy_vmasks = torch.ones(
|
740 |
+
(caps.size(0), 1), dtype=torch.bool,
|
741 |
+
device=vfeats.device)
|
742 |
+
|
743 |
+
dummy_caps = torch.LongTensor(
|
744 |
+
[[self.cls_token_id, self.sep_token_id,
|
745 |
+
self.pad_token_id, self.sep_token_id]],
|
746 |
+
).to(caps.device).repeat(vfeats.size(0), 1)
|
747 |
+
dummy_cmasks = torch.BoolTensor(
|
748 |
+
[[0, 1, 0, 1]] # pad are valid for attention.
|
749 |
+
).to(caps.device).repeat(vfeats.size(0), 1)
|
750 |
+
|
751 |
+
# video forwarding, text is dummy; never use attention_mask.
|
752 |
+
attention_mask, token_type_ids = self._mm_on_the_fly(
|
753 |
+
dummy_cmasks, vmasks, None)
|
754 |
+
|
755 |
+
outputs = self.mm_encoder(
|
756 |
+
input_ids=dummy_caps,
|
757 |
+
input_video_embeds=vfeats,
|
758 |
+
attention_mask=attention_mask,
|
759 |
+
token_type_ids=token_type_ids,
|
760 |
+
output_hidden_states=output_hidden_states,
|
761 |
+
)
|
762 |
+
|
763 |
+
layer_idx = self.last_iso_layer \
|
764 |
+
if self.last_iso_layer > 0 else self.num_hidden_layers
|
765 |
+
|
766 |
+
video_seq = outputs[2][layer_idx][:, 1:vmasks.size(1)+1].masked_select(
|
767 |
+
vmasks.unsqueeze(-1)
|
768 |
+
).view(-1, self.hidden_size)
|
769 |
+
|
770 |
+
# text forwarding, video is dummy
|
771 |
+
attention_mask, token_type_ids = self._mm_on_the_fly(
|
772 |
+
cmasks, dummy_vmasks, None)
|
773 |
+
|
774 |
+
outputs = self.mm_encoder(
|
775 |
+
input_ids=caps,
|
776 |
+
input_video_embeds=dummy_vfeats,
|
777 |
+
attention_mask=attention_mask,
|
778 |
+
token_type_ids=token_type_ids,
|
779 |
+
output_hidden_states=output_hidden_states,
|
780 |
+
)
|
781 |
+
|
782 |
+
_, pooled_text = self._pooling_vt_layer(
|
783 |
+
outputs[2], cmasks, dummy_vmasks)
|
784 |
+
# this line is not right.
|
785 |
+
logits = torch.mm(video_seq, pooled_text.transpose(1, 0))
|
786 |
+
return {"logits": logits}
|
787 |
+
|
788 |
+
|
789 |
+
# --------------- MMFusionSeparate for end tasks ---------------
|
790 |
+
|
791 |
+
class MMFusionSeparateActionSegmentation(MMFusionSeparate):
|
792 |
+
"""Fine-tuning wrapper for action segmentation."""
|
793 |
+
def forward(
|
794 |
+
self,
|
795 |
+
caps,
|
796 |
+
cmasks,
|
797 |
+
vfeats,
|
798 |
+
vmasks,
|
799 |
+
attention_mask=None,
|
800 |
+
**kwargs
|
801 |
+
):
|
802 |
+
# ActionLocalization assume of batch_size=1, squeeze it.
|
803 |
+
caps = caps.view(-1, caps.size(-1))
|
804 |
+
cmasks = cmasks.view(-1, cmasks.size(-1))
|
805 |
+
vfeats = vfeats.view(-1, vfeats.size(2), vfeats.size(3))
|
806 |
+
vmasks = vmasks.view(-1, vmasks.size(-1))
|
807 |
+
logits = self.forward_video(
|
808 |
+
vfeats,
|
809 |
+
vmasks,
|
810 |
+
caps,
|
811 |
+
cmasks,
|
812 |
+
output_hidden_states=True
|
813 |
+
)
|
814 |
+
return {"logits": logits[:, 1:vmasks.size(1)+1]}
|
815 |
+
|
816 |
+
|
817 |
+
class MMFusionSeparateActionLocalization(MMFusionSeparate):
|
818 |
+
def __init__(self, config, **kwargs):
|
819 |
+
super().__init__(config)
|
820 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
821 |
+
config.dataset.bert_name)
|
822 |
+
self.cls_token_id = tokenizer.cls_token_id
|
823 |
+
self.sep_token_id = tokenizer.sep_token_id
|
824 |
+
self.pad_token_id = tokenizer.pad_token_id
|
825 |
+
|
826 |
+
def forward(
|
827 |
+
self,
|
828 |
+
caps,
|
829 |
+
cmasks,
|
830 |
+
vfeats,
|
831 |
+
vmasks,
|
832 |
+
**kwargs
|
833 |
+
):
|
834 |
+
# ActionLocalization assume of batch_size=1, squeeze it.
|
835 |
+
caps = caps.squeeze(0)
|
836 |
+
cmasks = cmasks.squeeze(0)
|
837 |
+
vfeats = vfeats.squeeze(0)
|
838 |
+
vmasks = vmasks.squeeze(0)
|
839 |
+
|
840 |
+
# TODO (huxu): other ways to do negative examples; move the following
|
841 |
+
# into your criterion forward.
|
842 |
+
dummy_caps = torch.LongTensor(
|
843 |
+
[[self.cls_token_id, self.sep_token_id,
|
844 |
+
self.pad_token_id, self.sep_token_id]],
|
845 |
+
).to(caps.device).repeat(vfeats.size(0), 1)
|
846 |
+
dummy_cmasks = torch.BoolTensor(
|
847 |
+
[[0, 1, 0, 1]] # pad are valid for attention.
|
848 |
+
).to(caps.device).repeat(vfeats.size(0), 1)
|
849 |
+
|
850 |
+
outputs = self.forward_video(
|
851 |
+
vfeats,
|
852 |
+
vmasks,
|
853 |
+
dummy_caps,
|
854 |
+
dummy_cmasks,
|
855 |
+
output_hidden_states=True
|
856 |
+
)
|
857 |
+
|
858 |
+
video_seq = outputs[:, 1:vmasks.size(1)+1].masked_select(
|
859 |
+
vmasks.unsqueeze(-1)
|
860 |
+
).view(-1, self.hidden_size)
|
861 |
+
|
862 |
+
pooled_text = self.forward_text(
|
863 |
+
caps,
|
864 |
+
cmasks,
|
865 |
+
output_hidden_states=False
|
866 |
+
)
|
867 |
+
|
868 |
+
# this line is not right.
|
869 |
+
logits = torch.mm(video_seq, pooled_text.transpose(1, 0))
|
870 |
+
return {"logits": logits}
|
871 |
+
|
872 |
+
|
873 |
+
class MMFusionShareActionLocalization(MMFusionShare):
|
874 |
+
def __init__(self, config, **kwargs):
|
875 |
+
super().__init__(config)
|
876 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
877 |
+
config.dataset.bert_name)
|
878 |
+
self.cls_token_id = tokenizer.cls_token_id
|
879 |
+
self.sep_token_id = tokenizer.sep_token_id
|
880 |
+
self.pad_token_id = tokenizer.pad_token_id
|
881 |
+
|
882 |
+
def forward(
|
883 |
+
self,
|
884 |
+
caps,
|
885 |
+
cmasks,
|
886 |
+
vfeats,
|
887 |
+
vmasks,
|
888 |
+
**kwargs
|
889 |
+
):
|
890 |
+
# ActionLocalization assume of batch_size=1, squeeze it.
|
891 |
+
caps = caps.squeeze(0)
|
892 |
+
cmasks = cmasks.squeeze(0)
|
893 |
+
vfeats = vfeats.squeeze(0)
|
894 |
+
vmasks = vmasks.squeeze(0)
|
895 |
+
|
896 |
+
# TODO (huxu): other ways to do negative examples; move the following
|
897 |
+
# into your criterion forward.
|
898 |
+
dummy_caps = torch.LongTensor(
|
899 |
+
[[self.cls_token_id, self.sep_token_id,
|
900 |
+
self.pad_token_id, self.sep_token_id]],
|
901 |
+
).to(caps.device).repeat(vfeats.size(0), 1)
|
902 |
+
dummy_cmasks = torch.BoolTensor(
|
903 |
+
[[0, 1, 0, 1]] # pad are valid for attention.
|
904 |
+
).to(caps.device).repeat(vfeats.size(0), 1)
|
905 |
+
|
906 |
+
outputs = self.forward_video(
|
907 |
+
vfeats,
|
908 |
+
vmasks,
|
909 |
+
dummy_caps,
|
910 |
+
dummy_cmasks,
|
911 |
+
output_hidden_states=True
|
912 |
+
)
|
913 |
+
|
914 |
+
video_seq = outputs[:, 1:vmasks.size(1)+1].masked_select(
|
915 |
+
vmasks.unsqueeze(-1)
|
916 |
+
).view(-1, self.hidden_size)
|
917 |
+
|
918 |
+
pooled_text = self.forward_text(
|
919 |
+
caps,
|
920 |
+
cmasks,
|
921 |
+
output_hidden_states=False
|
922 |
+
)
|
923 |
+
|
924 |
+
# this line is not right.
|
925 |
+
logits = torch.mm(video_seq, pooled_text.transpose(1, 0))
|
926 |
+
return {"logits": logits}
|
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/models/mmfusionnlg.py
ADDED
@@ -0,0 +1,999 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
# Copyright (c) Facebook, Inc. All Rights Reserved
|
17 |
+
|
18 |
+
|
19 |
+
import torch
|
20 |
+
|
21 |
+
from torch.nn import functional as F
|
22 |
+
|
23 |
+
from typing import Optional, Iterable
|
24 |
+
|
25 |
+
try:
|
26 |
+
from transformers import BertPreTrainedModel
|
27 |
+
from transformers.modeling_bert import BertOnlyMLMHead
|
28 |
+
|
29 |
+
from transformers.file_utils import ModelOutput
|
30 |
+
from transformers.modeling_outputs import CausalLMOutput
|
31 |
+
from transformers.generation_utils import (
|
32 |
+
BeamHypotheses,
|
33 |
+
top_k_top_p_filtering
|
34 |
+
)
|
35 |
+
except ImportError:
|
36 |
+
pass
|
37 |
+
|
38 |
+
from .mmfusion import MMFusion
|
39 |
+
from .transformermodel import MMBertModel
|
40 |
+
from ..modules import VideoTokenMLP
|
41 |
+
|
42 |
+
|
43 |
+
class MMFusionNLG(MMFusion):
|
44 |
+
def __init__(self, config, **kwargs):
|
45 |
+
super().__init__(config)
|
46 |
+
if config.model.max_decode_length is not None:
|
47 |
+
self.max_length = min(
|
48 |
+
config.model.max_decode_length,
|
49 |
+
config.dataset.max_len - config.dataset.max_video_len - 3
|
50 |
+
)
|
51 |
+
else:
|
52 |
+
self.max_length = \
|
53 |
+
config.dataset.max_len - config.dataset.max_video_len - 3
|
54 |
+
self.gen_param = config.gen_param if config.gen_param is not None \
|
55 |
+
else {}
|
56 |
+
|
57 |
+
def forward(
|
58 |
+
self,
|
59 |
+
caps,
|
60 |
+
cmasks,
|
61 |
+
vfeats,
|
62 |
+
vmasks,
|
63 |
+
attention_mask,
|
64 |
+
video_label=None,
|
65 |
+
text_label=None,
|
66 |
+
**kwargs
|
67 |
+
):
|
68 |
+
"""use pre-trained LM header for generation."""
|
69 |
+
attention_mask, token_type_ids = self._mm_on_the_fly(
|
70 |
+
cmasks, vmasks, attention_mask)
|
71 |
+
|
72 |
+
outputs = self.mm_encoder(
|
73 |
+
input_ids=caps,
|
74 |
+
input_video_embeds=vfeats,
|
75 |
+
attention_mask=attention_mask,
|
76 |
+
token_type_ids=token_type_ids,
|
77 |
+
masked_lm_labels=text_label,
|
78 |
+
)
|
79 |
+
return {"logits": outputs[0]}
|
80 |
+
|
81 |
+
@torch.no_grad()
|
82 |
+
def generate(
|
83 |
+
self,
|
84 |
+
caps, cmasks, vfeats, vmasks,
|
85 |
+
attention_mask=None,
|
86 |
+
bos_token_id=None,
|
87 |
+
eos_token_id=None,
|
88 |
+
**kwargs
|
89 |
+
):
|
90 |
+
# a simplified interface from
|
91 |
+
# https://huggingface.co/transformers/v3.4.0/_modules/transformers/generation_utils.html#GenerationMixin.generate
|
92 |
+
|
93 |
+
# caps now only have
|
94 |
+
# [CLS], [SEP] (for video) and [CLS] (as bos_token)
|
95 |
+
assert caps.size(1) == 3
|
96 |
+
|
97 |
+
attention_mask, token_type_ids = self._mm_on_the_fly(
|
98 |
+
cmasks, vmasks, attention_mask)
|
99 |
+
|
100 |
+
output = self.mm_encoder.generate(
|
101 |
+
input_ids=caps,
|
102 |
+
input_video_embeds=vfeats,
|
103 |
+
attention_mask=attention_mask,
|
104 |
+
token_type_ids=token_type_ids,
|
105 |
+
bos_token_id=bos_token_id,
|
106 |
+
eos_token_id=eos_token_id,
|
107 |
+
max_length=self.max_length,
|
108 |
+
**self.gen_param
|
109 |
+
)
|
110 |
+
return output
|
111 |
+
|
112 |
+
|
113 |
+
class MMBertForNLG(BertPreTrainedModel):
|
114 |
+
def __init__(self, config):
|
115 |
+
super().__init__(config)
|
116 |
+
self.bert = MMBertModel(config)
|
117 |
+
self.videomlp = VideoTokenMLP(config)
|
118 |
+
# we do not use `BertGenerationOnlyLMHead`
|
119 |
+
# because we can reuse pretraining.
|
120 |
+
self.cls = BertOnlyMLMHead(config)
|
121 |
+
self.hidden_size = config.hidden_size
|
122 |
+
self.init_weights()
|
123 |
+
|
124 |
+
def get_output_embeddings(self):
|
125 |
+
return self.cls.predictions.decoder
|
126 |
+
|
127 |
+
def forward(
|
128 |
+
self,
|
129 |
+
input_ids=None,
|
130 |
+
input_video_embeds=None,
|
131 |
+
attention_mask=None,
|
132 |
+
token_type_ids=None,
|
133 |
+
position_ids=None,
|
134 |
+
head_mask=None,
|
135 |
+
inputs_embeds=None,
|
136 |
+
masked_lm_labels=None,
|
137 |
+
output_attentions=None,
|
138 |
+
output_hidden_states=None,
|
139 |
+
return_dict=None,
|
140 |
+
):
|
141 |
+
# similar to MMBertForMFMMLM without MFM.
|
142 |
+
video_tokens = self.videomlp(input_video_embeds)
|
143 |
+
outputs = self.bert(
|
144 |
+
input_ids,
|
145 |
+
video_tokens,
|
146 |
+
attention_mask=attention_mask,
|
147 |
+
token_type_ids=token_type_ids,
|
148 |
+
position_ids=position_ids,
|
149 |
+
head_mask=head_mask,
|
150 |
+
inputs_embeds=inputs_embeds,
|
151 |
+
output_attentions=output_attentions,
|
152 |
+
output_hidden_states=output_hidden_states,
|
153 |
+
return_dict=return_dict,
|
154 |
+
)
|
155 |
+
|
156 |
+
sequence_output = outputs[0]
|
157 |
+
|
158 |
+
prediction_scores = None
|
159 |
+
if masked_lm_labels is not None:
|
160 |
+
text_offset = input_video_embeds.size(1) + 1 # [CLS]
|
161 |
+
# recover caps format: [CLS] [SEP] text [SEP]
|
162 |
+
text_sequence_output = torch.cat(
|
163 |
+
[sequence_output[:, :1], sequence_output[:, text_offset:]],
|
164 |
+
dim=1
|
165 |
+
)
|
166 |
+
|
167 |
+
# only compute select tokens to training to speed up.
|
168 |
+
hidden_size = text_sequence_output.size(-1)
|
169 |
+
# masked_lm_labels = masked_lm_labels.reshape(-1)
|
170 |
+
labels_mask = masked_lm_labels != -100
|
171 |
+
|
172 |
+
selected_text_output = text_sequence_output.masked_select(
|
173 |
+
labels_mask.unsqueeze(-1)
|
174 |
+
).view(-1, hidden_size)
|
175 |
+
prediction_scores = self.cls(selected_text_output)
|
176 |
+
|
177 |
+
if not return_dict:
|
178 |
+
output = (
|
179 |
+
prediction_scores,
|
180 |
+
) + outputs[2:]
|
181 |
+
return output
|
182 |
+
|
183 |
+
# for generation.
|
184 |
+
text_offset = input_video_embeds.size(1) + 2 # [CLS]
|
185 |
+
text_sequence_output = sequence_output[:, text_offset:]
|
186 |
+
prediction_scores = self.cls(text_sequence_output)
|
187 |
+
return CausalLMOutput(
|
188 |
+
loss=None,
|
189 |
+
logits=prediction_scores,
|
190 |
+
)
|
191 |
+
|
192 |
+
def prepare_inputs_for_generation(
|
193 |
+
self,
|
194 |
+
input_ids,
|
195 |
+
input_video_embeds,
|
196 |
+
attention_mask=None,
|
197 |
+
token_type_ids=None,
|
198 |
+
**model_kwargs
|
199 |
+
):
|
200 |
+
# must return a dictionary.
|
201 |
+
seq_len = input_ids.size(1) + input_video_embeds.size(1)
|
202 |
+
if attention_mask is not None:
|
203 |
+
if len(attention_mask.size()) == 4:
|
204 |
+
attention_mask = attention_mask[:, :, :seq_len, :seq_len]
|
205 |
+
elif len(attention_mask.size()) == 3:
|
206 |
+
attention_mask = attention_mask[:, :seq_len, :seq_len]
|
207 |
+
else:
|
208 |
+
attention_mask = attention_mask[:, :seq_len]
|
209 |
+
if token_type_ids is not None:
|
210 |
+
token_type_ids = token_type_ids[:, :seq_len]
|
211 |
+
|
212 |
+
return {
|
213 |
+
"input_ids": input_ids,
|
214 |
+
"input_video_embeds": input_video_embeds,
|
215 |
+
"attention_mask": attention_mask,
|
216 |
+
"token_type_ids": token_type_ids,
|
217 |
+
}
|
218 |
+
|
219 |
+
@torch.no_grad()
|
220 |
+
def generate(
|
221 |
+
self,
|
222 |
+
input_ids: Optional[torch.LongTensor] = None,
|
223 |
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
224 |
+
max_length: Optional[int] = None,
|
225 |
+
min_length: Optional[int] = None,
|
226 |
+
do_sample: Optional[bool] = None,
|
227 |
+
early_stopping: Optional[bool] = None,
|
228 |
+
num_beams: Optional[int] = None,
|
229 |
+
temperature: Optional[float] = None,
|
230 |
+
top_k: Optional[int] = None,
|
231 |
+
top_p: Optional[float] = None,
|
232 |
+
repetition_penalty: Optional[float] = None,
|
233 |
+
bad_words_ids: Optional[Iterable[int]] = None,
|
234 |
+
bos_token_id: Optional[int] = None,
|
235 |
+
pad_token_id: Optional[int] = None,
|
236 |
+
eos_token_id: Optional[int] = None,
|
237 |
+
length_penalty: Optional[float] = None,
|
238 |
+
no_repeat_ngram_size: Optional[int] = None,
|
239 |
+
num_return_sequences: Optional[int] = None,
|
240 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
241 |
+
decoder_start_token_id: Optional[int] = None,
|
242 |
+
use_cache: Optional[bool] = None,
|
243 |
+
**model_kwargs
|
244 |
+
) -> torch.LongTensor:
|
245 |
+
r"""
|
246 |
+
Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
|
247 |
+
beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling.
|
248 |
+
Adapted in part from `Facebook's XLM beam search code
|
249 |
+
<https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529>`__.
|
250 |
+
Apart from :obj:`input_ids` and :obj:`attention_mask`, all the arguments below will default to the value of the
|
251 |
+
attribute of the same name inside the :class:`~transformers.PretrainedConfig` of the model. The default values
|
252 |
+
indicated are the default values of those config.
|
253 |
+
Most of these parameters are explained in more detail in `this blog post
|
254 |
+
<https://huggingface.co/blog/how-to-generate>`__.
|
255 |
+
Parameters:
|
256 |
+
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
257 |
+
The sequence used as a prompt for the generation. If :obj:`None` the method initializes
|
258 |
+
it as an empty :obj:`torch.LongTensor` of shape :obj:`(1,)`.
|
259 |
+
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
260 |
+
initial input_ids for the decoder of encoder-decoder type models. If :obj:`None` then only
|
261 |
+
decoder_start_token_id is passed as the first token to the decoder.
|
262 |
+
max_length (:obj:`int`, `optional`, defaults to 20):
|
263 |
+
The maximum length of the sequence to be generated.
|
264 |
+
min_length (:obj:`int`, `optional`, defaults to 10):
|
265 |
+
The minimum length of the sequence to be generated.
|
266 |
+
do_sample (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
267 |
+
Whether or not to use sampling ; use greedy decoding otherwise.
|
268 |
+
early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
269 |
+
Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not.
|
270 |
+
num_beams (:obj:`int`, `optional`, defaults to 1):
|
271 |
+
Number of beams for beam search. 1 means no beam search.
|
272 |
+
temperature (:obj:`float`, `optional`, defaults tp 1.0):
|
273 |
+
The value used to module the next token probabilities.
|
274 |
+
top_k (:obj:`int`, `optional`, defaults to 50):
|
275 |
+
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
276 |
+
top_p (:obj:`float`, `optional`, defaults to 1.0):
|
277 |
+
If set to float < 1, only the most probable tokens with probabilities that add up to ``top_p`` or
|
278 |
+
higher are kept for generation.
|
279 |
+
repetition_penalty (:obj:`float`, `optional`, defaults to 1.0):
|
280 |
+
The parameter for repetition penalty. 1.0 means no penalty. See `this paper
|
281 |
+
<https://arxiv.org/pdf/1909.05858.pdf>`__ for more details.
|
282 |
+
pad_token_id (:obj:`int`, `optional`):
|
283 |
+
The id of the `padding` token.
|
284 |
+
bos_token_id (:obj:`int`, `optional`):
|
285 |
+
The id of the `beginning-of-sequence` token.
|
286 |
+
eos_token_id (:obj:`int`, `optional`):
|
287 |
+
The id of the `end-of-sequence` token.
|
288 |
+
length_penalty (:obj:`float`, `optional`, defaults to 1.0):
|
289 |
+
Exponential penalty to the length. 1.0 means no penalty.
|
290 |
+
Set to values < 1.0 in order to encourage the model to generate shorter sequences, to a value > 1.0 in
|
291 |
+
order to encourage the model to produce longer sequences.
|
292 |
+
no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0):
|
293 |
+
If set to int > 0, all ngrams of that size can only occur once.
|
294 |
+
bad_words_ids(:obj:`List[int]`, `optional`):
|
295 |
+
List of token ids that are not allowed to be generated. In order to get the tokens of the words that
|
296 |
+
should not appear in the generated text, use :obj:`tokenizer.encode(bad_word, add_prefix_space=True)`.
|
297 |
+
num_return_sequences(:obj:`int`, `optional`, defaults to 1):
|
298 |
+
The number of independently computed returned sequences for each element in the batch.
|
299 |
+
attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
300 |
+
Mask to avoid performing attention on padding token indices. Mask values are in ``[0, 1]``, 1 for
|
301 |
+
tokens that are not masked, and 0 for masked tokens.
|
302 |
+
If not provided, will default to a tensor the same shape as :obj:`input_ids` that masks the pad token.
|
303 |
+
`What are attention masks? <../glossary.html#attention-mask>`__
|
304 |
+
decoder_start_token_id (:obj:`int`, `optional`):
|
305 |
+
If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token.
|
306 |
+
use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
307 |
+
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
|
308 |
+
speed up decoding.
|
309 |
+
model_kwargs:
|
310 |
+
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model.
|
311 |
+
Return:
|
312 |
+
:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`:
|
313 |
+
The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or
|
314 |
+
shorter if all batches finished early due to the :obj:`eos_token_id`.
|
315 |
+
Examples::
|
316 |
+
tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
|
317 |
+
model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
|
318 |
+
outputs = model.generate(max_length=40) # do greedy decoding
|
319 |
+
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
|
320 |
+
tokenizer = AutoTokenizer.from_pretrained('openai-gpt') # Initialize tokenizer
|
321 |
+
model = AutoModelWithLMHead.from_pretrained('openai-gpt') # Download model and configuration from S3 and cache.
|
322 |
+
input_context = 'The dog'
|
323 |
+
input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
|
324 |
+
outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5) # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog'
|
325 |
+
for i in range(3): # 3 output sequences were generated
|
326 |
+
print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
|
327 |
+
tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
|
328 |
+
model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
|
329 |
+
input_context = 'The dog'
|
330 |
+
input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
|
331 |
+
outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3, do_sample=True) # generate 3 candidates using sampling
|
332 |
+
for i in range(3): # 3 output sequences were generated
|
333 |
+
print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
|
334 |
+
tokenizer = AutoTokenizer.from_pretrained('ctrl') # Initialize tokenizer
|
335 |
+
model = AutoModelWithLMHead.from_pretrained('ctrl') # Download model and configuration from S3 and cache.
|
336 |
+
input_context = 'Legal My neighbor is' # "Legal" is one of the control codes for ctrl
|
337 |
+
input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
|
338 |
+
outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences
|
339 |
+
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
|
340 |
+
tokenizer = AutoTokenizer.from_pretrained('gpt2') # Initialize tokenizer
|
341 |
+
model = AutoModelWithLMHead.from_pretrained('gpt2') # Download model and configuration from S3 and cache.
|
342 |
+
input_context = 'My cute dog' # "Legal" is one of the control codes for ctrl
|
343 |
+
bad_words_ids = [tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ['idiot', 'stupid', 'shut up']]
|
344 |
+
input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
|
345 |
+
outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids) # generate sequences without allowing bad_words to be generated
|
346 |
+
"""
|
347 |
+
|
348 |
+
# We cannot generate if the model does not have a LM head
|
349 |
+
if self.get_output_embeddings() is None:
|
350 |
+
raise AttributeError(
|
351 |
+
"You tried to generate sequences with a model that does not have a LM Head."
|
352 |
+
"Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`, `XLMWithLMHeadModel`, `BartForConditionalGeneration` )"
|
353 |
+
)
|
354 |
+
|
355 |
+
max_length = max_length if max_length is not None else self.config.max_length
|
356 |
+
min_length = min_length if min_length is not None else self.config.min_length
|
357 |
+
do_sample = do_sample if do_sample is not None else self.config.do_sample
|
358 |
+
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
|
359 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
360 |
+
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
361 |
+
temperature = temperature if temperature is not None else self.config.temperature
|
362 |
+
top_k = top_k if top_k is not None else self.config.top_k
|
363 |
+
top_p = top_p if top_p is not None else self.config.top_p
|
364 |
+
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
|
365 |
+
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
|
366 |
+
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
367 |
+
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
368 |
+
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
|
369 |
+
no_repeat_ngram_size = (
|
370 |
+
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
|
371 |
+
)
|
372 |
+
bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
|
373 |
+
num_return_sequences = (
|
374 |
+
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
|
375 |
+
)
|
376 |
+
decoder_start_token_id = (
|
377 |
+
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
|
378 |
+
)
|
379 |
+
|
380 |
+
if input_ids is not None:
|
381 |
+
batch_size = input_ids.shape[0] # overriden by the input batch_size
|
382 |
+
else:
|
383 |
+
batch_size = 1
|
384 |
+
|
385 |
+
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
|
386 |
+
assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
|
387 |
+
assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
|
388 |
+
assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
|
389 |
+
assert isinstance(use_cache, bool), "`use_cache` should be a boolean."
|
390 |
+
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer."
|
391 |
+
assert temperature > 0, "`temperature` should be strictly positive."
|
392 |
+
assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
|
393 |
+
assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
|
394 |
+
assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
|
395 |
+
assert input_ids is not None or (
|
396 |
+
isinstance(bos_token_id, int) and bos_token_id >= 0
|
397 |
+
), "If input_ids is not defined, `bos_token_id` should be a positive integer."
|
398 |
+
assert pad_token_id is None or (
|
399 |
+
isinstance(pad_token_id, int) and (pad_token_id >= 0)
|
400 |
+
), "`pad_token_id` should be a positive integer."
|
401 |
+
assert (eos_token_id is None) or (
|
402 |
+
isinstance(eos_token_id, int) and (eos_token_id >= 0)
|
403 |
+
), "`eos_token_id` should be a positive integer."
|
404 |
+
assert length_penalty > 0, "`length_penalty` should be strictly positive."
|
405 |
+
assert (
|
406 |
+
isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0
|
407 |
+
), "`no_repeat_ngram_size` should be a positive integer."
|
408 |
+
assert (
|
409 |
+
isinstance(num_return_sequences, int) and num_return_sequences > 0
|
410 |
+
), "`num_return_sequences` should be a strictly positive integer."
|
411 |
+
assert (
|
412 |
+
bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list)
|
413 |
+
), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated"
|
414 |
+
|
415 |
+
if input_ids is None:
|
416 |
+
assert isinstance(bos_token_id, int) and bos_token_id >= 0, (
|
417 |
+
"you should either supply a context to complete as `input_ids` input "
|
418 |
+
"or a `bos_token_id` (integer >= 0) as a first token to start the generation."
|
419 |
+
)
|
420 |
+
input_ids = torch.full(
|
421 |
+
(batch_size, 1),
|
422 |
+
bos_token_id,
|
423 |
+
dtype=torch.long,
|
424 |
+
device=next(self.parameters()).device,
|
425 |
+
)
|
426 |
+
else:
|
427 |
+
assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)."
|
428 |
+
|
429 |
+
# not allow to duplicate outputs when greedy decoding
|
430 |
+
if do_sample is False:
|
431 |
+
if num_beams == 1:
|
432 |
+
# no_beam_search greedy generation conditions
|
433 |
+
assert (
|
434 |
+
num_return_sequences == 1
|
435 |
+
), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1"
|
436 |
+
|
437 |
+
else:
|
438 |
+
# beam_search greedy generation conditions
|
439 |
+
assert (
|
440 |
+
num_beams >= num_return_sequences
|
441 |
+
), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"
|
442 |
+
|
443 |
+
# create attention mask if necessary
|
444 |
+
# TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140
|
445 |
+
if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids):
|
446 |
+
attention_mask = input_ids.ne(pad_token_id).long()
|
447 |
+
elif attention_mask is None:
|
448 |
+
attention_mask = input_ids.new_ones(input_ids.shape)
|
449 |
+
|
450 |
+
# set pad_token_id to eos_token_id if not set. Important that this is done after
|
451 |
+
# attention_mask is created
|
452 |
+
if pad_token_id is None and eos_token_id is not None:
|
453 |
+
print(
|
454 |
+
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id)
|
455 |
+
)
|
456 |
+
pad_token_id = eos_token_id
|
457 |
+
|
458 |
+
# vocab size
|
459 |
+
if hasattr(self.config, "vocab_size"):
|
460 |
+
vocab_size = self.config.vocab_size
|
461 |
+
elif (
|
462 |
+
self.config.is_encoder_decoder
|
463 |
+
and hasattr(self.config, "decoder")
|
464 |
+
and hasattr(self.config.decoder, "vocab_size")
|
465 |
+
):
|
466 |
+
vocab_size = self.config.decoder.vocab_size
|
467 |
+
else:
|
468 |
+
raise ValueError("either self.config.vocab_size or self.config.decoder.vocab_size needs to be defined")
|
469 |
+
|
470 |
+
# set effective batch size and effective batch multiplier according to do_sample
|
471 |
+
if do_sample:
|
472 |
+
effective_batch_size = batch_size * num_return_sequences
|
473 |
+
effective_batch_mult = num_return_sequences
|
474 |
+
else:
|
475 |
+
effective_batch_size = batch_size
|
476 |
+
effective_batch_mult = 1
|
477 |
+
|
478 |
+
if self.config.is_encoder_decoder:
|
479 |
+
if decoder_start_token_id is None:
|
480 |
+
# see if BOS token can be used for decoder_start_token_id
|
481 |
+
if bos_token_id is not None:
|
482 |
+
decoder_start_token_id = bos_token_id
|
483 |
+
elif (
|
484 |
+
hasattr(self.config, "decoder")
|
485 |
+
and hasattr(self.config.decoder, "bos_token_id")
|
486 |
+
and self.config.decoder.bos_token_id is not None
|
487 |
+
):
|
488 |
+
decoder_start_token_id = self.config.decoder.bos_token_id
|
489 |
+
else:
|
490 |
+
raise ValueError(
|
491 |
+
"decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
|
492 |
+
)
|
493 |
+
|
494 |
+
assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
|
495 |
+
assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)
|
496 |
+
|
497 |
+
# get encoder and store encoder outputs
|
498 |
+
encoder = self.get_encoder()
|
499 |
+
encoder_outputs: ModelOutput = encoder(input_ids, attention_mask=attention_mask, return_dict=True)
|
500 |
+
|
501 |
+
# Expand input ids if num_beams > 1 or num_return_sequences > 1
|
502 |
+
if num_return_sequences > 1 or num_beams > 1:
|
503 |
+
# TODO: make this a call-back function.
|
504 |
+
# input_ids=caps,
|
505 |
+
# input_video_embeds=vfeats,
|
506 |
+
# attention_mask=attention_mask,
|
507 |
+
# token_type_ids=token_type_ids,
|
508 |
+
input_video_embeds = model_kwargs.pop("input_video_embeds", None)
|
509 |
+
token_type_ids = model_kwargs.pop("token_type_ids", None)
|
510 |
+
|
511 |
+
input_ids_len = input_ids.shape[-1]
|
512 |
+
input_ids = input_ids.unsqueeze(1).expand(
|
513 |
+
batch_size, effective_batch_mult * num_beams, input_ids_len)
|
514 |
+
|
515 |
+
input_video_embeds_len, input_video_embeds_hidden = input_video_embeds.size(1), input_video_embeds.size(2)
|
516 |
+
input_video_embeds = input_video_embeds.unsqueeze(1).expand(
|
517 |
+
batch_size, effective_batch_mult * num_beams, input_video_embeds_len, input_video_embeds_hidden)
|
518 |
+
|
519 |
+
attention_mask_from_len, attention_mask_to_len = attention_mask.size(1), attention_mask.size(2)
|
520 |
+
attention_mask = attention_mask.unsqueeze(1).expand(
|
521 |
+
batch_size, effective_batch_mult * num_beams, attention_mask_from_len, attention_mask_to_len
|
522 |
+
)
|
523 |
+
|
524 |
+
token_type_ids_len = token_type_ids.size(1)
|
525 |
+
token_type_ids = token_type_ids.unsqueeze(1).expand(
|
526 |
+
batch_size, effective_batch_mult * num_beams, token_type_ids_len
|
527 |
+
)
|
528 |
+
|
529 |
+
# contiguous ...
|
530 |
+
input_ids = input_ids.contiguous().view(
|
531 |
+
effective_batch_size * num_beams, input_ids_len
|
532 |
+
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
|
533 |
+
|
534 |
+
input_video_embeds = input_video_embeds.contiguous().view(
|
535 |
+
effective_batch_size * num_beams, input_video_embeds_len, input_video_embeds_hidden)
|
536 |
+
|
537 |
+
attention_mask = attention_mask.contiguous().view(
|
538 |
+
effective_batch_size * num_beams, attention_mask_from_len, attention_mask_to_len
|
539 |
+
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
|
540 |
+
|
541 |
+
token_type_ids = token_type_ids.contiguous().view(
|
542 |
+
effective_batch_size * num_beams, token_type_ids_len
|
543 |
+
)
|
544 |
+
|
545 |
+
model_kwargs["input_video_embeds"] = input_video_embeds
|
546 |
+
model_kwargs["token_type_ids"] = token_type_ids
|
547 |
+
|
548 |
+
if self.config.is_encoder_decoder:
|
549 |
+
device = next(self.parameters()).device
|
550 |
+
if decoder_input_ids is not None:
|
551 |
+
# give initial decoder input ids
|
552 |
+
input_ids = decoder_input_ids.repeat(effective_batch_size * num_beams, 1).to(device)
|
553 |
+
else:
|
554 |
+
# create empty decoder input_ids
|
555 |
+
input_ids = torch.full(
|
556 |
+
(effective_batch_size * num_beams, 1),
|
557 |
+
decoder_start_token_id,
|
558 |
+
dtype=torch.long,
|
559 |
+
device=device,
|
560 |
+
)
|
561 |
+
cur_len = input_ids.shape[-1]
|
562 |
+
|
563 |
+
assert (
|
564 |
+
batch_size == encoder_outputs.last_hidden_state.shape[0]
|
565 |
+
), f"expected encoder_outputs.last_hidden_state to have 1st dimension bs={batch_size}, got {encoder_outputs.last_hidden_state.shape[0]} "
|
566 |
+
|
567 |
+
# expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1)
|
568 |
+
expanded_batch_idxs = (
|
569 |
+
torch.arange(batch_size)
|
570 |
+
.view(-1, 1)
|
571 |
+
.repeat(1, num_beams * effective_batch_mult)
|
572 |
+
.view(-1)
|
573 |
+
.to(input_ids.device)
|
574 |
+
)
|
575 |
+
|
576 |
+
# expand encoder_outputs
|
577 |
+
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
|
578 |
+
0, expanded_batch_idxs
|
579 |
+
)
|
580 |
+
|
581 |
+
# save encoder_outputs in `model_kwargs`
|
582 |
+
model_kwargs["encoder_outputs"] = encoder_outputs
|
583 |
+
|
584 |
+
else:
|
585 |
+
cur_len = input_ids.shape[-1]
|
586 |
+
|
587 |
+
assert (
|
588 |
+
cur_len < max_length
|
589 |
+
), f"The context has {cur_len} number of tokens, but `max_length` is only {max_length}. Please make sure that `max_length` is bigger than the number of tokens, by setting either `generate(max_length=...,...)` or `config.max_length = ...`"
|
590 |
+
|
591 |
+
if num_beams > 1:
|
592 |
+
output = self._generate_beam_search(
|
593 |
+
input_ids,
|
594 |
+
cur_len=cur_len,
|
595 |
+
max_length=max_length,
|
596 |
+
min_length=min_length,
|
597 |
+
do_sample=do_sample,
|
598 |
+
early_stopping=early_stopping,
|
599 |
+
temperature=temperature,
|
600 |
+
top_k=top_k,
|
601 |
+
top_p=top_p,
|
602 |
+
repetition_penalty=repetition_penalty,
|
603 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
604 |
+
bad_words_ids=bad_words_ids,
|
605 |
+
pad_token_id=pad_token_id,
|
606 |
+
eos_token_id=eos_token_id,
|
607 |
+
batch_size=effective_batch_size,
|
608 |
+
num_return_sequences=num_return_sequences,
|
609 |
+
length_penalty=length_penalty,
|
610 |
+
num_beams=num_beams,
|
611 |
+
vocab_size=vocab_size,
|
612 |
+
attention_mask=attention_mask,
|
613 |
+
use_cache=use_cache,
|
614 |
+
model_kwargs=model_kwargs,
|
615 |
+
)
|
616 |
+
else:
|
617 |
+
output = self._generate_no_beam_search(
|
618 |
+
input_ids,
|
619 |
+
cur_len=cur_len,
|
620 |
+
max_length=max_length,
|
621 |
+
min_length=min_length,
|
622 |
+
do_sample=do_sample,
|
623 |
+
temperature=temperature,
|
624 |
+
top_k=top_k,
|
625 |
+
top_p=top_p,
|
626 |
+
repetition_penalty=repetition_penalty,
|
627 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
628 |
+
bad_words_ids=bad_words_ids,
|
629 |
+
pad_token_id=pad_token_id,
|
630 |
+
eos_token_id=eos_token_id,
|
631 |
+
batch_size=effective_batch_size,
|
632 |
+
attention_mask=attention_mask,
|
633 |
+
use_cache=use_cache,
|
634 |
+
model_kwargs=model_kwargs,
|
635 |
+
)
|
636 |
+
|
637 |
+
return output
|
638 |
+
|
639 |
+
def _generate_beam_search(
|
640 |
+
self,
|
641 |
+
input_ids,
|
642 |
+
cur_len,
|
643 |
+
max_length,
|
644 |
+
min_length,
|
645 |
+
do_sample,
|
646 |
+
early_stopping,
|
647 |
+
temperature,
|
648 |
+
top_k,
|
649 |
+
top_p,
|
650 |
+
repetition_penalty,
|
651 |
+
no_repeat_ngram_size,
|
652 |
+
bad_words_ids,
|
653 |
+
pad_token_id,
|
654 |
+
eos_token_id,
|
655 |
+
batch_size,
|
656 |
+
num_return_sequences,
|
657 |
+
length_penalty,
|
658 |
+
num_beams,
|
659 |
+
vocab_size,
|
660 |
+
attention_mask,
|
661 |
+
use_cache,
|
662 |
+
model_kwargs,
|
663 |
+
):
|
664 |
+
"""Generate sequences for each example with beam search."""
|
665 |
+
|
666 |
+
# generated hypotheses
|
667 |
+
generated_hyps = [
|
668 |
+
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping)
|
669 |
+
for _ in range(batch_size)
|
670 |
+
]
|
671 |
+
|
672 |
+
# scores for each sentence in the beam
|
673 |
+
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
|
674 |
+
|
675 |
+
# for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
|
676 |
+
if do_sample is False:
|
677 |
+
beam_scores[:, 1:] = -1e9
|
678 |
+
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
|
679 |
+
|
680 |
+
# cache compute states
|
681 |
+
past = None
|
682 |
+
|
683 |
+
# done sentences
|
684 |
+
done = [False for _ in range(batch_size)]
|
685 |
+
|
686 |
+
while cur_len < max_length:
|
687 |
+
model_inputs = self.prepare_inputs_for_generation(
|
688 |
+
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs
|
689 |
+
)
|
690 |
+
outputs = self(**model_inputs, return_dict=True) # (batch_size * num_beams, cur_len, vocab_size)
|
691 |
+
next_token_logits = outputs.logits[:, -1, :] # (batch_size * num_beams, vocab_size)
|
692 |
+
|
693 |
+
# if model has past, then set the past variable to speed up decoding
|
694 |
+
if "past_key_values" in outputs:
|
695 |
+
past = outputs.past_key_values
|
696 |
+
elif "mems" in outputs:
|
697 |
+
past = outputs.mems
|
698 |
+
|
699 |
+
if self.config.is_encoder_decoder and do_sample is False:
|
700 |
+
# TODO (PVP) still a bit hacky here - there might be a better solution
|
701 |
+
next_token_logits = self.adjust_logits_during_generation(
|
702 |
+
next_token_logits, cur_len=cur_len, max_length=max_length
|
703 |
+
)
|
704 |
+
|
705 |
+
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
706 |
+
|
707 |
+
scores = self.postprocess_next_token_scores(
|
708 |
+
scores=scores,
|
709 |
+
input_ids=input_ids,
|
710 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
711 |
+
bad_words_ids=bad_words_ids,
|
712 |
+
cur_len=cur_len,
|
713 |
+
min_length=min_length,
|
714 |
+
max_length=max_length,
|
715 |
+
eos_token_id=eos_token_id,
|
716 |
+
repetition_penalty=repetition_penalty,
|
717 |
+
batch_size=batch_size,
|
718 |
+
num_beams=num_beams,
|
719 |
+
)
|
720 |
+
|
721 |
+
assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format(
|
722 |
+
scores.shape, (batch_size * num_beams, vocab_size)
|
723 |
+
)
|
724 |
+
|
725 |
+
if do_sample:
|
726 |
+
_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
|
727 |
+
# Temperature
|
728 |
+
if temperature != 1.0:
|
729 |
+
_scores = _scores / temperature
|
730 |
+
# Top-p/top-k filtering
|
731 |
+
_scores = top_k_top_p_filtering(
|
732 |
+
_scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
|
733 |
+
) # (batch_size * num_beams, vocab_size)
|
734 |
+
# re-organize to group the beam together to sample from all beam_idxs
|
735 |
+
_scores = _scores.contiguous().view(
|
736 |
+
batch_size, num_beams * vocab_size
|
737 |
+
) # (batch_size, num_beams * vocab_size)
|
738 |
+
|
739 |
+
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
|
740 |
+
probs = F.softmax(_scores, dim=-1)
|
741 |
+
next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) # (batch_size, num_beams * 2)
|
742 |
+
# Compute next scores
|
743 |
+
next_scores = torch.gather(_scores, -1, next_tokens) # (batch_size, num_beams * 2)
|
744 |
+
# sort the sampled vector to make sure that the first num_beams samples are the best
|
745 |
+
next_scores, next_scores_indices = torch.sort(next_scores, descending=True, dim=1)
|
746 |
+
next_tokens = torch.gather(next_tokens, -1, next_scores_indices) # (batch_size, num_beams * 2)
|
747 |
+
|
748 |
+
else:
|
749 |
+
next_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
|
750 |
+
|
751 |
+
# re-organize to group the beam together (we are keeping top hypothesis accross beams)
|
752 |
+
next_scores = next_scores.view(
|
753 |
+
batch_size, num_beams * vocab_size
|
754 |
+
) # (batch_size, num_beams * vocab_size)
|
755 |
+
|
756 |
+
next_scores, next_tokens = torch.topk(next_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
|
757 |
+
|
758 |
+
assert next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams)
|
759 |
+
|
760 |
+
# next batch beam content
|
761 |
+
next_batch_beam = []
|
762 |
+
|
763 |
+
# for each sentence
|
764 |
+
for batch_idx in range(batch_size):
|
765 |
+
|
766 |
+
# if we are done with this sentence, add a pad token
|
767 |
+
if done[batch_idx]:
|
768 |
+
assert (
|
769 |
+
len(generated_hyps[batch_idx]) >= num_beams
|
770 |
+
), "Batch can only be done if at least {} beams have been generated".format(num_beams)
|
771 |
+
assert (
|
772 |
+
eos_token_id is not None and pad_token_id is not None
|
773 |
+
), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
|
774 |
+
next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
|
775 |
+
continue
|
776 |
+
|
777 |
+
# next sentence beam content, this will get added to next_batch_beam
|
778 |
+
next_sent_beam = []
|
779 |
+
|
780 |
+
# next tokens for this sentence
|
781 |
+
for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
|
782 |
+
zip(next_tokens[batch_idx], next_scores[batch_idx])
|
783 |
+
):
|
784 |
+
# get beam and token IDs
|
785 |
+
beam_id = beam_token_id // vocab_size
|
786 |
+
token_id = beam_token_id % vocab_size
|
787 |
+
|
788 |
+
effective_beam_id = batch_idx * num_beams + beam_id
|
789 |
+
# add to generated hypotheses if end of sentence
|
790 |
+
if (eos_token_id is not None) and (token_id.item() == eos_token_id):
|
791 |
+
# if beam_token does not belong to top num_beams tokens, it should not be added
|
792 |
+
is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
|
793 |
+
if is_beam_token_worse_than_top_num_beams:
|
794 |
+
continue
|
795 |
+
generated_hyps[batch_idx].add(
|
796 |
+
input_ids[effective_beam_id].clone(),
|
797 |
+
beam_token_score.item(),
|
798 |
+
)
|
799 |
+
else:
|
800 |
+
# add next predicted token since it is not eos_token
|
801 |
+
next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
|
802 |
+
|
803 |
+
# once the beam for next step is full, don't add more tokens to it.
|
804 |
+
if len(next_sent_beam) == num_beams:
|
805 |
+
break
|
806 |
+
|
807 |
+
# Check if we are done so that we can save a pad step if all(done)
|
808 |
+
done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
|
809 |
+
next_scores[batch_idx].max().item(), cur_len
|
810 |
+
)
|
811 |
+
|
812 |
+
# update next beam content
|
813 |
+
assert len(next_sent_beam) == num_beams, "Beam should always be full"
|
814 |
+
next_batch_beam.extend(next_sent_beam)
|
815 |
+
assert len(next_batch_beam) == num_beams * (batch_idx + 1), "We should have added num_beams each step"
|
816 |
+
|
817 |
+
# stop when we are done with each sentence
|
818 |
+
if all(done):
|
819 |
+
break
|
820 |
+
|
821 |
+
# sanity check / prepare next batch
|
822 |
+
assert len(next_batch_beam) == batch_size * num_beams
|
823 |
+
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
|
824 |
+
beam_tokens = input_ids.new([x[1] for x in next_batch_beam])
|
825 |
+
beam_idx = input_ids.new([x[2] for x in next_batch_beam])
|
826 |
+
|
827 |
+
# re-order batch and update current length
|
828 |
+
input_ids = input_ids[beam_idx, :]
|
829 |
+
input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)
|
830 |
+
cur_len = cur_len + 1
|
831 |
+
|
832 |
+
# re-order internal states
|
833 |
+
if past is not None:
|
834 |
+
past = self._reorder_cache(past, beam_idx)
|
835 |
+
|
836 |
+
# extend attention_mask for new generated input if only decoder
|
837 |
+
# (huxu): move out since we trim attention_mask by ourselves.
|
838 |
+
# if self.config.is_encoder_decoder is False:
|
839 |
+
# attention_mask = torch.cat(
|
840 |
+
# [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
841 |
+
# )
|
842 |
+
|
843 |
+
# finalize all open beam hypotheses and add to generated hypotheses
|
844 |
+
for batch_idx in range(batch_size):
|
845 |
+
if done[batch_idx]:
|
846 |
+
continue
|
847 |
+
|
848 |
+
# test that beam scores match previously calculated scores if not eos and batch_idx not done
|
849 |
+
if eos_token_id is not None and all(
|
850 |
+
(token_id % vocab_size).item() != eos_token_id for token_id in next_tokens[batch_idx]
|
851 |
+
):
|
852 |
+
assert torch.all(
|
853 |
+
next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx]
|
854 |
+
), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format(
|
855 |
+
next_scores[:, :num_beams][batch_idx],
|
856 |
+
beam_scores.view(batch_size, num_beams)[batch_idx],
|
857 |
+
)
|
858 |
+
|
859 |
+
# need to add best num_beams hypotheses to generated hyps
|
860 |
+
for beam_id in range(num_beams):
|
861 |
+
effective_beam_id = batch_idx * num_beams + beam_id
|
862 |
+
final_score = beam_scores[effective_beam_id].item()
|
863 |
+
final_tokens = input_ids[effective_beam_id]
|
864 |
+
generated_hyps[batch_idx].add(final_tokens, final_score)
|
865 |
+
|
866 |
+
# depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
|
867 |
+
output_batch_size = batch_size if do_sample else batch_size * num_return_sequences
|
868 |
+
output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences
|
869 |
+
|
870 |
+
# select the best hypotheses
|
871 |
+
sent_lengths = input_ids.new(output_batch_size)
|
872 |
+
best = []
|
873 |
+
|
874 |
+
# retrieve best hypotheses
|
875 |
+
for i, hypotheses in enumerate(generated_hyps):
|
876 |
+
sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])
|
877 |
+
for j in range(output_num_return_sequences_per_batch):
|
878 |
+
effective_batch_idx = output_num_return_sequences_per_batch * i + j
|
879 |
+
best_hyp = sorted_hyps.pop()[1]
|
880 |
+
sent_lengths[effective_batch_idx] = len(best_hyp)
|
881 |
+
best.append(best_hyp)
|
882 |
+
|
883 |
+
# prepare for adding eos
|
884 |
+
sent_max_len = min(sent_lengths.max().item() + 1, max_length)
|
885 |
+
decoded = input_ids.new(output_batch_size, sent_max_len)
|
886 |
+
# shorter batches are padded if needed
|
887 |
+
if sent_lengths.min().item() != sent_lengths.max().item():
|
888 |
+
assert pad_token_id is not None, "`pad_token_id` has to be defined"
|
889 |
+
decoded.fill_(pad_token_id)
|
890 |
+
|
891 |
+
# fill with hypotheses and eos_token_id if the latter fits in
|
892 |
+
for i, hypo in enumerate(best):
|
893 |
+
decoded[i, : sent_lengths[i]] = hypo
|
894 |
+
if sent_lengths[i] < max_length:
|
895 |
+
decoded[i, sent_lengths[i]] = eos_token_id
|
896 |
+
|
897 |
+
return decoded
|
898 |
+
|
899 |
+
def _generate_no_beam_search(
|
900 |
+
self,
|
901 |
+
input_ids,
|
902 |
+
cur_len,
|
903 |
+
max_length,
|
904 |
+
min_length,
|
905 |
+
do_sample,
|
906 |
+
temperature,
|
907 |
+
top_k,
|
908 |
+
top_p,
|
909 |
+
repetition_penalty,
|
910 |
+
no_repeat_ngram_size,
|
911 |
+
bad_words_ids,
|
912 |
+
pad_token_id,
|
913 |
+
eos_token_id,
|
914 |
+
batch_size,
|
915 |
+
attention_mask,
|
916 |
+
use_cache,
|
917 |
+
model_kwargs,
|
918 |
+
):
|
919 |
+
"""Generate sequences for each example without beam search (num_beams == 1).
|
920 |
+
All returned sequence are generated independantly.
|
921 |
+
"""
|
922 |
+
# length of generated sentences / unfinished sentences
|
923 |
+
unfinished_sents = input_ids.new(batch_size).fill_(1)
|
924 |
+
sent_lengths = input_ids.new(batch_size).fill_(max_length)
|
925 |
+
|
926 |
+
past = None
|
927 |
+
while cur_len < max_length:
|
928 |
+
model_inputs = self.prepare_inputs_for_generation(
|
929 |
+
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs
|
930 |
+
)
|
931 |
+
|
932 |
+
outputs = self(**model_inputs, return_dict=True)
|
933 |
+
next_token_logits = outputs.logits[:, -1, :]
|
934 |
+
scores = self.postprocess_next_token_scores(
|
935 |
+
scores=next_token_logits,
|
936 |
+
input_ids=input_ids,
|
937 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
938 |
+
bad_words_ids=bad_words_ids,
|
939 |
+
cur_len=cur_len,
|
940 |
+
min_length=min_length,
|
941 |
+
max_length=max_length,
|
942 |
+
eos_token_id=eos_token_id,
|
943 |
+
repetition_penalty=repetition_penalty,
|
944 |
+
batch_size=batch_size,
|
945 |
+
num_beams=1,
|
946 |
+
)
|
947 |
+
|
948 |
+
# if model has past, then set the past variable to speed up decoding
|
949 |
+
if "past_key_values" in outputs:
|
950 |
+
past = outputs.past_key_values
|
951 |
+
elif "mems" in outputs:
|
952 |
+
past = outputs.mems
|
953 |
+
|
954 |
+
if do_sample:
|
955 |
+
# Temperature (higher temperature => more likely to sample low probability tokens)
|
956 |
+
if temperature != 1.0:
|
957 |
+
scores = scores / temperature
|
958 |
+
# Top-p/top-k filtering
|
959 |
+
next_token_logscores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p)
|
960 |
+
# Sample
|
961 |
+
probs = F.softmax(next_token_logscores, dim=-1)
|
962 |
+
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
|
963 |
+
else:
|
964 |
+
# Greedy decoding
|
965 |
+
next_token = torch.argmax(next_token_logits, dim=-1)
|
966 |
+
|
967 |
+
# print(next_token_logits[0,next_token[0]], next_token_logits[0,eos_token_id])
|
968 |
+
|
969 |
+
# update generations and finished sentences
|
970 |
+
if eos_token_id is not None:
|
971 |
+
# pad finished sentences if eos_token_id exist
|
972 |
+
tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
|
973 |
+
else:
|
974 |
+
tokens_to_add = next_token
|
975 |
+
|
976 |
+
# add token and increase length by one
|
977 |
+
input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
|
978 |
+
cur_len = cur_len + 1
|
979 |
+
|
980 |
+
if eos_token_id is not None:
|
981 |
+
eos_in_sents = tokens_to_add == eos_token_id
|
982 |
+
# if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
|
983 |
+
is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool()
|
984 |
+
sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len)
|
985 |
+
# unfinished_sents is set to zero if eos in sentence
|
986 |
+
unfinished_sents.mul_((~eos_in_sents).long())
|
987 |
+
|
988 |
+
# stop when there is a </s> in each sentence, or if we exceed the maximul length
|
989 |
+
if unfinished_sents.max() == 0:
|
990 |
+
break
|
991 |
+
|
992 |
+
|
993 |
+
# extend attention_mask for new generated input if only decoder
|
994 |
+
# if self.config.is_encoder_decoder is False:
|
995 |
+
# attention_mask = torch.cat(
|
996 |
+
# [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
997 |
+
# )
|
998 |
+
|
999 |
+
return input_ids
|
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/models/transformermodel.py
ADDED
@@ -0,0 +1,734 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
# Copyright (c) Facebook, Inc. All Rights Reserved
|
17 |
+
|
18 |
+
import torch
|
19 |
+
|
20 |
+
from torch import nn
|
21 |
+
|
22 |
+
try:
|
23 |
+
from transformers.modeling_bert import (
|
24 |
+
BertPreTrainedModel,
|
25 |
+
BertModel,
|
26 |
+
BertEncoder,
|
27 |
+
BertPredictionHeadTransform,
|
28 |
+
)
|
29 |
+
except ImportError:
|
30 |
+
pass
|
31 |
+
|
32 |
+
from ..modules import VideoTokenMLP, MMBertEmbeddings
|
33 |
+
|
34 |
+
|
35 |
+
# --------------- fine-tuning models ---------------
|
36 |
+
class MMBertForJoint(BertPreTrainedModel):
|
37 |
+
"""A BertModel with isolated attention mask to separate modality."""
|
38 |
+
|
39 |
+
def __init__(self, config):
|
40 |
+
super().__init__(config)
|
41 |
+
self.videomlp = VideoTokenMLP(config)
|
42 |
+
self.bert = MMBertModel(config)
|
43 |
+
self.init_weights()
|
44 |
+
|
45 |
+
def forward(
|
46 |
+
self,
|
47 |
+
input_ids=None,
|
48 |
+
input_video_embeds=None,
|
49 |
+
attention_mask=None,
|
50 |
+
token_type_ids=None,
|
51 |
+
position_ids=None,
|
52 |
+
head_mask=None,
|
53 |
+
inputs_embeds=None,
|
54 |
+
next_sentence_label=None,
|
55 |
+
output_attentions=None,
|
56 |
+
output_hidden_states=None,
|
57 |
+
return_dict=None,
|
58 |
+
separate_forward_split=None,
|
59 |
+
):
|
60 |
+
return_dict = (
|
61 |
+
return_dict if return_dict is not None
|
62 |
+
else self.config.use_return_dict
|
63 |
+
)
|
64 |
+
video_tokens = self.videomlp(input_video_embeds)
|
65 |
+
|
66 |
+
outputs = self.bert(
|
67 |
+
input_ids,
|
68 |
+
video_tokens,
|
69 |
+
attention_mask=attention_mask,
|
70 |
+
token_type_ids=token_type_ids,
|
71 |
+
position_ids=position_ids,
|
72 |
+
head_mask=head_mask,
|
73 |
+
inputs_embeds=inputs_embeds,
|
74 |
+
output_attentions=output_attentions,
|
75 |
+
output_hidden_states=output_hidden_states,
|
76 |
+
return_dict=return_dict,
|
77 |
+
separate_forward_split=separate_forward_split,
|
78 |
+
)
|
79 |
+
|
80 |
+
return outputs
|
81 |
+
|
82 |
+
|
83 |
+
class MMBertForTokenClassification(BertPreTrainedModel):
|
84 |
+
"""A BertModel similar to MMJointUni, with extra wrapper layer
|
85 |
+
to be fine-tuned from other pretrained MMFusion model."""
|
86 |
+
|
87 |
+
def __init__(self, config):
|
88 |
+
super().__init__(config)
|
89 |
+
self.videomlp = VideoTokenMLP(config)
|
90 |
+
self.bert = MMBertModel(config)
|
91 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
92 |
+
# TODO(huxu): 779 is the number of classes for COIN: move to config?
|
93 |
+
self.classifier = nn.Linear(config.hidden_size, 779)
|
94 |
+
self.init_weights()
|
95 |
+
|
96 |
+
def forward(
|
97 |
+
self,
|
98 |
+
input_ids=None,
|
99 |
+
input_video_embeds=None,
|
100 |
+
attention_mask=None,
|
101 |
+
token_type_ids=None,
|
102 |
+
position_ids=None,
|
103 |
+
head_mask=None,
|
104 |
+
inputs_embeds=None,
|
105 |
+
next_sentence_label=None,
|
106 |
+
output_attentions=None,
|
107 |
+
output_hidden_states=None,
|
108 |
+
return_dict=None,
|
109 |
+
separate_forward_split=None,
|
110 |
+
):
|
111 |
+
return_dict = (
|
112 |
+
return_dict if return_dict is not None
|
113 |
+
else self.config.use_return_dict
|
114 |
+
)
|
115 |
+
|
116 |
+
video_tokens = self.videomlp(input_video_embeds)
|
117 |
+
outputs = self.bert(
|
118 |
+
input_ids,
|
119 |
+
video_tokens,
|
120 |
+
attention_mask=attention_mask,
|
121 |
+
token_type_ids=token_type_ids,
|
122 |
+
position_ids=position_ids,
|
123 |
+
head_mask=head_mask,
|
124 |
+
inputs_embeds=inputs_embeds,
|
125 |
+
output_attentions=output_attentions,
|
126 |
+
output_hidden_states=output_hidden_states,
|
127 |
+
return_dict=return_dict,
|
128 |
+
separate_forward_split=separate_forward_split,
|
129 |
+
)
|
130 |
+
|
131 |
+
return (self.classifier(outputs[0]),)
|
132 |
+
|
133 |
+
|
134 |
+
# ------------ pre-training models ----------------
|
135 |
+
|
136 |
+
class MMBertForEncoder(BertPreTrainedModel):
|
137 |
+
"""A BertModel for Contrastive Learning."""
|
138 |
+
def __init__(self, config):
|
139 |
+
super().__init__(config)
|
140 |
+
self.videomlp = VideoTokenMLP(config)
|
141 |
+
self.bert = MMBertModel(config)
|
142 |
+
self.init_weights()
|
143 |
+
|
144 |
+
def forward(
|
145 |
+
self,
|
146 |
+
input_ids=None,
|
147 |
+
input_video_embeds=None,
|
148 |
+
attention_mask=None,
|
149 |
+
token_type_ids=None,
|
150 |
+
position_ids=None,
|
151 |
+
head_mask=None,
|
152 |
+
inputs_embeds=None,
|
153 |
+
output_attentions=None,
|
154 |
+
output_hidden_states=None,
|
155 |
+
return_dict=None,
|
156 |
+
):
|
157 |
+
return_dict = (
|
158 |
+
return_dict if return_dict is not None
|
159 |
+
else self.config.use_return_dict
|
160 |
+
)
|
161 |
+
if input_video_embeds is not None:
|
162 |
+
video_tokens = self.videomlp(input_video_embeds)
|
163 |
+
else:
|
164 |
+
video_tokens = None
|
165 |
+
|
166 |
+
outputs = self.bert(
|
167 |
+
input_ids,
|
168 |
+
video_tokens,
|
169 |
+
attention_mask=attention_mask,
|
170 |
+
token_type_ids=token_type_ids,
|
171 |
+
position_ids=position_ids,
|
172 |
+
head_mask=head_mask,
|
173 |
+
inputs_embeds=inputs_embeds,
|
174 |
+
output_attentions=output_attentions,
|
175 |
+
output_hidden_states=output_hidden_states,
|
176 |
+
return_dict=return_dict,
|
177 |
+
)
|
178 |
+
return outputs
|
179 |
+
|
180 |
+
|
181 |
+
class MMBertForMFMMLM(BertPreTrainedModel):
|
182 |
+
"""A BertModel with shared prediction head on MFM-MLM."""
|
183 |
+
def __init__(self, config):
|
184 |
+
super().__init__(config)
|
185 |
+
self.videomlp = VideoTokenMLP(config)
|
186 |
+
self.bert = MMBertModel(config)
|
187 |
+
self.cls = MFMMLMHead(config)
|
188 |
+
self.hidden_size = config.hidden_size
|
189 |
+
self.init_weights()
|
190 |
+
|
191 |
+
def get_output_embeddings(self):
|
192 |
+
return self.cls.predictions.decoder
|
193 |
+
|
194 |
+
def forward(
|
195 |
+
self,
|
196 |
+
input_ids=None,
|
197 |
+
input_video_embeds=None,
|
198 |
+
attention_mask=None,
|
199 |
+
token_type_ids=None,
|
200 |
+
position_ids=None,
|
201 |
+
head_mask=None,
|
202 |
+
inputs_embeds=None,
|
203 |
+
masked_frame_labels=None,
|
204 |
+
target_video_hidden_states=None,
|
205 |
+
non_masked_frame_mask=None,
|
206 |
+
masked_lm_labels=None,
|
207 |
+
output_attentions=None,
|
208 |
+
output_hidden_states=None,
|
209 |
+
return_dict=None,
|
210 |
+
):
|
211 |
+
return_dict = (
|
212 |
+
return_dict if return_dict is not None
|
213 |
+
else self.config.use_return_dict
|
214 |
+
)
|
215 |
+
if input_video_embeds is not None:
|
216 |
+
video_tokens = self.videomlp(input_video_embeds)
|
217 |
+
else:
|
218 |
+
video_tokens = None
|
219 |
+
|
220 |
+
if target_video_hidden_states is not None:
|
221 |
+
target_video_hidden_states = self.videomlp(
|
222 |
+
target_video_hidden_states)
|
223 |
+
|
224 |
+
non_masked_frame_hidden_states = video_tokens.masked_select(
|
225 |
+
non_masked_frame_mask.unsqueeze(-1)
|
226 |
+
).view(-1, self.hidden_size)
|
227 |
+
|
228 |
+
outputs = self.bert(
|
229 |
+
input_ids,
|
230 |
+
video_tokens,
|
231 |
+
attention_mask=attention_mask,
|
232 |
+
token_type_ids=token_type_ids,
|
233 |
+
position_ids=position_ids,
|
234 |
+
head_mask=head_mask,
|
235 |
+
inputs_embeds=inputs_embeds,
|
236 |
+
output_attentions=output_attentions,
|
237 |
+
output_hidden_states=output_hidden_states,
|
238 |
+
return_dict=return_dict,
|
239 |
+
)
|
240 |
+
|
241 |
+
sequence_output = outputs[0]
|
242 |
+
|
243 |
+
mfm_scores, prediction_scores = None, None
|
244 |
+
if masked_frame_labels is not None and masked_lm_labels is not None:
|
245 |
+
# split the sequence.
|
246 |
+
text_offset = masked_frame_labels.size(1) + 1 # [CLS]
|
247 |
+
video_sequence_output = sequence_output[
|
248 |
+
:, 1:text_offset
|
249 |
+
] # remove [SEP] as not in video_label.
|
250 |
+
text_sequence_output = torch.cat(
|
251 |
+
[sequence_output[:, :1], sequence_output[:, text_offset:]],
|
252 |
+
dim=1
|
253 |
+
)
|
254 |
+
|
255 |
+
hidden_size = video_sequence_output.size(-1)
|
256 |
+
selected_video_output = video_sequence_output.masked_select(
|
257 |
+
masked_frame_labels.unsqueeze(-1)
|
258 |
+
).view(-1, hidden_size)
|
259 |
+
|
260 |
+
# only compute select tokens to training to speed up.
|
261 |
+
hidden_size = text_sequence_output.size(-1)
|
262 |
+
# masked_lm_labels = masked_lm_labels.reshape(-1)
|
263 |
+
labels_mask = masked_lm_labels != -100
|
264 |
+
|
265 |
+
selected_text_output = text_sequence_output.masked_select(
|
266 |
+
labels_mask.unsqueeze(-1)
|
267 |
+
).view(-1, hidden_size)
|
268 |
+
mfm_scores, prediction_scores = self.cls(
|
269 |
+
selected_video_output,
|
270 |
+
target_video_hidden_states,
|
271 |
+
non_masked_frame_hidden_states,
|
272 |
+
selected_text_output,
|
273 |
+
)
|
274 |
+
|
275 |
+
output = (
|
276 |
+
mfm_scores,
|
277 |
+
prediction_scores,
|
278 |
+
) + outputs
|
279 |
+
return output
|
280 |
+
|
281 |
+
|
282 |
+
class BertMFMMLMPredictionHead(nn.Module):
|
283 |
+
def __init__(self, config):
|
284 |
+
super().__init__()
|
285 |
+
self.transform = BertPredictionHeadTransform(config)
|
286 |
+
# The output weights are the same as the input embeddings, but there is
|
287 |
+
# an output-only bias for each token.
|
288 |
+
self.decoder = nn.Linear(
|
289 |
+
config.hidden_size, config.vocab_size, bias=False)
|
290 |
+
|
291 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
292 |
+
|
293 |
+
# Need a link between the two variables so that the bias is correctly
|
294 |
+
# resized with `resize_token_embeddings`
|
295 |
+
self.decoder.bias = self.bias
|
296 |
+
|
297 |
+
def forward(
|
298 |
+
self,
|
299 |
+
video_hidden_states=None,
|
300 |
+
target_video_hidden_states=None,
|
301 |
+
non_masked_frame_hidden_states=None,
|
302 |
+
text_hidden_states=None,
|
303 |
+
):
|
304 |
+
video_logits, text_logits = None, None
|
305 |
+
if video_hidden_states is not None:
|
306 |
+
video_hidden_states = self.transform(video_hidden_states)
|
307 |
+
non_masked_frame_logits = torch.mm(
|
308 |
+
video_hidden_states,
|
309 |
+
non_masked_frame_hidden_states.transpose(1, 0)
|
310 |
+
)
|
311 |
+
masked_frame_logits = torch.bmm(
|
312 |
+
video_hidden_states.unsqueeze(1),
|
313 |
+
target_video_hidden_states.unsqueeze(-1),
|
314 |
+
).squeeze(-1)
|
315 |
+
video_logits = torch.cat(
|
316 |
+
[masked_frame_logits, non_masked_frame_logits], dim=1
|
317 |
+
)
|
318 |
+
|
319 |
+
if text_hidden_states is not None:
|
320 |
+
text_hidden_states = self.transform(text_hidden_states)
|
321 |
+
text_logits = self.decoder(text_hidden_states)
|
322 |
+
return video_logits, text_logits
|
323 |
+
|
324 |
+
|
325 |
+
class MFMMLMHead(nn.Module):
|
326 |
+
def __init__(self, config):
|
327 |
+
super().__init__()
|
328 |
+
self.predictions = BertMFMMLMPredictionHead(config)
|
329 |
+
|
330 |
+
def forward(
|
331 |
+
self,
|
332 |
+
video_hidden_states=None,
|
333 |
+
target_video_hidden_states=None,
|
334 |
+
non_masked_frame_hidden_states=None,
|
335 |
+
text_hidden_states=None,
|
336 |
+
):
|
337 |
+
video_logits, text_logits = self.predictions(
|
338 |
+
video_hidden_states,
|
339 |
+
target_video_hidden_states,
|
340 |
+
non_masked_frame_hidden_states,
|
341 |
+
text_hidden_states,
|
342 |
+
)
|
343 |
+
return video_logits, text_logits
|
344 |
+
|
345 |
+
|
346 |
+
class MMBertForMTM(MMBertForMFMMLM):
|
347 |
+
def __init__(self, config):
|
348 |
+
BertPreTrainedModel.__init__(self, config)
|
349 |
+
self.videomlp = VideoTokenMLP(config)
|
350 |
+
self.bert = MMBertModel(config)
|
351 |
+
self.cls = MTMHead(config)
|
352 |
+
self.hidden_size = config.hidden_size
|
353 |
+
self.init_weights()
|
354 |
+
|
355 |
+
|
356 |
+
class BertMTMPredictionHead(nn.Module):
|
357 |
+
def __init__(self, config):
|
358 |
+
super().__init__()
|
359 |
+
self.transform = BertPredictionHeadTransform(config)
|
360 |
+
self.decoder = nn.Linear(
|
361 |
+
config.hidden_size, config.vocab_size, bias=False)
|
362 |
+
|
363 |
+
def forward(
|
364 |
+
self,
|
365 |
+
video_hidden_states=None,
|
366 |
+
target_video_hidden_states=None,
|
367 |
+
non_masked_frame_hidden_states=None,
|
368 |
+
text_hidden_states=None,
|
369 |
+
):
|
370 |
+
non_masked_frame_hidden_states = non_masked_frame_hidden_states.transpose(1, 0)
|
371 |
+
video_logits, text_logits = None, None
|
372 |
+
if video_hidden_states is not None:
|
373 |
+
video_hidden_states = self.transform(video_hidden_states)
|
374 |
+
|
375 |
+
masked_frame_logits = torch.bmm(
|
376 |
+
video_hidden_states.unsqueeze(1),
|
377 |
+
target_video_hidden_states.unsqueeze(-1),
|
378 |
+
).squeeze(-1)
|
379 |
+
|
380 |
+
non_masked_frame_logits = torch.mm(
|
381 |
+
video_hidden_states,
|
382 |
+
non_masked_frame_hidden_states
|
383 |
+
)
|
384 |
+
video_on_vocab_logits = self.decoder(video_hidden_states)
|
385 |
+
video_logits = torch.cat([
|
386 |
+
masked_frame_logits,
|
387 |
+
non_masked_frame_logits,
|
388 |
+
video_on_vocab_logits], dim=1)
|
389 |
+
|
390 |
+
if text_hidden_states is not None:
|
391 |
+
text_hidden_states = self.transform(text_hidden_states)
|
392 |
+
# text first so label does not need to be shifted.
|
393 |
+
text_on_vocab_logits = self.decoder(text_hidden_states)
|
394 |
+
text_on_video_logits = torch.mm(
|
395 |
+
text_hidden_states,
|
396 |
+
non_masked_frame_hidden_states
|
397 |
+
)
|
398 |
+
text_logits = torch.cat([
|
399 |
+
text_on_vocab_logits,
|
400 |
+
text_on_video_logits
|
401 |
+
], dim=1)
|
402 |
+
|
403 |
+
return video_logits, text_logits
|
404 |
+
|
405 |
+
|
406 |
+
class MTMHead(nn.Module):
|
407 |
+
def __init__(self, config):
|
408 |
+
super().__init__()
|
409 |
+
self.predictions = BertMTMPredictionHead(config)
|
410 |
+
|
411 |
+
def forward(
|
412 |
+
self,
|
413 |
+
video_hidden_states=None,
|
414 |
+
target_video_hidden_states=None,
|
415 |
+
non_masked_frame_hidden_states=None,
|
416 |
+
text_hidden_states=None,
|
417 |
+
):
|
418 |
+
video_logits, text_logits = self.predictions(
|
419 |
+
video_hidden_states,
|
420 |
+
target_video_hidden_states,
|
421 |
+
non_masked_frame_hidden_states,
|
422 |
+
text_hidden_states,
|
423 |
+
)
|
424 |
+
return video_logits, text_logits
|
425 |
+
|
426 |
+
|
427 |
+
class MMBertModel(BertModel):
|
428 |
+
"""MMBertModel has MMBertEmbedding to support video tokens."""
|
429 |
+
|
430 |
+
def __init__(self, config, add_pooling_layer=True):
|
431 |
+
super().__init__(config)
|
432 |
+
# overwrite embedding
|
433 |
+
self.embeddings = MMBertEmbeddings(config)
|
434 |
+
self.encoder = MultiLayerAttentionMaskBertEncoder(config)
|
435 |
+
self.init_weights()
|
436 |
+
|
437 |
+
def forward(
|
438 |
+
self,
|
439 |
+
input_ids=None,
|
440 |
+
input_video_embeds=None,
|
441 |
+
attention_mask=None,
|
442 |
+
token_type_ids=None,
|
443 |
+
position_ids=None,
|
444 |
+
head_mask=None,
|
445 |
+
inputs_embeds=None,
|
446 |
+
encoder_hidden_states=None,
|
447 |
+
encoder_attention_mask=None,
|
448 |
+
output_attentions=None,
|
449 |
+
output_hidden_states=None,
|
450 |
+
return_dict=None,
|
451 |
+
separate_forward_split=None,
|
452 |
+
):
|
453 |
+
output_attentions = (
|
454 |
+
output_attentions
|
455 |
+
if output_attentions is not None
|
456 |
+
else self.config.output_attentions
|
457 |
+
)
|
458 |
+
output_hidden_states = (
|
459 |
+
output_hidden_states
|
460 |
+
if output_hidden_states is not None
|
461 |
+
else self.config.output_hidden_states
|
462 |
+
)
|
463 |
+
return_dict = (
|
464 |
+
return_dict if return_dict is not None
|
465 |
+
else self.config.use_return_dict
|
466 |
+
)
|
467 |
+
|
468 |
+
if input_ids is not None and inputs_embeds is not None:
|
469 |
+
raise ValueError(
|
470 |
+
"You cannot specify both input_ids "
|
471 |
+
"and inputs_embeds at the same time"
|
472 |
+
)
|
473 |
+
elif input_ids is not None:
|
474 |
+
if input_video_embeds is not None:
|
475 |
+
input_shape = (
|
476 |
+
input_ids.size(0),
|
477 |
+
input_ids.size(1) + input_video_embeds.size(1),
|
478 |
+
)
|
479 |
+
else:
|
480 |
+
input_shape = (
|
481 |
+
input_ids.size(0),
|
482 |
+
input_ids.size(1),
|
483 |
+
)
|
484 |
+
elif inputs_embeds is not None:
|
485 |
+
if input_video_embeds is not None:
|
486 |
+
input_shape = (
|
487 |
+
inputs_embeds.size(0),
|
488 |
+
inputs_embeds.size(1) + input_video_embeds.size(1),
|
489 |
+
)
|
490 |
+
else:
|
491 |
+
input_shape = (
|
492 |
+
input_ids.size(0),
|
493 |
+
input_ids.size(1),
|
494 |
+
)
|
495 |
+
else:
|
496 |
+
raise ValueError(
|
497 |
+
"You have to specify either input_ids or inputs_embeds")
|
498 |
+
|
499 |
+
device = input_ids.device if input_ids is not None \
|
500 |
+
else inputs_embeds.device
|
501 |
+
|
502 |
+
if attention_mask is None:
|
503 |
+
attention_mask = torch.ones(input_shape, device=device)
|
504 |
+
if token_type_ids is None:
|
505 |
+
token_type_ids = torch.zeros(
|
506 |
+
input_shape, dtype=torch.long, device=device)
|
507 |
+
|
508 |
+
# We can provide a self-attention mask of dimensions
|
509 |
+
# [batch_size, from_seq_length, to_seq_length]
|
510 |
+
# ourselves in which case
|
511 |
+
# we just need to make it broadcastable to all heads.
|
512 |
+
extended_attention_mask: torch.Tensor = \
|
513 |
+
self.get_extended_attention_mask(
|
514 |
+
attention_mask, input_shape, device)
|
515 |
+
|
516 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
517 |
+
# we need to make broadcastable to
|
518 |
+
# [batch_size, num_heads, seq_length, seq_length]
|
519 |
+
if self.config.is_decoder and encoder_hidden_states is not None:
|
520 |
+
(
|
521 |
+
encoder_batch_size,
|
522 |
+
encoder_sequence_length,
|
523 |
+
_,
|
524 |
+
) = encoder_hidden_states.size()
|
525 |
+
encoder_hidden_shape = (
|
526 |
+
encoder_batch_size, encoder_sequence_length)
|
527 |
+
if encoder_attention_mask is None:
|
528 |
+
encoder_attention_mask = torch.ones(
|
529 |
+
encoder_hidden_shape, device=device)
|
530 |
+
encoder_extended_attention_mask = self.invert_attention_mask(
|
531 |
+
encoder_attention_mask
|
532 |
+
)
|
533 |
+
else:
|
534 |
+
encoder_extended_attention_mask = None
|
535 |
+
|
536 |
+
# Prepare head mask if needed
|
537 |
+
# 1.0 in head_mask indicate we keep the head
|
538 |
+
# attention_probs has shape bsz x n_heads x N x N
|
539 |
+
# input head_mask has shape [num_heads] or
|
540 |
+
# [num_hidden_layers x num_heads]
|
541 |
+
# and head_mask is converted to shape
|
542 |
+
# [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
543 |
+
|
544 |
+
head_mask = self.get_head_mask(
|
545 |
+
head_mask, self.config.num_hidden_layers)
|
546 |
+
|
547 |
+
embedding_output = self.embeddings(
|
548 |
+
input_ids,
|
549 |
+
input_video_embeds,
|
550 |
+
position_ids=position_ids,
|
551 |
+
token_type_ids=token_type_ids,
|
552 |
+
inputs_embeds=inputs_embeds,
|
553 |
+
)
|
554 |
+
|
555 |
+
if separate_forward_split is not None:
|
556 |
+
split_embedding_output = \
|
557 |
+
embedding_output[:, :separate_forward_split]
|
558 |
+
split_extended_attention_mask = extended_attention_mask[
|
559 |
+
:, :, :, :separate_forward_split, :separate_forward_split
|
560 |
+
]
|
561 |
+
split_encoder_outputs = self.encoder(
|
562 |
+
split_embedding_output,
|
563 |
+
attention_mask=split_extended_attention_mask,
|
564 |
+
head_mask=head_mask,
|
565 |
+
encoder_hidden_states=encoder_hidden_states,
|
566 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
567 |
+
output_attentions=output_attentions,
|
568 |
+
output_hidden_states=output_hidden_states,
|
569 |
+
return_dict=return_dict,
|
570 |
+
)
|
571 |
+
assert (
|
572 |
+
len(split_encoder_outputs) <= 2
|
573 |
+
), "we do not support merge on attention for now."
|
574 |
+
encoder_outputs = []
|
575 |
+
encoder_outputs.append([split_encoder_outputs[0]])
|
576 |
+
if len(split_encoder_outputs) == 2:
|
577 |
+
encoder_outputs.append([])
|
578 |
+
for _all_hidden_states in split_encoder_outputs[1]:
|
579 |
+
encoder_outputs[-1].append([_all_hidden_states])
|
580 |
+
|
581 |
+
split_embedding_output = \
|
582 |
+
embedding_output[:, separate_forward_split:]
|
583 |
+
split_extended_attention_mask = extended_attention_mask[
|
584 |
+
:, :, :, separate_forward_split:, separate_forward_split:
|
585 |
+
]
|
586 |
+
|
587 |
+
split_encoder_outputs = self.encoder(
|
588 |
+
split_embedding_output,
|
589 |
+
attention_mask=split_extended_attention_mask,
|
590 |
+
head_mask=head_mask,
|
591 |
+
encoder_hidden_states=encoder_hidden_states,
|
592 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
593 |
+
output_attentions=output_attentions,
|
594 |
+
output_hidden_states=output_hidden_states,
|
595 |
+
return_dict=return_dict,
|
596 |
+
)
|
597 |
+
|
598 |
+
assert (
|
599 |
+
len(split_encoder_outputs) <= 2
|
600 |
+
), "we do not support merge on attention for now."
|
601 |
+
encoder_outputs[0].append(split_encoder_outputs[0])
|
602 |
+
encoder_outputs[0] = torch.cat(encoder_outputs[0], dim=1)
|
603 |
+
if len(split_encoder_outputs) == 2:
|
604 |
+
for layer_idx, _all_hidden_states in enumerate(
|
605 |
+
split_encoder_outputs[1]
|
606 |
+
):
|
607 |
+
encoder_outputs[1][layer_idx].append(_all_hidden_states)
|
608 |
+
encoder_outputs[1][layer_idx] = torch.cat(
|
609 |
+
encoder_outputs[1][layer_idx], dim=1
|
610 |
+
)
|
611 |
+
encoder_outputs = tuple(encoder_outputs)
|
612 |
+
else:
|
613 |
+
encoder_outputs = self.encoder(
|
614 |
+
embedding_output,
|
615 |
+
attention_mask=extended_attention_mask,
|
616 |
+
head_mask=head_mask,
|
617 |
+
encoder_hidden_states=encoder_hidden_states,
|
618 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
619 |
+
output_attentions=output_attentions,
|
620 |
+
output_hidden_states=output_hidden_states,
|
621 |
+
return_dict=return_dict,
|
622 |
+
)
|
623 |
+
|
624 |
+
sequence_output = encoder_outputs[0]
|
625 |
+
pooled_output = (
|
626 |
+
self.pooler(sequence_output) if self.pooler is not None else None
|
627 |
+
)
|
628 |
+
|
629 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
630 |
+
|
631 |
+
def get_extended_attention_mask(self, attention_mask, input_shape, device):
|
632 |
+
"""This is borrowed from `modeling_utils.py` with the support of
|
633 |
+
multi-layer attention masks.
|
634 |
+
The second dim is expected to be number of layers.
|
635 |
+
See `MMAttentionMaskProcessor`.
|
636 |
+
Makes broadcastable attention and causal masks so that future
|
637 |
+
and masked tokens are ignored.
|
638 |
+
|
639 |
+
Arguments:
|
640 |
+
attention_mask (:obj:`torch.Tensor`):
|
641 |
+
Mask with ones indicating tokens to attend to,
|
642 |
+
zeros for tokens to ignore.
|
643 |
+
input_shape (:obj:`Tuple[int]`):
|
644 |
+
The shape of the input to the model.
|
645 |
+
device: (:obj:`torch.device`):
|
646 |
+
The device of the input to the model.
|
647 |
+
|
648 |
+
Returns:
|
649 |
+
:obj:`torch.Tensor` The extended attention mask, \
|
650 |
+
with a the same dtype as :obj:`attention_mask.dtype`.
|
651 |
+
"""
|
652 |
+
# We can provide a self-attention mask of dimensions
|
653 |
+
# [batch_size, from_seq_length, to_seq_length]
|
654 |
+
# ourselves in which case we just need to make it broadcastable
|
655 |
+
# to all heads.
|
656 |
+
if attention_mask.dim() == 4:
|
657 |
+
extended_attention_mask = attention_mask[:, :, None, :, :]
|
658 |
+
extended_attention_mask = extended_attention_mask.to(
|
659 |
+
dtype=self.dtype
|
660 |
+
) # fp16 compatibility
|
661 |
+
extended_attention_mask = (1.0 - extended_attention_mask) \
|
662 |
+
* -10000.0
|
663 |
+
return extended_attention_mask
|
664 |
+
else:
|
665 |
+
return super().get_extended_attention_mask(
|
666 |
+
attention_mask, input_shape, device
|
667 |
+
)
|
668 |
+
|
669 |
+
|
670 |
+
class MultiLayerAttentionMaskBertEncoder(BertEncoder):
|
671 |
+
"""extend BertEncoder with the capability of
|
672 |
+
multiple layers of attention mask."""
|
673 |
+
|
674 |
+
def forward(
|
675 |
+
self,
|
676 |
+
hidden_states,
|
677 |
+
attention_mask=None,
|
678 |
+
head_mask=None,
|
679 |
+
encoder_hidden_states=None,
|
680 |
+
encoder_attention_mask=None,
|
681 |
+
output_attentions=False,
|
682 |
+
output_hidden_states=False,
|
683 |
+
return_dict=False,
|
684 |
+
):
|
685 |
+
all_hidden_states = () if output_hidden_states else None
|
686 |
+
all_attentions = () if output_attentions else None
|
687 |
+
for i, layer_module in enumerate(self.layer):
|
688 |
+
if output_hidden_states:
|
689 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
690 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
691 |
+
|
692 |
+
layer_attention_mask = (
|
693 |
+
attention_mask[:, i, :, :, :]
|
694 |
+
if attention_mask.dim() == 5
|
695 |
+
else attention_mask
|
696 |
+
)
|
697 |
+
|
698 |
+
if getattr(self.config, "gradient_checkpointing", False):
|
699 |
+
|
700 |
+
def create_custom_forward(module):
|
701 |
+
def custom_forward(*inputs):
|
702 |
+
return module(*inputs, output_attentions)
|
703 |
+
|
704 |
+
return custom_forward
|
705 |
+
|
706 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
707 |
+
create_custom_forward(layer_module),
|
708 |
+
hidden_states,
|
709 |
+
layer_attention_mask,
|
710 |
+
layer_head_mask,
|
711 |
+
encoder_hidden_states,
|
712 |
+
encoder_attention_mask,
|
713 |
+
)
|
714 |
+
else:
|
715 |
+
layer_outputs = layer_module(
|
716 |
+
hidden_states,
|
717 |
+
layer_attention_mask,
|
718 |
+
layer_head_mask,
|
719 |
+
encoder_hidden_states,
|
720 |
+
encoder_attention_mask,
|
721 |
+
output_attentions,
|
722 |
+
)
|
723 |
+
hidden_states = layer_outputs[0]
|
724 |
+
if output_attentions:
|
725 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
726 |
+
|
727 |
+
if output_hidden_states:
|
728 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
729 |
+
|
730 |
+
return tuple(
|
731 |
+
v
|
732 |
+
for v in [hidden_states, all_hidden_states, all_attentions]
|
733 |
+
if v is not None
|
734 |
+
)
|
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/modules/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
from .mm import *
|
6 |
+
|
7 |
+
try:
|
8 |
+
from .expmm import *
|
9 |
+
except ImportError:
|
10 |
+
pass
|
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/modules/mm.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
# Copyright (c) Facebook, Inc. All Rights Reserved
|
17 |
+
|
18 |
+
|
19 |
+
import torch
|
20 |
+
|
21 |
+
from torch import nn
|
22 |
+
|
23 |
+
try:
|
24 |
+
from transformers.modeling_bert import (
|
25 |
+
BertEmbeddings,
|
26 |
+
ACT2FN,
|
27 |
+
)
|
28 |
+
except ImportError:
|
29 |
+
pass
|
30 |
+
|
31 |
+
|
32 |
+
class VideoTokenMLP(nn.Module):
|
33 |
+
def __init__(self, config):
|
34 |
+
super().__init__()
|
35 |
+
input_dim = config.input_dim if hasattr(config, "input_dim") else 512
|
36 |
+
self.linear1 = nn.Linear(input_dim, config.hidden_size)
|
37 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size)
|
38 |
+
self.activation = ACT2FN[config.hidden_act]
|
39 |
+
self.linear2 = nn.Linear(config.hidden_size, config.hidden_size)
|
40 |
+
|
41 |
+
def forward(self, hidden_states):
|
42 |
+
hidden_states = self.linear1(hidden_states)
|
43 |
+
hidden_states = self.activation(hidden_states)
|
44 |
+
hidden_states = self.LayerNorm(hidden_states)
|
45 |
+
hidden_states = self.linear2(hidden_states)
|
46 |
+
return hidden_states
|
47 |
+
|
48 |
+
|
49 |
+
class MMBertEmbeddings(BertEmbeddings):
|
50 |
+
def __init__(self, config):
|
51 |
+
super().__init__(config)
|
52 |
+
self.max_video_len = config.max_video_len
|
53 |
+
if hasattr(config, "use_seg_emb") and config.use_seg_emb:
|
54 |
+
"""the original VLM paper uses seg_embeddings for temporal space.
|
55 |
+
although not used it changed the randomness of initialization.
|
56 |
+
we keep it for reproducibility.
|
57 |
+
"""
|
58 |
+
self.seg_embeddings = nn.Embedding(256, config.hidden_size)
|
59 |
+
|
60 |
+
def forward(
|
61 |
+
self,
|
62 |
+
input_ids,
|
63 |
+
input_video_embeds,
|
64 |
+
token_type_ids=None,
|
65 |
+
position_ids=None,
|
66 |
+
inputs_embeds=None,
|
67 |
+
):
|
68 |
+
input_tensor = input_ids if input_ids is not None else inputs_embeds
|
69 |
+
if input_video_embeds is not None:
|
70 |
+
input_shape = (
|
71 |
+
input_tensor.size(0),
|
72 |
+
input_tensor.size(1) + input_video_embeds.size(1),
|
73 |
+
)
|
74 |
+
else:
|
75 |
+
input_shape = (input_tensor.size(0), input_tensor.size(1))
|
76 |
+
|
77 |
+
if position_ids is None:
|
78 |
+
"""
|
79 |
+
Auto skip position embeddings for text only case.
|
80 |
+
use cases:
|
81 |
+
(1) action localization and segmentation:
|
82 |
+
feed in len-1 dummy video token needs text part to
|
83 |
+
skip input_video_embeds.size(1) for the right
|
84 |
+
position_ids for video [SEP] and rest text tokens.
|
85 |
+
(2) MMFusionShare for two forward passings:
|
86 |
+
in `forward_text`: input_video_embeds is None.
|
87 |
+
need to skip video [SEP] token.
|
88 |
+
|
89 |
+
# video_len + 1: [CLS] + video_embed
|
90 |
+
# self.max_video_len + 1: [SEP] for video.
|
91 |
+
# self.max_video_len + 2: [SEP] for video.
|
92 |
+
# self.max_video_len + input_ids.size(1): rest for text.
|
93 |
+
"""
|
94 |
+
if input_video_embeds is not None:
|
95 |
+
video_len = input_video_embeds.size(1)
|
96 |
+
starting_offset = self.max_video_len + 1 # video [SEP]
|
97 |
+
ending_offset = self.max_video_len + input_ids.size(1)
|
98 |
+
else:
|
99 |
+
video_len = 0
|
100 |
+
starting_offset = self.max_video_len + 2 # first text token.
|
101 |
+
ending_offset = self.max_video_len + input_ids.size(1) + 1
|
102 |
+
position_ids = torch.cat([
|
103 |
+
self.position_ids[:, :video_len + 1],
|
104 |
+
self.position_ids[:, starting_offset:ending_offset]
|
105 |
+
], dim=1)
|
106 |
+
|
107 |
+
if token_type_ids is None:
|
108 |
+
token_type_ids = torch.zeros(
|
109 |
+
input_shape, dtype=torch.long, device=self.position_ids.device
|
110 |
+
)
|
111 |
+
|
112 |
+
"""
|
113 |
+
the format of input_ids is [CLS] [SEP] caption [SEP] padding.
|
114 |
+
the goal is to build [CLS] video tokens [SEP] caption [SEP] .
|
115 |
+
"""
|
116 |
+
if inputs_embeds is None:
|
117 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
118 |
+
if input_video_embeds is not None:
|
119 |
+
inputs_mm_embeds = torch.cat([
|
120 |
+
inputs_embeds[:, :1], input_video_embeds, inputs_embeds[:, 1:]
|
121 |
+
], dim=1)
|
122 |
+
else:
|
123 |
+
# text only for `MMFusionShare`.
|
124 |
+
inputs_mm_embeds = inputs_embeds
|
125 |
+
|
126 |
+
position_embeddings = self.position_embeddings(position_ids)
|
127 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
128 |
+
embeddings = inputs_mm_embeds + position_embeddings
|
129 |
+
embeddings += token_type_embeddings
|
130 |
+
|
131 |
+
embeddings = self.LayerNorm(embeddings)
|
132 |
+
embeddings = self.dropout(embeddings)
|
133 |
+
return embeddings
|
134 |
+
|
135 |
+
|
136 |
+
class AlignHead(nn.Module):
|
137 |
+
"""this will load pre-trained weights for NSP, which is desirable."""
|
138 |
+
|
139 |
+
def __init__(self, config):
|
140 |
+
super().__init__()
|
141 |
+
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
142 |
+
|
143 |
+
def forward(self, dropout_pooled_output):
|
144 |
+
logits = self.seq_relationship(dropout_pooled_output)
|
145 |
+
return logits
|