.gitignore ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio_queue.db*
2
+ pretrained/*
3
+ icetk_models/*
4
+ !*/.gitkeep
5
+ # Byte-compiled / optimized / DLL files
6
+ __pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+
10
+ # C extensions
11
+ *.so
12
+
13
+ # Distribution / packaging
14
+ .Python
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ pip-wheel-metadata/
28
+ share/python-wheels/
29
+ *.egg-info/
30
+ .installed.cfg
31
+ *.egg
32
+ MANIFEST
33
+
34
+ # PyInstaller
35
+ # Usually these files are written by a python script from a template
36
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
37
+ *.manifest
38
+ *.spec
39
+
40
+ # Installer logs
41
+ pip-log.txt
42
+ pip-delete-this-directory.txt
43
+
44
+ # Unit test / coverage reports
45
+ htmlcov/
46
+ .tox/
47
+ .nox/
48
+ .coverage
49
+ .coverage.*
50
+ .cache
51
+ nosetests.xml
52
+ coverage.xml
53
+ *.cover
54
+ *.py,cover
55
+ .hypothesis/
56
+ .pytest_cache/
57
+
58
+ # Translations
59
+ *.mo
60
+ *.pot
61
+
62
+ # Django stuff:
63
+ *.log
64
+ local_settings.py
65
+ db.sqlite3
66
+ db.sqlite3-journal
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ .python-version
90
+
91
+ # pipenv
92
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
94
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
95
+ # install all needed dependencies.
96
+ #Pipfile.lock
97
+
98
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
99
+ __pypackages__/
100
+
101
+ # Celery stuff
102
+ celerybeat-schedule
103
+ celerybeat.pid
104
+
105
+ # SageMath parsed files
106
+ *.sage.py
107
+
108
+ # Environments
109
+ .env
110
+ .venv
111
+ env/
112
+ venv/
113
+ ENV/
114
+ env.bak/
115
+ venv.bak/
116
+
117
+ # Spyder project settings
118
+ .spyderproject
119
+ .spyproject
120
+
121
+ # Rope project settings
122
+ .ropeproject
123
+
124
+ # mkdocs documentation
125
+ /site
126
+
127
+ # mypy
128
+ .mypy_cache/
129
+ .dmypy.json
130
+ dmypy.json
131
+
132
+ # Pyre type checker
133
+ .pyre/
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "CogVideo"]
2
+ path = CogVideo
3
+ url = https://github.com/THUDM/CogVideo
.pre-commit-config.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exclude: ^patch
2
+ repos:
3
+ - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v4.2.0
5
+ hooks:
6
+ - id: check-executables-have-shebangs
7
+ - id: check-json
8
+ - id: check-merge-conflict
9
+ - id: check-shebang-scripts-are-executable
10
+ - id: check-toml
11
+ - id: check-yaml
12
+ - id: double-quote-string-fixer
13
+ - id: end-of-file-fixer
14
+ - id: mixed-line-ending
15
+ args: ['--fix=lf']
16
+ - id: requirements-txt-fixer
17
+ - id: trailing-whitespace
18
+ - repo: https://github.com/myint/docformatter
19
+ rev: v1.4
20
+ hooks:
21
+ - id: docformatter
22
+ args: ['--in-place']
23
+ - repo: https://github.com/pycqa/isort
24
+ rev: 5.10.1
25
+ hooks:
26
+ - id: isort
27
+ - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v0.812
29
+ hooks:
30
+ - id: mypy
31
+ args: ['--ignore-missing-imports']
32
+ - repo: https://github.com/google/yapf
33
+ rev: v0.32.0
34
+ hooks:
35
+ - id: yapf
36
+ args: ['--parallel', '--in-place']
37
+ - repo: https://github.com/kynan/nbstripout
38
+ rev: 0.5.0
39
+ hooks:
40
+ - id: nbstripout
41
+ args: ['--extra-keys', 'metadata.interpreter metadata.kernelspec cell.metadata.pycharm']
42
+ - repo: https://github.com/nbQA-dev/nbQA
43
+ rev: 1.3.1
44
+ hooks:
45
+ - id: nbqa-isort
46
+ - id: nbqa-yapf
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
CogVideo ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit ff423aa169978fb2f636f761e348631fa3178b03
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 hysts
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
LICENSE.CogVideo ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -4,7 +4,8 @@ emoji: 🌍
4
  colorFrom: indigo
5
  colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 3.0.26
 
8
  app_file: app.py
9
  pinned: false
10
  ---
 
4
  colorFrom: indigo
5
  colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 3.1.6
8
+ python_version: 3.9.13
9
  app_file: app.py
10
  pinned: false
11
  ---
app.py CHANGED
@@ -1,18 +0,0 @@
1
- import gradio as gr
2
- import os
3
-
4
-
5
-
6
-
7
- os.environ['SAT_HOME'] = '/home/user/app/sharefs/cogview-new'
8
-
9
- def inference(text):
10
- os.system("""bash ./scripts/inference_cogvideo_pipeline.sh""")
11
- return "output/out.mp4"
12
-
13
- gr.Interface(inference,"text","video").launch()
14
-
15
-
16
-
17
-
18
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cluster_label2.npy DELETED
Binary file (160 kB)
 
