d22cs051 commited on
Commit
8273cb9
0 Parent(s):

retriying pushing the code

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +35 -0
  2. .gitignore +129 -0
  3. Dockerfile +32 -0
  4. README.md +11 -0
  5. app.py +71 -0
  6. config.py +149 -0
  7. fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.circleci/config.yml +159 -0
  8. fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.github/ISSUE_TEMPLATE.md +3 -0
  9. fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.github/ISSUE_TEMPLATE/bug_report.md +43 -0
  10. fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.github/ISSUE_TEMPLATE/documentation.md +15 -0
  11. fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.github/ISSUE_TEMPLATE/feature_request.md +24 -0
  12. fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.github/ISSUE_TEMPLATE/how-to-question.md +33 -0
  13. fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.github/PULL_REQUEST_TEMPLATE.md +16 -0
  14. fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.github/stale.yml +30 -0
  15. fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.github/workflows/build.yml +60 -0
  16. fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.github/workflows/build_wheels.yml +41 -0
  17. fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.gitignore +136 -0
  18. fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.gitmodules +4 -0
  19. fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.isort.cfg +2 -0
  20. fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.pre-commit-config.yaml +40 -0
  21. fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/CODE_OF_CONDUCT.md +77 -0
  22. fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/CONTRIBUTING.md +82 -0
  23. fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/LICENSE +21 -0
  24. fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/README.md +236 -0
  25. fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/.gitignore +2 -0
  26. fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/.gitignore +139 -0
  27. fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/CONFIG.md +41 -0
  28. fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/DATASET.md +34 -0
  29. fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/README.md +166 -0
  30. fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/endtask.md +41 -0
  31. fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/locallaunch.py +148 -0
  32. fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/__init__.py +12 -0
  33. fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/datasets/__init__.py +10 -0
  34. fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/datasets/fairseqmmdataset.py +57 -0
  35. fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/datasets/mmdataset.py +111 -0
  36. fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/evaluators/__init__.py +13 -0
  37. fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/evaluators/evaluator.py +54 -0
  38. fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/evaluators/metric.py +313 -0
  39. fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/evaluators/predictor.py +595 -0
  40. fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/losses/__init__.py +16 -0
  41. fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/losses/fairseqmmloss.py +63 -0
  42. fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/losses/loss.py +87 -0
  43. fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/losses/nce.py +156 -0
  44. fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/models/__init__.py +17 -0
  45. fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/models/fairseqmmmodel.py +51 -0
  46. fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/models/mmfusion.py +926 -0
  47. fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/models/mmfusionnlg.py +999 -0
  48. fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/models/transformermodel.py +734 -0
  49. fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/modules/__init__.py +10 -0
  50. fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/modules/mm.py +145 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
