hysts commited on
Commit
8aac645
1 Parent(s): 55efca8
Files changed (11) hide show
  1. .gitmodules +3 -0
  2. .pre-commit-config.yaml +46 -0
  3. .style.yapf +5 -0
  4. CogVideo +1 -0
  5. app.py +93 -0
  6. icetk_models/.gitkeep +0 -0
  7. model.py +1180 -0
  8. patch +51 -0
  9. pretrained/.gitkeep +0 -0
  10. requirements.txt +10 -0
  11. style.css +7 -0
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ [submodule "CogVideo"]
2
+ path = CogVideo
3
+ url = https://github.com/THUDM/CogVideo
.pre-commit-config.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exclude: ^patch
2
+ repos:
3
+ - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v4.2.0
5
+ hooks:
6
+ - id: check-executables-have-shebangs
7
+ - id: check-json
8
+ - id: check-merge-conflict
9
+ - id: check-shebang-scripts-are-executable
10
+ - id: check-toml
11
+ - id: check-yaml
12
+ - id: double-quote-string-fixer
13
+ - id: end-of-file-fixer
14
+ - id: mixed-line-ending
15
+ args: ['--fix=lf']
16
+ - id: requirements-txt-fixer
17
+ - id: trailing-whitespace
18
+ - repo: https://github.com/myint/docformatter
19
+ rev: v1.4
20
+ hooks:
21
+ - id: docformatter
22
+ args: ['--in-place']
23
+ - repo: https://github.com/pycqa/isort
24
+ rev: 5.10.1
25
+ hooks:
26
+ - id: isort
27
+ - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v0.812
29
+ hooks:
30
+ - id: mypy
31
+ args: ['--ignore-missing-imports']
32
+ - repo: https://github.com/google/yapf
33
+ rev: v0.32.0
34
+ hooks:
35
+ - id: yapf
36
+ args: ['--parallel', '--in-place']
37
+ - repo: https://github.com/kynan/nbstripout
38
+ rev: 0.5.0
39
+ hooks:
40
+ - id: nbstripout
41
+ args: ['--extra-keys', 'metadata.interpreter metadata.kernelspec cell.metadata.pycharm']
42
+ - repo: https://github.com/nbQA-dev/nbQA
43
+ rev: 1.3.1
44
+ hooks:
45
+ - id: nbqa-isort
46
+ - id: nbqa-yapf
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
CogVideo ADDED
@@ -0,0 +1 @@
 