coglm_strategy.py DELETED
@@ -1,101 +0,0 @@
1
- # -*- encoding: utf-8 -*-
2
- '''
3
- @File : coglm_strategy.py
4
- @Time : 2021/10/08 22:22:42
5
- @Author : Ming Ding
6
- @Contact : dm18@mails.tsinghua.edu.cn
7
- '''
8
-
9
- # here put the import lib
10
- import os
11
- import sys
12
- import math
13
- import random
14
- import torch
15
- import numpy as np
16
- import torch.nn.functional as F
17
-
18
-
19
- def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-65504):
20
- # This function has been mostly taken from huggingface conversational ai code at
21
- # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313
22
-
23
- if top_k > 0:
24
- # Remove all tokens with a probability less than the last token of the top-k
25
- indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
26
- logits[indices_to_remove] = filter_value
27
-
28
- if top_p > 0.0:
29
- # convert to 1D
30
- logits = logits.view(logits.size()[1]).contiguous()
31
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
32
- cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
33
-
34
- # Remove tokens with cumulative probability above the threshold
35
- sorted_indices_to_remove = cumulative_probs > top_p
36
- # Shift the indices to the right to keep also the first token above the threshold
37
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
38
- sorted_indices_to_remove[..., 0] = 0
39
- indices_to_remove = sorted_indices[sorted_indices_to_remove]
40
- logits[indices_to_remove] = filter_value
41
- # going back to 2D
42
- logits = logits.view(1, -1).contiguous()
43
-
44
- return logits
45
-
46
-
47
- class CoglmStrategy:
48
- def __init__(self, invalid_slices=[], temperature=1., top_k=200, eps=1e-4, top_p=0.0, end_tokens=None, temperature2=0.89):
49
- self.invalid_slices = invalid_slices
50
- self.temperature = temperature
51
- self.temperature2 = temperature2
52
- self.topk = top_k
53
- self.top_p = top_p
54
- self.eps = eps
55
- if end_tokens is None:
56
- end_tokens = []
57
- self.end_tokens = end_tokens
58
- self._is_done = False
59
- self.outlier_count_down = torch.zeros(16)
60
- self.vis_list = [[]for i in range(16)]
61
- self.cluster_labels = torch.tensor(np.load('cluster_label2.npy'), device='cuda', dtype=torch.long)
62
- self.start_pos = -1
63
- self.white_cluster = []
64
- # self.fout = open('tmp.txt', 'w')
65
-
66
- @property
67
- def is_done(self) -> bool:
68
- return self._is_done
69
-
70
- def forward(self, logits, tokens, mems, temperature=None, temperature2=None):
71
- if temperature is None:
72
- temperature = self.temperature
73
- if temperature2 is None:
74
- temperature2 = self.temperature2
75
- logits = logits / temperature
76
- for invalid_slice in self.invalid_slices:
77
- logits[..., invalid_slice] = -65504
78
-
79
- rprobs = F.softmax(logits.float(), dim=-1)
80
- c = self.cluster_labels.expand(*rprobs.shape)
81
- cprobs = torch.zeros(logits.shape[0], 500, device=logits.device).scatter_add_(1, c, rprobs)
82
- # self.fout.write(str(tokens.shape[-1])+ ' ' + str(cprobs.topk(10)) + '\n')
83
- # self.fout.flush()
84
- best_scores, best_clusters = cprobs.topk(self.topk)
85
- bz = logits.shape[0]
86
- for i in range(bz):
87
- selected_cluster = best_clusters[i][torch.multinomial(best_scores[i] / best_scores[i].sum(), num_samples=1)]
88
- logits[i, self.cluster_labels != selected_cluster] = -65504
89
-
90
- # logits = top_k_logits(logits, self.topk, self.top_p)
91
- probs = F.softmax(logits.float()/temperature2, dim=-1) # float is essetial, due to a bug in Pytorch
92
- pred = torch.multinomial(probs, num_samples=1)
93
-
94
- if pred.numel() == 1 and pred.item() in self.end_tokens:
95
- self._is_done = True
96
- tokens = torch.cat((tokens, pred.view(tokens.shape[0], 1)), dim=1)
97
- return tokens, mems
98
-
99
- def finalize(self, tokens, mems):
100
- self._is_done = False
101
- return tokens, mems
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cogvideo_pipeline.py DELETED
@@ -1,793 +0,0 @@
1
- # -*- encoding: utf-8 -*-
2
- '''
3
- @File : cogvideo_pipeline.py
4
- @Time : 2022/07/15 11:24:56
5
- @Author : Wenyi Hong
6
- @Version : 1.0
7
- @Contact : hwy22@mails.tsinghua.edu.cn
8
- '''
9
-
10
- # here put the import lib
11
-
12
- import os
13
- import sys
14
- import torch
15
- import argparse
16
- import time
17
- from torchvision.utils import save_image
18
- import stat
19
- from icetk import icetk as tokenizer
20
- import logging, sys
21
-
22
- import torch.distributed as dist
23
- tokenizer.add_special_tokens(['<start_of_image>', '<start_of_english>', '<start_of_chinese>'])
24
-
25
-
26
- from SwissArmyTransformer import get_args
27
- from SwissArmyTransformer.data_utils import BinaryDataset, make_loaders
28
- from SwissArmyTransformer.generation.sampling_strategies import BaseStrategy
29
- from SwissArmyTransformer.generation.utils import timed_name, save_multiple_images, generate_continually
30
- from SwissArmyTransformer.resources import auto_create
31
-
32
- from models.cogvideo_cache_model import CogVideoCacheModel
33
- from coglm_strategy import CoglmStrategy
34
-
35
-
36
- def get_masks_and_position_ids_stage1(data, textlen, framelen):
37
- # Extract batch size and sequence length.
38
- tokens = data
39
- seq_length = len(data[0])
40
- # Attention mask (lower triangular).
41
- attention_mask = torch.ones((1, textlen+framelen, textlen+framelen), device=data.device)
42
- attention_mask[:, :textlen, textlen:] = 0
43
- attention_mask[:, textlen:, textlen:].tril_()
44
- attention_mask.unsqueeze_(1)
45
- # Unaligned version
46
- position_ids = torch.zeros(seq_length, dtype=torch.long,
47
- device=data.device)
48
- torch.arange(textlen, out=position_ids[:textlen],
49
- dtype=torch.long, device=data.device)
50
- torch.arange(512, 512+seq_length-textlen, out=position_ids[textlen:],
51
- dtype=torch.long, device=data.device)
52
- position_ids = position_ids.unsqueeze(0)
53
-
54
- return tokens, attention_mask, position_ids
55
-
56
- def get_masks_and_position_ids_stage2(data, textlen, framelen):
57
- # Extract batch size and sequence length.
58
- tokens = data
59
- seq_length = len(data[0])
60
-
61
- # Attention mask (lower triangular).
62
- attention_mask = torch.ones((1, textlen+framelen, textlen+framelen), device=data.device)
63
- attention_mask[:, :textlen, textlen:] = 0
64
- attention_mask[:, textlen:, textlen:].tril_()
65
- attention_mask.unsqueeze_(1)
66
-
67
- # Unaligned version
68
- position_ids = torch.zeros(seq_length, dtype=torch.long,
69
- device=data.device)
70
- torch.arange(textlen, out=position_ids[:textlen],
71
- dtype=torch.long, device=data.device)
72
- frame_num = (seq_length-textlen)//framelen
73
- assert frame_num == 5
74
- torch.arange(512, 512+framelen, out=position_ids[textlen:textlen+framelen],
75
- dtype=torch.long, device=data.device)
76
- torch.arange(512+framelen*2, 512+framelen*3, out=position_ids[textlen+framelen:textlen+framelen*2],
77
- dtype=torch.long, device=data.device)
78
- torch.arange(512+framelen*(frame_num-1), 512+framelen*frame_num, out=position_ids[textlen+framelen*2:textlen+framelen*3],
79
- dtype=torch.long, device=data.device)
80
- torch.arange(512+framelen*1, 512+framelen*2, out=position_ids[textlen+framelen*3:textlen+framelen*4],
81
- dtype=torch.long, device=data.device)
82
- torch.arange(512+framelen*3, 512+framelen*4, out=position_ids[textlen+framelen*4:textlen+framelen*5],
83
- dtype=torch.long, device=data.device)
84
-
85
- position_ids = position_ids.unsqueeze(0)
86
-
87
- return tokens, attention_mask, position_ids
88
-
89
- def my_update_mems(hiddens, mems_buffers, mems_indexs, limited_spatial_channel_mem, text_len, frame_len):
90
- if hiddens is None:
91
- return None, mems_indexs
92
- mem_num = len(hiddens)
93
- ret_mem = []
94
- with torch.no_grad():
95
- for id in range(mem_num):
96
- if hiddens[id][0] is None:
97
- ret_mem.append(None)
98
- else:
99
- if id == 0 and limited_spatial_channel_mem and mems_indexs[id]+hiddens[0][0].shape[1] >= text_len+frame_len:
100
- if mems_indexs[id] == 0:
101
- for layer, hidden in enumerate(hiddens[id]):
102
- mems_buffers[id][layer, :, :text_len] = hidden.expand(mems_buffers[id].shape[1], -1, -1)[:, :text_len]
103
- new_mem_len_part2 = (mems_indexs[id]+hiddens[0][0].shape[1]-text_len)%frame_len
104
- if new_mem_len_part2 > 0:
105
- for layer, hidden in enumerate(hiddens[id]):
106
- mems_buffers[id][layer, :, text_len:text_len+new_mem_len_part2] = hidden.expand(mems_buffers[id].shape[1], -1, -1)[:, -new_mem_len_part2:]
107
- mems_indexs[id] = text_len+new_mem_len_part2
108
- else:
109
- for layer, hidden in enumerate(hiddens[id]):
110
- mems_buffers[id][layer, :, mems_indexs[id]:mems_indexs[id]+hidden.shape[1]] = hidden.expand(mems_buffers[id].shape[1], -1, -1)
111
- mems_indexs[id] += hidden.shape[1]
112
- ret_mem.append(mems_buffers[id][:, :, :mems_indexs[id]])
113
- return ret_mem, mems_indexs
114
-
115
-
116
- def my_save_multiple_images(imgs, path, subdir, debug=True):
117
- # imgs: list of tensor images
118
- if debug:
119
- imgs = torch.cat(imgs, dim=0)
120
- print("\nSave to: ", path, flush=True)
121
- save_image(imgs, path, normalize=True)
122
- else:
123
- print("\nSave to: ", path, flush=True)
124
- single_frame_path = os.path.join(path, subdir)
125
- os.makedirs(single_frame_path, exist_ok=True)
126
- for i in range(len(imgs)):
127
- save_image(imgs[i], os.path.join(single_frame_path, f'{str(i).rjust(4,"0")}.jpg'), normalize=True)
128
- os.chmod(os.path.join(single_frame_path,f'{str(i).rjust(4,"0")}.jpg'), stat.S_IRWXO+stat.S_IRWXG+stat.S_IRWXU)
129
- save_image(torch.cat(imgs, dim=0), os.path.join(single_frame_path,f'frame_concat.jpg'), normalize=True)
130
- os.chmod(os.path.join(single_frame_path,f'frame_concat.jpg'), stat.S_IRWXO+stat.S_IRWXG+stat.S_IRWXU)
131
-
132
- def calc_next_tokens_frame_begin_id(text_len, frame_len, total_len):
133
- # The fisrt token's position id of the frame that the next token belongs to;
134
- if total_len < text_len:
135
- return None
136
- return (total_len-text_len)//frame_len * frame_len + text_len
137
-
138
- def my_filling_sequence(
139
- model,
140
- args,
141
- seq,
142
- batch_size,
143
- get_masks_and_position_ids,
144
- text_len,
145
- frame_len,
146
- strategy=BaseStrategy(),
147
- strategy2=BaseStrategy(),
148
- mems=None,
149
- log_text_attention_weights=0, # default to 0: no artificial change
150
- mode_stage1=True,
151
- enforce_no_swin=False,
152
- guider_seq=None,
153
- guider_text_len=0,
154
- guidance_alpha=1,
155
- limited_spatial_channel_mem=False, # 空间通道的存储限制在本帧内
156
- **kw_args
157
- ):
158
- '''
159
- seq: [2, 3, 5, ..., -1(to be generated), -1, ...]
160
- mems: [num_layers, batch_size, len_mems(index), mem_hidden_size]
161
- cache, should be first mems.shape[1] parts of context_tokens.
162
- mems are the first-level citizens here, but we don't assume what is memorized.
163
- input mems are used when multi-phase generation.
164
- '''
165
- if guider_seq is not None:
166
- logging.debug("Using Guidance In Inference")
167
- if limited_spatial_channel_mem:
168
- logging.debug("Limit spatial-channel's mem to current frame")
169
- assert len(seq.shape) == 2
170
-
171
- # building the initial tokens, attention_mask, and position_ids
172
- actual_context_length = 0
173
-
174
- while seq[-1][actual_context_length] >= 0: # the last seq has least given tokens
175
- actual_context_length += 1 # [0, context_length-1] are given
176
- assert actual_context_length > 0
177
- current_frame_num = (actual_context_length-text_len) // frame_len
178
- assert current_frame_num >= 0
179
- context_length = text_len + current_frame_num * frame_len
180
-
181
- tokens, attention_mask, position_ids = get_masks_and_position_ids(seq, text_len, frame_len)
182
- tokens = tokens[..., :context_length]
183
- input_tokens = tokens.clone()
184
-
185
- if guider_seq is not None:
186
- guider_index_delta = text_len - guider_text_len
187
- guider_tokens, guider_attention_mask, guider_position_ids = get_masks_and_position_ids(guider_seq, guider_text_len, frame_len)
188
- guider_tokens = guider_tokens[..., :context_length-guider_index_delta]
189
- guider_input_tokens = guider_tokens.clone()
190
-
191
- for fid in range(current_frame_num):
192
- input_tokens[:, text_len+400*fid] = tokenizer['<start_of_image>']
193
- if guider_seq is not None:
194
- guider_input_tokens[:, guider_text_len+400*fid] = tokenizer['<start_of_image>']
195
-
196
- attention_mask = attention_mask.type_as(next(model.parameters())) # if fp16
197
- # initialize generation
198
- counter = context_length - 1 # Last fixed index is ``counter''
199
- index = 0 # Next forward starting index, also the length of cache.
200
- mems_buffers_on_GPU = False
201
- mems_indexs = [0, 0]
202
- mems_len = [(400+74) if limited_spatial_channel_mem else 5*400+74, 5*400+74]
203
- mems_buffers = [torch.zeros(args.num_layers, batch_size, mem_len, args.hidden_size*2, dtype=next(model.parameters()).dtype)
204
- for mem_len in mems_len]
205
-
206
-
207
- if guider_seq is not None:
208
- guider_attention_mask = guider_attention_mask.type_as(next(model.parameters())) # if fp16
209
- guider_mems_buffers = [torch.zeros(args.num_layers, batch_size, mem_len, args.hidden_size*2, dtype=next(model.parameters()).dtype)
210
- for mem_len in mems_len]
211
- guider_mems_indexs = [0, 0]
212
- guider_mems = None
213
-
214
- torch.cuda.empty_cache()
215
- # step-by-step generation
216
- while counter < len(seq[0]) - 1:
217
- # we have generated counter+1 tokens
218
- # Now, we want to generate seq[counter + 1],
219
- # token[:, index: counter+1] needs forwarding.
220
- if index == 0:
221
- group_size = 2 if (input_tokens.shape[0] == batch_size and not mode_stage1) else batch_size
222
-
223
- logits_all = None
224
- for batch_idx in range(0, input_tokens.shape[0], group_size):
225
- logits, *output_per_layers = model(
226
- input_tokens[batch_idx:batch_idx+group_size, index:],
227
- position_ids[..., index: counter+1],
228
- attention_mask, # TODO memlen
229
- mems=mems,
230
- text_len=text_len,
231
- frame_len=frame_len,
232
- counter=counter,
233
- log_text_attention_weights=log_text_attention_weights,
234
- enforce_no_swin=enforce_no_swin,
235
- **kw_args
236
- )
237
- logits_all = torch.cat((logits_all, logits), dim=0) if logits_all is not None else logits
238
- mem_kv01 = [[o['mem_kv'][0] for o in output_per_layers], [o['mem_kv'][1] for o in output_per_layers]]
239
- next_tokens_frame_begin_id = calc_next_tokens_frame_begin_id(text_len, frame_len, mem_kv01[0][0].shape[1])
240
- for id, mem_kv in enumerate(mem_kv01):
241
- for layer, mem_kv_perlayer in enumerate(mem_kv):
242
- if limited_spatial_channel_mem and id == 0:
243
- mems_buffers[id][layer, batch_idx:batch_idx+group_size, :text_len] = mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1)[:, :text_len]
244
- mems_buffers[id][layer, batch_idx:batch_idx+group_size, text_len:text_len+mem_kv_perlayer.shape[1]-next_tokens_frame_begin_id] =\
245
- mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1)[:, next_tokens_frame_begin_id:]
246
- else:
247
- mems_buffers[id][layer, batch_idx:batch_idx+group_size, :mem_kv_perlayer.shape[1]] = mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1)
248
- mems_indexs[0], mems_indexs[1] = mem_kv01[0][0].shape[1], mem_kv01[1][0].shape[1]
249
- if limited_spatial_channel_mem:
250
- mems_indexs[0] -= (next_tokens_frame_begin_id - text_len)
251
-
252
- mems = [mems_buffers[id][:, :, :mems_indexs[id]] for id in range(2)]
253
- logits = logits_all
254
-
255
- # Guider
256
- if guider_seq is not None:
257
- guider_logits_all = None
258
- for batch_idx in range(0, guider_input_tokens.shape[0], group_size):
259
- guider_logits, *guider_output_per_layers = model(
260
- guider_input_tokens[batch_idx:batch_idx+group_size, max(index-guider_index_delta, 0):],
261
- guider_position_ids[..., max(index-guider_index_delta, 0): counter+1-guider_index_delta],
262
- guider_attention_mask,
263
- mems=guider_mems,
264
- text_len=guider_text_len,
265
- frame_len=frame_len,
266
- counter=counter-guider_index_delta,
267
- log_text_attention_weights=log_text_attention_weights,
268
- enforce_no_swin=enforce_no_swin,
269
- **kw_args
270
- )
271
- guider_logits_all = torch.cat((guider_logits_all, guider_logits), dim=0) if guider_logits_all is not None else guider_logits
272
- guider_mem_kv01 = [[o['mem_kv'][0] for o in guider_output_per_layers], [o['mem_kv'][1] for o in guider_output_per_layers]]
273
- for id, guider_mem_kv in enumerate(guider_mem_kv01):
274
- for layer, guider_mem_kv_perlayer in enumerate(guider_mem_kv):
275
- if limited_spatial_channel_mem and id == 0:
276
- guider_mems_buffers[id][layer, batch_idx:batch_idx+group_size, :guider_text_len] = guider_mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1)[:, :guider_text_len]
277
- guider_next_tokens_frame_begin_id = calc_next_tokens_frame_begin_id(guider_text_len, frame_len, guider_mem_kv_perlayer.shape[1])
278
- guider_mems_buffers[id][layer, batch_idx:batch_idx+group_size, guider_text_len:guider_text_len+guider_mem_kv_perlayer.shape[1]-guider_next_tokens_frame_begin_id] =\
279
- guider_mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1)[:, guider_next_tokens_frame_begin_id:]
280
- else:
281
- guider_mems_buffers[id][layer, batch_idx:batch_idx+group_size, :guider_mem_kv_perlayer.shape[1]] = guider_mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1)
282
- guider_mems_indexs[0], guider_mems_indexs[1] = guider_mem_kv01[0][0].shape[1], guider_mem_kv01[1][0].shape[1]
283
- if limited_spatial_channel_mem:
284
- guider_mems_indexs[0] -= (guider_next_tokens_frame_begin_id-guider_text_len)
285
- guider_mems = [guider_mems_buffers[id][:, :, :guider_mems_indexs[id]] for id in range(2)]
286
- guider_logits = guider_logits_all
287
- else:
288
- if not mems_buffers_on_GPU:
289
- if not mode_stage1:
290
- torch.cuda.empty_cache()
291
- for idx, mem in enumerate(mems):
292
- mems[idx] = mem.to(next(model.parameters()).device)
293
- if guider_seq is not None:
294
- for idx, mem in enumerate(guider_mems):
295
- guider_mems[idx] = mem.to(next(model.parameters()).device)
296
- else:
297
- torch.cuda.empty_cache()
298
- for idx, mem_buffer in enumerate(mems_buffers):
299
- mems_buffers[idx] = mem_buffer.to(next(model.parameters()).device)
300
- mems = [mems_buffers[id][:, :, :mems_indexs[id]] for id in range(2)]
301
- if guider_seq is not None:
302
- for idx, guider_mem_buffer in enumerate(guider_mems_buffers):
303
- guider_mems_buffers[idx] = guider_mem_buffer.to(next(model.parameters()).device)
304
- guider_mems = [guider_mems_buffers[id][:, :, :guider_mems_indexs[id]] for id in range(2)]
305
- mems_buffers_on_GPU = True
306
-
307
- logits, *output_per_layers = model(
308
- input_tokens[:, index:],
309
- position_ids[..., index: counter+1],
310
- attention_mask, # TODO memlen
311
- mems=mems,
312
- text_len=text_len,
313
- frame_len=frame_len,
314
- counter=counter,
315
- log_text_attention_weights=log_text_attention_weights,
316
- enforce_no_swin=enforce_no_swin,
317
- limited_spatial_channel_mem=limited_spatial_channel_mem,
318
- **kw_args
319
- )
320
- mem_kv0, mem_kv1 = [o['mem_kv'][0] for o in output_per_layers], [o['mem_kv'][1] for o in output_per_layers]
321
-
322
- if guider_seq is not None:
323
- guider_logits, *guider_output_per_layers = model(
324
- guider_input_tokens[:, max(index-guider_index_delta, 0):],
325
- guider_position_ids[..., max(index-guider_index_delta, 0): counter+1-guider_index_delta],
326
- guider_attention_mask,
327
- mems=guider_mems,
328
- text_len=guider_text_len,
329
- frame_len=frame_len,
330
- counter=counter-guider_index_delta,
331
- log_text_attention_weights=0,
332
- enforce_no_swin=enforce_no_swin,
333
- limited_spatial_channel_mem=limited_spatial_channel_mem,
334
- **kw_args
335
- )
336
- guider_mem_kv0, guider_mem_kv1 = [o['mem_kv'][0] for o in guider_output_per_layers], [o['mem_kv'][1] for o in guider_output_per_layers]
337
-
338
- if not mems_buffers_on_GPU:
339
- torch.cuda.empty_cache()
340
- for idx, mem_buffer in enumerate(mems_buffers):
341
- mems_buffers[idx] = mem_buffer.to(next(model.parameters()).device)
342
- if guider_seq is not None:
343
- for idx, guider_mem_buffer in enumerate(guider_mems_buffers):
344
- guider_mems_buffers[idx] = guider_mem_buffer.to(next(model.parameters()).device)
345
- mems_buffers_on_GPU = True
346
-
347
- mems, mems_indexs = my_update_mems([mem_kv0, mem_kv1], mems_buffers, mems_indexs, limited_spatial_channel_mem, text_len, frame_len)
348
- if guider_seq is not None:
349
- guider_mems, guider_mems_indexs = my_update_mems([guider_mem_kv0, guider_mem_kv1], guider_mems_buffers, guider_mems_indexs, limited_spatial_channel_mem, guider_text_len, frame_len)
350
-
351
-
352
- counter += 1
353
- index = counter
354
-
355
- logits = logits[:, -1].expand(batch_size, -1) # [batch size, vocab size]
356
- tokens = tokens.expand(batch_size, -1)
357
- if guider_seq is not None:
358
- guider_logits = guider_logits[:, -1].expand(batch_size, -1)
359
- guider_tokens = guider_tokens.expand(batch_size, -1)
360
-
361
- if seq[-1][counter].item() < 0:
362
- # sampling
363
- guided_logits = guider_logits+(logits-guider_logits)*guidance_alpha if guider_seq is not None else logits
364
- if mode_stage1 and counter < text_len + 400:
365
- tokens, mems = strategy.forward(guided_logits, tokens, mems)
366
- else:
367
- tokens, mems = strategy2.forward(guided_logits, tokens, mems)
368
- if guider_seq is not None:
369
- guider_tokens = torch.cat((guider_tokens, tokens[:, -1:]), dim=1)
370
-
371
- if seq[0][counter].item() >= 0:
372
- for si in range(seq.shape[0]):
373
- if seq[si][counter].item() >= 0:
374
- tokens[si, -1] = seq[si, counter]
375
- if guider_seq is not None:
376
- guider_tokens[si, -1] = guider_seq[si, counter-guider_index_delta]
377
-
378
- else:
379
- tokens = torch.cat((tokens, seq[:, counter:counter+1].clone().expand(tokens.shape[0], 1).to(device=tokens.device, dtype=tokens.dtype)), dim=1)
380
- if guider_seq is not None:
381
- guider_tokens = torch.cat((guider_tokens,
382
- guider_seq[:, counter-guider_index_delta:counter+1-guider_index_delta]
383
- .clone().expand(guider_tokens.shape[0], 1).to(device=guider_tokens.device, dtype=guider_tokens.dtype)), dim=1)
384
-
385
- input_tokens = tokens.clone()
386
- if guider_seq is not None:
387
- guider_input_tokens = guider_tokens.clone()
388
- if (index-text_len-1)//400 < (input_tokens.shape[-1]-text_len-1)//400:
389
- boi_idx = ((index-text_len-1)//400 +1)*400+text_len
390
- while boi_idx < input_tokens.shape[-1]:
391
- input_tokens[:, boi_idx] = tokenizer['<start_of_image>']
392
- if guider_seq is not None:
393
- guider_input_tokens[:, boi_idx-guider_index_delta] = tokenizer['<start_of_image>']
394
- boi_idx += 400
395
-
396
- if strategy.is_done:
397
- break
398
- return strategy.finalize(tokens, mems)
399
-
400
- class InferenceModel_Sequential(CogVideoCacheModel):
401
- def __init__(self, args, transformer=None, parallel_output=True):
402
- super().__init__(args, transformer=transformer, parallel_output=parallel_output, window_size=-1, cogvideo_stage=1)
403
- # TODO: check it
404
-
405
- def final_forward(self, logits, **kwargs):
406
- logits_parallel = logits
407
- logits_parallel = torch.nn.functional.linear(logits_parallel.float(), self.transformer.word_embeddings.weight[:20000].float())
408
- return logits_parallel
409
-
410
- class InferenceModel_Interpolate(CogVideoCacheModel):
411
- def __init__(self, args, transformer=None, parallel_output=True):
412
- super().__init__(args, transformer=transformer, parallel_output=parallel_output, window_size=10, cogvideo_stage=2)
413
- # TODO: check it
414
-
415
- def final_forward(self, logits, **kwargs):
416
- logits_parallel = logits
417
- logits_parallel = torch.nn.functional.linear(logits_parallel.float(), self.transformer.word_embeddings.weight[:20000].float())
418
- return logits_parallel
419
-
420
- def main(args):
421
- assert int(args.stage_1) + int(args.stage_2) + int(args.both_stages) == 1
422
- rank_id = args.device % args.parallel_size
423
- generate_frame_num = args.generate_frame_num
424
-
425
- if args.stage_1 or args.both_stages:
426
- model_stage1, args = InferenceModel_Sequential.from_pretrained(args, 'cogvideo-stage1')
427
- model_stage1.eval()
428
- if args.both_stages:
429
- model_stage1 = model_stage1.cpu()
430
-
431
- if args.stage_2 or args.both_stages:
432
- model_stage2, args = InferenceModel_Interpolate.from_pretrained(args, 'cogvideo-stage2')
433
- model_stage2.eval()
434
- if args.both_stages:
435
- model_stage2 = model_stage2.cpu()
436
-
437
- invalid_slices = [slice(tokenizer.num_image_tokens, None)]
438
- strategy_cogview2 = CoglmStrategy(invalid_slices,
439
- temperature=1.0, top_k=16)
440
- strategy_cogvideo = CoglmStrategy(invalid_slices,
441
- temperature=args.temperature, top_k=args.top_k,
442
- temperature2=args.coglm_temperature2)
443
- if not args.stage_1:
444
- from sr_pipeline import DirectSuperResolution
445
- dsr_path = auto_create('cogview2-dsr', path=None) # path=os.getenv('SAT_HOME', '~/.sat_models')
446
- dsr = DirectSuperResolution(args, dsr_path,
447
- max_bz=12, onCUDA=False)
448
-
449
- def process_stage2(model, seq_text, duration, video_raw_text=None, video_guidance_text="视频", parent_given_tokens=None, conddir=None, outputdir=None, gpu_rank=0, gpu_parallel_size=1):
450
- stage2_starttime = time.time()
451
- use_guidance = args.use_guidance_stage2
452
- if args.both_stages:
453
- move_start_time = time.time()
454
- logging.debug("moving stage-2 model to cuda")
455
- model = model.cuda()
456
- logging.debug("moving in stage-2 model takes time: {:.2f}".format(time.time()-move_start_time))
457
-
458
- try:
459
- if parent_given_tokens is None:
460
- assert conddir is not None
461
- parent_given_tokens = torch.load(os.path.join(conddir, 'frame_tokens.pt'), map_location='cpu')
462
- sample_num_allgpu = parent_given_tokens.shape[0]
463
- sample_num = sample_num_allgpu // gpu_parallel_size
464
- assert sample_num * gpu_parallel_size == sample_num_allgpu
465
- parent_given_tokens = parent_given_tokens[gpu_rank*sample_num:(gpu_rank+1)*sample_num]
466
- except:
467
- logging.critical("No frame_tokens found in interpolation, skip")
468
- return False
469
-
470
- # CogVideo Stage2 Generation
471
- while duration >= 0.5: # TODO: You can change the boundary to change the frame rate
472
- parent_given_tokens_num = parent_given_tokens.shape[1]
473
- generate_batchsize_persample = (parent_given_tokens_num-1)//2
474
- generate_batchsize_total = generate_batchsize_persample * sample_num
475
- total_frames = generate_frame_num
476
- frame_len = 400
477
- enc_text = tokenizer.encode(seq_text)
478
- enc_duration = tokenizer.encode(str(float(duration))+"秒")
479
- seq = enc_duration + [tokenizer['<n>']] + enc_text + [tokenizer['<start_of_image>']] + [-1]*400*generate_frame_num
480
- text_len = len(seq) - frame_len*generate_frame_num - 1
481
-
482
- logging.info("[Stage2: Generating Frames, Frame Rate {:d}]\nraw text: {:s}".format(int(4/duration), tokenizer.decode(enc_text)))
483
-
484
- # generation
485
- seq = torch.cuda.LongTensor(seq, device=args.device).unsqueeze(0).repeat(generate_batchsize_total, 1)
486
- for sample_i in range(sample_num):
487
- for i in range(generate_batchsize_persample):
488
- seq[sample_i*generate_batchsize_persample+i][text_len+1:text_len+1+400] = parent_given_tokens[sample_i][2*i]
489
- seq[sample_i*generate_batchsize_persample+i][text_len+1+400:text_len+1+800] = parent_given_tokens[sample_i][2*i+1]
490
- seq[sample_i*generate_batchsize_persample+i][text_len+1+800:text_len+1+1200] = parent_given_tokens[sample_i][2*i+2]
491
-
492
- if use_guidance:
493
- guider_seq = enc_duration + [tokenizer['<n>']] + tokenizer.encode(video_guidance_text) + [tokenizer['<start_of_image>']] + [-1]*400*generate_frame_num
494
- guider_text_len = len(guider_seq) - frame_len*generate_frame_num - 1
495
- guider_seq = torch.cuda.LongTensor(guider_seq, device=args.device).unsqueeze(0).repeat(generate_batchsize_total, 1)
496
- for sample_i in range(sample_num):
497
- for i in range(generate_batchsize_persample):
498
- guider_seq[sample_i*generate_batchsize_persample+i][text_len+1:text_len+1+400] = parent_given_tokens[sample_i][2*i]
499
- guider_seq[sample_i*generate_batchsize_persample+i][text_len+1+400:text_len+1+800] = parent_given_tokens[sample_i][2*i+1]
500
- guider_seq[sample_i*generate_batchsize_persample+i][text_len+1+800:text_len+1+1200] = parent_given_tokens[sample_i][2*i+2]
501
- video_log_text_attention_weights = 0
502
- else:
503
- guider_seq=None
504
- guider_text_len=0
505
- video_log_text_attention_weights = 1.4
506
-
507
- mbz = args.max_inference_batch_size
508
-
509
- assert generate_batchsize_total < mbz or generate_batchsize_total % mbz == 0
510
- output_list = []
511
- start_time = time.time()
512
- for tim in range(max(generate_batchsize_total // mbz, 1)):
513
- input_seq = seq[:min(generate_batchsize_total, mbz)].clone() if tim == 0 else seq[mbz*tim:mbz*(tim+1)].clone()
514
- guider_seq2 = (guider_seq[:min(generate_batchsize_total, mbz)].clone() if tim == 0 else guider_seq[mbz*tim:mbz*(tim+1)].clone()) if guider_seq is not None else None
515
- output_list.append(
516
- my_filling_sequence(model, args, input_seq,
517
- batch_size=min(generate_batchsize_total, mbz),
518
- get_masks_and_position_ids=get_masks_and_position_ids_stage2,
519
- text_len=text_len, frame_len=frame_len,
520
- strategy=strategy_cogview2,
521
- strategy2=strategy_cogvideo,
522
- log_text_attention_weights=video_log_text_attention_weights,
523
- mode_stage1=False,
524
- guider_seq=guider_seq2,
525
- guider_text_len=guider_text_len,
526
- guidance_alpha=args.guidance_alpha,
527
- limited_spatial_channel_mem=True,
528
- )[0]
529
- )
530
- logging.info("Duration {:.2f}, Taken time {:.2f}\n".format(duration, time.time() - start_time))
531
-
532
- output_tokens = torch.cat(output_list, dim=0)
533
- output_tokens = output_tokens[:, text_len+1:text_len+1+(total_frames)*400].reshape(sample_num, -1, 400*total_frames)
534
- output_tokens_merge = torch.cat((output_tokens[:, :, :1*400],
535
- output_tokens[:, :, 400*3:4*400],
536
- output_tokens[:, :, 400*1:2*400],
537
- output_tokens[:, :, 400*4:(total_frames)*400]), dim=2).reshape(sample_num, -1, 400)
538
-
539
- output_tokens_merge = torch.cat((output_tokens_merge, output_tokens[:, -1:, 400*2:3*400]), dim=1)
540
- duration /= 2
541
- parent_given_tokens = output_tokens_merge
542
-
543
- if args.both_stages:
544
- move_start_time = time.time()
545
- logging.debug("moving stage 2 model to cpu")
546
- model = model.cpu()
547
- torch.cuda.empty_cache()
548
- logging.debug("moving out model2 takes time: {:.2f}".format(time.time()-move_start_time))
549
-
550
- logging.info("CogVideo Stage2 completed. Taken time {:.2f}\n".format(time.time() - stage2_starttime))
551
-
552
- # decoding
553
- # imgs = [torch.nn.functional.interpolate(tokenizer.decode(image_ids=seq.tolist()), size=(480, 480)) for seq in output_tokens_merge]
554
- # os.makedirs(output_dir_full_path, exist_ok=True)
555
- # my_save_multiple_images(imgs, output_dir_full_path,subdir="frames", debug=False)
556
- # torch.save(output_tokens_merge.cpu(), os.path.join(output_dir_full_path, 'frame_token.pt'))
557
- # os.system(f"gifmaker -i '{output_dir_full_path}'/frames/0*.jpg -o '{output_dir_full_path}/{str(float(duration))}_concat.gif' -d 0.2")
558
-
559
- # direct super-resolution by CogView2
560
- logging.info("[Direct super-resolution]")
561
- dsr_starttime = time.time()
562
- enc_text = tokenizer.encode(seq_text)
563
- frame_num_per_sample = parent_given_tokens.shape[1]
564
- parent_given_tokens_2d = parent_given_tokens.reshape(-1, 400)
565
- text_seq = torch.cuda.LongTensor(enc_text, device=args.device).unsqueeze(0).repeat(parent_given_tokens_2d.shape[0], 1)
566
- sred_tokens = dsr(text_seq, parent_given_tokens_2d)
567
- decoded_sr_videos = []
568
-
569
- for sample_i in range(sample_num):
570
- decoded_sr_imgs = []
571
- for frame_i in range(frame_num_per_sample):
572
- decoded_sr_img = tokenizer.decode(image_ids=sred_tokens[frame_i+sample_i*frame_num_per_sample][-3600:])
573
- decoded_sr_imgs.append(torch.nn.functional.interpolate(decoded_sr_img, size=(480, 480)))
574
- decoded_sr_videos.append(decoded_sr_imgs)
575
-
576
- for sample_i in range(sample_num):
577
- my_save_multiple_images(decoded_sr_videos[sample_i], outputdir,subdir=f"frames/{sample_i+sample_num*gpu_rank}", debug=False)
578
- os.system(f"gifmaker -i '{outputdir}'/frames/'{sample_i+sample_num*gpu_rank}'/0*.jpg -o '{outputdir}/{sample_i+sample_num*gpu_rank}.gif' -d 0.125")
579
-
580
- logging.info("Direct super-resolution completed. Taken time {:.2f}\n".format(time.time() - dsr_starttime))
581
-
582
- return True
583
-
584
-
585
- def process_stage1(model, seq_text, duration, video_raw_text=None, video_guidance_text="视频", image_text_suffix="", outputdir=None, batch_size=1):
586
- process_start_time = time.time()
587
- use_guide = args.use_guidance_stage1
588
- if args.both_stages:
589
- move_start_time = time.time()
590
- logging.debug("moving stage 1 model to cuda")
591
- model = model.cuda()
592
- logging.debug("moving in model1 takes time: {:.2f}".format(time.time()-move_start_time))
593
-
594
- if video_raw_text is None:
595
- video_raw_text = seq_text
596
- mbz = args.stage1_max_inference_batch_size if args.stage1_max_inference_batch_size > 0 else args.max_inference_batch_size
597
- assert batch_size < mbz or batch_size % mbz == 0
598
- frame_len = 400
599
-
600
- # generate the first frame:
601
- enc_text = tokenizer.encode(seq_text+image_text_suffix)
602
- seq_1st = enc_text + [tokenizer['<start_of_image>']] + [-1]*400 # IV!! # test local!!! # test randboi!!!
603
- logging.info("[Generating First Frame with CogView2]Raw text: {:s}".format(tokenizer.decode(enc_text)))
604
- text_len_1st = len(seq_1st) - frame_len*1 - 1
605
-
606
- seq_1st = torch.cuda.LongTensor(seq_1st, device=args.device).unsqueeze(0)
607
- output_list_1st = []
608
- for tim in range(max(batch_size // mbz, 1)):
609
- start_time = time.time()
610
- output_list_1st.append(
611
- my_filling_sequence(model, args,seq_1st.clone(),
612
- batch_size=min(batch_size, mbz),
613
- get_masks_and_position_ids=get_masks_and_position_ids_stage1,
614
- text_len=text_len_1st,
615
- frame_len=frame_len,
616
- strategy=strategy_cogview2,
617
- strategy2=strategy_cogvideo,
618
- log_text_attention_weights=1.4,
619
- enforce_no_swin=True,
620
- mode_stage1=True,
621
- )[0]
622
- )
623
- logging.info("[First Frame]Taken time {:.2f}\n".format(time.time() - start_time))
624
- output_tokens_1st = torch.cat(output_list_1st, dim=0)
625
- given_tokens = output_tokens_1st[:, text_len_1st+1:text_len_1st+401].unsqueeze(1) # given_tokens.shape: [bs, frame_num, 400]
626
-
627
- # generate subsequent frames:
628
- total_frames = generate_frame_num
629
- enc_duration = tokenizer.encode(str(float(duration))+"秒")
630
- if use_guide:
631
- video_raw_text = video_raw_text + " 视频"
632
- enc_text_video = tokenizer.encode(video_raw_text)
633
- seq = enc_duration + [tokenizer['<n>']] + enc_text_video + [tokenizer['<start_of_image>']] + [-1]*400*generate_frame_num
634
- guider_seq = enc_duration + [tokenizer['<n>']] + tokenizer.encode(video_guidance_text) + [tokenizer['<start_of_image>']] + [-1]*400*generate_frame_num
635
- logging.info("[Stage1: Generating Subsequent Frames, Frame Rate {:.1f}]\nraw text: {:s}".format(4/duration, tokenizer.decode(enc_text_video)))
636
-
637
- text_len = len(seq) - frame_len*generate_frame_num - 1
638
- guider_text_len = len(guider_seq) - frame_len*generate_frame_num - 1
639
- seq = torch.cuda.LongTensor(seq, device=args.device).unsqueeze(0).repeat(batch_size, 1)
640
- guider_seq = torch.cuda.LongTensor(guider_seq, device=args.device).unsqueeze(0).repeat(batch_size, 1)
641
-
642
- for given_frame_id in range(given_tokens.shape[1]):
643
- seq[:, text_len+1+given_frame_id*400: text_len+1+(given_frame_id+1)*400] = given_tokens[:, given_frame_id]
644
- guider_seq[:, guider_text_len+1+given_frame_id*400:guider_text_len+1+(given_frame_id+1)*400] = given_tokens[:, given_frame_id]
645
- output_list = []
646
-
647
- if use_guide:
648
- video_log_text_attention_weights = 0
649
- else:
650
- guider_seq = None
651
- video_log_text_attention_weights = 1.4
652
-
653
- for tim in range(max(batch_size // mbz, 1)):
654
- start_time = time.time()
655
- input_seq = seq[:min(batch_size, mbz)].clone() if tim == 0 else seq[mbz*tim:mbz*(tim+1)].clone()
656
- guider_seq2 = (guider_seq[:min(batch_size, mbz)].clone() if tim == 0 else guider_seq[mbz*tim:mbz*(tim+1)].clone()) if guider_seq is not None else None
657
- output_list.append(
658
- my_filling_sequence(model, args,input_seq,
659
- batch_size=min(batch_size, mbz),
660
- get_masks_and_position_ids=get_masks_and_position_ids_stage1,
661
- text_len=text_len, frame_len=frame_len,
662
- strategy=strategy_cogview2,
663
- strategy2=strategy_cogvideo,
664
- log_text_attention_weights=video_log_text_attention_weights,
665
- guider_seq=guider_seq2,
666
- guider_text_len=guider_text_len,
667
- guidance_alpha=args.guidance_alpha,
668
- limited_spatial_channel_mem=True,
669
- mode_stage1=True,
670
- )[0]
671
- )
672
-
673
- output_tokens = torch.cat(output_list, dim=0)[:, 1+text_len:]
674
-
675
- if args.both_stages:
676
- move_start_time = time.time()
677
- logging.debug("moving stage 1 model to cpu")
678
- model = model.cpu()
679
- torch.cuda.empty_cache()
680
- logging.debug("moving in model1 takes time: {:.2f}".format(time.time()-move_start_time))
681
-
682
- # decoding
683
- imgs, sred_imgs, txts = [], [], []
684
- for seq in output_tokens:
685
- decoded_imgs = [torch.nn.functional.interpolate(tokenizer.decode(image_ids=seq.tolist()[i*400: (i+1)*400]), size=(480, 480)) for i in range(total_frames)]
686
- imgs.append(decoded_imgs) # only the last image (target)
687
-
688
- assert len(imgs) == batch_size
689
- save_tokens = output_tokens[:, :+total_frames*400].reshape(-1, total_frames, 400).cpu()
690
- if outputdir is not None:
691
- for clip_i in range(len(imgs)):
692
- # os.makedirs(output_dir_full_paths[clip_i], exist_ok=True)
693
- my_save_multiple_images(imgs[clip_i], outputdir, subdir=f"frames/{clip_i}", debug=False)
694
- os.system(f"gifmaker -i '{outputdir}'/frames/'{clip_i}'/0*.jpg -o '{outputdir}/{clip_i}.gif' -d 0.25")
695
- torch.save(save_tokens, os.path.join(outputdir, 'frame_tokens.pt'))
696
-
697
- logging.info("CogVideo Stage1 completed. Taken time {:.2f}\n".format(time.time() - process_start_time))
698
-
699
- return save_tokens
700
-
701
- # ======================================================================================================
702
-
703
- if args.stage_1 or args.both_stages:
704
- if args.input_source != "interactive":
705
- with open(args.input_source, 'r') as fin:
706
- promptlist = fin.readlines()
707
- promptlist = [p.strip() for p in promptlist]
708
- else:
709
- promptlist = None
710
-
711
- now_qi = -1
712
- while True:
713
- now_qi += 1
714
-
715
- if promptlist is not None: # with input-source
716
- if args.multi_gpu:
717
- if now_qi % dist.get_world_size() != dist.get_rank():
718
- continue
719
- rk = dist.get_rank()
720
- else:
721
- rk = 0
722
- raw_text = promptlist[now_qi]
723
- raw_text = raw_text.strip()
724
- print(f'Working on Line No. {now_qi} on {rk}... [{raw_text}]')
725
- else: # interactive
726
- raw_text = input("\nPlease Input Query (stop to exit) >>> ")
727
- raw_text = raw_text.strip()
728
- if not raw_text:
729
- print('Query should not be empty!')
730
- continue
731
- if raw_text == "stop":
732
- return
733
-
734
- try:
735
- path = os.path.join(args.output_path, f"{now_qi}_{raw_text}")
736
- parent_given_tokens = process_stage1(model_stage1, raw_text, duration=4.0, video_raw_text=raw_text, video_guidance_text="视频",
737
- image_text_suffix=" 高清摄影",
738
- outputdir=path if args.stage_1 else None, batch_size=args.batch_size)
739
- if args.both_stages:
740
- process_stage2(model_stage2, raw_text, duration=2.0, video_raw_text=raw_text+" 视频",
741
- video_guidance_text="视频", parent_given_tokens=parent_given_tokens,
742
- outputdir=path,
743
- gpu_rank=0, gpu_parallel_size=1) # TODO: 修改
744
- except (ValueError, FileNotFoundError) as e:
745
- print(e)
746
- continue
747
-
748
- elif args.stage_2:
749
- sample_dirs = os.listdir(args.output_path)
750
- for sample in sample_dirs:
751
- raw_text = sample.split('_')[-1]
752
- path = os.path.join(args.output_path, sample, 'Interp')
753
- parent_given_tokens = torch.load(os.path.join(args.output_path, sample, "frame_tokens.pt"))
754
-
755
- process_stage2(raw_text, duration=2.0, video_raw_text=raw_text+" 视频",
756
- video_guidance_text="视频", parent_given_tokens=parent_given_tokens,
757
- outputdir=path,
758
- gpu_rank=0, gpu_parallel_size=1) # TODO: 修改
759
-
760
- else:
761
- assert False
762
-
763
-
764
- if __name__ == "__main__":
765
- logging.basicConfig(stream=sys.stderr, level=logging.DEBUG)
766
-
767
- py_parser = argparse.ArgumentParser(add_help=False)
768
- py_parser.add_argument('--generate-frame-num', type=int, default=5)
769
- py_parser.add_argument('--coglm-temperature2', type=float, default=0.89)
770
- # py_parser.add_argument("--interp-duration", type=float, default=-1) # -1是顺序生成,0是超分,0.5/1/2是插帧
771
- # py_parser.add_argument("--total-duration", type=float, default=4.0) # 整个的时间
772
- py_parser.add_argument('--use-guidance-stage1', action='store_true')
773
- py_parser.add_argument('--use-guidance-stage2', action='store_false')
774
- py_parser.add_argument('--guidance-alpha', type=float, default=3.0)
775
- py_parser.add_argument('--stage-1', action='store_true') # stage 1: sequential generation
776
- py_parser.add_argument('--stage-2', action='store_false') # stage 2: interp + dsr
777
- py_parser.add_argument('--both-stages', action='store_false') # stage 1&2: sequential generation; interp + dsr
778
- py_parser.add_argument('--parallel-size', type=int, default=1)
779
- py_parser.add_argument('--stage1-max-inference-batch-size', type=int, default=1) # -1: use max-inference-batch-size
780
- py_parser.add_argument('--multi-gpu', action='store_false')
781
-
782
- CogVideoCacheModel.add_model_specific_args(py_parser)
783
-
784
- known, args_list = py_parser.parse_known_args()
785
- args = get_args(args_list)
786
- args = argparse.Namespace(**vars(args), **vars(known))
787
- args.layout = [int(x) for x in args.layout.split(',')]
788
- args.do_train = False
789
-
790
- torch.cuda.set_device(args.device)
791
-
792
- with torch.no_grad():
793
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
icetk_models/.gitkeep ADDED
File without changes
model.py ADDED
@@ -0,0 +1,1243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is adapted from https://github.com/THUDM/CogVideo/blob/ff423aa169978fb2f636f761e348631fa3178b03/cogvideo_pipeline.py
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import logging
7
+ import os
8
+ import pathlib
9
+ import shutil
10
+ import subprocess
11
+ import sys
12
+ import tempfile
13
+ import time
14
+ import zipfile
15
+ from typing import Any
16
+
17
+ if os.getenv('SYSTEM') == 'spaces':
18
+ subprocess.run('pip install icetk==0.0.4'.split())
19
+ subprocess.run('pip install SwissArmyTransformer==0.2.9'.split())
20
+ subprocess.run(
21
+ 'pip install git+https://github.com/Sleepychord/Image-Local-Attention@43fee31'
22
+ .split())
23
+ #subprocess.run('git clone https://github.com/NVIDIA/apex'.split())
24
+ #subprocess.run('git checkout 1403c21'.split(), cwd='apex')
25
+ #with open('patch.apex') as f:
26
+ # subprocess.run('patch -p1'.split(), cwd='apex', stdin=f)
27
+ #subprocess.run(
28
+ # 'pip install -v --disable-pip-version-check --no-cache-dir --global-option --cpp_ext --global-option --cuda_ext ./'
29
+ # .split(),
30
+ # cwd='apex')
31
+ #subprocess.run('rm -rf apex'.split())
32
+ with open('patch') as f:
33
+ subprocess.run('patch -p1'.split(), cwd='CogVideo', stdin=f)
34
+
35
+ from huggingface_hub import hf_hub_download
36
+
37
+ def download_and_extract_icetk_models() -> None:
38
+ icetk_model_dir = pathlib.Path('/home/user/.icetk_models')
39
+ icetk_model_dir.mkdir()
40
+ path = hf_hub_download('THUDM/icetk',
41
+ 'models.zip',
42
+ use_auth_token=os.getenv('HF_TOKEN'))
43
+ with zipfile.ZipFile(path) as f:
44
+ f.extractall(path=icetk_model_dir.as_posix())
45
+
46
+ def download_and_extract_cogvideo_models(name: str) -> None:
47
+ path = hf_hub_download('THUDM/CogVideo',
48
+ name,
49
+ use_auth_token=os.getenv('HF_TOKEN'))
50
+ with zipfile.ZipFile(path) as f:
51
+ f.extractall('pretrained')
52
+ os.remove(path)
53
+
54
+ def download_and_extract_cogview2_models(name: str) -> None:
55
+ path = hf_hub_download('THUDM/CogView2', name)
56
+ with zipfile.ZipFile(path) as f:
57
+ f.extractall()
58
+ shutil.move('/home/user/app/sharefs/cogview-new/cogview2-dsr',
59
+ 'pretrained')
60
+ shutil.rmtree('/home/user/app/sharefs/')
61
+ os.remove(path)
62
+
63
+ download_and_extract_icetk_models()
64
+ download_and_extract_cogvideo_models('cogvideo-stage1.zip')
65
+ #download_and_extract_cogvideo_models('cogvideo-stage2.zip')
66
+ #download_and_extract_cogview2_models('cogview2-dsr.zip')
67
+
68
+ os.environ['SAT_HOME'] = '/home/user/app/pretrained'
69
+
70
+ import gradio as gr
71
+ import imageio.v2 as iio
72
+ import numpy as np
73
+ import torch
74
+ from icetk import IceTokenizer
75
+ from SwissArmyTransformer import get_args
76
+ from SwissArmyTransformer.arguments import set_random_seed
77
+ from SwissArmyTransformer.generation.sampling_strategies import BaseStrategy
78
+ from SwissArmyTransformer.resources import auto_create
79
+
80
+ app_dir = pathlib.Path(__file__).parent
81
+ submodule_dir = app_dir / 'CogVideo'
82
+ sys.path.insert(0, submodule_dir.as_posix())
83
+
84
+ from coglm_strategy import CoglmStrategy
85
+ from models.cogvideo_cache_model import CogVideoCacheModel
86
+ from sr_pipeline import DirectSuperResolution
87
+
88
+ formatter = logging.Formatter(
89
+ '[%(asctime)s] %(name)s %(levelname)s: %(message)s',
90
+ datefmt='%Y-%m-%d %H:%M:%S')
91
+ stream_handler = logging.StreamHandler(stream=sys.stdout)
92
+ stream_handler.setLevel(logging.INFO)
93
+ stream_handler.setFormatter(formatter)
94
+ logger = logging.getLogger(__name__)
95
+ logger.setLevel(logging.INFO)
96
+ logger.propagate = False
97
+ logger.addHandler(stream_handler)
98
+
99
+ ICETK_MODEL_DIR = app_dir / 'icetk_models'
100
+
101
+
102
+ def get_masks_and_position_ids_stage1(data, textlen, framelen):
103
+ # Extract batch size and sequence length.
104
+ tokens = data
105
+ seq_length = len(data[0])
106
+ # Attention mask (lower triangular).
107
+ attention_mask = torch.ones((1, textlen + framelen, textlen + framelen),
108
+ device=data.device)
109
+ attention_mask[:, :textlen, textlen:] = 0
110
+ attention_mask[:, textlen:, textlen:].tril_()
111
+ attention_mask.unsqueeze_(1)
112
+ # Unaligned version
113
+ position_ids = torch.zeros(seq_length,
114
+ dtype=torch.long,
115
+ device=data.device)
116
+ torch.arange(textlen,
117
+ out=position_ids[:textlen],
118
+ dtype=torch.long,
119
+ device=data.device)
120
+ torch.arange(512,
121
+ 512 + seq_length - textlen,
122
+ out=position_ids[textlen:],
123
+ dtype=torch.long,
124
+ device=data.device)
125
+ position_ids = position_ids.unsqueeze(0)
126
+
127
+ return tokens, attention_mask, position_ids
128
+
129
+
130
+ def get_masks_and_position_ids_stage2(data, textlen, framelen):
131
+ # Extract batch size and sequence length.
132
+ tokens = data
133
+ seq_length = len(data[0])
134
+
135
+ # Attention mask (lower triangular).
136
+ attention_mask = torch.ones((1, textlen + framelen, textlen + framelen),
137
+ device=data.device)
138
+ attention_mask[:, :textlen, textlen:] = 0
139
+ attention_mask[:, textlen:, textlen:].tril_()
140
+ attention_mask.unsqueeze_(1)
141
+
142
+ # Unaligned version
143
+ position_ids = torch.zeros(seq_length,
144
+ dtype=torch.long,
145
+ device=data.device)
146
+ torch.arange(textlen,
147
+ out=position_ids[:textlen],
148
+ dtype=torch.long,
149
+ device=data.device)
150
+ frame_num = (seq_length - textlen) // framelen
151
+ assert frame_num == 5
152
+ torch.arange(512,
153
+ 512 + framelen,
154
+ out=position_ids[textlen:textlen + framelen],
155
+ dtype=torch.long,
156
+ device=data.device)
157
+ torch.arange(512 + framelen * 2,
158
+ 512 + framelen * 3,
159
+ out=position_ids[textlen + framelen:textlen + framelen * 2],
160
+ dtype=torch.long,
161
+ device=data.device)
162
+ torch.arange(512 + framelen * (frame_num - 1),
163
+ 512 + framelen * frame_num,
164
+ out=position_ids[textlen + framelen * 2:textlen +
165
+ framelen * 3],
166
+ dtype=torch.long,
167
+ device=data.device)
168
+ torch.arange(512 + framelen * 1,
169
+ 512 + framelen * 2,
170
+ out=position_ids[textlen + framelen * 3:textlen +
171
+ framelen * 4],
172
+ dtype=torch.long,
173
+ device=data.device)
174
+ torch.arange(512 + framelen * 3,
175
+ 512 + framelen * 4,
176
+ out=position_ids[textlen + framelen * 4:textlen +
177
+ framelen * 5],
178
+ dtype=torch.long,
179
+ device=data.device)
180
+
181
+ position_ids = position_ids.unsqueeze(0)
182
+
183
+ return tokens, attention_mask, position_ids
184
+
185
+
186
+ def my_update_mems(hiddens, mems_buffers, mems_indexs,
187
+ limited_spatial_channel_mem, text_len, frame_len):
188
+ if hiddens is None:
189
+ return None, mems_indexs
190
+ mem_num = len(hiddens)
191
+ ret_mem = []
192
+ with torch.no_grad():
193
+ for id in range(mem_num):
194
+ if hiddens[id][0] is None:
195
+ ret_mem.append(None)
196
+ else:
197
+ if id == 0 and limited_spatial_channel_mem and mems_indexs[
198
+ id] + hiddens[0][0].shape[1] >= text_len + frame_len:
199
+ if mems_indexs[id] == 0:
200
+ for layer, hidden in enumerate(hiddens[id]):
201
+ mems_buffers[id][
202
+ layer, :, :text_len] = hidden.expand(
203
+ mems_buffers[id].shape[1], -1,
204
+ -1)[:, :text_len]
205
+ new_mem_len_part2 = (mems_indexs[id] +
206
+ hiddens[0][0].shape[1] -
207
+ text_len) % frame_len
208
+ if new_mem_len_part2 > 0:
209
+ for layer, hidden in enumerate(hiddens[id]):
210
+ mems_buffers[id][
211
+ layer, :, text_len:text_len +
212
+ new_mem_len_part2] = hidden.expand(
213
+ mems_buffers[id].shape[1], -1,
214
+ -1)[:, -new_mem_len_part2:]
215
+ mems_indexs[id] = text_len + new_mem_len_part2
216
+ else:
217
+ for layer, hidden in enumerate(hiddens[id]):
218
+ mems_buffers[id][layer, :,
219
+ mems_indexs[id]:mems_indexs[id] +
220
+ hidden.shape[1]] = hidden.expand(
221
+ mems_buffers[id].shape[1], -1, -1)
222
+ mems_indexs[id] += hidden.shape[1]
223
+ ret_mem.append(mems_buffers[id][:, :, :mems_indexs[id]])
224
+ return ret_mem, mems_indexs
225
+
226
+
227
+ def calc_next_tokens_frame_begin_id(text_len, frame_len, total_len):
228
+ # The fisrt token's position id of the frame that the next token belongs to;
229
+ if total_len < text_len:
230
+ return None
231
+ return (total_len - text_len) // frame_len * frame_len + text_len
232
+
233
+
234
+ def my_filling_sequence(
235
+ model,
236
+ tokenizer,
237
+ args,
238
+ seq,
239
+ batch_size,
240
+ get_masks_and_position_ids,
241
+ text_len,
242
+ frame_len,
243
+ strategy=BaseStrategy(),
244
+ strategy2=BaseStrategy(),
245
+ mems=None,
246
+ log_text_attention_weights=0, # default to 0: no artificial change
247
+ mode_stage1=True,
248
+ enforce_no_swin=False,
249
+ guider_seq=None,
250
+ guider_text_len=0,
251
+ guidance_alpha=1,
252
+ limited_spatial_channel_mem=False, # 空间通道的存储限制在本帧内
253
+ **kw_args):
254
+ '''
255
+ seq: [2, 3, 5, ..., -1(to be generated), -1, ...]
256
+ mems: [num_layers, batch_size, len_mems(index), mem_hidden_size]
257
+ cache, should be first mems.shape[1] parts of context_tokens.
258
+ mems are the first-level citizens here, but we don't assume what is memorized.
259
+ input mems are used when multi-phase generation.
260
+ '''
261
+ if guider_seq is not None:
262
+ logger.debug('Using Guidance In Inference')
263
+ if limited_spatial_channel_mem:
264
+ logger.debug("Limit spatial-channel's mem to current frame")
265
+ assert len(seq.shape) == 2
266
+
267
+ # building the initial tokens, attention_mask, and position_ids
268
+ actual_context_length = 0
269
+
270
+ while seq[-1][
271
+ actual_context_length] >= 0: # the last seq has least given tokens
272
+ actual_context_length += 1 # [0, context_length-1] are given
273
+ assert actual_context_length > 0
274
+ current_frame_num = (actual_context_length - text_len) // frame_len
275
+ assert current_frame_num >= 0
276
+ context_length = text_len + current_frame_num * frame_len
277
+
278
+ tokens, attention_mask, position_ids = get_masks_and_position_ids(
279
+ seq, text_len, frame_len)
280
+ tokens = tokens[..., :context_length]
281
+ input_tokens = tokens.clone()
282
+
283
+ if guider_seq is not None:
284
+ guider_index_delta = text_len - guider_text_len
285
+ guider_tokens, guider_attention_mask, guider_position_ids = get_masks_and_position_ids(
286
+ guider_seq, guider_text_len, frame_len)
287
+ guider_tokens = guider_tokens[..., :context_length -
288
+ guider_index_delta]
289
+ guider_input_tokens = guider_tokens.clone()
290
+
291
+ for fid in range(current_frame_num):
292
+ input_tokens[:, text_len + 400 * fid] = tokenizer['<start_of_image>']
293
+ if guider_seq is not None:
294
+ guider_input_tokens[:, guider_text_len +
295
+ 400 * fid] = tokenizer['<start_of_image>']
296
+
297
+ attention_mask = attention_mask.type_as(next(
298
+ model.parameters())) # if fp16
299
+ # initialize generation
300
+ counter = context_length - 1 # Last fixed index is ``counter''
301
+ index = 0 # Next forward starting index, also the length of cache.
302
+ mems_buffers_on_GPU = False
303
+ mems_indexs = [0, 0]
304
+ mems_len = [(400 + 74) if limited_spatial_channel_mem else 5 * 400 + 74,
305
+ 5 * 400 + 74]
306
+ mems_buffers = [
307
+ torch.zeros(args.num_layers,
308
+ batch_size,
309
+ mem_len,
310
+ args.hidden_size * 2,
311
+ dtype=next(model.parameters()).dtype)
312
+ for mem_len in mems_len
313
+ ]
314
+
315
+ if guider_seq is not None:
316
+ guider_attention_mask = guider_attention_mask.type_as(
317
+ next(model.parameters())) # if fp16
318
+ guider_mems_buffers = [
319
+ torch.zeros(args.num_layers,
320
+ batch_size,
321
+ mem_len,
322
+ args.hidden_size * 2,
323
+ dtype=next(model.parameters()).dtype)
324
+ for mem_len in mems_len
325
+ ]
326
+ guider_mems_indexs = [0, 0]
327
+ guider_mems = None
328
+
329
+ torch.cuda.empty_cache()
330
+ # step-by-step generation
331
+ while counter < len(seq[0]) - 1:
332
+ # we have generated counter+1 tokens
333
+ # Now, we want to generate seq[counter + 1],
334
+ # token[:, index: counter+1] needs forwarding.
335
+ if index == 0:
336
+ group_size = 2 if (input_tokens.shape[0] == batch_size
337
+ and not mode_stage1) else batch_size
338
+
339
+ logits_all = None
340
+ for batch_idx in range(0, input_tokens.shape[0], group_size):
341
+ logits, *output_per_layers = model(
342
+ input_tokens[batch_idx:batch_idx + group_size, index:],
343
+ position_ids[..., index:counter + 1],
344
+ attention_mask, # TODO memlen
345
+ mems=mems,
346
+ text_len=text_len,
347
+ frame_len=frame_len,
348
+ counter=counter,
349
+ log_text_attention_weights=log_text_attention_weights,
350
+ enforce_no_swin=enforce_no_swin,
351
+ **kw_args)
352
+ logits_all = torch.cat(
353
+ (logits_all,
354
+ logits), dim=0) if logits_all is not None else logits
355
+ mem_kv01 = [[o['mem_kv'][0] for o in output_per_layers],
356
+ [o['mem_kv'][1] for o in output_per_layers]]
357
+ next_tokens_frame_begin_id = calc_next_tokens_frame_begin_id(
358
+ text_len, frame_len, mem_kv01[0][0].shape[1])
359
+ for id, mem_kv in enumerate(mem_kv01):
360
+ for layer, mem_kv_perlayer in enumerate(mem_kv):
361
+ if limited_spatial_channel_mem and id == 0:
362
+ mems_buffers[id][
363
+ layer, batch_idx:batch_idx + group_size, :
364
+ text_len] = mem_kv_perlayer.expand(
365
+ min(group_size,
366
+ input_tokens.shape[0] - batch_idx), -1,
367
+ -1)[:, :text_len]
368
+ mems_buffers[id][layer, batch_idx:batch_idx+group_size, text_len:text_len+mem_kv_perlayer.shape[1]-next_tokens_frame_begin_id] =\
369
+ mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1)[:, next_tokens_frame_begin_id:]
370
+ else:
371
+ mems_buffers[id][
372
+ layer, batch_idx:batch_idx +
373
+ group_size, :mem_kv_perlayer.
374
+ shape[1]] = mem_kv_perlayer.expand(
375
+ min(group_size,
376
+ input_tokens.shape[0] - batch_idx), -1,
377
+ -1)
378
+ mems_indexs[0], mems_indexs[1] = mem_kv01[0][0].shape[
379
+ 1], mem_kv01[1][0].shape[1]
380
+ if limited_spatial_channel_mem:
381
+ mems_indexs[0] -= (next_tokens_frame_begin_id - text_len)
382
+
383
+ mems = [
384
+ mems_buffers[id][:, :, :mems_indexs[id]] for id in range(2)
385
+ ]
386
+ logits = logits_all
387
+
388
+ # Guider
389
+ if guider_seq is not None:
390
+ guider_logits_all = None
391
+ for batch_idx in range(0, guider_input_tokens.shape[0],
392
+ group_size):
393
+ guider_logits, *guider_output_per_layers = model(
394
+ guider_input_tokens[batch_idx:batch_idx + group_size,
395
+ max(index -
396
+ guider_index_delta, 0):],
397
+ guider_position_ids[
398
+ ...,
399
+ max(index - guider_index_delta, 0):counter + 1 -
400
+ guider_index_delta],
401
+ guider_attention_mask,
402
+ mems=guider_mems,
403
+ text_len=guider_text_len,
404
+ frame_len=frame_len,
405
+ counter=counter - guider_index_delta,
406
+ log_text_attention_weights=log_text_attention_weights,
407
+ enforce_no_swin=enforce_no_swin,
408
+ **kw_args)
409
+ guider_logits_all = torch.cat(
410
+ (guider_logits_all, guider_logits), dim=0
411
+ ) if guider_logits_all is not None else guider_logits
412
+ guider_mem_kv01 = [[
413
+ o['mem_kv'][0] for o in guider_output_per_layers
414
+ ], [o['mem_kv'][1] for o in guider_output_per_layers]]
415
+ for id, guider_mem_kv in enumerate(guider_mem_kv01):
416
+ for layer, guider_mem_kv_perlayer in enumerate(
417
+ guider_mem_kv):
418
+ if limited_spatial_channel_mem and id == 0:
419
+ guider_mems_buffers[id][
420
+ layer, batch_idx:batch_idx + group_size, :
421
+ guider_text_len] = guider_mem_kv_perlayer.expand(
422
+ min(group_size,
423
+ input_tokens.shape[0] - batch_idx),
424
+ -1, -1)[:, :guider_text_len]
425
+ guider_next_tokens_frame_begin_id = calc_next_tokens_frame_begin_id(
426
+ guider_text_len, frame_len,
427
+ guider_mem_kv_perlayer.shape[1])
428
+ guider_mems_buffers[id][layer, batch_idx:batch_idx+group_size, guider_text_len:guider_text_len+guider_mem_kv_perlayer.shape[1]-guider_next_tokens_frame_begin_id] =\
429
+ guider_mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1)[:, guider_next_tokens_frame_begin_id:]
430
+ else:
431
+ guider_mems_buffers[id][
432
+ layer, batch_idx:batch_idx +
433
+ group_size, :guider_mem_kv_perlayer.
434
+ shape[1]] = guider_mem_kv_perlayer.expand(
435
+ min(group_size,
436
+ input_tokens.shape[0] - batch_idx),
437
+ -1, -1)
438
+ guider_mems_indexs[0], guider_mems_indexs[
439
+ 1] = guider_mem_kv01[0][0].shape[1], guider_mem_kv01[
440
+ 1][0].shape[1]
441
+ if limited_spatial_channel_mem:
442
+ guider_mems_indexs[0] -= (
443
+ guider_next_tokens_frame_begin_id -
444
+ guider_text_len)
445
+ guider_mems = [
446
+ guider_mems_buffers[id][:, :, :guider_mems_indexs[id]]
447
+ for id in range(2)
448
+ ]
449
+ guider_logits = guider_logits_all
450
+ else:
451
+ if not mems_buffers_on_GPU:
452
+ if not mode_stage1:
453
+ torch.cuda.empty_cache()
454
+ for idx, mem in enumerate(mems):
455
+ mems[idx] = mem.to(next(model.parameters()).device)
456
+ if guider_seq is not None:
457
+ for idx, mem in enumerate(guider_mems):
458
+ guider_mems[idx] = mem.to(
459
+ next(model.parameters()).device)
460
+ else:
461
+ torch.cuda.empty_cache()
462
+ for idx, mem_buffer in enumerate(mems_buffers):
463
+ mems_buffers[idx] = mem_buffer.to(
464
+ next(model.parameters()).device)
465
+ mems = [
466
+ mems_buffers[id][:, :, :mems_indexs[id]]
467
+ for id in range(2)
468
+ ]
469
+ if guider_seq is not None:
470
+ for idx, guider_mem_buffer in enumerate(
471
+ guider_mems_buffers):
472
+ guider_mems_buffers[idx] = guider_mem_buffer.to(
473
+ next(model.parameters()).device)
474
+ guider_mems = [
475
+ guider_mems_buffers[id]
476
+ [:, :, :guider_mems_indexs[id]] for id in range(2)
477
+ ]
478
+ mems_buffers_on_GPU = True
479
+
480
+ logits, *output_per_layers = model(
481
+ input_tokens[:, index:],
482
+ position_ids[..., index:counter + 1],
483
+ attention_mask, # TODO memlen
484
+ mems=mems,
485
+ text_len=text_len,
486
+ frame_len=frame_len,
487
+ counter=counter,
488
+ log_text_attention_weights=log_text_attention_weights,
489
+ enforce_no_swin=enforce_no_swin,
490
+ limited_spatial_channel_mem=limited_spatial_channel_mem,
491
+ **kw_args)
492
+ mem_kv0, mem_kv1 = [o['mem_kv'][0] for o in output_per_layers
493
+ ], [o['mem_kv'][1] for o in output_per_layers]
494
+
495
+ if guider_seq is not None:
496
+ guider_logits, *guider_output_per_layers = model(
497
+ guider_input_tokens[:,
498
+ max(index - guider_index_delta, 0):],
499
+ guider_position_ids[...,
500
+ max(index -
501
+ guider_index_delta, 0):counter +
502
+ 1 - guider_index_delta],
503
+ guider_attention_mask,
504
+ mems=guider_mems,
505
+ text_len=guider_text_len,
506
+ frame_len=frame_len,
507
+ counter=counter - guider_index_delta,
508
+ log_text_attention_weights=0,
509
+ enforce_no_swin=enforce_no_swin,
510
+ limited_spatial_channel_mem=limited_spatial_channel_mem,
511
+ **kw_args)
512
+ guider_mem_kv0, guider_mem_kv1 = [
513
+ o['mem_kv'][0] for o in guider_output_per_layers
514
+ ], [o['mem_kv'][1] for o in guider_output_per_layers]
515
+
516
+ if not mems_buffers_on_GPU:
517
+ torch.cuda.empty_cache()
518
+ for idx, mem_buffer in enumerate(mems_buffers):
519
+ mems_buffers[idx] = mem_buffer.to(
520
+ next(model.parameters()).device)
521
+ if guider_seq is not None:
522
+ for idx, guider_mem_buffer in enumerate(
523
+ guider_mems_buffers):
524
+ guider_mems_buffers[idx] = guider_mem_buffer.to(
525
+ next(model.parameters()).device)
526
+ mems_buffers_on_GPU = True
527
+
528
+ mems, mems_indexs = my_update_mems([mem_kv0, mem_kv1],
529
+ mems_buffers, mems_indexs,
530
+ limited_spatial_channel_mem,
531
+ text_len, frame_len)
532
+ if guider_seq is not None:
533
+ guider_mems, guider_mems_indexs = my_update_mems(
534
+ [guider_mem_kv0, guider_mem_kv1], guider_mems_buffers,
535
+ guider_mems_indexs, limited_spatial_channel_mem,
536
+ guider_text_len, frame_len)
537
+
538
+ counter += 1
539
+ index = counter
540
+
541
+ logits = logits[:, -1].expand(batch_size,
542
+ -1) # [batch size, vocab size]
543
+ tokens = tokens.expand(batch_size, -1)
544
+ if guider_seq is not None:
545
+ guider_logits = guider_logits[:, -1].expand(batch_size, -1)
546
+ guider_tokens = guider_tokens.expand(batch_size, -1)
547
+
548
+ if seq[-1][counter].item() < 0:
549
+ # sampling
550
+ guided_logits = guider_logits + (
551
+ logits - guider_logits
552
+ ) * guidance_alpha if guider_seq is not None else logits
553
+ if mode_stage1 and counter < text_len + 400:
554
+ tokens, mems = strategy.forward(guided_logits, tokens, mems)
555
+ else:
556
+ tokens, mems = strategy2.forward(guided_logits, tokens, mems)
557
+ if guider_seq is not None:
558
+ guider_tokens = torch.cat((guider_tokens, tokens[:, -1:]),
559
+ dim=1)
560
+
561
+ if seq[0][counter].item() >= 0:
562
+ for si in range(seq.shape[0]):
563
+ if seq[si][counter].item() >= 0:
564
+ tokens[si, -1] = seq[si, counter]
565
+ if guider_seq is not None:
566
+ guider_tokens[si,
567
+ -1] = guider_seq[si, counter -
568
+ guider_index_delta]
569
+
570
+ else:
571
+ tokens = torch.cat(
572
+ (tokens, seq[:, counter:counter + 1].clone().expand(
573
+ tokens.shape[0], 1).to(device=tokens.device,
574
+ dtype=tokens.dtype)),
575
+ dim=1)
576
+ if guider_seq is not None:
577
+ guider_tokens = torch.cat(
578
+ (guider_tokens,
579
+ guider_seq[:, counter - guider_index_delta:counter + 1 -
580
+ guider_index_delta].clone().expand(
581
+ guider_tokens.shape[0], 1).to(
582
+ device=guider_tokens.device,
583
+ dtype=guider_tokens.dtype)),
584
+ dim=1)
585
+
586
+ input_tokens = tokens.clone()
587
+ if guider_seq is not None:
588
+ guider_input_tokens = guider_tokens.clone()
589
+ if (index - text_len - 1) // 400 < (input_tokens.shape[-1] - text_len -
590
+ 1) // 400:
591
+ boi_idx = ((index - text_len - 1) // 400 + 1) * 400 + text_len
592
+ while boi_idx < input_tokens.shape[-1]:
593
+ input_tokens[:, boi_idx] = tokenizer['<start_of_image>']
594
+ if guider_seq is not None:
595
+ guider_input_tokens[:, boi_idx -
596
+ guider_index_delta] = tokenizer[
597
+ '<start_of_image>']
598
+ boi_idx += 400
599
+
600
+ if strategy.is_done:
601
+ break
602
+ return strategy.finalize(tokens, mems)
603
+
604
+
605
+ class InferenceModel_Sequential(CogVideoCacheModel):
606
+ def __init__(self, args, transformer=None, parallel_output=True):
607
+ super().__init__(args,
608
+ transformer=transformer,
609
+ parallel_output=parallel_output,
610
+ window_size=-1,
611
+ cogvideo_stage=1)
612
+
613
+ # TODO: check it
614
+
615
+ def final_forward(self, logits, **kwargs):
616
+ logits_parallel = logits
617
+ logits_parallel = torch.nn.functional.linear(
618
+ logits_parallel.float(),
619
+ self.transformer.word_embeddings.weight[:20000].float())
620
+ return logits_parallel
621
+
622
+
623
+ class InferenceModel_Interpolate(CogVideoCacheModel):
624
+ def __init__(self, args, transformer=None, parallel_output=True):
625
+ super().__init__(args,
626
+ transformer=transformer,
627
+ parallel_output=parallel_output,
628
+ window_size=10,
629
+ cogvideo_stage=2)
630
+
631
+ # TODO: check it
632
+
633
+ def final_forward(self, logits, **kwargs):
634
+ logits_parallel = logits
635
+ logits_parallel = torch.nn.functional.linear(
636
+ logits_parallel.float(),
637
+ self.transformer.word_embeddings.weight[:20000].float())
638
+ return logits_parallel
639
+
640
+
641
+ def get_default_args() -> argparse.Namespace:
642
+ known = argparse.Namespace(generate_frame_num=5,
643
+ coglm_temperature2=0.89,
644
+ use_guidance_stage1=True,
645
+ use_guidance_stage2=False,
646
+ guidance_alpha=3.0,
647
+ stage_1=True,
648
+ stage_2=False,
649
+ both_stages=False,
650
+ parallel_size=1,
651
+ stage1_max_inference_batch_size=-1,
652
+ multi_gpu=False,
653
+ layout='64, 464, 2064',
654
+ window_size=10,
655
+ additional_seqlen=2000,
656
+ cogvideo_stage=1)
657
+
658
+ args_list = [
659
+ '--tokenizer-type',
660
+ 'fake',
661
+ '--mode',
662
+ 'inference',
663
+ '--distributed-backend',
664
+ 'nccl',
665
+ '--fp16',
666
+ '--model-parallel-size',
667
+ '1',
668
+ '--temperature',
669
+ '1.05',
670
+ '--top_k',
671
+ '12',
672
+ '--sandwich-ln',
673
+ '--seed',
674
+ '1234',
675
+ '--num-workers',
676
+ '0',
677
+ '--batch-size',
678
+ '1',
679
+ '--max-inference-batch-size',
680
+ '8',
681
+ ]
682
+ args = get_args(args_list)
683
+ args = argparse.Namespace(**vars(args), **vars(known))
684
+ args.layout = [int(x) for x in args.layout.split(',')]
685
+ args.do_train = False
686
+ return args
687
+
688
+
689
+ class Model:
690
+ def __init__(self, only_first_stage: bool = False):
691
+ self.args = get_default_args()
692
+ if only_first_stage:
693
+ self.args.stage_1 = True
694
+ self.args.both_stages = False
695
+ else:
696
+ self.args.stage_1 = False
697
+ self.args.both_stages = True
698
+
699
+ self.tokenizer = self.load_tokenizer()
700
+
701
+ self.model_stage1, self.args = self.load_model_stage1()
702
+ self.model_stage2, self.args = self.load_model_stage2()
703
+
704
+ self.strategy_cogview2, self.strategy_cogvideo = self.load_strategies()
705
+ self.dsr = self.load_dsr()
706
+
707
+ self.device = torch.device(self.args.device)
708
+
709
+ def load_tokenizer(self) -> IceTokenizer:
710
+ logger.info('--- load_tokenizer ---')
711
+ start = time.perf_counter()
712
+
713
+ tokenizer = IceTokenizer(ICETK_MODEL_DIR.as_posix())
714
+ tokenizer.add_special_tokens(
715
+ ['<start_of_image>', '<start_of_english>', '<start_of_chinese>'])
716
+
717
+ elapsed = time.perf_counter() - start
718
+ logger.info(f'--- done ({elapsed=:.3f}) ---')
719
+ return tokenizer
720
+
721
+ def load_model_stage1(
722
+ self) -> tuple[CogVideoCacheModel, argparse.Namespace]:
723
+ logger.info('--- load_model_stage1 ---')
724
+ start = time.perf_counter()
725
+
726
+ args = self.args
727
+ model_stage1, args = InferenceModel_Sequential.from_pretrained(
728
+ args, 'cogvideo-stage1')
729
+ model_stage1.eval()
730
+ if args.both_stages:
731
+ model_stage1 = model_stage1.cpu()
732
+
733
+ elapsed = time.perf_counter() - start
734
+ logger.info(f'--- done ({elapsed=:.3f}) ---')
735
+ return model_stage1, args
736
+
737
+ def load_model_stage2(
738
+ self) -> tuple[CogVideoCacheModel | None, argparse.Namespace]:
739
+ logger.info('--- load_model_stage2 ---')
740
+ start = time.perf_counter()
741
+
742
+ args = self.args
743
+ if args.both_stages:
744
+ model_stage2, args = InferenceModel_Interpolate.from_pretrained(
745
+ args, 'cogvideo-stage2')
746
+ model_stage2.eval()
747
+ if args.both_stages:
748
+ model_stage2 = model_stage2.cpu()
749
+ else:
750
+ model_stage2 = None
751
+
752
+ elapsed = time.perf_counter() - start
753
+ logger.info(f'--- done ({elapsed=:.3f}) ---')
754
+ return model_stage2, args
755
+
756
+ def load_strategies(self) -> tuple[CoglmStrategy, CoglmStrategy]:
757
+ logger.info('--- load_strategies ---')
758
+ start = time.perf_counter()
759
+
760
+ invalid_slices = [slice(self.tokenizer.num_image_tokens, None)]
761
+ strategy_cogview2 = CoglmStrategy(invalid_slices,
762
+ temperature=1.0,
763
+ top_k=16)
764
+ strategy_cogvideo = CoglmStrategy(
765
+ invalid_slices,
766
+ temperature=self.args.temperature,
767
+ top_k=self.args.top_k,
768
+ temperature2=self.args.coglm_temperature2)
769
+
770
+ elapsed = time.perf_counter() - start
771
+ logger.info(f'--- done ({elapsed=:.3f}) ---')
772
+ return strategy_cogview2, strategy_cogvideo
773
+
774
+ def load_dsr(self) -> DirectSuperResolution | None:
775
+ logger.info('--- load_dsr ---')
776
+ start = time.perf_counter()
777
+
778
+ if self.args.both_stages:
779
+ path = auto_create('cogview2-dsr', path=None)
780
+ dsr = DirectSuperResolution(self.args,
781
+ path,
782
+ max_bz=12,
783
+ onCUDA=False)
784
+ else:
785
+ dsr = None
786
+
787
+ elapsed = time.perf_counter() - start
788
+ logger.info(f'--- done ({elapsed=:.3f}) ---')
789
+ return dsr
790
+
791
+ @torch.inference_mode()
792
+ def process_stage1(self,
793
+ model,
794
+ seq_text,
795
+ duration,
796
+ video_raw_text=None,
797
+ video_guidance_text='视频',
798
+ image_text_suffix='',
799
+ batch_size=1,
800
+ image_prompt=None):
801
+ process_start_time = time.perf_counter()
802
+
803
+ generate_frame_num = self.args.generate_frame_num
804
+ tokenizer = self.tokenizer
805
+ use_guide = self.args.use_guidance_stage1
806
+
807
+ if next(model.parameters()).device != self.device:
808
+ move_start_time = time.perf_counter()
809
+ logger.debug('moving stage 1 model to cuda')
810
+
811
+ model = model.to(self.device)
812
+
813
+ elapsed = time.perf_counter() - move_start_time
814
+ logger.debug(f'moving in model1 takes time: {elapsed:.2f}')
815
+
816
+ if video_raw_text is None:
817
+ video_raw_text = seq_text
818
+ mbz = self.args.stage1_max_inference_batch_size if self.args.stage1_max_inference_batch_size > 0 else self.args.max_inference_batch_size
819
+ assert batch_size < mbz or batch_size % mbz == 0
820
+ frame_len = 400
821
+
822
+ # generate the first frame:
823
+ enc_text = tokenizer.encode(seq_text + image_text_suffix)
824
+ seq_1st = enc_text + [tokenizer['<start_of_image>']] + [-1] * 400
825
+ logger.info(
826
+ f'[Generating First Frame with CogView2] Raw text: {tokenizer.decode(enc_text):s}'
827
+ )
828
+ text_len_1st = len(seq_1st) - frame_len * 1 - 1
829
+
830
+ seq_1st = torch.tensor(seq_1st, dtype=torch.long,
831
+ device=self.device).unsqueeze(0)
832
+ if image_prompt is None:
833
+ output_list_1st = []
834
+ for tim in range(max(batch_size // mbz, 1)):
835
+ start_time = time.perf_counter()
836
+ output_list_1st.append(
837
+ my_filling_sequence(
838
+ model,
839
+ tokenizer,
840
+ self.args,
841
+ seq_1st.clone(),
842
+ batch_size=min(batch_size, mbz),
843
+ get_masks_and_position_ids=
844
+ get_masks_and_position_ids_stage1,
845
+ text_len=text_len_1st,
846
+ frame_len=frame_len,
847
+ strategy=self.strategy_cogview2,
848
+ strategy2=self.strategy_cogvideo,
849
+ log_text_attention_weights=1.4,
850
+ enforce_no_swin=True,
851
+ mode_stage1=True,
852
+ )[0])
853
+ elapsed = time.perf_counter() - start_time
854
+ logger.info(f'[First Frame] Elapsed: {elapsed:.2f}')
855
+ output_tokens_1st = torch.cat(output_list_1st, dim=0)
856
+ given_tokens = output_tokens_1st[:, text_len_1st + 1:text_len_1st +
857
+ 401].unsqueeze(
858
+ 1
859
+ ) # given_tokens.shape: [bs, frame_num, 400]
860
+ else:
861
+ given_tokens = tokenizer.encode(image_path=image_prompt, image_size=160).repeat(batch_size, 1).unsqueeze(1)
862
+
863
+ # generate subsequent frames:
864
+ total_frames = generate_frame_num
865
+ enc_duration = tokenizer.encode(f'{float(duration)}秒')
866
+ if use_guide:
867
+ video_raw_text = video_raw_text + ' 视频'
868
+ enc_text_video = tokenizer.encode(video_raw_text)
869
+ seq = enc_duration + [tokenizer['<n>']] + enc_text_video + [
870
+ tokenizer['<start_of_image>']
871
+ ] + [-1] * 400 * generate_frame_num
872
+ guider_seq = enc_duration + [tokenizer['<n>']] + tokenizer.encode(
873
+ video_guidance_text) + [tokenizer['<start_of_image>']
874
+ ] + [-1] * 400 * generate_frame_num
875
+ logger.info(
876
+ f'[Stage1: Generating Subsequent Frames, Frame Rate {4/duration:.1f}] raw text: {tokenizer.decode(enc_text_video):s}'
877
+ )
878
+
879
+ text_len = len(seq) - frame_len * generate_frame_num - 1
880
+ guider_text_len = len(guider_seq) - frame_len * generate_frame_num - 1
881
+ seq = torch.tensor(seq, dtype=torch.long,
882
+ device=self.device).unsqueeze(0).repeat(
883
+ batch_size, 1)
884
+ guider_seq = torch.tensor(guider_seq,
885
+ dtype=torch.long,
886
+ device=self.device).unsqueeze(0).repeat(
887
+ batch_size, 1)
888
+
889
+ for given_frame_id in range(given_tokens.shape[1]):
890
+ seq[:, text_len + 1 + given_frame_id * 400:text_len + 1 +
891
+ (given_frame_id + 1) * 400] = given_tokens[:, given_frame_id]
892
+ guider_seq[:, guider_text_len + 1 +
893
+ given_frame_id * 400:guider_text_len + 1 +
894
+ (given_frame_id + 1) *
895
+ 400] = given_tokens[:, given_frame_id]
896
+ output_list = []
897
+
898
+ if use_guide:
899
+ video_log_text_attention_weights = 0
900
+ else:
901
+ guider_seq = None
902
+ video_log_text_attention_weights = 1.4
903
+
904
+ for tim in range(max(batch_size // mbz, 1)):
905
+ input_seq = seq[:min(batch_size, mbz)].clone(
906
+ ) if tim == 0 else seq[mbz * tim:mbz * (tim + 1)].clone()
907
+ guider_seq2 = (guider_seq[:min(batch_size, mbz)].clone()
908
+ if tim == 0 else guider_seq[mbz * tim:mbz *
909
+ (tim + 1)].clone()
910
+ ) if guider_seq is not None else None
911
+ output_list.append(
912
+ my_filling_sequence(
913
+ model,
914
+ tokenizer,
915
+ self.args,
916
+ input_seq,
917
+ batch_size=min(batch_size, mbz),
918
+ get_masks_and_position_ids=
919
+ get_masks_and_position_ids_stage1,
920
+ text_len=text_len,
921
+ frame_len=frame_len,
922
+ strategy=self.strategy_cogview2,
923
+ strategy2=self.strategy_cogvideo,
924
+ log_text_attention_weights=video_log_text_attention_weights,
925
+ guider_seq=guider_seq2,
926
+ guider_text_len=guider_text_len,
927
+ guidance_alpha=self.args.guidance_alpha,
928
+ limited_spatial_channel_mem=True,
929
+ mode_stage1=True,
930
+ )[0])
931
+
932
+ output_tokens = torch.cat(output_list, dim=0)[:, 1 + text_len:]
933
+
934
+ if self.args.both_stages:
935
+ move_start_time = time.perf_counter()
936
+ logger.debug('moving stage 1 model to cpu')
937
+ model = model.cpu()
938
+ torch.cuda.empty_cache()
939
+ elapsed = time.perf_counter() - move_start_time
940
+ logger.debug(f'moving in model1 takes time: {elapsed:.2f}')
941
+
942
+ # decoding
943
+ res = []
944
+ for seq in output_tokens:
945
+ decoded_imgs = [
946
+ self.postprocess(
947
+ torch.nn.functional.interpolate(tokenizer.decode(
948
+ image_ids=seq.tolist()[i * 400:(i + 1) * 400]),
949
+ size=(480, 480))[0])
950
+ for i in range(total_frames)
951
+ ]
952
+ res.append(decoded_imgs) # only the last image (target)
953
+
954
+ assert len(res) == batch_size
955
+ tokens = output_tokens[:, :+total_frames * 400].reshape(
956
+ -1, total_frames, 400).cpu()
957
+
958
+ elapsed = time.perf_counter() - process_start_time
959
+ logger.info(f'--- done ({elapsed=:.3f}) ---')
960
+ return tokens, res[0]
961
+
962
+ @torch.inference_mode()
963
+ def process_stage2(self,
964
+ model,
965
+ seq_text,
966
+ duration,
967
+ parent_given_tokens,
968
+ video_raw_text=None,
969
+ video_guidance_text='视频',
970
+ gpu_rank=0,
971
+ gpu_parallel_size=1):
972
+ process_start_time = time.perf_counter()
973
+
974
+ generate_frame_num = self.args.generate_frame_num
975
+ tokenizer = self.tokenizer
976
+ use_guidance = self.args.use_guidance_stage2
977
+
978
+ stage2_start_time = time.perf_counter()
979
+
980
+ if next(model.parameters()).device != self.device:
981
+ move_start_time = time.perf_counter()
982
+ logger.debug('moving stage-2 model to cuda')
983
+
984
+ model = model.to(self.device)
985
+
986
+ elapsed = time.perf_counter() - move_start_time
987
+ logger.debug(f'moving in stage-2 model takes time: {elapsed:.2f}')
988
+
989
+ try:
990
+ sample_num_allgpu = parent_given_tokens.shape[0]
991
+ sample_num = sample_num_allgpu // gpu_parallel_size
992
+ assert sample_num * gpu_parallel_size == sample_num_allgpu
993
+ parent_given_tokens = parent_given_tokens[gpu_rank *
994
+ sample_num:(gpu_rank +
995
+ 1) *
996
+ sample_num]
997
+ except:
998
+ logger.critical('No frame_tokens found in interpolation, skip')
999
+ return False, []
1000
+
1001
+ # CogVideo Stage2 Generation
1002
+ while duration >= 0.5: # TODO: You can change the boundary to change the frame rate
1003
+ parent_given_tokens_num = parent_given_tokens.shape[1]
1004
+ generate_batchsize_persample = (parent_given_tokens_num - 1) // 2
1005
+ generate_batchsize_total = generate_batchsize_persample * sample_num
1006
+ total_frames = generate_frame_num
1007
+ frame_len = 400
1008
+ enc_text = tokenizer.encode(seq_text)
1009
+ enc_duration = tokenizer.encode(str(float(duration)) + '秒')
1010
+ seq = enc_duration + [tokenizer['<n>']] + enc_text + [
1011
+ tokenizer['<start_of_image>']
1012
+ ] + [-1] * 400 * generate_frame_num
1013
+ text_len = len(seq) - frame_len * generate_frame_num - 1
1014
+
1015
+ logger.info(
1016
+ f'[Stage2: Generating Frames, Frame Rate {int(4/duration):d}] raw text: {tokenizer.decode(enc_text):s}'
1017
+ )
1018
+
1019
+ # generation
1020
+ seq = torch.tensor(seq, dtype=torch.long,
1021
+ device=self.device).unsqueeze(0).repeat(
1022
+ generate_batchsize_total, 1)
1023
+ for sample_i in range(sample_num):
1024
+ for i in range(generate_batchsize_persample):
1025
+ seq[sample_i * generate_batchsize_persample +
1026
+ i][text_len + 1:text_len + 1 +
1027
+ 400] = parent_given_tokens[sample_i][2 * i]
1028
+ seq[sample_i * generate_batchsize_persample +
1029
+ i][text_len + 1 + 400:text_len + 1 +
1030
+ 800] = parent_given_tokens[sample_i][2 * i + 1]
1031
+ seq[sample_i * generate_batchsize_persample +
1032
+ i][text_len + 1 + 800:text_len + 1 +
1033
+ 1200] = parent_given_tokens[sample_i][2 * i + 2]
1034
+
1035
+ if use_guidance:
1036
+ guider_seq = enc_duration + [
1037
+ tokenizer['<n>']
1038
+ ] + tokenizer.encode(video_guidance_text) + [
1039
+ tokenizer['<start_of_image>']
1040
+ ] + [-1] * 400 * generate_frame_num
1041
+ guider_text_len = len(
1042
+ guider_seq) - frame_len * generate_frame_num - 1
1043
+ guider_seq = torch.tensor(
1044
+ guider_seq, dtype=torch.long,
1045
+ device=self.device).unsqueeze(0).repeat(
1046
+ generate_batchsize_total, 1)
1047
+ for sample_i in range(sample_num):
1048
+ for i in range(generate_batchsize_persample):
1049
+ guider_seq[sample_i * generate_batchsize_persample +
1050
+ i][text_len + 1:text_len + 1 +
1051
+ 400] = parent_given_tokens[sample_i][2 *
1052
+ i]
1053
+ guider_seq[sample_i * generate_batchsize_persample +
1054
+ i][text_len + 1 + 400:text_len + 1 +
1055
+ 800] = parent_given_tokens[sample_i][2 *
1056
+ i +
1057
+ 1]
1058
+ guider_seq[sample_i * generate_batchsize_persample +
1059
+ i][text_len + 1 + 800:text_len + 1 +
1060
+ 1200] = parent_given_tokens[sample_i][2 *
1061
+ i +
1062
+ 2]
1063
+ video_log_text_attention_weights = 0
1064
+ else:
1065
+ guider_seq = None
1066
+ guider_text_len = 0
1067
+ video_log_text_attention_weights = 1.4
1068
+
1069
+ mbz = self.args.max_inference_batch_size
1070
+
1071
+ assert generate_batchsize_total < mbz or generate_batchsize_total % mbz == 0
1072
+ output_list = []
1073
+ start_time = time.perf_counter()
1074
+ for tim in range(max(generate_batchsize_total // mbz, 1)):
1075
+ input_seq = seq[:min(generate_batchsize_total, mbz)].clone(
1076
+ ) if tim == 0 else seq[mbz * tim:mbz * (tim + 1)].clone()
1077
+ guider_seq2 = (
1078
+ guider_seq[:min(generate_batchsize_total, mbz)].clone()
1079
+ if tim == 0 else guider_seq[mbz * tim:mbz *
1080
+ (tim + 1)].clone()
1081
+ ) if guider_seq is not None else None
1082
+ output_list.append(
1083
+ my_filling_sequence(
1084
+ model,
1085
+ tokenizer,
1086
+ self.args,
1087
+ input_seq,
1088
+ batch_size=min(generate_batchsize_total, mbz),
1089
+ get_masks_and_position_ids=
1090
+ get_masks_and_position_ids_stage2,
1091
+ text_len=text_len,
1092
+ frame_len=frame_len,
1093
+ strategy=self.strategy_cogview2,
1094
+ strategy2=self.strategy_cogvideo,
1095
+ log_text_attention_weights=
1096
+ video_log_text_attention_weights,
1097
+ mode_stage1=False,
1098
+ guider_seq=guider_seq2,
1099
+ guider_text_len=guider_text_len,
1100
+ guidance_alpha=self.args.guidance_alpha,
1101
+ limited_spatial_channel_mem=True,
1102
+ )[0])
1103
+ elapsed = time.perf_counter() - start_time
1104
+ logger.info(f'Duration {duration:.2f}, Elapsed: {elapsed:.2f}\n')
1105
+
1106
+ output_tokens = torch.cat(output_list, dim=0)
1107
+ output_tokens = output_tokens[:, text_len + 1:text_len + 1 +
1108
+ (total_frames) * 400].reshape(
1109
+ sample_num, -1,
1110
+ 400 * total_frames)
1111
+ output_tokens_merge = torch.cat(
1112
+ (output_tokens[:, :, :1 * 400], output_tokens[:, :,
1113
+ 400 * 3:4 * 400],
1114
+ output_tokens[:, :, 400 * 1:2 * 400],
1115
+ output_tokens[:, :, 400 * 4:(total_frames) * 400]),
1116
+ dim=2).reshape(sample_num, -1, 400)
1117
+
1118
+ output_tokens_merge = torch.cat(
1119
+ (output_tokens_merge, output_tokens[:, -1:, 400 * 2:3 * 400]),
1120
+ dim=1)
1121
+ duration /= 2
1122
+ parent_given_tokens = output_tokens_merge
1123
+
1124
+ if self.args.both_stages:
1125
+ move_start_time = time.perf_counter()
1126
+ logger.debug('moving stage 2 model to cpu')
1127
+ model = model.cpu()
1128
+ torch.cuda.empty_cache()
1129
+ elapsed = time.perf_counter() - move_start_time
1130
+ logger.debug(f'moving out model2 takes time: {elapsed:.2f}')
1131
+
1132
+ elapsed = time.perf_counter() - stage2_start_time
1133
+ logger.info(f'CogVideo Stage2 completed. Elapsed: {elapsed:.2f}\n')
1134
+
1135
+ # direct super-resolution by CogView2
1136
+ logger.info('[Direct super-resolution]')
1137
+ dsr_start_time = time.perf_counter()
1138
+
1139
+ enc_text = tokenizer.encode(seq_text)
1140
+ frame_num_per_sample = parent_given_tokens.shape[1]
1141
+ parent_given_tokens_2d = parent_given_tokens.reshape(-1, 400)
1142
+ text_seq = torch.tensor(enc_text, dtype=torch.long,
1143
+ device=self.device).unsqueeze(0).repeat(
1144
+ parent_given_tokens_2d.shape[0], 1)
1145
+ sred_tokens = self.dsr(text_seq, parent_given_tokens_2d)
1146
+
1147
+ decoded_sr_videos = []
1148
+ for sample_i in range(sample_num):
1149
+ decoded_sr_imgs = []
1150
+ for frame_i in range(frame_num_per_sample):
1151
+ decoded_sr_img = tokenizer.decode(
1152
+ image_ids=sred_tokens[frame_i + sample_i *
1153
+ frame_num_per_sample][-3600:])
1154
+ decoded_sr_imgs.append(
1155
+ self.postprocess(
1156
+ torch.nn.functional.interpolate(decoded_sr_img,
1157
+ size=(480, 480))[0]))
1158
+ decoded_sr_videos.append(decoded_sr_imgs)
1159
+
1160
+ elapsed = time.perf_counter() - dsr_start_time
1161
+ logger.info(
1162
+ f'Direct super-resolution completed. Elapsed: {elapsed:.2f}')
1163
+
1164
+ elapsed = time.perf_counter() - process_start_time
1165
+ logger.info(f'--- done ({elapsed=:.3f}) ---')
1166
+ return True, decoded_sr_videos[0]
1167
+
1168
+ @staticmethod
1169
+ def postprocess(tensor: torch.Tensor) -> np.ndarray:
1170
+ return tensor.cpu().mul(255).add_(0.5).clamp_(0, 255).permute(
1171
+ 1, 2, 0).to(torch.uint8).numpy()
1172
+
1173
+ def run(self, text: str, seed: int,
1174
+ only_first_stage: bool,image_prompt: None) -> list[np.ndarray]:
1175
+ logger.info('==================== run ====================')
1176
+ start = time.perf_counter()
1177
+
1178
+ set_random_seed(seed)
1179
+ self.args.seed = seed
1180
+
1181
+ if only_first_stage:
1182
+ self.args.stage_1 = True
1183
+ self.args.both_stages = False
1184
+ else:
1185
+ self.args.stage_1 = False
1186
+ self.args.both_stages = True
1187
+
1188
+ parent_given_tokens, res = self.process_stage1(
1189
+ self.model_stage1,
1190
+ text,
1191
+ duration=4.0,
1192
+ video_raw_text=text,
1193
+ video_guidance_text='视频',
1194
+ image_text_suffix=' 高清摄影',
1195
+ batch_size=self.args.batch_size,
1196
+ image_prompt=image_prompt)
1197
+ if not only_first_stage:
1198
+ _, res = self.process_stage2(
1199
+ self.model_stage2,
1200
+ text,
1201
+ duration=2.0,
1202
+ parent_given_tokens=parent_given_tokens,
1203
+ video_raw_text=text + ' 视频',
1204
+ video_guidance_text='视频',
1205
+ gpu_rank=0,
1206
+ gpu_parallel_size=1) # TODO: 修改
1207
+
1208
+ elapsed = time.perf_counter() - start
1209
+ logger.info(f'Elapsed: {elapsed:.3f}')
1210
+ logger.info('==================== done ====================')
1211
+ return res
1212
+
1213
+
1214
+ class AppModel(Model):
1215
+ def __init__(self, only_first_stage: bool):
1216
+ super().__init__(only_first_stage)
1217
+ self.translator = gr.Interface.load(
1218
+ 'spaces/chinhon/translation_eng2ch')
1219
+
1220
+ def to_video(self, frames: list[np.ndarray]) -> str:
1221
+ out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
1222
+ if self.args.stage_1:
1223
+ fps = 4
1224
+ else:
1225
+ fps = 8
1226
+ writer = iio.get_writer(out_file.name, fps=fps)
1227
+ for frame in frames:
1228
+ writer.append_data(frame)
1229
+ writer.close()
1230
+ return out_file.name
1231
+
1232
+ def run_with_translation(
1233
+ self, text: str, translate: bool, seed: int,
1234
+ only_first_stage: bool,image_prompt: None) -> tuple[str | None, str | None]:
1235
+
1236
+ logger.info(f'{text=}, {translate=}, {seed=}, {only_first_stage=},{image_prompt=}')
1237
+ if translate:
1238
+ text = translated_text = self.translator(text)
1239
+ else:
1240
+ translated_text = None
1241
+ frames = self.run(text, seed, only_first_stage,image_prompt)
1242
+ video_path = self.to_video(frames)
1243
+ return translated_text, video_path
models/cogvideo_cache_model.py DELETED
@@ -1,695 +0,0 @@
1
- # -*- encoding: utf-8 -*-
2
- '''
3
- @File : cogvideo_cache_model.py
4
- @Time : 2022/07/15 11:22:19
5
- @Author : Wenyi Hong
6
- @Version : 1.0
7
- @Contact : hwy22@mails.tsinghua.edu.cn
8
- '''
9
-
10
- # here put the import lib
11
-
12
- from multiprocessing import context
13
- from tkinter import E
14
- import torch
15
- from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin
16
-
17
- from SwissArmyTransformer.mpu.utils import split_tensor_along_last_dim
18
- from SwissArmyTransformer.model.transformer import unscaled_init_method
19
- from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear
20
- import torch.nn.functional as F
21
- from deepspeed.runtime.activation_checkpointing.checkpointing import get_cuda_rng_tracker
22
- import math
23
-
24
-
25
- class PositionEmbeddingMixin(BaseMixin):
26
- def __init__(self, additional_sequence_length, hidden_size,
27
- init_method_std=0.02, reinit_slice=slice(512, 912),
28
- ):
29
- super(PositionEmbeddingMixin, self).__init__()
30
- self.reinit_slice = reinit_slice
31
- self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size)
32
- torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
33
-
34
- def reinit(self, parent_model=None):
35
- old_weights = self.transformer.position_embeddings.weight.data[self.reinit_slice]
36
- old_len, hidden_size = old_weights.shape
37
- assert hidden_size == self.position_embeddings.weight.shape[-1]
38
- self.position_embeddings.weight.data.view(-1, old_len, hidden_size).copy_(old_weights)
39
-
40
-
41
- def window_partition(x, window_size):
42
- """
43
- Args:
44
- x: (B, framenum, H, W, C)
45
- window_size (int): window size
46
- Returns:
47
- windows: (num_windows*B, frame_num, window_size, window_size, C)
48
- """
49
- B, framenum, H, W, C = x.shape
50
- x = x.view(B, framenum, H // window_size, window_size, W // window_size, window_size, C)
51
- windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(-1, framenum, window_size, window_size, C)
52
- return windows
53
-
54
- def window_reverse(windows, window_size, H, W):
55
- """
56
- Args:
57
- windows: (num_windows*B, frame_num, window_size, window_size, C)
58
- window_size (int): Window size
59
- H (int): Height of image
60
- W (int): Width of image
61
- Returns:
62
- x: (B, frame_num, H, W, C)
63
- """
64
- B = int(windows.shape[0] / (H * W / window_size / window_size))
65
- framenum = windows.shape[1]
66
- x = windows.view(B, H // window_size, W // window_size, framenum, window_size, window_size, -1)
67
- x = x.permute(0, 3, 1, 4, 2, 5, 6).contiguous().view(B, framenum, H, W, -1)
68
- return x
69
-
70
- class WindowAttentionMixin(BaseMixin):
71
- def __init__(self, num_layers,
72
- hidden_size,
73
- frame_resolution,
74
- window_size,
75
- shift_size,
76
- n_head,
77
- frame_num,
78
- init_method=unscaled_init_method(0.02),
79
- output_layer_init_method=unscaled_init_method(0.02),
80
- time_dim_attend_length=0
81
- ):
82
- super(WindowAttentionMixin, self).__init__()
83
- self.num_layers = num_layers # replace attention in the LAST n layers
84
- self.query_key_value = torch.nn.ModuleList(
85
- [ColumnParallelLinear(hidden_size, 3*hidden_size,stride=3,
86
- gather_output=False,init_method=init_method)
87
- for layer_id in range(num_layers)
88
- ])
89
- self.dense = torch.nn.ModuleList(
90
- [RowParallelLinear(
91
- hidden_size,
92
- hidden_size,
93
- input_is_parallel=True,
94
- init_method=output_layer_init_method,
95
- bias=True,
96
- module=self,
97
- name="dense")
98
- for layer_id in range(num_layers)
99
- ])
100
-
101
- self.n_head = n_head
102
- self.window_size = window_size
103
- self.frame_resolution = frame_resolution
104
- self.frame_len = frame_resolution * frame_resolution
105
- self.time_dim_attend_length = time_dim_attend_length
106
- assert frame_resolution % window_size == 0
107
- assert 0 < shift_size < window_size
108
- nW = (self.frame_resolution // self.window_size) ** 2
109
- ws_squre = self.window_size * self.window_size
110
-
111
- # odd non-shift, even shift
112
- img_mask = torch.zeros((1, 1, frame_resolution, frame_resolution, 1))
113
- h_slices = (slice(0, -shift_size),
114
- slice(-shift_size, None))
115
- w_slices = (slice(0, -shift_size),
116
- slice(-shift_size, None))
117
- cnt = 0
118
- for h in h_slices:
119
- for w in w_slices:
120
- img_mask[:, :, h, w, :] = cnt
121
- cnt += 1
122
- mask_windows = window_partition(img_mask, self.window_size) # nW, 1, window_size, window_size, 1
123
- mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
124
- sub_attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) #[nW, self.window_size * self.window_size, self.window_size * self.window_size]
125
- sub_attn_mask = sub_attn_mask.masked_fill(sub_attn_mask != 0, float(0.0)).masked_fill(sub_attn_mask == 0, float(1.00))
126
- attn_mask = sub_attn_mask.repeat(1, frame_num, frame_num)
127
- attn_mask = attn_mask.tril()
128
-
129
- causal_mask = torch.ones(ws_squre*frame_num, ws_squre*frame_num)
130
- causal_mask = causal_mask.tril()
131
-
132
- self.shift_sizes = [0, shift_size]
133
- self.attn_mask = attn_mask
134
- self.causal_mask = causal_mask
135
- self.mask_initialized = False
136
-
137
- self.attn_distribution = torch.nn.ParameterList([
138
- torch.nn.Parameter(torch.zeros(hidden_size))
139
- for _ in range(num_layers)
140
- ])
141
-
142
- def reinit(self, *pre_mixins):
143
- start_layer = len(self.transformer.layers) - self.num_layers
144
- assert start_layer >= 0
145
- for layer_id in range(self.num_layers):
146
- old_attention = self.transformer.layers[start_layer + layer_id].attention
147
- self.query_key_value[layer_id].weight.data.copy_(old_attention.query_key_value.weight.data)
148
- self.query_key_value[layer_id].bias.data.copy_(old_attention.query_key_value.bias.data)
149
-
150
- def attention_extra_NAR_inference(self, frame_hidden_state, layer_id, attn_dropout=None, memkv_text=None, stage=1):
151
- # frame_hidden_state [batchsize, frame_num*frame_size, n_head*hiddensize_perhead]
152
- if not self.mask_initialized:
153
- self.attn_mask = self.attn_mask.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype)
154
- self.causal_mask = self.causal_mask.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype)
155
- self.mask_initialized = True
156
- b0, s1, h0 = frame_hidden_state.shape
157
- h = h0 // self.n_head
158
- frame_len = self.frame_resolution * self.frame_resolution
159
- frame_num = s1 // frame_len
160
- if stage == 2:
161
- assert frame_num == 3
162
- assert frame_num*frame_len == s1
163
- wind_square = self.window_size * self.window_size
164
- nW = frame_len // wind_square
165
- bswin = b0 * nW
166
-
167
- if memkv_text is not None:
168
- s0 = memkv_text.shape[-2]
169
- k_text = memkv_text[..., :h0].expand(b0, -1, -1).reshape(b0, s0, self.n_head, h).permute(0, 2, 1, 3)
170
- v_text = memkv_text[..., h0:].expand(b0, -1, -1).reshape(b0, s0, self.n_head, h).permute(0, 2, 1, 3)
171
-
172
- # shift
173
- frame_hidden_state = frame_hidden_state.reshape(b0, frame_num, self.frame_resolution, self.frame_resolution, h0)
174
- if self.shift_sizes[layer_id%2] > 0:
175
- frame_hidden_state = torch.roll(frame_hidden_state, shifts=(-self.shift_sizes[layer_id%2], -self.shift_sizes[layer_id%2]), dims=(2,3))
176
- # window partition
177
- frame_hidden_state = window_partition(frame_hidden_state, self.window_size).reshape(bswin, frame_num*wind_square, h0)
178
- qkv = self.query_key_value[layer_id](frame_hidden_state).reshape(bswin, frame_num*wind_square, 3, self.n_head, h)\
179
- .permute(2, 0, 3, 1, 4) #[3, bswin, n_head, frame_num*wind_size*wind_size, h]
180
- q, k, v = qkv[0], qkv[1], qkv[2]
181
- attn = torch.matmul(q / math.sqrt(h), k.transpose(-1, -2))
182
-
183
- if stage == 1:
184
- if self.shift_sizes[layer_id%2] > 0:
185
- attn = torch.mul(attn.view(bswin // nW, nW, self.n_head, frame_num*wind_square, frame_num*wind_square),
186
- self.attn_mask[:,:frame_num*wind_square, :frame_num*wind_square].unsqueeze(1).unsqueeze(0))\
187
- - 10000.0 * (1.0 - self.attn_mask[:,:frame_num*wind_square, :frame_num*wind_square].unsqueeze(1).unsqueeze(0))
188
- attn = attn.view(bswin, self.n_head, frame_num*wind_square, frame_num*wind_square)
189
- else:
190
- attn = torch.mul(attn, self.causal_mask[:frame_num*wind_square, :frame_num*wind_square].unsqueeze(0).unsqueeze(0))\
191
- - 10000.0 * (1.0 - self.causal_mask[:frame_num*wind_square, :frame_num*wind_square].unsqueeze(0).unsqueeze(0))
192
-
193
- if memkv_text is None:
194
- attn = F.softmax(attn, dim=-1)
195
- if attn_dropout is not None:
196
- with get_cuda_rng_tracker().fork():
197
- attn = attn_dropout(attn)
198
- context_swin = torch.matmul(attn, v).permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0)
199
- else:
200
- attn_frame2text = torch.matmul(q.reshape(b0, -1, self.n_head, frame_num*wind_square, h) / math.sqrt(h), k_text.unsqueeze(1).transpose(-1, -2))
201
- attn_frame2text = attn_frame2text.reshape(bswin, self.n_head, frame_num*wind_square, s0)
202
- attn = torch.cat((attn, attn_frame2text), dim=-1)
203
- attn = F.softmax(attn, dim=-1)
204
-
205
- if attn_dropout is not None:
206
- with get_cuda_rng_tracker().fork():
207
- attn = attn_dropout(attn)
208
-
209
- context_swin = (torch.matmul(attn[..., :-s0], v) +
210
- torch.matmul(attn[..., -s0:].reshape(b0, -1, self.n_head,frame_num*wind_square, s0), v_text.unsqueeze(1))\
211
- .reshape(bswin, self.n_head, frame_num*wind_square, h))\
212
- .permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0)
213
-
214
- context_swin = window_reverse(context_swin, self.window_size, self.frame_resolution, self.frame_resolution)
215
-
216
- # reverse cycle shift
217
- if self.shift_sizes[layer_id%2] > 0:
218
- context_swin = torch.roll(context_swin, shifts=(self.shift_sizes[layer_id%2], self.shift_sizes[layer_id%2]), dims=(2,3))
219
- ret_context = context_swin.reshape(b0, s1, h0)
220
-
221
- # for mem
222
- memk = k.permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0)
223
- memv = v.permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0)
224
- memk = window_reverse(memk, self.window_size, self.frame_resolution, self.frame_resolution)
225
- memv = window_reverse(memv, self.window_size, self.frame_resolution, self.frame_resolution)
226
- if self.shift_sizes[layer_id%2] > 0:
227
- memk = torch.roll(memk, shifts=(self.shift_sizes[layer_id%2], self.shift_sizes[layer_id%2]), dims=(2,3))
228
- memv = torch.roll(memv, shifts=(self.shift_sizes[layer_id%2], self.shift_sizes[layer_id%2]), dims=(2,3))
229
- memk, memv = memk.reshape(b0, s1, h0), memv.reshape(b0, s1, h0)
230
-
231
- ret_mem = torch.cat((memk, memv), dim=-1)
232
- return ret_context, ret_mem
233
-
234
- def attention_extra_AR_inference(self, frame_hidden_state, memkv, pos, layer_id, log_text_attention_weights=0, attn_dropout=None, memkv_text=None, stage=1):
235
- # frame_hidden_state [batchsize, 1, n_head*hiddensize_perhead]
236
- # memkv [batchsize, pos, hidden_size*2] (include frames only)
237
- # if memkv_text is not None: will attend to text
238
- # pos: token's pos
239
- b0, sin, h0 = frame_hidden_state.shape
240
- h = h0 // self.n_head
241
- assert sin == 1
242
- this_qkv = self.query_key_value[layer_id](frame_hidden_state)
243
- thisq, thisk, thisv = this_qkv[..., :h0], this_qkv[..., h0:2*h0], this_qkv[..., 2*h0:]
244
- s1 = memkv.shape[1] if memkv is not None else 0
245
- frame_len = self.frame_resolution * self.frame_resolution
246
- frame_num_before = s1 // frame_len
247
-
248
-
249
- if memkv is not None:
250
- pos_inframe = pos - frame_num_before * frame_len
251
-
252
- xpos = pos_inframe // self.frame_resolution # pos = xpos*self.frame_resolution + ypos
253
- ypos = pos_inframe % self.frame_resolution
254
- # [start, end)
255
- if self.shift_sizes[layer_id%2] > 0:
256
- xstart = ((xpos+self.shift_sizes[layer_id%2]) // self.window_size) * self.window_size - self.shift_sizes[layer_id%2]
257
- ystart = ((ypos+self.shift_sizes[layer_id%2]) // self.window_size) * self.window_size - self.shift_sizes[layer_id%2]
258
- xend = xstart + self.window_size
259
- yend = ystart + self.window_size
260
- xstart, ystart = max(0, xstart), max(0, ystart)
261
- xend, yend = min(xend, self.frame_resolution), min(yend, self.frame_resolution)
262
- else:
263
- xstart = (xpos // self.window_size) * self.window_size
264
- ystart = (ypos // self.window_size) * self.window_size
265
- xend, yend = xstart + self.window_size, ystart+self.window_size
266
-
267
- # select index
268
- selected_index = list()
269
- if frame_num_before > 0:
270
- # frames before
271
- frame_attended_start = max(0, frame_num_before-self.time_dim_attend_length+1) if self.time_dim_attend_length > 0 else 0
272
- for x in range(xstart, xend):
273
- for y in range(ystart, yend):
274
- selected_index.append(x*self.frame_resolution+y+frame_len*frame_attended_start)
275
- cnt_per_frame = len(selected_index)
276
- for _ in range((frame_num_before-frame_attended_start-1)*cnt_per_frame):
277
- selected_index.append(selected_index[-cnt_per_frame]+frame_len)
278
-
279
- # the last frame
280
- for x in range(xstart, xend):
281
- for y in range(ystart, yend):
282
- tmppos = x*self.frame_resolution+y + frame_num_before * frame_len
283
- if tmppos < pos:
284
- selected_index.append(tmppos)
285
- else:
286
- break
287
- cnt_all = len(selected_index)+1
288
- selected_index = torch.tensor(selected_index, device=memkv.device)
289
- used_memkv = torch.index_select(memkv, 1, selected_index)
290
- used_k, used_v = used_memkv[..., :h0], used_memkv[..., h0:]
291
- used_k = torch.cat((used_k.expand(thisk.shape[0], -1, -1), thisk), dim=-2)
292
- used_v = torch.cat((used_v.expand(thisv.shape[0], -1, -1), thisv), dim=-2)
293
- if memkv_text is not None:
294
- cnt_all += memkv_text.shape[-2]
295
- used_k = torch.cat((memkv_text[..., :h0].expand(thisk.shape[0], -1, -1), used_k), dim=-2)
296
- used_v = torch.cat((memkv_text[..., h0:].expand(thisv.shape[0], -1, -1), used_v), dim=-2)
297
- used_k = used_k.reshape(b0, cnt_all, self.n_head, h).permute(0, 2, 1, 3)
298
- used_v = used_v.reshape(b0, cnt_all, self.n_head, h).permute(0, 2, 1, 3)
299
- else:
300
- used_k = thisk
301
- used_v = thisv
302
-
303
- if memkv_text is not None:
304
- used_k = torch.cat((memkv_text[..., :h0].expand(thisk.shape[0], -1, -1), used_k), dim=-2)
305
- used_v = torch.cat((memkv_text[..., h0:].expand(thisv.shape[0], -1, -1), used_v), dim=-2)
306
- used_k = used_k.reshape(b0, 1+memkv_text.shape[-2], self.n_head, h).permute(0, 2, 1, 3)
307
- used_v = used_v.reshape(b0, 1+memkv_text.shape[-2], self.n_head, h).permute(0, 2, 1, 3)
308
- else:
309
- used_k = used_k.reshape(b0, 1, self.n_head, h).permute(0, 2, 1, 3)
310
- used_v = used_v.reshape(b0, 1, self.n_head, h).permute(0, 2, 1, 3)
311
-
312
- thisq = thisq.reshape(b0, 1, self.n_head, h).permute(0, 2, 1, 3) # [b0, n_head, 1, h]
313
- attn = torch.matmul(thisq / math.sqrt(h), used_k.transpose(-1, -2))
314
- if memkv_text is not None:
315
- attn[..., :memkv_text.shape[-2]] += log_text_attention_weights
316
- attn = F.softmax(attn, dim=-1)
317
- context_swin = torch.matmul(attn, used_v).permute(0, 2, 1, 3).reshape(b0, 1, h0)
318
-
319
- return context_swin, this_qkv[..., h0:]
320
-
321
- class FullAttentionMixin(BaseMixin):
322
- def __init__(self, num_layers,
323
- hidden_size,
324
- frame_resolution,
325
- n_head,
326
- frame_num,
327
- init_method=unscaled_init_method(0.02),
328
- output_layer_init_method=unscaled_init_method(0.02),
329
- **kwargs,
330
- ):
331
- super(FullAttentionMixin, self).__init__()
332
- self.num_layers = num_layers # replace attention in the LAST n layers
333
- self.query_key_value = torch.nn.ModuleList(
334
- [ColumnParallelLinear(hidden_size, 3*hidden_size,stride=3,
335
- gather_output=False,init_method=init_method)
336
- for layer_id in range(num_layers)
337
- ])
338
- self.dense = torch.nn.ModuleList(
339
- [RowParallelLinear(
340
- hidden_size,
341
- hidden_size,
342
- input_is_parallel=True,
343
- init_method=output_layer_init_method,
344
- bias=True,
345
- module=self,
346
- name="dense")
347
- for layer_id in range(num_layers)
348
- ])
349
-
350
- self.n_head = n_head
351
- self.frame_resolution = frame_resolution
352
- self.frame_len = frame_resolution * frame_resolution
353
-
354
- self.attn_distribution = torch.nn.ParameterList([
355
- torch.nn.Parameter(torch.zeros(hidden_size))
356
- for _ in range(num_layers)
357
- ])
358
-
359
- def reinit(self, *pre_mixins):
360
- start_layer = len(self.transformer.layers) - self.num_layers
361
- assert start_layer >= 0
362
- for layer_id in range(self.num_layers):
363
- old_attention = self.transformer.layers[start_layer + layer_id].attention
364
- self.query_key_value[layer_id].weight.data.copy_(old_attention.query_key_value.weight.data)
365
- self.query_key_value[layer_id].bias.data.copy_(old_attention.query_key_value.bias.data)
366
-
367
-
368
- def attention_extra_NAR_inference(self, frame_hidden_state, layer_id, attn_dropout=None, memkv_text=None, stage=1):
369
- # frame_hidden_state [batchsize, frame_num*frame_size, n_head*hiddensize_perhead]
370
- assert stage == 1
371
-
372
- b0, s1, h0 = frame_hidden_state.shape
373
- h = h0 // self.n_head
374
- frame_len = self.frame_resolution * self.frame_resolution
375
- frame_num = s1 // frame_len
376
- assert frame_num*frame_len == s1
377
-
378
- if memkv_text is not None:
379
- s0 = memkv_text.shape[-2]
380
- k_text = memkv_text[..., :h0].expand(b0, -1, -1).reshape(b0, s0, self.n_head, h).permute(0, 2, 1, 3)
381
- v_text = memkv_text[..., h0:].expand(b0, -1, -1).reshape(b0, s0, self.n_head, h).permute(0, 2, 1, 3)
382
- qkv = self.query_key_value[layer_id](frame_hidden_state).reshape(b0, s1, 3, self.n_head, h)\
383
- .permute(2, 0, 3, 1, 4) #[3, b0, n_head, s1, h]
384
- q, k, v = qkv[0], qkv[1], qkv[2]
385
- attn = torch.matmul(q / math.sqrt(h), k.transpose(-1, -2))
386
- attn = attn - 10000.0 * (1.0-torch.ones(b0, self.n_head, s1, s1, device=attn.device, dtype=attn.dtype).tril())
387
-
388
- if memkv_text is None:
389
- attn = F.softmax(attn, dim=-1)
390
- if attn_dropout is not None:
391
- with get_cuda_rng_tracker().fork():
392
- attn = attn_dropout(attn)
393
- context_swin = torch.matmul(attn, v).permute(0, 2, 1, 3).reshape(b0, s1, h0)
394
- else:
395
- attn_frame2text = torch.matmul(q / math.sqrt(h), k_text.transpose(-1, -2)) #[b0, s1, s0]
396
- attn = torch.cat((attn, attn_frame2text), dim=-1)
397
- attn = F.softmax(attn, dim=-1)
398
- if attn_dropout is not None:
399
- with get_cuda_rng_tracker().fork():
400
- attn = attn_dropout(attn)
401
- context_swin = (torch.matmul(attn[..., :-s0], v) + torch.matmul(attn[..., -s0:], v_text))\
402
- .permute(0, 2, 1, 3).reshape(b0, s1, h0)
403
-
404
- # for mem
405
- memk = k.permute(0, 2, 1, 3).reshape(b0, s1, h0)
406
- memv = v.permute(0, 2, 1, 3).reshape(b0, s1, h0)
407
- ret_mem = torch.cat((memk, memv), dim=-1)
408
-
409
- return context_swin, ret_mem
410
-
411
- def attention_extra_AR_inference(self, frame_hidden_state, memkv, pos, layer_id, log_text_attention_weights=0, attn_dropout=None, memkv_text=None, stage=1):
412
- # pos: current token's pos
413
- b0, sin, h0 = frame_hidden_state.shape
414
- h = h0 // self.n_head
415
- assert sin == 1
416
- assert stage == 1
417
-
418
- this_qkv = self.query_key_value[layer_id](frame_hidden_state)
419
- thisq, thisk, thisv = this_qkv[..., :h0], this_qkv[..., h0:2*h0], this_qkv[..., 2*h0:]
420
-
421
- if memkv is not None:
422
- used_k, used_v = memkv[..., :h0], memkv[..., h0:]
423
- used_k = torch.cat((used_k.expand(thisk.shape[0], -1, -1), thisk), dim=-2)
424
- used_v = torch.cat((used_v.expand(thisv.shape[0], -1, -1), thisv), dim=-2)
425
- else:
426
- used_k, used_v = thisk, thisv
427
-
428
- if memkv_text is not None:
429
- used_k = torch.cat((memkv_text[..., :h0].expand(thisk.shape[0], -1, -1), used_k), dim=-2)
430
- used_v = torch.cat((memkv_text[..., h0:].expand(thisv.shape[0], -1, -1), used_v), dim=-2)
431
-
432
- used_k = used_k.reshape(b0, -1, self.n_head, h).permute(0, 2, 1, 3)
433
- used_v = used_v.reshape(b0, -1, self.n_head, h).permute(0, 2, 1, 3)
434
- thisq = thisq.reshape(b0, 1, self.n_head, h).permute(0, 2, 1, 3) # [b0, n_head, 1, h]
435
- attn = torch.matmul(thisq / math.sqrt(h), used_k.transpose(-1, -2))
436
- if memkv_text is not None:
437
- attn[..., :memkv_text.shape[-2]] += log_text_attention_weights
438
- attn = F.softmax(attn, dim=-1)
439
-
440
- context_swin = torch.matmul(attn, used_v).permute(0, 2, 1, 3).reshape(b0, 1, h0)
441
-
442
- return context_swin, this_qkv[..., h0:]
443
-
444
-
445
- def attention_localframe_and_text_NAR(q0, k0, v0, attention_mask,
446
- n_head, text_len, frame_len, frame_num,
447
- attention_dropout=None, log_text_attention_weights=0, stage=1, **kwargs):
448
- b, s0, h0 = q0.shape
449
- s1 = s0 - text_len
450
- h = h0 // n_head
451
- assert q0.shape[1] == v0.shape[1] == k0.shape[1] == text_len+frame_len*frame_num
452
- # attention_mask.shape [4, b or 1, 1, text_len+frame_len, text_len+frame_len]
453
- if stage == 2:
454
- assert frame_num == 3
455
-
456
- q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
457
- v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
458
- k0 = k0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
459
- k0T = k0.transpose(-1, -2)
460
-
461
- score_any2text = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T[..., :text_len])
462
- score_any2text += log_text_attention_weights
463
- score_any2text_part1 = torch.mul(score_any2text[..., :text_len, :], attention_mask[..., :text_len, :text_len]) \
464
- - 10000.0 * (1.0 - attention_mask[..., :text_len, :text_len])
465
- # context for text
466
- attention_probs_text = F.softmax(score_any2text_part1, dim=-1)
467
- if attention_dropout is not None:
468
- with get_cuda_rng_tracker().fork():
469
- attention_probs_text = attention_dropout(attention_probs_text)
470
- context_text2text = torch.matmul(attention_probs_text, v0[..., :text_len, :])
471
- context_text2text = context_text2text.transpose(1, 2).reshape(b, text_len, h0)
472
-
473
- if frame_num > 0:
474
- score_any2text_part2 = score_any2text[..., text_len:, :]
475
-
476
- # score: frame local
477
- q0_frame = q0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h)
478
- v0_frame = v0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h)
479
- k0T_frame = k0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h).transpose(-1, -2)
480
- score_frame_local0 = torch.matmul(q0_frame / math.sqrt(q0_frame.shape[-1]), k0T_frame)
481
- if stage == 1:
482
- score_frame_local0 = torch.mul(score_frame_local0, attention_mask[..., text_len:, text_len:].unsqueeze(1)) \
483
- - 10000.0 * (1.0 - attention_mask[..., text_len:, text_len:].unsqueeze(1))
484
-
485
- # context for frame
486
- score_frame_all = torch.cat((score_any2text_part2,
487
- score_frame_local0.view(b, n_head, s1, frame_len)), dim=-1)
488
- attention_probs_frame = F.softmax(score_frame_all, dim=-1)
489
- if attention_dropout is not None:
490
- with get_cuda_rng_tracker().fork():
491
- attention_probs_frame = attention_dropout(attention_probs_frame)
492
- context_frame2text = torch.matmul(attention_probs_frame[..., :text_len], v0[..., :text_len, :]) # [b, n_head, s1, h]
493
- context_frame_local0 = torch.matmul(attention_probs_frame[..., text_len:text_len+frame_len].\
494
- view(b, n_head, frame_num, frame_len, frame_len), v0_frame).view(b, n_head, s1, h)
495
-
496
- context_frame = (context_frame2text + context_frame_local0).transpose(1, 2).reshape(b, s1, h0)
497
- else:
498
- context_frame = None
499
-
500
- return context_text2text, context_frame
501
-
502
- def attention_localframe_and_text_AR(q0, k0, v0, n_head, text_len, frame_len, frame_num,
503
- attention_dropout=None, log_text_attention_weights=0, layer_id=None, limited_spatial_channel_mem=False, stage=1, **kwargs):
504
- # limited_spatial_channel_mem=True means: mems in spatial channel is consisted of {mem_text, mem_current_frame}
505
- b, s0, h0 = k0.shape
506
- frame_num_before = (s0-text_len-1) // frame_len # frame_num == frame_num_before or frame_num == frame_num_before+1
507
- h = h0 // n_head
508
- assert q0.shape[1] == 1
509
- assert v0.shape[1] == k0.shape[1]
510
-
511
- q0 = q0.reshape(b, 1, n_head, h).permute(0, 2, 1, 3)
512
- v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
513
- k0T = k0.reshape(b, s0, n_head, h).permute(0, 2, 3, 1)
514
-
515
- if limited_spatial_channel_mem:
516
- assert frame_num_before == 0
517
- assert stage == 1 # not implemented for stage-2 yet
518
- score = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T)
519
- score[..., :text_len] += log_text_attention_weights
520
- attention_probs_frame = F.softmax(score, dim=-1)
521
- context_frame = torch.matmul(attention_probs_frame, v0).transpose(1, 2).reshape(b, 1, h0)
522
-
523
- else:
524
- score_token2text = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T[..., :text_len])
525
- score_token2text += log_text_attention_weights
526
- score_frame_local0 = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T[..., text_len+frame_num_before*frame_len:])
527
- score_frame_all = torch.cat((score_token2text,
528
- score_frame_local0), dim=-1)
529
- attention_probs_frame = F.softmax(score_frame_all, dim=-1)
530
-
531
- context_token2text = torch.matmul(attention_probs_frame[..., :text_len], v0[..., :text_len, :]) # [b, n_head, s1, h]
532
- context_frame_local0 = torch.matmul(attention_probs_frame[..., text_len:], \
533
- v0[:, :, text_len+frame_num_before*frame_len:, :])
534
- context_frame = (context_token2text + context_frame_local0).transpose(1, 2).reshape(b, 1, h0)
535
-
536
- return context_frame
537
-
538
-
539
- class CogVideoCacheModel(BaseModel):
540
- def __init__(self, args, transformer=None, parallel_output=True, window_size=None, cogvideo_stage=None):
541
- super().__init__(args, transformer=transformer, parallel_output=parallel_output)
542
- self.layout = args.layout # [64, 64+1024, 64+6*1024]
543
- self.stage = cogvideo_stage if cogvideo_stage is not None else args.cogvideo_stage # 1 or 2
544
- self.n_head = args.num_attention_heads
545
- self.window_size = window_size if window_size is not None else args.window_size
546
-
547
- frame_resolution = int(math.sqrt(self.layout[1]-self.layout[0]))
548
- self.add_mixin('extra_position_embedding', PositionEmbeddingMixin(
549
- args.additional_seqlen, args.hidden_size
550
- ))
551
-
552
- if self.stage == 1:
553
- self.add_mixin('attention_plus', FullAttentionMixin(
554
- num_layers=args.num_layers,
555
- hidden_size=args.hidden_size,
556
- frame_resolution=frame_resolution,
557
- n_head=args.num_attention_heads,
558
- frame_num=(args.layout[2]-args.layout[0])//(args.layout[1]-args.layout[0]),
559
- ))
560
- else:
561
- self.add_mixin('attention_plus', WindowAttentionMixin(
562
- num_layers=args.num_layers,
563
- hidden_size=args.hidden_size,
564
- frame_resolution=frame_resolution,
565
- window_size=self.window_size,
566
- shift_size=self.window_size//2,
567
- n_head=args.num_attention_heads,
568
- frame_num=(args.layout[2]-args.layout[0])//(args.layout[1]-args.layout[0]),
569
- ))
570
-
571
-
572
- @classmethod
573
- def add_model_specific_args(cls, parser):
574
- group = parser.add_argument_group('VideoSwinLocalModel', 'video swin local model configurations')
575
- group.add_argument("--layout", type=str, default='64, 464, 2064')
576
- group.add_argument("--window-size", type=int, default=10) # 优先级在直接参数赋值之后
577
- group.add_argument("--additional-seqlen", type=int, default=2000)
578
- group.add_argument("--cogvideo-stage", type=int, default=1, choices=[1,2]) # 优先级在直接参数赋值之后
579
- return parser
580
-
581
- def disable_untrainable_params(self):
582
- pass
583
-
584
- def position_embedding_forward(self, position_ids, **kw_args):
585
- if position_ids.shape[-1] > 1:
586
- if self.stage == 1:
587
- if position_ids[0,-1] >= (512+400):
588
- frame_num = position_ids.shape[-1] // 400
589
- position_embeddings = torch.cat(
590
- (
591
- self.transformer.position_embeddings(position_ids[..., :-400*(frame_num-1)]),
592
- self.get_mixin('extra_position_embedding').position_embeddings(position_ids[..., -400*(frame_num-1):]-(512+400))
593
- ),
594
- dim=-2
595
- )
596
- else:
597
- position_embeddings = self.transformer.position_embeddings(position_ids)
598
- else:
599
- # given 3, interpolate 2
600
- position_embeddings = torch.cat(
601
- (
602
- self.transformer.position_embeddings(position_ids[..., :-800]),
603
- self.get_mixin('extra_position_embedding').position_embeddings(position_ids[..., -800:]-(512+400))
604
- ),
605
- dim=-2
606
- )
607
- else:
608
- if position_ids[0, 0] >= (512+400):
609
- position_embeddings = self.get_mixin('extra_position_embedding').position_embeddings(position_ids-(512+400))
610
- else:
611
- position_embeddings = self.transformer.position_embeddings(position_ids)
612
- return position_embeddings
613
-
614
- def attention_forward(self, hidden_states, mask, layer_id, mems=None, log_text_attention_weights=0, text_len=0, frame_len=0, counter=0, enforce_no_swin=False, limited_spatial_channel_mem=False, **kw_args):
615
- attn_module = self.transformer.layers[layer_id].attention
616
- hidden_size = hidden_states.shape[-1]
617
-
618
- # base model qkv
619
- if mems is None:
620
- mixed_raw_layer = attn_module.query_key_value(hidden_states)
621
- q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer, 3)
622
- assert (q0.shape[1]-text_len) % frame_len == 0
623
- memkv0 = torch.cat((k0, v0), dim=-1)
624
- context_text, context_frame_local_text = attention_localframe_and_text_NAR(
625
- q0, k0, v0,
626
- mask,
627
- n_head=attn_module.num_attention_heads_per_partition,
628
- text_len=text_len,
629
- frame_len=frame_len,
630
- frame_num=(q0.shape[1]-text_len)//frame_len,
631
- log_text_attention_weights=log_text_attention_weights,
632
- stage=self.stage
633
- )
634
-
635
- # change: self.swin_attend_to_text默认为True:
636
- memkv1_text = self.get_mixin('attention_plus').query_key_value[layer_id](hidden_states[..., :text_len, :])[..., hidden_size:]
637
- output_text = attn_module.dense(context_text)
638
-
639
- if (q0.shape[1]-text_len)//frame_len > 0:
640
- assert (q0.shape[1]-text_len) % frame_len == 0
641
- context_frame_swin, memkv1_frame = self.get_mixin('attention_plus').attention_extra_NAR_inference(
642
- hidden_states[:,text_len:], layer_id, memkv_text=memkv1_text, stage=self.stage)
643
- if not enforce_no_swin:
644
- attn_distrib = torch.sigmoid(self.get_mixin('attention_plus').attn_distribution[layer_id])
645
- attn_distrib = attn_distrib.unsqueeze(0).unsqueeze(0)
646
- output_frame = torch.mul(attn_module.dense(context_frame_local_text), attn_distrib)\
647
- +torch.mul(self.get_mixin('attention_plus').dense[layer_id](context_frame_swin), 1-attn_distrib)
648
- else:
649
- output_frame = attn_module.dense(context_frame_local_text[..., :frame_len, :])
650
- output = torch.cat((output_text, output_frame), dim=-2)
651
- memkv1 = torch.cat((memkv1_text, memkv1_frame), dim=-2) if memkv1_text is not None else memkv1_frame
652
- else:
653
- output = output_text
654
- memkv1 = memkv1_text
655
- kw_args['output_this_layer']['mem_kv'] = (memkv0, memkv1)
656
-
657
-
658
- else:
659
- mixed_raw_layer = attn_module.query_key_value(hidden_states)
660
- q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer, 3)
661
- new_memkv0 = torch.cat((k0, v0), dim=-1)
662
- old_k0, old_v0 = mems[0][layer_id][..., :hidden_size], mems[0][layer_id][..., hidden_size:]
663
-
664
- context_frame_local_text = attention_localframe_and_text_AR(
665
- q0,
666
- torch.cat((old_k0.expand(k0.shape[0], -1, -1), k0), dim=-2),
667
- torch.cat((old_v0.expand(v0.shape[0], -1, -1), v0), dim=-2),
668
- n_head=attn_module.num_attention_heads_per_partition,
669
- text_len=text_len,
670
- frame_len=frame_len,
671
- frame_num=None,
672
- log_text_attention_weights=log_text_attention_weights,
673
- layer_id=layer_id,
674
- limited_spatial_channel_mem=limited_spatial_channel_mem,
675
- )
676
-
677
- old_memkv1 = mems[1][layer_id] if mems[1] is not None else None
678
-
679
- context_frame_swin, new_memkv1 = self.get_mixin('attention_plus').attention_extra_AR_inference(hidden_states,
680
- old_memkv1[..., text_len:, :] if old_memkv1.shape[-2]>text_len else None,
681
- counter-text_len,
682
- layer_id,
683
- memkv_text=old_memkv1[..., :text_len, :],
684
- log_text_attention_weights=log_text_attention_weights)
685
- if not enforce_no_swin:
686
- attn_distrib = torch.sigmoid(self.get_mixin('attention_plus').attn_distribution[layer_id])
687
- attn_distrib = attn_distrib.unsqueeze(0).unsqueeze(0)
688
- output = torch.mul(attn_module.dense(context_frame_local_text), attn_distrib)\
689
- +torch.mul(self.get_mixin('attention_plus').dense[layer_id](context_frame_swin), 1-attn_distrib)
690
- else:
691
- output = attn_module.dense(context_frame_local_text)
692
-
693
- kw_args['output_this_layer']['mem_kv'] = (new_memkv0, new_memkv1)
694
-
695
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/cogvideo_model.py DELETED
@@ -1,543 +0,0 @@
1
- # -*- encoding: utf-8 -*-
2
- '''
3
- @File : cogvideo_model.py
4
- @Time : 2022/07/11 16:12:05
5
- @Author : Wenyi Hong
6
- @Version : 1.0
7
- @Contact : hwy22@mails.tsinghua.edu.cn
8
- '''
9
-
10
- # here put the import lib
11
-
12
- import torch
13
- from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin
14
-
15
- from SwissArmyTransformer.mpu.utils import split_tensor_along_last_dim
16
- from SwissArmyTransformer.model.transformer import unscaled_init_method
17
- from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear
18
- import torch.nn.functional as F
19
- from deepspeed.runtime.activation_checkpointing.checkpointing import get_cuda_rng_tracker
20
- import math
21
-
22
- class PositionEmbeddingMixin(BaseMixin):
23
- def __init__(self, additional_sequence_length, hidden_size,
24
- init_method_std=0.02, reinit_slice=slice(512, 912),
25
- ):
26
- super(PositionEmbeddingMixin, self).__init__()
27
- self.reinit_slice = reinit_slice
28
- self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size)
29
- torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
30
-
31
- def reinit(self, parent_model=None):
32
- old_weights = self.transformer.position_embeddings.weight.data[self.reinit_slice]
33
- old_len, hidden_size = old_weights.shape
34
- assert hidden_size == self.position_embeddings.weight.shape[-1]
35
- self.position_embeddings.weight.data.view(-1, old_len, hidden_size).copy_(old_weights)
36
-
37
- def window_partition(x, window_size):
38
- """
39
- Args:
40
- x: (B, framenum, H, W, C)
41
- window_size (int): window size
42
- Returns:
43
- windows: (num_windows*B, frame_num, window_size, window_size, C)
44
- """
45
- B, framenum, H, W, C = x.shape
46
- x = x.view(B, framenum, H // window_size, window_size, W // window_size, window_size, C)
47
- windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(-1, framenum, window_size, window_size, C)
48
- return windows
49
-
50
- def window_reverse(windows, window_size, H, W):
51
- """
52
- Args:
53
- windows: (num_windows*B, frame_num, window_size, window_size, C)
54
- window_size (int): Window size
55
- H (int): Height of image
56
- W (int): Width of image
57
- Returns:
58
- x: (B, frame_num, H, W, C)
59
- """
60
- B = int(windows.shape[0] / (H * W / window_size / window_size))
61
- framenum = windows.shape[1]
62
- x = windows.view(B, H // window_size, W // window_size, framenum, window_size, window_size, -1)
63
- x = x.permute(0, 3, 1, 4, 2, 5, 6).contiguous().view(B, framenum, H, W, -1)
64
- return x
65
-
66
- class WindowAttentionMixin(BaseMixin):
67
- def __init__(self, num_layers,
68
- hidden_size,
69
- frame_resolution,
70
- window_size,
71
- shift_size,
72
- n_head,
73
- frame_num,
74
- init_method=unscaled_init_method(0.02),
75
- output_layer_init_method=unscaled_init_method(0.02),
76
- ):
77
- super(WindowAttentionMixin, self).__init__()
78
- self.num_layers = num_layers # replace attention in the LAST n layers
79
- self.query_key_value = torch.nn.ModuleList(
80
- [ColumnParallelLinear(hidden_size, 3*hidden_size,stride=3,
81
- gather_output=False,init_method=init_method)
82
- for layer_id in range(num_layers)
83
- ])
84
- self.dense = torch.nn.ModuleList(
85
- [RowParallelLinear(
86
- hidden_size,
87
- hidden_size,
88
- input_is_parallel=True,
89
- init_method=output_layer_init_method,
90
- bias=True,
91
- module=self,
92
- name="dense",
93
- )
94
- for layer_id in range(num_layers)
95
- ])
96
-
97
- self.n_head = n_head
98
- self.window_size = window_size
99
- self.frame_resolution = frame_resolution
100
- self.frame_len = frame_resolution * frame_resolution
101
- assert frame_resolution % window_size == 0
102
- assert 0 < shift_size < window_size
103
- nW = (self.frame_resolution // self.window_size) ** 2
104
- ws_squre = self.window_size * self.window_size
105
-
106
- # odd non-shift, even shift
107
- img_mask = torch.zeros((1, 1, frame_resolution, frame_resolution, 1))
108
- h_slices = (slice(0, -shift_size),
109
- slice(-shift_size, None))
110
- w_slices = (slice(0, -shift_size),
111
- slice(-shift_size, None))
112
- cnt = 0
113
- for h in h_slices:
114
- for w in w_slices:
115
- img_mask[:, :, h, w, :] = cnt
116
- cnt += 1
117
- mask_windows = window_partition(img_mask, self.window_size) # nW, 1, window_size, window_size, 1
118
- mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
119
- sub_attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) #[nW, self.window_size * self.window_size, self.window_size * self.window_size]
120
- sub_attn_mask = sub_attn_mask.masked_fill(sub_attn_mask != 0, float(0.0)).masked_fill(sub_attn_mask == 0, float(1.00))
121
- attn_mask = sub_attn_mask.repeat(1, frame_num, frame_num)
122
-
123
- self.attn_mask_sequential = attn_mask.clone().tril()
124
- self.causal_mask_sequential = torch.ones(1, ws_squre*frame_num, ws_squre*frame_num).tril()
125
-
126
- self.causal_mask_interp = torch.ones(1, ws_squre*frame_num, ws_squre*frame_num)
127
- self.attn_mask_interp = attn_mask.clone()
128
-
129
- # bi-dir
130
- for bi_idx in range(0, frame_num, 2):
131
- for uni_idx in range(1, frame_num, 2):
132
- self.attn_mask_interp[:, bi_idx*ws_squre:(bi_idx+1)*ws_squre, uni_idx*ws_squre:(uni_idx+1)*ws_squre] = 0
133
- self.causal_mask_interp[:, bi_idx*ws_squre:(bi_idx+1)*ws_squre, uni_idx*ws_squre:(uni_idx+1)*ws_squre] = 0
134
- # uni-dir
135
- for uni_idx in range(1, frame_num, 2):
136
- self.attn_mask_interp[:, ws_squre*uni_idx:ws_squre*(uni_idx+1), ws_squre*uni_idx:ws_squre*(uni_idx+1)].tril_()
137
- self.causal_mask_interp[:, ws_squre*uni_idx:ws_squre*(uni_idx+1), ws_squre*uni_idx:ws_squre*(uni_idx+1)].tril_()
138
- for uni_idx2 in range(uni_idx+2, frame_num, 2):
139
- self.attn_mask_interp[:, ws_squre*uni_idx:ws_squre*(uni_idx+1), ws_squre*uni_idx2:ws_squre*(uni_idx2+1)] = 0
140
- self.causal_mask_interp[:, ws_squre*uni_idx:ws_squre*(uni_idx+1), ws_squre*uni_idx2:ws_squre*(uni_idx2+1)] = 0
141
-
142
- # expand dim
143
- self.attn_mask_sequential = self.attn_mask_sequential[None, None, :, None]
144
- self.attn_mask_interp = self.attn_mask_interp[None, None, :, None]
145
- self.causal_mask_sequential = self.causal_mask_sequential[None, None, :, None]
146
- self.causal_mask_interp = self.causal_mask_interp[None, None, :, None]
147
-
148
- self.shift_sizes = [0, shift_size]
149
- # self.register_buffer("attn_mask", attn_mask)
150
- # self.register_buffer("causal_mask", causal_mask)
151
- self.mask_initialized = False
152
-
153
- self.attn_distribution = torch.nn.ParameterList([
154
- torch.nn.Parameter(torch.zeros(hidden_size))
155
- for _ in range(num_layers)
156
- ])
157
-
158
- def reinit(self, *pre_mixins):
159
- start_layer = len(self.transformer.layers) - self.num_layers
160
- assert start_layer >= 0
161
- for layer_id in range(self.num_layers):
162
- old_attention = self.transformer.layers[start_layer + layer_id].attention
163
- self.query_key_value[layer_id].weight.data.copy_(old_attention.query_key_value.weight.data)
164
- self.query_key_value[layer_id].bias.data.copy_(old_attention.query_key_value.bias.data)
165
-
166
- def attention_extra(self, frame_hidden_state, layer_id, attn_dropout, text_hidden_state=None,
167
- text_attn_mask=None, mode_sequential=True):
168
- # pb relax
169
- swin_pb_relax = True
170
- alpha = 16
171
-
172
- # frame_hidden_state [batchsize, frame_num*frame_size, n_head*hiddensize_perhead]
173
- if not self.mask_initialized:
174
- self.attn_mask_sequential = self.attn_mask_sequential.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype)
175
- self.causal_mask_sequential = self.causal_mask_sequential.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype)
176
- self.attn_mask_interp = self.attn_mask_interp.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype)
177
- self.causal_mask_interp = self.causal_mask_interp.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype)
178
- self.mask_initialized = True
179
- b0, s1, h0 = frame_hidden_state.shape
180
- h = h0 // self.n_head
181
- frame_len = self.frame_resolution * self.frame_resolution
182
- frame_num = s1 // frame_len
183
- assert frame_num*frame_len == s1
184
- wind_square = self.window_size * self.window_size
185
- nW = frame_len // wind_square
186
- bswin = b0 * nW
187
-
188
- causal_mask = self.causal_mask_sequential if mode_sequential else self.causal_mask_interp
189
- attn_mask = self.attn_mask_sequential if mode_sequential else self.attn_mask_interp
190
- if text_hidden_state is not None:
191
- s0 = text_hidden_state.shape[1]
192
- qkv_text = self.query_key_value[layer_id](text_hidden_state).reshape(b0, s0, 3, self.n_head, h).permute(2, 0, 3, 1, 4) #[3, b0, n_head, s0, h]
193
- q_text, k_text, v_text = qkv_text[0], qkv_text[1], qkv_text[2]
194
-
195
- # shift
196
- frame_hidden_state = frame_hidden_state.reshape(b0, frame_num, self.frame_resolution, self.frame_resolution, h0)
197
- if self.shift_sizes[layer_id%2] > 0:
198
- frame_hidden_state = torch.roll(frame_hidden_state, shifts=(-self.shift_sizes[layer_id%2], -self.shift_sizes[layer_id%2]), dims=(2,3))
199
- # window partition
200
- frame_hidden_state = window_partition(frame_hidden_state, self.window_size).reshape(bswin, frame_num*wind_square, h0)
201
- qkv = self.query_key_value[layer_id](frame_hidden_state).reshape(bswin, frame_num*wind_square, 3, self.n_head, h)\
202
- .permute(2, 0, 3, 1, 4) #[3, bswin, n_head, frame_num*wind_size*wind_size, h]
203
- q, k, v = qkv[0], qkv[1], qkv[2]
204
-
205
- # pb-relax
206
- if swin_pb_relax:
207
- attn = torch.matmul(q / (math.sqrt(h)*alpha), k.transpose(-1, -2))
208
- else:
209
- attn = torch.matmul(q / math.sqrt(h), k.transpose(-1, -2))
210
-
211
- if self.shift_sizes[layer_id%2] > 0:
212
- # attn = attn.view(bswin // nW, nW, self.n_head, frame_num*wind_square, frame_num*wind_square) + self.attn_mask.unsqueeze(1).unsqueeze(0)
213
- attn = torch.mul(attn.view(bswin // nW, nW, self.n_head, frame_num*wind_square, frame_num*wind_square), attn_mask)\
214
- - 10000.0 * (1.0 - attn_mask)
215
- attn = attn.view(bswin, self.n_head, frame_num*wind_square, frame_num*wind_square)
216
- else:
217
- attn = torch.mul(attn.view(bswin // nW, nW, self.n_head, frame_num*wind_square, frame_num*wind_square), causal_mask)\
218
- - 10000.0 * (1.0 - causal_mask)
219
- attn = attn.view(bswin, self.n_head, frame_num*wind_square, frame_num*wind_square)
220
- if swin_pb_relax:
221
- swin_pb_relax_const = torch.max(attn.reshape(bswin, self.n_head, -1), dim=-1, keepdim=True)[0].detach().unsqueeze(-1)
222
- attn = (attn - swin_pb_relax_const)*alpha
223
-
224
- if text_hidden_state is None:
225
- attn = F.softmax(attn, dim=-1)
226
- if attn_dropout is not None:
227
- with get_cuda_rng_tracker().fork():
228
- attn = attn_dropout(attn)
229
- context_swin = torch.matmul(attn, v).permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0)
230
- else:
231
- assert text_attn_mask is not None
232
- text_attn_mask = text_attn_mask.unsqueeze(2).unsqueeze(2)
233
- # pb-relax
234
- if swin_pb_relax:
235
- attn_frame2text = torch.matmul(q.reshape(b0, -1, self.n_head, frame_num*wind_square, h) / (math.sqrt(h)*alpha), k_text.unsqueeze(1).transpose(-1, -2))
236
- attn_frame2text = (attn_frame2text-swin_pb_relax_const.reshape(b0, -1, self.n_head, 1, 1))*alpha
237
- else:
238
- attn_frame2text = torch.matmul(q.reshape(b0, -1, self.n_head, frame_num*wind_square, h) / math.sqrt(h), k_text.unsqueeze(1).transpose(-1, -2))
239
-
240
- attn_frame2text = torch.mul(text_attn_mask, attn_frame2text) - 10000.0 * (1.0 - text_attn_mask)
241
- attn_frame2text = attn_frame2text.reshape(bswin, self.n_head, frame_num*wind_square, s0)
242
- attn = torch.cat((attn, attn_frame2text), dim=-1)
243
- attn = F.softmax(attn, dim=-1)
244
-
245
- if attn_dropout is not None:
246
- with get_cuda_rng_tracker().fork():
247
- attn = attn_dropout(attn)
248
-
249
- context_swin = (torch.matmul(attn[..., :-s0], v) +
250
- torch.matmul(attn[..., -s0:].reshape(b0, -1, self.n_head,frame_num*wind_square, s0), v_text.unsqueeze(1))\
251
- .reshape(bswin, self.n_head, frame_num*wind_square, h))\
252
- .permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0)
253
-
254
- context_swin = window_reverse(context_swin, self.window_size, self.frame_resolution, self.frame_resolution)
255
- # reverse cycle shift
256
- if self.shift_sizes[layer_id%2] > 0:
257
- context_swin = torch.roll(context_swin, shifts=(self.shift_sizes[layer_id%2], self.shift_sizes[layer_id%2]), dims=(2,3))
258
- context_swin = context_swin.reshape(b0, s1, h0)
259
-
260
- return context_swin
261
-
262
-
263
- class FullAttentionMixin(BaseMixin):
264
- def __init__(self, num_layers,
265
- hidden_size,
266
- frame_resolution,
267
- n_head,
268
- frame_num,
269
- init_method=unscaled_init_method(0.02),
270
- output_layer_init_method=unscaled_init_method(0.02),
271
- ):
272
- super(FullAttentionMixin, self).__init__()
273
- self.num_layers = num_layers # replace attention in the LAST n layers
274
- self.query_key_value = torch.nn.ModuleList(
275
- [ColumnParallelLinear(hidden_size, 3*hidden_size,stride=3,
276
- gather_output=False,init_method=init_method)
277
- for layer_id in range(num_layers)
278
- ])
279
- self.dense = torch.nn.ModuleList(
280
- [RowParallelLinear(
281
- hidden_size,
282
- hidden_size,
283
- input_is_parallel=True,
284
- init_method=output_layer_init_method,
285
- bias=True,
286
- module=self,
287
- name="dense",)
288
- for layer_id in range(num_layers)
289
- ])
290
-
291
- self.n_head = n_head
292
- self.frame_resolution = frame_resolution
293
- self.frame_len = frame_resolution * frame_resolution
294
- self.causal_mask = torch.ones(1, 1, self.frame_len*frame_num, self.frame_len*frame_num).tril()
295
-
296
- self.mask_initialized = False
297
-
298
- self.attn_distribution = torch.nn.ParameterList([
299
- torch.nn.Parameter(torch.zeros(hidden_size))
300
- for _ in range(num_layers)
301
- ])
302
-
303
- def reinit(self, *pre_mixins):
304
- start_layer = len(self.transformer.layers) - self.num_layers
305
- assert start_layer >= 0
306
- for layer_id in range(self.num_layers):
307
- base_attention = self.transformer.layers[start_layer + layer_id].attention
308
- self.query_key_value[layer_id].weight.data.copy_(base_attention.query_key_value.weight.data)
309
- self.query_key_value[layer_id].bias.data.copy_(base_attention.query_key_value.bias.data)
310
-
311
- def attention_extra(self, frame_hidden_state, layer_id, attn_dropout, text_hidden_state=None,
312
- text_attn_mask=None, mode_sequential=False):
313
- # pb relax
314
- # frame_hidden_state [batchsize, frame_num*frame_size, n_head*hiddensize_perhead]
315
- assert mode_sequential == True # only
316
- swin_pb_relax = True
317
- alpha = 16
318
-
319
- if not self.mask_initialized:
320
- self.causal_mask = self.causal_mask.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype)
321
- self.mask_initialized = True
322
- b0, s1, h0 = frame_hidden_state.shape
323
- h = h0 // self.n_head
324
- frame_len = self.frame_resolution * self.frame_resolution
325
- frame_num = s1 // frame_len
326
- assert frame_num*frame_len == s1
327
-
328
- qkv = self.query_key_value[layer_id](frame_hidden_state).reshape(b0, s1, 3, self.n_head, h)\
329
- .permute(2, 0, 3, 1, 4) #[3, b0, n_head, s1, h]
330
- q, k, v = qkv[0], qkv[1], qkv[2]
331
-
332
- # frames-to-frames
333
- if swin_pb_relax:
334
- attn = torch.matmul(q / (math.sqrt(h)*alpha), k.transpose(-1, -2))
335
- else:
336
- attn = torch.matmul(q / math.sqrt(h), k.transpose(-1, -2))
337
- attn = torch.mul(attn, self.causal_mask) - 10000.0 * (1.0 - self.causal_mask)
338
- if swin_pb_relax:
339
- swin_pb_relax_const = torch.max(attn.reshape(b0, self.n_head, -1), dim=-1, keepdim=True)[0].detach().unsqueeze(-1)
340
- attn = (attn - swin_pb_relax_const)*alpha
341
-
342
- if text_hidden_state is None:
343
- attn = F.softmax(attn, dim=-1)
344
- if attn_dropout is not None:
345
- with get_cuda_rng_tracker().fork():
346
- attn = attn_dropout(attn)
347
- context_swin = torch.matmul(attn, v).permute(0, 2, 1, 3).reshape(b0, s1, h0)
348
- else:
349
- # frame-to-text
350
- assert text_attn_mask is not None
351
- s0 = text_hidden_state.shape[1]
352
- qkv_text = self.query_key_value[layer_id](text_hidden_state).reshape(b0, s0, 3, self.n_head, h).permute(2, 0, 3, 1, 4) #[3, b0, n_head, s0, h]
353
- q_text, k_text, v_text = qkv_text[0], qkv_text[1], qkv_text[2]
354
- text_attn_mask = text_attn_mask.unsqueeze(2)
355
- if swin_pb_relax:
356
- attn_frame2text = torch.matmul(q.reshape(b0, self.n_head, s1, h) / (math.sqrt(h)*alpha), k_text.transpose(-1, -2))
357
- attn_frame2text = (attn_frame2text-swin_pb_relax_const.reshape(b0, self.n_head, 1, 1))*alpha
358
- else:
359
- attn_frame2text = torch.matmul(q.reshape(b0, self.n_head, s1, h) / math.sqrt(h), k_text.transpose(-1, -2))
360
- attn_frame2text = torch.mul(text_attn_mask, attn_frame2text) - 10000.0 * (1.0 - text_attn_mask)
361
- attn_frame2text = attn_frame2text.reshape(b0, self.n_head, s1, s0)
362
-
363
- attn = torch.cat((attn, attn_frame2text), dim=-1)
364
- attn = F.softmax(attn, dim=-1)
365
-
366
- if attn_dropout is not None:
367
- with get_cuda_rng_tracker().fork():
368
- attn = attn_dropout(attn)
369
-
370
- context_frame = (torch.matmul(attn[..., :-s0], v) +
371
- torch.matmul(attn[..., -s0:].reshape(b0, self.n_head,s1, s0), v_text))\
372
- .permute(0, 2, 1, 3).reshape(b0, s1, h0)
373
-
374
- return context_frame
375
-
376
-
377
- def attention_localframe_and_text(q0, k0, v0, attention_mask_totxt, attention_mask_local,
378
- n_head, text_len, frame_len, frame_num, attention_dropout=None, layer_id=0, **kwargs):
379
- b, s0, h0 = q0.shape
380
- s1 = s0 - text_len
381
- h = h0 // n_head
382
- assert q0.shape[1] == v0.shape[1] == k0.shape[1] == text_len+frame_len*frame_num
383
- # attention_mask_totxt [b, 1, 1, text_len]
384
- # attention_mask_local [1, 1, frame_num, frame_len, frame_len]
385
- # attention_mask: [1, 1, text_len+frame_len, text_len+frame_len]
386
-
387
- q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
388
- v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
389
- k0 = k0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
390
- k0T = k0.transpose(-1, -2)
391
-
392
- # score: any2text
393
- score_any2text = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T[..., :text_len])
394
- score_any2text_part1 = torch.mul(score_any2text[..., :text_len, :], attention_mask_totxt) \
395
- - 10000.0 * (1.0 - attention_mask_totxt)
396
- score_any2text_part2 = torch.mul(score_any2text[..., text_len:, :], attention_mask_totxt) - \
397
- 10000.0 * (1.0 - attention_mask_totxt)
398
-
399
- # score: frame local
400
- q0_frame = q0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h)
401
- v0_frame = v0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h)
402
- k0T_frame = k0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h).transpose(-1, -2)
403
- score_frame_local0 = torch.matmul(q0_frame / math.sqrt(q0_frame.shape[-1]), k0T_frame)
404
- score_frame_local0 = torch.mul(score_frame_local0, attention_mask_local) \
405
- - 10000.0 * (1.0 - attention_mask_local)
406
-
407
- # context for frame
408
- score_frame_all = torch.cat((score_any2text_part2,
409
- score_frame_local0.view(b, n_head, s1, frame_len)), dim=-1)
410
- attention_probs_frame = F.softmax(score_frame_all, dim=-1)
411
-
412
- if attention_dropout is not None:
413
- with get_cuda_rng_tracker().fork():
414
- attention_probs_frame = attention_dropout(attention_probs_frame)
415
-
416
- context_frame2text = torch.matmul(attention_probs_frame[..., :text_len], v0[..., :text_len, :]) # [b, n_head, s1, h]
417
- context_frame_local0 = torch.matmul(attention_probs_frame[..., text_len:text_len+frame_len].\
418
- view(b, n_head, frame_num, frame_len, frame_len), v0_frame).view(b, n_head, s1, h)
419
- context_frame = (context_frame2text + context_frame_local0).transpose(1, 2).reshape(b, s1, h0)
420
-
421
- # context for text
422
- attention_probs_text = F.softmax(score_any2text_part1, dim=-1)
423
- if attention_dropout is not None:
424
- with get_cuda_rng_tracker().fork():
425
- attention_probs_text = attention_dropout(attention_probs_text)
426
- context_text2text = torch.matmul(attention_probs_text, v0[..., :text_len, :])
427
- context_text2text = context_text2text.transpose(1, 2).reshape(b, text_len, h0)
428
-
429
- return context_text2text, context_frame
430
-
431
-
432
- class CogVideoModel(BaseModel):
433
- def __init__(self, args, transformer=None, parallel_output=True):
434
- super().__init__(args, transformer=transformer, parallel_output=parallel_output)
435
- self.stage = args.cogvideo_stage # 1 or 2
436
- self.mode_sequential = True if self.stage==1 else False
437
- self.layout = args.layout # [64, 64+400, 64+5*400]
438
- self.n_head = args.num_attention_heads
439
- frame_resolution = int(math.sqrt(self.layout[1]-self.layout[0]))
440
- frame_num = (args.layout[2]-args.layout[0])//(args.layout[1]-args.layout[0])
441
- frame_len = self.layout[1]-self.layout[0]
442
-
443
- self.add_mixin('extra_position_embedding', PositionEmbeddingMixin(
444
- args.additional_seqlen, args.hidden_size
445
- ))
446
-
447
- if args.window_size == -1:
448
- # full attention
449
- assert self.stage == 1
450
- self.add_mixin('attention_plus', FullAttentionMixin(
451
- num_layers=args.num_layers,
452
- hidden_size=args.hidden_size,
453
- frame_resolution=frame_resolution,
454
- n_head=args.num_attention_heads,
455
- frame_num=frame_num,
456
- ))
457
- else:
458
- self.add_mixin('attention_plus', WindowAttentionMixin(
459
- num_layers=args.num_layers,
460
- hidden_size=args.hidden_size,
461
- frame_resolution=frame_resolution,
462
- window_size=args.window_size,
463
- shift_size=args.window_size//2,
464
- n_head=args.num_attention_heads,
465
- frame_num=frame_num,
466
- ))
467
- # attention_mask_local
468
- self.attention_mask_local_sequential = torch.ones(1, 1, frame_num, frame_len, frame_len).tril().unsqueeze(0)
469
- self.attention_mask_local_interp = torch.ones(1, 1, frame_num, frame_len, frame_len)
470
-
471
- for idx in range(1, frame_num, 2):
472
- self.attention_mask_local_interp[:, :, idx:idx+1].tril_()
473
- self.attention_mask_local_interp = self.attention_mask_local_interp.unsqueeze(0)
474
- self.mask_initialized = False
475
-
476
- @classmethod
477
- def add_model_specific_args(cls, parser):
478
- group = parser.add_argument_group('CogVideoModel', 'CogVideo model configurations')
479
- group.add_argument("--layout", type=str, default='64, 464, 2064', help='text_len, textlen+frame_len, textlen+frame_len*frame_num')
480
- group.add_argument("--window-size", type=int, default=10, help="swin attention's window size in temperal channel, -1 represents full attention")
481
- group.add_argument("--additional-seqlen", type=int, default=2000)
482
- group.add_argument("--cogvideo-stage", type=int, default=1, choices=[1,2])
483
- return parser
484
-
485
- def disable_untrainable_params(self):
486
- self.transformer.requires_grad_(False)
487
-
488
- def position_embedding_forward(self, position_ids, **kw_args):
489
- position = position_ids[..., :(64+400)]
490
- position_plus = position_ids[..., (64+400):]
491
- position_embeddings = torch.cat(
492
- (
493
- self.transformer.position_embeddings(position),
494
- self.get_mixin('extra_position_embedding').position_embeddings(position_plus-(512+400))
495
- ),
496
- dim=-2
497
- )
498
- return position_embeddings
499
-
500
- def attention_forward(self, hidden_states, mask, layer_id, **kw_args):
501
- # mask.shape=[bs, 1, 1, 64]
502
- if not self.mask_initialized:
503
- self.attention_mask_local_sequential = self.attention_mask_local_sequential.to(device=hidden_states.device, dtype=hidden_states.dtype)
504
- self.attention_mask_local_interp = self.attention_mask_local_interp.to(device=hidden_states.device, dtype=hidden_states.dtype)
505
- self.mask_initialized = True
506
-
507
- attn_module = self.transformer.layers[layer_id].attention
508
- hidden_size = hidden_states.shape[-1]
509
- bs = hidden_states.shape[0]
510
-
511
- # base model qkv
512
- mixed_raw_layer = attn_module.query_key_value(hidden_states)
513
- q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer, 3)
514
- dropout_fn = self.transformer.layers[layer_id].attention.attention_dropout if self.training else None
515
-
516
- attention_mask_local = self.attention_mask_local_sequential if self.mode_sequential else self.attention_mask_local_interp
517
- context_text, context_frame_local_text = attention_localframe_and_text(
518
- q0, k0, v0,
519
- attention_mask_totxt=mask,
520
- attention_mask_local=attention_mask_local,
521
- n_head=attn_module.num_attention_heads_per_partition,
522
- text_len=self.layout[0],
523
- frame_len=self.layout[1]-self.layout[0],
524
- frame_num=(self.layout[2]-self.layout[0])//(self.layout[1]-self.layout[0]),
525
- attention_dropout=dropout_fn,
526
- layer_id=layer_id,
527
- )
528
-
529
- context_frame_swin = self.get_mixin('attention_plus').attention_extra(
530
- hidden_states[:, self.layout[0]:], layer_id, dropout_fn,
531
- text_hidden_state=hidden_states[:, :self.layout[0]],
532
- text_attn_mask=mask[..., 0, :],
533
- mode_sequential=self.mode_sequential)
534
-
535
- attn_distrib = torch.sigmoid(self.get_mixin('attention_plus').attn_distribution[layer_id])
536
- attn_distrib = attn_distrib.unsqueeze(0).unsqueeze(0)
537
-
538
- output_text = attn_module.dense(context_text)
539
- output_frame = torch.mul(attn_module.dense(context_frame_local_text), attn_distrib)\
540
- +torch.mul(self.get_mixin('attention_plus').dense[layer_id](context_frame_swin), 1-attn_distrib)
541
- output = torch.cat((output_text, output_frame), dim=-2)
542
-
543
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
patch ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/coglm_strategy.py b/coglm_strategy.py
2
+ index d485715..a9eab3b 100644
3
+ --- a/coglm_strategy.py
4
+ +++ b/coglm_strategy.py
5
+ @@ -8,6 +8,7 @@
6
+
7
+ # here put the import lib
8
+ import os
9
+ +import pathlib
10
+ import sys
11
+ import math
12
+ import random
13
+ @@ -58,7 +59,8 @@ class CoglmStrategy:
14
+ self._is_done = False
15
+ self.outlier_count_down = torch.zeros(16)
16
+ self.vis_list = [[]for i in range(16)]
17
+ - self.cluster_labels = torch.tensor(np.load('cluster_label2.npy'), device='cuda', dtype=torch.long)
18
+ + cluster_label_path = pathlib.Path(__file__).parent / 'cluster_label2.npy'
19
+ + self.cluster_labels = torch.tensor(np.load(cluster_label_path), device='cuda', dtype=torch.long)
20
+ self.start_pos = -1
21
+ self.white_cluster = []
22
+ # self.fout = open('tmp.txt', 'w')
23
+ @@ -98,4 +100,4 @@ class CoglmStrategy:
24
+
25
+ def finalize(self, tokens, mems):
26
+ self._is_done = False
27
+ - return tokens, mems
28
+
29
+ + return tokens, mems
30
+ diff --git a/sr_pipeline/dsr_sampling.py b/sr_pipeline/dsr_sampling.py
31
+ index 5b8dded..07e97fd 100644
32
+ --- a/sr_pipeline/dsr_sampling.py
33
+ +++ b/sr_pipeline/dsr_sampling.py
34
+ @@ -8,6 +8,7 @@
35
+
36
+ # here put the import lib
37
+ import os
38
+ +import pathlib
39
+ import sys
40
+ import math
41
+ import random
42
+ @@ -28,7 +29,8 @@ class IterativeEntfilterStrategy:
43
+ self.invalid_slices = invalid_slices
44
+ self.temperature = temperature
45
+ self.topk = topk
46
+ - self.cluster_labels = torch.tensor(np.load('cluster_label2.npy'), device='cuda', dtype=torch.long)
47
+ + cluster_label_path = pathlib.Path(__file__).parents[1] / 'cluster_label2.npy'
48
+ + self.cluster_labels = torch.tensor(np.load(cluster_label_path), device='cuda', dtype=torch.long)
49
+
50
+
51
+ def forward(self, logits_, tokens, temperature=None, entfilter=None, filter_topk=5, temperature2=None):
pretrain_cogvideo.py DELETED
@@ -1,184 +0,0 @@
1
- # -*- encoding: utf-8 -*-
2
- '''
3
- @File : pretrain_cogvideo.py
4
- @Time : 2021/10/06 00:58:32
5
- @Author : Wenyi Hong
6
- @Contact : hwy22@mails.tsinghua.edu.cn
7
- '''
8
-
9
- # here put the import lib
10
- import os
11
- import sys
12
- import math
13
- import random
14
- import torch
15
- import argparse
16
- import numpy as np
17
- from icetk import icetk as tokenizer
18
- tokenizer.add_special_tokens(['<start_of_image>', '<start_of_english>', '<start_of_chinese>'])
19
-
20
- from models.cogvideo_model import CogVideoModel
21
- from SwissArmyTransformer import mpu, get_args
22
- from SwissArmyTransformer.training.deepspeed_training import training_main
23
- from SwissArmyTransformer.data_utils import BinaryDataset
24
-
25
- def get_masks_and_position_ids_video(data, attention_mask_totxt=None, args=None):
26
- # Extract batch size and sequence length.
27
- batch_size, seq_length = data.size()
28
- assert attention_mask_totxt is not None
29
- layout = args.layout
30
- assert seq_length == layout[-1]
31
- n_pads = layout[0] - attention_mask_totxt.sum(dim=-1).long()
32
- frame_len = layout[1]-layout[0]
33
- position_ids = torch.zeros(batch_size, layout[2], dtype=torch.long,
34
- device=data.device)
35
- for i in range(batch_size):
36
- torch.arange(layout[0] - n_pads[i], out=position_ids[i, n_pads[i]:layout[0]],
37
- dtype=torch.long, device=data.device)
38
- torch.arange(512, 512+layout[2]-layout[0],
39
- out=position_ids[i, layout[0]:], dtype=torch.long, device=data.device)
40
- return position_ids
41
-
42
-
43
- def get_batch(data_iterator, args, timers):
44
- # Items and their type.
45
- keys = ['text', 'loss_mask', 'attention_mask_totxt']
46
- datatype = torch.int64
47
-
48
- # Broadcast data.
49
- timers('data loader').start()
50
- if data_iterator is not None:
51
- data = next(data_iterator)
52
- else:
53
- data = None
54
- timers('data loader').stop()
55
-
56
- data_b = mpu.broadcast_data(keys, data, datatype)
57
- # Unpack.
58
- tokens_ = data_b['text'].long()
59
- loss_mask = data_b['loss_mask'].float()
60
- attention_mask_totxt = data_b['attention_mask_totxt'].float()
61
-
62
- labels = tokens_[:, 1:].clone().contiguous()
63
- loss_mask = loss_mask[:, 1:].contiguous()
64
- tokens = tokens_[:, :-1].clone().contiguous()
65
-
66
- for idx in range(args.layout[0], args.layout[2], 400):
67
- tokens[:, idx] = tokenizer['<start_of_image>']
68
- # Get the masks and postition ids.
69
- position_ids = get_masks_and_position_ids_video(
70
- tokens,
71
- attention_mask_totxt=attention_mask_totxt,
72
- args=args
73
- )
74
- attention_mask_totxt = attention_mask_totxt.unsqueeze(1).unsqueeze(1)
75
- # Convert
76
- if args.fp16:
77
- attention_mask_totxt = attention_mask_totxt.half()
78
- return tokens, labels, loss_mask, attention_mask_totxt, position_ids
79
-
80
-
81
- def forward_step(data_iterator, model, args, timers):
82
- """Forward step."""
83
-
84
- # Get the batch.
85
- timers('batch generator').start()
86
- tokens, labels, loss_mask, attention_mask_totxt, position_ids = get_batch(
87
- data_iterator, args, timers)
88
- timers('batch generator').stop()
89
-
90
- # Forward model.
91
- logits, *mems = model(tokens, position_ids, attention_mask_totxt)
92
- # ======= hyper params =======#
93
- perframe_len = 400
94
- text_len=64
95
- frame_num = 5
96
- logits_img_tokens = logits[:, text_len:, :tokenizer.num_image_tokens].float().contiguous()
97
- losses = mpu.vocab_parallel_cross_entropy(logits_img_tokens, labels[:, text_len:])
98
- # scaling loss mask
99
- loss_mask = loss_mask[:, text_len:].reshape(-1)
100
-
101
- losses_1d = losses.reshape(-1) * loss_mask
102
- loss = torch.sum(losses_1d) / loss_mask.sum()
103
- # ===================== Log partial losses ======================== #
104
- log_loss_dict = {}
105
- bs = losses.shape[0]
106
-
107
- if args.cogvideo_stage == 1:
108
- for i in range(frame_num):
109
- log_loss_dict[f'AR_f{i}_loss'] = losses[:, i*perframe_len:(i+1)*perframe_len].contiguous().reshape(-1).detach().sum() / max((perframe_len*bs), 1)
110
- else:
111
- for i in range(1, frame_num-1):
112
- log_loss_dict[f'ITP_f{i}_loss'] = losses[:, i*perframe_len:(i+1)*perframe_len].contiguous().reshape(-1).detach().sum() / max((perframe_len*bs), 1)
113
-
114
- # ===================== END OF BLOCK ======================= #
115
- return loss, log_loss_dict
116
-
117
-
118
- def create_dataset_function(path, args):
119
- dataset_layout = [64, 464, 2064]
120
- input_layout = [64, 464, 2064]
121
- # frame_num = 6
122
- # frame_interval = 2 # DEBUG!!!
123
- def process_fn(row):
124
- row = row.astype(np.int64)
125
- text = row[:dataset_layout[0]]
126
- frames = row[dataset_layout[0]:]
127
-
128
- if text[0] == tokenizer['<pad>']:
129
- text = text[1:] # due to our way of data processing
130
- if args.cogvideo_stage == 1:
131
- text, loss_mask, frames = make_text_video_generation(text, frames)
132
- else:
133
- text, loss_mask, frames = mask_video_frame_interpolation(text, frames)
134
-
135
- n_pad = input_layout[0] - len(text)
136
- parts = [
137
- np.array([tokenizer['<pad>']] * n_pad, dtype=np.int64),
138
- text,
139
- np.array([tokenizer['<start_of_image>']], dtype=np.int64),
140
- frames,
141
- ]
142
- ret = np.concatenate(parts, axis=0)
143
-
144
- attention_mask_totxt = np.array([0] * n_pad + [1] * (input_layout[0]-n_pad))
145
- return {'text': ret,
146
- 'loss_mask': loss_mask,
147
- 'attention_mask_totxt': attention_mask_totxt,
148
- }
149
- return BinaryDataset(path, process_fn, length_per_sample=dataset_layout[-1])
150
-
151
- def make_text_video_generation(text, frames):
152
- input_layout = [64, 464, 2064]
153
- text = text[text!= tokenizer['<pad>']][:input_layout[0]] # dataset format: 1.0秒<n>{text}<pad><pad> ...
154
- loss_mask = np.array([0] * (input_layout[1]+1) + [1] * (input_layout[2] - input_layout[1])) # 按照input的,之后loss_mask会左移一位
155
- return text, loss_mask, frames
156
-
157
- def mask_video_frame_interpolation(text, frames):
158
- input_layout = [64, 464, 2064]
159
- frame_len = input_layout[1]-input_layout[0]
160
- # text format: <pad> 1.0秒 <n> {text} <pad> <pad>
161
- text = text[text!= tokenizer['<pad>']][:input_layout[0]]
162
- loss_mask = np.array([0] * (input_layout[1]+1)
163
- + [1] * (input_layout[1]-input_layout[0])
164
- + [0] * (input_layout[1]-input_layout[0])
165
- + [1] * (input_layout[1]-input_layout[0])
166
- + [0] * (input_layout[1]-input_layout[0]) )# 按照input的,之后loss_mask会左移一位
167
-
168
- return text, loss_mask, frames
169
-
170
-
171
-
172
- if __name__ == '__main__':
173
- py_parser = argparse.ArgumentParser(add_help=False)
174
- py_parser.add_argument('--txt-loss-scale', type=float, default=1)
175
- CogVideoModel.add_model_specific_args(py_parser)
176
-
177
- known, args_list = py_parser.parse_known_args()
178
-
179
- args = get_args(args_list)
180
- args = argparse.Namespace(**vars(args), **vars(known))
181
-
182
- args.layout = [int(x) for x in args.layout.split(',')]
183
-
184
- training_main(args, model_cls=CogVideoModel, forward_step_function=forward_step, create_dataset_function=create_dataset_function)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pretrained/.gitkeep ADDED
File without changes
requirements.txt CHANGED
@@ -1,4 +1,7 @@
1
- SwissArmyTransformer>=0.2.9
2
- icetk
3
- gifmaker
4
- torchvision
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu113
2
+ imageio==2.19.5
3
+ imageio-ffmpeg==0.4.7
4
+ numpy==1.22.4
5
+ opencv-python-headless==4.6.0.66
6
+ torch==1.12.0+cu113
7
+ torchvision==0.13.0+cu113
prompt.txt → samples.txt RENAMED
@@ -1 +1,2 @@
1
  骑滑板的皮卡丘
 
 