Dockerfile ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.8.1-slim-buster
2
+
3
+
4
+ WORKDIR /code
5
+
6
+ COPY . /code
7
+
8
+ # RUN useradd -m -u 1000 user
9
+
10
+ RUN apt-get update
11
+ RUN apt-get install build-essential -y
12
+ # RUN pip install --no-cache-dir -r requirements.txt
13
+ RUN pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
14
+ # WORKDIR fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1
15
+ RUN pip install -e fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.
16
+ RUN pip install -r requirements.txt --no-cache-dir
17
+ RUN pip install gradio --no-cache-dir
18
+ RUN pip install protobuf==3.20.* --no-cache-dir
19
+
20
+ # Switch to the "user" user
21
+ # USER user
22
+
23
+ # Set home to the user's home directory
24
+ # ENV HOME=/home/user \
25
+ # PATH=/home/user/.local/bin:$PATH
26
+
27
+ # Set the working directory to the user's home directory
28
+ # WORKDIR $HOME/code
29
+
30
+ # COPY --chown=user . $HOME/code
31
+ # RUN ls -la $HOME/code
32
+ CMD ["python3", "app.py"]
README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Audio Deepfake Detection
3
+ emoji: 🐢
4
+ colorFrom: indigo
5
+ colorTo: purple
6
+ sdk: docker
7
+ pinned: false
8
+ license: mit
9
+ ---
10
+
11
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from model import Model
4
+ from config import Config
5
+
6
+ import warnings
7
+ # warnings.filterwarnings('ignore')
8
+
9
+ # making config object
10
+ config = Config()
11
+
12
+
13
+
14
+ def infrence(audio_file1):
15
+ print(f"[LOG] Audio file: {audio_file1}")
16
+
17
+ class DFSeparationApp:
18
+ def __init__(self, model_path,device="cpu"):
19
+ self.device = device
20
+ self.model = self.load_model(model_path)
21
+ self.model.to(self.device)
22
+
23
+
24
+ def load_model(self, model_path):
25
+ checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
26
+ fine_tuned_model = Model(
27
+ args=config,
28
+ device=self.device
29
+ )
30
+ fine_tuned_model.load_state_dict(checkpoint["model"])
31
+ print("[LOG] Model loaded successfully.")
32
+ return fine_tuned_model
33
+
34
+ def predict(self, audio_file):
35
+ # Load the audio file
36
+ audio_tensor = torch.tensor(audio_file[1]).to(self.device)
37
+ with torch.no_grad():
38
+ # Make prediction
39
+ output = self.model(audio_tensor)
40
+ preds = output.argmax(dim=-1)
41
+ probs = output.softmax(dim=-1)
42
+ print(f"[LOG] Prediction: {preds.item()}")
43
+ print(f"[LOG] Probability: {probs.max().item()}")
44
+ return preds.item(), probs.max().item()
45
+
46
+ def run(self):
47
+ print(f"[LOG] Running the app...")
48
+ # gradio interface
49
+ audio_input1 = gr.Audio(label="Upload or record audio")
50
+ prediction = gr.Label(label="Prediction:")
51
+ prob = gr.Label(label="Probability:")
52
+ gr.Interface(
53
+ fn=self.predict,
54
+ inputs=[audio_input1],
55
+ outputs=[prediction, prob],
56
+ title="DF Separation",
57
+ description="This app classify the audio samples into Real and Fake.",
58
+ examples=[
59
+ ["samples/Fake/download (5).wav","1"],
60
+ ["samples/Fake/fake1_1.wav","1"],
61
+ ["samples/Real/Central Avenue 1.wav","0"],
62
+ ["samples/Real/hindi.mp3","0"],
63
+ ]
64
+ ).launch(quiet=False,server_name="0.0.0.0")
65
+
66
+ if __name__ == "__main__":
67
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
68
+ print(f"[LOG] Device: {device}")
69
+ model_path = "models/for_trained_model.ckpt" # Replace with your model path
70
+ app = DFSeparationApp(model_path, device=device)
71
+ app.run()
config.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Config:
2
+ def __init__(self):
3
+ self.custom_data_dir = 'data/Dataset_Speech_Assignment'
4
+ self.for2sec_data_dir = 'data/for-2seconds'
5
+ self.batch_size = 32
6
+ self.num_workers = 4
7
+ self.num_epochs = 50
8
+ self.lr = 1e-3
9
+ self.model_checkpoint_path = 'models/Best_LA_model_for_DF.pth'
10
+
11
+ ############################################################################
12
+ """
13
+ parser.add_argument('--algo', type=int, default=3,
14
+ help='Rawboost algos discriptions. 0: No augmentation 1: LnL_convolutive_noise, 2: ISD_additive_noise, 3: SSI_additive_noise, 4: series algo (1+2+3), \
15
+ 5: series algo (1+2), 6: series algo (1+3), 7: series algo(2+3), 8: parallel algo(1,2) .default=0]')
16
+
17
+ # LnL_convolutive_noise parameters
18
+ parser.add_argument('--nBands', type=int, default=5,
19
+ help='number of notch filters.The higher the number of bands, the more aggresive the distortions is.[default=5]')
20
+ parser.add_argument('--minF', type=int, default=20,
21
+ help='minimum centre frequency [Hz] of notch filter.[default=20] ')
22
+ parser.add_argument('--maxF', type=int, default=8000,
23
+ help='maximum centre frequency [Hz] (<sr/2) of notch filter.[default=8000]')
24
+ parser.add_argument('--minBW', type=int, default=100,
25
+ help='minimum width [Hz] of filter.[default=100] ')
26
+ parser.add_argument('--maxBW', type=int, default=1000,
27
+ help='maximum width [Hz] of filter.[default=1000] ')
28
+ parser.add_argument('--minCoeff', type=int, default=10,
29
+ help='minimum filter coefficients. More the filter coefficients more ideal the filter slope.[default=10]')
30
+ parser.add_argument('--maxCoeff', type=int, default=100,
31
+ help='maximum filter coefficients. More the filter coefficients more ideal the filter slope.[default=100]')
32
+ parser.add_argument('--minG', type=int, default=0,
33
+ help='minimum gain factor of linear component.[default=0]')
34
+ parser.add_argument('--maxG', type=int, default=0,
35
+ help='maximum gain factor of linear component.[default=0]')
36
+ parser.add_argument('--minBiasLinNonLin', type=int, default=5,
37
+ help=' minimum gain difference between linear and non-linear components.[default=5]')
38
+ parser.add_argument('--maxBiasLinNonLin', type=int, default=20,
39
+ help=' maximum gain difference between linear and non-linear components.[default=20]')
40
+ parser.add_argument('--N_f', type=int, default=5,
41
+ help='order of the (non-)linearity where N_f=1 refers only to linear components.[default=5]')
42
+
43
+ # ISD_additive_noise parameters
44
+ parser.add_argument('--P', type=int, default=10,
45
+ help='Maximum number of uniformly distributed samples in [%].[defaul=10]')
46
+ parser.add_argument('--g_sd', type=int, default=2,
47
+ help='gain parameters > 0. [default=2]')
48
+
49
+ # SSI_additive_noise parameters
50
+ parser.add_argument('--SNRmin', type=int, default=10,
51
+ help='Minimum SNR value for coloured additive noise.[defaul=10]')
52
+ parser.add_argument('--SNRmax', type=int, default=40,
53
+ help='Maximum SNR value for coloured additive noise.[defaul=40]')
54
+ """
55
+ ############################################################################
56
+ # conversion from agrparse to class object
57
+ self.algo = 3
58
+ self.nBands = 5
59
+ self.minF = 20
60
+ self.maxF = 8000
61
+ self.minBW = 100
62
+ self.maxBW = 1000
63
+ self.minCoeff = 10
64
+ self.maxCoeff = 100
65
+ self.minG = 0
66
+ self.maxG = 0
67
+ self.minBiasLinNonLin = 5
68
+ self.maxBiasLinNonLin = 20
69
+ self.N_f = 5
70
+ self.P = 10
71
+ self.g_sd = 2
72
+ self.SNRmin = 10
73
+ self.SNRmax = 40
74
+
75
+
76
+ #############################################################################
77
+ """
78
+ parser.add_argument('--database_path', type=str, default='/your/path/to/data/ASVspoof_database/DF/', help='Change this to user\'s full directory address of LA database (ASVspoof2019- for training & development (used as validation), ASVspoof2021 DF for evaluation scores). We assume that all three ASVspoof 2019 LA train, LA dev and ASVspoof2021 DF eval data folders are in the same database_path directory.')
79
+ '''
80
+ % database_path/
81
+ % |- DF
82
+ % |- ASVspoof2021_DF_eval/flac
83
+ % |- ASVspoof2019_LA_train/flac
84
+ % |- ASVspoof2019_LA_dev/flac
85
+ '''
86
+
87
+ parser.add_argument('--protocols_path', type=str, default='database/', help='Change with path to user\'s DF database protocols directory address')
88
+ '''
89
+ % protocols_path/
90
+ % |- ASVspoof_LA_cm_protocols
91
+ % |- ASVspoof2021.LA.cm.eval.trl.txt
92
+ % |- ASVspoof2019.LA.cm.dev.trl.txt
93
+ % |- ASVspoof2019.LA.cm.train.trn.txt
94
+
95
+ % |- ASVspoof_DF_cm_protocols
96
+ % |- ASVspoof2021.DF.cm.eval.trl.txt
97
+
98
+ '''
99
+
100
+ # Hyperparameters
101
+ parser.add_argument('--batch_size', type=int, default=14)
102
+ parser.add_argument('--num_epochs', type=int, default=100)
103
+ parser.add_argument('--lr', type=float, default=0.000001)
104
+ parser.add_argument('--weight_decay', type=float, default=0.0001)
105
+ parser.add_argument('--loss', type=str, default='weighted_CCE')
106
+ # model
107
+ parser.add_argument('--seed', type=int, default=1234,
108
+ help='random seed (default: 1234)')
109
+
110
+ parser.add_argument('--model_path', type=str,
111
+ default=None, help='Model checkpoint')
112
+ parser.add_argument('--comment', type=str, default=None,
113
+ help='Comment to describe the saved model')
114
+ # Auxiliary arguments
115
+ parser.add_argument('--track', type=str, default='DF',choices=['LA', 'PA','DF'], help='LA/PA/DF')
116
+ parser.add_argument('--eval_output', type=str, default=None,
117
+ help='Path to save the evaluation result')
118
+ parser.add_argument('--eval', action='store_true', default=False,
119
+ help='eval mode')
120
+ parser.add_argument('--is_eval', action='store_true', default=False,help='eval database')
121
+ parser.add_argument('--eval_part', type=int, default=0)
122
+ # backend options
123
+ parser.add_argument('--cudnn-deterministic-toggle', action='store_false', \
124
+ default=True,
125
+ help='use cudnn-deterministic? (default true)')
126
+
127
+ parser.add_argument('--cudnn-benchmark-toggle', action='store_true', \
128
+ default=False,
129
+ help='use cudnn-benchmark? (default false)')
130
+ """
131
+
132
+ self.weight_decay = 0.0001
133
+ self.loss = 'weighted_CCE'
134
+ self.seed = 1234
135
+ self.model_path = "models/LA_model.pth"
136
+ self.comment = None
137
+ self.track = 'DF'
138
+ self.eval_output = None
139
+ self.eval = False
140
+ self.is_eval = False
141
+ self.eval_part = 0
142
+ self.cudnn_deterministic_toggle = False
143
+ self.cudnn_benchmark_toggle = False
144
+
145
+ self.wandb_config = {
146
+ 'project': 'Speech Assignment 3',
147
+ 'run_name': 'LA_model',
148
+ }
149
+
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.circleci/config.yml ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use 2.1 for orbs
2
+ version: 2.1
3
+
4
+ # -------------------------------------------------------------------------------------
5
+ # Environments to run the jobs in
6
+ # -------------------------------------------------------------------------------------
7
+ gpu: &gpu
8
+ environment:
9
+ CUDA_VERSION: "11.1"
10
+ machine:
11
+ image: ubuntu-1604-cuda-11.1:202012-01
12
+ resource_class: gpu.nvidia.medium.multi
13
+
14
+
15
+ # -------------------------------------------------------------------------------------
16
+ # Re-usable commands
17
+ # -------------------------------------------------------------------------------------
18
+ cache_key: &cache_key cache-key-{{ .Environment.CIRCLE_JOB }}-{{ checksum ".circleci/config.yml" }}-{{ checksum "setup.py"}}
19
+
20
+ install_dep_common: &install_dep_common
21
+ - run:
22
+ name: Install Common Dependencies
23
+ command: |
24
+ source activate fairseq
25
+ pip install --upgrade setuptools
26
+ pip install bitarray boto3 deepspeed editdistance fastBPE iopath ipdb ipython pyarrow pytest sacremoses sentencepiece subword-nmt hydra-core==1.0.7 omegaconf==2.0.6
27
+ pip install --progress-bar off pytest
28
+ pip install --progress-bar off fairscale
29
+ pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda111 -U
30
+ python -c 'import torch; print("Torch version:", torch.__version__)'
31
+ python -m torch.utils.collect_env
32
+
33
+ install_dep_fused_ops: &install_dep_fused_ops
34
+ - run:
35
+ name: Install Megatron/Apex Dependencies
36
+ working_directory: ~/
37
+ command: |
38
+ source activate fairseq
39
+ git clone https://github.com/NVIDIA/apex
40
+ cd apex
41
+ git checkout e2083df5eb96643c61613b9df48dd4eea6b07690
42
+ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--deprecated_fused_adam" --global-option="--xentropy" --global-option="--fast_multihead_attn" ./
43
+ cd ~/
44
+ git clone --depth=1 --branch v2.4 https://github.com/NVIDIA/Megatron-LM.git
45
+ cd Megatron-LM
46
+ pip install -e .
47
+
48
+
49
+ install_dep_pt19: &install_dep_pt19
50
+ - run:
51
+ name: Install Pytorch Dependencies
52
+ command: |
53
+ source activate fairseq
54
+ pip install --upgrade setuptools
55
+ pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html
56
+ python -c 'import torch; print("Torch version:", torch.__version__)'
57
+
58
+ install_dep_pt18: &install_dep_pt18
59
+ - run:
60
+ name: Install Pytorch Dependencies
61
+ command: |
62
+ source activate fairseq
63
+ pip install --upgrade setuptools
64
+ pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
65
+ python -c 'import torch; print("Torch version:", torch.__version__)'
66
+
67
+ install_repo: &install_repo
68
+ - run:
69
+ name: Install Repository
70
+ command: |
71
+ source activate fairseq
72
+ pip install .
73
+ python setup.py build_ext --inplace
74
+
75
+ run_unittests: &run_unittests
76
+ - run:
77
+ name: Run Unit Tests
78
+ command: |
79
+ source activate fairseq
80
+ pytest tests/gpu/test_binaries_gpu.py
81
+
82
+ check_nvidia_driver: &check_nvidia_driver
83
+ - run:
84
+ name: Check NVIDIA Driver
85
+ working_directory: ~/
86
+ command: |
87
+ pyenv versions
88
+ nvidia-smi
89
+
90
+ create_conda_env: &create_conda_env
91
+ - run:
92
+ name: Install and Create Conda Environment
93
+ command: |
94
+ curl -o ~/miniconda.sh -O https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
95
+ chmod +x ~/miniconda.sh
96
+ ~/miniconda.sh -b -p $HOME/miniconda
97
+ rm ~/miniconda.sh
98
+ echo 'export PATH=$HOME/miniconda/bin:$PATH' >> $BASH_ENV
99
+ source $BASH_ENV
100
+ if [ ! -d ~/miniconda/envs/fairseq ]
101
+ then
102
+ conda create -y -n fairseq python=3.8
103
+ fi
104
+ source activate fairseq
105
+ python --version
106
+ pip install --upgrade pip
107
+ # -------------------------------------------------------------------------------------
108
+ # Jobs to run
109
+ # -------------------------------------------------------------------------------------
110
+
111
+ jobs:
112
+ gpu_tests_pt19:
113
+ <<: *gpu
114
+
115
+ working_directory: ~/fairseq-py
116
+
117
+ steps:
118
+ - checkout
119
+ - <<: *check_nvidia_driver
120
+ - <<: *create_conda_env
121
+ - restore_cache:
122
+ key: *cache_key
123
+ - <<: *install_dep_pt19
124
+ - <<: *install_dep_common
125
+ - <<: *install_dep_fused_ops
126
+ - save_cache:
127
+ paths:
128
+ - ~/miniconda/
129
+ key: *cache_key
130
+ - <<: *install_repo
131
+ - <<: *run_unittests
132
+
133
+ gpu_tests_pt18:
134
+ <<: *gpu
135
+
136
+ working_directory: ~/fairseq-py
137
+
138
+ steps:
139
+ - checkout
140
+ - <<: *check_nvidia_driver
141
+ - <<: *create_conda_env
142
+ - restore_cache:
143
+ key: *cache_key
144
+ - <<: *install_dep_pt18
145
+ - <<: *install_dep_common
146
+ - <<: *install_dep_fused_ops
147
+ - save_cache:
148
+ paths:
149
+ - ~/miniconda/
150
+ key: *cache_key
151
+ - <<: *install_repo
152
+ - <<: *run_unittests
153
+
154
+ workflows:
155
+ version: 2
156
+ build:
157
+ jobs:
158
+ - gpu_tests_pt18
159
+ - gpu_tests_pt19
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.github/ISSUE_TEMPLATE.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ## 👉 [Please follow one of these issue templates](https://github.com/pytorch/fairseq/issues/new/choose) 👈
2
+
3
+ Note: to keep the backlog clean and actionable, issues may be immediately closed if they do not follow one of the above issue templates.
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.github/ISSUE_TEMPLATE/bug_report.md ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: 🐛 Bug Report
3
+ about: Submit a bug report to help us improve
4
+ labels: 'bug, needs triage'
5
+ ---
6
+
7
+ ## 🐛 Bug
8
+
9
+ <!-- A clear and concise description of what the bug is. -->
10
+
11
+ ### To Reproduce
12
+
13
+ Steps to reproduce the behavior (**always include the command you ran**):
14
+
15
+ 1. Run cmd '....'
16
+ 2. See error
17
+
18
+ <!-- If you have a code sample, error messages, stack traces, please provide it here as well -->
19
+
20
+
21
+ #### Code sample
22
+ <!-- Ideally attach a minimal code sample to reproduce the decried issue.
23
+ Minimal means having the shortest code but still preserving the bug. -->
24
+
25
+ ### Expected behavior
26
+
27
+ <!-- A clear and concise description of what you expected to happen. -->
28
+
29
+ ### Environment
30
+
31
+ - fairseq Version (e.g., 1.0 or main):
32
+ - PyTorch Version (e.g., 1.0)
33
+ - OS (e.g., Linux):
34
+ - How you installed fairseq (`pip`, source):
35
+ - Build command you used (if compiling from source):
36
+ - Python version:
37
+ - CUDA/cuDNN version:
38
+ - GPU models and configuration:
39
+ - Any other relevant information:
40
+
41
+ ### Additional context
42
+
43
+ <!-- Add any other context about the problem here. -->
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.github/ISSUE_TEMPLATE/documentation.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: 📚 Documentation/Typos
3
+ about: Report an issue related to documentation or a typo
4
+ labels: 'documentation, needs triage'
5
+ ---
6
+
7
+ ## 📚 Documentation
8
+
9
+ For typos and doc fixes, please go ahead and:
10
+
11
+ 1. Create an issue.
12
+ 2. Fix the typo.
13
+ 3. Submit a PR.
14
+
15
+ Thanks!
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.github/ISSUE_TEMPLATE/feature_request.md ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: 🚀 Feature Request
3
+ about: Submit a proposal/request for a new feature
4
+ labels: 'enhancement, help wanted, needs triage'
5
+ ---
6
+
7
+ ## 🚀 Feature Request
8
+ <!-- A clear and concise description of the feature proposal -->
9
+
10
+ ### Motivation
11
+
12
+ <!-- Please outline the motivation for the proposal. Is your feature request related to a problem? e.g., I'm always frustrated when [...]. If this is related to another GitHub issue, please link here too -->
13
+
14
+ ### Pitch
15
+
16
+ <!-- A clear and concise description of what you want to happen. -->
17
+
18
+ ### Alternatives
19
+
20
+ <!-- A clear and concise description of any alternative solutions or features you've considered, if any. -->
21
+
22
+ ### Additional context
23
+
24
+ <!-- Add any other context or screenshots about the feature request here. -->
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.github/ISSUE_TEMPLATE/how-to-question.md ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: ❓ Questions/Help
3
+ about: If you have questions, please first search existing issues and docs
4
+ labels: 'question, needs triage'
5
+ ---
6
+
7
+ ## ❓ Questions and Help
8
+
9
+ ### Before asking:
10
+ 1. search the issues.
11
+ 2. search the docs.
12
+
13
+ <!-- If you still can't find what you need: -->
14
+
15
+ #### What is your question?
16
+
17
+ #### Code
18
+
19
+ <!-- Please paste a code snippet if your question requires it! -->
20
+
21
+ #### What have you tried?
22
+
23
+ #### What's your environment?
24
+
25
+ - fairseq Version (e.g., 1.0 or main):
26
+ - PyTorch Version (e.g., 1.0)
27
+ - OS (e.g., Linux):
28
+ - How you installed fairseq (`pip`, source):
29
+ - Build command you used (if compiling from source):
30
+ - Python version:
31
+ - CUDA/cuDNN version:
32
+ - GPU models and configuration:
33
+ - Any other relevant information:
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.github/PULL_REQUEST_TEMPLATE.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Before submitting
2
+
3
+ - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
4
+ - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/main/CONTRIBUTING.md)?
5
+ - [ ] Did you make sure to update the docs?
6
+ - [ ] Did you write any new necessary tests?
7
+
8
+ ## What does this PR do?
9
+ Fixes # (issue).
10
+
11
+ ## PR review
12
+ Anyone in the community is free to review the PR once the tests have passed.
13
+ If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
14
+
15
+ ## Did you have fun?
16
+ Make sure you had fun coding 🙃
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.github/stale.yml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration for probot-stale - https://github.com/probot/stale
2
+ # Mostly copied from github.com/facebook/react/blob/master/.github/stale.yml
3
+ # Number of days of inactivity before an issue becomes stale
4
+ daysUntilStale: 90
5
+ # Number of days of inactivity before a stale issue is closed
6
+ daysUntilClose: 7
7
+ # Issues with these labels will never be considered stale
8
+ exemptLabels:
9
+ - bug
10
+ # Label to use when marking an issue as stale
11
+ staleLabel: stale
12
+ issues:
13
+ # Comment to post when marking an issue as stale.
14
+ markComment: >
15
+ This issue has been automatically marked as stale.
16
+ **If this issue is still affecting you, please leave any comment** (for example, "bump"), and we'll keep it open.
17
+ We are sorry that we haven't been able to prioritize it yet. If you have any new additional information, please include it with your comment!
18
+ # Comment to post when closing a stale issue.
19
+ closeComment: >
20
+ Closing this issue after a prolonged period of inactivity. If this issue is still present in the latest release, please create a new issue with up-to-date information. Thank you!
21
+ pulls:
22
+ # Comment to post when marking a pull request as stale.
23
+ markComment: >
24
+ This pull request has been automatically marked as stale.
25
+ **If this pull request is still relevant, please leave any comment** (for example, "bump"), and we'll keep it open.
26
+ We are sorry that we haven't been able to prioritize reviewing it yet. Your contribution is very much appreciated.
27
+ # Comment to post when closing a stale pull request.
28
+ closeComment: >
29
+ Closing this pull request after a prolonged period of inactivity. If this issue is still present in the latest release, please ask for this pull request to be reopened. Thank you!
30
+
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.github/workflows/build.yml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: build
2
+
3
+ on:
4
+ # Trigger the workflow on push to main or any pull request
5
+ push:
6
+ branches:
7
+ - main
8
+ pull_request:
9
+
10
+ jobs:
11
+ build:
12
+
13
+ strategy:
14
+ max-parallel: 4
15
+ matrix:
16
+ platform: [ubuntu-latest, macos-latest]
17
+ python-version: [3.8, 3.9]
18
+
19
+ runs-on: ${{ matrix.platform }}
20
+
21
+ steps:
22
+ - uses: actions/checkout@v2
23
+
24
+ - name: Set up Python ${{ matrix.python-version }}
25
+ uses: actions/setup-python@v2
26
+ with:
27
+ python-version: ${{ matrix.python-version }}
28
+
29
+ - name: Conditionally install pytorch
30
+ if: matrix.platform == 'windows-latest'
31
+ run: pip3 install torch -f https://download.pytorch.org/whl/torch_stable.html
32
+
33
+ - name: Install locally
34
+ run: |
35
+ python -m pip install --upgrade pip
36
+ git submodule update --init --recursive
37
+ python setup.py build_ext --inplace
38
+ python -m pip install --editable .
39
+
40
+ - name: Install optional test requirements
41
+ run: |
42
+ python -m pip install iopath transformers pyarrow
43
+ python -m pip install git+https://github.com/facebookresearch/fairscale.git@main
44
+
45
+ - name: Lint with flake8
46
+ run: |
47
+ pip install flake8
48
+ # stop the build if there are Python syntax errors or undefined names
49
+ flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --extend-exclude fairseq/model_parallel/megatron
50
+ # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
51
+ flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --extend-exclude fairseq/model_parallel/megatron
52
+
53
+ - name: Run tests
54
+ run: |
55
+ python setup.py test
56
+
57
+ - name: Lint with black
58
+ run: |
59
+ pip install black
60
+ black --check . --extend-exclude 'examples|fairseq\/model_parallel\/megatron'
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.github/workflows/build_wheels.yml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: build_wheels
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - v[0-9]+.[0-9]+.[x0-9]+
7
+ tags:
8
+ - v*
9
+
10
+ jobs:
11
+ build_wheels:
12
+ name: Build wheels on ${{ matrix.os }}
13
+ runs-on: ${{ matrix.os }}
14
+ strategy:
15
+ matrix:
16
+ os: [ubuntu-latest, macos-latest]
17
+
18
+ steps:
19
+ - uses: actions/checkout@v2
20
+
21
+ - name: Install Python
22
+ uses: actions/setup-python@v2
23
+ with:
24
+ python-version: '3.7'
25
+
26
+ - name: Install cibuildwheel
27
+ run: |
28
+ python -m pip install cibuildwheel
29
+
30
+ - name: Build wheels for CPython
31
+ run: |
32
+ python -m cibuildwheel --output-dir dist
33
+ env:
34
+ CIBW_BUILD: "cp36-*64 cp37-*64 cp38-*64"
35
+ CIBW_MANYLINUX_X86_64_IMAGE: manylinux1
36
+ CIBW_BEFORE_BUILD: git submodule update --init --recursive && pip install .
37
+
38
+ - uses: actions/upload-artifact@v2
39
+ with:
40
+ name: wheels
41
+ path: ./dist/*.whl
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.gitignore ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # JetBrains PyCharm IDE
2
+ .idea/
3
+
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ *.py[cod]
7
+ *$py.class
8
+
9
+ # C extensions
10
+ *.so
11
+
12
+ # macOS dir files
13
+ .DS_Store
14
+
15
+ # Distribution / packaging
16
+ .Python
17
+ env/
18
+ build/
19
+ develop-eggs/
20
+ dist/
21
+ downloads/
22
+ eggs/
23
+ .eggs/
24
+ lib/
25
+ lib64/
26
+ parts/
27
+ sdist/
28
+ var/
29
+ wheels/
30
+ *.egg-info/
31
+ .installed.cfg
32
+ *.egg
33
+
34
+ # Checkpoints
35
+ checkpoints
36
+
37
+ # PyInstaller
38
+ # Usually these files are written by a python script from a template
39
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
40
+ *.manifest
41
+ *.spec
42
+
43
+ # Installer logs
44
+ pip-log.txt
45
+ pip-delete-this-directory.txt
46
+
47
+ # Unit test / coverage reports
48
+ htmlcov/
49
+ .tox/
50
+ .coverage
51
+ .coverage.*
52
+ .cache
53
+ nosetests.xml
54
+ coverage.xml
55
+ *.cover
56
+ .hypothesis/
57
+
58
+ # Translations
59
+ *.mo
60
+ *.pot
61
+
62
+ # Django stuff:
63
+ *.log
64
+ local_settings.py
65
+
66
+ # Flask stuff:
67
+ instance/
68
+ .webassets-cache
69
+
70
+ # Scrapy stuff:
71
+ .scrapy
72
+
73
+ # Sphinx documentation
74
+ docs/_build/
75
+
76
+ # PyBuilder
77
+ target/
78
+
79
+ # Jupyter Notebook
80
+ .ipynb_checkpoints
81
+
82
+ # pyenv
83
+ .python-version
84
+
85
+ # celery beat schedule file
86
+ celerybeat-schedule
87
+
88
+ # SageMath parsed files
89
+ *.sage.py
90
+
91
+ # dotenv
92
+ .env
93
+
94
+ # virtualenv
95
+ .venv
96
+ venv/
97
+ ENV/
98
+
99
+ # Spyder project settings
100
+ .spyderproject
101
+ .spyproject
102
+
103
+ # Rope project settings
104
+ .ropeproject
105
+
106
+ # mkdocs documentation
107
+ /site
108
+
109
+ # mypy
110
+ .mypy_cache/
111
+
112
+ # Generated files
113
+ /fairseq/temporal_convolution_tbc
114
+ /fairseq/modules/*_layer/*_forward.cu
115
+ /fairseq/modules/*_layer/*_backward.cu
116
+ /fairseq/version.py
117
+
118
+ # data
119
+ data-bin/
120
+
121
+ # reranking
122
+ /examples/reranking/rerank_data
123
+
124
+ # Cython-generated C++ source files
125
+ /fairseq/data/data_utils_fast.cpp
126
+ /fairseq/data/token_block_utils_fast.cpp
127
+
128
+ # VSCODE
129
+ .vscode/ftp-sync.json
130
+ .vscode/settings.json
131
+
132
+ # Experimental Folder
133
+ experimental/*
134
+
135
+ # Weights and Biases logs
136
+ wandb/
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.gitmodules ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ [submodule "fairseq/model_parallel/megatron"]
2
+ path = fairseq/model_parallel/megatron
3
+ url = https://github.com/ngoyal2707/Megatron-LM
4
+ branch = fairseq
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.isort.cfg ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [settings]
2
+ known_third_party = _cffi_backend,agg_results,aml,bitarray,boto3,botocore,dump_hubert_feature,dynamicconv_cuda,editdistance,faiss,fasttext,feature_utils,ffmpeg,g2p_en,h5py,hydra,hypothesis,indicnlp,inflect,iopath,joblib,kaldi_io,kenlm,libfb,librosa,lightconv_cuda,matplotlib,misc,mmpt,mmpt_cli,model,nltk,npy_append_array,numpy,omegaconf,pandas,pathbuilder,preprocessing,progressbar,pythainlp,random_sequence_shuffler,regex,sacrebleu,sacremoses,scipy,sentencepiece,setuptools,six,sklearn,soundfile,sweep,sweep_wmt_en2de_transformer_big_common,tabulate,torch,torchaudio,tqdm,unidecode,utils,videoreader,wav2vec_cluster_faiss,wget,yaml
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/.pre-commit-config.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exclude: 'build|stubs'
2
+
3
+ default_language_version:
4
+ python: python3
5
+
6
+ repos:
7
+ - repo: https://github.com/pre-commit/pre-commit-hooks
8
+ rev: v4.0.1
9
+ hooks:
10
+ - id: trailing-whitespace
11
+ - id: check-ast
12
+ - id: check-merge-conflict
13
+ - id: no-commit-to-branch
14
+ args: ['--branch=master']
15
+ - id: check-added-large-files
16
+ args: ['--maxkb=500']
17
+ - id: end-of-file-fixer
18
+
19
+ - repo: https://github.com/ambv/black
20
+ rev: 21.12b0
21
+ hooks:
22
+ - id: black
23
+ language_version: python3.8
24
+
25
+ - repo: https://gitlab.com/pycqa/flake8
26
+ rev: 3.9.2
27
+ hooks:
28
+ - id: flake8
29
+ args: [
30
+ # only error for syntax errors and undefined names
31
+ "--select=E9,F63,F7,F82",
32
+ ]
33
+
34
+ - repo: https://github.com/pycqa/isort
35
+ rev: 5.10.1
36
+ hooks:
37
+ - id: isort
38
+ exclude: README.md
39
+ additional_dependencies: [toml]
40
+ args: ["--profile", "black"]
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ In the interest of fostering an open and welcoming environment, we as
6
+ contributors and maintainers pledge to make participation in our project and
7
+ our community a harassment-free experience for everyone, regardless of age, body
8
+ size, disability, ethnicity, sex characteristics, gender identity and expression,
9
+ level of experience, education, socio-economic status, nationality, personal
10
+ appearance, race, religion, or sexual identity and orientation.
11
+
12
+ ## Our Standards
13
+
14
+ Examples of behavior that contributes to creating a positive environment
15
+ include:
16
+
17
+ * Using welcoming and inclusive language
18
+ * Being respectful of differing viewpoints and experiences
19
+ * Gracefully accepting constructive criticism
20
+ * Focusing on what is best for the community
21
+ * Showing empathy towards other community members
22
+
23
+ Examples of unacceptable behavior by participants include:
24
+
25
+ * The use of sexualized language or imagery and unwelcome sexual attention or
26
+ advances
27
+ * Trolling, insulting/derogatory comments, and personal or political attacks
28
+ * Public or private harassment
29
+ * Publishing others' private information, such as a physical or electronic
30
+ address, without explicit permission
31
+ * Other conduct which could reasonably be considered inappropriate in a
32
+ professional setting
33
+
34
+ ## Our Responsibilities
35
+
36
+ Project maintainers are responsible for clarifying the standards of acceptable
37
+ behavior and are expected to take appropriate and fair corrective action in
38
+ response to any instances of unacceptable behavior.
39
+
40
+ Project maintainers have the right and responsibility to remove, edit, or
41
+ reject comments, commits, code, wiki edits, issues, and other contributions
42
+ that are not aligned to this Code of Conduct, or to ban temporarily or
43
+ permanently any contributor for other behaviors that they deem inappropriate,
44
+ threatening, offensive, or harmful.
45
+
46
+ ## Scope
47
+
48
+ This Code of Conduct applies within all project spaces, and it also applies when
49
+ an individual is representing the project or its community in public spaces.
50
+ Examples of representing a project or community include using an official
51
+ project e-mail address, posting via an official social media account, or acting
52
+ as an appointed representative at an online or offline event. Representation of
53
+ a project may be further defined and clarified by project maintainers.
54
+
55
+ ## Enforcement
56
+
57
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
58
+ reported by contacting the project team at <conduct@pytorch.org>. All
59
+ complaints will be reviewed and investigated and will result in a response that
60
+ is deemed necessary and appropriate to the circumstances. The project team is
61
+ obligated to maintain confidentiality with regard to the reporter of an incident.
62
+ Further details of specific enforcement policies may be posted separately.
63
+
64
+ Project maintainers who do not follow or enforce the Code of Conduct in good
65
+ faith may face temporary or permanent repercussions as determined by other
66
+ members of the project's leadership.
67
+
68
+ ## Attribution
69
+
70
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
71
+ available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
72
+
73
+ [homepage]: https://www.contributor-covenant.org
74
+
75
+ For answers to common questions about this code of conduct, see
76
+ https://www.contributor-covenant.org/faq
77
+
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/CONTRIBUTING.md ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to Facebook AI Research Sequence-to-Sequence Toolkit (fairseq)
2
+ We want to make contributing to this project as easy and transparent as
3
+ possible.
4
+
5
+ ## Pull Requests
6
+ We actively welcome your pull requests.
7
+
8
+ 1. Fork the repo and create your branch from `main`.
9
+ 2. If you've added code that should be tested, add tests.
10
+ 3. If you've changed APIs, update the documentation.
11
+ 4. Ensure the test suite passes.
12
+ 5. Make sure your code lints.
13
+ 6. If you haven't already, complete the Contributor License Agreement ("CLA").
14
+
15
+ ## Contributor License Agreement ("CLA")
16
+ In order to accept your pull request, we need you to submit a CLA. You only need
17
+ to do this once to work on any of Facebook's open source projects.
18
+
19
+ Complete your CLA here: <https://code.facebook.com/cla>
20
+
21
+ ## Issues
22
+ We use GitHub issues to track public bugs. Please ensure your description is
23
+ clear and has sufficient instructions to be able to reproduce the issue.
24
+
25
+ ## License
26
+ By contributing to Facebook AI Research Sequence-to-Sequence Toolkit (fairseq),
27
+ you agree that your contributions will be licensed under the LICENSE file in
28
+ the root directory of this source tree.
29
+
30
+ ## Pre-commit hooks
31
+ In order to ensure your code lints, there are pre-commit hooks configured in the repository which you can install.
32
+ After installation, they will automatically run each time you commit.
33
+ An abbreviated guide is given below; for more information, refer to [the offical pre-commit documentation](https://pre-commit.com/).
34
+
35
+ ### Installation
36
+ ```
37
+ pip install pre-commit
38
+ pre-commit install
39
+ ```
40
+
41
+ ### Usage
42
+ Just commit your changes:
43
+ ```
44
+ git commit -m "My informative commit message"
45
+ ```
46
+
47
+ If there was a failure, you will get feedback
48
+ ```
49
+ [INFO] Initializing environment for https://github.com/PyCQA/flake8.
50
+ [INFO] Installing environment for https://github.com/pre-commit/pre-commit-hooks.
51
+ [INFO] Once installed this environment will be reused.
52
+ [INFO] This may take a few minutes...
53
+ [INFO] Installing environment for https://github.com/PyCQA/flake8.
54
+ [INFO] Once installed this environment will be reused.
55
+ [INFO] This may take a few minutes...
56
+ Trim Trailing Whitespace.................................................Failed
57
+ - hook id: trailing-whitespace
58
+ - exit code: 1
59
+ - files were modified by this hook
60
+ Fixing examples/nllb/modeling/wmt15_benchmark/eval_langs2.sh
61
+ Fix End of Files.........................................................Failed
62
+ - hook id: end-of-file-fixer
63
+ - exit code: 1
64
+ - files were modified by this hook
65
+ Fixing examples/few_shot/scripts/schedule_jobs_few_shot.py
66
+ flake8...................................................................Passed
67
+ ```
68
+
69
+ Certain hooks modify your files to comply.
70
+ To include these modifications, you will need to add them (i.e. `git add ...`) and commit again.
71
+
72
+ If all is well, you should see something like:
73
+ ```
74
+ Trim Trailing Whitespace.................................................Passed
75
+ Fix End of Files.........................................................Passed
76
+ flake8...................................................................Passed
77
+ [gshard-fix-ci 8698644e1] Fix lint, add pre-commit hooks
78
+ 10 files changed, 148 insertions(+), 110 deletions(-)
79
+ create mode 100644 .flake8
80
+ create mode 100644 .pre-commit-config.yaml
81
+ rename examples/nllb/modeling/wmt15_benchmark/{eval_langs2.py => eval_langs2.sh} (99%)
82
+ ```
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) Facebook, Inc. and its affiliates.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/README.md ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="center">
2
+ <img src="docs/fairseq_logo.png" width="150">
3
+ <br />
4
+ <br />
5
+ <a href="https://github.com/pytorch/fairseq/blob/main/LICENSE"><img alt="MIT License" src="https://img.shields.io/badge/license-MIT-blue.svg" /></a>
6
+ <a href="https://github.com/pytorch/fairseq/releases"><img alt="Latest Release" src="https://img.shields.io/github/release/pytorch/fairseq.svg" /></a>
7
+ <a href="https://github.com/pytorch/fairseq/actions?query=workflow:build"><img alt="Build Status" src="https://github.com/pytorch/fairseq/workflows/build/badge.svg" /></a>
8
+ <a href="https://fairseq.readthedocs.io/en/latest/?badge=latest"><img alt="Documentation Status" src="https://readthedocs.org/projects/fairseq/badge/?version=latest" /></a>
9
+ </p>
10
+
11
+ --------------------------------------------------------------------------------
12
+
13
+ Fairseq(-py) is a sequence modeling toolkit that allows researchers and
14
+ developers to train custom models for translation, summarization, language
15
+ modeling and other text generation tasks.
16
+
17
+ We provide reference implementations of various sequence modeling papers:
18
+
19
+ <details><summary>List of implemented papers</summary><p>
20
+
21
+ * **Convolutional Neural Networks (CNN)**
22
+ + [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/conv_lm/README.md)
23
+ + [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md)
24
+ + [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
25
+ + [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md)
26
+ + [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md)
27
+ * **LightConv and DynamicConv models**
28
+ + [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md)
29
+ * **Long Short-Term Memory (LSTM) networks**
30
+ + Effective Approaches to Attention-based Neural Machine Translation (Luong et al., 2015)
31
+ * **Transformer (self-attention) networks**
32
+ + Attention Is All You Need (Vaswani et al., 2017)
33
+ + [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md)
34
+ + [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md)
35
+ + [Adaptive Input Representations for Neural Language Modeling (Baevski and Auli, 2018)](examples/language_model/README.adaptive_inputs.md)
36
+ + [Lexically constrained decoding with dynamic beam allocation (Post & Vilar, 2018)](examples/constrained_decoding/README.md)
37
+ + [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context (Dai et al., 2019)](examples/truncated_bptt/README.md)
38
+ + [Adaptive Attention Span in Transformers (Sukhbaatar et al., 2019)](examples/adaptive_span/README.md)
39
+ + [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
40
+ + [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
41
+ + [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
42
+ + [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md )
43
+ + [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md)
44
+ + [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md)
45
+ + [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md)
46
+ + [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md)
47
+ + [Generating Medical Reports from Patient-Doctor Conversations Using Sequence-to-Sequence Models (Enarvi et al., 2020)](examples/pointer_generator/README.md)
48
+ + [Linformer: Self-Attention with Linear Complexity (Wang et al., 2020)](examples/linformer/README.md)
49
+ + [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md)
50
+ + [Deep Transformers with Latent Depth (Li et al., 2020)](examples/latent_depth/README.md)
51
+ + [Unsupervised Cross-lingual Representation Learning for Speech Recognition (Conneau et al., 2020)](https://arxiv.org/abs/2006.13979)
52
+ + [Self-training and Pre-training are Complementary for Speech Recognition (Xu et al., 2020)](https://arxiv.org/abs/2010.11430)
53
+ + [Robust wav2vec 2.0: Analyzing Domain Shift in Self-Supervised Pre-Training (Hsu, et al., 2021)](https://arxiv.org/abs/2104.01027)
54
+ + [Unsupervised Speech Recognition (Baevski, et al., 2021)](https://arxiv.org/abs/2105.11084)
55
+ + [Simple and Effective Zero-shot Cross-lingual Phoneme Recognition (Xu et al., 2021)](https://arxiv.org/abs/2109.11680)
56
+ + [VideoCLIP: Contrastive Pre-training for Zero-shot Video-Text Understanding (Xu et. al., 2021)](https://arxiv.org/pdf/2109.14084.pdf)
57
+ + [VLM: Task-agnostic Video-Language Model Pre-training for Video Understanding (Xu et. al., 2021)](https://aclanthology.org/2021.findings-acl.370.pdf)
58
+ + [NormFormer: Improved Transformer Pretraining with Extra Normalization (Shleifer et. al, 2021)](examples/normformer/README.md)
59
+ * **Non-autoregressive Transformers**
60
+ + Non-Autoregressive Neural Machine Translation (Gu et al., 2017)
61
+ + Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018)
62
+ + Insertion Transformer: Flexible Sequence Generation via Insertion Operations (Stern et al. 2019)
63
+ + Mask-Predict: Parallel Decoding of Conditional Masked Language Models (Ghazvininejad et al., 2019)
64
+ + [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md)
65
+ * **Finetuning**
66
+ + [Better Fine-Tuning by Reducing Representational Collapse (Aghajanyan et al. 2020)](examples/rxf/README.md)
67
+
68
+ </p></details>
69
+
70
+ ### What's New:
71
+ * October 2021 [Released VideoCLIP and VLM models](examples/MMPT/README.md)
72
+ * October 2021 [Released multilingual finetuned XLSR-53 model](examples/wav2vec/README.md)
73
+ * September 2021 [`master` branch renamed to `main`](https://github.com/github/renaming).
74
+ * July 2021 [Released DrNMT code](examples/discriminative_reranking_nmt/README.md)
75
+ * July 2021 [Released Robust wav2vec 2.0 model](examples/wav2vec/README.md)
76
+ * June 2021 [Released XLMR-XL and XLMR-XXL models](examples/xlmr/README.md)
77
+ * May 2021 [Released Unsupervised Speech Recognition code](examples/wav2vec/unsupervised/README.md)
78
+ * March 2021 [Added full parameter and optimizer state sharding + CPU offloading](examples/fully_sharded_data_parallel/README.md)
79
+ * February 2021 [Added LASER training code](examples/laser/README.md)
80
+ * December 2020: [Added Adaptive Attention Span code](examples/adaptive_span/README.md)
81
+ * December 2020: [GottBERT model and code released](examples/gottbert/README.md)
82
+ * November 2020: Adopted the [Hydra](https://github.com/facebookresearch/hydra) configuration framework
83
+ * [see documentation explaining how to use it for new and existing projects](docs/hydra_integration.md)
84
+ * November 2020: [fairseq 0.10.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.10.0)
85
+ * October 2020: [Added R3F/R4F (Better Fine-Tuning) code](examples/rxf/README.md)
86
+ * October 2020: [Deep Transformer with Latent Depth code released](examples/latent_depth/README.md)
87
+ * October 2020: [Added CRISS models and code](examples/criss/README.md)
88
+
89
+ <details><summary>Previous updates</summary><p>
90
+
91
+ * September 2020: [Added Linformer code](examples/linformer/README.md)
92
+ * September 2020: [Added pointer-generator networks](examples/pointer_generator/README.md)
93
+ * August 2020: [Added lexically constrained decoding](examples/constrained_decoding/README.md)
94
+ * August 2020: [wav2vec2 models and code released](examples/wav2vec/README.md)
95
+ * July 2020: [Unsupervised Quality Estimation code released](examples/unsupervised_quality_estimation/README.md)
96
+ * May 2020: [Follow fairseq on Twitter](https://twitter.com/fairseq)
97
+ * April 2020: [Monotonic Multihead Attention code released](examples/simultaneous_translation/README.md)
98
+ * April 2020: [Quant-Noise code released](examples/quant_noise/README.md)
99
+ * April 2020: [Initial model parallel support and 11B parameters unidirectional LM released](examples/megatron_11b/README.md)
100
+ * March 2020: [Byte-level BPE code released](examples/byte_level_bpe/README.md)
101
+ * February 2020: [mBART model and code released](examples/mbart/README.md)
102
+ * February 2020: [Added tutorial for back-translation](https://github.com/pytorch/fairseq/tree/main/examples/backtranslation#training-your-own-model-wmt18-english-german)
103
+ * December 2019: [fairseq 0.9.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.9.0)
104
+ * November 2019: [VizSeq released (a visual analysis toolkit for evaluating fairseq models)](https://facebookresearch.github.io/vizseq/docs/getting_started/fairseq_example)
105
+ * November 2019: [CamemBERT model and code released](examples/camembert/README.md)
106
+ * November 2019: [BART model and code released](examples/bart/README.md)
107
+ * November 2019: [XLM-R models and code released](examples/xlmr/README.md)
108
+ * September 2019: [Nonautoregressive translation code released](examples/nonautoregressive_translation/README.md)
109
+ * August 2019: [WMT'19 models released](examples/wmt19/README.md)
110
+ * July 2019: fairseq relicensed under MIT license
111
+ * July 2019: [RoBERTa models and code released](examples/roberta/README.md)
112
+ * June 2019: [wav2vec models and code released](examples/wav2vec/README.md)
113
+
114
+ </p></details>
115
+
116
+ ### Features:
117
+
118
+ * multi-GPU training on one machine or across multiple machines (data and model parallel)
119
+ * fast generation on both CPU and GPU with multiple search algorithms implemented:
120
+ + beam search
121
+ + Diverse Beam Search ([Vijayakumar et al., 2016](https://arxiv.org/abs/1610.02424))
122
+ + sampling (unconstrained, top-k and top-p/nucleus)
123
+ + [lexically constrained decoding](examples/constrained_decoding/README.md) (Post & Vilar, 2018)
124
+ * [gradient accumulation](https://fairseq.readthedocs.io/en/latest/getting_started.html#large-mini-batch-training-with-delayed-updates) enables training with large mini-batches even on a single GPU
125
+ * [mixed precision training](https://fairseq.readthedocs.io/en/latest/getting_started.html#training-with-half-precision-floating-point-fp16) (trains faster with less GPU memory on [NVIDIA tensor cores](https://developer.nvidia.com/tensor-cores))
126
+ * [extensible](https://fairseq.readthedocs.io/en/latest/overview.html): easily register new models, criterions, tasks, optimizers and learning rate schedulers
127
+ * [flexible configuration](docs/hydra_integration.md) based on [Hydra](https://github.com/facebookresearch/hydra) allowing a combination of code, command-line and file based configuration
128
+ * [full parameter and optimizer state sharding](examples/fully_sharded_data_parallel/README.md)
129
+ * [offloading parameters to CPU](examples/fully_sharded_data_parallel/README.md)
130
+
131
+ We also provide [pre-trained models for translation and language modeling](#pre-trained-models-and-examples)
132
+ with a convenient `torch.hub` interface:
133
+
134
+ ``` python
135
+ en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model')
136
+ en2de.translate('Hello world', beam=5)
137
+ # 'Hallo Welt'
138
+ ```
139
+
140
+ See the PyTorch Hub tutorials for [translation](https://pytorch.org/hub/pytorch_fairseq_translation/)
141
+ and [RoBERTa](https://pytorch.org/hub/pytorch_fairseq_roberta/) for more examples.
142
+
143
+ # Requirements and Installation
144
+
145
+ * [PyTorch](http://pytorch.org/) version >= 1.5.0
146
+ * Python version >= 3.6
147
+ * For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl)
148
+ * **To install fairseq** and develop locally:
149
+
150
+ ``` bash
151
+ git clone https://github.com/pytorch/fairseq
152
+ cd fairseq
153
+ pip install --editable ./
154
+
155
+ # on MacOS:
156
+ # CFLAGS="-stdlib=libc++" pip install --editable ./
157
+
158
+ # to install the latest stable release (0.10.x)
159
+ # pip install fairseq
160
+ ```
161
+
162
+ * **For faster training** install NVIDIA's [apex](https://github.com/NVIDIA/apex) library:
163
+
164
+ ``` bash
165
+ git clone https://github.com/NVIDIA/apex
166
+ cd apex
167
+ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" \
168
+ --global-option="--deprecated_fused_adam" --global-option="--xentropy" \
169
+ --global-option="--fast_multihead_attn" ./
170
+ ```
171
+
172
+ * **For large datasets** install [PyArrow](https://arrow.apache.org/docs/python/install.html#using-pip): `pip install pyarrow`
173
+ * If you use Docker make sure to increase the shared memory size either with `--ipc=host` or `--shm-size`
174
+ as command line options to `nvidia-docker run` .
175
+
176
+ # Getting Started
177
+
178
+ The [full documentation](https://fairseq.readthedocs.io/) contains instructions
179
+ for getting started, training new models and extending fairseq with new model
180
+ types and tasks.
181
+
182
+ # Pre-trained models and examples
183
+
184
+ We provide pre-trained models and pre-processed, binarized test sets for several tasks listed below,
185
+ as well as example training and evaluation commands.
186
+
187
+ * [Translation](examples/translation/README.md): convolutional and transformer models are available
188
+ * [Language Modeling](examples/language_model/README.md): convolutional and transformer models are available
189
+
190
+ We also have more detailed READMEs to reproduce results from specific papers:
191
+
192
+ * [XLS-R: Self-supervised Cross-lingual Speech Representation Learning at Scale (Babu et al., 2021)](examples/wav2vec/xlsr/README.md)
193
+ * [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md)
194
+ * [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md)
195
+ * [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md)
196
+ * [Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)](examples/quant_noise/README.md)
197
+ * [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md)
198
+ * [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md)
199
+ * [Reducing Transformer Depth on Demand with Structured Dropout (Fan et al., 2019)](examples/layerdrop/README.md)
200
+ * [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md)
201
+ * [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md)
202
+ * [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
203
+ * [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
204
+ * [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md)
205
+ * [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
206
+ * [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md)
207
+ * [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md)
208
+ * [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
209
+ * [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md)
210
+ * [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md)
211
+ * [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md)
212
+ * [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/README.conv.md)
213
+
214
+ # Join the fairseq community
215
+
216
+ * Twitter: https://twitter.com/fairseq
217
+ * Facebook page: https://www.facebook.com/groups/fairseq.users
218
+ * Google group: https://groups.google.com/forum/#!forum/fairseq-users
219
+
220
+ # License
221
+
222
+ fairseq(-py) is MIT-licensed.
223
+ The license applies to the pre-trained models as well.
224
+
225
+ # Citation
226
+
227
+ Please cite as:
228
+
229
+ ``` bibtex
230
+ @inproceedings{ott2019fairseq,
231
+ title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling},
232
+ author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli},
233
+ booktitle = {Proceedings of NAACL-HLT 2019: Demonstrations},
234
+ year = {2019},
235
+ }
236
+ ```
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ !*/*.sh
2
+ !*/*.md
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/.gitignore ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
130
+ runs
131
+ data
132
+ pretrained_models
133
+ projects/mmfusion_*
134
+ log_test
135
+ third-party
136
+ python_log
137
+ slurm_snapshot_code
138
+ lightning_logs
139
+ demos
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/CONFIG.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Config Files Explained
2
+
3
+ Taking `projects/mfmmlm.yaml` for example, which run pretraining using masked frame model (MFM) and masked language model (MLM) on a single BERT:
4
+
5
+ ```yaml
6
+ project_dir: mfmmlm # specify the project dir for this baseline.
7
+ run_task:
8
+ - how2.yaml # run pretraining on how2 when launching `projects/taskmfmmlm.yaml`
9
+ - [vtt.yaml, vttcap.yaml, vttqa.yaml, youcook.yaml, youcookcap.yaml, crosstask.yaml, coin.yaml] # run fine-tuning tasks.
10
+ base_dir: task # a global template folder to specify each training task.
11
+ task_group:
12
+ pretrain: # section for pretraining. Most baselines differs in this section.
13
+ task_list:
14
+ - how2.yaml # reconfig `projects/task/how2.yaml`
15
+ dataset:
16
+ aligner: MFMMLMAligner # overwrite the aligner for MFMMLM training task.
17
+ model:
18
+ model_cls: MMFusionMFMMLM # overwrite the model, which constructs negative examples for MFM on-the-fly.
19
+ loss:
20
+ loss_cls: MFMMLM # overwrite the loss as MFMMLM, which combines MFM and MLM together.
21
+ fairseq: # all fairseq args can be expecified under this name.
22
+ dataset:
23
+ batch_size: 128
24
+ finetune: # section for fine-tuning tasks, we don't need to change anything here mostly since we want to see how pretraining can contribute to finetuning.
25
+ task_list: # specify the list of downstream tasks, e.g., copy `projects/task/vtt.yaml` to `projects/mfmmlm`.
26
+ - vtt.yaml
27
+ - vttqa.yaml
28
+ - youcook.yaml
29
+ - youcookcap.yaml
30
+ - crosstask.yaml
31
+ - coin.yaml
32
+ test: # section for testing.
33
+ task_list:
34
+ - test_vtt.yaml
35
+ - test_vttqa.yaml
36
+ - test_youcook.yaml
37
+ - test_youcookcap.yaml
38
+ - test_crosstask.yaml
39
+ - test_crosstask_zs.yaml
40
+ - test_coin.yaml
41
+ ```
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/DATASET.md ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dataset
2
+
3
+ We understand video data are challenging to download and process. For videos, we provide our preprocessing scripts under `scripts/video_feature_extractor` (deeply adapted from `https://github.com/antoine77340/video_feature_extractor`); for text, we pre-tokenizing scripts under `scripts/text_token_extractor`.
4
+
5
+ ### S3D Feature Extraction
6
+ We use pre-trained [S3D](https://github.com/antoine77340/S3D_HowTo100M) for video feature extraction. Please place the models as `pretrained_models/s3d_dict.npy` and `pretrained_models/s3d_howto100m.pth`.
7
+
8
+ We implement a `PathBuilder` to automatically track video ids, source video paths to their feature locations (you may need `conda install -c anaconda pandas`). Decoding may need `pip install ffmpeg-python`.
9
+
10
+ ### Howto100M
11
+ [Howto100M](https://www.di.ens.fr/willow/research/howto100m/) is a large-scale video pre-training datasets. You may download videos by yourself and run preprocessing of our scripts.
12
+
13
+ Several key differences of our preprocessing from existing papers: (1) we use `raw_caption.json` instead of `caption.json` to have pure self-supervision on text (`caption.json` has manual removal of stop words); (2) we remove partially duplicated texts that are originally designed for real-time readability (see `mmpt/processors/dedupprocessor.py`); (3) then we shard video/text features using `SharedTensor` in `mmpt/utils/shardedtensor.py` for fast loading during training (faster than `h5py`).
14
+
15
+ #### Steps
16
+ ##### video
17
+ To extract video features: edit and run `bash scripts/video_feature_extractor/how2/s3d.sh`. (consider to run this on multiple machines; by default, we store features in fp16 to save space and also for faster training).
18
+
19
+ Split available video ids as `data/how2/how2_s3d_train.lst` and `data/how2/how2_s3d_val.lst`.
20
+
21
+ Lastly, pack video features into `ShardedTensor` using `python scripts/video_feature_extractor/shard_feature.py`.
22
+
23
+ ##### text
24
+ Clean captions using `python -m mmpt.processors.dedupprocessor`.
25
+
26
+ Tokenize dedupped captions `data/how2/raw_caption_dedup.pkl` into sharded numpy arrays:
27
+ ```
28
+ python scripts/text_token_extractor/pretokenization.py scripts/text_token_extractor/configs/bert-base-uncased.yaml
29
+ ```
30
+
31
+ ### Youcook, MSRVTT etc.
32
+ We use the version of Youcook and MSRVTT come with Howto100M and MILNCE. Please download the data to `data/youcook` and `data/msrvtt` accordingly, you can also check `projects/task/youcook.yaml` and `projects/task/vtt.yaml` etc. in details.
33
+ We extract features for Youcook, MSRVTT similar to the first step of Howto100M but we read text from meta data directly and perform on-the-fly tokenization.
34
+
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/README.md ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # VideoCLIP and VLM
2
+
3
+ You just find this toolkit for multimodal video understanding! It contains implementation of two recent multi-modal video understanding papers [VideoCLIP](https://arxiv.org/pdf/2109.14084.pdf) (EMNLP, 2021) and [VLM](https://aclanthology.org/2021.findings-acl.370.pdf) (ACL Findings, 2021), along with high-performance toolkits that are typically lacking in existing codebase. The toolkit is desigend to contain generic performance-tuned components that can be potentially adapted to other frameworks (we initially use fairseq).
4
+
5
+ VideoCLIP is a contrastive learning model for zero-shot transfer to retrieval/classification/sequence labeling style tasks.
6
+
7
+ <img src="videoclip.png" width="350" class="center">
8
+
9
+ VLM is a masked language model style pre-training using only one encoder with masked modality model (MMM) for retrieval/generation/sequence labeling style tasks.
10
+
11
+ <img src="vlm.png" width="350" class="center">
12
+
13
+ ### News
14
+ [Oct. 2021] Initial release of implementation for the following papers:
15
+ [VideoCLIP: Contrastive Pre-training for Zero-shot Video-Text Understanding](https://arxiv.org/pdf/2109.14084.pdf) (Xu et. al., EMNLP 2021)
16
+ [VLM: Task-agnostic Video-Language Model Pre-training for Video Understanding](https://aclanthology.org/2021.findings-acl.370.pdf) (Xu et. al., ACL Findings 2021)
17
+
18
+
19
+ ### Installation
20
+ We aim to minimize the dependency of this repo on other packages.
21
+ We use fairseq as the main trainer (no models/datasets dependency on fairseq. We will support other trainer in future):
22
+ ```
23
+ git clone https://github.com/pytorch/fairseq
24
+ cd fairseq
25
+ pip install -e . # also optionally follow fairseq README for apex installation for fp16 training.
26
+ export MKL_THREADING_LAYER=GNU # fairseq may need this for numpy.
27
+ ```
28
+
29
+ Then install this toolkit:
30
+ ```
31
+ cd examples/MMPT # MMPT can be in any folder, not necessarily under fairseq/examples.
32
+ pip install -e .
33
+ ```
34
+
35
+ The code is developed under Python=3.8.8, Pytorch=1.8, cuda=11.0 with fairseq=1.0.0a0+af0389f and tested under Python=3.8.8 pytorch=1.9 cuda=11.0 fairseq=1.0.0a0+8e7bc73 during code release.
36
+ Most models require `transformers==3.4` for API compatibility `pip install transformers==3.4`.
37
+ In addition, some downstream tasks may need `conda install pandas`.
38
+
39
+
40
+ ### Usage
41
+ #### Download Checkpoints
42
+ We use pre-trained [S3D](https://github.com/antoine77340/S3D_HowTo100M) for video feature extraction. Please place the models as `pretrained_models/s3d_dict.npy` and `pretrained_models/s3d_howto100m.pth`.
43
+
44
+ Download VideoCLIP checkpoint `https://dl.fbaipublicfiles.com/MMPT/retri/videoclip/checkpoint_best.pt` to `runs/retri/videoclip` or VLM checkpoint `https://dl.fbaipublicfiles.com/MMPT/mtm/vlm/checkpoint_best.pt` to `runs/mtm/vlm`.
45
+
46
+ #### Demo of Inference
47
+ run `python locallaunch.py projects/retri/videoclip.yaml --dryrun` to get all `.yaml`s for VideoCLIP.
48
+
49
+ ```python
50
+ import torch
51
+
52
+ from mmpt.models import MMPTModel
53
+
54
+
55
+ model, tokenizer, aligner = MMPTModel.from_pretrained(
56
+ "projects/retri/videoclip/how2.yaml")
57
+
58
+ model.eval()
59
+
60
+
61
+ # B, T, FPS, H, W, C (VideoCLIP is trained on 30 fps of s3d)
62
+ video_frames = torch.randn(1, 2, 30, 224, 224, 3)
63
+ caps, cmasks = aligner._build_text_seq(
64
+ tokenizer("some text", add_special_tokens=False)["input_ids"]
65
+ )
66
+
67
+ caps, cmasks = caps[None, :], cmasks[None, :] # bsz=1
68
+
69
+ with torch.no_grad():
70
+ output = model(video_frames, caps, cmasks, return_score=True)
71
+ print(output["score"]) # dot-product
72
+ ```
73
+
74
+ #### Data Preparation
75
+ See [dataset](DATASET.md) for each dataset.
76
+
77
+ #### Global Config for Training Pipeline
78
+ We organize a global config file for a training/testing pipeline under projects (see a detailed [explanation](CONFIG.md)). For example, VideoCLIP in `projects/retri/videoclip.yaml` and VLM is in `projects/mtm/vlm.yaml`.
79
+
80
+ We wrap all cmds into `locallaunch.py` and `mmpt_cli/localjob.py`. You can check concrete cmds by `--dryrun` and then drop it for actual run.
81
+
82
+ First, run `python locallaunch.py projects/retri/videoclip.yaml --dryrun` will generate configs for all configs of pre-training, zero-shot evaluation, fine-tuning and testing, for VideoCLIP under `projects/retri/videoclip`.
83
+
84
+ Then each (either training or evaluation) process will be configed by a concrete config file (we save all complex arguments into the concrete config file for reproducibility, including fairseq args). For example, run zero-shot evaluation on youcook,
85
+ ```
86
+ python locallaunch.py projects/retri/videoclip/test_youcook_zs.yaml --jobtype local_predict # zero-shot evaluation.
87
+ python locallaunch.py projects/retri/videoclip/youcook_videoclip.yaml --jobtype local_single --dryrun # fine-tuning: use --dryrun to check cmds and drop it to make an actual run; local_small will run on two gpus (as in paper).
88
+ python locallaunch.py projects/retri/videoclip/test_youcook_videoclip.yaml --jobtype local_predict # testing on fine-tuned model.
89
+ ```
90
+
91
+ Pretraining can be run as:
92
+ ```
93
+ python locallaunch.py projects/retri/videoclip/how2.yaml --jobtype local_single --dryrun # check then drop dryrun; paper is ran on local_big as 8 gpus.
94
+ ```
95
+ You may need to change `--jobtype`, check/extend `LocalJob` in `mmpt_cli/localjob.py` for multi-gpu/multi-node pre-training.
96
+
97
+ The detailed instructions of pretraining and fine-tuning can be found at [pretraining instruction](pretraining.md) and [finetuning instruction](endtask.md).
98
+
99
+
100
+ ### Development
101
+ Several components of this toolkit can be re-used for future research (and also our ongoing research).
102
+
103
+ #### Framework Wrapper
104
+ We currently only support fairseq, but most components can be easily fit into other frameworks like huggingface. This repo is a `--user-dir` of fairseq with fairseq wrapper. For example, `mmpt/tasks` includes a `FairseqMMTTask`, which manages `mmpt/datasets` with `FairseqDataset`, `mmpt/models` with `FairseqModel`, `mmpt/losses` with `FairseqCriterion`.
105
+
106
+ #### Processors
107
+ **Multi**modal research introduces the complexity on modality alignment from different input sources to losses. Inspired by [MMF](https://github.com/facebookresearch/mmf), this toolkit leverages `mmpt/processors` to handle various needs of data preprocessing and loading, **alleviating** the needs of multiple `torch.data.utils.Dataset` (that can be tricky for ablation study).
108
+ Processors can also be decoupled from `torch.data.utils.Dataset` for offline preprocessing instead of on-the-fly data preprocessing.
109
+
110
+ We decouple a `mmpt.MMDataset` as 3 types of processors: `MetaProcessor`, `VideoProcessor`, `TextProcessor` and `Aligner`. They can be configed in `dataset` field of a config file (e.g., see `projects/task/how2.yaml`).
111
+ `MetaProcessor` is used to load the meta data about a dataset, aka, all video_ids of how2 dataset.
112
+ `VideoProcessor` is used to load the video features about a dataset. For example, S3D features for each second of a video.
113
+ `TextProcessor` is used to load the text (feature). For example, BERT pre-tokenized text clips for how2 dataset (with `start`s, `end`s of timestamps and `cap` for `token_ids`).
114
+ `Aligner` is the core class for different baselines that prepares the training data. For example, sampling a clip, masking tokens for MLM, etc.
115
+
116
+ #### Performance-tuned Components
117
+ To speed up pre-training, this toolkit uses sharded features stored in mmaped numpy, backed by `ShardedTensor` in `mmpt/utils/shardedtensor.py` (adopted from MARGE paper). This reduces the loads of IO for multi-GPU training without loading all features for a video into the memory each time and `ShardedTensor` ensure features are stored in continuous disk space for near random access. This is used for both How2 video features and texts in `mmpt/processors/how2processor.py`.
118
+
119
+
120
+ ### Citation
121
+ If this codebase is useful for your work, please cite the following papers:
122
+
123
+ ```BibTeX
124
+ @inproceedings{xu-etal-2021-videoclip,
125
+ title = "{VideoCLIP}: Contrastive Pre-training for\\Zero-shot Video-Text Understanding",
126
+ author = "Xu, Hu and
127
+ Ghosh, Gargi and
128
+ Huang, Po-Yao and
129
+ Okhonko, Dmytro and
130
+ Aghajanyan, Armen and
131
+ Metze, Florian and
132
+ Zettlemoyer, Luke and
133
+ Feichtenhofer, Christoph",
134
+ booktitle = "Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing (EMNLP)",
135
+ month = nov,
136
+ year = "2021",
137
+ address = "Online",
138
+ publisher = "Association for Computational Linguistics",
139
+ }
140
+
141
+ @inproceedings{xu-etal-2021-vlm,
142
+ title = "{VLM}: Task-agnostic Video-Language Model Pre-training for Video Understanding",
143
+ author = "Xu, Hu and
144
+ Ghosh, Gargi and
145
+ Huang, Po-Yao and
146
+ Arora, Prahal and
147
+ Aminzadeh, Masoumeh and
148
+ Feichtenhofer, Christoph and
149
+ Metze, Florian and
150
+ Zettlemoyer, Luke",
151
+ booktitle = "Findings of the Association for Computational Linguistics: ACL-IJCNLP 2021",
152
+ month = aug,
153
+ year = "2021",
154
+ address = "Online",
155
+ publisher = "Association for Computational Linguistics",
156
+ url = "https://aclanthology.org/2021.findings-acl.370",
157
+ doi = "10.18653/v1/2021.findings-acl.370",
158
+ pages = "4227--4239",
159
+ }
160
+ ```
161
+
162
+ ### Bug Reports
163
+ This repo is in its initial stage, welcome bug reports to huxu@fb.com
164
+
165
+ ### Copyright
166
+ The majority of Multimodal Pre-training (MMPT) is licensed under CC-BY-NC, however portions of the project are available under separate license terms: Evaluation Codes/Models: Howto100M and HuggingFace Transformers are licensed under the Apache2.0 license; COIN and NLG-eval are licensed under the MIT license; CrossTask is licensed under the BSD-3; DiDeMo is licensed under the BSD-2 license.
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/endtask.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Zero-shot Transfer and Finetuning
2
+
3
+ (If you are new to the ideas of `mmpt.processors`, see [README](README.md) first.)
4
+ All finetuning datasets (specifically `processors`) are defined in `mmpt.processors.dsprocessor`.
5
+ Given the complexity of different types of finetuning tasks, each task may have their own meta/video/text/aligner processors and `mmpt/evaluators/{Predictor,Metric}`.
6
+
7
+ ### Tasks
8
+
9
+ Currently, we support 5 end datasets: `MSRVTT`, `Youcook`, `COIN`, `Crosstask` and `DiDeMo` with the following tasks:
10
+ text-video retrieval: `MSRVTT`, `Youcook`, `DiDeMo`;
11
+ video captioning: `Youcook`;
12
+ Video Question and Answering: `MSRVTT-QA`.
13
+
14
+ To add your own dataset, you can specify the corresponding processors and config them in the `dataset` field of a config file, such as `projects/task/vtt.yaml`.
15
+
16
+ ### Zero-shot Transfer (no Training)
17
+ Zero-shot transfer will run the pre-trained model (e.g., VideoCLIP) directly on testing data. Configs with pattern: `projects/task/*_zs_*.yaml` are dedicated for zero-shot transfer.
18
+
19
+ ### Fine-tuning
20
+
21
+ The training of a downstream task is similar to pretraining, execept you may need to specify the `restore_file` in `fairseq.checkpoint` and reset optimizers, see `projects/task/ft.yaml` that is included by `projects/task/vtt.yaml`.
22
+
23
+ We typically do finetuning on 2 gpus (`local_small`).
24
+
25
+ ### Testing
26
+ For each finetuning dataset, you may need to specify a testing config, similar to `projects/task/test_vtt.yaml`.
27
+
28
+ We define `mmpt.evaluators.Predictor` for different types of prediction. For example, `MSRVTT` and `Youcook` are video-retrieval tasks and expecting to use `RetrievalPredictor`. You may need to define your new type of predictors and specify that in `predictor` field of a testing config.
29
+
30
+ Each task may also have their own metric for evaluation. This can be created in `mmpt.evaluators.Metric` and specified in the `metric` field of a testing config.
31
+
32
+ Launching a testing is as simple as training by specifying the path of a testing config:
33
+ ```python locallaunch.py projects/mfmmlm/test_vtt.yaml```
34
+ Testing will be launched locally by default since prediction is computationally less expensive.
35
+
36
+ ### Third-party Libraries
37
+ We list the following finetuning tasks that require third-party libraries.
38
+
39
+ Youcook captioning: `https://github.com/Maluuba/nlg-eval`
40
+
41
+ CrossTask: `https://github.com/DmZhukov/CrossTask`'s `dp` under `third-party/CrossTask` (`python setup.py build_ext --inplace`)
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/locallaunch.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import argparse
6
+ import os
7
+
8
+ from omegaconf import OmegaConf
9
+
10
+ from mmpt.utils import recursive_config, overwrite_dir
11
+ from mmpt_cli.localjob import LocalJob
12
+
13
+
14
+ class JobLauncher(object):
15
+ JOB_CONFIG = {
16
+ "local": LocalJob,
17
+ }
18
+
19
+ def __init__(self, yaml_file):
20
+ self.yaml_file = yaml_file
21
+ job_key = "local"
22
+
23
+ if yaml_file.endswith(".yaml"):
24
+ config = recursive_config(yaml_file)
25
+ if config.task_type is not None:
26
+ job_key = config.task_type.split("_")[0]
27
+ else:
28
+ raise ValueError("unknown extension of job file:", yaml_file)
29
+ self.job_key = job_key
30
+
31
+ def __call__(self, job_type=None, dryrun=False):
32
+ if job_type is not None:
33
+ self.job_key = job_type.split("_")[0]
34
+ print("[JobLauncher] job_key", self.job_key)
35
+ job = JobLauncher.JOB_CONFIG[self.job_key](
36
+ self.yaml_file, job_type=job_type, dryrun=dryrun)
37
+ return job.submit()
38
+
39
+
40
+ class Pipeline(object):
41
+ """a job that loads yaml config."""
42
+
43
+ def __init__(self, fn):
44
+ """
45
+ load a yaml config of a job and save generated configs as yaml for each task.
46
+ return: a list of files to run as specified by `run_task`.
47
+ """
48
+ if fn.endswith(".py"):
49
+ # a python command.
50
+ self.backend = "python"
51
+ self.run_yamls = [fn]
52
+ return
53
+
54
+ job_config = recursive_config(fn)
55
+ if job_config.base_dir is None: # single file job config.
56
+ self.run_yamls = [fn]
57
+ return
58
+
59
+ self.project_dir = os.path.join("projects", job_config.project_dir)
60
+ self.run_dir = os.path.join("runs", job_config.project_dir)
61
+
62
+ if job_config.run_task is not None:
63
+ run_yamls = []
64
+ for stage in job_config.run_task:
65
+ # each stage can have multiple tasks running in parallel.
66
+ if OmegaConf.is_list(stage):
67
+ stage_yamls = []
68
+ for task_file in stage:
69
+ stage_yamls.append(
70
+ os.path.join(self.project_dir, task_file))
71
+ run_yamls.append(stage_yamls)
72
+ else:
73
+ run_yamls.append(os.path.join(self.project_dir, stage))
74
+ self.run_yamls = run_yamls
75
+ configs_to_save = self._overwrite_task(job_config)
76
+ self._save_configs(configs_to_save)
77
+
78
+ def __getitem__(self, idx):
79
+ yaml_files = self.run_yamls[idx]
80
+ if isinstance(yaml_files, list):
81
+ return [JobLauncher(yaml_file) for yaml_file in yaml_files]
82
+ return [JobLauncher(yaml_files)]
83
+
84
+ def __len__(self):
85
+ return len(self.run_yamls)
86
+
87
+ def _save_configs(self, configs_to_save: dict):
88
+ # save
89
+ os.makedirs(self.project_dir, exist_ok=True)
90
+ for config_file in configs_to_save:
91
+ config = configs_to_save[config_file]
92
+ print("saving", config_file)
93
+ OmegaConf.save(config=config, f=config_file)
94
+
95
+ def _overwrite_task(self, job_config):
96
+ configs_to_save = {}
97
+ self.base_project_dir = os.path.join("projects", job_config.base_dir)
98
+ self.base_run_dir = os.path.join("runs", job_config.base_dir)
99
+
100
+ for config_sets in job_config.task_group:
101
+ overwrite_config = job_config.task_group[config_sets]
102
+ if (
103
+ overwrite_config.task_list is None
104
+ or len(overwrite_config.task_list) == 0
105
+ ):
106
+ print(
107
+ "[warning]",
108
+ job_config.task_group,
109
+ "has no task_list specified.")
110
+ # we don't want this added to a final config.
111
+ task_list = overwrite_config.pop("task_list", None)
112
+ for config_file in task_list:
113
+ config_file_path = os.path.join(
114
+ self.base_project_dir, config_file)
115
+ config = recursive_config(config_file_path)
116
+ # overwrite it.
117
+ if overwrite_config:
118
+ config = OmegaConf.merge(config, overwrite_config)
119
+ overwrite_dir(config, self.run_dir, basedir=self.base_run_dir)
120
+ save_file_path = os.path.join(self.project_dir, config_file)
121
+ configs_to_save[save_file_path] = config
122
+ return configs_to_save
123
+
124
+
125
+ def main(args):
126
+ job_type = args.jobtype if args.jobtype else None
127
+ # parse multiple pipelines.
128
+ pipelines = [Pipeline(fn) for fn in args.yamls.split(",")]
129
+
130
+ for pipe_id, pipeline in enumerate(pipelines):
131
+ if not hasattr(pipeline, "project_dir"):
132
+ for job in pipeline[0]:
133
+ job(job_type=job_type, dryrun=args.dryrun)
134
+
135
+
136
+ if __name__ == "__main__":
137
+ parser = argparse.ArgumentParser()
138
+ parser.add_argument("yamls", type=str)
139
+ parser.add_argument(
140
+ "--dryrun",
141
+ action="store_true",
142
+ help="run config and prepare to submit without launch the job.",
143
+ )
144
+ parser.add_argument(
145
+ "--jobtype", type=str, default="",
146
+ help="force to run jobs as specified.")
147
+ args = parser.parse_args()
148
+ main(args)
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ try:
6
+ # fairseq user dir
7
+ from .datasets import FairseqMMDataset
8
+ from .losses import FairseqCriterion
9
+ from .models import FairseqMMModel
10
+ from .tasks import FairseqMMTask
11
+ except ImportError:
12
+ pass
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/datasets/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from .mmdataset import *
6
+
7
+ try:
8
+ from .fairseqmmdataset import *
9
+ except ImportError:
10
+ pass
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/datasets/fairseqmmdataset.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """
6
+ TODO (huxu): fairseq wrapper class for all dataset you defined: mostly MMDataset.
7
+ """
8
+
9
+ from collections import OrderedDict
10
+
11
+ from torch.utils.data import Dataset
12
+ from torch.utils.data.dataloader import default_collate
13
+ from fairseq.data import FairseqDataset, data_utils
14
+
15
+
16
+ class FairseqMMDataset(FairseqDataset):
17
+ """
18
+ A wrapper class for MMDataset for fairseq.
19
+ """
20
+
21
+ def __init__(self, mmdataset):
22
+ if not isinstance(mmdataset, Dataset):
23
+ raise TypeError("mmdataset must be of type `torch.utils.data.dataset`.")
24
+ self.mmdataset = mmdataset
25
+
26
+ def set_epoch(self, epoch, **unused):
27
+ super().set_epoch(epoch)
28
+ self.epoch = epoch
29
+
30
+ def __getitem__(self, idx):
31
+ with data_utils.numpy_seed(43211, self.epoch, idx):
32
+ return self.mmdataset[idx]
33
+
34
+ def __len__(self):
35
+ return len(self.mmdataset)
36
+
37
+ def collater(self, samples):
38
+ if hasattr(self.mmdataset, "collator"):
39
+ return self.mmdataset.collator(samples)
40
+ if len(samples) == 0:
41
+ return {}
42
+ if isinstance(samples[0], dict):
43
+ batch = OrderedDict()
44
+ for key in samples[0]:
45
+ if samples[0][key] is not None:
46
+ batch[key] = default_collate([sample[key] for sample in samples])
47
+ return batch
48
+ else:
49
+ return default_collate(samples)
50
+
51
+ def size(self, index):
52
+ """dummy implementation: we don't use --max-tokens"""
53
+ return 1
54
+
55
+ def num_tokens(self, index):
56
+ """dummy implementation: we don't use --max-tokens"""
57
+ return 1
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/datasets/mmdataset.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+
8
+ from collections import OrderedDict
9
+
10
+ from torch.utils.data import Dataset
11
+ from torch.utils.data.dataloader import default_collate
12
+
13
+ from ..utils import set_seed
14
+
15
+
16
+ class MMDataset(Dataset):
17
+ """
18
+ A generic multi-modal dataset.
19
+ Args:
20
+ `meta_processor`: a meta processor,
21
+ handling loading meta data and return video_id and text_id.
22
+ `video_processor`: a video processor,
23
+ handling e.g., decoding, loading .np files.
24
+ `text_processor`: a text processor,
25
+ handling e.g., tokenization.
26
+ `aligner`: combine the video and text feature
27
+ as one training example.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ meta_processor,
33
+ video_processor,
34
+ text_processor,
35
+ align_processor,
36
+ ):
37
+ self.split = meta_processor.split
38
+ self.meta_processor = meta_processor
39
+ self.video_processor = video_processor
40
+ self.text_processor = text_processor
41
+ self.align_processor = align_processor
42
+
43
+ def __len__(self):
44
+ return len(self.meta_processor)
45
+
46
+ def __getitem__(self, idx):
47
+ if self.split == "test":
48
+ set_seed(idx)
49
+ video_id, text_id = self.meta_processor[idx]
50
+ video_feature = self.video_processor(video_id)
51
+ text_feature = self.text_processor(text_id)
52
+ output = self.align_processor(video_id, video_feature, text_feature)
53
+ # TODO (huxu): the following is for debug purpose.
54
+ output.update({"idx": idx})
55
+ return output
56
+
57
+ def collater(self, samples):
58
+ """This collator is deprecated.
59
+ set self.collator = MMDataset.collater.
60
+ see collator in FairseqMMDataset.
61
+ """
62
+
63
+ if len(samples) == 0:
64
+ return {}
65
+ if isinstance(samples[0], dict):
66
+ batch = OrderedDict()
67
+ for key in samples[0]:
68
+ if samples[0][key] is not None:
69
+ batch[key] = default_collate(
70
+ [sample[key] for sample in samples])
71
+ # if torch.is_tensor(batch[key]):
72
+ # print(key, batch[key].size())
73
+ # else:
74
+ # print(key, len(batch[key]))
75
+ return batch
76
+ else:
77
+ return default_collate(samples)
78
+
79
+ def print_example(self, output):
80
+ print("[one example]", output["video_id"])
81
+ if (
82
+ hasattr(self.align_processor, "subsampling")
83
+ and self.align_processor.subsampling is not None
84
+ and self.align_processor.subsampling > 1
85
+ ):
86
+ for key in output:
87
+ if torch.is_tensor(output[key]):
88
+ output[key] = output[key][0]
89
+
90
+ # search tokenizer to translate ids back.
91
+ tokenizer = None
92
+ if hasattr(self.text_processor, "tokenizer"):
93
+ tokenizer = self.text_processor.tokenizer
94
+ elif hasattr(self.align_processor, "tokenizer"):
95
+ tokenizer = self.align_processor.tokenizer
96
+ if tokenizer is not None:
97
+ caps = output["caps"].tolist()
98
+ if isinstance(caps[0], list):
99
+ caps = caps[0]
100
+ print("caps", tokenizer.decode(caps))
101
+ print("caps", tokenizer.convert_ids_to_tokens(caps))
102
+
103
+ for key, value in output.items():
104
+ if torch.is_tensor(value):
105
+ if len(value.size()) >= 3: # attention_mask.
106
+ print(key, value.size())
107
+ print(key, "first", value[0, :, :])
108
+ print(key, "last", value[-1, :, :])
109
+ else:
110
+ print(key, value)
111
+ print("[end of one example]")
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/evaluators/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from .metric import *
6
+ from .evaluator import *
7
+
8
+
9
+ # experimental.
10
+ try:
11
+ from .expmetric import *
12
+ except ImportError:
13
+ pass
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/evaluators/evaluator.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import os
6
+ import glob
7
+ import numpy as np
8
+
9
+ from . import metric as metric_path
10
+ from . import predictor as predictor_path
11
+
12
+
13
+ class Evaluator(object):
14
+ """
15
+ perform evaluation on a single (downstream) task.
16
+ make this both offline and online.
17
+ TODO(huxu) saving evaluation results.
18
+ """
19
+
20
+ def __init__(self, config, eval_dataloader=None):
21
+ if config.metric is None:
22
+ raise ValueError("config.metric is", config.metric)
23
+ metric_cls = getattr(metric_path, config.metric)
24
+ self.metric = metric_cls(config)
25
+ if config.predictor is None:
26
+ raise ValueError("config.predictor is", config.predictor)
27
+ predictor_cls = getattr(predictor_path, config.predictor)
28
+ self.predictor = predictor_cls(config)
29
+ self.eval_dataloader = eval_dataloader
30
+
31
+ def __call__(self):
32
+ try:
33
+ print(self.predictor.pred_dir)
34
+ for pred_file in glob.glob(
35
+ self.predictor.pred_dir + "/*_merged.npy"):
36
+ outputs = np.load(pred_file)
37
+ results = self.metric.compute_metrics(outputs)
38
+ self.metric.print_computed_metrics(results)
39
+
40
+ outputs = np.load(os.path.join(
41
+ self.predictor.pred_dir, "merged.npy"))
42
+ results = self.metric.compute_metrics(outputs)
43
+ return {"results": results, "metric": self.metric}
44
+ except FileNotFoundError:
45
+ print("\n[missing]", self.predictor.pred_dir)
46
+ return {}
47
+
48
+ def evaluate(self, model, eval_dataloader=None, output_file="merged"):
49
+ if eval_dataloader is None:
50
+ eval_dataloader = self.eval_dataloader
51
+ outputs = self.predictor.predict_loop(
52
+ model, eval_dataloader, output_file)
53
+ results = self.metric.compute_metrics(**outputs)
54
+ return results
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/evaluators/metric.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import numpy as np
7
+ import json
8
+
9
+
10
+ class Metric(object):
11
+ def __init__(self, config, metric_names):
12
+ self.metric_names = metric_names
13
+
14
+ def best_metric(self, metric):
15
+ return metric[self.metric_names[0]]
16
+
17
+ def save_metrics(self, fn, metrics):
18
+ with open(fn, "w") as fw:
19
+ json.dump(fw, metrics)
20
+
21
+ def print_computed_metrics(self, metrics):
22
+ raise NotImplementedError
23
+
24
+
25
+ class RetrievalMetric(Metric):
26
+ """
27
+ this is modified from `howto100m/metrics.py`.
28
+ History of changes:
29
+ refactor as a class.
30
+ add metric_key in __init__
31
+ """
32
+
33
+ def __init__(self, config, metric_names=["R1", "R5", "R10", "MR"]):
34
+ super().__init__(config, metric_names)
35
+ self.error = False # TODO(huxu): add to config to print error.
36
+
37
+ def compute_metrics(self, outputs, texts, **kwargs):
38
+ x = outputs
39
+ sx = np.sort(-x, axis=1)
40
+ d = np.diag(-x)
41
+ d = d[:, np.newaxis]
42
+ ind = sx - d
43
+ ind = np.where(ind == 0)
44
+ ind = ind[1]
45
+ metrics = {}
46
+ metrics["R1"] = float(np.sum(ind == 0)) / len(ind)
47
+ metrics["R5"] = float(np.sum(ind < 5)) / len(ind)
48
+ metrics["R10"] = float(np.sum(ind < 10)) / len(ind)
49
+ metrics["MR"] = np.median(ind) + 1
50
+
51
+ max_idx = np.argmax(outputs, axis=1)
52
+ if self.error:
53
+ # print top-20 errors.
54
+ error = []
55
+ for ex_idx in range(20):
56
+ error.append((texts[ex_idx], texts[max_idx[ex_idx]]))
57
+ metrics["error"] = error
58
+ return metrics
59
+
60
+ def print_computed_metrics(self, metrics):
61
+ r1 = metrics["R1"]
62
+ r5 = metrics["R5"]
63
+ r10 = metrics["R10"]
64
+ mr = metrics["MR"]
65
+ print(
66
+ "R@1: {:.4f} - R@5: {:.4f} - R@10: {:.4f} - Median R: {}".format(
67
+ r1, r5, r10, mr
68
+ )
69
+ )
70
+ if "error" in metrics:
71
+ print(metrics["error"])
72
+
73
+
74
+ class DiDeMoMetric(Metric):
75
+ """
76
+ History of changes:
77
+ python 2.x to python 3.x.
78
+ merge utils.py into eval to save one file.
79
+ reference: https://github.com/LisaAnne/LocalizingMoments/blob/master/utils/eval.py
80
+ Code to evaluate your results on the DiDeMo dataset.
81
+ """
82
+ def __init__(self, config, metric_names=["rank1", "rank5", "miou"]):
83
+ super().__init__(config, metric_names)
84
+
85
+ def compute_metrics(self, outputs, targets, **kwargs):
86
+ assert len(outputs) == len(targets)
87
+ rank1, rank5, miou = self._eval_predictions(outputs, targets)
88
+ metrics = {
89
+ "rank1": rank1,
90
+ "rank5": rank5,
91
+ "miou": miou
92
+ }
93
+ return metrics
94
+
95
+ def print_computed_metrics(self, metrics):
96
+ rank1 = metrics["rank1"]
97
+ rank5 = metrics["rank5"]
98
+ miou = metrics["miou"]
99
+ # print("Average rank@1: %f" % rank1)
100
+ # print("Average rank@5: %f" % rank5)
101
+ # print("Average iou: %f" % miou)
102
+
103
+ print(
104
+ "Average rank@1: {:.4f} Average rank@5: {:.4f} Average iou: {:.4f}".format(
105
+ rank1, rank5, miou
106
+ )
107
+ )
108
+
109
+ def _iou(self, pred, gt):
110
+ intersection = max(0, min(pred[1], gt[1]) + 1 - max(pred[0], gt[0]))
111
+ union = max(pred[1], gt[1]) + 1 - min(pred[0], gt[0])
112
+ return float(intersection)/union
113
+
114
+ def _rank(self, pred, gt):
115
+ return pred.index(tuple(gt)) + 1
116
+
117
+ def _eval_predictions(self, segments, data):
118
+ '''
119
+ Inputs:
120
+ segments: For each item in the ground truth data, rank possible video segments given the description and video.
121
+ In DiDeMo, there are 21 posible moments extracted for each video so the list of video segments will be of length 21.
122
+ The first video segment should be the video segment that best corresponds to the text query.
123
+ There are 4180 sentence in the validation data, so when evaluating a model on the val dataset,
124
+ segments should be a list of lenght 4180, and each item in segments should be a list of length 21.
125
+ data: ground truth data
126
+ '''
127
+ average_ranks = []
128
+ average_iou = []
129
+ for s, d in zip(segments, data):
130
+ pred = s[0]
131
+ ious = [self._iou(pred, t) for t in d['times']]
132
+ average_iou.append(np.mean(np.sort(ious)[-3:]))
133
+ ranks = [self._rank(s, t) for t in d['times'] if tuple(t) in s] # if t in s] is added for s, e not in prediction.
134
+ average_ranks.append(np.mean(np.sort(ranks)[:3]))
135
+ rank1 = np.sum(np.array(average_ranks) <= 1)/float(len(average_ranks))
136
+ rank5 = np.sum(np.array(average_ranks) <= 5)/float(len(average_ranks))
137
+ miou = np.mean(average_iou)
138
+
139
+ # print("Average rank@1: %f" % rank1)
140
+ # print("Average rank@5: %f" % rank5)
141
+ # print("Average iou: %f" % miou)
142
+ return rank1, rank5, miou
143
+
144
+
145
+ class NLGMetric(Metric):
146
+ def __init__(
147
+ self,
148
+ config,
149
+ metric_names=[
150
+ "Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4",
151
+ "METEOR", "ROUGE_L", "CIDEr"
152
+ ]
153
+ ):
154
+ super().__init__(config, metric_names)
155
+ # please install NLGEval from `https://github.com/Maluuba/nlg-eval`
156
+ from nlgeval import NLGEval
157
+ self.nlg = NLGEval()
158
+
159
+ def compute_metrics(self, outputs, targets, **kwargs):
160
+ return self.nlg.compute_metrics(
161
+ hyp_list=outputs, ref_list=targets)
162
+
163
+ def print_computed_metrics(self, metrics):
164
+ Bleu_1 = metrics["Bleu_1"]
165
+ Bleu_2 = metrics["Bleu_2"]
166
+ Bleu_3 = metrics["Bleu_3"]
167
+ Bleu_4 = metrics["Bleu_4"]
168
+ METEOR = metrics["METEOR"]
169
+ ROUGE_L = metrics["ROUGE_L"]
170
+ CIDEr = metrics["CIDEr"]
171
+
172
+ print(
173
+ "Bleu_1: {:.4f} - Bleu_2: {:.4f} - Bleu_3: {:.4f} - Bleu_4: {:.4f} - METEOR: {:.4f} - ROUGE_L: {:.4f} - CIDEr: {:.4f}".format(
174
+ Bleu_1, Bleu_2, Bleu_3, Bleu_4, METEOR, ROUGE_L, CIDEr
175
+ )
176
+ )
177
+
178
+
179
+ class QAMetric(Metric):
180
+ def __init__(
181
+ self,
182
+ config,
183
+ metric_names=["acc"]
184
+ ):
185
+ super().__init__(config, metric_names)
186
+
187
+ def compute_metrics(self, outputs, targets, **kwargs):
188
+ from sklearn.metrics import accuracy_score
189
+ return {"acc": accuracy_score(targets, outputs)}
190
+
191
+ def print_computed_metrics(self, metrics):
192
+ print("acc: {:.4f}".format(metrics["acc"]))
193
+
194
+
195
+ class COINActionSegmentationMetric(Metric):
196
+ """
197
+ COIN dataset listed 3 repos for Action Segmentation.
198
+ Action Sets, NeuralNetwork-Viterbi, TCFPN-ISBA.
199
+ The first and second are the same.
200
+ https://github.com/alexanderrichard/action-sets/blob/master/eval.py
201
+
202
+ Future reference for the third:
203
+ `https://github.com/Zephyr-D/TCFPN-ISBA/blob/master/utils/metrics.py`
204
+ """
205
+ def __init__(self, config, metric_name=["frame_acc"]):
206
+ super().__init__(config, metric_name)
207
+
208
+ def compute_metrics(self, outputs, targets):
209
+ n_frames = 0
210
+ n_errors = 0
211
+ n_errors = sum(outputs != targets)
212
+ n_frames = len(targets)
213
+ return {"frame_acc": 1.0 - float(n_errors) / n_frames}
214
+
215
+ def print_computed_metrics(self, metrics):
216
+ fa = metrics["frame_acc"]
217
+ print("frame accuracy:", fa)
218
+
219
+
220
+ class CrossTaskMetric(Metric):
221
+ def __init__(self, config, metric_names=["recall"]):
222
+ super().__init__(config, metric_names)
223
+
224
+ def compute_metrics(self, outputs, targets, **kwargs):
225
+ """refactored from line 166:
226
+ https://github.com/DmZhukov/CrossTask/blob/master/train.py"""
227
+
228
+ recalls = self._get_recalls(Y_true=targets, Y_pred=outputs)
229
+ results = {}
230
+ for task, rec in recalls.items():
231
+ results[str(task)] = rec
232
+
233
+ avg_recall = np.mean(list(recalls.values()))
234
+ results["recall"] = avg_recall
235
+ return results
236
+
237
+ def print_computed_metrics(self, metrics):
238
+ print('Recall: {0:0.3f}'.format(metrics["recall"]))
239
+ for task in metrics:
240
+ if task != "recall":
241
+ print('Task {0}. Recall = {1:0.3f}'.format(
242
+ task, metrics[task]))
243
+
244
+ def _get_recalls(self, Y_true, Y_pred):
245
+ """refactored from
246
+ https://github.com/DmZhukov/CrossTask/blob/master/train.py"""
247
+
248
+ step_match = {task: 0 for task in Y_true.keys()}
249
+ step_total = {task: 0 for task in Y_true.keys()}
250
+ for task, ys_true in Y_true.items():
251
+ ys_pred = Y_pred[task]
252
+ for vid in set(ys_pred.keys()).intersection(set(ys_true.keys())):
253
+ y_true = ys_true[vid]
254
+ y_pred = ys_pred[vid]
255
+ step_total[task] += (y_true.sum(axis=0) > 0).sum()
256
+ step_match[task] += (y_true*y_pred).sum()
257
+ recalls = {
258
+ task: step_match[task] / n for task, n in step_total.items()}
259
+ return recalls
260
+
261
+
262
+ class ActionRecognitionMetric(Metric):
263
+ def __init__(
264
+ self,
265
+ config,
266
+ metric_names=["acc", "acc_splits", "r1_splits", "r5_splits", "r10_splits"]
267
+ ):
268
+ super().__init__(config, metric_names)
269
+
270
+ def compute_metrics(self, outputs, targets, splits, **kwargs):
271
+ all_video_embd = outputs
272
+ labels = targets
273
+ split1, split2, split3 = splits
274
+ accs = []
275
+ r1s = []
276
+ r5s = []
277
+ r10s = []
278
+ for split in range(3):
279
+ if split == 0:
280
+ s = split1
281
+ elif split == 1:
282
+ s = split2
283
+ else:
284
+ s = split3
285
+
286
+ X_pred = all_video_embd[np.where(s == 2)[0]]
287
+ label_test = labels[np.where(s == 2)[0]]
288
+ logits = X_pred
289
+ X_pred = np.argmax(X_pred, axis=1)
290
+ acc = np.sum(X_pred == label_test) / float(len(X_pred))
291
+ accs.append(acc)
292
+ # compute recall.
293
+ sorted_pred = (-logits).argsort(axis=-1)
294
+ label_test_sp = label_test.reshape(-1, 1)
295
+
296
+ r1 = np.mean((sorted_pred[:, :1] == label_test_sp).sum(axis=1), axis=0)
297
+ r5 = np.mean((sorted_pred[:, :5] == label_test_sp).sum(axis=1), axis=0)
298
+ r10 = np.mean((sorted_pred[:, :10] == label_test_sp).sum(axis=1), axis=0)
299
+ r1s.append(r1)
300
+ r5s.append(r5)
301
+ r10s.append(r10)
302
+
303
+ return {"acc": accs[0], "acc_splits": accs, "r1_splits": r1s, "r5_splits": r5s, "r10_splits": r10s}
304
+
305
+ def print_computed_metrics(self, metrics):
306
+ for split, acc in enumerate(metrics["acc_splits"]):
307
+ print("Top 1 accuracy on split {}: {}; r1 {}; r5 {}; r10 {}".format(
308
+ split + 1, acc,
309
+ metrics["r1_splits"][split],
310
+ metrics["r5_splits"][split],
311
+ metrics["r10_splits"][split],
312
+ )
313
+ )
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/evaluators/predictor.py ADDED
@@ -0,0 +1,595 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import os
6
+ import random
7
+ import json
8
+ import numpy as np
9
+ import torch
10
+ import pickle
11
+ import math
12
+
13
+ from tqdm import tqdm
14
+
15
+
16
+ class Predictor(object):
17
+ """this base class is used to save predictions to disk
18
+ (and being called by a evaluator later).
19
+ Predictor has minimum support of single gpu prediction.
20
+ """
21
+ def __init__(self, config):
22
+ self.pred_dir = None # on-the-fly eval does not save the results.
23
+ if hasattr(config, "eval") and config.eval is not None:
24
+ self.pred_dir = config.eval.save_path
25
+ os.makedirs(self.pred_dir, exist_ok=True)
26
+
27
+ def __call__(self, outputs):
28
+ """extract the prediction and save it."""
29
+ raise NotImplementedError
30
+
31
+ def predict_loop(self, model, eval_dataloader, output_file=None):
32
+ """on-the-fly prediction on a single gpu."""
33
+ self.full_scores = []
34
+ model.eval()
35
+ model = model.to(0)
36
+ with torch.no_grad():
37
+ for data in eval_dataloader:
38
+ data = self.to_ctx(data)
39
+ outputs = model(**data)
40
+ outputs.update(data)
41
+ self(outputs)
42
+ return self.finalize(output_file)
43
+
44
+ def finalize(self, output_file):
45
+ pass
46
+
47
+ def to_ctx(self, data, ctx=0, dtype=None):
48
+ if isinstance(data, dict):
49
+ for key in data:
50
+ if torch.is_tensor(data[key]):
51
+ if dtype is not None and data[key].dtype == torch.float32:
52
+ data[key] = data[key].to(dtype)
53
+ data[key] = data[key].to(ctx)
54
+ return data
55
+ else:
56
+ raise ValueError("non-dict type of batch is not supported yet.")
57
+
58
+
59
+ class NLGPredictor(Predictor):
60
+ """Predicting Text from MMFusion models."""
61
+ """TODO: make a context."""
62
+ def __init__(self, config):
63
+ super().__init__(config)
64
+ from transformers import AutoTokenizer
65
+
66
+ self.tokenizer = AutoTokenizer.from_pretrained(
67
+ config.dataset.bert_name,
68
+ bos_token="[CLS]", eos_token="[SEP]")
69
+ self.bos_token_id = self.tokenizer.bos_token_id
70
+ self.eos_token_id = self.tokenizer.eos_token_id
71
+
72
+ def predict_loop(self, model, eval_dataloader, output_file=None):
73
+ """TODO: refactor base classes."""
74
+ ctx = 0
75
+ outputs = {"outputs": [], "targets": [[]]}
76
+ model.eval()
77
+ model = model.to(ctx)
78
+ with torch.no_grad():
79
+ for data in tqdm(eval_dataloader):
80
+ data = self.to_ctx(data, ctx)
81
+ self(data, model, outputs)
82
+ return self.finalize(outputs, output_file)
83
+
84
+ def __call__(self, data, model, outputs):
85
+ data.update({
86
+ "bos_token_id": self.bos_token_id,
87
+ "eos_token_id": self.eos_token_id
88
+ })
89
+
90
+ output = model.generate(**data)
91
+ assert len(output) == len(data["ref"])
92
+ for idx, _output in enumerate(output):
93
+ generated_text = self.tokenizer.decode(
94
+ _output, skip_special_tokens=True)
95
+ if generated_text == "":
96
+ generated_text = "none"
97
+ outputs["outputs"].append(generated_text)
98
+ outputs["targets"][0].append(data["ref"][idx])
99
+ if random.random() < 0.001:
100
+ print("_output", _output)
101
+ print("generated_text", generated_text)
102
+ print("ref", data["ref"][idx])
103
+
104
+ def finalize(self, outputs, output_file=None):
105
+ if output_file is not None:
106
+ with open(os.path.join(
107
+ self.pred_dir, output_file + ".json"), "w") as fw:
108
+ json.dump(outputs, fw, indent=4)
109
+ return outputs
110
+
111
+
112
+ class RetrievalPredictor(Predictor):
113
+ """generated `pooled_video` and `pooled_text`."""
114
+ def __init__(self, config):
115
+ super().__init__(config)
116
+ from transformers import AutoTokenizer
117
+ self.tokenizer = AutoTokenizer.from_pretrained(
118
+ config.dataset.bert_name)
119
+
120
+ def predict_loop(
121
+ self,
122
+ model,
123
+ eval_dataloader,
124
+ output_file="retrieval.npy"
125
+ ):
126
+ """on-the-fly prediction on a single gpu."""
127
+ full_scores = []
128
+ texts = []
129
+ model.eval()
130
+ model = model.cuda()
131
+ with torch.no_grad():
132
+ for data in eval_dataloader:
133
+ # convert to dict.
134
+ if not isinstance(data, dict):
135
+ data = {
136
+ "caps": data[0],
137
+ "cmasks": data[1],
138
+ "vfeats": data[2],
139
+ "vmasks": data[3],
140
+ "video_id": data[4]
141
+ }
142
+ data = self.to_ctx(data)
143
+ outputs = model(**data)
144
+ outputs.update(data)
145
+ self(outputs, full_scores)
146
+ for _cap in data["caps"]:
147
+ texts.append(
148
+ self.tokenizer.decode(_cap, skip_special_tokens=True)
149
+ )
150
+
151
+ return self.finalize(full_scores, texts, output_file)
152
+
153
+ def __call__(self, sample, full_scores):
154
+ scores = self._get_pooled_outputs(sample)
155
+ self._append_scores(scores, full_scores)
156
+
157
+ def finalize(self, full_scores, texts, output_file=None):
158
+ outputs = self._aggregate_scores(full_scores)
159
+ if output_file is not None:
160
+ np.save(os.path.join(self.pred_dir, output_file + ".npy"), outputs)
161
+ return {"outputs": outputs, "texts": texts}
162
+
163
+ def _get_pooled_outputs(self, outputs):
164
+ if "pooled_video" in outputs:
165
+ return outputs["pooled_video"], outputs["pooled_text"]
166
+ else:
167
+ raise ValueError("unknown format of outputs.")
168
+
169
+ def _append_scores(self, scores, full_scores):
170
+ assert len(scores) == 2
171
+ if len(full_scores) == 0:
172
+ full_scores.append([])
173
+ full_scores.append([])
174
+ full_scores[0].append(scores[0].cpu().detach().numpy())
175
+ full_scores[1].append(scores[1].cpu().detach().numpy())
176
+
177
+ def _aggregate_scores(self, scores):
178
+ assert len(scores) == 2
179
+ video_hidden = np.concatenate(scores[0], axis=0)
180
+ text_hidden = np.concatenate(scores[1], axis=0)
181
+ # clear up.
182
+ self.full_scores = []
183
+ return np.matmul(text_hidden, video_hidden.T)
184
+
185
+
186
+ class QAPredictor(Predictor):
187
+ """generated `pooled_video` and `pooled_text`."""
188
+ def __init__(self, config):
189
+ super().__init__(config)
190
+ """predictor maintains scores and aggregate them."""
191
+
192
+ def predict_loop(self, model, eval_dataloader, output_file="qa.npy"):
193
+ """on-the-fly prediction on a single gpu."""
194
+ self.full_scores = []
195
+ model.eval()
196
+ model = model.cuda()
197
+ with torch.no_grad():
198
+ for data in eval_dataloader:
199
+ # reshape ans and dup video 5 times.
200
+ v_len = data["vfeats"].size(1)
201
+ hidden_size = data["vfeats"].size(2)
202
+ data["vfeats"] = data["vfeats"].unsqueeze(1).repeat(1, 5, 1, 1).view(-1, v_len, hidden_size)
203
+ data["vmasks"] = data["vmasks"].unsqueeze(1).repeat(1, 5, 1).view(-1, v_len)
204
+
205
+ t_len = data["caps"].size(-1)
206
+ data["caps"] = data["caps"].view(-1, t_len)
207
+ data["cmasks"] = data["cmasks"].view(-1, t_len)
208
+
209
+ data = self.to_ctx(data)
210
+ outputs = model(**data)
211
+ outputs.update(data)
212
+ self(outputs)
213
+ return self.finalize(output_file)
214
+
215
+ def __call__(self, sample):
216
+ hidden_size = sample["pooled_video"].size(-1)
217
+ pooled_video = sample["pooled_video"].view(-1, 5, hidden_size)
218
+ pooled_text = sample["pooled_text"].view(-1, 5, hidden_size)
219
+ scores = torch.bmm(pooled_video, pooled_text.transpose(2, 1))
220
+ scores = scores.argmax(-1)
221
+ self._append_scores(scores[:, 0], sample["answers"], self.full_scores)
222
+
223
+ def finalize(self, output_file=None):
224
+ outputs, targets = self._aggregate_scores(self.full_scores)
225
+ if output_file is not None:
226
+ np.save(os.path.join(self.pred_dir, output_file + ".npy"), outputs)
227
+ return {"outputs": outputs, "targets": targets}
228
+
229
+ def _append_scores(self, scores, answers, full_scores):
230
+ if len(full_scores) == 0:
231
+ full_scores.append([])
232
+ full_scores.append([])
233
+ full_scores[0].append(scores.cpu().detach().numpy())
234
+ full_scores[1].append(answers.cpu().detach().numpy())
235
+
236
+ def _aggregate_scores(self, scores):
237
+ assert len(scores) == 2
238
+ outputs = np.concatenate(scores[0], axis=0)
239
+ targets = np.concatenate(scores[1], axis=0)
240
+ # clear up.
241
+ self.full_scores = []
242
+ return outputs, targets
243
+
244
+
245
+ class CrossTaskPredictor(Predictor):
246
+ """
247
+ CrossTaskPredictor needs to compute the average of logits
248
+ for overlapped sliding-window.
249
+ """
250
+ def __init__(self, config):
251
+ super().__init__(config)
252
+ self.lsm = torch.nn.LogSoftmax(dim=1)
253
+ self.max_video_len = config.dataset.max_video_len
254
+ self.sliding_window = config.dataset.sliding_window
255
+ self.sliding_window_size = config.dataset.sliding_window_size
256
+ self.annotation_path = config.dataset.annotation_path
257
+
258
+ def predict_loop(self, model, eval_dataloader, output_file="result.pkl"):
259
+ """refactored from line 144:
260
+ https://github.com/DmZhukov/CrossTask/blob/master/train.py
261
+ """
262
+ ctx = 0
263
+ model.eval()
264
+ model = model.to(ctx)
265
+ # this is not a loss but just compute neg_log_prob.
266
+ Y_pred = {}
267
+ Y_true = {}
268
+ with torch.no_grad():
269
+ for batch in eval_dataloader:
270
+ self(batch, model, Y_pred, Y_true)
271
+ return self.finalize(Y_pred, Y_true, output_file)
272
+
273
+ def __call__(self, sample, model, Y_pred, Y_true):
274
+ # please install dp from `https://github.com/DmZhukov/CrossTask`
275
+ from dp import dp
276
+ vid, task = sample['video_id'][0], sample['task'][0]
277
+ sample = self.to_ctx(sample)
278
+ # compute the average logits over sliding windows.
279
+ output = model(**sample)
280
+ batch_logits = output["logits"].cpu()
281
+
282
+ video_len = sample["video_len"][0]
283
+
284
+ # the following version is slow.
285
+ logits = torch.zeros((video_len, batch_logits.size(1)))
286
+ logits_counts = torch.zeros((video_len, 1), dtype=torch.long)
287
+ # use the same loop as aligner to recover.
288
+ batch_logit_idx = 0
289
+ for window_start in range(0, video_len, self.sliding_window):
290
+ video_end = min(video_len - window_start, self.sliding_window_size)
291
+ logits[window_start: window_start + video_end] += batch_logits[
292
+ batch_logit_idx: batch_logit_idx + video_end]
293
+ batch_logit_idx += video_end
294
+ logits_counts[window_start: window_start + video_end] += torch.ones((video_end, 1), dtype=torch.long)
295
+
296
+ if (video_len - window_start) <= self.sliding_window_size:
297
+ break
298
+
299
+ logits /= logits_counts
300
+ assert logits.size() == (video_len, batch_logits.size(1)), "{}, {}".format(logits.size(), video_len)
301
+
302
+ O = self.lsm(logits)
303
+ y = np.zeros(O.size(), dtype=np.float32)
304
+ dp(y, -O.detach().cpu().numpy())
305
+ if task not in Y_pred:
306
+ Y_pred[task] = {}
307
+ Y_pred[task][vid] = y
308
+ annot_path = os.path.join(
309
+ self.annotation_path, task+'_'+vid+'.csv')
310
+ if os.path.exists(annot_path):
311
+ if task not in Y_true:
312
+ Y_true[task] = {}
313
+ Y_true[task][vid] = self._read_assignment(
314
+ *y.shape, annot_path)
315
+
316
+ def finalize(self, Y_pred, Y_true, output_file=None):
317
+ if output_file is not None:
318
+ with open(
319
+ os.path.join(self.pred_dir, output_file + ".pkl"),
320
+ "wb") as fw:
321
+ pickle.dump(
322
+ {"Y_pred": Y_pred, "Y_true": Y_true}, fw,
323
+ protocol=pickle.HIGHEST_PROTOCOL)
324
+ return {"outputs": Y_pred, "targets": Y_true}
325
+
326
+ def _read_assignment(self, T, K, path):
327
+ """
328
+ refactored from https://github.com/DmZhukov/CrossTask/blob/master/data.py
329
+ Howto interpret contraints on loss that is going to be minimized:
330
+ lambd is a big number;
331
+ self.lambd * C is a big number for all valid position (csv stores invalids)
332
+
333
+ def forward(self, O, Y, C):
334
+ return (Y*(self.lambd * C - self.lsm(O))).mean(dim=0).sum()
335
+
336
+ This will load the csv file and fill-in the step col from start to end rows.
337
+ """
338
+
339
+ Y = np.zeros([T, K], dtype=np.uint8)
340
+ with open(path, 'r') as f:
341
+ for line in f:
342
+ step, start, end = line.strip().split(',')
343
+ start = int(math.floor(float(start)))
344
+ end = int(math.ceil(float(end)))
345
+ step = int(step) - 1
346
+ Y[start:end, step] = 1
347
+ return Y
348
+
349
+
350
+ class COINPredictor(Predictor):
351
+ """
352
+ COINPredictor is similar to CrossTask on sliding windows.
353
+ """
354
+ def __init__(self, config):
355
+ super().__init__(config)
356
+ self.max_video_len = config.dataset.max_video_len
357
+ self.sliding_window = config.dataset.sliding_window
358
+ self.sliding_window_size = config.dataset.sliding_window_size
359
+
360
+ def predict_loop(self, model, eval_dataloader, output_file="result.pkl"):
361
+ """refactored from line 144:
362
+ https://github.com/DmZhukov/CrossTask/blob/master/train.py
363
+ """
364
+ ctx = 0
365
+ model.eval()
366
+ model = model.to(ctx)
367
+ # this is not a loss but just compute neg_log_prob.
368
+ Y_pred = []
369
+ Y_true = []
370
+ with torch.no_grad():
371
+ for batch in eval_dataloader:
372
+ self(batch, model, Y_pred, Y_true)
373
+ return self.finalize(Y_pred, Y_true, output_file)
374
+
375
+ def __call__(self, sample, model, Y_pred, Y_true):
376
+ sample = self.to_ctx(sample)
377
+ # compute the average logits over sliding windows.
378
+ output = model(**sample)
379
+ logits = self._merge_windows(sample, output)
380
+ Y_pred.append(logits.argmax(dim=1))
381
+ Y_true.append(sample["video_targets"].squeeze(0).cpu())
382
+
383
+ def _merge_windows(self, sample, output):
384
+ targets = sample["targets"].reshape(-1).cpu()
385
+ valid_mask = targets != -100
386
+ targets = targets[valid_mask]
387
+ batch_logits = output["logits"].cpu()
388
+ batch_logits = batch_logits.reshape(-1, batch_logits.size(-1))
389
+ batch_logits = batch_logits[valid_mask]
390
+
391
+ video_len = sample["video_len"][0]
392
+
393
+ # the following version is slow.
394
+ logits = torch.zeros((video_len, batch_logits.size(1)))
395
+ logits_counts = torch.zeros((video_len, 1), dtype=torch.long)
396
+ # use the same loop as aligner to recover.
397
+ batch_logit_idx = 0
398
+ for window_start in range(0, video_len, self.sliding_window):
399
+ video_end = min(video_len - window_start, self.sliding_window_size)
400
+ logits[window_start: window_start + video_end] += batch_logits[
401
+ batch_logit_idx: batch_logit_idx + video_end]
402
+ batch_logit_idx += video_end
403
+ logits_counts[window_start: window_start + video_end] += torch.ones((video_end, 1), dtype=torch.long)
404
+ if (video_len - window_start) <= self.sliding_window_size:
405
+ break
406
+ logits /= logits_counts
407
+ assert logits.size() == (video_len, batch_logits.size(1)), "{}, {}".format(logits.size(), video_len)
408
+ return logits
409
+
410
+ def finalize(self, Y_pred, Y_true, output_file=None):
411
+ Y_pred = torch.cat(Y_pred, dim=0).numpy()
412
+ Y_true = torch.cat(Y_true, dim=0).numpy()
413
+ assert len(Y_pred) == len(Y_true)
414
+
415
+ error_mask = Y_pred != Y_true
416
+ print("sample error", Y_pred[error_mask][:10], Y_true[error_mask][:10])
417
+ print("sample error", Y_pred[error_mask][10:20], Y_true[error_mask][10:20])
418
+
419
+ if output_file is not None:
420
+ with open(
421
+ os.path.join(self.pred_dir, output_file + ".pkl"),
422
+ "wb") as fw:
423
+ pickle.dump(
424
+ {"Y_pred": Y_pred, "Y_true": Y_true}, fw,
425
+ protocol=pickle.HIGHEST_PROTOCOL)
426
+ return {"outputs": Y_pred, "targets": Y_true}
427
+
428
+
429
+ class COINZSPredictor(COINPredictor):
430
+ """
431
+ COINZSPredictor for COIN zero-shot prediction.
432
+ """
433
+
434
+ def __init__(self, config):
435
+ super().__init__(config)
436
+ self.dataset_config = config.dataset
437
+
438
+ def predict_loop(self, model, eval_dataloader, output_file="result.pkl"):
439
+ """refactored from line 144:
440
+ https://github.com/DmZhukov/CrossTask/blob/master/train.py
441
+ """
442
+ ctx = 0
443
+ model.eval()
444
+ model = model.to(ctx)
445
+
446
+ with torch.no_grad():
447
+ outputs = eval_dataloader.dataset.meta_processor.meta_text_labels(
448
+ self.dataset_config)
449
+ outputs = self.to_ctx(outputs, ctx)
450
+ label_hidden_states = model.forward_text(**outputs).cpu()
451
+ label_sim = label_hidden_states @ label_hidden_states.t()
452
+ num_labels = label_sim.size(0)
453
+ eye_mask = ~torch.eye(num_labels, dtype=torch.bool)
454
+ label_sim = label_sim.masked_select(eye_mask).view(num_labels, num_labels - 1)
455
+ lbd = label_sim.max()
456
+
457
+ # this is not a loss but just compute neg_log_prob.
458
+ Y_pred = []
459
+ Y_true = []
460
+ with torch.no_grad():
461
+ for batch in eval_dataloader:
462
+ self(batch, label_hidden_states, model, lbd, Y_pred, Y_true)
463
+ return self.finalize(Y_pred, Y_true, output_file)
464
+
465
+ def reshape_subsample(self, sample):
466
+ for key in sample:
467
+ if torch.is_tensor(sample[key]):
468
+ sample[key] = self.flat_subsample(sample[key])
469
+ return sample
470
+
471
+ def flat_subsample(self, tensor):
472
+ if len(tensor.size()) > 1 and tensor.size(0) == 1:
473
+ tensor = tensor.squeeze(0)
474
+ return tensor
475
+
476
+ def __call__(self, sample, label_hidden_states, model, lbd, Y_pred, Y_true):
477
+ sample = self.reshape_subsample(sample)
478
+ sample = self.to_ctx(sample)
479
+ # compute the average logits over sliding windows.
480
+ sample["output_hidden_states"] = True
481
+ video_outputs = model.forward_video(**sample).cpu()
482
+ output = {"logits": video_outputs[:, 1:sample["vmasks"].size(1)+1] @ label_hidden_states.t()}
483
+ logits = self._merge_windows(sample, output)
484
+ # logic of zero-shot for sequence labeling.
485
+ logits_argmax = logits.argmax(dim=1) + 1 # 0 is "O" label.
486
+ logits_max = logits.max(dim=1)[0]
487
+
488
+ pred = torch.zeros_like(logits_argmax)
489
+ label_select = logits_max > lbd # 73 or 74
490
+ pred[label_select] = logits_argmax[label_select]
491
+
492
+ Y_pred.append(pred)
493
+ Y_true.append(sample["video_targets"].squeeze(0).cpu())
494
+
495
+ def finalize(self, Y_pred, Y_true, output_file=None):
496
+ Y_pred = torch.cat(Y_pred, dim=0).numpy()
497
+ Y_true = torch.cat(Y_true, dim=0).numpy()
498
+ assert len(Y_pred) == len(Y_true)
499
+
500
+ error_mask = Y_pred != Y_true
501
+ print("sample error", Y_pred[error_mask][:10], Y_true[error_mask][:10])
502
+ print("sample error", Y_pred[error_mask][10:20], Y_true[error_mask][10:20])
503
+
504
+ if output_file is not None:
505
+ with open(
506
+ os.path.join(self.pred_dir, output_file + ".pkl"),
507
+ "wb") as fw:
508
+ pickle.dump(
509
+ {"Y_pred": Y_pred, "Y_true": Y_true}, fw,
510
+ protocol=pickle.HIGHEST_PROTOCOL)
511
+ return {"outputs": Y_pred, "targets": Y_true}
512
+
513
+
514
+ class DiDeMoPredictor(Predictor):
515
+ """reference: https://github.com/LisaAnne/LocalizingMoments/blob/master/utils/eval.py
516
+ https://github.com/LisaAnne/LocalizingMoments/blob/master/utils/data_processing.py
517
+ """
518
+ def __init__(self, config):
519
+ super().__init__(config)
520
+ # load targets.
521
+ with open(config.dataset.test_path) as data_file:
522
+ self.test_data = json.load(data_file)
523
+
524
+ def predict_loop(self, model, eval_dataloader, output_file="didemo.npy"):
525
+ """
526
+ TODO: two solutions here.
527
+ """
528
+ import itertools
529
+ # 21 chunks.
530
+ self.possible_segments = [(0,0), (1,1), (2,2), (3,3), (4,4), (5,5)]
531
+ for i in itertools.combinations(range(6), 2):
532
+ self.possible_segments.append(i)
533
+ # pick segments from a video.
534
+
535
+ """on-the-fly prediction on a single gpu."""
536
+ self.full_scores = []
537
+ model.eval()
538
+ model = model.cuda()
539
+ with torch.no_grad():
540
+ for data in eval_dataloader:
541
+ # TODO special forwarding logic here.
542
+ data = self.to_ctx(data)
543
+ data["output_hidden_states"] = True
544
+ hidden_video = model.forward_video(**data)
545
+ data["output_hidden_states"] = False
546
+ pooled_text = model.forward_text(**data)
547
+ outputs = {
548
+ "hidden_video": hidden_video,
549
+ "pooled_text": pooled_text
550
+ }
551
+ outputs.update(data)
552
+ self(outputs)
553
+ return self.finalize(output_file)
554
+
555
+ def __call__(self, sample):
556
+ # TODO: make an index select from self.possible_segments.
557
+ hidden_video = sample["hidden_video"]
558
+ pooled_text = sample["pooled_text"]
559
+ vmasks = sample["vmasks"]
560
+ # probably maintain valid results here.
561
+
562
+ hidden_video = hidden_video[:, 1:-1, :]
563
+ # probably maintain valid results here.
564
+ pooled_video = []
565
+ for s, e in self.possible_segments:
566
+ pooled_video.append(
567
+ torch.mean(
568
+ hidden_video[:, int(s*5):int((e+1)*5), :],
569
+ dim=1, keepdim=True)
570
+ )
571
+ pooled_video = torch.cat(pooled_video, dim=1)
572
+ scores = torch.bmm(
573
+ pooled_video, pooled_text.unsqueeze(-1)).squeeze(-1).cpu()
574
+
575
+ ranks = scores.argsort(dim=-1, descending=True)
576
+
577
+ for batch_idx, rank in enumerate(ranks):
578
+ rank_of_moment = []
579
+ for m_idx, moment in enumerate(rank):
580
+ s, e = self.possible_segments[moment.item()]
581
+ if torch.any(
582
+ vmasks[batch_idx, int(s*5):int((e+1)*5)]
583
+ ):
584
+ rank_of_moment.append((s, e))
585
+ self.full_scores.append(rank_of_moment)
586
+
587
+ def finalize(self, output_file=None):
588
+ outputs = self._aggregate_scores(self.full_scores)
589
+ if output_file is not None:
590
+ np.save(os.path.join(self.pred_dir, output_file + ".npy"), outputs)
591
+ return {"outputs": outputs, "targets": self.test_data}
592
+
593
+ def _aggregate_scores(self, scores):
594
+ self.full_scores = []
595
+ return scores
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/losses/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from .loss import *
6
+ from .nce import *
7
+
8
+ try:
9
+ from .fairseqmmloss import *
10
+ except ImportError:
11
+ pass
12
+
13
+ try:
14
+ from .expnce import *
15
+ except ImportError:
16
+ pass
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/losses/fairseqmmloss.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """
7
+ TODO (huxu): a general fairseq criterion for all your pre-defined losses.
8
+ """
9
+
10
+ from fairseq.criterions import FairseqCriterion, register_criterion
11
+ from fairseq import metrics
12
+
13
+
14
+ @register_criterion("mmloss")
15
+ class MMCriterion(FairseqCriterion):
16
+ def __init__(self, task):
17
+ super().__init__(task)
18
+ # TODO (huxu): wrap forward call of loss_fn and eval_fn into task.
19
+ self.mmtask = task.mmtask
20
+
21
+ def forward(self, model, sample):
22
+ """Compute the loss for the given sample.
23
+ Returns a tuple with three elements:
24
+ 1) the loss
25
+ 2) the sample size, which is used as the denominator for the gradient
26
+ 3) logging outputs to display while training
27
+ """
28
+ outputs = self.mmtask(model, sample)
29
+
30
+ loss, loss_scalar, max_len, batch_size, sample_size = (
31
+ outputs["loss"],
32
+ outputs["loss_scalar"],
33
+ outputs["max_len"],
34
+ outputs["batch_size"],
35
+ outputs["sample_size"],
36
+ )
37
+
38
+ logging_output = {
39
+ "loss": loss_scalar,
40
+ "ntokens": max_len * batch_size, # dummy report.
41
+ "nsentences": batch_size, # dummy report.
42
+ "sample_size": sample_size,
43
+ }
44
+
45
+ return loss, 1, logging_output
46
+
47
+ @staticmethod
48
+ def reduce_metrics(logging_outputs) -> None:
49
+ """Aggregate logging outputs from data parallel training."""
50
+ """since we use NCE, our actual batch_size is 1 per GPU.
51
+ Then we take the mean of each worker."""
52
+ loss_sum = sum(log.get("loss", 0.0) for log in logging_outputs)
53
+ sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
54
+ metrics.log_scalar("loss", loss_sum / sample_size, round=3)
55
+
56
+ @staticmethod
57
+ def logging_outputs_can_be_summed() -> bool:
58
+ """
59
+ Whether the logging outputs returned by `forward` can be summed
60
+ across workers prior to calling `reduce_metrics`. Setting this
61
+ to True will improves distributed training speed.
62
+ """
63
+ return True
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/losses/loss.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. All Rights Reserved
2
+
3
+ import torch
4
+
5
+ from torch import nn
6
+
7
+
8
+ class Loss(object):
9
+ def __call__(self, *args, **kwargs):
10
+ raise NotImplementedError
11
+
12
+
13
+ # Dummy Loss for testing.
14
+ class DummyLoss(Loss):
15
+ def __init__(self):
16
+ self.loss = nn.CrossEntropyLoss()
17
+
18
+ def __call__(self, logits, targets, **kwargs):
19
+ return self.loss(logits, targets)
20
+
21
+
22
+ class DummyK400Loss(Loss):
23
+ """dummy k400 loss for MViT."""
24
+ def __init__(self):
25
+ self.loss = nn.CrossEntropyLoss()
26
+
27
+ def __call__(self, logits, targets, **kwargs):
28
+ return self.loss(
29
+ logits, torch.randint(0, 400, (logits.size(0),), device=logits.device))
30
+
31
+
32
+ class CrossEntropy(Loss):
33
+ def __init__(self):
34
+ self.loss = nn.CrossEntropyLoss()
35
+
36
+ def __call__(self, logits, targets, **kwargs):
37
+ return self.loss(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
38
+
39
+
40
+ class ArgmaxCrossEntropy(Loss):
41
+ def __init__(self):
42
+ self.loss = nn.CrossEntropyLoss()
43
+
44
+ def __call__(self, logits, targets, **kwargs):
45
+ return self.loss(logits, targets.argmax(dim=1))
46
+
47
+
48
+ class BCE(Loss):
49
+ def __init__(self):
50
+ self.loss = nn.BCEWithLogitsLoss()
51
+
52
+ def __call__(self, logits, targets, **kwargs):
53
+ targets = targets.squeeze(0)
54
+ return self.loss(logits, targets)
55
+
56
+
57
+ class NLGLoss(Loss):
58
+ def __init__(self):
59
+ self.loss = nn.CrossEntropyLoss()
60
+
61
+ def __call__(self, logits, text_label, **kwargs):
62
+ targets = text_label[text_label != -100]
63
+ return self.loss(logits, targets)
64
+
65
+
66
+ class MSE(Loss):
67
+ def __init__(self):
68
+ self.loss = nn.MSELoss()
69
+
70
+ def __call__(self, logits, targets, **kwargs):
71
+ return self.loss(logits, targets)
72
+
73
+
74
+ class L1(Loss):
75
+ def __init__(self):
76
+ self.loss = nn.L1Loss()
77
+
78
+ def __call__(self, logits, targets, **kwargs):
79
+ return self.loss(logits, targets)
80
+
81
+
82
+ class SmoothL1(Loss):
83
+ def __init__(self):
84
+ self.loss = nn.SmoothL1Loss()
85
+
86
+ def __call__(self, logits, targets, **kwargs):
87
+ return self.loss(logits, targets)
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/losses/nce.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """
7
+ softmax-based NCE loss, used by this project.
8
+ """
9
+
10
+ import torch
11
+
12
+ from torch import nn
13
+
14
+ from .loss import Loss
15
+
16
+
17
+ class NCE(Loss):
18
+ def __init__(self):
19
+ # TODO (huxu): define temperature.
20
+ self.loss = nn.CrossEntropyLoss()
21
+
22
+ def __call__(self, align_scores, **kargs):
23
+ # note: we reuse the same shape as cls head in BERT (batch_size, 2)
24
+ # but NCE only needs one logits.
25
+ # (so we drop all weights in the second neg logits.)
26
+ align_scores = align_scores[:, :1]
27
+ # duplicate negative examples
28
+ batch_size = align_scores.size(0) // 2
29
+ pos_scores = align_scores[:batch_size]
30
+ neg_scores = align_scores[batch_size:].view(1, batch_size).repeat(
31
+ batch_size, 1)
32
+ scores = torch.cat([pos_scores, neg_scores], dim=1)
33
+ return self.loss(
34
+ scores,
35
+ torch.zeros(
36
+ (batch_size,),
37
+ dtype=torch.long,
38
+ device=align_scores.device),
39
+ )
40
+
41
+
42
+ class T2VContraLoss(Loss):
43
+ """NCE for MM joint space, on softmax text2video matrix.
44
+ """
45
+ def __init__(self):
46
+ # TODO (huxu): define temperature.
47
+ self.loss = nn.CrossEntropyLoss()
48
+
49
+ def __call__(self, pooled_video, pooled_text, **kargs):
50
+ batch_size = pooled_video.size(0)
51
+ logits = torch.mm(pooled_text, pooled_video.transpose(1, 0))
52
+ targets = torch.arange(
53
+ batch_size,
54
+ dtype=torch.long,
55
+ device=pooled_video.device)
56
+ return self.loss(logits, targets)
57
+
58
+
59
+ class V2TContraLoss(Loss):
60
+ """NCE for MM joint space, with softmax on video2text matrix."""
61
+
62
+ def __init__(self):
63
+ # TODO (huxu): define temperature.
64
+ self.loss = nn.CrossEntropyLoss()
65
+
66
+ def __call__(self, pooled_video, pooled_text, **kargs):
67
+ batch_size = pooled_video.size(0)
68
+ logits = torch.mm(pooled_video, pooled_text.transpose(1, 0))
69
+ targets = torch.arange(
70
+ batch_size,
71
+ dtype=torch.long,
72
+ device=pooled_video.device)
73
+ return self.loss(logits, targets)
74
+
75
+
76
+ class MMContraLoss(Loss):
77
+ def __init__(self):
78
+ self.loss = nn.CrossEntropyLoss()
79
+
80
+ def __call__(self, pooled_video, pooled_text, **kwargs):
81
+ logits_per_video = pooled_video @ pooled_text.t()
82
+ logits_per_text = pooled_text @ pooled_video.t()
83
+
84
+ targets = torch.arange(
85
+ pooled_video.size(0),
86
+ dtype=torch.long,
87
+ device=pooled_video.device)
88
+ loss_video = self.loss(logits_per_video, targets)
89
+ loss_text = self.loss(logits_per_text, targets)
90
+ return loss_video + loss_text
91
+
92
+
93
+ class MTM(Loss):
94
+ """Combination of MFM and MLM."""
95
+
96
+ def __init__(self):
97
+ self.loss = nn.CrossEntropyLoss()
98
+
99
+ def __call__(
100
+ self,
101
+ video_logits,
102
+ text_logits,
103
+ video_label,
104
+ text_label,
105
+ **kwargs
106
+ ):
107
+ text_logits = torch.cat([
108
+ text_logits,
109
+ torch.zeros(
110
+ (text_logits.size(0), 1), device=text_logits.device)
111
+ ], dim=1)
112
+ vt_logits = torch.cat([video_logits, text_logits], dim=0)
113
+ # loss for video.
114
+ video_label = torch.zeros(
115
+ (video_logits.size(0),),
116
+ dtype=torch.long,
117
+ device=video_logits.device
118
+ )
119
+
120
+ # loss for text.
121
+ text_label = text_label.reshape(-1)
122
+ labels_mask = text_label != -100
123
+ selected_text_label = text_label[labels_mask]
124
+
125
+ vt_label = torch.cat([video_label, selected_text_label], dim=0)
126
+ return self.loss(vt_logits, vt_label)
127
+
128
+
129
+ class MFMMLM(Loss):
130
+ """Combination of MFM and MLM."""
131
+
132
+ def __init__(self):
133
+ self.loss = nn.CrossEntropyLoss()
134
+
135
+ def __call__(
136
+ self,
137
+ video_logits,
138
+ text_logits,
139
+ video_label,
140
+ text_label,
141
+ **kwargs
142
+ ):
143
+ # loss for video.
144
+ video_label = torch.zeros(
145
+ (video_logits.size(0),),
146
+ dtype=torch.long,
147
+ device=video_logits.device
148
+ )
149
+ masked_frame_loss = self.loss(video_logits, video_label)
150
+
151
+ # loss for text.
152
+ text_label = text_label.reshape(-1)
153
+ labels_mask = text_label != -100
154
+ selected_text_label = text_label[labels_mask]
155
+ masked_lm_loss = self.loss(text_logits, selected_text_label)
156
+ return masked_frame_loss + masked_lm_loss
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/models/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from .mmfusion import *
6
+ from .transformermodel import *
7
+ from .mmfusionnlg import *
8
+
9
+ try:
10
+ from .fairseqmmmodel import *
11
+ except ImportError:
12
+ pass
13
+
14
+ try:
15
+ from .expmmfusion import *
16
+ except ImportError:
17
+ pass
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/models/fairseqmmmodel.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from fairseq.models import (
7
+ BaseFairseqModel,
8
+ register_model,
9
+ register_model_architecture
10
+ )
11
+
12
+
13
+ @register_model("mmmodel")
14
+ class FairseqMMModel(BaseFairseqModel):
15
+ """a fairseq wrapper of model built by `task`."""
16
+
17
+ @classmethod
18
+ def build_model(cls, args, task):
19
+ return FairseqMMModel(task.mmtask.model)
20
+
21
+ def __init__(self, mmmodel):
22
+ super().__init__()
23
+ self.mmmodel = mmmodel
24
+
25
+ def forward(self, *args, **kwargs):
26
+ return self.mmmodel(*args, **kwargs)
27
+
28
+ def upgrade_state_dict_named(self, state_dict, name):
29
+
30
+ super().upgrade_state_dict_named(state_dict, name)
31
+
32
+ keys_to_delete = []
33
+
34
+ for key in state_dict:
35
+ if key not in self.state_dict():
36
+ keys_to_delete.append(key)
37
+ for key in keys_to_delete:
38
+ print("[INFO]", key, "not used anymore.")
39
+ del state_dict[key]
40
+
41
+ # copy any newly defined parameters.
42
+ for key in self.state_dict():
43
+ if key not in state_dict:
44
+ print("[INFO] adding", key)
45
+ state_dict[key] = self.state_dict()[key]
46
+
47
+
48
+ # a dummy arch, we config the model.
49
+ @register_model_architecture("mmmodel", "mmarch")
50
+ def mmarch(args):
51
+ pass
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/models/mmfusion.py ADDED
@@ -0,0 +1,926 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # Copyright (c) Facebook, Inc. All Rights Reserved
17
+
18
+
19
+ import torch
20
+
21
+ from torch import nn
22
+
23
+ try:
24
+ from transformers import AutoConfig, AutoTokenizer
25
+ except ImportError:
26
+ pass
27
+
28
+ from . import transformermodel
29
+
30
+
31
+ class MMPTModel(nn.Module):
32
+ """An e2e wrapper of inference model.
33
+ """
34
+ @classmethod
35
+ def from_pretrained(cls, config, checkpoint="checkpoint_best.pt"):
36
+ import os
37
+ from ..utils import recursive_config
38
+ from ..tasks import Task
39
+ config = recursive_config(config)
40
+ mmtask = Task.config_task(config)
41
+ checkpoint_path = os.path.join(config.eval.save_path, checkpoint)
42
+ mmtask.build_model(checkpoint=checkpoint_path)
43
+ # TODO(huxu): make the video encoder configurable.
44
+ from ..processors.models.s3dg import S3D
45
+ video_encoder = S3D('pretrained_models/s3d_dict.npy', 512)
46
+ video_encoder.load_state_dict(
47
+ torch.load('pretrained_models/s3d_howto100m.pth'))
48
+ from transformers import AutoTokenizer
49
+ tokenizer = AutoTokenizer.from_pretrained(
50
+ config.dataset.bert_name, use_fast=config.dataset.use_fast
51
+ )
52
+ from ..processors import Aligner
53
+ aligner = Aligner(config.dataset)
54
+ return (
55
+ MMPTModel(config, mmtask.model, video_encoder),
56
+ tokenizer,
57
+ aligner
58
+ )
59
+
60
+ def __init__(self, config, model, video_encoder, **kwargs):
61
+ super().__init__()
62
+ self.max_video_len = config.dataset.max_video_len
63
+ self.video_encoder = video_encoder
64
+ self.model = model
65
+
66
+ def forward(self, video_frames, caps, cmasks, return_score=False):
67
+ bsz = video_frames.size(0)
68
+ assert bsz == 1, "only bsz=1 is supported now."
69
+ seq_len = video_frames.size(1)
70
+ video_frames = video_frames.view(-1, *video_frames.size()[2:])
71
+ vfeats = self.video_encoder(video_frames.permute(0, 4, 1, 2, 3))
72
+ vfeats = vfeats['video_embedding']
73
+ vfeats = vfeats.view(bsz, seq_len, vfeats.size(-1))
74
+ padding = torch.zeros(
75
+ bsz, self.max_video_len - seq_len, vfeats.size(-1))
76
+ vfeats = torch.cat([vfeats, padding], dim=1)
77
+ vmasks = torch.cat([
78
+ torch.ones((bsz, seq_len), dtype=torch.bool),
79
+ torch.zeros((bsz, self.max_video_len - seq_len), dtype=torch.bool)
80
+ ],
81
+ dim=1
82
+ )
83
+ output = self.model(caps, cmasks, vfeats, vmasks)
84
+ if return_score:
85
+ output = {"score": torch.bmm(
86
+ output["pooled_video"][:, None, :],
87
+ output["pooled_text"][:, :, None]
88
+ ).squeeze(-1).squeeze(-1)}
89
+ return output
90
+
91
+
92
+ class MMFusion(nn.Module):
93
+ """a MMPT wrapper class for MMBert style models.
94
+ TODO: move isolated mask to a subclass.
95
+ """
96
+ def __init__(self, config, **kwargs):
97
+ super().__init__()
98
+ transformer_config = AutoConfig.from_pretrained(
99
+ config.dataset.bert_name)
100
+ self.hidden_size = transformer_config.hidden_size
101
+ self.is_train = False
102
+ if config.dataset.train_path is not None:
103
+ self.is_train = True
104
+ # 0 means no iso; 1-12 means iso up to that layer.
105
+ self.num_hidden_layers = transformer_config.num_hidden_layers
106
+ self.last_iso_layer = 0
107
+ if config.dataset.num_iso_layer is not None:
108
+ self.last_iso_layer = config.dataset.num_iso_layer - 1 + 1
109
+
110
+ if config.model.mm_encoder_cls is not None:
111
+ mm_encoder_cls = getattr(transformermodel, config.model.mm_encoder_cls)
112
+ model_config = AutoConfig.from_pretrained(config.dataset.bert_name)
113
+ model_config.max_video_len = config.dataset.max_video_len
114
+ # TODO: a general way to add parameter for a model.
115
+ model_config.use_seg_emb = config.model.use_seg_emb
116
+ self.mm_encoder = mm_encoder_cls.from_pretrained(
117
+ config.dataset.bert_name, config=model_config)
118
+ elif config.model.video_encoder_cls is not None\
119
+ and config.model.text_encoder_cls is not None:
120
+ video_encoder_cls = getattr(transformermodel, config.model.video_encoder_cls)
121
+ model_config = AutoConfig.from_pretrained(config.dataset.bert_name)
122
+ model_config.max_video_len = config.dataset.max_video_len
123
+ # TODO: make each model a set of config class.
124
+ if hasattr(model_config, "num_layers"):
125
+ model_config.num_layers = config.model.num_hidden_video_layers
126
+ else:
127
+ model_config.num_hidden_layers = config.model.num_hidden_video_layers
128
+ self.video_encoder = video_encoder_cls.from_pretrained(
129
+ config.dataset.bert_name, config=model_config)
130
+ # exact same NLP model from Huggingface.
131
+ text_encoder_cls = getattr(transformermodel, config.model.text_encoder_cls)
132
+ self.text_encoder = text_encoder_cls.from_pretrained(
133
+ config.dataset.bert_name)
134
+ else:
135
+ raise ValueError("the encoder must be either MM or two backbones.")
136
+
137
+ def forward(
138
+ self,
139
+ caps,
140
+ cmasks,
141
+ vfeats,
142
+ vmasks,
143
+ **kwargs
144
+ ):
145
+ raise NotImplementedError(
146
+ "Please derive MMFusion module."
147
+ )
148
+
149
+ def _mm_on_the_fly(
150
+ self,
151
+ cmasks,
152
+ vmasks,
153
+ attention_mask
154
+ ):
155
+ """helper function for mask, seg_ids and token_type_ids."""
156
+ if attention_mask is None:
157
+ attention_mask = self._mm_attention_mask(cmasks, vmasks)
158
+
159
+ """
160
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
161
+ | first sequence | second sequence |
162
+ """
163
+ token_type_ids = torch.cat(
164
+ [
165
+ torch.zeros(
166
+ (vmasks.size(0), vmasks.size(1) + 2),
167
+ dtype=torch.long,
168
+ device=vmasks.device,
169
+ ),
170
+ torch.ones(
171
+ (cmasks.size(0), cmasks.size(1) - 2),
172
+ dtype=torch.long,
173
+ device=cmasks.device,
174
+ ),
175
+ ],
176
+ dim=1,
177
+ )
178
+ return attention_mask, token_type_ids
179
+
180
+ def _mm_attention_mask(self, cmasks, vmasks):
181
+ assert cmasks.size(0) == vmasks.size(0), "{}, {}, {}, {}".format(
182
+ str(cmasks.size()),
183
+ str(vmasks.size()),
184
+ str(cmasks.size(0)),
185
+ str(vmasks.size(0)),
186
+ )
187
+
188
+ mm_mask = torch.cat([cmasks[:, :1], vmasks, cmasks[:, 1:]], dim=1)
189
+ if self.last_iso_layer == 0:
190
+ # hard attention mask.
191
+ return mm_mask
192
+ else:
193
+ # a gpu iso mask; 0 : num_iso_layer is isolated;
194
+ # num_iso_layer: are MM-fused.
195
+ # make an iso layer
196
+ batch_size = cmasks.size(0)
197
+ iso_mask = self._make_iso_mask(batch_size, cmasks, vmasks)
198
+ mm_mask = mm_mask[:, None, :].repeat(1, mm_mask.size(-1), 1)
199
+ iso_mm_masks = []
200
+ # hard attention mask.
201
+ iso_mask = iso_mask[:, None, :, :].repeat(
202
+ 1, self.last_iso_layer, 1, 1)
203
+ iso_mm_masks.append(iso_mask)
204
+ if self.last_iso_layer < self.num_hidden_layers:
205
+ mm_mask = mm_mask[:, None, :, :].repeat(
206
+ 1, self.num_hidden_layers - self.last_iso_layer, 1, 1
207
+ )
208
+ iso_mm_masks.append(mm_mask)
209
+ iso_mm_masks = torch.cat(iso_mm_masks, dim=1)
210
+ return iso_mm_masks
211
+
212
+ def _make_iso_mask(self, batch_size, cmasks, vmasks):
213
+ cls_self_mask = torch.cat(
214
+ [
215
+ torch.ones(
216
+ (batch_size, 1), dtype=torch.bool, device=cmasks.device),
217
+ torch.zeros(
218
+ (batch_size, cmasks.size(1) + vmasks.size(1) - 1),
219
+ dtype=torch.bool, device=cmasks.device)
220
+ ], dim=1)
221
+
222
+ iso_video_mask = torch.cat(
223
+ [
224
+ # [CLS] is not used.
225
+ torch.zeros(
226
+ (batch_size, 1), dtype=torch.bool, device=cmasks.device
227
+ ),
228
+ vmasks,
229
+ # assume to be 1.
230
+ cmasks[:, 1:2],
231
+ # 2 means [CLS] + [SEP]
232
+ torch.zeros(
233
+ (batch_size, cmasks.size(1) - 2),
234
+ dtype=torch.bool,
235
+ device=cmasks.device,
236
+ ),
237
+ ],
238
+ dim=1,
239
+ )
240
+ iso_text_mask = torch.cat(
241
+ [
242
+ torch.zeros(
243
+ (batch_size, 2 + vmasks.size(1)),
244
+ dtype=torch.bool,
245
+ device=cmasks.device,
246
+ ), # [CLS] is not used.
247
+ cmasks[:, 2:], # assume to be 1.
248
+ ],
249
+ dim=1,
250
+ )
251
+ cls_self_mask = cls_self_mask[:, None, :]
252
+ iso_video_mask = iso_video_mask[:, None, :].repeat(
253
+ 1, vmasks.size(1) + 1, 1)
254
+ iso_text_mask = iso_text_mask[:, None, :].repeat(
255
+ 1, cmasks.size(1) - 2, 1)
256
+ return torch.cat([cls_self_mask, iso_video_mask, iso_text_mask], dim=1)
257
+
258
+ def _pooling_vt_layer(
259
+ self,
260
+ layered_sequence_output,
261
+ cmasks,
262
+ vmasks
263
+ ):
264
+ layer_idx = self.last_iso_layer \
265
+ if self.last_iso_layer > 0 else self.num_hidden_layers
266
+ hidden_state = layered_sequence_output[layer_idx]
267
+ # also output pooled_video and pooled_text.
268
+ batch_size = cmasks.size(0)
269
+ # pool the modality.
270
+ text_offset = vmasks.size(1) + 2 # [CLS] + [SEP]
271
+ # video tokens + [SEP]
272
+ video_outputs = hidden_state[:, 1:text_offset]
273
+ video_attention_mask = torch.cat(
274
+ [
275
+ vmasks,
276
+ torch.ones(
277
+ (batch_size, 1), dtype=torch.bool, device=vmasks.device),
278
+ ],
279
+ dim=1,
280
+ )
281
+ assert video_outputs.size(1) == video_attention_mask.size(1)
282
+ pooled_video = torch.sum(
283
+ video_outputs * video_attention_mask.unsqueeze(-1), dim=1
284
+ ) / video_attention_mask.sum(1, keepdim=True)
285
+ # pooled_video = torch.mean(video_outputs[0], dim=1)
286
+
287
+ # text tokens + [SEP]
288
+ text_attention_mask = cmasks[:, 2:]
289
+ text_outputs = hidden_state[:, text_offset:]
290
+ assert text_outputs.size(1) == text_attention_mask.size(1)
291
+ pooled_text = torch.sum(
292
+ text_outputs * text_attention_mask.unsqueeze(-1), dim=1
293
+ ) / text_attention_mask.sum(1, keepdim=True)
294
+ return pooled_video, pooled_text
295
+
296
+
297
+ class MMFusionMFMMLM(MMFusion):
298
+ """forward function for MFM and MLM."""
299
+ def forward(
300
+ self,
301
+ caps,
302
+ cmasks,
303
+ vfeats,
304
+ vmasks,
305
+ attention_mask=None,
306
+ video_label=None,
307
+ text_label=None,
308
+ **kwargs
309
+ ):
310
+ output_hidden_states = False if self.is_train else True
311
+
312
+ target_vfeats, non_masked_frame_mask = None, None
313
+ if video_label is not None:
314
+ target_vfeats = vfeats.masked_select(
315
+ video_label.unsqueeze(-1)).view(
316
+ -1, vfeats.size(-1)
317
+ )
318
+ # mask video token.
319
+ vfeats[video_label] = 0.0
320
+ non_masked_frame_mask = vmasks.clone()
321
+ non_masked_frame_mask[video_label] = False
322
+
323
+ attention_mask, token_type_ids = self._mm_on_the_fly(
324
+ cmasks, vmasks, attention_mask)
325
+
326
+ outputs = self.mm_encoder(
327
+ input_ids=caps,
328
+ input_video_embeds=vfeats,
329
+ attention_mask=attention_mask,
330
+ token_type_ids=token_type_ids,
331
+ masked_frame_labels=video_label,
332
+ target_video_hidden_states=target_vfeats,
333
+ non_masked_frame_mask=non_masked_frame_mask,
334
+ masked_lm_labels=text_label,
335
+ output_hidden_states=output_hidden_states,
336
+ )
337
+
338
+ video_logits, text_logits = outputs[0], outputs[1]
339
+
340
+ if self.is_train: # return earlier for training.
341
+ return {
342
+ "video_logits": video_logits,
343
+ "text_logits": text_logits,
344
+ }
345
+
346
+ pooled_video, pooled_text = self._pooling_vt_layer(
347
+ outputs[2], cmasks, vmasks)
348
+ return {"pooled_video": pooled_video, "pooled_text": pooled_text}
349
+
350
+
351
+ class MMFusionMTM(MMFusionMFMMLM):
352
+ def __init__(self, config, **kwargs):
353
+ super().__init__(config)
354
+ """
355
+ For reproducibility:
356
+ self.mm_encoder will be initialized then discarded.
357
+ """
358
+ from .transformermodel import MMBertForMTM
359
+ model_config = AutoConfig.from_pretrained(config.dataset.bert_name)
360
+ model_config.max_video_len = config.dataset.max_video_len
361
+ model_config.use_seg_emb = config.model.use_seg_emb
362
+ self.mm_encoder = MMBertForMTM.from_pretrained(
363
+ config.dataset.bert_name, config=model_config)
364
+
365
+
366
+ class MMFusionShare(MMFusion):
367
+ """A retrival wrapper using mm_encoder as both video/text backbone.
368
+ TODO: move formally.
369
+ """
370
+ def forward(
371
+ self,
372
+ caps,
373
+ cmasks,
374
+ vfeats,
375
+ vmasks,
376
+ attention_mask=None,
377
+ video_label=None,
378
+ text_label=None,
379
+ output_hidden_states=False,
380
+ **kwargs
381
+ ):
382
+ pooled_video = self.forward_video(
383
+ vfeats,
384
+ vmasks,
385
+ caps,
386
+ cmasks,
387
+ output_hidden_states
388
+ )
389
+
390
+ pooled_text = self.forward_text(
391
+ caps,
392
+ cmasks,
393
+ output_hidden_states
394
+ )
395
+
396
+ return {"pooled_video": pooled_video, "pooled_text": pooled_text}
397
+
398
+ def forward_video(
399
+ self,
400
+ vfeats,
401
+ vmasks,
402
+ caps,
403
+ cmasks,
404
+ output_hidden_states=False,
405
+ **kwargs
406
+ ):
407
+ input_ids = caps[:, :2]
408
+
409
+ attention_mask = torch.cat([
410
+ cmasks[:, :1],
411
+ vmasks,
412
+ cmasks[:, 1:2]
413
+ ], dim=1)
414
+
415
+ token_type_ids = torch.zeros(
416
+ (vmasks.size(0), vmasks.size(1) + 2),
417
+ dtype=torch.long,
418
+ device=vmasks.device)
419
+
420
+ outputs = self.mm_encoder(
421
+ input_ids=input_ids,
422
+ input_video_embeds=vfeats,
423
+ attention_mask=attention_mask,
424
+ token_type_ids=token_type_ids,
425
+ output_hidden_states=True
426
+ )
427
+ video_outputs = outputs[0]
428
+
429
+ if output_hidden_states:
430
+ return video_outputs
431
+
432
+ batch_size = cmasks.size(0)
433
+
434
+ video_attention_mask = torch.cat(
435
+ [
436
+ torch.zeros(
437
+ (batch_size, 1), dtype=torch.bool, device=vmasks.device),
438
+ vmasks,
439
+ torch.ones(
440
+ (batch_size, 1), dtype=torch.bool, device=vmasks.device),
441
+ ],
442
+ dim=1,
443
+ )
444
+ assert video_outputs.size(1) == video_attention_mask.size(1)
445
+
446
+ video_attention_mask = video_attention_mask.type(video_outputs.dtype) \
447
+ / video_attention_mask.sum(1, keepdim=True)
448
+
449
+ pooled_video = torch.bmm(
450
+ video_outputs.transpose(2, 1),
451
+ video_attention_mask.unsqueeze(2)
452
+ ).squeeze(-1)
453
+ return pooled_video # video_outputs
454
+
455
+ def forward_text(
456
+ self,
457
+ caps,
458
+ cmasks,
459
+ output_hidden_states=False,
460
+ **kwargs
461
+ ):
462
+ input_ids = torch.cat([
463
+ caps[:, :1], caps[:, 2:],
464
+ ], dim=1)
465
+
466
+ attention_mask = torch.cat([
467
+ cmasks[:, :1],
468
+ cmasks[:, 2:]
469
+ ], dim=1)
470
+
471
+ token_type_ids = torch.cat([
472
+ torch.zeros(
473
+ (cmasks.size(0), 1),
474
+ dtype=torch.long,
475
+ device=cmasks.device),
476
+ torch.ones(
477
+ (cmasks.size(0), cmasks.size(1) - 2),
478
+ dtype=torch.long,
479
+ device=cmasks.device)
480
+ ], dim=1)
481
+
482
+ outputs = self.mm_encoder(
483
+ input_ids=input_ids,
484
+ input_video_embeds=None,
485
+ attention_mask=attention_mask,
486
+ token_type_ids=token_type_ids,
487
+ output_hidden_states=True
488
+ )
489
+ text_outputs = outputs[0]
490
+
491
+ if output_hidden_states:
492
+ return text_outputs
493
+
494
+ batch_size = caps.size(0)
495
+ # text tokens + [SEP]
496
+ text_attention_mask = torch.cat([
497
+ torch.zeros(
498
+ (batch_size, 1), dtype=torch.bool, device=cmasks.device),
499
+ cmasks[:, 2:]
500
+ ], dim=1)
501
+
502
+ assert text_outputs.size(1) == text_attention_mask.size(1)
503
+
504
+ text_attention_mask = text_attention_mask.type(text_outputs.dtype) \
505
+ / text_attention_mask.sum(1, keepdim=True)
506
+
507
+ pooled_text = torch.bmm(
508
+ text_outputs.transpose(2, 1),
509
+ text_attention_mask.unsqueeze(2)
510
+ ).squeeze(-1)
511
+ return pooled_text # text_outputs
512
+
513
+
514
+ class MMFusionSeparate(MMFusionShare):
515
+ def forward_video(
516
+ self,
517
+ vfeats,
518
+ vmasks,
519
+ caps,
520
+ cmasks,
521
+ output_hidden_states=False,
522
+ **kwargs
523
+ ):
524
+ input_ids = caps[:, :2]
525
+
526
+ attention_mask = torch.cat([
527
+ cmasks[:, :1],
528
+ vmasks,
529
+ cmasks[:, 1:2]
530
+ ], dim=1)
531
+
532
+ token_type_ids = torch.zeros(
533
+ (vmasks.size(0), vmasks.size(1) + 2),
534
+ dtype=torch.long,
535
+ device=vmasks.device)
536
+
537
+ outputs = self.video_encoder(
538
+ input_ids=input_ids,
539
+ input_video_embeds=vfeats,
540
+ attention_mask=attention_mask,
541
+ token_type_ids=token_type_ids,
542
+ output_hidden_states=True
543
+ )
544
+ video_outputs = outputs[0]
545
+
546
+ if output_hidden_states:
547
+ return video_outputs
548
+
549
+ batch_size = cmasks.size(0)
550
+
551
+ video_attention_mask = torch.cat(
552
+ [
553
+ torch.zeros(
554
+ (batch_size, 1), dtype=torch.bool, device=vmasks.device),
555
+ vmasks,
556
+ torch.ones(
557
+ (batch_size, 1), dtype=torch.bool, device=vmasks.device),
558
+ ],
559
+ dim=1,
560
+ )
561
+ assert video_outputs.size(1) == video_attention_mask.size(1)
562
+
563
+ video_attention_mask = video_attention_mask.type(video_outputs.dtype) \
564
+ / video_attention_mask.sum(1, keepdim=True)
565
+
566
+ pooled_video = torch.bmm(
567
+ video_outputs.transpose(2, 1),
568
+ video_attention_mask.unsqueeze(2)
569
+ ).squeeze(-1)
570
+ return pooled_video # video_outputs
571
+
572
+ def forward_text(
573
+ self,
574
+ caps,
575
+ cmasks,
576
+ output_hidden_states=False,
577
+ **kwargs
578
+ ):
579
+ input_ids = torch.cat([
580
+ caps[:, :1], caps[:, 2:],
581
+ ], dim=1)
582
+
583
+ attention_mask = torch.cat([
584
+ cmasks[:, :1],
585
+ cmasks[:, 2:]
586
+ ], dim=1)
587
+ # different from sharing, we use all-0 type.
588
+ token_type_ids = torch.zeros(
589
+ (cmasks.size(0), cmasks.size(1) - 1),
590
+ dtype=torch.long,
591
+ device=cmasks.device)
592
+
593
+ outputs = self.text_encoder(
594
+ input_ids=input_ids,
595
+ attention_mask=attention_mask,
596
+ token_type_ids=token_type_ids,
597
+ output_hidden_states=True
598
+ )
599
+ text_outputs = outputs[0]
600
+
601
+ if output_hidden_states:
602
+ return text_outputs
603
+
604
+ batch_size = caps.size(0)
605
+ # text tokens + [SEP]
606
+ text_attention_mask = torch.cat([
607
+ torch.zeros(
608
+ (batch_size, 1), dtype=torch.bool, device=cmasks.device),
609
+ cmasks[:, 2:]
610
+ ], dim=1)
611
+
612
+ assert text_outputs.size(1) == text_attention_mask.size(1)
613
+
614
+ text_attention_mask = text_attention_mask.type(text_outputs.dtype) \
615
+ / text_attention_mask.sum(1, keepdim=True)
616
+
617
+ pooled_text = torch.bmm(
618
+ text_outputs.transpose(2, 1),
619
+ text_attention_mask.unsqueeze(2)
620
+ ).squeeze(-1)
621
+ return pooled_text # text_outputs
622
+
623
+
624
+ class MMFusionJoint(MMFusion):
625
+ """fine-tuning wrapper for retrival task."""
626
+
627
+ def forward(
628
+ self,
629
+ caps,
630
+ cmasks,
631
+ vfeats,
632
+ vmasks,
633
+ attention_mask=None,
634
+ video_label=None,
635
+ text_label=None,
636
+ **kwargs
637
+ ):
638
+ # TODO (huxu): other ways to do negative examples; move the following
639
+ # into your criterion forward.
640
+ output_hidden_states = True
641
+
642
+ attention_mask, token_type_ids = self._mm_on_the_fly(
643
+ cmasks, vmasks, attention_mask)
644
+
645
+ separate_forward_split = (
646
+ None if self.is_train else vmasks.size(1) + 2
647
+ ) # [CLS] + [SEP]
648
+
649
+ outputs = self.mm_encoder(
650
+ input_ids=caps,
651
+ input_video_embeds=vfeats,
652
+ attention_mask=attention_mask,
653
+ token_type_ids=token_type_ids,
654
+ output_hidden_states=output_hidden_states,
655
+ separate_forward_split=separate_forward_split,
656
+ )
657
+
658
+ pooled_video, pooled_text = self._pooling_vt_layer(
659
+ outputs[2], cmasks, vmasks)
660
+ return {"pooled_video": pooled_video, "pooled_text": pooled_text}
661
+
662
+
663
+ class MMFusionActionSegmentation(MMFusion):
664
+ """Fine-tuning wrapper for action segmentation.
665
+ TODO: rename this for VLM.
666
+ """
667
+ def forward(
668
+ self,
669
+ caps,
670
+ cmasks,
671
+ vfeats,
672
+ vmasks,
673
+ attention_mask=None,
674
+ **kwargs
675
+ ):
676
+ # ActionLocalization assume of batch_size=1, squeeze it.
677
+ caps = caps.view(-1, caps.size(-1))
678
+ cmasks = cmasks.view(-1, cmasks.size(-1))
679
+ vfeats = vfeats.view(-1, vfeats.size(2), vfeats.size(3))
680
+ vmasks = vmasks.view(-1, vmasks.size(-1))
681
+
682
+ # this may not cover all shapes of attention_mask.
683
+ attention_mask = attention_mask.view(
684
+ -1, attention_mask.size(2), attention_mask.size(3)) \
685
+ if attention_mask is not None else None
686
+
687
+ # TODO (huxu): other ways to do negative examples; move the following
688
+ # into your criterion forward.
689
+ output_hidden_states = True
690
+
691
+ # video forwarding, text is dummy; never use attention_mask.
692
+ attention_mask, token_type_ids = self._mm_on_the_fly(
693
+ cmasks, vmasks, attention_mask)
694
+
695
+ logits = self.mm_encoder(
696
+ input_ids=caps,
697
+ input_video_embeds=vfeats,
698
+ attention_mask=attention_mask,
699
+ token_type_ids=token_type_ids,
700
+ output_hidden_states=output_hidden_states,
701
+ )
702
+ return {"logits": logits[0][:, 1:vmasks.size(1)+1]}
703
+
704
+
705
+ class MMFusionActionLocalization(MMFusion):
706
+ """fine-tuning model for retrival task."""
707
+
708
+ def __init__(self, config, **kwargs):
709
+ super().__init__(config)
710
+ tokenizer = AutoTokenizer.from_pretrained(
711
+ config.dataset.bert_name)
712
+ self.cls_token_id = tokenizer.cls_token_id
713
+ self.sep_token_id = tokenizer.sep_token_id
714
+ self.pad_token_id = tokenizer.pad_token_id
715
+
716
+ def forward(
717
+ self,
718
+ caps,
719
+ cmasks,
720
+ vfeats,
721
+ vmasks,
722
+ attention_mask=None,
723
+ **kwargs
724
+ ):
725
+ # ActionLocalization assume of batch_size=1, squeeze it.
726
+ caps = caps.squeeze(0)
727
+ cmasks = cmasks.squeeze(0)
728
+ vfeats = vfeats.squeeze(0)
729
+ vmasks = vmasks.squeeze(0)
730
+ attention_mask = attention_mask.squeeze(0) if attention_mask is not None else None
731
+
732
+ # TODO (huxu): other ways to do negative examples; move the following
733
+ # into your criterion forward.
734
+ output_hidden_states = True
735
+
736
+ # a len1 dummy video token.
737
+ dummy_vfeats = torch.zeros(
738
+ (caps.size(0), 1, vfeats.size(-1)), device=vfeats.device, dtype=vfeats.dtype)
739
+ dummy_vmasks = torch.ones(
740
+ (caps.size(0), 1), dtype=torch.bool,
741
+ device=vfeats.device)
742
+
743
+ dummy_caps = torch.LongTensor(
744
+ [[self.cls_token_id, self.sep_token_id,
745
+ self.pad_token_id, self.sep_token_id]],
746
+ ).to(caps.device).repeat(vfeats.size(0), 1)
747
+ dummy_cmasks = torch.BoolTensor(
748
+ [[0, 1, 0, 1]] # pad are valid for attention.
749
+ ).to(caps.device).repeat(vfeats.size(0), 1)
750
+
751
+ # video forwarding, text is dummy; never use attention_mask.
752
+ attention_mask, token_type_ids = self._mm_on_the_fly(
753
+ dummy_cmasks, vmasks, None)
754
+
755
+ outputs = self.mm_encoder(
756
+ input_ids=dummy_caps,
757
+ input_video_embeds=vfeats,
758
+ attention_mask=attention_mask,
759
+ token_type_ids=token_type_ids,
760
+ output_hidden_states=output_hidden_states,
761
+ )
762
+
763
+ layer_idx = self.last_iso_layer \
764
+ if self.last_iso_layer > 0 else self.num_hidden_layers
765
+
766
+ video_seq = outputs[2][layer_idx][:, 1:vmasks.size(1)+1].masked_select(
767
+ vmasks.unsqueeze(-1)
768
+ ).view(-1, self.hidden_size)
769
+
770
+ # text forwarding, video is dummy
771
+ attention_mask, token_type_ids = self._mm_on_the_fly(
772
+ cmasks, dummy_vmasks, None)
773
+
774
+ outputs = self.mm_encoder(
775
+ input_ids=caps,
776
+ input_video_embeds=dummy_vfeats,
777
+ attention_mask=attention_mask,
778
+ token_type_ids=token_type_ids,
779
+ output_hidden_states=output_hidden_states,
780
+ )
781
+
782
+ _, pooled_text = self._pooling_vt_layer(
783
+ outputs[2], cmasks, dummy_vmasks)
784
+ # this line is not right.
785
+ logits = torch.mm(video_seq, pooled_text.transpose(1, 0))
786
+ return {"logits": logits}
787
+
788
+
789
+ # --------------- MMFusionSeparate for end tasks ---------------
790
+
791
+ class MMFusionSeparateActionSegmentation(MMFusionSeparate):
792
+ """Fine-tuning wrapper for action segmentation."""
793
+ def forward(
794
+ self,
795
+ caps,
796
+ cmasks,
797
+ vfeats,
798
+ vmasks,
799
+ attention_mask=None,
800
+ **kwargs
801
+ ):
802
+ # ActionLocalization assume of batch_size=1, squeeze it.
803
+ caps = caps.view(-1, caps.size(-1))
804
+ cmasks = cmasks.view(-1, cmasks.size(-1))
805
+ vfeats = vfeats.view(-1, vfeats.size(2), vfeats.size(3))
806
+ vmasks = vmasks.view(-1, vmasks.size(-1))
807
+ logits = self.forward_video(
808
+ vfeats,
809
+ vmasks,
810
+ caps,
811
+ cmasks,
812
+ output_hidden_states=True
813
+ )
814
+ return {"logits": logits[:, 1:vmasks.size(1)+1]}
815
+
816
+
817
+ class MMFusionSeparateActionLocalization(MMFusionSeparate):
818
+ def __init__(self, config, **kwargs):
819
+ super().__init__(config)
820
+ tokenizer = AutoTokenizer.from_pretrained(
821
+ config.dataset.bert_name)
822
+ self.cls_token_id = tokenizer.cls_token_id
823
+ self.sep_token_id = tokenizer.sep_token_id
824
+ self.pad_token_id = tokenizer.pad_token_id
825
+
826
+ def forward(
827
+ self,
828
+ caps,
829
+ cmasks,
830
+ vfeats,
831
+ vmasks,
832
+ **kwargs
833
+ ):
834
+ # ActionLocalization assume of batch_size=1, squeeze it.
835
+ caps = caps.squeeze(0)
836
+ cmasks = cmasks.squeeze(0)
837
+ vfeats = vfeats.squeeze(0)
838
+ vmasks = vmasks.squeeze(0)
839
+
840
+ # TODO (huxu): other ways to do negative examples; move the following
841
+ # into your criterion forward.
842
+ dummy_caps = torch.LongTensor(
843
+ [[self.cls_token_id, self.sep_token_id,
844
+ self.pad_token_id, self.sep_token_id]],
845
+ ).to(caps.device).repeat(vfeats.size(0), 1)
846
+ dummy_cmasks = torch.BoolTensor(
847
+ [[0, 1, 0, 1]] # pad are valid for attention.
848
+ ).to(caps.device).repeat(vfeats.size(0), 1)
849
+
850
+ outputs = self.forward_video(
851
+ vfeats,
852
+ vmasks,
853
+ dummy_caps,
854
+ dummy_cmasks,
855
+ output_hidden_states=True
856
+ )
857
+
858
+ video_seq = outputs[:, 1:vmasks.size(1)+1].masked_select(
859
+ vmasks.unsqueeze(-1)
860
+ ).view(-1, self.hidden_size)
861
+
862
+ pooled_text = self.forward_text(
863
+ caps,
864
+ cmasks,
865
+ output_hidden_states=False
866
+ )
867
+
868
+ # this line is not right.
869
+ logits = torch.mm(video_seq, pooled_text.transpose(1, 0))
870
+ return {"logits": logits}
871
+
872
+
873
+ class MMFusionShareActionLocalization(MMFusionShare):
874
+ def __init__(self, config, **kwargs):
875
+ super().__init__(config)
876
+ tokenizer = AutoTokenizer.from_pretrained(
877
+ config.dataset.bert_name)
878
+ self.cls_token_id = tokenizer.cls_token_id
879
+ self.sep_token_id = tokenizer.sep_token_id
880
+ self.pad_token_id = tokenizer.pad_token_id
881
+
882
+ def forward(
883
+ self,
884
+ caps,
885
+ cmasks,
886
+ vfeats,
887
+ vmasks,
888
+ **kwargs
889
+ ):
890
+ # ActionLocalization assume of batch_size=1, squeeze it.
891
+ caps = caps.squeeze(0)
892
+ cmasks = cmasks.squeeze(0)
893
+ vfeats = vfeats.squeeze(0)
894
+ vmasks = vmasks.squeeze(0)
895
+
896
+ # TODO (huxu): other ways to do negative examples; move the following
897
+ # into your criterion forward.
898
+ dummy_caps = torch.LongTensor(
899
+ [[self.cls_token_id, self.sep_token_id,
900
+ self.pad_token_id, self.sep_token_id]],
901
+ ).to(caps.device).repeat(vfeats.size(0), 1)
902
+ dummy_cmasks = torch.BoolTensor(
903
+ [[0, 1, 0, 1]] # pad are valid for attention.
904
+ ).to(caps.device).repeat(vfeats.size(0), 1)
905
+
906
+ outputs = self.forward_video(
907
+ vfeats,
908
+ vmasks,
909
+ dummy_caps,
910
+ dummy_cmasks,
911
+ output_hidden_states=True
912
+ )
913
+
914
+ video_seq = outputs[:, 1:vmasks.size(1)+1].masked_select(
915
+ vmasks.unsqueeze(-1)
916
+ ).view(-1, self.hidden_size)
917
+
918
+ pooled_text = self.forward_text(
919
+ caps,
920
+ cmasks,
921
+ output_hidden_states=False
922
+ )
923
+
924
+ # this line is not right.
925
+ logits = torch.mm(video_seq, pooled_text.transpose(1, 0))
926
+ return {"logits": logits}
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/models/mmfusionnlg.py ADDED
@@ -0,0 +1,999 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # Copyright (c) Facebook, Inc. All Rights Reserved
17
+
18
+
19
+ import torch
20
+
21
+ from torch.nn import functional as F
22
+
23
+ from typing import Optional, Iterable
24
+
25
+ try:
26
+ from transformers import BertPreTrainedModel
27
+ from transformers.modeling_bert import BertOnlyMLMHead
28
+
29
+ from transformers.file_utils import ModelOutput
30
+ from transformers.modeling_outputs import CausalLMOutput
31
+ from transformers.generation_utils import (
32
+ BeamHypotheses,
33
+ top_k_top_p_filtering
34
+ )
35
+ except ImportError:
36
+ pass
37
+
38
+ from .mmfusion import MMFusion
39
+ from .transformermodel import MMBertModel
40
+ from ..modules import VideoTokenMLP
41
+
42
+
43
+ class MMFusionNLG(MMFusion):
44
+ def __init__(self, config, **kwargs):
45
+ super().__init__(config)
46
+ if config.model.max_decode_length is not None:
47
+ self.max_length = min(
48
+ config.model.max_decode_length,
49
+ config.dataset.max_len - config.dataset.max_video_len - 3
50
+ )
51
+ else:
52
+ self.max_length = \
53
+ config.dataset.max_len - config.dataset.max_video_len - 3
54
+ self.gen_param = config.gen_param if config.gen_param is not None \
55
+ else {}
56
+
57
+ def forward(
58
+ self,
59
+ caps,
60
+ cmasks,
61
+ vfeats,
62
+ vmasks,
63
+ attention_mask,
64
+ video_label=None,
65
+ text_label=None,
66
+ **kwargs
67
+ ):
68
+ """use pre-trained LM header for generation."""
69
+ attention_mask, token_type_ids = self._mm_on_the_fly(
70
+ cmasks, vmasks, attention_mask)
71
+
72
+ outputs = self.mm_encoder(
73
+ input_ids=caps,
74
+ input_video_embeds=vfeats,
75
+ attention_mask=attention_mask,
76
+ token_type_ids=token_type_ids,
77
+ masked_lm_labels=text_label,
78
+ )
79
+ return {"logits": outputs[0]}
80
+
81
+ @torch.no_grad()
82
+ def generate(
83
+ self,
84
+ caps, cmasks, vfeats, vmasks,
85
+ attention_mask=None,
86
+ bos_token_id=None,
87
+ eos_token_id=None,
88
+ **kwargs
89
+ ):
90
+ # a simplified interface from
91
+ # https://huggingface.co/transformers/v3.4.0/_modules/transformers/generation_utils.html#GenerationMixin.generate
92
+
93
+ # caps now only have
94
+ # [CLS], [SEP] (for video) and [CLS] (as bos_token)
95
+ assert caps.size(1) == 3
96
+
97
+ attention_mask, token_type_ids = self._mm_on_the_fly(
98
+ cmasks, vmasks, attention_mask)
99
+
100
+ output = self.mm_encoder.generate(
101
+ input_ids=caps,
102
+ input_video_embeds=vfeats,
103
+ attention_mask=attention_mask,
104
+ token_type_ids=token_type_ids,
105
+ bos_token_id=bos_token_id,
106
+ eos_token_id=eos_token_id,
107
+ max_length=self.max_length,
108
+ **self.gen_param
109
+ )
110
+ return output
111
+
112
+
113
+ class MMBertForNLG(BertPreTrainedModel):
114
+ def __init__(self, config):
115
+ super().__init__(config)
116
+ self.bert = MMBertModel(config)
117
+ self.videomlp = VideoTokenMLP(config)
118
+ # we do not use `BertGenerationOnlyLMHead`
119
+ # because we can reuse pretraining.
120
+ self.cls = BertOnlyMLMHead(config)
121
+ self.hidden_size = config.hidden_size
122
+ self.init_weights()
123
+
124
+ def get_output_embeddings(self):
125
+ return self.cls.predictions.decoder
126
+
127
+ def forward(
128
+ self,
129
+ input_ids=None,
130
+ input_video_embeds=None,
131
+ attention_mask=None,
132
+ token_type_ids=None,
133
+ position_ids=None,
134
+ head_mask=None,
135
+ inputs_embeds=None,
136
+ masked_lm_labels=None,
137
+ output_attentions=None,
138
+ output_hidden_states=None,
139
+ return_dict=None,
140
+ ):
141
+ # similar to MMBertForMFMMLM without MFM.
142
+ video_tokens = self.videomlp(input_video_embeds)
143
+ outputs = self.bert(
144
+ input_ids,
145
+ video_tokens,
146
+ attention_mask=attention_mask,
147
+ token_type_ids=token_type_ids,
148
+ position_ids=position_ids,
149
+ head_mask=head_mask,
150
+ inputs_embeds=inputs_embeds,
151
+ output_attentions=output_attentions,
152
+ output_hidden_states=output_hidden_states,
153
+ return_dict=return_dict,
154
+ )
155
+
156
+ sequence_output = outputs[0]
157
+
158
+ prediction_scores = None
159
+ if masked_lm_labels is not None:
160
+ text_offset = input_video_embeds.size(1) + 1 # [CLS]
161
+ # recover caps format: [CLS] [SEP] text [SEP]
162
+ text_sequence_output = torch.cat(
163
+ [sequence_output[:, :1], sequence_output[:, text_offset:]],
164
+ dim=1
165
+ )
166
+
167
+ # only compute select tokens to training to speed up.
168
+ hidden_size = text_sequence_output.size(-1)
169
+ # masked_lm_labels = masked_lm_labels.reshape(-1)
170
+ labels_mask = masked_lm_labels != -100
171
+
172
+ selected_text_output = text_sequence_output.masked_select(
173
+ labels_mask.unsqueeze(-1)
174
+ ).view(-1, hidden_size)
175
+ prediction_scores = self.cls(selected_text_output)
176
+
177
+ if not return_dict:
178
+ output = (
179
+ prediction_scores,
180
+ ) + outputs[2:]
181
+ return output
182
+
183
+ # for generation.
184
+ text_offset = input_video_embeds.size(1) + 2 # [CLS]
185
+ text_sequence_output = sequence_output[:, text_offset:]
186
+ prediction_scores = self.cls(text_sequence_output)
187
+ return CausalLMOutput(
188
+ loss=None,
189
+ logits=prediction_scores,
190
+ )
191
+
192
+ def prepare_inputs_for_generation(
193
+ self,
194
+ input_ids,
195
+ input_video_embeds,
196
+ attention_mask=None,
197
+ token_type_ids=None,
198
+ **model_kwargs
199
+ ):
200
+ # must return a dictionary.
201
+ seq_len = input_ids.size(1) + input_video_embeds.size(1)
202
+ if attention_mask is not None:
203
+ if len(attention_mask.size()) == 4:
204
+ attention_mask = attention_mask[:, :, :seq_len, :seq_len]
205
+ elif len(attention_mask.size()) == 3:
206
+ attention_mask = attention_mask[:, :seq_len, :seq_len]
207
+ else:
208
+ attention_mask = attention_mask[:, :seq_len]
209
+ if token_type_ids is not None:
210
+ token_type_ids = token_type_ids[:, :seq_len]
211
+
212
+ return {
213
+ "input_ids": input_ids,
214
+ "input_video_embeds": input_video_embeds,
215
+ "attention_mask": attention_mask,
216
+ "token_type_ids": token_type_ids,
217
+ }
218
+
219
+ @torch.no_grad()
220
+ def generate(
221
+ self,
222
+ input_ids: Optional[torch.LongTensor] = None,
223
+ decoder_input_ids: Optional[torch.LongTensor] = None,
224
+ max_length: Optional[int] = None,
225
+ min_length: Optional[int] = None,
226
+ do_sample: Optional[bool] = None,
227
+ early_stopping: Optional[bool] = None,
228
+ num_beams: Optional[int] = None,
229
+ temperature: Optional[float] = None,
230
+ top_k: Optional[int] = None,
231
+ top_p: Optional[float] = None,
232
+ repetition_penalty: Optional[float] = None,
233
+ bad_words_ids: Optional[Iterable[int]] = None,
234
+ bos_token_id: Optional[int] = None,
235
+ pad_token_id: Optional[int] = None,
236
+ eos_token_id: Optional[int] = None,
237
+ length_penalty: Optional[float] = None,
238
+ no_repeat_ngram_size: Optional[int] = None,
239
+ num_return_sequences: Optional[int] = None,
240
+ attention_mask: Optional[torch.LongTensor] = None,
241
+ decoder_start_token_id: Optional[int] = None,
242
+ use_cache: Optional[bool] = None,
243
+ **model_kwargs
244
+ ) -> torch.LongTensor:
245
+ r"""
246
+ Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
247
+ beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling.
248
+ Adapted in part from `Facebook's XLM beam search code
249
+ <https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529>`__.
250
+ Apart from :obj:`input_ids` and :obj:`attention_mask`, all the arguments below will default to the value of the
251
+ attribute of the same name inside the :class:`~transformers.PretrainedConfig` of the model. The default values
252
+ indicated are the default values of those config.
253
+ Most of these parameters are explained in more detail in `this blog post
254
+ <https://huggingface.co/blog/how-to-generate>`__.
255
+ Parameters:
256
+ input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
257
+ The sequence used as a prompt for the generation. If :obj:`None` the method initializes
258
+ it as an empty :obj:`torch.LongTensor` of shape :obj:`(1,)`.
259
+ decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
260
+ initial input_ids for the decoder of encoder-decoder type models. If :obj:`None` then only
261
+ decoder_start_token_id is passed as the first token to the decoder.
262
+ max_length (:obj:`int`, `optional`, defaults to 20):
263
+ The maximum length of the sequence to be generated.
264
+ min_length (:obj:`int`, `optional`, defaults to 10):
265
+ The minimum length of the sequence to be generated.
266
+ do_sample (:obj:`bool`, `optional`, defaults to :obj:`False`):
267
+ Whether or not to use sampling ; use greedy decoding otherwise.
268
+ early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`):
269
+ Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not.
270
+ num_beams (:obj:`int`, `optional`, defaults to 1):
271
+ Number of beams for beam search. 1 means no beam search.
272
+ temperature (:obj:`float`, `optional`, defaults tp 1.0):
273
+ The value used to module the next token probabilities.
274
+ top_k (:obj:`int`, `optional`, defaults to 50):
275
+ The number of highest probability vocabulary tokens to keep for top-k-filtering.
276
+ top_p (:obj:`float`, `optional`, defaults to 1.0):
277
+ If set to float < 1, only the most probable tokens with probabilities that add up to ``top_p`` or
278
+ higher are kept for generation.
279
+ repetition_penalty (:obj:`float`, `optional`, defaults to 1.0):
280
+ The parameter for repetition penalty. 1.0 means no penalty. See `this paper
281
+ <https://arxiv.org/pdf/1909.05858.pdf>`__ for more details.
282
+ pad_token_id (:obj:`int`, `optional`):
283
+ The id of the `padding` token.
284
+ bos_token_id (:obj:`int`, `optional`):
285
+ The id of the `beginning-of-sequence` token.
286
+ eos_token_id (:obj:`int`, `optional`):
287
+ The id of the `end-of-sequence` token.
288
+ length_penalty (:obj:`float`, `optional`, defaults to 1.0):
289
+ Exponential penalty to the length. 1.0 means no penalty.
290
+ Set to values < 1.0 in order to encourage the model to generate shorter sequences, to a value > 1.0 in
291
+ order to encourage the model to produce longer sequences.
292
+ no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0):
293
+ If set to int > 0, all ngrams of that size can only occur once.
294
+ bad_words_ids(:obj:`List[int]`, `optional`):
295
+ List of token ids that are not allowed to be generated. In order to get the tokens of the words that
296
+ should not appear in the generated text, use :obj:`tokenizer.encode(bad_word, add_prefix_space=True)`.
297
+ num_return_sequences(:obj:`int`, `optional`, defaults to 1):
298
+ The number of independently computed returned sequences for each element in the batch.
299
+ attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
300
+ Mask to avoid performing attention on padding token indices. Mask values are in ``[0, 1]``, 1 for
301
+ tokens that are not masked, and 0 for masked tokens.
302
+ If not provided, will default to a tensor the same shape as :obj:`input_ids` that masks the pad token.
303
+ `What are attention masks? <../glossary.html#attention-mask>`__
304
+ decoder_start_token_id (:obj:`int`, `optional`):
305
+ If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token.
306
+ use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`):
307
+ Whether or not the model should use the past last key/values attentions (if applicable to the model) to
308
+ speed up decoding.
309
+ model_kwargs:
310
+ Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model.
311
+ Return:
312
+ :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`:
313
+ The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or
314
+ shorter if all batches finished early due to the :obj:`eos_token_id`.
315
+ Examples::
316
+ tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
317
+ model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
318
+ outputs = model.generate(max_length=40) # do greedy decoding
319
+ print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
320
+ tokenizer = AutoTokenizer.from_pretrained('openai-gpt') # Initialize tokenizer
321
+ model = AutoModelWithLMHead.from_pretrained('openai-gpt') # Download model and configuration from S3 and cache.
322
+ input_context = 'The dog'
323
+ input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
324
+ outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5) # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog'
325
+ for i in range(3): # 3 output sequences were generated
326
+ print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
327
+ tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
328
+ model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
329
+ input_context = 'The dog'
330
+ input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
331
+ outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3, do_sample=True) # generate 3 candidates using sampling
332
+ for i in range(3): # 3 output sequences were generated
333
+ print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
334
+ tokenizer = AutoTokenizer.from_pretrained('ctrl') # Initialize tokenizer
335
+ model = AutoModelWithLMHead.from_pretrained('ctrl') # Download model and configuration from S3 and cache.
336
+ input_context = 'Legal My neighbor is' # "Legal" is one of the control codes for ctrl
337
+ input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
338
+ outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences
339
+ print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
340
+ tokenizer = AutoTokenizer.from_pretrained('gpt2') # Initialize tokenizer
341
+ model = AutoModelWithLMHead.from_pretrained('gpt2') # Download model and configuration from S3 and cache.
342
+ input_context = 'My cute dog' # "Legal" is one of the control codes for ctrl
343
+ bad_words_ids = [tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ['idiot', 'stupid', 'shut up']]
344
+ input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
345
+ outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids) # generate sequences without allowing bad_words to be generated
346
+ """
347
+
348
+ # We cannot generate if the model does not have a LM head
349
+ if self.get_output_embeddings() is None:
350
+ raise AttributeError(
351
+ "You tried to generate sequences with a model that does not have a LM Head."
352
+ "Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`, `XLMWithLMHeadModel`, `BartForConditionalGeneration` )"
353
+ )
354
+
355
+ max_length = max_length if max_length is not None else self.config.max_length
356
+ min_length = min_length if min_length is not None else self.config.min_length
357
+ do_sample = do_sample if do_sample is not None else self.config.do_sample
358
+ early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
359
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
360
+ num_beams = num_beams if num_beams is not None else self.config.num_beams
361
+ temperature = temperature if temperature is not None else self.config.temperature
362
+ top_k = top_k if top_k is not None else self.config.top_k
363
+ top_p = top_p if top_p is not None else self.config.top_p
364
+ repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
365
+ bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
366
+ pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
367
+ eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
368
+ length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
369
+ no_repeat_ngram_size = (
370
+ no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
371
+ )
372
+ bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
373
+ num_return_sequences = (
374
+ num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
375
+ )
376
+ decoder_start_token_id = (
377
+ decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
378
+ )
379
+
380
+ if input_ids is not None:
381
+ batch_size = input_ids.shape[0] # overriden by the input batch_size
382
+ else:
383
+ batch_size = 1
384
+
385
+ assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
386
+ assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
387
+ assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
388
+ assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
389
+ assert isinstance(use_cache, bool), "`use_cache` should be a boolean."
390
+ assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer."
391
+ assert temperature > 0, "`temperature` should be strictly positive."
392
+ assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
393
+ assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
394
+ assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
395
+ assert input_ids is not None or (
396
+ isinstance(bos_token_id, int) and bos_token_id >= 0
397
+ ), "If input_ids is not defined, `bos_token_id` should be a positive integer."
398
+ assert pad_token_id is None or (
399
+ isinstance(pad_token_id, int) and (pad_token_id >= 0)
400
+ ), "`pad_token_id` should be a positive integer."
401
+ assert (eos_token_id is None) or (
402
+ isinstance(eos_token_id, int) and (eos_token_id >= 0)
403
+ ), "`eos_token_id` should be a positive integer."
404
+ assert length_penalty > 0, "`length_penalty` should be strictly positive."
405
+ assert (
406
+ isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0
407
+ ), "`no_repeat_ngram_size` should be a positive integer."
408
+ assert (
409
+ isinstance(num_return_sequences, int) and num_return_sequences > 0
410
+ ), "`num_return_sequences` should be a strictly positive integer."
411
+ assert (
412
+ bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list)
413
+ ), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated"
414
+
415
+ if input_ids is None:
416
+ assert isinstance(bos_token_id, int) and bos_token_id >= 0, (
417
+ "you should either supply a context to complete as `input_ids` input "
418
+ "or a `bos_token_id` (integer >= 0) as a first token to start the generation."
419
+ )
420
+ input_ids = torch.full(
421
+ (batch_size, 1),
422
+ bos_token_id,
423
+ dtype=torch.long,
424
+ device=next(self.parameters()).device,
425
+ )
426
+ else:
427
+ assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)."
428
+
429
+ # not allow to duplicate outputs when greedy decoding
430
+ if do_sample is False:
431
+ if num_beams == 1:
432
+ # no_beam_search greedy generation conditions
433
+ assert (
434
+ num_return_sequences == 1
435
+ ), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1"
436
+
437
+ else:
438
+ # beam_search greedy generation conditions
439
+ assert (
440
+ num_beams >= num_return_sequences
441
+ ), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"
442
+
443
+ # create attention mask if necessary
444
+ # TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140
445
+ if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids):
446
+ attention_mask = input_ids.ne(pad_token_id).long()
447
+ elif attention_mask is None:
448
+ attention_mask = input_ids.new_ones(input_ids.shape)
449
+
450
+ # set pad_token_id to eos_token_id if not set. Important that this is done after
451
+ # attention_mask is created
452
+ if pad_token_id is None and eos_token_id is not None:
453
+ print(
454
+ "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id)
455
+ )
456
+ pad_token_id = eos_token_id
457
+
458
+ # vocab size
459
+ if hasattr(self.config, "vocab_size"):
460
+ vocab_size = self.config.vocab_size
461
+ elif (
462
+ self.config.is_encoder_decoder
463
+ and hasattr(self.config, "decoder")
464
+ and hasattr(self.config.decoder, "vocab_size")
465
+ ):
466
+ vocab_size = self.config.decoder.vocab_size
467
+ else:
468
+ raise ValueError("either self.config.vocab_size or self.config.decoder.vocab_size needs to be defined")
469
+
470
+ # set effective batch size and effective batch multiplier according to do_sample
471
+ if do_sample:
472
+ effective_batch_size = batch_size * num_return_sequences
473
+ effective_batch_mult = num_return_sequences
474
+ else:
475
+ effective_batch_size = batch_size
476
+ effective_batch_mult = 1
477
+
478
+ if self.config.is_encoder_decoder:
479
+ if decoder_start_token_id is None:
480
+ # see if BOS token can be used for decoder_start_token_id
481
+ if bos_token_id is not None:
482
+ decoder_start_token_id = bos_token_id
483
+ elif (
484
+ hasattr(self.config, "decoder")
485
+ and hasattr(self.config.decoder, "bos_token_id")
486
+ and self.config.decoder.bos_token_id is not None
487
+ ):
488
+ decoder_start_token_id = self.config.decoder.bos_token_id
489
+ else:
490
+ raise ValueError(
491
+ "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
492
+ )
493
+
494
+ assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
495
+ assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)
496
+
497
+ # get encoder and store encoder outputs
498
+ encoder = self.get_encoder()
499
+ encoder_outputs: ModelOutput = encoder(input_ids, attention_mask=attention_mask, return_dict=True)
500
+
501
+ # Expand input ids if num_beams > 1 or num_return_sequences > 1
502
+ if num_return_sequences > 1 or num_beams > 1:
503
+ # TODO: make this a call-back function.
504
+ # input_ids=caps,
505
+ # input_video_embeds=vfeats,
506
+ # attention_mask=attention_mask,
507
+ # token_type_ids=token_type_ids,
508
+ input_video_embeds = model_kwargs.pop("input_video_embeds", None)
509
+ token_type_ids = model_kwargs.pop("token_type_ids", None)
510
+
511
+ input_ids_len = input_ids.shape[-1]
512
+ input_ids = input_ids.unsqueeze(1).expand(
513
+ batch_size, effective_batch_mult * num_beams, input_ids_len)
514
+
515
+ input_video_embeds_len, input_video_embeds_hidden = input_video_embeds.size(1), input_video_embeds.size(2)
516
+ input_video_embeds = input_video_embeds.unsqueeze(1).expand(
517
+ batch_size, effective_batch_mult * num_beams, input_video_embeds_len, input_video_embeds_hidden)
518
+
519
+ attention_mask_from_len, attention_mask_to_len = attention_mask.size(1), attention_mask.size(2)
520
+ attention_mask = attention_mask.unsqueeze(1).expand(
521
+ batch_size, effective_batch_mult * num_beams, attention_mask_from_len, attention_mask_to_len
522
+ )
523
+
524
+ token_type_ids_len = token_type_ids.size(1)
525
+ token_type_ids = token_type_ids.unsqueeze(1).expand(
526
+ batch_size, effective_batch_mult * num_beams, token_type_ids_len
527
+ )
528
+
529
+ # contiguous ...
530
+ input_ids = input_ids.contiguous().view(
531
+ effective_batch_size * num_beams, input_ids_len
532
+ ) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
533
+
534
+ input_video_embeds = input_video_embeds.contiguous().view(
535
+ effective_batch_size * num_beams, input_video_embeds_len, input_video_embeds_hidden)
536
+
537
+ attention_mask = attention_mask.contiguous().view(
538
+ effective_batch_size * num_beams, attention_mask_from_len, attention_mask_to_len
539
+ ) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
540
+
541
+ token_type_ids = token_type_ids.contiguous().view(
542
+ effective_batch_size * num_beams, token_type_ids_len
543
+ )
544
+
545
+ model_kwargs["input_video_embeds"] = input_video_embeds
546
+ model_kwargs["token_type_ids"] = token_type_ids
547
+
548
+ if self.config.is_encoder_decoder:
549
+ device = next(self.parameters()).device
550
+ if decoder_input_ids is not None:
551
+ # give initial decoder input ids
552
+ input_ids = decoder_input_ids.repeat(effective_batch_size * num_beams, 1).to(device)
553
+ else:
554
+ # create empty decoder input_ids
555
+ input_ids = torch.full(
556
+ (effective_batch_size * num_beams, 1),
557
+ decoder_start_token_id,
558
+ dtype=torch.long,
559
+ device=device,
560
+ )
561
+ cur_len = input_ids.shape[-1]
562
+
563
+ assert (
564
+ batch_size == encoder_outputs.last_hidden_state.shape[0]
565
+ ), f"expected encoder_outputs.last_hidden_state to have 1st dimension bs={batch_size}, got {encoder_outputs.last_hidden_state.shape[0]} "
566
+
567
+ # expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1)
568
+ expanded_batch_idxs = (
569
+ torch.arange(batch_size)
570
+ .view(-1, 1)
571
+ .repeat(1, num_beams * effective_batch_mult)
572
+ .view(-1)
573
+ .to(input_ids.device)
574
+ )
575
+
576
+ # expand encoder_outputs
577
+ encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
578
+ 0, expanded_batch_idxs
579
+ )
580
+
581
+ # save encoder_outputs in `model_kwargs`
582
+ model_kwargs["encoder_outputs"] = encoder_outputs
583
+
584
+ else:
585
+ cur_len = input_ids.shape[-1]
586
+
587
+ assert (
588
+ cur_len < max_length
589
+ ), f"The context has {cur_len} number of tokens, but `max_length` is only {max_length}. Please make sure that `max_length` is bigger than the number of tokens, by setting either `generate(max_length=...,...)` or `config.max_length = ...`"
590
+
591
+ if num_beams > 1:
592
+ output = self._generate_beam_search(
593
+ input_ids,
594
+ cur_len=cur_len,
595
+ max_length=max_length,
596
+ min_length=min_length,
597
+ do_sample=do_sample,
598
+ early_stopping=early_stopping,
599
+ temperature=temperature,
600
+ top_k=top_k,
601
+ top_p=top_p,
602
+ repetition_penalty=repetition_penalty,
603
+ no_repeat_ngram_size=no_repeat_ngram_size,
604
+ bad_words_ids=bad_words_ids,
605
+ pad_token_id=pad_token_id,
606
+ eos_token_id=eos_token_id,
607
+ batch_size=effective_batch_size,
608
+ num_return_sequences=num_return_sequences,
609
+ length_penalty=length_penalty,
610
+ num_beams=num_beams,
611
+ vocab_size=vocab_size,
612
+ attention_mask=attention_mask,
613
+ use_cache=use_cache,
614
+ model_kwargs=model_kwargs,
615
+ )
616
+ else:
617
+ output = self._generate_no_beam_search(
618
+ input_ids,
619
+ cur_len=cur_len,
620
+ max_length=max_length,
621
+ min_length=min_length,
622
+ do_sample=do_sample,
623
+ temperature=temperature,
624
+ top_k=top_k,
625
+ top_p=top_p,
626
+ repetition_penalty=repetition_penalty,
627
+ no_repeat_ngram_size=no_repeat_ngram_size,
628
+ bad_words_ids=bad_words_ids,
629
+ pad_token_id=pad_token_id,
630
+ eos_token_id=eos_token_id,
631
+ batch_size=effective_batch_size,
632
+ attention_mask=attention_mask,
633
+ use_cache=use_cache,
634
+ model_kwargs=model_kwargs,
635
+ )
636
+
637
+ return output
638
+
639
+ def _generate_beam_search(
640
+ self,
641
+ input_ids,
642
+ cur_len,
643
+ max_length,
644
+ min_length,
645
+ do_sample,
646
+ early_stopping,
647
+ temperature,
648
+ top_k,
649
+ top_p,
650
+ repetition_penalty,
651
+ no_repeat_ngram_size,
652
+ bad_words_ids,
653
+ pad_token_id,
654
+ eos_token_id,
655
+ batch_size,
656
+ num_return_sequences,
657
+ length_penalty,
658
+ num_beams,
659
+ vocab_size,
660
+ attention_mask,
661
+ use_cache,
662
+ model_kwargs,
663
+ ):
664
+ """Generate sequences for each example with beam search."""
665
+
666
+ # generated hypotheses
667
+ generated_hyps = [
668
+ BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping)
669
+ for _ in range(batch_size)
670
+ ]
671
+
672
+ # scores for each sentence in the beam
673
+ beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
674
+
675
+ # for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
676
+ if do_sample is False:
677
+ beam_scores[:, 1:] = -1e9
678
+ beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
679
+
680
+ # cache compute states
681
+ past = None
682
+
683
+ # done sentences
684
+ done = [False for _ in range(batch_size)]
685
+
686
+ while cur_len < max_length:
687
+ model_inputs = self.prepare_inputs_for_generation(
688
+ input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs
689
+ )
690
+ outputs = self(**model_inputs, return_dict=True) # (batch_size * num_beams, cur_len, vocab_size)
691
+ next_token_logits = outputs.logits[:, -1, :] # (batch_size * num_beams, vocab_size)
692
+
693
+ # if model has past, then set the past variable to speed up decoding
694
+ if "past_key_values" in outputs:
695
+ past = outputs.past_key_values
696
+ elif "mems" in outputs:
697
+ past = outputs.mems
698
+
699
+ if self.config.is_encoder_decoder and do_sample is False:
700
+ # TODO (PVP) still a bit hacky here - there might be a better solution
701
+ next_token_logits = self.adjust_logits_during_generation(
702
+ next_token_logits, cur_len=cur_len, max_length=max_length
703
+ )
704
+
705
+ scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
706
+
707
+ scores = self.postprocess_next_token_scores(
708
+ scores=scores,
709
+ input_ids=input_ids,
710
+ no_repeat_ngram_size=no_repeat_ngram_size,
711
+ bad_words_ids=bad_words_ids,
712
+ cur_len=cur_len,
713
+ min_length=min_length,
714
+ max_length=max_length,
715
+ eos_token_id=eos_token_id,
716
+ repetition_penalty=repetition_penalty,
717
+ batch_size=batch_size,
718
+ num_beams=num_beams,
719
+ )
720
+
721
+ assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format(
722
+ scores.shape, (batch_size * num_beams, vocab_size)
723
+ )
724
+
725
+ if do_sample:
726
+ _scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
727
+ # Temperature
728
+ if temperature != 1.0:
729
+ _scores = _scores / temperature
730
+ # Top-p/top-k filtering
731
+ _scores = top_k_top_p_filtering(
732
+ _scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
733
+ ) # (batch_size * num_beams, vocab_size)
734
+ # re-organize to group the beam together to sample from all beam_idxs
735
+ _scores = _scores.contiguous().view(
736
+ batch_size, num_beams * vocab_size
737
+ ) # (batch_size, num_beams * vocab_size)
738
+
739
+ # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
740
+ probs = F.softmax(_scores, dim=-1)
741
+ next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) # (batch_size, num_beams * 2)
742
+ # Compute next scores
743
+ next_scores = torch.gather(_scores, -1, next_tokens) # (batch_size, num_beams * 2)
744
+ # sort the sampled vector to make sure that the first num_beams samples are the best
745
+ next_scores, next_scores_indices = torch.sort(next_scores, descending=True, dim=1)
746
+ next_tokens = torch.gather(next_tokens, -1, next_scores_indices) # (batch_size, num_beams * 2)
747
+
748
+ else:
749
+ next_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
750
+
751
+ # re-organize to group the beam together (we are keeping top hypothesis accross beams)
752
+ next_scores = next_scores.view(
753
+ batch_size, num_beams * vocab_size
754
+ ) # (batch_size, num_beams * vocab_size)
755
+
756
+ next_scores, next_tokens = torch.topk(next_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
757
+
758
+ assert next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams)
759
+
760
+ # next batch beam content
761
+ next_batch_beam = []
762
+
763
+ # for each sentence
764
+ for batch_idx in range(batch_size):
765
+
766
+ # if we are done with this sentence, add a pad token
767
+ if done[batch_idx]:
768
+ assert (
769
+ len(generated_hyps[batch_idx]) >= num_beams
770
+ ), "Batch can only be done if at least {} beams have been generated".format(num_beams)
771
+ assert (
772
+ eos_token_id is not None and pad_token_id is not None
773
+ ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
774
+ next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
775
+ continue
776
+
777
+ # next sentence beam content, this will get added to next_batch_beam
778
+ next_sent_beam = []
779
+
780
+ # next tokens for this sentence
781
+ for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
782
+ zip(next_tokens[batch_idx], next_scores[batch_idx])
783
+ ):
784
+ # get beam and token IDs
785
+ beam_id = beam_token_id // vocab_size
786
+ token_id = beam_token_id % vocab_size
787
+
788
+ effective_beam_id = batch_idx * num_beams + beam_id
789
+ # add to generated hypotheses if end of sentence
790
+ if (eos_token_id is not None) and (token_id.item() == eos_token_id):
791
+ # if beam_token does not belong to top num_beams tokens, it should not be added
792
+ is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
793
+ if is_beam_token_worse_than_top_num_beams:
794
+ continue
795
+ generated_hyps[batch_idx].add(
796
+ input_ids[effective_beam_id].clone(),
797
+ beam_token_score.item(),
798
+ )
799
+ else:
800
+ # add next predicted token since it is not eos_token
801
+ next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
802
+
803
+ # once the beam for next step is full, don't add more tokens to it.
804
+ if len(next_sent_beam) == num_beams:
805
+ break
806
+
807
+ # Check if we are done so that we can save a pad step if all(done)
808
+ done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
809
+ next_scores[batch_idx].max().item(), cur_len
810
+ )
811
+
812
+ # update next beam content
813
+ assert len(next_sent_beam) == num_beams, "Beam should always be full"
814
+ next_batch_beam.extend(next_sent_beam)
815
+ assert len(next_batch_beam) == num_beams * (batch_idx + 1), "We should have added num_beams each step"
816
+
817
+ # stop when we are done with each sentence
818
+ if all(done):
819
+ break
820
+
821
+ # sanity check / prepare next batch
822
+ assert len(next_batch_beam) == batch_size * num_beams
823
+ beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
824
+ beam_tokens = input_ids.new([x[1] for x in next_batch_beam])
825
+ beam_idx = input_ids.new([x[2] for x in next_batch_beam])
826
+
827
+ # re-order batch and update current length
828
+ input_ids = input_ids[beam_idx, :]
829
+ input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)
830
+ cur_len = cur_len + 1
831
+
832
+ # re-order internal states
833
+ if past is not None:
834
+ past = self._reorder_cache(past, beam_idx)
835
+
836
+ # extend attention_mask for new generated input if only decoder
837
+ # (huxu): move out since we trim attention_mask by ourselves.
838
+ # if self.config.is_encoder_decoder is False:
839
+ # attention_mask = torch.cat(
840
+ # [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
841
+ # )
842
+
843
+ # finalize all open beam hypotheses and add to generated hypotheses
844
+ for batch_idx in range(batch_size):
845
+ if done[batch_idx]:
846
+ continue
847
+
848
+ # test that beam scores match previously calculated scores if not eos and batch_idx not done
849
+ if eos_token_id is not None and all(
850
+ (token_id % vocab_size).item() != eos_token_id for token_id in next_tokens[batch_idx]
851
+ ):
852
+ assert torch.all(
853
+ next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx]
854
+ ), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format(
855
+ next_scores[:, :num_beams][batch_idx],
856
+ beam_scores.view(batch_size, num_beams)[batch_idx],
857
+ )
858
+
859
+ # need to add best num_beams hypotheses to generated hyps
860
+ for beam_id in range(num_beams):
861
+ effective_beam_id = batch_idx * num_beams + beam_id
862
+ final_score = beam_scores[effective_beam_id].item()
863
+ final_tokens = input_ids[effective_beam_id]
864
+ generated_hyps[batch_idx].add(final_tokens, final_score)
865
+
866
+ # depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
867
+ output_batch_size = batch_size if do_sample else batch_size * num_return_sequences
868
+ output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences
869
+
870
+ # select the best hypotheses
871
+ sent_lengths = input_ids.new(output_batch_size)
872
+ best = []
873
+
874
+ # retrieve best hypotheses
875
+ for i, hypotheses in enumerate(generated_hyps):
876
+ sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])
877
+ for j in range(output_num_return_sequences_per_batch):
878
+ effective_batch_idx = output_num_return_sequences_per_batch * i + j
879
+ best_hyp = sorted_hyps.pop()[1]
880
+ sent_lengths[effective_batch_idx] = len(best_hyp)
881
+ best.append(best_hyp)
882
+
883
+ # prepare for adding eos
884
+ sent_max_len = min(sent_lengths.max().item() + 1, max_length)
885
+ decoded = input_ids.new(output_batch_size, sent_max_len)
886
+ # shorter batches are padded if needed
887
+ if sent_lengths.min().item() != sent_lengths.max().item():
888
+ assert pad_token_id is not None, "`pad_token_id` has to be defined"
889
+ decoded.fill_(pad_token_id)
890
+
891
+ # fill with hypotheses and eos_token_id if the latter fits in
892
+ for i, hypo in enumerate(best):
893
+ decoded[i, : sent_lengths[i]] = hypo
894
+ if sent_lengths[i] < max_length:
895
+ decoded[i, sent_lengths[i]] = eos_token_id
896
+
897
+ return decoded
898
+
899
+ def _generate_no_beam_search(
900
+ self,
901
+ input_ids,
902
+ cur_len,
903
+ max_length,
904
+ min_length,
905
+ do_sample,
906
+ temperature,
907
+ top_k,
908
+ top_p,
909
+ repetition_penalty,
910
+ no_repeat_ngram_size,
911
+ bad_words_ids,
912
+ pad_token_id,
913
+ eos_token_id,
914
+ batch_size,
915
+ attention_mask,
916
+ use_cache,
917
+ model_kwargs,
918
+ ):
919
+ """Generate sequences for each example without beam search (num_beams == 1).
920
+ All returned sequence are generated independantly.
921
+ """
922
+ # length of generated sentences / unfinished sentences
923
+ unfinished_sents = input_ids.new(batch_size).fill_(1)
924
+ sent_lengths = input_ids.new(batch_size).fill_(max_length)
925
+
926
+ past = None
927
+ while cur_len < max_length:
928
+ model_inputs = self.prepare_inputs_for_generation(
929
+ input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs
930
+ )
931
+
932
+ outputs = self(**model_inputs, return_dict=True)
933
+ next_token_logits = outputs.logits[:, -1, :]
934
+ scores = self.postprocess_next_token_scores(
935
+ scores=next_token_logits,
936
+ input_ids=input_ids,
937
+ no_repeat_ngram_size=no_repeat_ngram_size,
938
+ bad_words_ids=bad_words_ids,
939
+ cur_len=cur_len,
940
+ min_length=min_length,
941
+ max_length=max_length,
942
+ eos_token_id=eos_token_id,
943
+ repetition_penalty=repetition_penalty,
944
+ batch_size=batch_size,
945
+ num_beams=1,
946
+ )
947
+
948
+ # if model has past, then set the past variable to speed up decoding
949
+ if "past_key_values" in outputs:
950
+ past = outputs.past_key_values
951
+ elif "mems" in outputs:
952
+ past = outputs.mems
953
+
954
+ if do_sample:
955
+ # Temperature (higher temperature => more likely to sample low probability tokens)
956
+ if temperature != 1.0:
957
+ scores = scores / temperature
958
+ # Top-p/top-k filtering
959
+ next_token_logscores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p)
960
+ # Sample
961
+ probs = F.softmax(next_token_logscores, dim=-1)
962
+ next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
963
+ else:
964
+ # Greedy decoding
965
+ next_token = torch.argmax(next_token_logits, dim=-1)
966
+
967
+ # print(next_token_logits[0,next_token[0]], next_token_logits[0,eos_token_id])
968
+
969
+ # update generations and finished sentences
970
+ if eos_token_id is not None:
971
+ # pad finished sentences if eos_token_id exist
972
+ tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
973
+ else:
974
+ tokens_to_add = next_token
975
+
976
+ # add token and increase length by one
977
+ input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
978
+ cur_len = cur_len + 1
979
+
980
+ if eos_token_id is not None:
981
+ eos_in_sents = tokens_to_add == eos_token_id
982
+ # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
983
+ is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool()
984
+ sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len)
985
+ # unfinished_sents is set to zero if eos in sentence
986
+ unfinished_sents.mul_((~eos_in_sents).long())
987
+
988
+ # stop when there is a </s> in each sentence, or if we exceed the maximul length
989
+ if unfinished_sents.max() == 0:
990
+ break
991
+
992
+
993
+ # extend attention_mask for new generated input if only decoder
994
+ # if self.config.is_encoder_decoder is False:
995
+ # attention_mask = torch.cat(
996
+ # [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
997
+ # )
998
+
999
+ return input_ids
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/models/transformermodel.py ADDED
@@ -0,0 +1,734 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # Copyright (c) Facebook, Inc. All Rights Reserved
17
+
18
+ import torch
19
+
20
+ from torch import nn
21
+
22
+ try:
23
+ from transformers.modeling_bert import (
24
+ BertPreTrainedModel,
25
+ BertModel,
26
+ BertEncoder,
27
+ BertPredictionHeadTransform,
28
+ )
29
+ except ImportError:
30
+ pass
31
+
32
+ from ..modules import VideoTokenMLP, MMBertEmbeddings
33
+
34
+
35
+ # --------------- fine-tuning models ---------------
36
+ class MMBertForJoint(BertPreTrainedModel):
37
+ """A BertModel with isolated attention mask to separate modality."""
38
+
39
+ def __init__(self, config):
40
+ super().__init__(config)
41
+ self.videomlp = VideoTokenMLP(config)
42
+ self.bert = MMBertModel(config)
43
+ self.init_weights()
44
+
45
+ def forward(
46
+ self,
47
+ input_ids=None,
48
+ input_video_embeds=None,
49
+ attention_mask=None,
50
+ token_type_ids=None,
51
+ position_ids=None,
52
+ head_mask=None,
53
+ inputs_embeds=None,
54
+ next_sentence_label=None,
55
+ output_attentions=None,
56
+ output_hidden_states=None,
57
+ return_dict=None,
58
+ separate_forward_split=None,
59
+ ):
60
+ return_dict = (
61
+ return_dict if return_dict is not None
62
+ else self.config.use_return_dict
63
+ )
64
+ video_tokens = self.videomlp(input_video_embeds)
65
+
66
+ outputs = self.bert(
67
+ input_ids,
68
+ video_tokens,
69
+ attention_mask=attention_mask,
70
+ token_type_ids=token_type_ids,
71
+ position_ids=position_ids,
72
+ head_mask=head_mask,
73
+ inputs_embeds=inputs_embeds,
74
+ output_attentions=output_attentions,
75
+ output_hidden_states=output_hidden_states,
76
+ return_dict=return_dict,
77
+ separate_forward_split=separate_forward_split,
78
+ )
79
+
80
+ return outputs
81
+
82
+
83
+ class MMBertForTokenClassification(BertPreTrainedModel):
84
+ """A BertModel similar to MMJointUni, with extra wrapper layer
85
+ to be fine-tuned from other pretrained MMFusion model."""
86
+
87
+ def __init__(self, config):
88
+ super().__init__(config)
89
+ self.videomlp = VideoTokenMLP(config)
90
+ self.bert = MMBertModel(config)
91
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
92
+ # TODO(huxu): 779 is the number of classes for COIN: move to config?
93
+ self.classifier = nn.Linear(config.hidden_size, 779)
94
+ self.init_weights()
95
+
96
+ def forward(
97
+ self,
98
+ input_ids=None,
99
+ input_video_embeds=None,
100
+ attention_mask=None,
101
+ token_type_ids=None,
102
+ position_ids=None,
103
+ head_mask=None,
104
+ inputs_embeds=None,
105
+ next_sentence_label=None,
106
+ output_attentions=None,
107
+ output_hidden_states=None,
108
+ return_dict=None,
109
+ separate_forward_split=None,
110
+ ):
111
+ return_dict = (
112
+ return_dict if return_dict is not None
113
+ else self.config.use_return_dict
114
+ )
115
+
116
+ video_tokens = self.videomlp(input_video_embeds)
117
+ outputs = self.bert(
118
+ input_ids,
119
+ video_tokens,
120
+ attention_mask=attention_mask,
121
+ token_type_ids=token_type_ids,
122
+ position_ids=position_ids,
123
+ head_mask=head_mask,
124
+ inputs_embeds=inputs_embeds,
125
+ output_attentions=output_attentions,
126
+ output_hidden_states=output_hidden_states,
127
+ return_dict=return_dict,
128
+ separate_forward_split=separate_forward_split,
129
+ )
130
+
131
+ return (self.classifier(outputs[0]),)
132
+
133
+
134
+ # ------------ pre-training models ----------------
135
+
136
+ class MMBertForEncoder(BertPreTrainedModel):
137
+ """A BertModel for Contrastive Learning."""
138
+ def __init__(self, config):
139
+ super().__init__(config)
140
+ self.videomlp = VideoTokenMLP(config)
141
+ self.bert = MMBertModel(config)
142
+ self.init_weights()
143
+
144
+ def forward(
145
+ self,
146
+ input_ids=None,
147
+ input_video_embeds=None,
148
+ attention_mask=None,
149
+ token_type_ids=None,
150
+ position_ids=None,
151
+ head_mask=None,
152
+ inputs_embeds=None,
153
+ output_attentions=None,
154
+ output_hidden_states=None,
155
+ return_dict=None,
156
+ ):
157
+ return_dict = (
158
+ return_dict if return_dict is not None
159
+ else self.config.use_return_dict
160
+ )
161
+ if input_video_embeds is not None:
162
+ video_tokens = self.videomlp(input_video_embeds)
163
+ else:
164
+ video_tokens = None
165
+
166
+ outputs = self.bert(
167
+ input_ids,
168
+ video_tokens,
169
+ attention_mask=attention_mask,
170
+ token_type_ids=token_type_ids,
171
+ position_ids=position_ids,
172
+ head_mask=head_mask,
173
+ inputs_embeds=inputs_embeds,
174
+ output_attentions=output_attentions,
175
+ output_hidden_states=output_hidden_states,
176
+ return_dict=return_dict,
177
+ )
178
+ return outputs
179
+
180
+
181
+ class MMBertForMFMMLM(BertPreTrainedModel):
182
+ """A BertModel with shared prediction head on MFM-MLM."""
183
+ def __init__(self, config):
184
+ super().__init__(config)
185
+ self.videomlp = VideoTokenMLP(config)
186
+ self.bert = MMBertModel(config)
187
+ self.cls = MFMMLMHead(config)
188
+ self.hidden_size = config.hidden_size
189
+ self.init_weights()
190
+
191
+ def get_output_embeddings(self):
192
+ return self.cls.predictions.decoder
193
+
194
+ def forward(
195
+ self,
196
+ input_ids=None,
197
+ input_video_embeds=None,
198
+ attention_mask=None,
199
+ token_type_ids=None,
200
+ position_ids=None,
201
+ head_mask=None,
202
+ inputs_embeds=None,
203
+ masked_frame_labels=None,
204
+ target_video_hidden_states=None,
205
+ non_masked_frame_mask=None,
206
+ masked_lm_labels=None,
207
+ output_attentions=None,
208
+ output_hidden_states=None,
209
+ return_dict=None,
210
+ ):
211
+ return_dict = (
212
+ return_dict if return_dict is not None
213
+ else self.config.use_return_dict
214
+ )
215
+ if input_video_embeds is not None:
216
+ video_tokens = self.videomlp(input_video_embeds)
217
+ else:
218
+ video_tokens = None
219
+
220
+ if target_video_hidden_states is not None:
221
+ target_video_hidden_states = self.videomlp(
222
+ target_video_hidden_states)
223
+
224
+ non_masked_frame_hidden_states = video_tokens.masked_select(
225
+ non_masked_frame_mask.unsqueeze(-1)
226
+ ).view(-1, self.hidden_size)
227
+
228
+ outputs = self.bert(
229
+ input_ids,
230
+ video_tokens,
231
+ attention_mask=attention_mask,
232
+ token_type_ids=token_type_ids,
233
+ position_ids=position_ids,
234
+ head_mask=head_mask,
235
+ inputs_embeds=inputs_embeds,
236
+ output_attentions=output_attentions,
237
+ output_hidden_states=output_hidden_states,
238
+ return_dict=return_dict,
239
+ )
240
+
241
+ sequence_output = outputs[0]
242
+
243
+ mfm_scores, prediction_scores = None, None
244
+ if masked_frame_labels is not None and masked_lm_labels is not None:
245
+ # split the sequence.
246
+ text_offset = masked_frame_labels.size(1) + 1 # [CLS]
247
+ video_sequence_output = sequence_output[
248
+ :, 1:text_offset
249
+ ] # remove [SEP] as not in video_label.
250
+ text_sequence_output = torch.cat(
251
+ [sequence_output[:, :1], sequence_output[:, text_offset:]],
252
+ dim=1
253
+ )
254
+
255
+ hidden_size = video_sequence_output.size(-1)
256
+ selected_video_output = video_sequence_output.masked_select(
257
+ masked_frame_labels.unsqueeze(-1)
258
+ ).view(-1, hidden_size)
259
+
260
+ # only compute select tokens to training to speed up.
261
+ hidden_size = text_sequence_output.size(-1)
262
+ # masked_lm_labels = masked_lm_labels.reshape(-1)
263
+ labels_mask = masked_lm_labels != -100
264
+
265
+ selected_text_output = text_sequence_output.masked_select(
266
+ labels_mask.unsqueeze(-1)
267
+ ).view(-1, hidden_size)
268
+ mfm_scores, prediction_scores = self.cls(
269
+ selected_video_output,
270
+ target_video_hidden_states,
271
+ non_masked_frame_hidden_states,
272
+ selected_text_output,
273
+ )
274
+
275
+ output = (
276
+ mfm_scores,
277
+ prediction_scores,
278
+ ) + outputs
279
+ return output
280
+
281
+
282
+ class BertMFMMLMPredictionHead(nn.Module):
283
+ def __init__(self, config):
284
+ super().__init__()
285
+ self.transform = BertPredictionHeadTransform(config)
286
+ # The output weights are the same as the input embeddings, but there is
287
+ # an output-only bias for each token.
288
+ self.decoder = nn.Linear(
289
+ config.hidden_size, config.vocab_size, bias=False)
290
+
291
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
292
+
293
+ # Need a link between the two variables so that the bias is correctly
294
+ # resized with `resize_token_embeddings`
295
+ self.decoder.bias = self.bias
296
+
297
+ def forward(
298
+ self,
299
+ video_hidden_states=None,
300
+ target_video_hidden_states=None,
301
+ non_masked_frame_hidden_states=None,
302
+ text_hidden_states=None,
303
+ ):
304
+ video_logits, text_logits = None, None
305
+ if video_hidden_states is not None:
306
+ video_hidden_states = self.transform(video_hidden_states)
307
+ non_masked_frame_logits = torch.mm(
308
+ video_hidden_states,
309
+ non_masked_frame_hidden_states.transpose(1, 0)
310
+ )
311
+ masked_frame_logits = torch.bmm(
312
+ video_hidden_states.unsqueeze(1),
313
+ target_video_hidden_states.unsqueeze(-1),
314
+ ).squeeze(-1)
315
+ video_logits = torch.cat(
316
+ [masked_frame_logits, non_masked_frame_logits], dim=1
317
+ )
318
+
319
+ if text_hidden_states is not None:
320
+ text_hidden_states = self.transform(text_hidden_states)
321
+ text_logits = self.decoder(text_hidden_states)
322
+ return video_logits, text_logits
323
+
324
+
325
+ class MFMMLMHead(nn.Module):
326
+ def __init__(self, config):
327
+ super().__init__()
328
+ self.predictions = BertMFMMLMPredictionHead(config)
329
+
330
+ def forward(
331
+ self,
332
+ video_hidden_states=None,
333
+ target_video_hidden_states=None,
334
+ non_masked_frame_hidden_states=None,
335
+ text_hidden_states=None,
336
+ ):
337
+ video_logits, text_logits = self.predictions(
338
+ video_hidden_states,
339
+ target_video_hidden_states,
340
+ non_masked_frame_hidden_states,
341
+ text_hidden_states,
342
+ )
343
+ return video_logits, text_logits
344
+
345
+
346
+ class MMBertForMTM(MMBertForMFMMLM):
347
+ def __init__(self, config):
348
+ BertPreTrainedModel.__init__(self, config)
349
+ self.videomlp = VideoTokenMLP(config)
350
+ self.bert = MMBertModel(config)
351
+ self.cls = MTMHead(config)
352
+ self.hidden_size = config.hidden_size
353
+ self.init_weights()
354
+
355
+
356
+ class BertMTMPredictionHead(nn.Module):
357
+ def __init__(self, config):
358
+ super().__init__()
359
+ self.transform = BertPredictionHeadTransform(config)
360
+ self.decoder = nn.Linear(
361
+ config.hidden_size, config.vocab_size, bias=False)
362
+
363
+ def forward(
364
+ self,
365
+ video_hidden_states=None,
366
+ target_video_hidden_states=None,
367
+ non_masked_frame_hidden_states=None,
368
+ text_hidden_states=None,
369
+ ):
370
+ non_masked_frame_hidden_states = non_masked_frame_hidden_states.transpose(1, 0)
371
+ video_logits, text_logits = None, None
372
+ if video_hidden_states is not None:
373
+ video_hidden_states = self.transform(video_hidden_states)
374
+
375
+ masked_frame_logits = torch.bmm(
376
+ video_hidden_states.unsqueeze(1),
377
+ target_video_hidden_states.unsqueeze(-1),
378
+ ).squeeze(-1)
379
+
380
+ non_masked_frame_logits = torch.mm(
381
+ video_hidden_states,
382
+ non_masked_frame_hidden_states
383
+ )
384
+ video_on_vocab_logits = self.decoder(video_hidden_states)
385
+ video_logits = torch.cat([
386
+ masked_frame_logits,
387
+ non_masked_frame_logits,
388
+ video_on_vocab_logits], dim=1)
389
+
390
+ if text_hidden_states is not None:
391
+ text_hidden_states = self.transform(text_hidden_states)
392
+ # text first so label does not need to be shifted.
393
+ text_on_vocab_logits = self.decoder(text_hidden_states)
394
+ text_on_video_logits = torch.mm(
395
+ text_hidden_states,
396
+ non_masked_frame_hidden_states
397
+ )
398
+ text_logits = torch.cat([
399
+ text_on_vocab_logits,
400
+ text_on_video_logits
401
+ ], dim=1)
402
+
403
+ return video_logits, text_logits
404
+
405
+
406
+ class MTMHead(nn.Module):
407
+ def __init__(self, config):
408
+ super().__init__()
409
+ self.predictions = BertMTMPredictionHead(config)
410
+
411
+ def forward(
412
+ self,
413
+ video_hidden_states=None,
414
+ target_video_hidden_states=None,
415
+ non_masked_frame_hidden_states=None,
416
+ text_hidden_states=None,
417
+ ):
418
+ video_logits, text_logits = self.predictions(
419
+ video_hidden_states,
420
+ target_video_hidden_states,
421
+ non_masked_frame_hidden_states,
422
+ text_hidden_states,
423
+ )
424
+ return video_logits, text_logits
425
+
426
+
427
+ class MMBertModel(BertModel):
428
+ """MMBertModel has MMBertEmbedding to support video tokens."""
429
+
430
+ def __init__(self, config, add_pooling_layer=True):
431
+ super().__init__(config)
432
+ # overwrite embedding
433
+ self.embeddings = MMBertEmbeddings(config)
434
+ self.encoder = MultiLayerAttentionMaskBertEncoder(config)
435
+ self.init_weights()
436
+
437
+ def forward(
438
+ self,
439
+ input_ids=None,
440
+ input_video_embeds=None,
441
+ attention_mask=None,
442
+ token_type_ids=None,
443
+ position_ids=None,
444
+ head_mask=None,
445
+ inputs_embeds=None,
446
+ encoder_hidden_states=None,
447
+ encoder_attention_mask=None,
448
+ output_attentions=None,
449
+ output_hidden_states=None,
450
+ return_dict=None,
451
+ separate_forward_split=None,
452
+ ):
453
+ output_attentions = (
454
+ output_attentions
455
+ if output_attentions is not None
456
+ else self.config.output_attentions
457
+ )
458
+ output_hidden_states = (
459
+ output_hidden_states
460
+ if output_hidden_states is not None
461
+ else self.config.output_hidden_states
462
+ )
463
+ return_dict = (
464
+ return_dict if return_dict is not None
465
+ else self.config.use_return_dict
466
+ )
467
+
468
+ if input_ids is not None and inputs_embeds is not None:
469
+ raise ValueError(
470
+ "You cannot specify both input_ids "
471
+ "and inputs_embeds at the same time"
472
+ )
473
+ elif input_ids is not None:
474
+ if input_video_embeds is not None:
475
+ input_shape = (
476
+ input_ids.size(0),
477
+ input_ids.size(1) + input_video_embeds.size(1),
478
+ )
479
+ else:
480
+ input_shape = (
481
+ input_ids.size(0),
482
+ input_ids.size(1),
483
+ )
484
+ elif inputs_embeds is not None:
485
+ if input_video_embeds is not None:
486
+ input_shape = (
487
+ inputs_embeds.size(0),
488
+ inputs_embeds.size(1) + input_video_embeds.size(1),
489
+ )
490
+ else:
491
+ input_shape = (
492
+ input_ids.size(0),
493
+ input_ids.size(1),
494
+ )
495
+ else:
496
+ raise ValueError(
497
+ "You have to specify either input_ids or inputs_embeds")
498
+
499
+ device = input_ids.device if input_ids is not None \
500
+ else inputs_embeds.device
501
+
502
+ if attention_mask is None:
503
+ attention_mask = torch.ones(input_shape, device=device)
504
+ if token_type_ids is None:
505
+ token_type_ids = torch.zeros(
506
+ input_shape, dtype=torch.long, device=device)
507
+
508
+ # We can provide a self-attention mask of dimensions
509
+ # [batch_size, from_seq_length, to_seq_length]
510
+ # ourselves in which case
511
+ # we just need to make it broadcastable to all heads.
512
+ extended_attention_mask: torch.Tensor = \
513
+ self.get_extended_attention_mask(
514
+ attention_mask, input_shape, device)
515
+
516
+ # If a 2D or 3D attention mask is provided for the cross-attention
517
+ # we need to make broadcastable to
518
+ # [batch_size, num_heads, seq_length, seq_length]
519
+ if self.config.is_decoder and encoder_hidden_states is not None:
520
+ (
521
+ encoder_batch_size,
522
+ encoder_sequence_length,
523
+ _,
524
+ ) = encoder_hidden_states.size()
525
+ encoder_hidden_shape = (
526
+ encoder_batch_size, encoder_sequence_length)
527
+ if encoder_attention_mask is None:
528
+ encoder_attention_mask = torch.ones(
529
+ encoder_hidden_shape, device=device)
530
+ encoder_extended_attention_mask = self.invert_attention_mask(
531
+ encoder_attention_mask
532
+ )
533
+ else:
534
+ encoder_extended_attention_mask = None
535
+
536
+ # Prepare head mask if needed
537
+ # 1.0 in head_mask indicate we keep the head
538
+ # attention_probs has shape bsz x n_heads x N x N
539
+ # input head_mask has shape [num_heads] or
540
+ # [num_hidden_layers x num_heads]
541
+ # and head_mask is converted to shape
542
+ # [num_hidden_layers x batch x num_heads x seq_length x seq_length]
543
+
544
+ head_mask = self.get_head_mask(
545
+ head_mask, self.config.num_hidden_layers)
546
+
547
+ embedding_output = self.embeddings(
548
+ input_ids,
549
+ input_video_embeds,
550
+ position_ids=position_ids,
551
+ token_type_ids=token_type_ids,
552
+ inputs_embeds=inputs_embeds,
553
+ )
554
+
555
+ if separate_forward_split is not None:
556
+ split_embedding_output = \
557
+ embedding_output[:, :separate_forward_split]
558
+ split_extended_attention_mask = extended_attention_mask[
559
+ :, :, :, :separate_forward_split, :separate_forward_split
560
+ ]
561
+ split_encoder_outputs = self.encoder(
562
+ split_embedding_output,
563
+ attention_mask=split_extended_attention_mask,
564
+ head_mask=head_mask,
565
+ encoder_hidden_states=encoder_hidden_states,
566
+ encoder_attention_mask=encoder_extended_attention_mask,
567
+ output_attentions=output_attentions,
568
+ output_hidden_states=output_hidden_states,
569
+ return_dict=return_dict,
570
+ )
571
+ assert (
572
+ len(split_encoder_outputs) <= 2
573
+ ), "we do not support merge on attention for now."
574
+ encoder_outputs = []
575
+ encoder_outputs.append([split_encoder_outputs[0]])
576
+ if len(split_encoder_outputs) == 2:
577
+ encoder_outputs.append([])
578
+ for _all_hidden_states in split_encoder_outputs[1]:
579
+ encoder_outputs[-1].append([_all_hidden_states])
580
+
581
+ split_embedding_output = \
582
+ embedding_output[:, separate_forward_split:]
583
+ split_extended_attention_mask = extended_attention_mask[
584
+ :, :, :, separate_forward_split:, separate_forward_split:
585
+ ]
586
+
587
+ split_encoder_outputs = self.encoder(
588
+ split_embedding_output,
589
+ attention_mask=split_extended_attention_mask,
590
+ head_mask=head_mask,
591
+ encoder_hidden_states=encoder_hidden_states,
592
+ encoder_attention_mask=encoder_extended_attention_mask,
593
+ output_attentions=output_attentions,
594
+ output_hidden_states=output_hidden_states,
595
+ return_dict=return_dict,
596
+ )
597
+
598
+ assert (
599
+ len(split_encoder_outputs) <= 2
600
+ ), "we do not support merge on attention for now."
601
+ encoder_outputs[0].append(split_encoder_outputs[0])
602
+ encoder_outputs[0] = torch.cat(encoder_outputs[0], dim=1)
603
+ if len(split_encoder_outputs) == 2:
604
+ for layer_idx, _all_hidden_states in enumerate(
605
+ split_encoder_outputs[1]
606
+ ):
607
+ encoder_outputs[1][layer_idx].append(_all_hidden_states)
608
+ encoder_outputs[1][layer_idx] = torch.cat(
609
+ encoder_outputs[1][layer_idx], dim=1
610
+ )
611
+ encoder_outputs = tuple(encoder_outputs)
612
+ else:
613
+ encoder_outputs = self.encoder(
614
+ embedding_output,
615
+ attention_mask=extended_attention_mask,
616
+ head_mask=head_mask,
617
+ encoder_hidden_states=encoder_hidden_states,
618
+ encoder_attention_mask=encoder_extended_attention_mask,
619
+ output_attentions=output_attentions,
620
+ output_hidden_states=output_hidden_states,
621
+ return_dict=return_dict,
622
+ )
623
+
624
+ sequence_output = encoder_outputs[0]
625
+ pooled_output = (
626
+ self.pooler(sequence_output) if self.pooler is not None else None
627
+ )
628
+
629
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
630
+
631
+ def get_extended_attention_mask(self, attention_mask, input_shape, device):
632
+ """This is borrowed from `modeling_utils.py` with the support of
633
+ multi-layer attention masks.
634
+ The second dim is expected to be number of layers.
635
+ See `MMAttentionMaskProcessor`.
636
+ Makes broadcastable attention and causal masks so that future
637
+ and masked tokens are ignored.
638
+
639
+ Arguments:
640
+ attention_mask (:obj:`torch.Tensor`):
641
+ Mask with ones indicating tokens to attend to,
642
+ zeros for tokens to ignore.
643
+ input_shape (:obj:`Tuple[int]`):
644
+ The shape of the input to the model.
645
+ device: (:obj:`torch.device`):
646
+ The device of the input to the model.
647
+
648
+ Returns:
649
+ :obj:`torch.Tensor` The extended attention mask, \
650
+ with a the same dtype as :obj:`attention_mask.dtype`.
651
+ """
652
+ # We can provide a self-attention mask of dimensions
653
+ # [batch_size, from_seq_length, to_seq_length]
654
+ # ourselves in which case we just need to make it broadcastable
655
+ # to all heads.
656
+ if attention_mask.dim() == 4:
657
+ extended_attention_mask = attention_mask[:, :, None, :, :]
658
+ extended_attention_mask = extended_attention_mask.to(
659
+ dtype=self.dtype
660
+ ) # fp16 compatibility
661
+ extended_attention_mask = (1.0 - extended_attention_mask) \
662
+ * -10000.0
663
+ return extended_attention_mask
664
+ else:
665
+ return super().get_extended_attention_mask(
666
+ attention_mask, input_shape, device
667
+ )
668
+
669
+
670
+ class MultiLayerAttentionMaskBertEncoder(BertEncoder):
671
+ """extend BertEncoder with the capability of
672
+ multiple layers of attention mask."""
673
+
674
+ def forward(
675
+ self,
676
+ hidden_states,
677
+ attention_mask=None,
678
+ head_mask=None,
679
+ encoder_hidden_states=None,
680
+ encoder_attention_mask=None,
681
+ output_attentions=False,
682
+ output_hidden_states=False,
683
+ return_dict=False,
684
+ ):
685
+ all_hidden_states = () if output_hidden_states else None
686
+ all_attentions = () if output_attentions else None
687
+ for i, layer_module in enumerate(self.layer):
688
+ if output_hidden_states:
689
+ all_hidden_states = all_hidden_states + (hidden_states,)
690
+ layer_head_mask = head_mask[i] if head_mask is not None else None
691
+
692
+ layer_attention_mask = (
693
+ attention_mask[:, i, :, :, :]
694
+ if attention_mask.dim() == 5
695
+ else attention_mask
696
+ )
697
+
698
+ if getattr(self.config, "gradient_checkpointing", False):
699
+
700
+ def create_custom_forward(module):
701
+ def custom_forward(*inputs):
702
+ return module(*inputs, output_attentions)
703
+
704
+ return custom_forward
705
+
706
+ layer_outputs = torch.utils.checkpoint.checkpoint(
707
+ create_custom_forward(layer_module),
708
+ hidden_states,
709
+ layer_attention_mask,
710
+ layer_head_mask,
711
+ encoder_hidden_states,
712
+ encoder_attention_mask,
713
+ )
714
+ else:
715
+ layer_outputs = layer_module(
716
+ hidden_states,
717
+ layer_attention_mask,
718
+ layer_head_mask,
719
+ encoder_hidden_states,
720
+ encoder_attention_mask,
721
+ output_attentions,
722
+ )
723
+ hidden_states = layer_outputs[0]
724
+ if output_attentions:
725
+ all_attentions = all_attentions + (layer_outputs[1],)
726
+
727
+ if output_hidden_states:
728
+ all_hidden_states = all_hidden_states + (hidden_states,)
729
+
730
+ return tuple(
731
+ v
732
+ for v in [hidden_states, all_hidden_states, all_attentions]
733
+ if v is not None
734
+ )
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/modules/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from .mm import *
6
+
7
+ try:
8
+ from .expmm import *
9
+ except ImportError:
10
+ pass
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1/examples/MMPT/mmpt/modules/mm.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # Copyright (c) Facebook, Inc. All Rights Reserved
17
+
18
+
19
+ import torch
20
+
21
+ from torch import nn
22
+
23
+ try:
24
+ from transformers.modeling_bert import (
25
+ BertEmbeddings,
26
+ ACT2FN,
27
+ )
28
+ except ImportError:
29
+ pass
30
+
31
+
32
+ class VideoTokenMLP(nn.Module):
33
+ def __init__(self, config):
34
+ super().__init__()
35
+ input_dim = config.input_dim if hasattr(config, "input_dim") else 512
36
+ self.linear1 = nn.Linear(input_dim, config.hidden_size)
37
+ self.LayerNorm = nn.LayerNorm(config.hidden_size)
38
+ self.activation = ACT2FN[config.hidden_act]
39
+ self.linear2 = nn.Linear(config.hidden_size, config.hidden_size)
40
+
41
+ def forward(self, hidden_states):
42
+ hidden_states = self.linear1(hidden_states)
43
+ hidden_states = self.activation(hidden_states)
44
+ hidden_states = self.LayerNorm(hidden_states)
45
+ hidden_states = self.linear2(hidden_states)
46
+ return hidden_states
47
+
48
+
49
+ class MMBertEmbeddings(BertEmbeddings):
50
+ def __init__(self, config):
51
+ super().__init__(config)
52
+ self.max_video_len = config.max_video_len
53
+ if hasattr(config, "use_seg_emb") and config.use_seg_emb:
54
+ """the original VLM paper uses seg_embeddings for temporal space.
55
+ although not used it changed the randomness of initialization.
56
+ we keep it for reproducibility.
57
+ """
58
+ self.seg_embeddings = nn.Embedding(256, config.hidden_size)
59
+
60
+ def forward(
61
+ self,
62
+ input_ids,
63
+ input_video_embeds,
64
+ token_type_ids=None,
65
+ position_ids=None,
66
+ inputs_embeds=None,
67
+ ):
68
+ input_tensor = input_ids if input_ids is not None else inputs_embeds
69
+ if input_video_embeds is not None:
70
+ input_shape = (
71
+ input_tensor.size(0),
72
+ input_tensor.size(1) + input_video_embeds.size(1),
73
+ )
74
+ else:
75
+ input_shape = (input_tensor.size(0), input_tensor.size(1))
76
+
77
+ if position_ids is None:
78
+ """
79
+ Auto skip position embeddings for text only case.
80
+ use cases:
81
+ (1) action localization and segmentation:
82
+ feed in len-1 dummy video token needs text part to
83
+ skip input_video_embeds.size(1) for the right
84
+ position_ids for video [SEP] and rest text tokens.
85
+ (2) MMFusionShare for two forward passings:
86
+ in `forward_text`: input_video_embeds is None.
87
+ need to skip video [SEP] token.
88
+
89
+ # video_len + 1: [CLS] + video_embed
90
+ # self.max_video_len + 1: [SEP] for video.
91
+ # self.max_video_len + 2: [SEP] for video.
92
+ # self.max_video_len + input_ids.size(1): rest for text.
93
+ """
94
+ if input_video_embeds is not None:
95
+ video_len = input_video_embeds.size(1)
96
+ starting_offset = self.max_video_len + 1 # video [SEP]
97
+ ending_offset = self.max_video_len + input_ids.size(1)
98
+ else:
99
+ video_len = 0
100
+ starting_offset = self.max_video_len + 2 # first text token.
101
+ ending_offset = self.max_video_len + input_ids.size(1) + 1
102
+ position_ids = torch.cat([
103
+ self.position_ids[:, :video_len + 1],
104
+ self.position_ids[:, starting_offset:ending_offset]
105
+ ], dim=1)
106
+
107
+ if token_type_ids is None:
108
+ token_type_ids = torch.zeros(
109
+ input_shape, dtype=torch.long, device=self.position_ids.device
110
+ )
111
+
112
+ """
113
+ the format of input_ids is [CLS] [SEP] caption [SEP] padding.
114
+ the goal is to build [CLS] video tokens [SEP] caption [SEP] .
115
+ """
116
+ if inputs_embeds is None:
117
+ inputs_embeds = self.word_embeddings(input_ids)
118
+ if input_video_embeds is not None:
119
+ inputs_mm_embeds = torch.cat([
120
+ inputs_embeds[:, :1], input_video_embeds, inputs_embeds[:, 1:]
121
+ ], dim=1)
122
+ else:
123
+ # text only for `MMFusionShare`.
124
+ inputs_mm_embeds = inputs_embeds
125
+
126
+ position_embeddings = self.position_embeddings(position_ids)
127
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
128
+ embeddings = inputs_mm_embeds + position_embeddings
129
+ embeddings += token_type_embeddings
130
+
131
+ embeddings = self.LayerNorm(embeddings)
132
+ embeddings = self.dropout(embeddings)
133
+ return embeddings
134
+
135
+
136
+ class AlignHead(nn.Module):
137
+ """this will load pre-trained weights for NSP, which is desirable."""
138
+
139
+ def __init__(self, config):
140
+ super().__init__()
141
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
142
+
143
+ def forward(self, dropout_pooled_output):
144
+ logits = self.seq_relationship(dropout_pooled_output)
145
+ return logits