1
+ Subproject commit ff423aa169978fb2f636f761e348631fa3178b03
app.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+
7
+ import gradio as gr
8
+
9
+ from model import AppModel
10
+
11
+ DESCRIPTION = '''# <a href="https://github.com/THUDM/CogVideo">CogVideo</a>
12
+
13
+ The model takes only Chinese as input.
14
+ If you check the "Translate to Chinese" checkbox, the app will use the English to Chinese translation results with [this Space](https://huggingface.co/spaces/chinhon/translation_eng2ch) as input.
15
+ But the translation model may mistranslate and the results could be poor.
16
+ So, it is also a good idea to input the translation results from other translation services.
17
+ '''
18
+
19
+
20
+ def parse_args() -> argparse.Namespace:
21
+ parser = argparse.ArgumentParser()
22
+ parser.add_argument('--only-first-stage', action='store_true')
23
+ parser.add_argument('--share', action='store_true')
24
+ return parser.parse_args()
25
+
26
+
27
+ def set_example_text(example: list) -> dict:
28
+ return gr.Textbox.update(value=example[0])
29
+
30
+
31
+ def main():
32
+ args = parse_args()
33
+ model = AppModel(args.only_first_stage)
34
+
35
+ with gr.Blocks(css='style.css') as demo:
36
+ gr.Markdown(DESCRIPTION)
37
+
38
+ with gr.Row():
39
+ with gr.Column():
40
+ with gr.Group():
41
+ text = gr.Textbox(label='Input Text')
42
+ translate = gr.Checkbox(label='Translate to Chinese',
43
+ value=False)
44
+ seed = gr.Slider(0,
45
+ 100000,
46
+ step=1,
47
+ value=1234,
48
+ label='Seed')
49
+ only_first_stage = gr.Checkbox(
50
+ label='Only First Stage',
51
+ value=args.only_first_stage,
52
+ visible=not args.only_first_stage)
53
+ run_button = gr.Button('Run')
54
+
55
+ with open('samples.txt') as f:
56
+ samples = [
57
+ line.strip().split('\t') for line in f.readlines()
58
+ ]
59
+ examples = gr.Dataset(components=[text], samples=samples)
60
+
61
+ with gr.Column():
62
+ with gr.Group():
63
+ translated_text = gr.Textbox(label='Translated Text')
64
+ with gr.Tabs():
65
+ with gr.TabItem('Output (Video)'):
66
+ result_video = gr.Video(show_label=False)
67
+ with gr.TabItem('Output (Gallery)'):
68
+ result_gallery = gr.Gallery(show_label=False)
69
+
70
+ run_button.click(fn=model.run_with_translation,
71
+ inputs=[
72
+ text,
73
+ translate,
74
+ seed,
75
+ only_first_stage,
76
+ ],
77
+ outputs=[
78
+ translated_text,
79
+ result_video,
80
+ result_gallery,
81
+ ])
82
+ examples.click(fn=set_example_text,
83
+ inputs=examples,
84
+ outputs=examples.components)
85
+
86
+ demo.launch(
87
+ enable_queue=True,
88
+ share=args.share,
89
+ )
90
+
91
+
92
+ if __name__ == '__main__':
93
+ main()
icetk_models/.gitkeep ADDED
File without changes
model.py ADDED
@@ -0,0 +1,1180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is adapted from https://github.com/THUDM/CogView2/blob/4e55cce981eb94b9c8c1f19ba9f632fd3ee42ba8/cogview2_text2image.py
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import functools
7
+ import logging
8
+ import pathlib
9
+ import sys
10
+ import tempfile
11
+ import time
12
+ from typing import Any
13
+
14
+ import gradio as gr
15
+ import imageio.v2 as iio
16
+ import numpy as np
17
+ import torch
18
+ from icetk import IceTokenizer
19
+ from SwissArmyTransformer import get_args
20
+ from SwissArmyTransformer.arguments import set_random_seed
21
+ from SwissArmyTransformer.generation.sampling_strategies import BaseStrategy
22
+ from SwissArmyTransformer.resources import auto_create
23
+
24
+ app_dir = pathlib.Path(__file__).parent
25
+ submodule_dir = app_dir / 'CogVideo'
26
+ sys.path.insert(0, submodule_dir.as_posix())
27
+
28
+ from coglm_strategy import CoglmStrategy
29
+ from models.cogvideo_cache_model import CogVideoCacheModel
30
+ from sr_pipeline import DirectSuperResolution
31
+
32
+ formatter = logging.Formatter(
33
+ '[%(asctime)s] %(name)s %(levelname)s: %(message)s',
34
+ datefmt='%Y-%m-%d %H:%M:%S')
35
+ stream_handler = logging.StreamHandler(stream=sys.stdout)
36
+ stream_handler.setLevel(logging.INFO)
37
+ stream_handler.setFormatter(formatter)
38
+ logger = logging.getLogger(__name__)
39
+ logger.setLevel(logging.INFO)
40
+ logger.propagate = False
41
+ logger.addHandler(stream_handler)
42
+
43
+ ICETK_MODEL_DIR = app_dir / 'icetk_models'
44
+
45
+
46
+ def get_masks_and_position_ids_stage1(data, textlen, framelen):
47
+ # Extract batch size and sequence length.
48
+ tokens = data
49
+ seq_length = len(data[0])
50
+ # Attention mask (lower triangular).
51
+ attention_mask = torch.ones((1, textlen + framelen, textlen + framelen),
52
+ device=data.device)
53
+ attention_mask[:, :textlen, textlen:] = 0
54
+ attention_mask[:, textlen:, textlen:].tril_()
55
+ attention_mask.unsqueeze_(1)
56
+ # Unaligned version
57
+ position_ids = torch.zeros(seq_length,
58
+ dtype=torch.long,
59
+ device=data.device)
60
+ torch.arange(textlen,
61
+ out=position_ids[:textlen],
62
+ dtype=torch.long,
63
+ device=data.device)
64
+ torch.arange(512,
65
+ 512 + seq_length - textlen,
66
+ out=position_ids[textlen:],
67
+ dtype=torch.long,
68
+ device=data.device)
69
+ position_ids = position_ids.unsqueeze(0)
70
+
71
+ return tokens, attention_mask, position_ids
72
+
73
+
74
+ def get_masks_and_position_ids_stage2(data, textlen, framelen):
75
+ # Extract batch size and sequence length.
76
+ tokens = data
77
+ seq_length = len(data[0])
78
+
79
+ # Attention mask (lower triangular).
80
+ attention_mask = torch.ones((1, textlen + framelen, textlen + framelen),
81
+ device=data.device)
82
+ attention_mask[:, :textlen, textlen:] = 0
83
+ attention_mask[:, textlen:, textlen:].tril_()
84
+ attention_mask.unsqueeze_(1)
85
+
86
+ # Unaligned version
87
+ position_ids = torch.zeros(seq_length,
88
+ dtype=torch.long,
89
+ device=data.device)
90
+ torch.arange(textlen,
91
+ out=position_ids[:textlen],
92
+ dtype=torch.long,
93
+ device=data.device)
94
+ frame_num = (seq_length - textlen) // framelen
95
+ assert frame_num == 5
96
+ torch.arange(512,
97
+ 512 + framelen,
98
+ out=position_ids[textlen:textlen + framelen],
99
+ dtype=torch.long,
100
+ device=data.device)
101
+ torch.arange(512 + framelen * 2,
102
+ 512 + framelen * 3,
103
+ out=position_ids[textlen + framelen:textlen + framelen * 2],
104
+ dtype=torch.long,
105
+ device=data.device)
106
+ torch.arange(512 + framelen * (frame_num - 1),
107
+ 512 + framelen * frame_num,
108
+ out=position_ids[textlen + framelen * 2:textlen +
109
+ framelen * 3],
110
+ dtype=torch.long,
111
+ device=data.device)
112
+ torch.arange(512 + framelen * 1,
113
+ 512 + framelen * 2,
114
+ out=position_ids[textlen + framelen * 3:textlen +
115
+ framelen * 4],
116
+ dtype=torch.long,
117
+ device=data.device)
118
+ torch.arange(512 + framelen * 3,
119
+ 512 + framelen * 4,
120
+ out=position_ids[textlen + framelen * 4:textlen +
121
+ framelen * 5],
122
+ dtype=torch.long,
123
+ device=data.device)
124
+
125
+ position_ids = position_ids.unsqueeze(0)
126
+
127
+ return tokens, attention_mask, position_ids
128
+
129
+
130
+ def my_update_mems(hiddens, mems_buffers, mems_indexs,
131
+ limited_spatial_channel_mem, text_len, frame_len):
132
+ if hiddens is None:
133
+ return None, mems_indexs
134
+ mem_num = len(hiddens)
135
+ ret_mem = []
136
+ with torch.no_grad():
137
+ for id in range(mem_num):
138
+ if hiddens[id][0] is None:
139
+ ret_mem.append(None)
140
+ else:
141
+ if id == 0 and limited_spatial_channel_mem and mems_indexs[
142
+ id] + hiddens[0][0].shape[1] >= text_len + frame_len:
143
+ if mems_indexs[id] == 0:
144
+ for layer, hidden in enumerate(hiddens[id]):
145
+ mems_buffers[id][
146
+ layer, :, :text_len] = hidden.expand(
147
+ mems_buffers[id].shape[1], -1,
148
+ -1)[:, :text_len]
149
+ new_mem_len_part2 = (mems_indexs[id] +
150
+ hiddens[0][0].shape[1] -
151
+ text_len) % frame_len
152
+ if new_mem_len_part2 > 0:
153
+ for layer, hidden in enumerate(hiddens[id]):
154
+ mems_buffers[id][
155
+ layer, :, text_len:text_len +
156
+ new_mem_len_part2] = hidden.expand(
157
+ mems_buffers[id].shape[1], -1,
158
+ -1)[:, -new_mem_len_part2:]
159
+ mems_indexs[id] = text_len + new_mem_len_part2
160
+ else:
161
+ for layer, hidden in enumerate(hiddens[id]):
162
+ mems_buffers[id][layer, :,
163
+ mems_indexs[id]:mems_indexs[id] +
164
+ hidden.shape[1]] = hidden.expand(
165
+ mems_buffers[id].shape[1], -1, -1)
166
+ mems_indexs[id] += hidden.shape[1]
167
+ ret_mem.append(mems_buffers[id][:, :, :mems_indexs[id]])
168
+ return ret_mem, mems_indexs
169
+
170
+
171
+ def calc_next_tokens_frame_begin_id(text_len, frame_len, total_len):
172
+ # The fisrt token's position id of the frame that the next token belongs to;
173
+ if total_len < text_len:
174
+ return None
175
+ return (total_len - text_len) // frame_len * frame_len + text_len
176
+
177
+
178
+ def my_filling_sequence(
179
+ model,
180
+ tokenizer,
181
+ args,
182
+ seq,
183
+ batch_size,
184
+ get_masks_and_position_ids,
185
+ text_len,
186
+ frame_len,
187
+ strategy=BaseStrategy(),
188
+ strategy2=BaseStrategy(),
189
+ mems=None,
190
+ log_text_attention_weights=0, # default to 0: no artificial change
191
+ mode_stage1=True,
192
+ enforce_no_swin=False,
193
+ guider_seq=None,
194
+ guider_text_len=0,
195
+ guidance_alpha=1,
196
+ limited_spatial_channel_mem=False, # 空间通道的存储限制在本帧内
197
+ **kw_args):
198
+ '''
199
+ seq: [2, 3, 5, ..., -1(to be generated), -1, ...]
200
+ mems: [num_layers, batch_size, len_mems(index), mem_hidden_size]
201
+ cache, should be first mems.shape[1] parts of context_tokens.
202
+ mems are the first-level citizens here, but we don't assume what is memorized.
203
+ input mems are used when multi-phase generation.
204
+ '''
205
+ if guider_seq is not None:
206
+ logger.debug('Using Guidance In Inference')
207
+ if limited_spatial_channel_mem:
208
+ logger.debug("Limit spatial-channel's mem to current frame")
209
+ assert len(seq.shape) == 2
210
+
211
+ # building the initial tokens, attention_mask, and position_ids
212
+ actual_context_length = 0
213
+
214
+ while seq[-1][
215
+ actual_context_length] >= 0: # the last seq has least given tokens
216
+ actual_context_length += 1 # [0, context_length-1] are given
217
+ assert actual_context_length > 0
218
+ current_frame_num = (actual_context_length - text_len) // frame_len
219
+ assert current_frame_num >= 0
220
+ context_length = text_len + current_frame_num * frame_len
221
+
222
+ tokens, attention_mask, position_ids = get_masks_and_position_ids(
223
+ seq, text_len, frame_len)
224
+ tokens = tokens[..., :context_length]
225
+ input_tokens = tokens.clone()
226
+
227
+ if guider_seq is not None:
228
+ guider_index_delta = text_len - guider_text_len
229
+ guider_tokens, guider_attention_mask, guider_position_ids = get_masks_and_position_ids(
230
+ guider_seq, guider_text_len, frame_len)
231
+ guider_tokens = guider_tokens[..., :context_length -
232
+ guider_index_delta]
233
+ guider_input_tokens = guider_tokens.clone()
234
+
235
+ for fid in range(current_frame_num):
236
+ input_tokens[:, text_len + 400 * fid] = tokenizer['<start_of_image>']
237
+ if guider_seq is not None:
238
+ guider_input_tokens[:, guider_text_len +
239
+ 400 * fid] = tokenizer['<start_of_image>']
240
+
241
+ attention_mask = attention_mask.type_as(next(
242
+ model.parameters())) # if fp16
243
+ # initialize generation
244
+ counter = context_length - 1 # Last fixed index is ``counter''
245
+ index = 0 # Next forward starting index, also the length of cache.
246
+ mems_buffers_on_GPU = False
247
+ mems_indexs = [0, 0]
248
+ mems_len = [(400 + 74) if limited_spatial_channel_mem else 5 * 400 + 74,
249
+ 5 * 400 + 74]
250
+ mems_buffers = [
251
+ torch.zeros(args.num_layers,
252
+ batch_size,
253
+ mem_len,
254
+ args.hidden_size * 2,
255
+ dtype=next(model.parameters()).dtype)
256
+ for mem_len in mems_len
257
+ ]
258
+
259
+ if guider_seq is not None:
260
+ guider_attention_mask = guider_attention_mask.type_as(
261
+ next(model.parameters())) # if fp16
262
+ guider_mems_buffers = [
263
+ torch.zeros(args.num_layers,
264
+ batch_size,
265
+ mem_len,
266
+ args.hidden_size * 2,
267
+ dtype=next(model.parameters()).dtype)
268
+ for mem_len in mems_len
269
+ ]
270
+ guider_mems_indexs = [0, 0]
271
+ guider_mems = None
272
+
273
+ torch.cuda.empty_cache()
274
+ # step-by-step generation
275
+ while counter < len(seq[0]) - 1:
276
+ # we have generated counter+1 tokens
277
+ # Now, we want to generate seq[counter + 1],
278
+ # token[:, index: counter+1] needs forwarding.
279
+ if index == 0:
280
+ group_size = 2 if (input_tokens.shape[0] == batch_size
281
+ and not mode_stage1) else batch_size
282
+
283
+ logits_all = None
284
+ for batch_idx in range(0, input_tokens.shape[0], group_size):
285
+ logits, *output_per_layers = model(
286
+ input_tokens[batch_idx:batch_idx + group_size, index:],
287
+ position_ids[..., index:counter + 1],
288
+ attention_mask, # TODO memlen
289
+ mems=mems,
290
+ text_len=text_len,
291
+ frame_len=frame_len,
292
+ counter=counter,
293
+ log_text_attention_weights=log_text_attention_weights,
294
+ enforce_no_swin=enforce_no_swin,
295
+ **kw_args)
296
+ logits_all = torch.cat(
297
+ (logits_all,
298
+ logits), dim=0) if logits_all is not None else logits
299
+ mem_kv01 = [[o['mem_kv'][0] for o in output_per_layers],
300
+ [o['mem_kv'][1] for o in output_per_layers]]
301
+ next_tokens_frame_begin_id = calc_next_tokens_frame_begin_id(
302
+ text_len, frame_len, mem_kv01[0][0].shape[1])
303
+ for id, mem_kv in enumerate(mem_kv01):
304
+ for layer, mem_kv_perlayer in enumerate(mem_kv):
305
+ if limited_spatial_channel_mem and id == 0:
306
+ mems_buffers[id][
307
+ layer, batch_idx:batch_idx + group_size, :
308
+ text_len] = mem_kv_perlayer.expand(
309
+ min(group_size,
310
+ input_tokens.shape[0] - batch_idx), -1,
311
+ -1)[:, :text_len]
312
+ mems_buffers[id][layer, batch_idx:batch_idx+group_size, text_len:text_len+mem_kv_perlayer.shape[1]-next_tokens_frame_begin_id] =\
313
+ mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1)[:, next_tokens_frame_begin_id:]
314
+ else:
315
+ mems_buffers[id][
316
+ layer, batch_idx:batch_idx +
317
+ group_size, :mem_kv_perlayer.
318
+ shape[1]] = mem_kv_perlayer.expand(
319
+ min(group_size,
320
+ input_tokens.shape[0] - batch_idx), -1,
321
+ -1)
322
+ mems_indexs[0], mems_indexs[1] = mem_kv01[0][0].shape[
323
+ 1], mem_kv01[1][0].shape[1]
324
+ if limited_spatial_channel_mem:
325
+ mems_indexs[0] -= (next_tokens_frame_begin_id - text_len)
326
+
327
+ mems = [
328
+ mems_buffers[id][:, :, :mems_indexs[id]] for id in range(2)
329
+ ]
330
+ logits = logits_all
331
+
332
+ # Guider
333
+ if guider_seq is not None:
334
+ guider_logits_all = None
335
+ for batch_idx in range(0, guider_input_tokens.shape[0],
336
+ group_size):
337
+ guider_logits, *guider_output_per_layers = model(
338
+ guider_input_tokens[batch_idx:batch_idx + group_size,
339
+ max(index -
340
+ guider_index_delta, 0):],
341
+ guider_position_ids[
342
+ ...,
343
+ max(index - guider_index_delta, 0):counter + 1 -
344
+ guider_index_delta],
345
+ guider_attention_mask,
346
+ mems=guider_mems,
347
+ text_len=guider_text_len,
348
+ frame_len=frame_len,
349
+ counter=counter - guider_index_delta,
350
+ log_text_attention_weights=log_text_attention_weights,
351
+ enforce_no_swin=enforce_no_swin,
352
+ **kw_args)
353
+ guider_logits_all = torch.cat(
354
+ (guider_logits_all, guider_logits), dim=0
355
+ ) if guider_logits_all is not None else guider_logits
356
+ guider_mem_kv01 = [[
357
+ o['mem_kv'][0] for o in guider_output_per_layers
358
+ ], [o['mem_kv'][1] for o in guider_output_per_layers]]
359
+ for id, guider_mem_kv in enumerate(guider_mem_kv01):
360
+ for layer, guider_mem_kv_perlayer in enumerate(
361
+ guider_mem_kv):
362
+ if limited_spatial_channel_mem and id == 0:
363
+ guider_mems_buffers[id][
364
+ layer, batch_idx:batch_idx + group_size, :
365
+ guider_text_len] = guider_mem_kv_perlayer.expand(
366
+ min(group_size,
367
+ input_tokens.shape[0] - batch_idx),
368
+ -1, -1)[:, :guider_text_len]
369
+ guider_next_tokens_frame_begin_id = calc_next_tokens_frame_begin_id(
370
+ guider_text_len, frame_len,
371
+ guider_mem_kv_perlayer.shape[1])
372
+ guider_mems_buffers[id][layer, batch_idx:batch_idx+group_size, guider_text_len:guider_text_len+guider_mem_kv_perlayer.shape[1]-guider_next_tokens_frame_begin_id] =\
373
+ guider_mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1)[:, guider_next_tokens_frame_begin_id:]
374
+ else:
375
+ guider_mems_buffers[id][
376
+ layer, batch_idx:batch_idx +
377
+ group_size, :guider_mem_kv_perlayer.
378
+ shape[1]] = guider_mem_kv_perlayer.expand(
379
+ min(group_size,
380
+ input_tokens.shape[0] - batch_idx),
381
+ -1, -1)
382
+ guider_mems_indexs[0], guider_mems_indexs[
383
+ 1] = guider_mem_kv01[0][0].shape[1], guider_mem_kv01[
384
+ 1][0].shape[1]
385
+ if limited_spatial_channel_mem:
386
+ guider_mems_indexs[0] -= (
387
+ guider_next_tokens_frame_begin_id -
388
+ guider_text_len)
389
+ guider_mems = [
390
+ guider_mems_buffers[id][:, :, :guider_mems_indexs[id]]
391
+ for id in range(2)
392
+ ]
393
+ guider_logits = guider_logits_all
394
+ else:
395
+ if not mems_buffers_on_GPU:
396
+ if not mode_stage1:
397
+ torch.cuda.empty_cache()
398
+ for idx, mem in enumerate(mems):
399
+ mems[idx] = mem.to(next(model.parameters()).device)
400
+ if guider_seq is not None:
401
+ for idx, mem in enumerate(guider_mems):
402
+ guider_mems[idx] = mem.to(
403
+ next(model.parameters()).device)
404
+ else:
405
+ torch.cuda.empty_cache()
406
+ for idx, mem_buffer in enumerate(mems_buffers):
407
+ mems_buffers[idx] = mem_buffer.to(
408
+ next(model.parameters()).device)
409
+ mems = [
410
+ mems_buffers[id][:, :, :mems_indexs[id]]
411
+ for id in range(2)
412
+ ]
413
+ if guider_seq is not None:
414
+ for idx, guider_mem_buffer in enumerate(
415
+ guider_mems_buffers):
416
+ guider_mems_buffers[idx] = guider_mem_buffer.to(
417
+ next(model.parameters()).device)
418
+ guider_mems = [
419
+ guider_mems_buffers[id]
420
+ [:, :, :guider_mems_indexs[id]] for id in range(2)
421
+ ]
422
+ mems_buffers_on_GPU = True
423
+
424
+ logits, *output_per_layers = model(
425
+ input_tokens[:, index:],
426
+ position_ids[..., index:counter + 1],
427
+ attention_mask, # TODO memlen
428
+ mems=mems,
429
+ text_len=text_len,
430
+ frame_len=frame_len,
431
+ counter=counter,
432
+ log_text_attention_weights=log_text_attention_weights,
433
+ enforce_no_swin=enforce_no_swin,
434
+ limited_spatial_channel_mem=limited_spatial_channel_mem,
435
+ **kw_args)
436
+ mem_kv0, mem_kv1 = [o['mem_kv'][0] for o in output_per_layers
437
+ ], [o['mem_kv'][1] for o in output_per_layers]
438
+
439
+ if guider_seq is not None:
440
+ guider_logits, *guider_output_per_layers = model(
441
+ guider_input_tokens[:,
442
+ max(index - guider_index_delta, 0):],
443
+ guider_position_ids[...,
444
+ max(index -
445
+ guider_index_delta, 0):counter +
446
+ 1 - guider_index_delta],
447
+ guider_attention_mask,
448
+ mems=guider_mems,
449
+ text_len=guider_text_len,
450
+ frame_len=frame_len,
451
+ counter=counter - guider_index_delta,
452
+ log_text_attention_weights=0,
453
+ enforce_no_swin=enforce_no_swin,
454
+ limited_spatial_channel_mem=limited_spatial_channel_mem,
455
+ **kw_args)
456
+ guider_mem_kv0, guider_mem_kv1 = [
457
+ o['mem_kv'][0] for o in guider_output_per_layers
458
+ ], [o['mem_kv'][1] for o in guider_output_per_layers]
459
+
460
+ if not mems_buffers_on_GPU:
461
+ torch.cuda.empty_cache()
462
+ for idx, mem_buffer in enumerate(mems_buffers):
463
+ mems_buffers[idx] = mem_buffer.to(
464
+ next(model.parameters()).device)
465
+ if guider_seq is not None:
466
+ for idx, guider_mem_buffer in enumerate(
467
+ guider_mems_buffers):
468
+ guider_mems_buffers[idx] = guider_mem_buffer.to(
469
+ next(model.parameters()).device)
470
+ mems_buffers_on_GPU = True
471
+
472
+ mems, mems_indexs = my_update_mems([mem_kv0, mem_kv1],
473
+ mems_buffers, mems_indexs,
474
+ limited_spatial_channel_mem,
475
+ text_len, frame_len)
476
+ if guider_seq is not None:
477
+ guider_mems, guider_mems_indexs = my_update_mems(
478
+ [guider_mem_kv0, guider_mem_kv1], guider_mems_buffers,
479
+ guider_mems_indexs, limited_spatial_channel_mem,
480
+ guider_text_len, frame_len)
481
+
482
+ counter += 1
483
+ index = counter
484
+
485
+ logits = logits[:, -1].expand(batch_size,
486
+ -1) # [batch size, vocab size]
487
+ tokens = tokens.expand(batch_size, -1)
488
+ if guider_seq is not None:
489
+ guider_logits = guider_logits[:, -1].expand(batch_size, -1)
490
+ guider_tokens = guider_tokens.expand(batch_size, -1)
491
+
492
+ if seq[-1][counter].item() < 0:
493
+ # sampling
494
+ guided_logits = guider_logits + (
495
+ logits - guider_logits
496
+ ) * guidance_alpha if guider_seq is not None else logits
497
+ if mode_stage1 and counter < text_len + 400:
498
+ tokens, mems = strategy.forward(guided_logits, tokens, mems)
499
+ else:
500
+ tokens, mems = strategy2.forward(guided_logits, tokens, mems)
501
+ if guider_seq is not None:
502
+ guider_tokens = torch.cat((guider_tokens, tokens[:, -1:]),
503
+ dim=1)
504
+
505
+ if seq[0][counter].item() >= 0:
506
+ for si in range(seq.shape[0]):
507
+ if seq[si][counter].item() >= 0:
508
+ tokens[si, -1] = seq[si, counter]
509
+ if guider_seq is not None:
510
+ guider_tokens[si,
511
+ -1] = guider_seq[si, counter -
512
+ guider_index_delta]
513
+
514
+ else:
515
+ tokens = torch.cat(
516
+ (tokens, seq[:, counter:counter + 1].clone().expand(
517
+ tokens.shape[0], 1).to(device=tokens.device,
518
+ dtype=tokens.dtype)),
519
+ dim=1)
520
+ if guider_seq is not None:
521
+ guider_tokens = torch.cat(
522
+ (guider_tokens,
523
+ guider_seq[:, counter - guider_index_delta:counter + 1 -
524
+ guider_index_delta].clone().expand(
525
+ guider_tokens.shape[0], 1).to(
526
+ device=guider_tokens.device,
527
+ dtype=guider_tokens.dtype)),
528
+ dim=1)
529
+
530
+ input_tokens = tokens.clone()
531
+ if guider_seq is not None:
532
+ guider_input_tokens = guider_tokens.clone()
533
+ if (index - text_len - 1) // 400 < (input_tokens.shape[-1] - text_len -
534
+ 1) // 400:
535
+ boi_idx = ((index - text_len - 1) // 400 + 1) * 400 + text_len
536
+ while boi_idx < input_tokens.shape[-1]:
537
+ input_tokens[:, boi_idx] = tokenizer['<start_of_image>']
538
+ if guider_seq is not None:
539
+ guider_input_tokens[:, boi_idx -
540
+ guider_index_delta] = tokenizer[
541
+ '<start_of_image>']
542
+ boi_idx += 400
543
+
544
+ if strategy.is_done:
545
+ break
546
+ return strategy.finalize(tokens, mems)
547
+
548
+
549
+ class InferenceModel_Sequential(CogVideoCacheModel):
550
+ def __init__(self, args, transformer=None, parallel_output=True):
551
+ super().__init__(args,
552
+ transformer=transformer,
553
+ parallel_output=parallel_output,
554
+ window_size=-1,
555
+ cogvideo_stage=1)
556
+
557
+ # TODO: check it
558
+
559
+ def final_forward(self, logits, **kwargs):
560
+ logits_parallel = logits
561
+ logits_parallel = torch.nn.functional.linear(
562
+ logits_parallel.float(),
563
+ self.transformer.word_embeddings.weight[:20000].float())
564
+ return logits_parallel
565
+
566
+
567
+ class InferenceModel_Interpolate(CogVideoCacheModel):
568
+ def __init__(self, args, transformer=None, parallel_output=True):
569
+ super().__init__(args,
570
+ transformer=transformer,
571
+ parallel_output=parallel_output,
572
+ window_size=10,
573
+ cogvideo_stage=2)
574
+
575
+ # TODO: check it
576
+
577
+ def final_forward(self, logits, **kwargs):
578
+ logits_parallel = logits
579
+ logits_parallel = torch.nn.functional.linear(
580
+ logits_parallel.float(),
581
+ self.transformer.word_embeddings.weight[:20000].float())
582
+ return logits_parallel
583
+
584
+
585
+ def get_default_args() -> argparse.Namespace:
586
+ known = argparse.Namespace(generate_frame_num=5,
587
+ coglm_temperature2=0.89,
588
+ use_guidance_stage1=True,
589
+ use_guidance_stage2=False,
590
+ guidance_alpha=3.0,
591
+ stage_1=True,
592
+ stage_2=False,
593
+ both_stages=False,
594
+ parallel_size=1,
595
+ stage1_max_inference_batch_size=-1,
596
+ multi_gpu=False,
597
+ layout='64, 464, 2064',
598
+ window_size=10,
599
+ additional_seqlen=2000,
600
+ cogvideo_stage=1)
601
+
602
+ args_list = [
603
+ '--tokenizer-type',
604
+ 'fake',
605
+ '--mode',
606
+ 'inference',
607
+ '--distributed-backend',
608
+ 'nccl',
609
+ '--fp16',
610
+ '--model-parallel-size',
611
+ '1',
612
+ '--temperature',
613
+ '1.05',
614
+ '--top_k',
615
+ '12',
616
+ '--sandwich-ln',
617
+ '--seed',
618
+ '1234',
619
+ '--num-workers',
620
+ '0',
621
+ '--batch-size',
622
+ '1',
623
+ '--max-inference-batch-size',
624
+ '8',
625
+ ]
626
+ args = get_args(args_list)
627
+ args = argparse.Namespace(**vars(args), **vars(known))
628
+ args.layout = [int(x) for x in args.layout.split(',')]
629
+ args.do_train = False
630
+ return args
631
+
632
+
633
+ class Model:
634
+ def __init__(self, only_first_stage: bool = False):
635
+ self.args = get_default_args()
636
+ if only_first_stage:
637
+ self.args.stage_1 = True
638
+ self.args.both_stages = False
639
+ else:
640
+ self.args.stage_1 = False
641
+ self.args.both_stages = True
642
+
643
+ self.tokenizer = self.load_tokenizer()
644
+
645
+ self.model_stage1, self.args = self.load_model_stage1()
646
+ self.model_stage2, self.args = self.load_model_stage2()
647
+
648
+ self.strategy_cogview2, self.strategy_cogvideo = self.load_strategies()
649
+ self.dsr = self.load_dsr()
650
+
651
+ self.device = torch.device(self.args.device)
652
+
653
+ def load_tokenizer(self) -> IceTokenizer:
654
+ logger.info('--- load_tokenizer ---')
655
+ start = time.perf_counter()
656
+
657
+ tokenizer = IceTokenizer(ICETK_MODEL_DIR.as_posix())
658
+ tokenizer.add_special_tokens(
659
+ ['<start_of_image>', '<start_of_english>', '<start_of_chinese>'])
660
+
661
+ elapsed = time.perf_counter() - start
662
+ logger.info(f'--- done ({elapsed=:.3f}) ---')
663
+ return tokenizer
664
+
665
+ def load_model_stage1(
666
+ self) -> tuple[CogVideoCacheModel, argparse.Namespace]:
667
+ logger.info('--- load_model_stage1 ---')
668
+ start = time.perf_counter()
669
+
670
+ args = self.args
671
+ model_stage1, args = InferenceModel_Sequential.from_pretrained(
672
+ args, 'cogvideo-stage1')
673
+ model_stage1.eval()
674
+ if args.both_stages:
675
+ model_stage1 = model_stage1.cpu()
676
+
677
+ elapsed = time.perf_counter() - start
678
+ logger.info(f'--- done ({elapsed=:.3f}) ---')
679
+ return model_stage1, args
680
+
681
+ def load_model_stage2(
682
+ self) -> tuple[CogVideoCacheModel | None, argparse.Namespace]:
683
+ logger.info('--- load_model_stage2 ---')
684
+ start = time.perf_counter()
685
+
686
+ args = self.args
687
+ if args.both_stages:
688
+ model_stage2, args = InferenceModel_Interpolate.from_pretrained(
689
+ args, 'cogvideo-stage2')
690
+ model_stage2.eval()
691
+ if args.both_stages:
692
+ model_stage2 = model_stage2.cpu()
693
+ else:
694
+ model_stage2 = None
695
+
696
+ elapsed = time.perf_counter() - start
697
+ logger.info(f'--- done ({elapsed=:.3f}) ---')
698
+ return model_stage2, args
699
+
700
+ def load_strategies(self) -> tuple[CoglmStrategy, CoglmStrategy]:
701
+ logger.info('--- load_strategies ---')
702
+ start = time.perf_counter()
703
+
704
+ invalid_slices = [slice(self.tokenizer.num_image_tokens, None)]
705
+ strategy_cogview2 = CoglmStrategy(invalid_slices,
706
+ temperature=1.0,
707
+ top_k=16)
708
+ strategy_cogvideo = CoglmStrategy(
709
+ invalid_slices,
710
+ temperature=self.args.temperature,
711
+ top_k=self.args.top_k,
712
+ temperature2=self.args.coglm_temperature2)
713
+
714
+ elapsed = time.perf_counter() - start
715
+ logger.info(f'--- done ({elapsed=:.3f}) ---')
716
+ return strategy_cogview2, strategy_cogvideo
717
+
718
+ def load_dsr(self) -> DirectSuperResolution | None:
719
+ logger.info('--- load_dsr ---')
720
+ start = time.perf_counter()
721
+
722
+ if self.args.both_stages:
723
+ path = auto_create('cogview2-dsr', path=None)
724
+ dsr = DirectSuperResolution(self.args,
725
+ path,
726
+ max_bz=12,
727
+ onCUDA=False)
728
+ else:
729
+ dsr = None
730
+
731
+ elapsed = time.perf_counter() - start
732
+ logger.info(f'--- done ({elapsed=:.3f}) ---')
733
+ return dsr
734
+
735
+ @torch.inference_mode()
736
+ def process_stage1(self,
737
+ model,
738
+ seq_text,
739
+ duration,
740
+ video_raw_text=None,
741
+ video_guidance_text='视频',
742
+ image_text_suffix='',
743
+ batch_size=1):
744
+ process_start_time = time.perf_counter()
745
+
746
+ generate_frame_num = self.args.generate_frame_num
747
+ tokenizer = self.tokenizer
748
+ use_guide = self.args.use_guidance_stage1
749
+
750
+ if next(model.parameters()).device != self.device:
751
+ move_start_time = time.perf_counter()
752
+ logger.debug('moving stage 1 model to cuda')
753
+
754
+ model = model.to(self.device)
755
+
756
+ elapsed = time.perf_counter() - move_start_time
757
+ logger.debug(f'moving in model1 takes time: {elapsed:.2f}')
758
+
759
+ if video_raw_text is None:
760
+ video_raw_text = seq_text
761
+ mbz = self.args.stage1_max_inference_batch_size if self.args.stage1_max_inference_batch_size > 0 else self.args.max_inference_batch_size
762
+ assert batch_size < mbz or batch_size % mbz == 0
763
+ frame_len = 400
764
+
765
+ # generate the first frame:
766
+ enc_text = tokenizer.encode(seq_text + image_text_suffix)
767
+ seq_1st = enc_text + [tokenizer['<start_of_image>']] + [-1] * 400
768
+ logger.info(
769
+ f'[Generating First Frame with CogView2] Raw text: {tokenizer.decode(enc_text):s}'
770
+ )
771
+ text_len_1st = len(seq_1st) - frame_len * 1 - 1
772
+
773
+ seq_1st = torch.tensor(seq_1st, dtype=torch.long,
774
+ device=self.device).unsqueeze(0)
775
+ output_list_1st = []
776
+ for tim in range(max(batch_size // mbz, 1)):
777
+ start_time = time.perf_counter()
778
+ output_list_1st.append(
779
+ my_filling_sequence(
780
+ model,
781
+ tokenizer,
782
+ self.args,
783
+ seq_1st.clone(),
784
+ batch_size=min(batch_size, mbz),
785
+ get_masks_and_position_ids=
786
+ get_masks_and_position_ids_stage1,
787
+ text_len=text_len_1st,
788
+ frame_len=frame_len,
789
+ strategy=self.strategy_cogview2,
790
+ strategy2=self.strategy_cogvideo,
791
+ log_text_attention_weights=1.4,
792
+ enforce_no_swin=True,
793
+ mode_stage1=True,
794
+ )[0])
795
+ elapsed = time.perf_counter() - start_time
796
+ logger.info(f'[First Frame] Elapsed: {elapsed:.2f}')
797
+ output_tokens_1st = torch.cat(output_list_1st, dim=0)
798
+ given_tokens = output_tokens_1st[:, text_len_1st + 1:text_len_1st +
799
+ 401].unsqueeze(
800
+ 1
801
+ ) # given_tokens.shape: [bs, frame_num, 400]
802
+
803
+ # generate subsequent frames:
804
+ total_frames = generate_frame_num
805
+ enc_duration = tokenizer.encode(f'{float(duration)}秒')
806
+ if use_guide:
807
+ video_raw_text = video_raw_text + ' 视频'
808
+ enc_text_video = tokenizer.encode(video_raw_text)
809
+ seq = enc_duration + [tokenizer['<n>']] + enc_text_video + [
810
+ tokenizer['<start_of_image>']
811
+ ] + [-1] * 400 * generate_frame_num
812
+ guider_seq = enc_duration + [tokenizer['<n>']] + tokenizer.encode(
813
+ video_guidance_text) + [tokenizer['<start_of_image>']
814
+ ] + [-1] * 400 * generate_frame_num
815
+ logger.info(
816
+ f'[Stage1: Generating Subsequent Frames, Frame Rate {4/duration:.1f}] raw text: {tokenizer.decode(enc_text_video):s}'
817
+ )
818
+
819
+ text_len = len(seq) - frame_len * generate_frame_num - 1
820
+ guider_text_len = len(guider_seq) - frame_len * generate_frame_num - 1
821
+ seq = torch.tensor(seq, dtype=torch.long,
822
+ device=self.device).unsqueeze(0).repeat(
823
+ batch_size, 1)
824
+ guider_seq = torch.tensor(guider_seq,
825
+ dtype=torch.long,
826
+ device=self.device).unsqueeze(0).repeat(
827
+ batch_size, 1)
828
+
829
+ for given_frame_id in range(given_tokens.shape[1]):
830
+ seq[:, text_len + 1 + given_frame_id * 400:text_len + 1 +
831
+ (given_frame_id + 1) * 400] = given_tokens[:, given_frame_id]
832
+ guider_seq[:, guider_text_len + 1 +
833
+ given_frame_id * 400:guider_text_len + 1 +
834
+ (given_frame_id + 1) *
835
+ 400] = given_tokens[:, given_frame_id]
836
+ output_list = []
837
+
838
+ if use_guide:
839
+ video_log_text_attention_weights = 0
840
+ else:
841
+ guider_seq = None
842
+ video_log_text_attention_weights = 1.4
843
+
844
+ for tim in range(max(batch_size // mbz, 1)):
845
+ input_seq = seq[:min(batch_size, mbz)].clone(
846
+ ) if tim == 0 else seq[mbz * tim:mbz * (tim + 1)].clone()
847
+ guider_seq2 = (guider_seq[:min(batch_size, mbz)].clone()
848
+ if tim == 0 else guider_seq[mbz * tim:mbz *
849
+ (tim + 1)].clone()
850
+ ) if guider_seq is not None else None
851
+ output_list.append(
852
+ my_filling_sequence(
853
+ model,
854
+ tokenizer,
855
+ self.args,
856
+ input_seq,
857
+ batch_size=min(batch_size, mbz),
858
+ get_masks_and_position_ids=
859
+ get_masks_and_position_ids_stage1,
860
+ text_len=text_len,
861
+ frame_len=frame_len,
862
+ strategy=self.strategy_cogview2,
863
+ strategy2=self.strategy_cogvideo,
864
+ log_text_attention_weights=video_log_text_attention_weights,
865
+ guider_seq=guider_seq2,
866
+ guider_text_len=guider_text_len,
867
+ guidance_alpha=self.args.guidance_alpha,
868
+ limited_spatial_channel_mem=True,
869
+ mode_stage1=True,
870
+ )[0])
871
+
872
+ output_tokens = torch.cat(output_list, dim=0)[:, 1 + text_len:]
873
+
874
+ if self.args.both_stages:
875
+ move_start_time = time.perf_counter()
876
+ logger.debug('moving stage 1 model to cpu')
877
+ model = model.cpu()
878
+ torch.cuda.empty_cache()
879
+ elapsed = time.perf_counter() - move_start_time
880
+ logger.debug(f'moving in model1 takes time: {elapsed:.2f}')
881
+
882
+ # decoding
883
+ res = []
884
+ for seq in output_tokens:
885
+ decoded_imgs = [
886
+ self.postprocess(
887
+ torch.nn.functional.interpolate(tokenizer.decode(
888
+ image_ids=seq.tolist()[i * 400:(i + 1) * 400]),
889
+ size=(480, 480))[0])
890
+ for i in range(total_frames)
891
+ ]
892
+ res.append(decoded_imgs) # only the last image (target)
893
+
894
+ assert len(res) == batch_size
895
+ tokens = output_tokens[:, :+total_frames * 400].reshape(
896
+ -1, total_frames, 400).cpu()
897
+
898
+ elapsed = time.perf_counter() - process_start_time
899
+ logger.info(f'--- done ({elapsed=:.3f}) ---')
900
+ return tokens, res[0]
901
+
902
+ @torch.inference_mode()
903
+ def process_stage2(self,
904
+ model,
905
+ seq_text,
906
+ duration,
907
+ parent_given_tokens,
908
+ video_raw_text=None,
909
+ video_guidance_text='视频',
910
+ gpu_rank=0,
911
+ gpu_parallel_size=1):
912
+ process_start_time = time.perf_counter()
913
+
914
+ generate_frame_num = self.args.generate_frame_num
915
+ tokenizer = self.tokenizer
916
+ use_guidance = self.args.use_guidance_stage2
917
+
918
+ stage2_start_time = time.perf_counter()
919
+
920
+ if next(model.parameters()).device != self.device:
921
+ move_start_time = time.perf_counter()
922
+ logger.debug('moving stage-2 model to cuda')
923
+
924
+ model = model.to(self.device)
925
+
926
+ elapsed = time.perf_counter() - move_start_time
927
+ logger.debug(f'moving in stage-2 model takes time: {elapsed:.2f}')
928
+
929
+ try:
930
+ sample_num_allgpu = parent_given_tokens.shape[0]
931
+ sample_num = sample_num_allgpu // gpu_parallel_size
932
+ assert sample_num * gpu_parallel_size == sample_num_allgpu
933
+ parent_given_tokens = parent_given_tokens[gpu_rank *
934
+ sample_num:(gpu_rank +
935
+ 1) *
936
+ sample_num]
937
+ except:
938
+ logger.critical('No frame_tokens found in interpolation, skip')
939
+ return False, []
940
+
941
+ # CogVideo Stage2 Generation
942
+ while duration >= 0.5: # TODO: You can change the boundary to change the frame rate
943
+ parent_given_tokens_num = parent_given_tokens.shape[1]
944
+ generate_batchsize_persample = (parent_given_tokens_num - 1) // 2
945
+ generate_batchsize_total = generate_batchsize_persample * sample_num
946
+ total_frames = generate_frame_num
947
+ frame_len = 400
948
+ enc_text = tokenizer.encode(seq_text)
949
+ enc_duration = tokenizer.encode(str(float(duration)) + '秒')
950
+ seq = enc_duration + [tokenizer['<n>']] + enc_text + [
951
+ tokenizer['<start_of_image>']
952
+ ] + [-1] * 400 * generate_frame_num
953
+ text_len = len(seq) - frame_len * generate_frame_num - 1
954
+
955
+ logger.info(
956
+ f'[Stage2: Generating Frames, Frame Rate {int(4/duration):d}] raw text: {tokenizer.decode(enc_text):s}'
957
+ )
958
+
959
+ # generation
960
+ seq = torch.tensor(seq, dtype=torch.long,
961
+ device=self.device).unsqueeze(0).repeat(
962
+ generate_batchsize_total, 1)
963
+ for sample_i in range(sample_num):
964
+ for i in range(generate_batchsize_persample):
965
+ seq[sample_i * generate_batchsize_persample +
966
+ i][text_len + 1:text_len + 1 +
967
+ 400] = parent_given_tokens[sample_i][2 * i]
968
+ seq[sample_i * generate_batchsize_persample +
969
+ i][text_len + 1 + 400:text_len + 1 +
970
+ 800] = parent_given_tokens[sample_i][2 * i + 1]
971
+ seq[sample_i * generate_batchsize_persample +
972
+ i][text_len + 1 + 800:text_len + 1 +
973
+ 1200] = parent_given_tokens[sample_i][2 * i + 2]
974
+
975
+ if use_guidance:
976
+ guider_seq = enc_duration + [
977
+ tokenizer['<n>']
978
+ ] + tokenizer.encode(video_guidance_text) + [
979
+ tokenizer['<start_of_image>']
980
+ ] + [-1] * 400 * generate_frame_num
981
+ guider_text_len = len(
982
+ guider_seq) - frame_len * generate_frame_num - 1
983
+ guider_seq = torch.tensor(
984
+ guider_seq, dtype=torch.long,
985
+ device=self.device).unsqueeze(0).repeat(
986
+ generate_batchsize_total, 1)
987
+ for sample_i in range(sample_num):
988
+ for i in range(generate_batchsize_persample):
989
+ guider_seq[sample_i * generate_batchsize_persample +
990
+ i][text_len + 1:text_len + 1 +
991
+ 400] = parent_given_tokens[sample_i][2 *
992
+ i]
993
+ guider_seq[sample_i * generate_batchsize_persample +
994
+ i][text_len + 1 + 400:text_len + 1 +
995
+ 800] = parent_given_tokens[sample_i][2 *
996
+ i +
997
+ 1]
998
+ guider_seq[sample_i * generate_batchsize_persample +
999
+ i][text_len + 1 + 800:text_len + 1 +
1000
+ 1200] = parent_given_tokens[sample_i][2 *
1001
+ i +
1002
+ 2]
1003
+ video_log_text_attention_weights = 0
1004
+ else:
1005
+ guider_seq = None
1006
+ guider_text_len = 0
1007
+ video_log_text_attention_weights = 1.4
1008
+
1009
+ mbz = self.args.max_inference_batch_size
1010
+
1011
+ assert generate_batchsize_total < mbz or generate_batchsize_total % mbz == 0
1012
+ output_list = []
1013
+ start_time = time.perf_counter()
1014
+ for tim in range(max(generate_batchsize_total // mbz, 1)):
1015
+ input_seq = seq[:min(generate_batchsize_total, mbz)].clone(
1016
+ ) if tim == 0 else seq[mbz * tim:mbz * (tim + 1)].clone()
1017
+ guider_seq2 = (
1018
+ guider_seq[:min(generate_batchsize_total, mbz)].clone()
1019
+ if tim == 0 else guider_seq[mbz * tim:mbz *
1020
+ (tim + 1)].clone()
1021
+ ) if guider_seq is not None else None
1022
+ output_list.append(
1023
+ my_filling_sequence(
1024
+ model,
1025
+ tokenizer,
1026
+ self.args,
1027
+ input_seq,
1028
+ batch_size=min(generate_batchsize_total, mbz),
1029
+ get_masks_and_position_ids=
1030
+ get_masks_and_position_ids_stage2,
1031
+ text_len=text_len,
1032
+ frame_len=frame_len,
1033
+ strategy=self.strategy_cogview2,
1034
+ strategy2=self.strategy_cogvideo,
1035
+ log_text_attention_weights=
1036
+ video_log_text_attention_weights,
1037
+ mode_stage1=False,
1038
+ guider_seq=guider_seq2,
1039
+ guider_text_len=guider_text_len,
1040
+ guidance_alpha=self.args.guidance_alpha,
1041
+ limited_spatial_channel_mem=True,
1042
+ )[0])
1043
+ elapsed = time.perf_counter() - start_time
1044
+ logger.info(f'Duration {duration:.2f}, Elapsed: {elapsed:.2f}\n')
1045
+
1046
+ output_tokens = torch.cat(output_list, dim=0)
1047
+ output_tokens = output_tokens[:, text_len + 1:text_len + 1 +
1048
+ (total_frames) * 400].reshape(
1049
+ sample_num, -1,
1050
+ 400 * total_frames)
1051
+ output_tokens_merge = torch.cat(
1052
+ (output_tokens[:, :, :1 * 400], output_tokens[:, :,
1053
+ 400 * 3:4 * 400],
1054
+ output_tokens[:, :, 400 * 1:2 * 400],
1055
+ output_tokens[:, :, 400 * 4:(total_frames) * 400]),
1056
+ dim=2).reshape(sample_num, -1, 400)
1057
+
1058
+ output_tokens_merge = torch.cat(
1059
+ (output_tokens_merge, output_tokens[:, -1:, 400 * 2:3 * 400]),
1060
+ dim=1)
1061
+ duration /= 2
1062
+ parent_given_tokens = output_tokens_merge
1063
+
1064
+ if self.args.both_stages:
1065
+ move_start_time = time.perf_counter()
1066
+ logger.debug('moving stage 2 model to cpu')
1067
+ model = model.cpu()
1068
+ torch.cuda.empty_cache()
1069
+ elapsed = time.perf_counter() - move_start_time
1070
+ logger.debug(f'moving out model2 takes time: {elapsed:.2f}')
1071
+
1072
+ elapsed = time.perf_counter() - stage2_start_time
1073
+ logger.info(f'CogVideo Stage2 completed. Elapsed: {elapsed:.2f}\n')
1074
+
1075
+ # direct super-resolution by CogView2
1076
+ logger.info('[Direct super-resolution]')
1077
+ dsr_start_time = time.perf_counter()
1078
+
1079
+ enc_text = tokenizer.encode(seq_text)
1080
+ frame_num_per_sample = parent_given_tokens.shape[1]
1081
+ parent_given_tokens_2d = parent_given_tokens.reshape(-1, 400)
1082
+ text_seq = torch.tensor(enc_text, dtype=torch.long,
1083
+ device=self.device).unsqueeze(0).repeat(
1084
+ parent_given_tokens_2d.shape[0], 1)
1085
+ sred_tokens = self.dsr(text_seq, parent_given_tokens_2d)
1086
+
1087
+ decoded_sr_videos = []
1088
+ for sample_i in range(sample_num):
1089
+ decoded_sr_imgs = []
1090
+ for frame_i in range(frame_num_per_sample):
1091
+ decoded_sr_img = tokenizer.decode(
1092
+ image_ids=sred_tokens[frame_i + sample_i *
1093
+ frame_num_per_sample][-3600:])
1094
+ decoded_sr_imgs.append(
1095
+ self.postprocess(
1096
+ torch.nn.functional.interpolate(decoded_sr_img,
1097
+ size=(480, 480))[0]))
1098
+ decoded_sr_videos.append(decoded_sr_imgs)
1099
+
1100
+ elapsed = time.perf_counter() - dsr_start_time
1101
+ logger.info(
1102
+ f'Direct super-resolution completed. Elapsed: {elapsed:.2f}')
1103
+
1104
+ elapsed = time.perf_counter() - process_start_time
1105
+ logger.info(f'--- done ({elapsed=:.3f}) ---')
1106
+ return True, decoded_sr_videos[0]
1107
+
1108
+ @staticmethod
1109
+ def postprocess(tensor: torch.Tensor) -> np.ndarray:
1110
+ return tensor.cpu().mul(255).add_(0.5).clamp_(0, 255).permute(
1111
+ 1, 2, 0).to(torch.uint8).numpy()
1112
+
1113
+ def run(self, text: str, seed: int,
1114
+ only_first_stage: bool) -> list[np.ndarray]:
1115
+ logger.info('==================== run ====================')
1116
+ start = time.perf_counter()
1117
+
1118
+ set_random_seed(seed)
1119
+
1120
+ if only_first_stage:
1121
+ self.args.stage_1 = True
1122
+ self.args.both_stages = False
1123
+ else:
1124
+ self.args.stage_1 = False
1125
+ self.args.both_stages = True
1126
+
1127
+ parent_given_tokens, res = self.process_stage1(
1128
+ self.model_stage1,
1129
+ text,
1130
+ duration=4.0,
1131
+ video_raw_text=text,
1132
+ video_guidance_text='视频',
1133
+ image_text_suffix=' 高清摄影',
1134
+ batch_size=self.args.batch_size)
1135
+ if not only_first_stage:
1136
+ _, res = self.process_stage2(
1137
+ self.model_stage2,
1138
+ text,
1139
+ duration=2.0,
1140
+ parent_given_tokens=parent_given_tokens,
1141
+ video_raw_text=text + ' 视频',
1142
+ video_guidance_text='视频',
1143
+ gpu_rank=0,
1144
+ gpu_parallel_size=1) # TODO: 修改
1145
+
1146
+ elapsed = time.perf_counter() - start
1147
+ logger.info(f'Elapsed: {elapsed:.3f}')
1148
+ logger.info('==================== done ====================')
1149
+ return res
1150
+
1151
+
1152
+ class AppModel(Model):
1153
+ def __init__(self, only_first_stage: bool):
1154
+ super().__init__(only_first_stage)
1155
+ self.translator = gr.Interface.load(
1156
+ 'spaces/chinhon/translation_eng2ch')
1157
+
1158
+ def to_video(self, frames: list[np.ndarray]) -> str:
1159
+ out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
1160
+ if self.args.stage_1:
1161
+ fps = 4
1162
+ else:
1163
+ fps = 8
1164
+ writer = iio.get_writer(out_file.name, fps=fps)
1165
+ for frame in frames:
1166
+ writer.append_data(frame)
1167
+ writer.close()
1168
+ return out_file.name
1169
+
1170
+ def run_with_translation(
1171
+ self, text: str, translate: bool, seed: int, only_first_stage: bool
1172
+ ) -> tuple[str | None, np.ndarray | None, list[np.ndarray] | None]:
1173
+ logger.info(f'{text=}, {translate=}, {seed=}, {only_first_stage=}')
1174
+ if translate:
1175
+ text = translated_text = self.translator(text)
1176
+ else:
1177
+ translated_text = None
1178
+ frames = self.run(text, seed, only_first_stage)
1179
+ video_path = self.to_video(frames)
1180
+ return translated_text, video_path, frames
patch ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/coglm_strategy.py b/coglm_strategy.py
2
+ index d485715..a9eab3b 100644
3
+ --- a/coglm_strategy.py
4
+ +++ b/coglm_strategy.py
5
+ @@ -8,6 +8,7 @@
6
+
7
+ # here put the import lib
8
+ import os
9
+ +import pathlib
10
+ import sys
11
+ import math
12
+ import random
13
+ @@ -58,7 +59,8 @@ class CoglmStrategy:
14
+ self._is_done = False
15
+ self.outlier_count_down = torch.zeros(16)
16
+ self.vis_list = [[]for i in range(16)]
17
+ - self.cluster_labels = torch.tensor(np.load('cluster_label2.npy'), device='cuda', dtype=torch.long)
18
+ + cluster_label_path = pathlib.Path(__file__).parent / 'cluster_label2.npy'
19
+ + self.cluster_labels = torch.tensor(np.load(cluster_label_path), device='cuda', dtype=torch.long)
20
+ self.start_pos = -1
21
+ self.white_cluster = []
22
+ # self.fout = open('tmp.txt', 'w')
23
+ @@ -98,4 +100,4 @@ class CoglmStrategy:
24
+
25
+ def finalize(self, tokens, mems):
26
+ self._is_done = False
27
+ - return tokens, mems
28
+
29
+ + return tokens, mems
30
+ diff --git a/sr_pipeline/dsr_sampling.py b/sr_pipeline/dsr_sampling.py
31
+ index 5b8dded..07e97fd 100644
32
+ --- a/sr_pipeline/dsr_sampling.py
33
+ +++ b/sr_pipeline/dsr_sampling.py
34
+ @@ -8,6 +8,7 @@
35
+
36
+ # here put the import lib
37
+ import os
38
+ +import pathlib
39
+ import sys
40
+ import math
41
+ import random
42
+ @@ -28,7 +29,8 @@ class IterativeEntfilterStrategy:
43
+ self.invalid_slices = invalid_slices
44
+ self.temperature = temperature
45
+ self.topk = topk
46
+ - self.cluster_labels = torch.tensor(np.load('cluster_label2.npy'), device='cuda', dtype=torch.long)
47
+ + cluster_label_path = pathlib.Path(__file__).parents[1] / 'cluster_label2.npy'
48
+ + self.cluster_labels = torch.tensor(np.load(cluster_label_path), device='cuda', dtype=torch.long)
49
+
50
+
51
+ def forward(self, logits_, tokens, temperature=None, entfilter=None, filter_topk=5, temperature2=None):
pretrained/.gitkeep ADDED
File without changes
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/Sleepychord/Image-Local-Attention@43fee31
2
+ gradio==3.1.0
3
+ icetk==0.0.4
4
+ imageio==2.19.5
5
+ imageio-ffmpeg==0.4.7
6
+ numpy==1.22.4
7
+ opencv-python-headless==4.6.0.66
8
+ SwissArmyTransformer==0.2.9
9
+ torch==1.12.0
10
+ torchvision==0.13.0
style.css ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
4
+ img#visitor-badge {
5
+ display: block;
6
+ margin: auto;
7
+ }