1
  骑滑板的皮卡丘
2
+ a cat playing chess
scripts/ds_brain_pretrain_cogvideo_stage1.sh DELETED
@@ -1,108 +0,0 @@
1
- #! /bin/bash
2
-
3
- # Change for multinode config
4
-
5
- NUM_WORKERS=1
6
- NUM_GPUS_PER_WORKER=8
7
- MP_SIZE=1
8
-
9
- script_path=$(realpath $0)
10
- script_dir=$(dirname $script_path)
11
- main_dir=$(dirname $script_dir)
12
-
13
- OPTIONS_NCCL="NCCL_DEBUG=warning NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2"
14
- HOST_FILE_PATH="hostfile"
15
- # HOST_FILE_PATH="hostfile_single"
16
-
17
- video_data_test="" # TODO
18
- CHECKPOINT_PATH="" # TODO: CogView2 ckpt
19
-
20
- config_json="$script_dir/ds_config_zero.json"
21
- gpt_options=" \
22
- --experiment-name pretrain-cogvideo-stage1 \
23
- --tokenizer-type fake \
24
- --vocab-size 150010 \
25
- --model-parallel-size ${MP_SIZE} \
26
- --mode finetune \
27
- --num-workers 0 \
28
- --num-layers 48 \
29
- --hidden-size 3072 \
30
- --num-attention-heads 48 \
31
- --layout 64,464,2064 \
32
- --window-size -1 \
33
- --cogvideo-stage 1 \
34
- --additional-seqlen 2000 \
35
- --train-iters 500000 \
36
- --resume-dataloader \
37
- --train-data ${video_data_test} \
38
- --train-data-weights 1 \
39
- --split 949,50,1 \
40
- --distributed-backend nccl \
41
- --lr-decay-style cosine \
42
- --warmup .001 \
43
- --checkpoint-activations \
44
- --max-sequence-length 1024 \
45
- --fp16 \
46
- --save-interval 2000 \
47
- --eval-interval 500 \
48
- --eval-iters 15 \
49
- --log-interval 50 \
50
- --save $main_dir/checkpoints \
51
- --sandwich-ln \
52
- --load $CHECKPOINT_PATH \
53
- "
54
- # --load $CHECKPOINT_PATH \
55
- # \ --sandwich-ln
56
-
57
-
58
- gpt_options="${gpt_options}
59
- --deepspeed \
60
- --deepspeed_config ${config_json} \
61
- "
62
-
63
- #!/bin/bash
64
-
65
- # Distribute Example
66
- #export NCCL_SOCKET_IFNAME=eth0
67
- export NCCL_IB_DISABLE=0
68
- export NCCL_NET_GDR_LEVEL=2
69
- #export NCCL_IB_CUDA_SUPPORT=1
70
- #export NCCL_IB_GID_INDEX=3
71
- #export NCCL_IB_HCA=$(pushd /sys/class/infiniband/ > /dev/null; for i in mlx5_*; do cat $i/ports/1/gid_attrs/types/* 2>/dev/null | grep v >/dev/null && echo $i ; done; popd > /dev/null)
72
- export NCCL_DEBUG=info
73
- export OMP_NUM_THREADS=4
74
-
75
- if [ $RLAUNCH_REPLICA == "0" ]; then
76
- ifconfig eth0 | grep inet | grep -v inet6 | awk '{print $2}' > master_ip
77
- fi
78
-
79
- function finish {
80
- rm -rf master_ip
81
- }
82
-
83
- trap finish EXIT INT TERM
84
-
85
- while [ ! -f master_ip ]; do
86
- echo "wait master_ip..."
87
- ls > /dev/null && sleep 1;
88
- done
89
-
90
- export MASTER_ADDR=$(cat master_ip)
91
- echo "master_ip: $MASTER_ADDR"
92
-
93
- MP_SIZE=1
94
- task_set=$2
95
- source $1
96
- DATESTR=$(date +"%m-%d-%H-%M")
97
-
98
- mkdir logs
99
- run_cmd="sudo /opt/conda/bin/python -m torch.distributed.launch --nproc_per_node=8 \
100
- --nnodes=$RLAUNCH_REPLICA_TOTAL --node_rank=$RLAUNCH_REPLICA \
101
- --master_addr=$MASTER_ADDR --master_port=12355 pretrain_cogvideo.py $@ ${gpt_options} 2>&1 | tee logs/log-${DATESTR}-${RLAUNCH_REPLICA}.txt"
102
-
103
-
104
- # run_cmd="${OPTIONS_NCCL} deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} --hostfile ${HOST_FILE_PATH} pretrain_video_swin_cond_glm_interp.py $@ ${gpt_options}"
105
- echo ${run_cmd}
106
- eval ${run_cmd}
107
-
108
- set +x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/ds_brain_pretrain_cogvideo_stage2.sh DELETED
@@ -1,108 +0,0 @@
1
- #! /bin/bash
2
-
3
- # Change for multinode config
4
-
5
- NUM_WORKERS=1
6
- NUM_GPUS_PER_WORKER=8
7
- MP_SIZE=1
8
-
9
- script_path=$(realpath $0)
10
- script_dir=$(dirname $script_path)
11
- main_dir=$(dirname $script_dir)
12
-
13
- OPTIONS_NCCL="NCCL_DEBUG=warning NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2"
14
- HOST_FILE_PATH="hostfile"
15
- # HOST_FILE_PATH="hostfile_single"
16
-
17
- video_data_test="" # TODO
18
- CHECKPOINT_PATH="" # TODO: CogView2 ckpt
19
-
20
- config_json="$script_dir/ds_config_zero.json"
21
- gpt_options=" \
22
- --experiment-name pretrain-cogvideo-stage2 \
23
- --tokenizer-type fake \
24
- --vocab-size 150010 \
25
- --model-parallel-size ${MP_SIZE} \
26
- --mode finetune \
27
- --num-workers 0 \
28
- --num-layers 48 \
29
- --hidden-size 3072 \
30
- --num-attention-heads 48 \
31
- --layout 64,464,2064 \
32
- --window-size 10 \
33
- --cogvideo-stage 2 \
34
- --additional-seqlen 2000 \
35
- --train-iters 500000 \
36
- --resume-dataloader \
37
- --train-data ${video_data_test} \
38
- --train-data-weights 1 \
39
- --split 949,50,1 \
40
- --distributed-backend nccl \
41
- --lr-decay-style cosine \
42
- --warmup .001 \
43
- --checkpoint-activations \
44
- --max-sequence-length 1024 \
45
- --fp16 \
46
- --save-interval 2000 \
47
- --eval-interval 500 \
48
- --eval-iters 15 \
49
- --log-interval 50 \
50
- --save $main_dir/checkpoints \
51
- --sandwich-ln \
52
- --load $CHECKPOINT_PATH \
53
- "
54
- # --load $CHECKPOINT_PATH \
55
- # \ --sandwich-ln
56
-
57
-
58
- gpt_options="${gpt_options}
59
- --deepspeed \
60
- --deepspeed_config ${config_json} \
61
- "
62
-
63
- #!/bin/bash
64
-
65
- # Distribute Example
66
- #export NCCL_SOCKET_IFNAME=eth0
67
- export NCCL_IB_DISABLE=0
68
- export NCCL_NET_GDR_LEVEL=2
69
- #export NCCL_IB_CUDA_SUPPORT=1
70
- #export NCCL_IB_GID_INDEX=3
71
- #export NCCL_IB_HCA=$(pushd /sys/class/infiniband/ > /dev/null; for i in mlx5_*; do cat $i/ports/1/gid_attrs/types/* 2>/dev/null | grep v >/dev/null && echo $i ; done; popd > /dev/null)
72
- export NCCL_DEBUG=info
73
- export OMP_NUM_THREADS=4
74
-
75
- if [ $RLAUNCH_REPLICA == "0" ]; then
76
- ifconfig eth0 | grep inet | grep -v inet6 | awk '{print $2}' > master_ip
77
- fi
78
-
79
- function finish {
80
- rm -rf master_ip
81
- }
82
-
83
- trap finish EXIT INT TERM
84
-
85
- while [ ! -f master_ip ]; do
86
- echo "wait master_ip..."
87
- ls > /dev/null && sleep 1;
88
- done
89
-
90
- export MASTER_ADDR=$(cat master_ip)
91
- echo "master_ip: $MASTER_ADDR"
92
-
93
- MP_SIZE=1
94
- task_set=$2
95
- source $1
96
- DATESTR=$(date +"%m-%d-%H-%M")
97
-
98
- mkdir logs
99
- run_cmd="sudo /opt/conda/bin/python -m torch.distributed.launch --nproc_per_node=8 \
100
- --nnodes=$RLAUNCH_REPLICA_TOTAL --node_rank=$RLAUNCH_REPLICA \
101
- --master_addr=$MASTER_ADDR --master_port=12355 pretrain_cogvideo.py $@ ${gpt_options} 2>&1 | tee logs/log-${DATESTR}-${RLAUNCH_REPLICA}.txt"
102
-
103
-
104
- # run_cmd="${OPTIONS_NCCL} deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} --hostfile ${HOST_FILE_PATH} pretrain_video_swin_cond_glm_interp.py $@ ${gpt_options}"
105
- echo ${run_cmd}
106
- eval ${run_cmd}
107
-
108
- set +x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/ds_config_zero.json DELETED
@@ -1,42 +0,0 @@
1
- {
2
- "train_micro_batch_size_per_gpu": 4,
3
- "gradient_accumulation_steps": 1,
4
- "steps_per_print": 1,
5
- "gradient_clipping": 0.1,
6
- "zero_optimization": {
7
- "stage": 2,
8
- "cpu_offload": true,
9
- "contiguous_gradients": false,
10
- "overlap_comm": true,
11
- "reduce_scatter": false,
12
- "reduce_bucket_size": 100000000,
13
- "allgather_bucket_size": 1000000000,
14
- "load_from_fp32_weights": false
15
- },
16
- "zero_allow_untested_optimizer": true,
17
- "fp16": {
18
- "enabled": true,
19
- "loss_scale": 0,
20
- "loss_scale_window": 400,
21
- "hysteresis": 2,
22
- "min_loss_scale": 1
23
- },
24
- "optimizer": {
25
- "type": "Adam",
26
- "params": {
27
- "lr": 0.0002,
28
- "betas": [
29
- 0.9,
30
- 0.95
31
- ],
32
- "eps": 1e-8,
33
- "weight_decay": 1e-4
34
- }
35
- },
36
- "activation_checkpointing": {
37
- "partition_activations": false,
38
- "contiguous_memory_optimization": false
39
- },
40
- "wall_clock_breakdown": false
41
- }
42
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/inference_cogvideo_pipeline.sh DELETED
@@ -1,38 +0,0 @@
1
- #!/bin/bash
2
-
3
- NLAYERS=48
4
- NHIDDEN=3072
5
- NATT=48
6
- MAXSEQLEN=1024
7
- MASTER_PORT=$(shuf -n 1 -i 10000-65535)
8
- MPSIZE=1
9
-
10
- #SAMPLING ARGS
11
- TEMP=1.05
12
- TOPK=12
13
-
14
- script_path=$(realpath $0)
15
- script_dir=$(dirname $script_path)
16
-
17
- MASTER_PORT=${MASTER_PORT} SAT_HOME=/sharefs/cogview-new python cogvideo_pipeline.py \
18
- --input-source /home/user/app/CogVideo/prompt.txt \
19
- --output-path ./output \
20
- --parallel-size 1 \
21
- --both-stages \
22
- --use-guidance-stage1 \
23
- --guidance-alpha 3.0 \
24
- --generate-frame-num 5 \
25
- --tokenizer-type fake \
26
- --mode inference \
27
- --distributed-backend nccl \
28
- --fp16 \
29
- --model-parallel-size $MPSIZE \
30
- --temperature $TEMP \
31
- --coglm-temperature2 0.89 \
32
- --top_k $TOPK \
33
- --sandwich-ln \
34
- --seed 1234 \
35
- --num-workers 0 \
36
- --batch-size 1 \
37
- --max-inference-batch-size 1 \
38
- $@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sr_pipeline/__init__.py DELETED
@@ -1,17 +0,0 @@
1
- # -*- encoding: utf-8 -*-
2
- '''
3
- @File : __init__.py
4
- @Time : 2022/03/02 13:57:09
5
- @Author : Ming Ding
6
- @Contact : dm18@mails.tsinghua.edu.cn
7
- '''
8
-
9
- # here put the import lib
10
- import os
11
- import sys
12
- import math
13
- import random
14
-
15
- from .direct_sr import DirectSuperResolution
16
- from .iterative_sr import IterativeSuperResolution
17
- from .sr_group import SRGroup
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sr_pipeline/direct_sr.py DELETED
@@ -1,117 +0,0 @@
1
- # -*- encoding: utf-8 -*-
2
- '''
3
- @File : direct_sr.py
4
- @Time : 2022/03/02 13:58:11
5
- @Author : Ming Ding
6
- @Contact : dm18@mails.tsinghua.edu.cn
7
- '''
8
-
9
- # here put the import lib
10
- import os
11
- import sys
12
- import math
13
- import random
14
- import torch
15
-
16
- # -*- encoding: utf-8 -*-
17
- '''
18
- @File : inference_cogview2.py
19
- @Time : 2021/10/10 16:31:34
20
- @Author : Ming Ding
21
- @Contact : dm18@mails.tsinghua.edu.cn
22
- '''
23
-
24
- # here put the import lib
25
- import os
26
- import sys
27
- import math
28
- import random
29
- from PIL import ImageEnhance, Image
30
-
31
- import torch
32
- import argparse
33
- from torchvision import transforms
34
-
35
- from SwissArmyTransformer import get_args
36
- from SwissArmyTransformer.training.model_io import load_checkpoint
37
- from .dsr_sampling import filling_sequence_dsr, IterativeEntfilterStrategy
38
- from SwissArmyTransformer.generation.utils import timed_name, save_multiple_images, generate_continually
39
-
40
- from .dsr_model import DsrModel
41
-
42
- from icetk import icetk as tokenizer
43
-
44
- class DirectSuperResolution:
45
- def __init__(self, args, path, max_bz=4, topk=6, onCUDA=False):
46
- args.load = path
47
- args.kernel_size = 5
48
- args.kernel_size2 = 5
49
- args.new_sequence_length = 4624
50
- args.layout = [96,496,4096]
51
-
52
- model = DsrModel(args)
53
- if args.fp16:
54
- model = model.half()
55
-
56
- load_checkpoint(model, args) # on cpu
57
- model.eval()
58
- self.model = model
59
- self.onCUDA = onCUDA
60
- if onCUDA:
61
- self.model = self.model.cuda()
62
-
63
- invalid_slices = [slice(tokenizer.num_image_tokens, None)]
64
-
65
- self.strategy = IterativeEntfilterStrategy(invalid_slices,
66
- temperature=1.0, topk=topk) # temperature not used # Temperature Freezed Here!!
67
- self.max_bz = max_bz
68
-
69
- def __call__(self, text_tokens, image_tokens, enhance=False):
70
- if len(text_tokens.shape) == 1:
71
- text_tokens.unsqueeze_(0)
72
- if len(image_tokens.shape) == 1:
73
- image_tokens.unsqueeze_(0)
74
- # ===================== Debug ======================== #
75
- # new_image_tokens = []
76
- # for small_img in image_tokens:
77
- # decoded = tokenizer.decode(image_ids=small_img)
78
- # decoded = torch.nn.functional.interpolate(decoded, size=(480, 480)).squeeze(0)
79
- # ndarr = decoded.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
80
- # image_pil_raw = ImageEnhance.Sharpness(Image.fromarray(ndarr))
81
- # small_img2 = tokenizer.encode(image_pil=image_pil_raw.enhance(1.5), image_size=480).view(-1)
82
- # new_image_tokens.append(small_img2)
83
- # image_tokens = torch.stack(new_image_tokens)
84
- # return image_tokens
85
- # ===================== END OF BLOCK ======================= #
86
- if enhance:
87
- new_image_tokens = []
88
- for small_img in image_tokens:
89
- decoded = tokenizer.decode(image_ids=small_img).squeeze(0)
90
- ndarr = decoded.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
91
- image_pil_raw = ImageEnhance.Sharpness(Image.fromarray(ndarr))
92
- small_img2 = tokenizer.encode(image_pil=image_pil_raw.enhance(1.), image_size=160).view(-1)
93
- new_image_tokens.append(small_img2)
94
- image_tokens = torch.stack(new_image_tokens)
95
-
96
- seq = torch.cat((text_tokens,image_tokens), dim=1)
97
- seq1 = torch.tensor([tokenizer['<start_of_image>']]*3601, device=image_tokens.device).unsqueeze(0).expand(text_tokens.shape[0], -1)
98
- if not self.onCUDA:
99
- print('Converting Dsr model...')
100
- model = self.model.cuda()
101
- else:
102
- model = self.model
103
- print('Direct super-resolution...')
104
- output_list = []
105
- for tim in range(max((text_tokens.shape[0]+self.max_bz-1) // self.max_bz, 1)):
106
- output1 = filling_sequence_dsr(model,
107
- seq[tim*self.max_bz:(tim+1)*self.max_bz],
108
- seq1[tim*self.max_bz:(tim+1)*self.max_bz],
109
- warmup_steps=1, block_hw=(1, 0),
110
- strategy=self.strategy
111
- )
112
- output_list.extend(output1[1:])
113
- if not self.onCUDA:
114
- print('Moving back Dsr to cpu...')
115
- model = model.cpu()
116
- torch.cuda.empty_cache()
117
- return torch.cat(output_list, dim=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sr_pipeline/dsr_model.py DELETED
@@ -1,225 +0,0 @@
1
- # -*- encoding: utf-8 -*-
2
- '''
3
- @File : cuda2d_model.py
4
- @Time : 2021/10/02 01:36:32
5
- @Author : Ming Ding
6
- @Contact : dm18@mails.tsinghua.edu.cn
7
- '''
8
-
9
- # here put the import lib
10
- import os
11
- import sys
12
- import math
13
- import random
14
- import torch
15
- import torch.nn.functional as F
16
-
17
-
18
- from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin
19
-
20
- from SwissArmyTransformer.model.transformer import split_tensor_along_last_dim, unscaled_init_method
21
- from SwissArmyTransformer.mpu.utils import sqrt
22
- from deepspeed.runtime.activation_checkpointing.checkpointing import get_cuda_rng_tracker
23
- from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear
24
-
25
- class PositionEmbeddingMixin(BaseMixin):
26
- def __init__(self, additional_sequence_length, hidden_size,
27
- init_method_std=0.02, reinit_slice=slice(512, 512+400)
28
- ):
29
- super(PositionEmbeddingMixin, self).__init__()
30
- self.reinit_slice = reinit_slice
31
- self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size)
32
- torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
33
-
34
- def reinit(self, parent_model=None):
35
- old_weights = self.transformer.position_embeddings.weight.data[self.reinit_slice]
36
- old_len, hidden_size = old_weights.shape
37
- assert hidden_size == self.position_embeddings.weight.shape[-1]
38
- old_edge, new_edge = sqrt(old_len), sqrt(self.position_embeddings.weight.shape[-2])
39
- assert new_edge % old_edge == 0
40
- self.position_embeddings.weight.data.view(new_edge // old_edge, old_edge, new_edge // old_edge, old_edge, hidden_size).copy_(old_weights.view(1, old_edge, 1, old_edge, hidden_size))
41
- # self.position_embeddings.weight.data.view(-1, old_len, hidden_size).copy_(old_weights)
42
-
43
-
44
- class AttentionMixin(BaseMixin):
45
- def __init__(self, num_layers,
46
- hidden_size,
47
- init_method=unscaled_init_method(0.02),
48
- output_layer_init_method=unscaled_init_method(0.02)
49
- ):
50
- super(AttentionMixin, self).__init__()
51
- self.num_layers = num_layers # replace attention in the LAST n layers
52
- self.query_key_value = torch.nn.ModuleList(
53
- [ColumnParallelLinear(hidden_size, 3 * hidden_size, stride=3,
54
- gather_output=False, init_method=init_method)
55
- for layer_id in range(num_layers)
56
- ])
57
- self.dense = torch.nn.ModuleList(
58
- [RowParallelLinear(hidden_size,
59
- hidden_size,
60
- input_is_parallel=True,
61
- init_method=output_layer_init_method)
62
- for layer_id in range(num_layers)
63
- ])
64
-
65
- def reinit(self, parent_model=None):
66
- start_layer = len(self.transformer.layers) - self.num_layers
67
- assert start_layer >= 0
68
- for layer_id in range(self.num_layers):
69
- old_attention = self.transformer.layers[start_layer + layer_id].attention
70
- self.query_key_value[layer_id].weight.data.copy_(old_attention.query_key_value.weight.data)
71
- self.query_key_value[layer_id].bias.data.copy_(old_attention.query_key_value.bias.data)
72
- self.dense[layer_id].weight.data.copy_(old_attention.dense.weight.data)
73
- self.dense[layer_id].bias.data.copy_(old_attention.dense.bias.data)
74
-
75
- class DsrModel(BaseModel):
76
- def __init__(self, args, transformer=None):
77
- super().__init__(args, transformer=transformer)
78
- self.original_sequence_length = args.max_sequence_length
79
- additional_seqlen = args.new_sequence_length - args.max_sequence_length
80
- self.add_mixin('extra_position_embedding', PositionEmbeddingMixin(
81
- additional_seqlen, args.hidden_size
82
- ))
83
- self.add_mixin('attention_plus', AttentionMixin(
84
- num_layers=args.num_layers,
85
- hidden_size=args.hidden_size
86
- ))
87
- self.layout = args.layout
88
- # [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1] 4095 {layout[2]}
89
- self.kernel_size = args.kernel_size
90
- self.kernel_size2 = args.kernel_size2
91
- self.log_attention_weights = None
92
-
93
- def position_embedding_forward(self, position_ids, **kw_args):
94
- position = position_ids[..., :self.layout[1]]
95
- position_plus = position_ids[..., self.layout[1]:] - self.original_sequence_length
96
- position_embeddings = torch.cat(
97
- (
98
- self.transformer.position_embeddings(position),
99
- self.get_mixin('extra_position_embedding').position_embeddings(position_plus)
100
- ),
101
- dim=-2
102
- )
103
- return position_embeddings
104
-
105
- def attention_forward(self, hidden_states, mask,
106
- layer_id=None, log_attention_weights=None, **kw_args):
107
- attn_module = self.transformer.layers[layer_id].attention
108
- # attention_plus on all layers
109
- query_key_value_plus = self.get_mixin('attention_plus').query_key_value[layer_id]
110
- dense_plus = self.get_mixin('attention_plus').dense[layer_id]
111
- # split two parts
112
- hidden_states_plus = hidden_states[:, self.layout[1]:]
113
- hidden_states = hidden_states[:, :self.layout[1]]
114
- # base model qkv
115
- mixed_raw_layer = attn_module.query_key_value(hidden_states)
116
- q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer, 3)
117
- # cuda2d model qkv
118
- mixed_raw_layer = query_key_value_plus(hidden_states_plus)
119
- q1, k1, v1 = split_tensor_along_last_dim(mixed_raw_layer, 3)
120
-
121
- dropout_fn = attn_module.attention_dropout if self.training else None
122
-
123
- # cuda2d attention
124
- context_layer0, context_layer1 = sparse_attention_2d_light(
125
- q0, k0, v0,
126
- q1, k1, v1,
127
- mask,
128
- n_head=attn_module.num_attention_heads_per_partition,
129
- text_len=self.layout[0],
130
- kernel_size=self.kernel_size,
131
- kernel_size2=self.kernel_size2,
132
- attention_dropout=dropout_fn,
133
- log_attention_weights=log_attention_weights,
134
- add_scalar=(kw_args['add_scalar'] if 'add_scalar' in kw_args else 0)
135
- )
136
-
137
- output_0 = attn_module.dense(context_layer0)
138
- output_1 = dense_plus(context_layer1)
139
- output = torch.cat((output_0, output_1), dim=1)
140
-
141
- return output
142
-
143
- def final_forward(self, logits, **kwargs):
144
- logits_parallel = logits
145
- logits_parallel = torch.nn.functional.linear(logits_parallel.float(), self.transformer.word_embeddings.weight[:20000].float())
146
- # logits_parallel = torch.nn.functional.linear(logits_parallel, self.transformer.word_embeddings.weight[:20000])
147
- return logits_parallel
148
-
149
- def disable_untrainable_params(self):
150
- self.transformer.requires_grad_(False)
151
-
152
- @classmethod
153
- def add_model_specific_args(cls, parser):
154
- group = parser.add_argument_group('Cuda2dModel', 'cuda2d model configurations')
155
- group.add_argument("--kernel-size", type=int, default=5)
156
- group.add_argument("--kernel-size2", type=int, default=5)
157
- group.add_argument("--layout", type=str, default='96,496,4096')
158
- group.add_argument("--new-sequence-length", type=int, default=4096)
159
- return parser
160
-
161
- def sparse_attention_2d_light(q0, k0, v0, q1, k1, v1, attention_mask, n_head, text_len, kernel_size=9, kernel_size2=7, attention_dropout=None, log_attention_weights = None, add_scalar=0, **kwargs):
162
- '''
163
- q0, k0, v0: [batch_size, 1088, hidden_size]
164
- q1, k1, v1: [batch_size, 4096, h2]
165
- n_head: int
166
- attention_mask: [batch_size, 1088, 1088]
167
- '''
168
- from SwissArmyTransformer.ops.local_attention_function import f_similar, f_weighting
169
-
170
- b, s0, h0 = q0.shape
171
- b, s1, h1 = q1.shape
172
- h, l0, l1 = h0 // n_head, sqrt(s0-text_len), sqrt(s1)
173
-
174
- q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
175
- v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
176
- k0T = k0.reshape(b, s0, n_head, h).permute(0, 2, 3, 1)
177
-
178
- # standard attention for level 0
179
- attention_scores = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T)
180
-
181
- if log_attention_weights is not None:
182
- attention_scores += log_attention_weights
183
- attention_scores = torch.mul(attention_scores, attention_mask) - \
184
- 10000.0 * (1.0 - attention_mask)
185
-
186
- attention_probs0 = F.softmax(attention_scores, dim=-1)
187
-
188
- # local attention for level 1
189
- q1 = (q1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1) / math.sqrt(h1//n_head)).contiguous().view(b*n_head, h1//n_head, l1, l1)
190
- k1 = k1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1)
191
- v1 = v1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1)
192
- # scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, True)
193
- scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, False)
194
-
195
- # cross attention
196
- k0T = k0T[..., -l0**2:].reshape(b*n_head, h, l0, l0).contiguous()
197
- scores_1_to_0 = f_similar(q1, k0T, kernel_size2, kernel_size2, False) # [b*n_head, l1, l1, field]
198
- scores_1 = torch.cat(
199
- (
200
- scores_1_to_0.view(b*n_head, -1, scores_1_to_0.shape[3]) + add_scalar,
201
- scores_1_to_1.view(b*n_head, -1, scores_1_to_1.shape[3])
202
- ),
203
- dim=-1)
204
- attention_probs1 = F.softmax(scores_1, dim=-1)
205
-
206
- if attention_dropout is not None:
207
- # with get_cuda_rng_tracker().fork():
208
- attention_probs0 = attention_dropout(attention_probs0)
209
- attention_probs1 = attention_dropout(attention_probs1)
210
-
211
- # weighting for level 0
212
- context0 = torch.matmul(attention_probs0, v0) # [b, n_head, s0, h]
213
- # weighting for level 1
214
- probs_1_to_1 = attention_probs1[:, :, -scores_1_to_1.shape[3]:].view_as(scores_1_to_1)
215
- # context1_to_1 = f_weighting(v1, probs_1_to_1.contiguous(), kernel_size*2-1, kernel_size, True)
216
- context1_to_1 = f_weighting(v1, probs_1_to_1.contiguous(), kernel_size*2-1, kernel_size, False)
217
-
218
- context1 = context1_to_1.view(b, n_head * h, l1**2)
219
- # weighting for cross attention
220
- probs_1_to_0 = attention_probs1[:, :, :scores_1_to_0.shape[3]].view_as(scores_1_to_0)
221
- v0_part = v0[:, :, -l0**2:].transpose(-1, -2).contiguous().view(b*n_head, h, l0, l0)
222
- context1_to_0 = f_weighting(v0_part, probs_1_to_0.contiguous(), kernel_size2, kernel_size2, False)
223
- context1_to_0 = context1_to_0.view(b, n_head * h, l1**2)
224
- context1 = context1 + context1_to_0
225
- return context0.transpose(1, 2).reshape(b, s0, h0), context1.transpose(-1, -2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sr_pipeline/dsr_sampling.py DELETED
@@ -1,159 +0,0 @@
1
- # -*- encoding: utf-8 -*-
2
- '''
3
- @File : cuda2d_sampling.py
4
- @Time : 2021/10/09 00:46:04
5
- @Author : Ming Ding
6
- @Contact : dm18@mails.tsinghua.edu.cn
7
- '''
8
-
9
- # here put the import lib
10
- import os
11
- import sys
12
- import math
13
- import random
14
- from cv2 import reduce
15
- import torch
16
-
17
- import torch
18
- import torch.nn.functional as F
19
- import numpy as np
20
-
21
- def top_k_logits_(logits, top_k=0, filter_value=-float('Inf')):
22
- indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
23
- logits[indices_to_remove] = filter_value
24
- return logits
25
-
26
- class IterativeEntfilterStrategy:
27
- def __init__(self, invalid_slices=[], temperature=1., topk=6):
28
- self.invalid_slices = invalid_slices
29
- self.temperature = temperature
30
- self.topk = topk
31
- self.cluster_labels = torch.tensor(np.load('cluster_label2.npy'), device='cuda', dtype=torch.long)
32
-
33
-
34
- def forward(self, logits_, tokens, temperature=None, entfilter=None, filter_topk=5, temperature2=None):
35
- # In interative strategy, logits are of shape [batch_size, seq_length, hidden_size]
36
- if temperature is None:
37
- temperature = self.temperature
38
-
39
- logits = logits_.float() / temperature
40
- for invalid_slice in self.invalid_slices:
41
- logits[..., invalid_slice] = -float('Inf')
42
- logits = logits.view(-1, logits.shape[-1])
43
-
44
- rprobs = F.softmax(logits.float(), dim=-1)
45
- c = self.cluster_labels.expand(*rprobs.shape)
46
- cprobs = torch.zeros(logits.shape[0], 500, device=logits.device).scatter_add_(1, c, rprobs)
47
-
48
- best_scores, best_clusters = cprobs.topk(self.topk)
49
- bz = logits.shape[0]
50
- best_scores = best_scores / best_scores.sum(dim=-1, keepdim=True)
51
- sampled_ids = torch.multinomial(best_scores, num_samples=1)
52
- selected_clusters = torch.gather(best_clusters, dim=1, index=sampled_ids)
53
- selected_mask = (self.cluster_labels.unsqueeze(0).expand(bz, -1) != selected_clusters) # cluster_labels [1, 20000] \in [0,500)
54
- logits[selected_mask] = -65504
55
- # for i in range(bz):
56
- # selected_cluster = best_clusters[i][torch.multinomial(best_scores[i] / best_scores[i].sum(), num_samples=1)]
57
- # logits[i, self.cluster_labels != selected_cluster] = -65504
58
-
59
- # logits = top_k_logits(logits, self.topk, self.top_p)
60
- probs = F.softmax(logits.float()/0.6, dim=-1) # float is essetial, due to a bug in Pytorch
61
- pred = torch.multinomial(probs, num_samples=1).view(*logits_.shape[:2])
62
-
63
- assert tokens.shape[1] == pred.shape[1] + 1
64
- tokens = torch.cat((tokens[:, :1], pred), dim=1)
65
- return tokens
66
-
67
- def filling_sequence_dsr(
68
- model,
69
- seq0,
70
- seq1,
71
- warmup_steps=3,
72
- block_hw=(4, 4),
73
- strategy=IterativeEntfilterStrategy(topk=10),
74
- ):
75
- '''
76
- seq: [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1]
77
- 4095 {layout[2]} final_token.
78
- Attention:
79
- The sampling temperature are changing, temporally we hard code them here.
80
- The temperature in the strategy is not used.
81
- '''
82
- assert hasattr(model, 'layout')
83
- layout = model.layout
84
- assert len(seq0.shape) == 2 and len(seq1.shape) == 2 \
85
- and seq0.shape[0] == seq1.shape[0]
86
- assert len(layout) == 3
87
- assert seq1.shape[1] == layout[-1] - layout[-2] + 1
88
- assert (seq1 >= 0).all() and (seq0 >= 0).all()
89
- device = seq0.device
90
- # concat and pad sequences
91
- batch_size = seq0.shape[0]
92
- n_pad = layout[1] - seq0.shape[1]
93
- assert n_pad > 0, "You should truncate long input before filling."
94
- seq = torch.cat((
95
- torch.tensor([0]*n_pad, device=device, dtype=seq0.dtype)
96
- .unsqueeze(0).expand(batch_size, n_pad),
97
- seq0, seq1), dim=1) # [b, layout[-1]+1]
98
- assert seq.shape[1] == layout[-1] + 1
99
-
100
- # build initial tokens, attention_mask, and position_ids
101
- tokens = seq.clone()
102
- attention_mask = torch.ones(layout[1], layout[1]).to(device)
103
- attention_mask[:layout[0], layout[0]:] = 0
104
- attention_mask[n_pad:, :n_pad] = 0
105
- attention_mask = attention_mask.type_as(next(model.parameters())) # if fp16
106
- position_ids = torch.cat((
107
- torch.zeros(n_pad, dtype=torch.long),
108
- torch.arange(0, layout[0] - n_pad),
109
- torch.arange(513, 513 + layout[1] - layout[0]),
110
- torch.arange(1024, 1024+layout[2]-layout[1]))).to(device)
111
- log_attention_weights = torch.zeros(layout[1], layout[1],
112
- device=device).type_as(next(model.parameters()))
113
- log_attention_weights[layout[0]:, n_pad:layout[0]] = 0.
114
-
115
- # prepare for interation
116
- unfixed = (tokens < 0) # just init an all-False tensor
117
- unfixed[:, -layout[-1] + layout[-2]:] = True
118
-
119
- ll, rr = block_hw
120
- edge_len = int(math.sqrt(layout[-1] - layout[-2]) + 1e-4)
121
- num_steps = warmup_steps + ll - 1 + rr
122
- # interative refining
123
-
124
- # unfixed[..., -(layout[-1] - layout[-2]):].view(
125
- # batch_size, edge_len//ll, ll, edge_len//rr, rr)[:, :, :, :, -1] = False
126
-
127
-
128
- ret = []
129
- ret.append(tokens[:, layout[-2]+1:].clone())
130
- for step_cnt in range(1, num_steps+1):
131
- if step_cnt <= warmup_steps:
132
- logits, *_dump = model(tokens[:,:-1], position_ids, attention_mask, log_attention_weights=log_attention_weights)
133
- real_temp = 1.
134
- new_tokens = strategy.forward(logits, tokens, real_temp)
135
- tokens[unfixed] = new_tokens[unfixed]
136
- else:
137
- logits, *_dump = model(tokens[:,:-1], position_ids, attention_mask, log_attention_weights=log_attention_weights)
138
- real_temp = 1.
139
- new_tokens = strategy.forward(
140
- logits, tokens, real_temp,
141
- entfilter=1.3,
142
- filter_topk=5,
143
- temperature2=0.6
144
- )
145
- # tokens[unfixed] = new_tokens[unfixed]
146
- # fixed tokens (update unfixed)
147
- unfixed2 = (tokens > 10000000)
148
- for x in range(min(ll, step_cnt - warmup_steps)):
149
- y = step_cnt - warmup_steps - x - 1
150
- if y < rr:
151
- unfixed[..., -(layout[-1] - layout[-2]):].view(
152
- batch_size, edge_len//ll, ll, edge_len//rr, rr)[:, :, x, :, y] = False
153
- unfixed2[..., -(layout[-1] - layout[-2]):].view(
154
- batch_size, edge_len//ll, ll, edge_len//rr, rr)[:, :, x, :, y] = True
155
- tokens[unfixed2] = new_tokens[unfixed2]
156
-
157
- ret.append(tokens[:, layout[-2]+1:].clone())
158
-
159
- return ret
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sr_pipeline/iterative_sr.py DELETED
@@ -1,118 +0,0 @@
1
- # -*- encoding: utf-8 -*-
2
- '''
3
- @File : iterative_sr.py
4
- @Time : 2022/03/02 15:57:45
5
- @Author : Ming Ding
6
- @Contact : dm18@mails.tsinghua.edu.cn
7
- '''
8
-
9
- # here put the import lib
10
- import os
11
- import sys
12
- import math
13
- import random
14
-
15
- # here put the import lib
16
- import os
17
- import sys
18
- import math
19
- import random
20
- from PIL import ImageEnhance, Image
21
-
22
- import torch
23
- import argparse
24
- from torchvision import transforms
25
-
26
- from SwissArmyTransformer.training.model_io import load_checkpoint
27
- from SwissArmyTransformer import get_args
28
- from .itersr_sampling import filling_sequence_itersr, IterativeEntfilterStrategy
29
- from SwissArmyTransformer.generation.utils import timed_name, save_multiple_images, generate_continually
30
-
31
- from .itersr_model import ItersrModel
32
-
33
- from icetk import icetk as tokenizer
34
-
35
- class IterativeSuperResolution:
36
- def __init__(self, args, path, max_bz=4, shared_transformer=None):
37
- args.load = path
38
- args.kernel_size = 5
39
- args.kernel_size2 = 5
40
- args.new_sequence_length = 4624
41
- args.layout = [16,3616]
42
-
43
- model = ItersrModel(args, transformer=shared_transformer)
44
- if args.fp16:
45
- model = model.half()
46
-
47
- load_checkpoint(model, args) # on cpu
48
- model.eval()
49
- self.model = model.cuda()
50
-
51
- # save cpu weights
52
- self.saved_weights = dict((k,v.cpu())
53
- for k, v in model.named_parameters()
54
- if 'transformer' in k
55
- )
56
-
57
- invalid_slices = [slice(tokenizer.num_image_tokens, None)]
58
-
59
- self.strategy = IterativeEntfilterStrategy(invalid_slices,
60
- temperature=args.temp_all_itersr, topk=args.topk_itersr)
61
- self.max_bz = max_bz
62
-
63
- def _restore_transformer_from_cpu(self, non_blocking=False):
64
- for k, v in self.model.named_parameters():
65
- if k in self.saved_weights:
66
- v.copy_(self.saved_weights[k])
67
-
68
- def __call__(self, text_tokens, image_tokens, enhance=False, input_mask=None):
69
- if len(text_tokens.shape) == 1:
70
- text_tokens.unsqueeze_(0)
71
- text_tokens = text_tokens.clone()[..., :16]
72
- if len(image_tokens.shape) == 1:
73
- image_tokens.unsqueeze_(0)
74
- if enhance:
75
- new_image_tokens = []
76
- for big_img in image_tokens:
77
- decoded = tokenizer.decode(image_ids=big_img).squeeze(0)
78
- ndarr = decoded.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
79
- image_pil_raw = ImageEnhance.Sharpness(Image.fromarray(ndarr))
80
- big_img2 = tokenizer.encode(image_pil=image_pil_raw.enhance(1.5), image_size=480).view(-1)
81
- new_image_tokens.append(big_img2)
82
- image_tokens = torch.stack(new_image_tokens)
83
- print('Converting Itersr model...')
84
- self._restore_transformer_from_cpu()
85
- model = self.model
86
- print('iterative super-resolution...')
87
- output_list = []
88
- for tim in range(max(text_tokens.shape[0] // self.max_bz, 1)):
89
- big_img = image_tokens[tim*self.max_bz:(tim+1)*self.max_bz]
90
- text_seq = text_tokens[tim*self.max_bz:(tim+1)*self.max_bz]
91
- mask_raw = torch.tensor(
92
- [
93
- -1, 0, 1, 2, 3, 4,
94
- 0, -1, 2, -1, -2, 5,
95
- 1, -2, 3, 4, 5, 6,
96
- 2, 3, 4, 5, -1, 1,
97
- 3, -1, -2, 0, -1, 2,
98
- 4, 5, 6, 1, 3, -2
99
- ]
100
- ).view(1, 6, 1, 6).expand(10, 6, 10, 6).reshape(-1).contiguous()
101
-
102
- topks = [60, 40, 40, 40, 20, 20, 10]
103
-
104
- for mask_ratio in range(1, 7):
105
- self.strategy.topk = topks[mask_ratio]
106
- mask = (mask_raw.to(big_img.device) >= mask_ratio)
107
- if input_mask is not None:
108
- mask = mask & input_mask
109
- big_img.masked_fill_(mask, tokenizer['<start_of_image>'])
110
- seq1 = big_img
111
- output1 = filling_sequence_itersr(model, text_seq, seq1,
112
- warmup_steps=1, block_hw=(1, 0),
113
- strategy=self.strategy
114
- )
115
- big_img = output1
116
- print(f'Iter {mask_ratio} times.')
117
- output_list.append(output1.clone())
118
- return torch.cat(output_list, dim=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sr_pipeline/itersr_model.py DELETED
@@ -1,232 +0,0 @@
1
- # -*- encoding: utf-8 -*-
2
- '''
3
- @File : itersr_model.py
4
- @Time : 2021/10/02 01:36:32
5
- @Author : Ming Ding
6
- @Contact : dm18@mails.tsinghua.edu.cn
7
- '''
8
-
9
- # here put the import lib
10
- import os
11
- import sys
12
- import math
13
- import random
14
- import torch
15
- import torch.nn.functional as F
16
-
17
-
18
- from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin
19
-
20
- from SwissArmyTransformer.mpu.utils import sqrt
21
- from deepspeed.runtime.activation_checkpointing.checkpointing import get_cuda_rng_tracker
22
- from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear
23
- from SwissArmyTransformer.model.transformer import unscaled_init_method, split_tensor_along_last_dim
24
-
25
- class PositionEmbeddingMixin(BaseMixin):
26
- def __init__(self, additional_sequence_length, hidden_size,
27
- init_method_std=0.02, reinit_slice=slice(512, 512+400)
28
- ):
29
- super(PositionEmbeddingMixin, self).__init__()
30
- self.reinit_slice = reinit_slice
31
- self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size)
32
- torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
33
-
34
- def reinit(self, parent_model=None):
35
- old_weights = self.transformer.position_embeddings.weight.data[self.reinit_slice]
36
- old_len, hidden_size = old_weights.shape
37
- assert hidden_size == self.position_embeddings.weight.shape[-1]
38
- old_edge, new_edge = sqrt(old_len), sqrt(self.position_embeddings.weight.shape[-2])
39
- assert new_edge % old_edge == 0
40
- self.position_embeddings.weight.data.view(new_edge // old_edge, old_edge, new_edge // old_edge, old_edge, hidden_size).copy_(old_weights.view(1, old_edge, 1, old_edge, hidden_size))
41
-
42
- class ItersrModel(BaseModel):
43
- def __init__(self, args, transformer=None):
44
- super().__init__(args, transformer=transformer)
45
- self.original_sequence_length = args.max_sequence_length
46
- additional_seqlen = args.new_sequence_length - args.max_sequence_length
47
- self.add_mixin('extra_position_embedding', PositionEmbeddingMixin(
48
- additional_seqlen, args.hidden_size
49
- ))
50
- # self.add_mixin('attention_plus', AttentionMixin(
51
- # num_layers=args.num_layers,
52
- # hidden_size=args.hidden_size
53
- # ))
54
- self.layout = args.layout
55
- # [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1] 4095 {layout[2]}
56
- self.kernel_size = args.kernel_size
57
- self.kernel_size2 = args.kernel_size2
58
- self.log_attention_weights = None
59
-
60
- def position_embedding_forward(self, position_ids, **kw_args):
61
- position = position_ids[..., :self.layout[0]]
62
- position_plus = position_ids[..., self.layout[0]:] - self.original_sequence_length
63
- position_embeddings = torch.cat(
64
- (
65
- self.transformer.position_embeddings(position),
66
- self.get_mixin('extra_position_embedding').position_embeddings(position_plus)
67
- ),
68
- dim=-2
69
- )
70
- return position_embeddings
71
-
72
- def attention_forward(self, hidden_states, mask,
73
- layer_id=None, log_attention_weights=None, **kw_args):
74
- attn_module = self.transformer.layers[layer_id].attention
75
- # base model qkv
76
- mixed_raw_layer = attn_module.query_key_value(hidden_states)
77
- q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer[:, :self.layout[0]], 3)
78
- # cuda2d model qkv
79
- q1, k1, v1 = split_tensor_along_last_dim(mixed_raw_layer[:, self.layout[0]:], 3)
80
-
81
- dropout_fn = attn_module.attention_dropout if self.training else None
82
-
83
- # cuda2d attention
84
- context_layer = sparse_attention_2d_text(
85
- q0, k0, v0,
86
- q1, k1, v1,
87
- mask,
88
- n_head=attn_module.num_attention_heads_per_partition,
89
- text_len=self.layout[0],
90
- kernel_size=self.kernel_size,
91
- attention_dropout=dropout_fn,
92
- log_attention_weights=log_attention_weights,
93
- )
94
-
95
- output = attn_module.dense(context_layer)
96
-
97
- return output
98
-
99
- def final_forward(self, logits, **kwargs):
100
- logits_parallel = logits
101
- logits_parallel = torch.nn.functional.linear(logits_parallel, self.transformer.word_embeddings.weight[:20000]).float()
102
- # logits_parallel = torch.nn.functional.linear(logits_parallel, self.transformer.word_embeddings.weight[:20000])
103
- return logits_parallel
104
-
105
- # def disable_untrainable_params(self):
106
- # self.transformer.requires_grad_(False)
107
-
108
- @classmethod
109
- def add_model_specific_args(cls, parser):
110
- group = parser.add_argument_group('Cuda2dModel', 'cuda2d model configurations')
111
- group.add_argument("--kernel-size", type=int, default=5)
112
- group.add_argument("--kernel-size2", type=int, default=5)
113
- group.add_argument("--layout", type=str, default='16,3616')
114
- group.add_argument("--new-sequence-length", type=int, default=4096)
115
- return parser
116
-
117
- def sparse_attention_2d_text(q0, k0, v0, q1, k1, v1, attention_mask, n_head, text_len, kernel_size=9, attention_dropout=None, log_attention_weights = None, **kwargs):
118
- '''
119
- q0, k0, v0: [batch_size, 16, hidden_size]
120
- q1, k1, v1: [batch_size, 3600, hidden_size]
121
- n_head: int
122
- attention_mask: [batch_size, 16]
123
- '''
124
- from SwissArmyTransformer.ops.local_attention_function import f_similar, f_weighting
125
- b, s0, h0 = q0.shape
126
- b, s1, h1 = q1.shape
127
- h, l1 = h0 // n_head, sqrt(s1)
128
- assert attention_mask.shape[-1] == s0, f"Mask Shape: {attention_mask.shape}"
129
-
130
- q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
131
- v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
132
- k0T = k0.reshape(b, s0, n_head, h).permute(0, 2, 3, 1)
133
-
134
- # standard attention for level 0
135
- attention_scores = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T)
136
-
137
- attention_scores = torch.mul(attention_scores, attention_mask) - \
138
- 10000.0 * (1.0 - attention_mask)
139
-
140
- attention_probs0 = F.softmax(attention_scores, dim=-1)
141
-
142
- # local attention for level 1
143
- q1 = (q1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1) / math.sqrt(h1//n_head)).contiguous().view(b*n_head, h1//n_head, l1, l1)
144
- k1 = k1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1)
145
- v1 = v1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1)
146
- scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, False)
147
-
148
- # cross attention
149
- scores_1_to_0 = torch.matmul(q1.view(b, n_head, h, s1).transpose(-1, -2), k0T)
150
- if log_attention_weights is not None:
151
- scores_1_to_0 += log_attention_weights
152
- scores_1_to_0 = torch.mul(scores_1_to_0, attention_mask) - \
153
- 10000.0 * (1.0 - attention_mask)
154
- scores_1 = torch.cat(
155
- (
156
- scores_1_to_0.view(b*n_head, s1, s0),
157
- scores_1_to_1.view(b*n_head, -1, scores_1_to_1.shape[3])
158
- ),
159
- dim=-1)
160
- attention_probs1 = F.softmax(scores_1, dim=-1)
161
-
162
- if attention_dropout is not None:
163
- with get_cuda_rng_tracker().fork():
164
- attention_probs1 = attention_dropout(attention_probs1)
165
-
166
- # weighting for level 0
167
- context0 = torch.matmul(attention_probs0, v0) # [b, n_head, s0, h]
168
- # weighting for level 1
169
- probs_1_to_1 = attention_probs1[:, :, -scores_1_to_1.shape[3]:].view_as(scores_1_to_1)
170
- context1_to_1 = f_weighting(v1, probs_1_to_1.contiguous(), kernel_size*2-1, kernel_size, False)
171
-
172
- context1 = context1_to_1.view(b, n_head, h, l1**2)
173
- # weighting for cross attention
174
- probs_1_to_0 = attention_probs1[:, :, :scores_1_to_0.shape[3]].view(b, n_head, -1, scores_1_to_0.shape[3])
175
-
176
- context1_to_0 = torch.matmul(probs_1_to_0, v0)
177
- context1 = context1.transpose(-1, -2) + context1_to_0
178
-
179
- output = torch.cat((context0, context1), dim=2).transpose(1, 2).reshape(b, s0+s1, h0)
180
-
181
- return output
182
-
183
- def sparse_attention_2d_notext(q0, k0, v0, q1, k1, v1, attention_mask, n_head, text_len, kernel_size=9, attention_dropout=None, log_attention_weights = None, **kwargs):
184
- '''
185
- q0, k0, v0: [batch_size, 16, hidden_size]
186
- q1, k1, v1: [batch_size, 3600, hidden_size]
187
- n_head: int
188
- attention_mask: [batch_size, 16]
189
- '''
190
- from SwissArmyTransformer.mpu.local_attention_function import f_similar, f_weighting
191
- b, s0, h0 = q0.shape
192
- b, s1, h1 = q1.shape
193
- h, l1 = h0 // n_head, sqrt(s1)
194
- assert len(attention_mask.shape) == 4 and attention_mask.shape[-1] == s0, f"Mask Shape: {attention_mask.shape}"
195
-
196
- q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
197
- v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
198
- k0T = k0.reshape(b, s0, n_head, h).permute(0, 2, 3, 1)
199
-
200
- # standard attention for level 0
201
- attention_scores = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T)
202
-
203
- attention_scores = torch.mul(attention_scores, attention_mask) - \
204
- 10000.0 * (1.0 - attention_mask)
205
-
206
- attention_probs0 = F.softmax(attention_scores, dim=-1)
207
-
208
- # local attention for level 1
209
- q1 = (q1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1) / math.sqrt(h1//n_head)).contiguous().view(b*n_head, h1//n_head, l1, l1)
210
- k1 = k1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1)
211
- v1 = v1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1)
212
- scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, False)
213
-
214
- attention_probs1 = F.softmax(scores_1_to_1, dim=-1)
215
-
216
- if attention_dropout is not None:
217
- with get_cuda_rng_tracker().fork():
218
- attention_probs1 = attention_dropout(attention_probs1)
219
-
220
- # weighting for level 0
221
- context0 = torch.matmul(attention_probs0, v0) # [b, n_head, s0, h]
222
- # weighting for level 1
223
- probs_1_to_1 = attention_probs1
224
- context1_to_1 = f_weighting(v1, probs_1_to_1.contiguous(), kernel_size*2-1, kernel_size, False)
225
-
226
- context1 = context1_to_1.view(b, n_head, h, l1**2)
227
- # weighting for cross attention
228
- context1 = context1.transpose(-1, -2)
229
-
230
- output = torch.cat((context0, context1), dim=2).transpose(1, 2).reshape(b, s0+s1, h0)
231
-
232
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sr_pipeline/itersr_sampling.py DELETED
@@ -1,168 +0,0 @@
1
- # -*- encoding: utf-8 -*-
2
- '''
3
- @File : itersr_sampling.py
4
- @Time : 2022/03/03 14:24:28
5
- @Author : Ming Ding
6
- @Contact : dm18@mails.tsinghua.edu.cn
7
- '''
8
-
9
- # here put the import lib
10
- import os
11
- import sys
12
- import math
13
- import random
14
- import numpy as np
15
-
16
- import torch
17
- import torch.nn.functional as F
18
- from icetk import icetk as tokenizer
19
-
20
- def top_k_logits_(logits, top_k=0, filter_value=-float('Inf')):
21
- indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
22
- logits[indices_to_remove] = filter_value
23
- return logits
24
-
25
- # class IterativeEntfilterStrategy:
26
- # def __init__(self, invalid_slices=[], temperature=1., topk=10):
27
- # self.invalid_slices = invalid_slices
28
- # self.temperature = temperature
29
- # self.topk = topk
30
- # self.cluster_labels = torch.tensor(np.load('cluster_label.npy'), device='cuda', dtype=torch.long)
31
-
32
-
33
- # def forward(self, logits_, tokens, temperature=None, entfilter=None, filter_topk=5, temperature2=None):
34
- # # In interative strategy, logits are of shape [batch_size, seq_length, hidden_size]
35
- # if temperature is None:
36
- # temperature = self.temperature
37
-
38
- # logits = logits_.float() / temperature
39
- # for invalid_slice in self.invalid_slices:
40
- # logits[..., invalid_slice] = -float('Inf')
41
- # logits = logits.view(-1, logits.shape[-1])
42
-
43
- # rprobs = F.softmax(logits.float(), dim=-1)
44
- # c = self.cluster_labels.expand(*rprobs.shape)
45
- # cprobs = torch.zeros(logits.shape[0], 500, device=logits.device).scatter_add_(1, c, rprobs)
46
-
47
- # best_scores, best_clusters = cprobs.topk(self.topk)
48
- # bz = logits.shape[0]
49
- # best_scores = best_scores / best_scores.sum(dim=-1, keepdim=True)
50
- # sampled_ids = torch.multinomial(best_scores, num_samples=1)
51
- # selected_clusters = torch.gather(best_clusters, dim=1, index=sampled_ids)
52
- # selected_mask = (self.cluster_labels.unsqueeze(0).expand(bz, -1) != selected_clusters) # cluster_labels [1, 20000] \in [0,500)
53
- # logits[selected_mask] = -65504
54
- # # for i in range(bz):
55
- # # selected_cluster = best_clusters[i][torch.multinomial(best_scores[i] / best_scores[i].sum(), num_samples=1)]
56
- # # logits[i, self.cluster_labels != selected_cluster] = -65504
57
-
58
- # # logits = top_k_logits(logits, self.topk, self.top_p)
59
- # probs = F.softmax(logits.float(), dim=-1) # float is essetial, due to a bug in Pytorch
60
- # pred = torch.multinomial(probs, num_samples=1).view(*logits_.shape[:2])
61
-
62
- # assert tokens.shape[1] == pred.shape[1]
63
- # tokens = pred
64
- # return tokens
65
-
66
- class IterativeEntfilterStrategy:
67
- def __init__(self, invalid_slices=[], temperature=1., topk=10):
68
- self.invalid_slices = invalid_slices
69
- self.temperature = temperature
70
- self.topk = topk
71
-
72
- def forward(self, logits, tokens, temperature=None, entfilter=None, filter_topk=5, temperature2=None):
73
- # In interative strategy, logits are of shape [batch_size, seq_length, hidden_size]
74
- if temperature is None:
75
- temperature = self.temperature
76
- # check entropy filter
77
- # if entfilter is not None:
78
- # assert temperature2 is not None
79
- # topraw = (torch.topk(logits, filter_topk, dim=-1)[0]).softmax(dim=-1)
80
- # ent = -(topraw * topraw.log()).sum(dim=-1) # [batch_size, seq_length]
81
- # temperature = torch.tensor([[[temperature - temperature2]]], device=logits.device).expand(*logits.shape[:2], 1) * (ent > entfilter).unsqueeze(-1) + temperature2
82
-
83
- logits = logits.float() / temperature
84
- for invalid_slice in self.invalid_slices:
85
- logits[..., invalid_slice] = -float('Inf')
86
-
87
- # debiased topk
88
- # probs = F.softmax(logits, dim=-1)
89
- # tk_value, tk_idx = torch.topk(probs, self.topk, dim=-1)
90
- # pred = torch.multinomial(probs.view(-1, logits.shape[-1]), num_samples=1).view(*logits.shape[:2], 1)
91
- # edge_idx = tk_idx[:, :, -1:]
92
- # edge_value = tk_value[:, :, -1:]
93
- # edge_mask = probs.gather(dim=-1, index=pred) < edge_value
94
- # pred[edge_mask] = edge_idx[edge_mask] # replace outliers as the "filter_topk"-th token
95
- # pred.squeeze_(-1) # [batch_size, seq_length]
96
-
97
- top_k_logits_(logits, self.topk)
98
- probs = F.softmax(logits, dim=-1)
99
- pred = torch.multinomial(probs.view(-1, logits.shape[-1]), num_samples=1).view(*logits.shape[:2], 1)
100
- pred.squeeze_(-1)
101
-
102
- assert tokens.shape[1] == pred.shape[1]
103
- tokens = pred
104
- return tokens
105
-
106
- def filling_sequence_itersr(
107
- model,
108
- seq0,
109
- seq1,
110
- warmup_steps=3,
111
- block_hw=(4, 4),
112
- strategy=IterativeEntfilterStrategy(topk=10),
113
- ):
114
- '''
115
- seq: [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1]
116
- 4095 {layout[2]} final_token.
117
- Attention:
118
- The sampling temperature are changing, temporally we hard code them here.
119
- The temperature in the strategy is not used.
120
- '''
121
- assert hasattr(model, 'layout')
122
- layout = model.layout
123
-
124
- device = seq0.device
125
- # concat and pad sequences
126
- batch_size = seq0.shape[0]
127
- n_pad = layout[0] - seq0.shape[1]
128
- assert n_pad >= 0, "You should truncate long input before filling."
129
- seq = torch.cat((
130
- torch.tensor([0]*n_pad, device=device, dtype=seq0.dtype)
131
- .unsqueeze(0).expand(batch_size, n_pad),
132
- seq0, seq1), dim=1) # [b, layout[-1]+1]
133
- assert seq.shape[1] == layout[-1]
134
-
135
- # build initial tokens, attention_mask, and position_ids
136
- tokens = seq.clone()
137
- attention_mask = torch.ones(layout[0]).to(device)
138
- attention_mask[:n_pad] = 0
139
- attention_mask = attention_mask.unsqueeze(0).type_as(next(model.parameters())) # if fp16
140
- position_ids = torch.cat((
141
- torch.zeros(n_pad, dtype=torch.long),
142
- torch.arange(0, layout[0] - n_pad),
143
- torch.arange(1024, 1024+layout[1]-layout[0]))).to(device)
144
- log_attention_weights = torch.zeros(layout[0], device=device).type_as(next(model.parameters()))
145
- log_attention_weights[n_pad:layout[0]] = 0.
146
- log_attention_weights = log_attention_weights.unsqueeze(0)
147
-
148
- # prepare for interation
149
- unfixed = (tokens == tokenizer['<start_of_image>'])
150
- ll, rr = block_hw
151
- edge_len = int(math.sqrt(layout[-1] - layout[-2]) + 1e-4)
152
- num_steps = 1
153
- # interative refining
154
-
155
- # unfixed[..., -(layout[-1] - layout[-2]):].view(
156
- # batch_size, edge_len//ll, ll, edge_len//rr, rr)[:, :, :, :, -1] = False
157
-
158
-
159
- ret = []
160
- # ret.append(tokens[:, layout[-2]:-1].clone())
161
- for step_cnt in range(1, num_steps+1):
162
- logits, *_dump = model(tokens, position_ids, attention_mask, log_attention_weights=log_attention_weights)
163
- real_temp = 1.
164
- new_tokens = strategy.forward(logits, tokens, real_temp)
165
- tokens[unfixed] = new_tokens[unfixed]
166
-
167
- ret.append(tokens[:, layout[-2]:].clone())
168
- return torch.cat(ret, dim=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sr_pipeline/sr_group.py DELETED
@@ -1,49 +0,0 @@
1
- # -*- encoding: utf-8 -*-
2
- '''
3
- @File : sr_group.py
4
- @Time : 2022/04/02 01:17:21
5
- @Author : Ming Ding
6
- @Contact : dm18@mails.tsinghua.edu.cn
7
- '''
8
-
9
- # here put the import lib
10
- import os
11
- import sys
12
- import math
13
- import random
14
-
15
- import numpy as np
16
- import torch
17
- import torch.nn.functional as F
18
- from SwissArmyTransformer.resources import auto_create
19
- from .direct_sr import DirectSuperResolution
20
- from .iterative_sr import IterativeSuperResolution
21
-
22
- class SRGroup:
23
- def __init__(self, args, home_path=None,):
24
- dsr_path = auto_create('cogview2-dsr', path=home_path)
25
- itersr_path = auto_create('cogview2-itersr', path=home_path)
26
- dsr = DirectSuperResolution(args, dsr_path)
27
- itersr = IterativeSuperResolution(args, itersr_path, shared_transformer=dsr.model.transformer)
28
- self.dsr = dsr
29
- self.itersr = itersr
30
-
31
- def sr_base(self, img_tokens, txt_tokens):
32
- assert img_tokens.shape[-1] == 400 and len(img_tokens.shape) == 2
33
- batch_size = img_tokens.shape[0]
34
- txt_len = txt_tokens.shape[-1]
35
- if len(txt_tokens.shape) == 1:
36
- txt_tokens = txt_tokens.unsqueeze(0).expand(batch_size, txt_len)
37
- sred_tokens = self.dsr(txt_tokens, img_tokens)
38
- iter_tokens = self.itersr(txt_tokens, sred_tokens[:, -3600:].clone())
39
- return iter_tokens[-batch_size:]
40
-
41
- # def sr_patch(self, img_tokens, txt_tokens):
42
- # assert img_tokens.shape[-1] == 3600 and len(img_tokens.shape) == 2
43
- # batch_size = img_tokens.shape[0] * 9
44
- # txt_len = txt_tokens.shape[-1]
45
- # if len(txt_tokens.shape) == 1:
46
- # txt_tokens = txt_tokens.unsqueeze(0).expand(batch_size, txt_len)
47
- # img_tokens = img_tokens.view(img_tokens.shape[0], 3, 20, 3, 20).permute(0, 1, 3, 2, 4).reshape(batch_size, 400)
48
- # iter_tokens = self.sr_base(img_tokens, txt_tokens)
49
- # return iter_tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
style.css ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
4
+ img#visitor-badge {
5
+ display: block;
6
+ margin: auto;
7
+ }