akhaliq HF staff commited on
Commit
89dc200
1 Parent(s): 05de68d
cluster_label2.npy ADDED
Binary file (160 kB). View file
 
coglm_strategy.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ '''
3
+ @File : coglm_strategy.py
4
+ @Time : 2021/10/08 22:22:42
5
+ @Author : Ming Ding
6
+ @Contact : dm18@mails.tsinghua.edu.cn
7
+ '''
8
+
9
+ # here put the import lib
10
+ import os
11
+ import sys
12
+ import math
13
+ import random
14
+ import torch
15
+ import numpy as np
16
+ import torch.nn.functional as F
17
+
18
+
19
+ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-65504):
20
+ # This function has been mostly taken from huggingface conversational ai code at
21
+ # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313
22
+
23
+ if top_k > 0:
24
+ # Remove all tokens with a probability less than the last token of the top-k
25
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
26
+ logits[indices_to_remove] = filter_value
27
+
28
+ if top_p > 0.0:
29
+ # convert to 1D
30
+ logits = logits.view(logits.size()[1]).contiguous()
31
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
32
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
33
+
34
+ # Remove tokens with cumulative probability above the threshold
35
+ sorted_indices_to_remove = cumulative_probs > top_p
36
+ # Shift the indices to the right to keep also the first token above the threshold
37
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
38
+ sorted_indices_to_remove[..., 0] = 0
39
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
40
+ logits[indices_to_remove] = filter_value
41
+ # going back to 2D
42
+ logits = logits.view(1, -1).contiguous()
43
+
44
+ return logits
45
+
46
+
47
+ class CoglmStrategy:
48
+ def __init__(self, invalid_slices=[], temperature=1., top_k=200, eps=1e-4, top_p=0.0, end_tokens=None, temperature2=0.89):
49
+ self.invalid_slices = invalid_slices
50
+ self.temperature = temperature
51
+ self.temperature2 = temperature2
52
+ self.topk = top_k
53
+ self.top_p = top_p
54
+ self.eps = eps
55
+ if end_tokens is None:
56
+ end_tokens = []
57
+ self.end_tokens = end_tokens
58
+ self._is_done = False
59
+ self.outlier_count_down = torch.zeros(16)
60
+ self.vis_list = [[]for i in range(16)]
61
+ self.cluster_labels = torch.tensor(np.load('cluster_label2.npy'), device='cuda', dtype=torch.long)
62
+ self.start_pos = -1
63
+ self.white_cluster = []
64
+ # self.fout = open('tmp.txt', 'w')
65
+
66
+ @property
67
+ def is_done(self) -> bool:
68
+ return self._is_done
69
+
70
+ def forward(self, logits, tokens, mems, temperature=None, temperature2=None):
71
+ if temperature is None:
72
+ temperature = self.temperature
73
+ if temperature2 is None:
74
+ temperature2 = self.temperature2
75
+ logits = logits / temperature
76
+ for invalid_slice in self.invalid_slices:
77
+ logits[..., invalid_slice] = -65504
78
+
79
+ rprobs = F.softmax(logits.float(), dim=-1)
80
+ c = self.cluster_labels.expand(*rprobs.shape)
81
+ cprobs = torch.zeros(logits.shape[0], 500, device=logits.device).scatter_add_(1, c, rprobs)
82
+ # self.fout.write(str(tokens.shape[-1])+ ' ' + str(cprobs.topk(10)) + '\n')
83
+ # self.fout.flush()
84
+ best_scores, best_clusters = cprobs.topk(self.topk)
85
+ bz = logits.shape[0]
86
+ for i in range(bz):
87
+ selected_cluster = best_clusters[i][torch.multinomial(best_scores[i] / best_scores[i].sum(), num_samples=1)]
88
+ logits[i, self.cluster_labels != selected_cluster] = -65504
89
+
90
+ # logits = top_k_logits(logits, self.topk, self.top_p)
91
+ probs = F.softmax(logits.float()/temperature2, dim=-1) # float is essetial, due to a bug in Pytorch
92
+ pred = torch.multinomial(probs, num_samples=1)
93
+
94
+ if pred.numel() == 1 and pred.item() in self.end_tokens:
95
+ self._is_done = True
96
+ tokens = torch.cat((tokens, pred.view(tokens.shape[0], 1)), dim=1)
97
+ return tokens, mems
98
+
99
+ def finalize(self, tokens, mems):
100
+ self._is_done = False
101
+ return tokens, mems
cogvideo_pipeline.py ADDED
@@ -0,0 +1,793 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ '''
3
+ @File : cogvideo_pipeline.py
4
+ @Time : 2022/07/15 11:24:56
5
+ @Author : Wenyi Hong
6
+ @Version : 1.0
7
+ @Contact : hwy22@mails.tsinghua.edu.cn
8
+ '''
9
+
10
+ # here put the import lib
11
+
12
+ import os
13
+ import sys
14
+ import torch
15
+ import argparse
16
+ import time
17
+ from torchvision.utils import save_image
18
+ import stat
19
+ from icetk import icetk as tokenizer
20
+ import logging, sys
21
+
22
+ import torch.distributed as dist
23
+ tokenizer.add_special_tokens(['<start_of_image>', '<start_of_english>', '<start_of_chinese>'])
24
+
25
+
26
+ from SwissArmyTransformer import get_args
27
+ from SwissArmyTransformer.data_utils import BinaryDataset, make_loaders
28
+ from SwissArmyTransformer.generation.sampling_strategies import BaseStrategy
29
+ from SwissArmyTransformer.generation.utils import timed_name, save_multiple_images, generate_continually
30
+ from SwissArmyTransformer.resources import auto_create
31
+
32
+ from models.cogvideo_cache_model import CogVideoCacheModel
33
+ from coglm_strategy import CoglmStrategy
34
+
35
+
36
+ def get_masks_and_position_ids_stage1(data, textlen, framelen):
37
+ # Extract batch size and sequence length.
38
+ tokens = data
39
+ seq_length = len(data[0])
40
+ # Attention mask (lower triangular).
41
+ attention_mask = torch.ones((1, textlen+framelen, textlen+framelen), device=data.device)
42
+ attention_mask[:, :textlen, textlen:] = 0
43
+ attention_mask[:, textlen:, textlen:].tril_()
44
+ attention_mask.unsqueeze_(1)
45
+ # Unaligned version
46
+ position_ids = torch.zeros(seq_length, dtype=torch.long,
47
+ device=data.device)
48
+ torch.arange(textlen, out=position_ids[:textlen],
49
+ dtype=torch.long, device=data.device)
50
+ torch.arange(512, 512+seq_length-textlen, out=position_ids[textlen:],
51
+ dtype=torch.long, device=data.device)
52
+ position_ids = position_ids.unsqueeze(0)
53
+
54
+ return tokens, attention_mask, position_ids
55
+
56
+ def get_masks_and_position_ids_stage2(data, textlen, framelen):
57
+ # Extract batch size and sequence length.
58
+ tokens = data
59
+ seq_length = len(data[0])
60
+
61
+ # Attention mask (lower triangular).
62
+ attention_mask = torch.ones((1, textlen+framelen, textlen+framelen), device=data.device)
63
+ attention_mask[:, :textlen, textlen:] = 0
64
+ attention_mask[:, textlen:, textlen:].tril_()
65
+ attention_mask.unsqueeze_(1)
66
+
67
+ # Unaligned version
68
+ position_ids = torch.zeros(seq_length, dtype=torch.long,
69
+ device=data.device)
70
+ torch.arange(textlen, out=position_ids[:textlen],
71
+ dtype=torch.long, device=data.device)
72
+ frame_num = (seq_length-textlen)//framelen
73
+ assert frame_num == 5
74
+ torch.arange(512, 512+framelen, out=position_ids[textlen:textlen+framelen],
75
+ dtype=torch.long, device=data.device)
76
+ torch.arange(512+framelen*2, 512+framelen*3, out=position_ids[textlen+framelen:textlen+framelen*2],
77
+ dtype=torch.long, device=data.device)
78
+ torch.arange(512+framelen*(frame_num-1), 512+framelen*frame_num, out=position_ids[textlen+framelen*2:textlen+framelen*3],
79
+ dtype=torch.long, device=data.device)
80
+ torch.arange(512+framelen*1, 512+framelen*2, out=position_ids[textlen+framelen*3:textlen+framelen*4],
81
+ dtype=torch.long, device=data.device)
82
+ torch.arange(512+framelen*3, 512+framelen*4, out=position_ids[textlen+framelen*4:textlen+framelen*5],
83
+ dtype=torch.long, device=data.device)
84
+
85
+ position_ids = position_ids.unsqueeze(0)
86
+
87
+ return tokens, attention_mask, position_ids
88
+
89
+ def my_update_mems(hiddens, mems_buffers, mems_indexs, limited_spatial_channel_mem, text_len, frame_len):
90
+ if hiddens is None:
91
+ return None, mems_indexs
92
+ mem_num = len(hiddens)
93
+ ret_mem = []
94
+ with torch.no_grad():
95
+ for id in range(mem_num):
96
+ if hiddens[id][0] is None:
97
+ ret_mem.append(None)
98
+ else:
99
+ if id == 0 and limited_spatial_channel_mem and mems_indexs[id]+hiddens[0][0].shape[1] >= text_len+frame_len:
100
+ if mems_indexs[id] == 0:
101
+ for layer, hidden in enumerate(hiddens[id]):
102
+ mems_buffers[id][layer, :, :text_len] = hidden.expand(mems_buffers[id].shape[1], -1, -1)[:, :text_len]
103
+ new_mem_len_part2 = (mems_indexs[id]+hiddens[0][0].shape[1]-text_len)%frame_len
104
+ if new_mem_len_part2 > 0:
105
+ for layer, hidden in enumerate(hiddens[id]):
106
+ mems_buffers[id][layer, :, text_len:text_len+new_mem_len_part2] = hidden.expand(mems_buffers[id].shape[1], -1, -1)[:, -new_mem_len_part2:]
107
+ mems_indexs[id] = text_len+new_mem_len_part2
108
+ else:
109
+ for layer, hidden in enumerate(hiddens[id]):
110
+ mems_buffers[id][layer, :, mems_indexs[id]:mems_indexs[id]+hidden.shape[1]] = hidden.expand(mems_buffers[id].shape[1], -1, -1)
111
+ mems_indexs[id] += hidden.shape[1]
112
+ ret_mem.append(mems_buffers[id][:, :, :mems_indexs[id]])
113
+ return ret_mem, mems_indexs
114
+
115
+
116
+ def my_save_multiple_images(imgs, path, subdir, debug=True):
117
+ # imgs: list of tensor images
118
+ if debug:
119
+ imgs = torch.cat(imgs, dim=0)
120
+ print("\nSave to: ", path, flush=True)
121
+ save_image(imgs, path, normalize=True)
122
+ else:
123
+ print("\nSave to: ", path, flush=True)
124
+ single_frame_path = os.path.join(path, subdir)
125
+ os.makedirs(single_frame_path, exist_ok=True)
126
+ for i in range(len(imgs)):
127
+ save_image(imgs[i], os.path.join(single_frame_path, f'{str(i).rjust(4,"0")}.jpg'), normalize=True)
128
+ os.chmod(os.path.join(single_frame_path,f'{str(i).rjust(4,"0")}.jpg'), stat.S_IRWXO+stat.S_IRWXG+stat.S_IRWXU)
129
+ save_image(torch.cat(imgs, dim=0), os.path.join(single_frame_path,f'frame_concat.jpg'), normalize=True)
130
+ os.chmod(os.path.join(single_frame_path,f'frame_concat.jpg'), stat.S_IRWXO+stat.S_IRWXG+stat.S_IRWXU)
131
+
132
+ def calc_next_tokens_frame_begin_id(text_len, frame_len, total_len):
133
+ # The fisrt token's position id of the frame that the next token belongs to;
134
+ if total_len < text_len:
135
+ return None
136
+ return (total_len-text_len)//frame_len * frame_len + text_len
137
+
138
+ def my_filling_sequence(
139
+ model,
140
+ args,
141
+ seq,
142
+ batch_size,
143
+ get_masks_and_position_ids,
144
+ text_len,
145
+ frame_len,
146
+ strategy=BaseStrategy(),
147
+ strategy2=BaseStrategy(),
148
+ mems=None,
149
+ log_text_attention_weights=0, # default to 0: no artificial change
150
+ mode_stage1=True,
151
+ enforce_no_swin=False,
152
+ guider_seq=None,
153
+ guider_text_len=0,
154
+ guidance_alpha=1,
155
+ limited_spatial_channel_mem=False, # 空间通道的存储限制在本帧内
156
+ **kw_args
157
+ ):
158
+ '''
159
+ seq: [2, 3, 5, ..., -1(to be generated), -1, ...]
160
+ mems: [num_layers, batch_size, len_mems(index), mem_hidden_size]
161
+ cache, should be first mems.shape[1] parts of context_tokens.
162
+ mems are the first-level citizens here, but we don't assume what is memorized.
163
+ input mems are used when multi-phase generation.
164
+ '''
165
+ if guider_seq is not None:
166
+ logging.debug("Using Guidance In Inference")
167
+ if limited_spatial_channel_mem:
168
+ logging.debug("Limit spatial-channel's mem to current frame")
169
+ assert len(seq.shape) == 2
170
+
171
+ # building the initial tokens, attention_mask, and position_ids
172
+ actual_context_length = 0
173
+
174
+ while seq[-1][actual_context_length] >= 0: # the last seq has least given tokens
175
+ actual_context_length += 1 # [0, context_length-1] are given
176
+ assert actual_context_length > 0
177
+ current_frame_num = (actual_context_length-text_len) // frame_len
178
+ assert current_frame_num >= 0
179
+ context_length = text_len + current_frame_num * frame_len
180
+
181
+ tokens, attention_mask, position_ids = get_masks_and_position_ids(seq, text_len, frame_len)
182
+ tokens = tokens[..., :context_length]
183
+ input_tokens = tokens.clone()
184
+
185
+ if guider_seq is not None:
186
+ guider_index_delta = text_len - guider_text_len
187
+ guider_tokens, guider_attention_mask, guider_position_ids = get_masks_and_position_ids(guider_seq, guider_text_len, frame_len)
188
+ guider_tokens = guider_tokens[..., :context_length-guider_index_delta]
189
+ guider_input_tokens = guider_tokens.clone()
190
+
191
+ for fid in range(current_frame_num):
192
+ input_tokens[:, text_len+400*fid] = tokenizer['<start_of_image>']
193
+ if guider_seq is not None:
194
+ guider_input_tokens[:, guider_text_len+400*fid] = tokenizer['<start_of_image>']
195
+
196
+ attention_mask = attention_mask.type_as(next(model.parameters())) # if fp16
197
+ # initialize generation
198
+ counter = context_length - 1 # Last fixed index is ``counter''
199
+ index = 0 # Next forward starting index, also the length of cache.
200
+ mems_buffers_on_GPU = False
201
+ mems_indexs = [0, 0]
202
+ mems_len = [(400+74) if limited_spatial_channel_mem else 5*400+74, 5*400+74]
203
+ mems_buffers = [torch.zeros(args.num_layers, batch_size, mem_len, args.hidden_size*2, dtype=next(model.parameters()).dtype)
204
+ for mem_len in mems_len]
205
+
206
+
207
+ if guider_seq is not None:
208
+ guider_attention_mask = guider_attention_mask.type_as(next(model.parameters())) # if fp16
209
+ guider_mems_buffers = [torch.zeros(args.num_layers, batch_size, mem_len, args.hidden_size*2, dtype=next(model.parameters()).dtype)
210
+ for mem_len in mems_len]
211
+ guider_mems_indexs = [0, 0]
212
+ guider_mems = None
213
+
214
+ torch.cuda.empty_cache()
215
+ # step-by-step generation
216
+ while counter < len(seq[0]) - 1:
217
+ # we have generated counter+1 tokens
218
+ # Now, we want to generate seq[counter + 1],
219
+ # token[:, index: counter+1] needs forwarding.
220
+ if index == 0:
221
+ group_size = 2 if (input_tokens.shape[0] == batch_size and not mode_stage1) else batch_size
222
+
223
+ logits_all = None
224
+ for batch_idx in range(0, input_tokens.shape[0], group_size):
225
+ logits, *output_per_layers = model(
226
+ input_tokens[batch_idx:batch_idx+group_size, index:],
227
+ position_ids[..., index: counter+1],
228
+ attention_mask, # TODO memlen
229
+ mems=mems,
230
+ text_len=text_len,
231
+ frame_len=frame_len,
232
+ counter=counter,
233
+ log_text_attention_weights=log_text_attention_weights,
234
+ enforce_no_swin=enforce_no_swin,
235
+ **kw_args
236
+ )
237
+ logits_all = torch.cat((logits_all, logits), dim=0) if logits_all is not None else logits
238
+ mem_kv01 = [[o['mem_kv'][0] for o in output_per_layers], [o['mem_kv'][1] for o in output_per_layers]]
239
+ next_tokens_frame_begin_id = calc_next_tokens_frame_begin_id(text_len, frame_len, mem_kv01[0][0].shape[1])
240
+ for id, mem_kv in enumerate(mem_kv01):
241
+ for layer, mem_kv_perlayer in enumerate(mem_kv):
242
+ if limited_spatial_channel_mem and id == 0:
243
+ mems_buffers[id][layer, batch_idx:batch_idx+group_size, :text_len] = mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1)[:, :text_len]
244
+ mems_buffers[id][layer, batch_idx:batch_idx+group_size, text_len:text_len+mem_kv_perlayer.shape[1]-next_tokens_frame_begin_id] =\
245
+ mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1)[:, next_tokens_frame_begin_id:]
246
+ else:
247
+ mems_buffers[id][layer, batch_idx:batch_idx+group_size, :mem_kv_perlayer.shape[1]] = mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1)
248
+ mems_indexs[0], mems_indexs[1] = mem_kv01[0][0].shape[1], mem_kv01[1][0].shape[1]
249
+ if limited_spatial_channel_mem:
250
+ mems_indexs[0] -= (next_tokens_frame_begin_id - text_len)
251
+
252
+ mems = [mems_buffers[id][:, :, :mems_indexs[id]] for id in range(2)]
253
+ logits = logits_all
254
+
255
+ # Guider
256
+ if guider_seq is not None:
257
+ guider_logits_all = None
258
+ for batch_idx in range(0, guider_input_tokens.shape[0], group_size):
259
+ guider_logits, *guider_output_per_layers = model(
260
+ guider_input_tokens[batch_idx:batch_idx+group_size, max(index-guider_index_delta, 0):],
261
+ guider_position_ids[..., max(index-guider_index_delta, 0): counter+1-guider_index_delta],
262
+ guider_attention_mask,
263
+ mems=guider_mems,
264
+ text_len=guider_text_len,
265
+ frame_len=frame_len,
266
+ counter=counter-guider_index_delta,
267
+ log_text_attention_weights=log_text_attention_weights,
268
+ enforce_no_swin=enforce_no_swin,
269
+ **kw_args
270
+ )
271
+ guider_logits_all = torch.cat((guider_logits_all, guider_logits), dim=0) if guider_logits_all is not None else guider_logits
272
+ guider_mem_kv01 = [[o['mem_kv'][0] for o in guider_output_per_layers], [o['mem_kv'][1] for o in guider_output_per_layers]]
273
+ for id, guider_mem_kv in enumerate(guider_mem_kv01):
274
+ for layer, guider_mem_kv_perlayer in enumerate(guider_mem_kv):
275
+ if limited_spatial_channel_mem and id == 0:
276
+ guider_mems_buffers[id][layer, batch_idx:batch_idx+group_size, :guider_text_len] = guider_mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1)[:, :guider_text_len]
277
+ guider_next_tokens_frame_begin_id = calc_next_tokens_frame_begin_id(guider_text_len, frame_len, guider_mem_kv_perlayer.shape[1])
278
+ guider_mems_buffers[id][layer, batch_idx:batch_idx+group_size, guider_text_len:guider_text_len+guider_mem_kv_perlayer.shape[1]-guider_next_tokens_frame_begin_id] =\
279
+ guider_mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1)[:, guider_next_tokens_frame_begin_id:]
280
+ else:
281
+ guider_mems_buffers[id][layer, batch_idx:batch_idx+group_size, :guider_mem_kv_perlayer.shape[1]] = guider_mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1)
282
+ guider_mems_indexs[0], guider_mems_indexs[1] = guider_mem_kv01[0][0].shape[1], guider_mem_kv01[1][0].shape[1]
283
+ if limited_spatial_channel_mem:
284
+ guider_mems_indexs[0] -= (guider_next_tokens_frame_begin_id-guider_text_len)
285
+ guider_mems = [guider_mems_buffers[id][:, :, :guider_mems_indexs[id]] for id in range(2)]
286
+ guider_logits = guider_logits_all
287
+ else:
288
+ if not mems_buffers_on_GPU:
289
+ if not mode_stage1:
290
+ torch.cuda.empty_cache()
291
+ for idx, mem in enumerate(mems):
292
+ mems[idx] = mem.to(next(model.parameters()).device)
293
+ if guider_seq is not None:
294
+ for idx, mem in enumerate(guider_mems):
295
+ guider_mems[idx] = mem.to(next(model.parameters()).device)
296
+ else:
297
+ torch.cuda.empty_cache()
298
+ for idx, mem_buffer in enumerate(mems_buffers):
299
+ mems_buffers[idx] = mem_buffer.to(next(model.parameters()).device)
300
+ mems = [mems_buffers[id][:, :, :mems_indexs[id]] for id in range(2)]
301
+ if guider_seq is not None:
302
+ for idx, guider_mem_buffer in enumerate(guider_mems_buffers):
303
+ guider_mems_buffers[idx] = guider_mem_buffer.to(next(model.parameters()).device)
304
+ guider_mems = [guider_mems_buffers[id][:, :, :guider_mems_indexs[id]] for id in range(2)]
305
+ mems_buffers_on_GPU = True
306
+
307
+ logits, *output_per_layers = model(
308
+ input_tokens[:, index:],
309
+ position_ids[..., index: counter+1],
310
+ attention_mask, # TODO memlen
311
+ mems=mems,
312
+ text_len=text_len,
313
+ frame_len=frame_len,
314
+ counter=counter,
315
+ log_text_attention_weights=log_text_attention_weights,
316
+ enforce_no_swin=enforce_no_swin,
317
+ limited_spatial_channel_mem=limited_spatial_channel_mem,
318
+ **kw_args
319
+ )
320
+ mem_kv0, mem_kv1 = [o['mem_kv'][0] for o in output_per_layers], [o['mem_kv'][1] for o in output_per_layers]
321
+
322
+ if guider_seq is not None:
323
+ guider_logits, *guider_output_per_layers = model(
324
+ guider_input_tokens[:, max(index-guider_index_delta, 0):],
325
+ guider_position_ids[..., max(index-guider_index_delta, 0): counter+1-guider_index_delta],
326
+ guider_attention_mask,
327
+ mems=guider_mems,
328
+ text_len=guider_text_len,
329
+ frame_len=frame_len,
330
+ counter=counter-guider_index_delta,
331
+ log_text_attention_weights=0,
332
+ enforce_no_swin=enforce_no_swin,
333
+ limited_spatial_channel_mem=limited_spatial_channel_mem,
334
+ **kw_args
335
+ )
336
+ guider_mem_kv0, guider_mem_kv1 = [o['mem_kv'][0] for o in guider_output_per_layers], [o['mem_kv'][1] for o in guider_output_per_layers]
337
+
338
+ if not mems_buffers_on_GPU:
339
+ torch.cuda.empty_cache()
340
+ for idx, mem_buffer in enumerate(mems_buffers):
341
+ mems_buffers[idx] = mem_buffer.to(next(model.parameters()).device)
342
+ if guider_seq is not None:
343
+ for idx, guider_mem_buffer in enumerate(guider_mems_buffers):
344
+ guider_mems_buffers[idx] = guider_mem_buffer.to(next(model.parameters()).device)
345
+ mems_buffers_on_GPU = True
346
+
347
+ mems, mems_indexs = my_update_mems([mem_kv0, mem_kv1], mems_buffers, mems_indexs, limited_spatial_channel_mem, text_len, frame_len)
348
+ if guider_seq is not None:
349
+ guider_mems, guider_mems_indexs = my_update_mems([guider_mem_kv0, guider_mem_kv1], guider_mems_buffers, guider_mems_indexs, limited_spatial_channel_mem, guider_text_len, frame_len)
350
+
351
+
352
+ counter += 1
353
+ index = counter
354
+
355
+ logits = logits[:, -1].expand(batch_size, -1) # [batch size, vocab size]
356
+ tokens = tokens.expand(batch_size, -1)
357
+ if guider_seq is not None:
358
+ guider_logits = guider_logits[:, -1].expand(batch_size, -1)
359
+ guider_tokens = guider_tokens.expand(batch_size, -1)
360
+
361
+ if seq[-1][counter].item() < 0:
362
+ # sampling
363
+ guided_logits = guider_logits+(logits-guider_logits)*guidance_alpha if guider_seq is not None else logits
364
+ if mode_stage1 and counter < text_len + 400:
365
+ tokens, mems = strategy.forward(guided_logits, tokens, mems)
366
+ else:
367
+ tokens, mems = strategy2.forward(guided_logits, tokens, mems)
368
+ if guider_seq is not None:
369
+ guider_tokens = torch.cat((guider_tokens, tokens[:, -1:]), dim=1)
370
+
371
+ if seq[0][counter].item() >= 0:
372
+ for si in range(seq.shape[0]):
373
+ if seq[si][counter].item() >= 0:
374
+ tokens[si, -1] = seq[si, counter]
375
+ if guider_seq is not None:
376
+ guider_tokens[si, -1] = guider_seq[si, counter-guider_index_delta]
377
+
378
+ else:
379
+ tokens = torch.cat((tokens, seq[:, counter:counter+1].clone().expand(tokens.shape[0], 1).to(device=tokens.device, dtype=tokens.dtype)), dim=1)
380
+ if guider_seq is not None:
381
+ guider_tokens = torch.cat((guider_tokens,
382
+ guider_seq[:, counter-guider_index_delta:counter+1-guider_index_delta]
383
+ .clone().expand(guider_tokens.shape[0], 1).to(device=guider_tokens.device, dtype=guider_tokens.dtype)), dim=1)
384
+
385
+ input_tokens = tokens.clone()
386
+ if guider_seq is not None:
387
+ guider_input_tokens = guider_tokens.clone()
388
+ if (index-text_len-1)//400 < (input_tokens.shape[-1]-text_len-1)//400:
389
+ boi_idx = ((index-text_len-1)//400 +1)*400+text_len
390
+ while boi_idx < input_tokens.shape[-1]:
391
+ input_tokens[:, boi_idx] = tokenizer['<start_of_image>']
392
+ if guider_seq is not None:
393
+ guider_input_tokens[:, boi_idx-guider_index_delta] = tokenizer['<start_of_image>']
394
+ boi_idx += 400
395
+
396
+ if strategy.is_done:
397
+ break
398
+ return strategy.finalize(tokens, mems)
399
+
400
+ class InferenceModel_Sequential(CogVideoCacheModel):
401
+ def __init__(self, args, transformer=None, parallel_output=True):
402
+ super().__init__(args, transformer=transformer, parallel_output=parallel_output, window_size=-1, cogvideo_stage=1)
403
+ # TODO: check it
404
+
405
+ def final_forward(self, logits, **kwargs):
406
+ logits_parallel = logits
407
+ logits_parallel = torch.nn.functional.linear(logits_parallel.float(), self.transformer.word_embeddings.weight[:20000].float())
408
+ return logits_parallel
409
+
410
+ class InferenceModel_Interpolate(CogVideoCacheModel):
411
+ def __init__(self, args, transformer=None, parallel_output=True):
412
+ super().__init__(args, transformer=transformer, parallel_output=parallel_output, window_size=10, cogvideo_stage=2)
413
+ # TODO: check it
414
+
415
+ def final_forward(self, logits, **kwargs):
416
+ logits_parallel = logits
417
+ logits_parallel = torch.nn.functional.linear(logits_parallel.float(), self.transformer.word_embeddings.weight[:20000].float())
418
+ return logits_parallel
419
+
420
+ def main(args):
421
+ assert int(args.stage_1) + int(args.stage_2) + int(args.both_stages) == 1
422
+ rank_id = args.device % args.parallel_size
423
+ generate_frame_num = args.generate_frame_num
424
+
425
+ if args.stage_1 or args.both_stages:
426
+ model_stage1, args = InferenceModel_Sequential.from_pretrained(args, 'cogvideo-stage1')
427
+ model_stage1.eval()
428
+ if args.both_stages:
429
+ model_stage1 = model_stage1.cpu()
430
+
431
+ if args.stage_2 or args.both_stages:
432
+ model_stage2, args = InferenceModel_Interpolate.from_pretrained(args, 'cogvideo-stage2')
433
+ model_stage2.eval()
434
+ if args.both_stages:
435
+ model_stage2 = model_stage2.cpu()
436
+
437
+ invalid_slices = [slice(tokenizer.num_image_tokens, None)]
438
+ strategy_cogview2 = CoglmStrategy(invalid_slices,
439
+ temperature=1.0, top_k=16)
440
+ strategy_cogvideo = CoglmStrategy(invalid_slices,
441
+ temperature=args.temperature, top_k=args.top_k,
442
+ temperature2=args.coglm_temperature2)
443
+ if not args.stage_1:
444
+ from sr_pipeline import DirectSuperResolution
445
+ dsr_path = auto_create('cogview2-dsr', path=None) # path=os.getenv('SAT_HOME', '~/.sat_models')
446
+ dsr = DirectSuperResolution(args, dsr_path,
447
+ max_bz=12, onCUDA=False)
448
+
449
+ def process_stage2(model, seq_text, duration, video_raw_text=None, video_guidance_text="视频", parent_given_tokens=None, conddir=None, outputdir=None, gpu_rank=0, gpu_parallel_size=1):
450
+ stage2_starttime = time.time()
451
+ use_guidance = args.use_guidance_stage2
452
+ if args.both_stages:
453
+ move_start_time = time.time()
454
+ logging.debug("moving stage-2 model to cuda")
455
+ model = model.cuda()
456
+ logging.debug("moving in stage-2 model takes time: {:.2f}".format(time.time()-move_start_time))
457
+
458
+ try:
459
+ if parent_given_tokens is None:
460
+ assert conddir is not None
461
+ parent_given_tokens = torch.load(os.path.join(conddir, 'frame_tokens.pt'), map_location='cpu')
462
+ sample_num_allgpu = parent_given_tokens.shape[0]
463
+ sample_num = sample_num_allgpu // gpu_parallel_size
464
+ assert sample_num * gpu_parallel_size == sample_num_allgpu
465
+ parent_given_tokens = parent_given_tokens[gpu_rank*sample_num:(gpu_rank+1)*sample_num]
466
+ except:
467
+ logging.critical("No frame_tokens found in interpolation, skip")
468
+ return False
469
+
470
+ # CogVideo Stage2 Generation
471
+ while duration >= 0.5: # TODO: You can change the boundary to change the frame rate
472
+ parent_given_tokens_num = parent_given_tokens.shape[1]
473
+ generate_batchsize_persample = (parent_given_tokens_num-1)//2
474
+ generate_batchsize_total = generate_batchsize_persample * sample_num
475
+ total_frames = generate_frame_num
476
+ frame_len = 400
477
+ enc_text = tokenizer.encode(seq_text)
478
+ enc_duration = tokenizer.encode(str(float(duration))+"秒")
479
+ seq = enc_duration + [tokenizer['<n>']] + enc_text + [tokenizer['<start_of_image>']] + [-1]*400*generate_frame_num
480
+ text_len = len(seq) - frame_len*generate_frame_num - 1
481
+
482
+ logging.info("[Stage2: Generating Frames, Frame Rate {:d}]\nraw text: {:s}".format(int(4/duration), tokenizer.decode(enc_text)))
483
+
484
+ # generation
485
+ seq = torch.cuda.LongTensor(seq, device=args.device).unsqueeze(0).repeat(generate_batchsize_total, 1)
486
+ for sample_i in range(sample_num):
487
+ for i in range(generate_batchsize_persample):
488
+ seq[sample_i*generate_batchsize_persample+i][text_len+1:text_len+1+400] = parent_given_tokens[sample_i][2*i]
489
+ seq[sample_i*generate_batchsize_persample+i][text_len+1+400:text_len+1+800] = parent_given_tokens[sample_i][2*i+1]
490
+ seq[sample_i*generate_batchsize_persample+i][text_len+1+800:text_len+1+1200] = parent_given_tokens[sample_i][2*i+2]
491
+
492
+ if use_guidance:
493
+ guider_seq = enc_duration + [tokenizer['<n>']] + tokenizer.encode(video_guidance_text) + [tokenizer['<start_of_image>']] + [-1]*400*generate_frame_num
494
+ guider_text_len = len(guider_seq) - frame_len*generate_frame_num - 1
495
+ guider_seq = torch.cuda.LongTensor(guider_seq, device=args.device).unsqueeze(0).repeat(generate_batchsize_total, 1)
496
+ for sample_i in range(sample_num):
497
+ for i in range(generate_batchsize_persample):
498
+ guider_seq[sample_i*generate_batchsize_persample+i][text_len+1:text_len+1+400] = parent_given_tokens[sample_i][2*i]
499
+ guider_seq[sample_i*generate_batchsize_persample+i][text_len+1+400:text_len+1+800] = parent_given_tokens[sample_i][2*i+1]
500
+ guider_seq[sample_i*generate_batchsize_persample+i][text_len+1+800:text_len+1+1200] = parent_given_tokens[sample_i][2*i+2]
501
+ video_log_text_attention_weights = 0
502
+ else:
503
+ guider_seq=None
504
+ guider_text_len=0
505
+ video_log_text_attention_weights = 1.4
506
+
507
+ mbz = args.max_inference_batch_size
508
+
509
+ assert generate_batchsize_total < mbz or generate_batchsize_total % mbz == 0
510
+ output_list = []
511
+ start_time = time.time()
512
+ for tim in range(max(generate_batchsize_total // mbz, 1)):
513
+ input_seq = seq[:min(generate_batchsize_total, mbz)].clone() if tim == 0 else seq[mbz*tim:mbz*(tim+1)].clone()
514
+ guider_seq2 = (guider_seq[:min(generate_batchsize_total, mbz)].clone() if tim == 0 else guider_seq[mbz*tim:mbz*(tim+1)].clone()) if guider_seq is not None else None
515
+ output_list.append(
516
+ my_filling_sequence(model, args, input_seq,
517
+ batch_size=min(generate_batchsize_total, mbz),
518
+ get_masks_and_position_ids=get_masks_and_position_ids_stage2,
519
+ text_len=text_len, frame_len=frame_len,
520
+ strategy=strategy_cogview2,
521
+ strategy2=strategy_cogvideo,
522
+ log_text_attention_weights=video_log_text_attention_weights,
523
+ mode_stage1=False,
524
+ guider_seq=guider_seq2,
525
+ guider_text_len=guider_text_len,
526
+ guidance_alpha=args.guidance_alpha,
527
+ limited_spatial_channel_mem=True,
528
+ )[0]
529
+ )
530
+ logging.info("Duration {:.2f}, Taken time {:.2f}\n".format(duration, time.time() - start_time))
531
+
532
+ output_tokens = torch.cat(output_list, dim=0)
533
+ output_tokens = output_tokens[:, text_len+1:text_len+1+(total_frames)*400].reshape(sample_num, -1, 400*total_frames)
534
+ output_tokens_merge = torch.cat((output_tokens[:, :, :1*400],
535
+ output_tokens[:, :, 400*3:4*400],
536
+ output_tokens[:, :, 400*1:2*400],
537
+ output_tokens[:, :, 400*4:(total_frames)*400]), dim=2).reshape(sample_num, -1, 400)
538
+
539
+ output_tokens_merge = torch.cat((output_tokens_merge, output_tokens[:, -1:, 400*2:3*400]), dim=1)
540
+ duration /= 2
541
+ parent_given_tokens = output_tokens_merge
542
+
543
+ if args.both_stages:
544
+ move_start_time = time.time()
545
+ logging.debug("moving stage 2 model to cpu")
546
+ model = model.cpu()
547
+ torch.cuda.empty_cache()
548
+ logging.debug("moving out model2 takes time: {:.2f}".format(time.time()-move_start_time))
549
+
550
+ logging.info("CogVideo Stage2 completed. Taken time {:.2f}\n".format(time.time() - stage2_starttime))
551
+
552
+ # decoding
553
+ # imgs = [torch.nn.functional.interpolate(tokenizer.decode(image_ids=seq.tolist()), size=(480, 480)) for seq in output_tokens_merge]
554
+ # os.makedirs(output_dir_full_path, exist_ok=True)
555
+ # my_save_multiple_images(imgs, output_dir_full_path,subdir="frames", debug=False)
556
+ # torch.save(output_tokens_merge.cpu(), os.path.join(output_dir_full_path, 'frame_token.pt'))
557
+ # os.system(f"gifmaker -i '{output_dir_full_path}'/frames/0*.jpg -o '{output_dir_full_path}/{str(float(duration))}_concat.gif' -d 0.2")
558
+
559
+ # direct super-resolution by CogView2
560
+ logging.info("[Direct super-resolution]")
561
+ dsr_starttime = time.time()
562
+ enc_text = tokenizer.encode(seq_text)
563
+ frame_num_per_sample = parent_given_tokens.shape[1]
564
+ parent_given_tokens_2d = parent_given_tokens.reshape(-1, 400)
565
+ text_seq = torch.cuda.LongTensor(enc_text, device=args.device).unsqueeze(0).repeat(parent_given_tokens_2d.shape[0], 1)
566
+ sred_tokens = dsr(text_seq, parent_given_tokens_2d)
567
+ decoded_sr_videos = []
568
+
569
+ for sample_i in range(sample_num):
570
+ decoded_sr_imgs = []
571
+ for frame_i in range(frame_num_per_sample):
572
+ decoded_sr_img = tokenizer.decode(image_ids=sred_tokens[frame_i+sample_i*frame_num_per_sample][-3600:])
573
+ decoded_sr_imgs.append(torch.nn.functional.interpolate(decoded_sr_img, size=(480, 480)))
574
+ decoded_sr_videos.append(decoded_sr_imgs)
575
+
576
+ for sample_i in range(sample_num):
577
+ my_save_multiple_images(decoded_sr_videos[sample_i], outputdir,subdir=f"frames/{sample_i+sample_num*gpu_rank}", debug=False)
578
+ os.system(f"gifmaker -i '{outputdir}'/frames/'{sample_i+sample_num*gpu_rank}'/0*.jpg -o '{outputdir}/{sample_i+sample_num*gpu_rank}.gif' -d 0.125")
579
+
580
+ logging.info("Direct super-resolution completed. Taken time {:.2f}\n".format(time.time() - dsr_starttime))
581
+
582
+ return True
583
+
584
+
585
+ def process_stage1(model, seq_text, duration, video_raw_text=None, video_guidance_text="视频", image_text_suffix="", outputdir=None, batch_size=1):
586
+ process_start_time = time.time()
587
+ use_guide = args.use_guidance_stage1
588
+ if args.both_stages:
589
+ move_start_time = time.time()
590
+ logging.debug("moving stage 1 model to cuda")
591
+ model = model.cuda()
592
+ logging.debug("moving in model1 takes time: {:.2f}".format(time.time()-move_start_time))
593
+
594
+ if video_raw_text is None:
595
+ video_raw_text = seq_text
596
+ mbz = args.stage1_max_inference_batch_size if args.stage1_max_inference_batch_size > 0 else args.max_inference_batch_size
597
+ assert batch_size < mbz or batch_size % mbz == 0
598
+ frame_len = 400
599
+
600
+ # generate the first frame:
601
+ enc_text = tokenizer.encode(seq_text+image_text_suffix)
602
+ seq_1st = enc_text + [tokenizer['<start_of_image>']] + [-1]*400 # IV!! # test local!!! # test randboi!!!
603
+ logging.info("[Generating First Frame with CogView2]Raw text: {:s}".format(tokenizer.decode(enc_text)))
604
+ text_len_1st = len(seq_1st) - frame_len*1 - 1
605
+
606
+ seq_1st = torch.cuda.LongTensor(seq_1st, device=args.device).unsqueeze(0)
607
+ output_list_1st = []
608
+ for tim in range(max(batch_size // mbz, 1)):
609
+ start_time = time.time()
610
+ output_list_1st.append(
611
+ my_filling_sequence(model, args,seq_1st.clone(),
612
+ batch_size=min(batch_size, mbz),
613
+ get_masks_and_position_ids=get_masks_and_position_ids_stage1,
614
+ text_len=text_len_1st,
615
+ frame_len=frame_len,
616
+ strategy=strategy_cogview2,
617
+ strategy2=strategy_cogvideo,
618
+ log_text_attention_weights=1.4,
619
+ enforce_no_swin=True,
620
+ mode_stage1=True,
621
+ )[0]
622
+ )
623
+ logging.info("[First Frame]Taken time {:.2f}\n".format(time.time() - start_time))
624
+ output_tokens_1st = torch.cat(output_list_1st, dim=0)
625
+ given_tokens = output_tokens_1st[:, text_len_1st+1:text_len_1st+401].unsqueeze(1) # given_tokens.shape: [bs, frame_num, 400]
626
+
627
+ # generate subsequent frames:
628
+ total_frames = generate_frame_num
629
+ enc_duration = tokenizer.encode(str(float(duration))+"秒")
630
+ if use_guide:
631
+ video_raw_text = video_raw_text + " 视频"
632
+ enc_text_video = tokenizer.encode(video_raw_text)
633
+ seq = enc_duration + [tokenizer['<n>']] + enc_text_video + [tokenizer['<start_of_image>']] + [-1]*400*generate_frame_num
634
+ guider_seq = enc_duration + [tokenizer['<n>']] + tokenizer.encode(video_guidance_text) + [tokenizer['<start_of_image>']] + [-1]*400*generate_frame_num
635
+ logging.info("[Stage1: Generating Subsequent Frames, Frame Rate {:.1f}]\nraw text: {:s}".format(4/duration, tokenizer.decode(enc_text_video)))
636
+
637
+ text_len = len(seq) - frame_len*generate_frame_num - 1
638
+ guider_text_len = len(guider_seq) - frame_len*generate_frame_num - 1
639
+ seq = torch.cuda.LongTensor(seq, device=args.device).unsqueeze(0).repeat(batch_size, 1)
640
+ guider_seq = torch.cuda.LongTensor(guider_seq, device=args.device).unsqueeze(0).repeat(batch_size, 1)
641
+
642
+ for given_frame_id in range(given_tokens.shape[1]):
643
+ seq[:, text_len+1+given_frame_id*400: text_len+1+(given_frame_id+1)*400] = given_tokens[:, given_frame_id]
644
+ guider_seq[:, guider_text_len+1+given_frame_id*400:guider_text_len+1+(given_frame_id+1)*400] = given_tokens[:, given_frame_id]
645
+ output_list = []
646
+
647
+ if use_guide:
648
+ video_log_text_attention_weights = 0
649
+ else:
650
+ guider_seq = None
651
+ video_log_text_attention_weights = 1.4
652
+
653
+ for tim in range(max(batch_size // mbz, 1)):
654
+ start_time = time.time()
655
+ input_seq = seq[:min(batch_size, mbz)].clone() if tim == 0 else seq[mbz*tim:mbz*(tim+1)].clone()
656
+ guider_seq2 = (guider_seq[:min(batch_size, mbz)].clone() if tim == 0 else guider_seq[mbz*tim:mbz*(tim+1)].clone()) if guider_seq is not None else None
657
+ output_list.append(
658
+ my_filling_sequence(model, args,input_seq,
659
+ batch_size=min(batch_size, mbz),
660
+ get_masks_and_position_ids=get_masks_and_position_ids_stage1,
661
+ text_len=text_len, frame_len=frame_len,
662
+ strategy=strategy_cogview2,
663
+ strategy2=strategy_cogvideo,
664
+ log_text_attention_weights=video_log_text_attention_weights,
665
+ guider_seq=guider_seq2,
666
+ guider_text_len=guider_text_len,
667
+ guidance_alpha=args.guidance_alpha,
668
+ limited_spatial_channel_mem=True,
669
+ mode_stage1=True,
670
+ )[0]
671
+ )
672
+
673
+ output_tokens = torch.cat(output_list, dim=0)[:, 1+text_len:]
674
+
675
+ if args.both_stages:
676
+ move_start_time = time.time()
677
+ logging.debug("moving stage 1 model to cpu")
678
+ model = model.cpu()
679
+ torch.cuda.empty_cache()
680
+ logging.debug("moving in model1 takes time: {:.2f}".format(time.time()-move_start_time))
681
+
682
+ # decoding
683
+ imgs, sred_imgs, txts = [], [], []
684
+ for seq in output_tokens:
685
+ decoded_imgs = [torch.nn.functional.interpolate(tokenizer.decode(image_ids=seq.tolist()[i*400: (i+1)*400]), size=(480, 480)) for i in range(total_frames)]
686
+ imgs.append(decoded_imgs) # only the last image (target)
687
+
688
+ assert len(imgs) == batch_size
689
+ save_tokens = output_tokens[:, :+total_frames*400].reshape(-1, total_frames, 400).cpu()
690
+ if outputdir is not None:
691
+ for clip_i in range(len(imgs)):
692
+ # os.makedirs(output_dir_full_paths[clip_i], exist_ok=True)
693
+ my_save_multiple_images(imgs[clip_i], outputdir, subdir=f"frames/{clip_i}", debug=False)
694
+ os.system(f"gifmaker -i '{outputdir}'/frames/'{clip_i}'/0*.jpg -o '{outputdir}/{clip_i}.gif' -d 0.25")
695
+ torch.save(save_tokens, os.path.join(outputdir, 'frame_tokens.pt'))
696
+
697
+ logging.info("CogVideo Stage1 completed. Taken time {:.2f}\n".format(time.time() - process_start_time))
698
+
699
+ return save_tokens
700
+
701
+ # ======================================================================================================
702
+
703
+ if args.stage_1 or args.both_stages:
704
+ if args.input_source != "interactive":
705
+ with open(args.input_source, 'r') as fin:
706
+ promptlist = fin.readlines()
707
+ promptlist = [p.strip() for p in promptlist]
708
+ else:
709
+ promptlist = None
710
+
711
+ now_qi = -1
712
+ while True:
713
+ now_qi += 1
714
+
715
+ if promptlist is not None: # with input-source
716
+ if args.multi_gpu:
717
+ if now_qi % dist.get_world_size() != dist.get_rank():
718
+ continue
719
+ rk = dist.get_rank()
720
+ else:
721
+ rk = 0
722
+ raw_text = promptlist[now_qi]
723
+ raw_text = raw_text.strip()
724
+ print(f'Working on Line No. {now_qi} on {rk}... [{raw_text}]')
725
+ else: # interactive
726
+ raw_text = input("\nPlease Input Query (stop to exit) >>> ")
727
+ raw_text = raw_text.strip()
728
+ if not raw_text:
729
+ print('Query should not be empty!')
730
+ continue
731
+ if raw_text == "stop":
732
+ return
733
+
734
+ try:
735
+ path = os.path.join(args.output_path, f"{now_qi}_{raw_text}")
736
+ parent_given_tokens = process_stage1(model_stage1, raw_text, duration=4.0, video_raw_text=raw_text, video_guidance_text="视频",
737
+ image_text_suffix=" 高清摄影",
738
+ outputdir=path if args.stage_1 else None, batch_size=args.batch_size)
739
+ if args.both_stages:
740
+ process_stage2(model_stage2, raw_text, duration=2.0, video_raw_text=raw_text+" 视频",
741
+ video_guidance_text="视频", parent_given_tokens=parent_given_tokens,
742
+ outputdir=path,
743
+ gpu_rank=0, gpu_parallel_size=1) # TODO: 修改
744
+ except (ValueError, FileNotFoundError) as e:
745
+ print(e)
746
+ continue
747
+
748
+ elif args.stage_2:
749
+ sample_dirs = os.listdir(args.output_path)
750
+ for sample in sample_dirs:
751
+ raw_text = sample.split('_')[-1]
752
+ path = os.path.join(args.output_path, sample, 'Interp')
753
+ parent_given_tokens = torch.load(os.path.join(args.output_path, sample, "frame_tokens.pt"))
754
+
755
+ process_stage2(raw_text, duration=2.0, video_raw_text=raw_text+" 视频",
756
+ video_guidance_text="视频", parent_given_tokens=parent_given_tokens,
757
+ outputdir=path,
758
+ gpu_rank=0, gpu_parallel_size=1) # TODO: 修改
759
+
760
+ else:
761
+ assert False
762
+
763
+
764
+ if __name__ == "__main__":
765
+ logging.basicConfig(stream=sys.stderr, level=logging.DEBUG)
766
+
767
+ py_parser = argparse.ArgumentParser(add_help=False)
768
+ py_parser.add_argument('--generate-frame-num', type=int, default=5)
769
+ py_parser.add_argument('--coglm-temperature2', type=float, default=0.89)
770
+ # py_parser.add_argument("--interp-duration", type=float, default=-1) # -1是顺序生成,0是超分,0.5/1/2是插帧
771
+ # py_parser.add_argument("--total-duration", type=float, default=4.0) # 整个的时间
772
+ py_parser.add_argument('--use-guidance-stage1', action='store_true')
773
+ py_parser.add_argument('--use-guidance-stage2', action='store_true')
774
+ py_parser.add_argument('--guidance-alpha', type=float, default=3.0)
775
+ py_parser.add_argument('--stage-1', action='store_true') # stage 1: sequential generation
776
+ py_parser.add_argument('--stage-2', action='store_true') # stage 2: interp + dsr
777
+ py_parser.add_argument('--both-stages', action='store_true') # stage 1&2: sequential generation; interp + dsr
778
+ py_parser.add_argument('--parallel-size', type=int, default=1)
779
+ py_parser.add_argument('--stage1-max-inference-batch-size', type=int, default=-1) # -1: use max-inference-batch-size
780
+ py_parser.add_argument('--multi-gpu', action='store_true')
781
+
782
+ CogVideoCacheModel.add_model_specific_args(py_parser)
783
+
784
+ known, args_list = py_parser.parse_known_args()
785
+ args = get_args(args_list)
786
+ args = argparse.Namespace(**vars(args), **vars(known))
787
+ args.layout = [int(x) for x in args.layout.split(',')]
788
+ args.do_train = False
789
+
790
+ torch.cuda.set_device(args.device)
791
+
792
+ with torch.no_grad():
793
+ main(args)
models/cogvideo_cache_model.py ADDED
@@ -0,0 +1,695 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ '''
3
+ @File : cogvideo_cache_model.py
4
+ @Time : 2022/07/15 11:22:19
5
+ @Author : Wenyi Hong
6
+ @Version : 1.0
7
+ @Contact : hwy22@mails.tsinghua.edu.cn
8
+ '''
9
+
10
+ # here put the import lib
11
+
12
+ from multiprocessing import context
13
+ from tkinter import E
14
+ import torch
15
+ from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin
16
+
17
+ from SwissArmyTransformer.mpu.utils import split_tensor_along_last_dim
18
+ from SwissArmyTransformer.model.transformer import unscaled_init_method
19
+ from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear
20
+ import torch.nn.functional as F
21
+ from deepspeed.runtime.activation_checkpointing.checkpointing import get_cuda_rng_tracker
22
+ import math
23
+
24
+
25
+ class PositionEmbeddingMixin(BaseMixin):
26
+ def __init__(self, additional_sequence_length, hidden_size,
27
+ init_method_std=0.02, reinit_slice=slice(512, 912),
28
+ ):
29
+ super(PositionEmbeddingMixin, self).__init__()
30
+ self.reinit_slice = reinit_slice
31
+ self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size)
32
+ torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
33
+
34
+ def reinit(self, parent_model=None):
35
+ old_weights = self.transformer.position_embeddings.weight.data[self.reinit_slice]
36
+ old_len, hidden_size = old_weights.shape
37
+ assert hidden_size == self.position_embeddings.weight.shape[-1]
38
+ self.position_embeddings.weight.data.view(-1, old_len, hidden_size).copy_(old_weights)
39
+
40
+
41
+ def window_partition(x, window_size):
42
+ """
43
+ Args:
44
+ x: (B, framenum, H, W, C)
45
+ window_size (int): window size
46
+ Returns:
47
+ windows: (num_windows*B, frame_num, window_size, window_size, C)
48
+ """
49
+ B, framenum, H, W, C = x.shape
50
+ x = x.view(B, framenum, H // window_size, window_size, W // window_size, window_size, C)
51
+ windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(-1, framenum, window_size, window_size, C)
52
+ return windows
53
+
54
+ def window_reverse(windows, window_size, H, W):
55
+ """
56
+ Args:
57
+ windows: (num_windows*B, frame_num, window_size, window_size, C)
58
+ window_size (int): Window size
59
+ H (int): Height of image
60
+ W (int): Width of image
61
+ Returns:
62
+ x: (B, frame_num, H, W, C)
63
+ """
64
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
65
+ framenum = windows.shape[1]
66
+ x = windows.view(B, H // window_size, W // window_size, framenum, window_size, window_size, -1)
67
+ x = x.permute(0, 3, 1, 4, 2, 5, 6).contiguous().view(B, framenum, H, W, -1)
68
+ return x
69
+
70
+ class WindowAttentionMixin(BaseMixin):
71
+ def __init__(self, num_layers,
72
+ hidden_size,
73
+ frame_resolution,
74
+ window_size,
75
+ shift_size,
76
+ n_head,
77
+ frame_num,
78
+ init_method=unscaled_init_method(0.02),
79
+ output_layer_init_method=unscaled_init_method(0.02),
80
+ time_dim_attend_length=0
81
+ ):
82
+ super(WindowAttentionMixin, self).__init__()
83
+ self.num_layers = num_layers # replace attention in the LAST n layers
84
+ self.query_key_value = torch.nn.ModuleList(
85
+ [ColumnParallelLinear(hidden_size, 3*hidden_size,stride=3,
86
+ gather_output=False,init_method=init_method)
87
+ for layer_id in range(num_layers)
88
+ ])
89
+ self.dense = torch.nn.ModuleList(
90
+ [RowParallelLinear(
91
+ hidden_size,
92
+ hidden_size,
93
+ input_is_parallel=True,
94
+ init_method=output_layer_init_method,
95
+ bias=True,
96
+ module=self,
97
+ name="dense")
98
+ for layer_id in range(num_layers)
99
+ ])
100
+
101
+ self.n_head = n_head
102
+ self.window_size = window_size
103
+ self.frame_resolution = frame_resolution
104
+ self.frame_len = frame_resolution * frame_resolution
105
+ self.time_dim_attend_length = time_dim_attend_length
106
+ assert frame_resolution % window_size == 0
107
+ assert 0 < shift_size < window_size
108
+ nW = (self.frame_resolution // self.window_size) ** 2
109
+ ws_squre = self.window_size * self.window_size
110
+
111
+ # odd non-shift, even shift
112
+ img_mask = torch.zeros((1, 1, frame_resolution, frame_resolution, 1))
113
+ h_slices = (slice(0, -shift_size),
114
+ slice(-shift_size, None))
115
+ w_slices = (slice(0, -shift_size),
116
+ slice(-shift_size, None))
117
+ cnt = 0
118
+ for h in h_slices:
119
+ for w in w_slices:
120
+ img_mask[:, :, h, w, :] = cnt
121
+ cnt += 1
122
+ mask_windows = window_partition(img_mask, self.window_size) # nW, 1, window_size, window_size, 1
123
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
124
+ sub_attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) #[nW, self.window_size * self.window_size, self.window_size * self.window_size]
125
+ sub_attn_mask = sub_attn_mask.masked_fill(sub_attn_mask != 0, float(0.0)).masked_fill(sub_attn_mask == 0, float(1.00))
126
+ attn_mask = sub_attn_mask.repeat(1, frame_num, frame_num)
127
+ attn_mask = attn_mask.tril()
128
+
129
+ causal_mask = torch.ones(ws_squre*frame_num, ws_squre*frame_num)
130
+ causal_mask = causal_mask.tril()
131
+
132
+ self.shift_sizes = [0, shift_size]
133
+ self.attn_mask = attn_mask
134
+ self.causal_mask = causal_mask
135
+ self.mask_initialized = False
136
+
137
+ self.attn_distribution = torch.nn.ParameterList([
138
+ torch.nn.Parameter(torch.zeros(hidden_size))
139
+ for _ in range(num_layers)
140
+ ])
141
+
142
+ def reinit(self, *pre_mixins):
143
+ start_layer = len(self.transformer.layers) - self.num_layers
144
+ assert start_layer >= 0
145
+ for layer_id in range(self.num_layers):
146
+ old_attention = self.transformer.layers[start_layer + layer_id].attention
147
+ self.query_key_value[layer_id].weight.data.copy_(old_attention.query_key_value.weight.data)
148
+ self.query_key_value[layer_id].bias.data.copy_(old_attention.query_key_value.bias.data)
149
+
150
+ def attention_extra_NAR_inference(self, frame_hidden_state, layer_id, attn_dropout=None, memkv_text=None, stage=1):
151
+ # frame_hidden_state [batchsize, frame_num*frame_size, n_head*hiddensize_perhead]
152
+ if not self.mask_initialized:
153
+ self.attn_mask = self.attn_mask.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype)
154
+ self.causal_mask = self.causal_mask.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype)
155
+ self.mask_initialized = True
156
+ b0, s1, h0 = frame_hidden_state.shape
157
+ h = h0 // self.n_head
158
+ frame_len = self.frame_resolution * self.frame_resolution
159
+ frame_num = s1 // frame_len
160
+ if stage == 2:
161
+ assert frame_num == 3
162
+ assert frame_num*frame_len == s1
163
+ wind_square = self.window_size * self.window_size
164
+ nW = frame_len // wind_square
165
+ bswin = b0 * nW
166
+
167
+ if memkv_text is not None:
168
+ s0 = memkv_text.shape[-2]
169
+ k_text = memkv_text[..., :h0].expand(b0, -1, -1).reshape(b0, s0, self.n_head, h).permute(0, 2, 1, 3)
170
+ v_text = memkv_text[..., h0:].expand(b0, -1, -1).reshape(b0, s0, self.n_head, h).permute(0, 2, 1, 3)
171
+
172
+ # shift
173
+ frame_hidden_state = frame_hidden_state.reshape(b0, frame_num, self.frame_resolution, self.frame_resolution, h0)
174
+ if self.shift_sizes[layer_id%2] > 0:
175
+ frame_hidden_state = torch.roll(frame_hidden_state, shifts=(-self.shift_sizes[layer_id%2], -self.shift_sizes[layer_id%2]), dims=(2,3))
176
+ # window partition
177
+ frame_hidden_state = window_partition(frame_hidden_state, self.window_size).reshape(bswin, frame_num*wind_square, h0)
178
+ qkv = self.query_key_value[layer_id](frame_hidden_state).reshape(bswin, frame_num*wind_square, 3, self.n_head, h)\
179
+ .permute(2, 0, 3, 1, 4) #[3, bswin, n_head, frame_num*wind_size*wind_size, h]
180
+ q, k, v = qkv[0], qkv[1], qkv[2]
181
+ attn = torch.matmul(q / math.sqrt(h), k.transpose(-1, -2))
182
+
183
+ if stage == 1:
184
+ if self.shift_sizes[layer_id%2] > 0:
185
+ attn = torch.mul(attn.view(bswin // nW, nW, self.n_head, frame_num*wind_square, frame_num*wind_square),
186
+ self.attn_mask[:,:frame_num*wind_square, :frame_num*wind_square].unsqueeze(1).unsqueeze(0))\
187
+ - 10000.0 * (1.0 - self.attn_mask[:,:frame_num*wind_square, :frame_num*wind_square].unsqueeze(1).unsqueeze(0))
188
+ attn = attn.view(bswin, self.n_head, frame_num*wind_square, frame_num*wind_square)
189
+ else:
190
+ attn = torch.mul(attn, self.causal_mask[:frame_num*wind_square, :frame_num*wind_square].unsqueeze(0).unsqueeze(0))\
191
+ - 10000.0 * (1.0 - self.causal_mask[:frame_num*wind_square, :frame_num*wind_square].unsqueeze(0).unsqueeze(0))
192
+
193
+ if memkv_text is None:
194
+ attn = F.softmax(attn, dim=-1)
195
+ if attn_dropout is not None:
196
+ with get_cuda_rng_tracker().fork():
197
+ attn = attn_dropout(attn)
198
+ context_swin = torch.matmul(attn, v).permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0)
199
+ else:
200
+ attn_frame2text = torch.matmul(q.reshape(b0, -1, self.n_head, frame_num*wind_square, h) / math.sqrt(h), k_text.unsqueeze(1).transpose(-1, -2))
201
+ attn_frame2text = attn_frame2text.reshape(bswin, self.n_head, frame_num*wind_square, s0)
202
+ attn = torch.cat((attn, attn_frame2text), dim=-1)
203
+ attn = F.softmax(attn, dim=-1)
204
+
205
+ if attn_dropout is not None:
206
+ with get_cuda_rng_tracker().fork():
207
+ attn = attn_dropout(attn)
208
+
209
+ context_swin = (torch.matmul(attn[..., :-s0], v) +
210
+ torch.matmul(attn[..., -s0:].reshape(b0, -1, self.n_head,frame_num*wind_square, s0), v_text.unsqueeze(1))\
211
+ .reshape(bswin, self.n_head, frame_num*wind_square, h))\
212
+ .permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0)
213
+
214
+ context_swin = window_reverse(context_swin, self.window_size, self.frame_resolution, self.frame_resolution)
215
+
216
+ # reverse cycle shift
217
+ if self.shift_sizes[layer_id%2] > 0:
218
+ context_swin = torch.roll(context_swin, shifts=(self.shift_sizes[layer_id%2], self.shift_sizes[layer_id%2]), dims=(2,3))
219
+ ret_context = context_swin.reshape(b0, s1, h0)
220
+
221
+ # for mem
222
+ memk = k.permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0)
223
+ memv = v.permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0)
224
+ memk = window_reverse(memk, self.window_size, self.frame_resolution, self.frame_resolution)
225
+ memv = window_reverse(memv, self.window_size, self.frame_resolution, self.frame_resolution)
226
+ if self.shift_sizes[layer_id%2] > 0:
227
+ memk = torch.roll(memk, shifts=(self.shift_sizes[layer_id%2], self.shift_sizes[layer_id%2]), dims=(2,3))
228
+ memv = torch.roll(memv, shifts=(self.shift_sizes[layer_id%2], self.shift_sizes[layer_id%2]), dims=(2,3))
229
+ memk, memv = memk.reshape(b0, s1, h0), memv.reshape(b0, s1, h0)
230
+
231
+ ret_mem = torch.cat((memk, memv), dim=-1)
232
+ return ret_context, ret_mem
233
+
234
+ def attention_extra_AR_inference(self, frame_hidden_state, memkv, pos, layer_id, log_text_attention_weights=0, attn_dropout=None, memkv_text=None, stage=1):
235
+ # frame_hidden_state [batchsize, 1, n_head*hiddensize_perhead]
236
+ # memkv [batchsize, pos, hidden_size*2] (include frames only)
237
+ # if memkv_text is not None: will attend to text
238
+ # pos: token's pos
239
+ b0, sin, h0 = frame_hidden_state.shape
240
+ h = h0 // self.n_head
241
+ assert sin == 1
242
+ this_qkv = self.query_key_value[layer_id](frame_hidden_state)
243
+ thisq, thisk, thisv = this_qkv[..., :h0], this_qkv[..., h0:2*h0], this_qkv[..., 2*h0:]
244
+ s1 = memkv.shape[1] if memkv is not None else 0
245
+ frame_len = self.frame_resolution * self.frame_resolution
246
+ frame_num_before = s1 // frame_len
247
+
248
+
249
+ if memkv is not None:
250
+ pos_inframe = pos - frame_num_before * frame_len
251
+
252
+ xpos = pos_inframe // self.frame_resolution # pos = xpos*self.frame_resolution + ypos
253
+ ypos = pos_inframe % self.frame_resolution
254
+ # [start, end)
255
+ if self.shift_sizes[layer_id%2] > 0:
256
+ xstart = ((xpos+self.shift_sizes[layer_id%2]) // self.window_size) * self.window_size - self.shift_sizes[layer_id%2]
257
+ ystart = ((ypos+self.shift_sizes[layer_id%2]) // self.window_size) * self.window_size - self.shift_sizes[layer_id%2]
258
+ xend = xstart + self.window_size
259
+ yend = ystart + self.window_size
260
+ xstart, ystart = max(0, xstart), max(0, ystart)
261
+ xend, yend = min(xend, self.frame_resolution), min(yend, self.frame_resolution)
262
+ else:
263
+ xstart = (xpos // self.window_size) * self.window_size
264
+ ystart = (ypos // self.window_size) * self.window_size
265
+ xend, yend = xstart + self.window_size, ystart+self.window_size
266
+
267
+ # select index
268
+ selected_index = list()
269
+ if frame_num_before > 0:
270
+ # frames before
271
+ frame_attended_start = max(0, frame_num_before-self.time_dim_attend_length+1) if self.time_dim_attend_length > 0 else 0
272
+ for x in range(xstart, xend):
273
+ for y in range(ystart, yend):
274
+ selected_index.append(x*self.frame_resolution+y+frame_len*frame_attended_start)
275
+ cnt_per_frame = len(selected_index)
276
+ for _ in range((frame_num_before-frame_attended_start-1)*cnt_per_frame):
277
+ selected_index.append(selected_index[-cnt_per_frame]+frame_len)
278
+
279
+ # the last frame
280
+ for x in range(xstart, xend):
281
+ for y in range(ystart, yend):
282
+ tmppos = x*self.frame_resolution+y + frame_num_before * frame_len
283
+ if tmppos < pos:
284
+ selected_index.append(tmppos)
285
+ else:
286
+ break
287
+ cnt_all = len(selected_index)+1
288
+ selected_index = torch.tensor(selected_index, device=memkv.device)
289
+ used_memkv = torch.index_select(memkv, 1, selected_index)
290
+ used_k, used_v = used_memkv[..., :h0], used_memkv[..., h0:]
291
+ used_k = torch.cat((used_k.expand(thisk.shape[0], -1, -1), thisk), dim=-2)
292
+ used_v = torch.cat((used_v.expand(thisv.shape[0], -1, -1), thisv), dim=-2)
293
+ if memkv_text is not None:
294
+ cnt_all += memkv_text.shape[-2]
295
+ used_k = torch.cat((memkv_text[..., :h0].expand(thisk.shape[0], -1, -1), used_k), dim=-2)
296
+ used_v = torch.cat((memkv_text[..., h0:].expand(thisv.shape[0], -1, -1), used_v), dim=-2)
297
+ used_k = used_k.reshape(b0, cnt_all, self.n_head, h).permute(0, 2, 1, 3)
298
+ used_v = used_v.reshape(b0, cnt_all, self.n_head, h).permute(0, 2, 1, 3)
299
+ else:
300
+ used_k = thisk
301
+ used_v = thisv
302
+
303
+ if memkv_text is not None:
304
+ used_k = torch.cat((memkv_text[..., :h0].expand(thisk.shape[0], -1, -1), used_k), dim=-2)
305
+ used_v = torch.cat((memkv_text[..., h0:].expand(thisv.shape[0], -1, -1), used_v), dim=-2)
306
+ used_k = used_k.reshape(b0, 1+memkv_text.shape[-2], self.n_head, h).permute(0, 2, 1, 3)
307
+ used_v = used_v.reshape(b0, 1+memkv_text.shape[-2], self.n_head, h).permute(0, 2, 1, 3)
308
+ else:
309
+ used_k = used_k.reshape(b0, 1, self.n_head, h).permute(0, 2, 1, 3)
310
+ used_v = used_v.reshape(b0, 1, self.n_head, h).permute(0, 2, 1, 3)
311
+
312
+ thisq = thisq.reshape(b0, 1, self.n_head, h).permute(0, 2, 1, 3) # [b0, n_head, 1, h]
313
+ attn = torch.matmul(thisq / math.sqrt(h), used_k.transpose(-1, -2))
314
+ if memkv_text is not None:
315
+ attn[..., :memkv_text.shape[-2]] += log_text_attention_weights
316
+ attn = F.softmax(attn, dim=-1)
317
+ context_swin = torch.matmul(attn, used_v).permute(0, 2, 1, 3).reshape(b0, 1, h0)
318
+
319
+ return context_swin, this_qkv[..., h0:]
320
+
321
+ class FullAttentionMixin(BaseMixin):
322
+ def __init__(self, num_layers,
323
+ hidden_size,
324
+ frame_resolution,
325
+ n_head,
326
+ frame_num,
327
+ init_method=unscaled_init_method(0.02),
328
+ output_layer_init_method=unscaled_init_method(0.02),
329
+ **kwargs,
330
+ ):
331
+ super(FullAttentionMixin, self).__init__()
332
+ self.num_layers = num_layers # replace attention in the LAST n layers
333
+ self.query_key_value = torch.nn.ModuleList(
334
+ [ColumnParallelLinear(hidden_size, 3*hidden_size,stride=3,
335
+ gather_output=False,init_method=init_method)
336
+ for layer_id in range(num_layers)
337
+ ])
338
+ self.dense = torch.nn.ModuleList(
339
+ [RowParallelLinear(
340
+ hidden_size,
341
+ hidden_size,
342
+ input_is_parallel=True,
343
+ init_method=output_layer_init_method,
344
+ bias=True,
345
+ module=self,
346
+ name="dense")
347
+ for layer_id in range(num_layers)
348
+ ])
349
+
350
+ self.n_head = n_head
351
+ self.frame_resolution = frame_resolution
352
+ self.frame_len = frame_resolution * frame_resolution
353
+
354
+ self.attn_distribution = torch.nn.ParameterList([
355
+ torch.nn.Parameter(torch.zeros(hidden_size))
356
+ for _ in range(num_layers)
357
+ ])
358
+
359
+ def reinit(self, *pre_mixins):
360
+ start_layer = len(self.transformer.layers) - self.num_layers
361
+ assert start_layer >= 0
362
+ for layer_id in range(self.num_layers):
363
+ old_attention = self.transformer.layers[start_layer + layer_id].attention
364
+ self.query_key_value[layer_id].weight.data.copy_(old_attention.query_key_value.weight.data)
365
+ self.query_key_value[layer_id].bias.data.copy_(old_attention.query_key_value.bias.data)
366
+
367
+
368
+ def attention_extra_NAR_inference(self, frame_hidden_state, layer_id, attn_dropout=None, memkv_text=None, stage=1):
369
+ # frame_hidden_state [batchsize, frame_num*frame_size, n_head*hiddensize_perhead]
370
+ assert stage == 1
371
+
372
+ b0, s1, h0 = frame_hidden_state.shape
373
+ h = h0 // self.n_head
374
+ frame_len = self.frame_resolution * self.frame_resolution
375
+ frame_num = s1 // frame_len
376
+ assert frame_num*frame_len == s1
377
+
378
+ if memkv_text is not None:
379
+ s0 = memkv_text.shape[-2]
380
+ k_text = memkv_text[..., :h0].expand(b0, -1, -1).reshape(b0, s0, self.n_head, h).permute(0, 2, 1, 3)
381
+ v_text = memkv_text[..., h0:].expand(b0, -1, -1).reshape(b0, s0, self.n_head, h).permute(0, 2, 1, 3)
382
+ qkv = self.query_key_value[layer_id](frame_hidden_state).reshape(b0, s1, 3, self.n_head, h)\
383
+ .permute(2, 0, 3, 1, 4) #[3, b0, n_head, s1, h]
384
+ q, k, v = qkv[0], qkv[1], qkv[2]
385
+ attn = torch.matmul(q / math.sqrt(h), k.transpose(-1, -2))
386
+ attn = attn - 10000.0 * (1.0-torch.ones(b0, self.n_head, s1, s1, device=attn.device, dtype=attn.dtype).tril())
387
+
388
+ if memkv_text is None:
389
+ attn = F.softmax(attn, dim=-1)
390
+ if attn_dropout is not None:
391
+ with get_cuda_rng_tracker().fork():
392
+ attn = attn_dropout(attn)
393
+ context_swin = torch.matmul(attn, v).permute(0, 2, 1, 3).reshape(b0, s1, h0)
394
+ else:
395
+ attn_frame2text = torch.matmul(q / math.sqrt(h), k_text.transpose(-1, -2)) #[b0, s1, s0]
396
+ attn = torch.cat((attn, attn_frame2text), dim=-1)
397
+ attn = F.softmax(attn, dim=-1)
398
+ if attn_dropout is not None:
399
+ with get_cuda_rng_tracker().fork():
400
+ attn = attn_dropout(attn)
401
+ context_swin = (torch.matmul(attn[..., :-s0], v) + torch.matmul(attn[..., -s0:], v_text))\
402
+ .permute(0, 2, 1, 3).reshape(b0, s1, h0)
403
+
404
+ # for mem
405
+ memk = k.permute(0, 2, 1, 3).reshape(b0, s1, h0)
406
+ memv = v.permute(0, 2, 1, 3).reshape(b0, s1, h0)
407
+ ret_mem = torch.cat((memk, memv), dim=-1)
408
+
409
+ return context_swin, ret_mem
410
+
411
+ def attention_extra_AR_inference(self, frame_hidden_state, memkv, pos, layer_id, log_text_attention_weights=0, attn_dropout=None, memkv_text=None, stage=1):
412
+ # pos: current token's pos
413
+ b0, sin, h0 = frame_hidden_state.shape
414
+ h = h0 // self.n_head
415
+ assert sin == 1
416
+ assert stage == 1
417
+
418
+ this_qkv = self.query_key_value[layer_id](frame_hidden_state)
419
+ thisq, thisk, thisv = this_qkv[..., :h0], this_qkv[..., h0:2*h0], this_qkv[..., 2*h0:]
420
+
421
+ if memkv is not None:
422
+ used_k, used_v = memkv[..., :h0], memkv[..., h0:]
423
+ used_k = torch.cat((used_k.expand(thisk.shape[0], -1, -1), thisk), dim=-2)
424
+ used_v = torch.cat((used_v.expand(thisv.shape[0], -1, -1), thisv), dim=-2)
425
+ else:
426
+ used_k, used_v = thisk, thisv
427
+
428
+ if memkv_text is not None:
429
+ used_k = torch.cat((memkv_text[..., :h0].expand(thisk.shape[0], -1, -1), used_k), dim=-2)
430
+ used_v = torch.cat((memkv_text[..., h0:].expand(thisv.shape[0], -1, -1), used_v), dim=-2)
431
+
432
+ used_k = used_k.reshape(b0, -1, self.n_head, h).permute(0, 2, 1, 3)
433
+ used_v = used_v.reshape(b0, -1, self.n_head, h).permute(0, 2, 1, 3)
434
+ thisq = thisq.reshape(b0, 1, self.n_head, h).permute(0, 2, 1, 3) # [b0, n_head, 1, h]
435
+ attn = torch.matmul(thisq / math.sqrt(h), used_k.transpose(-1, -2))
436
+ if memkv_text is not None:
437
+ attn[..., :memkv_text.shape[-2]] += log_text_attention_weights
438
+ attn = F.softmax(attn, dim=-1)
439
+
440
+ context_swin = torch.matmul(attn, used_v).permute(0, 2, 1, 3).reshape(b0, 1, h0)
441
+
442
+ return context_swin, this_qkv[..., h0:]
443
+
444
+
445
+ def attention_localframe_and_text_NAR(q0, k0, v0, attention_mask,
446
+ n_head, text_len, frame_len, frame_num,
447
+ attention_dropout=None, log_text_attention_weights=0, stage=1, **kwargs):
448
+ b, s0, h0 = q0.shape
449
+ s1 = s0 - text_len
450
+ h = h0 // n_head
451
+ assert q0.shape[1] == v0.shape[1] == k0.shape[1] == text_len+frame_len*frame_num
452
+ # attention_mask.shape [4, b or 1, 1, text_len+frame_len, text_len+frame_len]
453
+ if stage == 2:
454
+ assert frame_num == 3
455
+
456
+ q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
457
+ v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
458
+ k0 = k0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
459
+ k0T = k0.transpose(-1, -2)
460
+
461
+ score_any2text = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T[..., :text_len])
462
+ score_any2text += log_text_attention_weights
463
+ score_any2text_part1 = torch.mul(score_any2text[..., :text_len, :], attention_mask[..., :text_len, :text_len]) \
464
+ - 10000.0 * (1.0 - attention_mask[..., :text_len, :text_len])
465
+ # context for text
466
+ attention_probs_text = F.softmax(score_any2text_part1, dim=-1)
467
+ if attention_dropout is not None:
468
+ with get_cuda_rng_tracker().fork():
469
+ attention_probs_text = attention_dropout(attention_probs_text)
470
+ context_text2text = torch.matmul(attention_probs_text, v0[..., :text_len, :])
471
+ context_text2text = context_text2text.transpose(1, 2).reshape(b, text_len, h0)
472
+
473
+ if frame_num > 0:
474
+ score_any2text_part2 = score_any2text[..., text_len:, :]
475
+
476
+ # score: frame local
477
+ q0_frame = q0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h)
478
+ v0_frame = v0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h)
479
+ k0T_frame = k0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h).transpose(-1, -2)
480
+ score_frame_local0 = torch.matmul(q0_frame / math.sqrt(q0_frame.shape[-1]), k0T_frame)
481
+ if stage == 1:
482
+ score_frame_local0 = torch.mul(score_frame_local0, attention_mask[..., text_len:, text_len:].unsqueeze(1)) \
483
+ - 10000.0 * (1.0 - attention_mask[..., text_len:, text_len:].unsqueeze(1))
484
+
485
+ # context for frame
486
+ score_frame_all = torch.cat((score_any2text_part2,
487
+ score_frame_local0.view(b, n_head, s1, frame_len)), dim=-1)
488
+ attention_probs_frame = F.softmax(score_frame_all, dim=-1)
489
+ if attention_dropout is not None:
490
+ with get_cuda_rng_tracker().fork():
491
+ attention_probs_frame = attention_dropout(attention_probs_frame)
492
+ context_frame2text = torch.matmul(attention_probs_frame[..., :text_len], v0[..., :text_len, :]) # [b, n_head, s1, h]
493
+ context_frame_local0 = torch.matmul(attention_probs_frame[..., text_len:text_len+frame_len].\
494
+ view(b, n_head, frame_num, frame_len, frame_len), v0_frame).view(b, n_head, s1, h)
495
+
496
+ context_frame = (context_frame2text + context_frame_local0).transpose(1, 2).reshape(b, s1, h0)
497
+ else:
498
+ context_frame = None
499
+
500
+ return context_text2text, context_frame
501
+
502
+ def attention_localframe_and_text_AR(q0, k0, v0, n_head, text_len, frame_len, frame_num,
503
+ attention_dropout=None, log_text_attention_weights=0, layer_id=None, limited_spatial_channel_mem=False, stage=1, **kwargs):
504
+ # limited_spatial_channel_mem=True means: mems in spatial channel is consisted of {mem_text, mem_current_frame}
505
+ b, s0, h0 = k0.shape
506
+ frame_num_before = (s0-text_len-1) // frame_len # frame_num == frame_num_before or frame_num == frame_num_before+1
507
+ h = h0 // n_head
508
+ assert q0.shape[1] == 1
509
+ assert v0.shape[1] == k0.shape[1]
510
+
511
+ q0 = q0.reshape(b, 1, n_head, h).permute(0, 2, 1, 3)
512
+ v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
513
+ k0T = k0.reshape(b, s0, n_head, h).permute(0, 2, 3, 1)
514
+
515
+ if limited_spatial_channel_mem:
516
+ assert frame_num_before == 0
517
+ assert stage == 1 # not implemented for stage-2 yet
518
+ score = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T)
519
+ score[..., :text_len] += log_text_attention_weights
520
+ attention_probs_frame = F.softmax(score, dim=-1)
521
+ context_frame = torch.matmul(attention_probs_frame, v0).transpose(1, 2).reshape(b, 1, h0)
522
+
523
+ else:
524
+ score_token2text = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T[..., :text_len])
525
+ score_token2text += log_text_attention_weights
526
+ score_frame_local0 = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T[..., text_len+frame_num_before*frame_len:])
527
+ score_frame_all = torch.cat((score_token2text,
528
+ score_frame_local0), dim=-1)
529
+ attention_probs_frame = F.softmax(score_frame_all, dim=-1)
530
+
531
+ context_token2text = torch.matmul(attention_probs_frame[..., :text_len], v0[..., :text_len, :]) # [b, n_head, s1, h]
532
+ context_frame_local0 = torch.matmul(attention_probs_frame[..., text_len:], \
533
+ v0[:, :, text_len+frame_num_before*frame_len:, :])
534
+ context_frame = (context_token2text + context_frame_local0).transpose(1, 2).reshape(b, 1, h0)
535
+
536
+ return context_frame
537
+
538
+
539
+ class CogVideoCacheModel(BaseModel):
540
+ def __init__(self, args, transformer=None, parallel_output=True, window_size=None, cogvideo_stage=None):
541
+ super().__init__(args, transformer=transformer, parallel_output=parallel_output)
542
+ self.layout = args.layout # [64, 64+1024, 64+6*1024]
543
+ self.stage = cogvideo_stage if cogvideo_stage is not None else args.cogvideo_stage # 1 or 2
544
+ self.n_head = args.num_attention_heads
545
+ self.window_size = window_size if window_size is not None else args.window_size
546
+
547
+ frame_resolution = int(math.sqrt(self.layout[1]-self.layout[0]))
548
+ self.add_mixin('extra_position_embedding', PositionEmbeddingMixin(
549
+ args.additional_seqlen, args.hidden_size
550
+ ))
551
+
552
+ if self.stage == 1:
553
+ self.add_mixin('attention_plus', FullAttentionMixin(
554
+ num_layers=args.num_layers,
555
+ hidden_size=args.hidden_size,
556
+ frame_resolution=frame_resolution,
557
+ n_head=args.num_attention_heads,
558
+ frame_num=(args.layout[2]-args.layout[0])//(args.layout[1]-args.layout[0]),
559
+ ))
560
+ else:
561
+ self.add_mixin('attention_plus', WindowAttentionMixin(
562
+ num_layers=args.num_layers,
563
+ hidden_size=args.hidden_size,
564
+ frame_resolution=frame_resolution,
565
+ window_size=self.window_size,
566
+ shift_size=self.window_size//2,
567
+ n_head=args.num_attention_heads,
568
+ frame_num=(args.layout[2]-args.layout[0])//(args.layout[1]-args.layout[0]),
569
+ ))
570
+
571
+
572
+ @classmethod
573
+ def add_model_specific_args(cls, parser):
574
+ group = parser.add_argument_group('VideoSwinLocalModel', 'video swin local model configurations')
575
+ group.add_argument("--layout", type=str, default='64, 464, 2064')
576
+ group.add_argument("--window-size", type=int, default=10) # 优先级在直接参数赋值之后
577
+ group.add_argument("--additional-seqlen", type=int, default=2000)
578
+ group.add_argument("--cogvideo-stage", type=int, default=1, choices=[1,2]) # 优先级在直接参数赋值之后
579
+ return parser
580
+
581
+ def disable_untrainable_params(self):
582
+ pass
583
+
584
+ def position_embedding_forward(self, position_ids, **kw_args):
585
+ if position_ids.shape[-1] > 1:
586
+ if self.stage == 1:
587
+ if position_ids[0,-1] >= (512+400):
588
+ frame_num = position_ids.shape[-1] // 400
589
+ position_embeddings = torch.cat(
590
+ (
591
+ self.transformer.position_embeddings(position_ids[..., :-400*(frame_num-1)]),
592
+ self.get_mixin('extra_position_embedding').position_embeddings(position_ids[..., -400*(frame_num-1):]-(512+400))
593
+ ),
594
+ dim=-2
595
+ )
596
+ else:
597
+ position_embeddings = self.transformer.position_embeddings(position_ids)
598
+ else:
599
+ # given 3, interpolate 2
600
+ position_embeddings = torch.cat(
601
+ (
602
+ self.transformer.position_embeddings(position_ids[..., :-800]),
603
+ self.get_mixin('extra_position_embedding').position_embeddings(position_ids[..., -800:]-(512+400))
604
+ ),
605
+ dim=-2
606
+ )
607
+ else:
608
+ if position_ids[0, 0] >= (512+400):
609
+ position_embeddings = self.get_mixin('extra_position_embedding').position_embeddings(position_ids-(512+400))
610
+ else:
611
+ position_embeddings = self.transformer.position_embeddings(position_ids)
612
+ return position_embeddings
613
+
614
+ def attention_forward(self, hidden_states, mask, layer_id, mems=None, log_text_attention_weights=0, text_len=0, frame_len=0, counter=0, enforce_no_swin=False, limited_spatial_channel_mem=False, **kw_args):
615
+ attn_module = self.transformer.layers[layer_id].attention
616
+ hidden_size = hidden_states.shape[-1]
617
+
618
+ # base model qkv
619
+ if mems is None:
620
+ mixed_raw_layer = attn_module.query_key_value(hidden_states)
621
+ q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer, 3)
622
+ assert (q0.shape[1]-text_len) % frame_len == 0
623
+ memkv0 = torch.cat((k0, v0), dim=-1)
624
+ context_text, context_frame_local_text = attention_localframe_and_text_NAR(
625
+ q0, k0, v0,
626
+ mask,
627
+ n_head=attn_module.num_attention_heads_per_partition,
628
+ text_len=text_len,
629
+ frame_len=frame_len,
630
+ frame_num=(q0.shape[1]-text_len)//frame_len,
631
+ log_text_attention_weights=log_text_attention_weights,
632
+ stage=self.stage
633
+ )
634
+
635
+ # change: self.swin_attend_to_text默认为True:
636
+ memkv1_text = self.get_mixin('attention_plus').query_key_value[layer_id](hidden_states[..., :text_len, :])[..., hidden_size:]
637
+ output_text = attn_module.dense(context_text)
638
+
639
+ if (q0.shape[1]-text_len)//frame_len > 0:
640
+ assert (q0.shape[1]-text_len) % frame_len == 0
641
+ context_frame_swin, memkv1_frame = self.get_mixin('attention_plus').attention_extra_NAR_inference(
642
+ hidden_states[:,text_len:], layer_id, memkv_text=memkv1_text, stage=self.stage)
643
+ if not enforce_no_swin:
644
+ attn_distrib = torch.sigmoid(self.get_mixin('attention_plus').attn_distribution[layer_id])
645
+ attn_distrib = attn_distrib.unsqueeze(0).unsqueeze(0)
646
+ output_frame = torch.mul(attn_module.dense(context_frame_local_text), attn_distrib)\
647
+ +torch.mul(self.get_mixin('attention_plus').dense[layer_id](context_frame_swin), 1-attn_distrib)
648
+ else:
649
+ output_frame = attn_module.dense(context_frame_local_text[..., :frame_len, :])
650
+ output = torch.cat((output_text, output_frame), dim=-2)
651
+ memkv1 = torch.cat((memkv1_text, memkv1_frame), dim=-2) if memkv1_text is not None else memkv1_frame
652
+ else:
653
+ output = output_text
654
+ memkv1 = memkv1_text
655
+ kw_args['output_this_layer']['mem_kv'] = (memkv0, memkv1)
656
+
657
+
658
+ else:
659
+ mixed_raw_layer = attn_module.query_key_value(hidden_states)
660
+ q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer, 3)
661
+ new_memkv0 = torch.cat((k0, v0), dim=-1)
662
+ old_k0, old_v0 = mems[0][layer_id][..., :hidden_size], mems[0][layer_id][..., hidden_size:]
663
+
664
+ context_frame_local_text = attention_localframe_and_text_AR(
665
+ q0,
666
+ torch.cat((old_k0.expand(k0.shape[0], -1, -1), k0), dim=-2),
667
+ torch.cat((old_v0.expand(v0.shape[0], -1, -1), v0), dim=-2),
668
+ n_head=attn_module.num_attention_heads_per_partition,
669
+ text_len=text_len,
670
+ frame_len=frame_len,
671
+ frame_num=None,
672
+ log_text_attention_weights=log_text_attention_weights,
673
+ layer_id=layer_id,
674
+ limited_spatial_channel_mem=limited_spatial_channel_mem,
675
+ )
676
+
677
+ old_memkv1 = mems[1][layer_id] if mems[1] is not None else None
678
+
679
+ context_frame_swin, new_memkv1 = self.get_mixin('attention_plus').attention_extra_AR_inference(hidden_states,
680
+ old_memkv1[..., text_len:, :] if old_memkv1.shape[-2]>text_len else None,
681
+ counter-text_len,
682
+ layer_id,
683
+ memkv_text=old_memkv1[..., :text_len, :],
684
+ log_text_attention_weights=log_text_attention_weights)
685
+ if not enforce_no_swin:
686
+ attn_distrib = torch.sigmoid(self.get_mixin('attention_plus').attn_distribution[layer_id])
687
+ attn_distrib = attn_distrib.unsqueeze(0).unsqueeze(0)
688
+ output = torch.mul(attn_module.dense(context_frame_local_text), attn_distrib)\
689
+ +torch.mul(self.get_mixin('attention_plus').dense[layer_id](context_frame_swin), 1-attn_distrib)
690
+ else:
691
+ output = attn_module.dense(context_frame_local_text)
692
+
693
+ kw_args['output_this_layer']['mem_kv'] = (new_memkv0, new_memkv1)
694
+
695
+ return output
models/cogvideo_model.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ '''
3
+ @File : cogvideo_model.py
4
+ @Time : 2022/07/11 16:12:05
5
+ @Author : Wenyi Hong
6
+ @Version : 1.0
7
+ @Contact : hwy22@mails.tsinghua.edu.cn
8
+ '''
9
+
10
+ # here put the import lib
11
+
12
+ import torch
13
+ from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin
14
+
15
+ from SwissArmyTransformer.mpu.utils import split_tensor_along_last_dim
16
+ from SwissArmyTransformer.model.transformer import unscaled_init_method
17
+ from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear
18
+ import torch.nn.functional as F
19
+ from deepspeed.runtime.activation_checkpointing.checkpointing import get_cuda_rng_tracker
20
+ import math
21
+
22
+ class PositionEmbeddingMixin(BaseMixin):
23
+ def __init__(self, additional_sequence_length, hidden_size,
24
+ init_method_std=0.02, reinit_slice=slice(512, 912),
25
+ ):
26
+ super(PositionEmbeddingMixin, self).__init__()
27
+ self.reinit_slice = reinit_slice
28
+ self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size)
29
+ torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
30
+
31
+ def reinit(self, parent_model=None):
32
+ old_weights = self.transformer.position_embeddings.weight.data[self.reinit_slice]
33
+ old_len, hidden_size = old_weights.shape
34
+ assert hidden_size == self.position_embeddings.weight.shape[-1]
35
+ self.position_embeddings.weight.data.view(-1, old_len, hidden_size).copy_(old_weights)
36
+
37
+ def window_partition(x, window_size):
38
+ """
39
+ Args:
40
+ x: (B, framenum, H, W, C)
41
+ window_size (int): window size
42
+ Returns:
43
+ windows: (num_windows*B, frame_num, window_size, window_size, C)
44
+ """
45
+ B, framenum, H, W, C = x.shape
46
+ x = x.view(B, framenum, H // window_size, window_size, W // window_size, window_size, C)
47
+ windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(-1, framenum, window_size, window_size, C)
48
+ return windows
49
+
50
+ def window_reverse(windows, window_size, H, W):
51
+ """
52
+ Args:
53
+ windows: (num_windows*B, frame_num, window_size, window_size, C)
54
+ window_size (int): Window size
55
+ H (int): Height of image
56
+ W (int): Width of image
57
+ Returns:
58
+ x: (B, frame_num, H, W, C)
59
+ """
60
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
61
+ framenum = windows.shape[1]
62
+ x = windows.view(B, H // window_size, W // window_size, framenum, window_size, window_size, -1)
63
+ x = x.permute(0, 3, 1, 4, 2, 5, 6).contiguous().view(B, framenum, H, W, -1)
64
+ return x
65
+
66
+ class WindowAttentionMixin(BaseMixin):
67
+ def __init__(self, num_layers,
68
+ hidden_size,
69
+ frame_resolution,
70
+ window_size,
71
+ shift_size,
72
+ n_head,
73
+ frame_num,
74
+ init_method=unscaled_init_method(0.02),
75
+ output_layer_init_method=unscaled_init_method(0.02),
76
+ ):
77
+ super(WindowAttentionMixin, self).__init__()
78
+ self.num_layers = num_layers # replace attention in the LAST n layers
79
+ self.query_key_value = torch.nn.ModuleList(
80
+ [ColumnParallelLinear(hidden_size, 3*hidden_size,stride=3,
81
+ gather_output=False,init_method=init_method)
82
+ for layer_id in range(num_layers)
83
+ ])
84
+ self.dense = torch.nn.ModuleList(
85
+ [RowParallelLinear(
86
+ hidden_size,
87
+ hidden_size,
88
+ input_is_parallel=True,
89
+ init_method=output_layer_init_method,
90
+ bias=True,
91
+ module=self,
92
+ name="dense",
93
+ )
94
+ for layer_id in range(num_layers)
95
+ ])
96
+
97
+ self.n_head = n_head
98
+ self.window_size = window_size
99
+ self.frame_resolution = frame_resolution
100
+ self.frame_len = frame_resolution * frame_resolution
101
+ assert frame_resolution % window_size == 0
102
+ assert 0 < shift_size < window_size
103
+ nW = (self.frame_resolution // self.window_size) ** 2
104
+ ws_squre = self.window_size * self.window_size
105
+
106
+ # odd non-shift, even shift
107
+ img_mask = torch.zeros((1, 1, frame_resolution, frame_resolution, 1))
108
+ h_slices = (slice(0, -shift_size),
109
+ slice(-shift_size, None))
110
+ w_slices = (slice(0, -shift_size),
111
+ slice(-shift_size, None))
112
+ cnt = 0
113
+ for h in h_slices:
114
+ for w in w_slices:
115
+ img_mask[:, :, h, w, :] = cnt
116
+ cnt += 1
117
+ mask_windows = window_partition(img_mask, self.window_size) # nW, 1, window_size, window_size, 1
118
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
119
+ sub_attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) #[nW, self.window_size * self.window_size, self.window_size * self.window_size]
120
+ sub_attn_mask = sub_attn_mask.masked_fill(sub_attn_mask != 0, float(0.0)).masked_fill(sub_attn_mask == 0, float(1.00))
121
+ attn_mask = sub_attn_mask.repeat(1, frame_num, frame_num)
122
+
123
+ self.attn_mask_sequential = attn_mask.clone().tril()
124
+ self.causal_mask_sequential = torch.ones(1, ws_squre*frame_num, ws_squre*frame_num).tril()
125
+
126
+ self.causal_mask_interp = torch.ones(1, ws_squre*frame_num, ws_squre*frame_num)
127
+ self.attn_mask_interp = attn_mask.clone()
128
+
129
+ # bi-dir
130
+ for bi_idx in range(0, frame_num, 2):
131
+ for uni_idx in range(1, frame_num, 2):
132
+ self.attn_mask_interp[:, bi_idx*ws_squre:(bi_idx+1)*ws_squre, uni_idx*ws_squre:(uni_idx+1)*ws_squre] = 0
133
+ self.causal_mask_interp[:, bi_idx*ws_squre:(bi_idx+1)*ws_squre, uni_idx*ws_squre:(uni_idx+1)*ws_squre] = 0
134
+ # uni-dir
135
+ for uni_idx in range(1, frame_num, 2):
136
+ self.attn_mask_interp[:, ws_squre*uni_idx:ws_squre*(uni_idx+1), ws_squre*uni_idx:ws_squre*(uni_idx+1)].tril_()
137
+ self.causal_mask_interp[:, ws_squre*uni_idx:ws_squre*(uni_idx+1), ws_squre*uni_idx:ws_squre*(uni_idx+1)].tril_()
138
+ for uni_idx2 in range(uni_idx+2, frame_num, 2):
139
+ self.attn_mask_interp[:, ws_squre*uni_idx:ws_squre*(uni_idx+1), ws_squre*uni_idx2:ws_squre*(uni_idx2+1)] = 0
140
+ self.causal_mask_interp[:, ws_squre*uni_idx:ws_squre*(uni_idx+1), ws_squre*uni_idx2:ws_squre*(uni_idx2+1)] = 0
141
+
142
+ # expand dim
143
+ self.attn_mask_sequential = self.attn_mask_sequential[None, None, :, None]
144
+ self.attn_mask_interp = self.attn_mask_interp[None, None, :, None]
145
+ self.causal_mask_sequential = self.causal_mask_sequential[None, None, :, None]
146
+ self.causal_mask_interp = self.causal_mask_interp[None, None, :, None]
147
+
148
+ self.shift_sizes = [0, shift_size]
149
+ # self.register_buffer("attn_mask", attn_mask)
150
+ # self.register_buffer("causal_mask", causal_mask)
151
+ self.mask_initialized = False
152
+
153
+ self.attn_distribution = torch.nn.ParameterList([
154
+ torch.nn.Parameter(torch.zeros(hidden_size))
155
+ for _ in range(num_layers)
156
+ ])
157
+
158
+ def reinit(self, *pre_mixins):
159
+ start_layer = len(self.transformer.layers) - self.num_layers
160
+ assert start_layer >= 0
161
+ for layer_id in range(self.num_layers):
162
+ old_attention = self.transformer.layers[start_layer + layer_id].attention
163
+ self.query_key_value[layer_id].weight.data.copy_(old_attention.query_key_value.weight.data)
164
+ self.query_key_value[layer_id].bias.data.copy_(old_attention.query_key_value.bias.data)
165
+
166
+ def attention_extra(self, frame_hidden_state, layer_id, attn_dropout, text_hidden_state=None,
167
+ text_attn_mask=None, mode_sequential=True):
168
+ # pb relax
169
+ swin_pb_relax = True
170
+ alpha = 16
171
+
172
+ # frame_hidden_state [batchsize, frame_num*frame_size, n_head*hiddensize_perhead]
173
+ if not self.mask_initialized:
174
+ self.attn_mask_sequential = self.attn_mask_sequential.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype)
175
+ self.causal_mask_sequential = self.causal_mask_sequential.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype)
176
+ self.attn_mask_interp = self.attn_mask_interp.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype)
177
+ self.causal_mask_interp = self.causal_mask_interp.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype)
178
+ self.mask_initialized = True
179
+ b0, s1, h0 = frame_hidden_state.shape
180
+ h = h0 // self.n_head
181
+ frame_len = self.frame_resolution * self.frame_resolution
182
+ frame_num = s1 // frame_len
183
+ assert frame_num*frame_len == s1
184
+ wind_square = self.window_size * self.window_size
185
+ nW = frame_len // wind_square
186
+ bswin = b0 * nW
187
+
188
+ causal_mask = self.causal_mask_sequential if mode_sequential else self.causal_mask_interp
189
+ attn_mask = self.attn_mask_sequential if mode_sequential else self.attn_mask_interp
190
+ if text_hidden_state is not None:
191
+ s0 = text_hidden_state.shape[1]
192
+ qkv_text = self.query_key_value[layer_id](text_hidden_state).reshape(b0, s0, 3, self.n_head, h).permute(2, 0, 3, 1, 4) #[3, b0, n_head, s0, h]
193
+ q_text, k_text, v_text = qkv_text[0], qkv_text[1], qkv_text[2]
194
+
195
+ # shift
196
+ frame_hidden_state = frame_hidden_state.reshape(b0, frame_num, self.frame_resolution, self.frame_resolution, h0)
197
+ if self.shift_sizes[layer_id%2] > 0:
198
+ frame_hidden_state = torch.roll(frame_hidden_state, shifts=(-self.shift_sizes[layer_id%2], -self.shift_sizes[layer_id%2]), dims=(2,3))
199
+ # window partition
200
+ frame_hidden_state = window_partition(frame_hidden_state, self.window_size).reshape(bswin, frame_num*wind_square, h0)
201
+ qkv = self.query_key_value[layer_id](frame_hidden_state).reshape(bswin, frame_num*wind_square, 3, self.n_head, h)\
202
+ .permute(2, 0, 3, 1, 4) #[3, bswin, n_head, frame_num*wind_size*wind_size, h]
203
+ q, k, v = qkv[0], qkv[1], qkv[2]
204
+
205
+ # pb-relax
206
+ if swin_pb_relax:
207
+ attn = torch.matmul(q / (math.sqrt(h)*alpha), k.transpose(-1, -2))
208
+ else:
209
+ attn = torch.matmul(q / math.sqrt(h), k.transpose(-1, -2))
210
+
211
+ if self.shift_sizes[layer_id%2] > 0:
212
+ # attn = attn.view(bswin // nW, nW, self.n_head, frame_num*wind_square, frame_num*wind_square) + self.attn_mask.unsqueeze(1).unsqueeze(0)
213
+ attn = torch.mul(attn.view(bswin // nW, nW, self.n_head, frame_num*wind_square, frame_num*wind_square), attn_mask)\
214
+ - 10000.0 * (1.0 - attn_mask)
215
+ attn = attn.view(bswin, self.n_head, frame_num*wind_square, frame_num*wind_square)
216
+ else:
217
+ attn = torch.mul(attn.view(bswin // nW, nW, self.n_head, frame_num*wind_square, frame_num*wind_square), causal_mask)\
218
+ - 10000.0 * (1.0 - causal_mask)
219
+ attn = attn.view(bswin, self.n_head, frame_num*wind_square, frame_num*wind_square)
220
+ if swin_pb_relax:
221
+ swin_pb_relax_const = torch.max(attn.reshape(bswin, self.n_head, -1), dim=-1, keepdim=True)[0].detach().unsqueeze(-1)
222
+ attn = (attn - swin_pb_relax_const)*alpha
223
+
224
+ if text_hidden_state is None:
225
+ attn = F.softmax(attn, dim=-1)
226
+ if attn_dropout is not None:
227
+ with get_cuda_rng_tracker().fork():
228
+ attn = attn_dropout(attn)
229
+ context_swin = torch.matmul(attn, v).permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0)
230
+ else:
231
+ assert text_attn_mask is not None
232
+ text_attn_mask = text_attn_mask.unsqueeze(2).unsqueeze(2)
233
+ # pb-relax
234
+ if swin_pb_relax:
235
+ attn_frame2text = torch.matmul(q.reshape(b0, -1, self.n_head, frame_num*wind_square, h) / (math.sqrt(h)*alpha), k_text.unsqueeze(1).transpose(-1, -2))
236
+ attn_frame2text = (attn_frame2text-swin_pb_relax_const.reshape(b0, -1, self.n_head, 1, 1))*alpha
237
+ else:
238
+ attn_frame2text = torch.matmul(q.reshape(b0, -1, self.n_head, frame_num*wind_square, h) / math.sqrt(h), k_text.unsqueeze(1).transpose(-1, -2))
239
+
240
+ attn_frame2text = torch.mul(text_attn_mask, attn_frame2text) - 10000.0 * (1.0 - text_attn_mask)
241
+ attn_frame2text = attn_frame2text.reshape(bswin, self.n_head, frame_num*wind_square, s0)
242
+ attn = torch.cat((attn, attn_frame2text), dim=-1)
243
+ attn = F.softmax(attn, dim=-1)
244
+
245
+ if attn_dropout is not None:
246
+ with get_cuda_rng_tracker().fork():
247
+ attn = attn_dropout(attn)
248
+
249
+ context_swin = (torch.matmul(attn[..., :-s0], v) +
250
+ torch.matmul(attn[..., -s0:].reshape(b0, -1, self.n_head,frame_num*wind_square, s0), v_text.unsqueeze(1))\
251
+ .reshape(bswin, self.n_head, frame_num*wind_square, h))\
252
+ .permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0)
253
+
254
+ context_swin = window_reverse(context_swin, self.window_size, self.frame_resolution, self.frame_resolution)
255
+ # reverse cycle shift
256
+ if self.shift_sizes[layer_id%2] > 0:
257
+ context_swin = torch.roll(context_swin, shifts=(self.shift_sizes[layer_id%2], self.shift_sizes[layer_id%2]), dims=(2,3))
258
+ context_swin = context_swin.reshape(b0, s1, h0)
259
+
260
+ return context_swin
261
+
262
+
263
+ class FullAttentionMixin(BaseMixin):
264
+ def __init__(self, num_layers,
265
+ hidden_size,
266
+ frame_resolution,
267
+ n_head,
268
+ frame_num,
269
+ init_method=unscaled_init_method(0.02),
270
+ output_layer_init_method=unscaled_init_method(0.02),
271
+ ):
272
+ super(FullAttentionMixin, self).__init__()
273
+ self.num_layers = num_layers # replace attention in the LAST n layers
274
+ self.query_key_value = torch.nn.ModuleList(
275
+ [ColumnParallelLinear(hidden_size, 3*hidden_size,stride=3,
276
+ gather_output=False,init_method=init_method)
277
+ for layer_id in range(num_layers)
278
+ ])
279
+ self.dense = torch.nn.ModuleList(
280
+ [RowParallelLinear(
281
+ hidden_size,
282
+ hidden_size,
283
+ input_is_parallel=True,
284
+ init_method=output_layer_init_method,
285
+ bias=True,
286
+ module=self,
287
+ name="dense",)
288
+ for layer_id in range(num_layers)
289
+ ])
290
+
291
+ self.n_head = n_head
292
+ self.frame_resolution = frame_resolution
293
+ self.frame_len = frame_resolution * frame_resolution
294
+ self.causal_mask = torch.ones(1, 1, self.frame_len*frame_num, self.frame_len*frame_num).tril()
295
+
296
+ self.mask_initialized = False
297
+
298
+ self.attn_distribution = torch.nn.ParameterList([
299
+ torch.nn.Parameter(torch.zeros(hidden_size))
300
+ for _ in range(num_layers)
301
+ ])
302
+
303
+ def reinit(self, *pre_mixins):
304
+ start_layer = len(self.transformer.layers) - self.num_layers
305
+ assert start_layer >= 0
306
+ for layer_id in range(self.num_layers):
307
+ base_attention = self.transformer.layers[start_layer + layer_id].attention
308
+ self.query_key_value[layer_id].weight.data.copy_(base_attention.query_key_value.weight.data)
309
+ self.query_key_value[layer_id].bias.data.copy_(base_attention.query_key_value.bias.data)
310
+
311
+ def attention_extra(self, frame_hidden_state, layer_id, attn_dropout, text_hidden_state=None,
312
+ text_attn_mask=None, mode_sequential=False):
313
+ # pb relax
314
+ # frame_hidden_state [batchsize, frame_num*frame_size, n_head*hiddensize_perhead]
315
+ assert mode_sequential == True # only
316
+ swin_pb_relax = True
317
+ alpha = 16
318
+
319
+ if not self.mask_initialized:
320
+ self.causal_mask = self.causal_mask.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype)
321
+ self.mask_initialized = True
322
+ b0, s1, h0 = frame_hidden_state.shape
323
+ h = h0 // self.n_head
324
+ frame_len = self.frame_resolution * self.frame_resolution
325
+ frame_num = s1 // frame_len
326
+ assert frame_num*frame_len == s1
327
+
328
+ qkv = self.query_key_value[layer_id](frame_hidden_state).reshape(b0, s1, 3, self.n_head, h)\
329
+ .permute(2, 0, 3, 1, 4) #[3, b0, n_head, s1, h]
330
+ q, k, v = qkv[0], qkv[1], qkv[2]
331
+
332
+ # frames-to-frames
333
+ if swin_pb_relax:
334
+ attn = torch.matmul(q / (math.sqrt(h)*alpha), k.transpose(-1, -2))
335
+ else:
336
+ attn = torch.matmul(q / math.sqrt(h), k.transpose(-1, -2))
337
+ attn = torch.mul(attn, self.causal_mask) - 10000.0 * (1.0 - self.causal_mask)
338
+ if swin_pb_relax:
339
+ swin_pb_relax_const = torch.max(attn.reshape(b0, self.n_head, -1), dim=-1, keepdim=True)[0].detach().unsqueeze(-1)
340
+ attn = (attn - swin_pb_relax_const)*alpha
341
+
342
+ if text_hidden_state is None:
343
+ attn = F.softmax(attn, dim=-1)
344
+ if attn_dropout is not None:
345
+ with get_cuda_rng_tracker().fork():
346
+ attn = attn_dropout(attn)
347
+ context_swin = torch.matmul(attn, v).permute(0, 2, 1, 3).reshape(b0, s1, h0)
348
+ else:
349
+ # frame-to-text
350
+ assert text_attn_mask is not None
351
+ s0 = text_hidden_state.shape[1]
352
+ qkv_text = self.query_key_value[layer_id](text_hidden_state).reshape(b0, s0, 3, self.n_head, h).permute(2, 0, 3, 1, 4) #[3, b0, n_head, s0, h]
353
+ q_text, k_text, v_text = qkv_text[0], qkv_text[1], qkv_text[2]
354
+ text_attn_mask = text_attn_mask.unsqueeze(2)
355
+ if swin_pb_relax:
356
+ attn_frame2text = torch.matmul(q.reshape(b0, self.n_head, s1, h) / (math.sqrt(h)*alpha), k_text.transpose(-1, -2))
357
+ attn_frame2text = (attn_frame2text-swin_pb_relax_const.reshape(b0, self.n_head, 1, 1))*alpha
358
+ else:
359
+ attn_frame2text = torch.matmul(q.reshape(b0, self.n_head, s1, h) / math.sqrt(h), k_text.transpose(-1, -2))
360
+ attn_frame2text = torch.mul(text_attn_mask, attn_frame2text) - 10000.0 * (1.0 - text_attn_mask)
361
+ attn_frame2text = attn_frame2text.reshape(b0, self.n_head, s1, s0)
362
+
363
+ attn = torch.cat((attn, attn_frame2text), dim=-1)
364
+ attn = F.softmax(attn, dim=-1)
365
+
366
+ if attn_dropout is not None:
367
+ with get_cuda_rng_tracker().fork():
368
+ attn = attn_dropout(attn)
369
+
370
+ context_frame = (torch.matmul(attn[..., :-s0], v) +
371
+ torch.matmul(attn[..., -s0:].reshape(b0, self.n_head,s1, s0), v_text))\
372
+ .permute(0, 2, 1, 3).reshape(b0, s1, h0)
373
+
374
+ return context_frame
375
+
376
+
377
+ def attention_localframe_and_text(q0, k0, v0, attention_mask_totxt, attention_mask_local,
378
+ n_head, text_len, frame_len, frame_num, attention_dropout=None, layer_id=0, **kwargs):
379
+ b, s0, h0 = q0.shape
380
+ s1 = s0 - text_len
381
+ h = h0 // n_head
382
+ assert q0.shape[1] == v0.shape[1] == k0.shape[1] == text_len+frame_len*frame_num
383
+ # attention_mask_totxt [b, 1, 1, text_len]
384
+ # attention_mask_local [1, 1, frame_num, frame_len, frame_len]
385
+ # attention_mask: [1, 1, text_len+frame_len, text_len+frame_len]
386
+
387
+ q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
388
+ v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
389
+ k0 = k0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
390
+ k0T = k0.transpose(-1, -2)
391
+
392
+ # score: any2text
393
+ score_any2text = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T[..., :text_len])
394
+ score_any2text_part1 = torch.mul(score_any2text[..., :text_len, :], attention_mask_totxt) \
395
+ - 10000.0 * (1.0 - attention_mask_totxt)
396
+ score_any2text_part2 = torch.mul(score_any2text[..., text_len:, :], attention_mask_totxt) - \
397
+ 10000.0 * (1.0 - attention_mask_totxt)
398
+
399
+ # score: frame local
400
+ q0_frame = q0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h)
401
+ v0_frame = v0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h)
402
+ k0T_frame = k0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h).transpose(-1, -2)
403
+ score_frame_local0 = torch.matmul(q0_frame / math.sqrt(q0_frame.shape[-1]), k0T_frame)
404
+ score_frame_local0 = torch.mul(score_frame_local0, attention_mask_local) \
405
+ - 10000.0 * (1.0 - attention_mask_local)
406
+
407
+ # context for frame
408
+ score_frame_all = torch.cat((score_any2text_part2,
409
+ score_frame_local0.view(b, n_head, s1, frame_len)), dim=-1)
410
+ attention_probs_frame = F.softmax(score_frame_all, dim=-1)
411
+
412
+ if attention_dropout is not None:
413
+ with get_cuda_rng_tracker().fork():
414
+ attention_probs_frame = attention_dropout(attention_probs_frame)
415
+
416
+ context_frame2text = torch.matmul(attention_probs_frame[..., :text_len], v0[..., :text_len, :]) # [b, n_head, s1, h]
417
+ context_frame_local0 = torch.matmul(attention_probs_frame[..., text_len:text_len+frame_len].\
418
+ view(b, n_head, frame_num, frame_len, frame_len), v0_frame).view(b, n_head, s1, h)
419
+ context_frame = (context_frame2text + context_frame_local0).transpose(1, 2).reshape(b, s1, h0)
420
+
421
+ # context for text
422
+ attention_probs_text = F.softmax(score_any2text_part1, dim=-1)
423
+ if attention_dropout is not None:
424
+ with get_cuda_rng_tracker().fork():
425
+ attention_probs_text = attention_dropout(attention_probs_text)
426
+ context_text2text = torch.matmul(attention_probs_text, v0[..., :text_len, :])
427
+ context_text2text = context_text2text.transpose(1, 2).reshape(b, text_len, h0)
428
+
429
+ return context_text2text, context_frame
430
+
431
+
432
+ class CogVideoModel(BaseModel):
433
+ def __init__(self, args, transformer=None, parallel_output=True):
434
+ super().__init__(args, transformer=transformer, parallel_output=parallel_output)
435
+ self.stage = args.cogvideo_stage # 1 or 2
436
+ self.mode_sequential = True if self.stage==1 else False
437
+ self.layout = args.layout # [64, 64+400, 64+5*400]
438
+ self.n_head = args.num_attention_heads
439
+ frame_resolution = int(math.sqrt(self.layout[1]-self.layout[0]))
440
+ frame_num = (args.layout[2]-args.layout[0])//(args.layout[1]-args.layout[0])
441
+ frame_len = self.layout[1]-self.layout[0]
442
+
443
+ self.add_mixin('extra_position_embedding', PositionEmbeddingMixin(
444
+ args.additional_seqlen, args.hidden_size
445
+ ))
446
+
447
+ if args.window_size == -1:
448
+ # full attention
449
+ assert self.stage == 1
450
+ self.add_mixin('attention_plus', FullAttentionMixin(
451
+ num_layers=args.num_layers,
452
+ hidden_size=args.hidden_size,
453
+ frame_resolution=frame_resolution,
454
+ n_head=args.num_attention_heads,
455
+ frame_num=frame_num,
456
+ ))
457
+ else:
458
+ self.add_mixin('attention_plus', WindowAttentionMixin(
459
+ num_layers=args.num_layers,
460
+ hidden_size=args.hidden_size,
461
+ frame_resolution=frame_resolution,
462
+ window_size=args.window_size,
463
+ shift_size=args.window_size//2,
464
+ n_head=args.num_attention_heads,
465
+ frame_num=frame_num,
466
+ ))
467
+ # attention_mask_local
468
+ self.attention_mask_local_sequential = torch.ones(1, 1, frame_num, frame_len, frame_len).tril().unsqueeze(0)
469
+ self.attention_mask_local_interp = torch.ones(1, 1, frame_num, frame_len, frame_len)
470
+
471
+ for idx in range(1, frame_num, 2):
472
+ self.attention_mask_local_interp[:, :, idx:idx+1].tril_()
473
+ self.attention_mask_local_interp = self.attention_mask_local_interp.unsqueeze(0)
474
+ self.mask_initialized = False
475
+
476
+ @classmethod
477
+ def add_model_specific_args(cls, parser):
478
+ group = parser.add_argument_group('CogVideoModel', 'CogVideo model configurations')
479
+ group.add_argument("--layout", type=str, default='64, 464, 2064', help='text_len, textlen+frame_len, textlen+frame_len*frame_num')
480
+ group.add_argument("--window-size", type=int, default=10, help="swin attention's window size in temperal channel, -1 represents full attention")
481
+ group.add_argument("--additional-seqlen", type=int, default=2000)
482
+ group.add_argument("--cogvideo-stage", type=int, default=1, choices=[1,2])
483
+ return parser
484
+
485
+ def disable_untrainable_params(self):
486
+ self.transformer.requires_grad_(False)
487
+
488
+ def position_embedding_forward(self, position_ids, **kw_args):
489
+ position = position_ids[..., :(64+400)]
490
+ position_plus = position_ids[..., (64+400):]
491
+ position_embeddings = torch.cat(
492
+ (
493
+ self.transformer.position_embeddings(position),
494
+ self.get_mixin('extra_position_embedding').position_embeddings(position_plus-(512+400))
495
+ ),
496
+ dim=-2
497
+ )
498
+ return position_embeddings
499
+
500
+ def attention_forward(self, hidden_states, mask, layer_id, **kw_args):
501
+ # mask.shape=[bs, 1, 1, 64]
502
+ if not self.mask_initialized:
503
+ self.attention_mask_local_sequential = self.attention_mask_local_sequential.to(device=hidden_states.device, dtype=hidden_states.dtype)
504
+ self.attention_mask_local_interp = self.attention_mask_local_interp.to(device=hidden_states.device, dtype=hidden_states.dtype)
505
+ self.mask_initialized = True
506
+
507
+ attn_module = self.transformer.layers[layer_id].attention
508
+ hidden_size = hidden_states.shape[-1]
509
+ bs = hidden_states.shape[0]
510
+
511
+ # base model qkv
512
+ mixed_raw_layer = attn_module.query_key_value(hidden_states)
513
+ q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer, 3)
514
+ dropout_fn = self.transformer.layers[layer_id].attention.attention_dropout if self.training else None
515
+
516
+ attention_mask_local = self.attention_mask_local_sequential if self.mode_sequential else self.attention_mask_local_interp
517
+ context_text, context_frame_local_text = attention_localframe_and_text(
518
+ q0, k0, v0,
519
+ attention_mask_totxt=mask,
520
+ attention_mask_local=attention_mask_local,
521
+ n_head=attn_module.num_attention_heads_per_partition,
522
+ text_len=self.layout[0],
523
+ frame_len=self.layout[1]-self.layout[0],
524
+ frame_num=(self.layout[2]-self.layout[0])//(self.layout[1]-self.layout[0]),
525
+ attention_dropout=dropout_fn,
526
+ layer_id=layer_id,
527
+ )
528
+
529
+ context_frame_swin = self.get_mixin('attention_plus').attention_extra(
530
+ hidden_states[:, self.layout[0]:], layer_id, dropout_fn,
531
+ text_hidden_state=hidden_states[:, :self.layout[0]],
532
+ text_attn_mask=mask[..., 0, :],
533
+ mode_sequential=self.mode_sequential)
534
+
535
+ attn_distrib = torch.sigmoid(self.get_mixin('attention_plus').attn_distribution[layer_id])
536
+ attn_distrib = attn_distrib.unsqueeze(0).unsqueeze(0)
537
+
538
+ output_text = attn_module.dense(context_text)
539
+ output_frame = torch.mul(attn_module.dense(context_frame_local_text), attn_distrib)\
540
+ +torch.mul(self.get_mixin('attention_plus').dense[layer_id](context_frame_swin), 1-attn_distrib)
541
+ output = torch.cat((output_text, output_frame), dim=-2)
542
+
543
+ return output
pretrain_cogvideo.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ '''
3
+ @File : pretrain_cogvideo.py
4
+ @Time : 2021/10/06 00:58:32
5
+ @Author : Wenyi Hong
6
+ @Contact : hwy22@mails.tsinghua.edu.cn
7
+ '''
8
+
9
+ # here put the import lib
10
+ import os
11
+ import sys
12
+ import math
13
+ import random
14
+ import torch
15
+ import argparse
16
+ import numpy as np
17
+ from icetk import icetk as tokenizer
18
+ tokenizer.add_special_tokens(['<start_of_image>', '<start_of_english>', '<start_of_chinese>'])
19
+
20
+ from models.cogvideo_model import CogVideoModel
21
+ from SwissArmyTransformer import mpu, get_args
22
+ from SwissArmyTransformer.training.deepspeed_training import training_main
23
+ from SwissArmyTransformer.data_utils import BinaryDataset
24
+
25
+ def get_masks_and_position_ids_video(data, attention_mask_totxt=None, args=None):
26
+ # Extract batch size and sequence length.
27
+ batch_size, seq_length = data.size()
28
+ assert attention_mask_totxt is not None
29
+ layout = args.layout
30
+ assert seq_length == layout[-1]
31
+ n_pads = layout[0] - attention_mask_totxt.sum(dim=-1).long()
32
+ frame_len = layout[1]-layout[0]
33
+ position_ids = torch.zeros(batch_size, layout[2], dtype=torch.long,
34
+ device=data.device)
35
+ for i in range(batch_size):
36
+ torch.arange(layout[0] - n_pads[i], out=position_ids[i, n_pads[i]:layout[0]],
37
+ dtype=torch.long, device=data.device)
38
+ torch.arange(512, 512+layout[2]-layout[0],
39
+ out=position_ids[i, layout[0]:], dtype=torch.long, device=data.device)
40
+ return position_ids
41
+
42
+
43
+ def get_batch(data_iterator, args, timers):
44
+ # Items and their type.
45
+ keys = ['text', 'loss_mask', 'attention_mask_totxt']
46
+ datatype = torch.int64
47
+
48
+ # Broadcast data.
49
+ timers('data loader').start()
50
+ if data_iterator is not None:
51
+ data = next(data_iterator)
52
+ else:
53
+ data = None
54
+ timers('data loader').stop()
55
+
56
+ data_b = mpu.broadcast_data(keys, data, datatype)
57
+ # Unpack.
58
+ tokens_ = data_b['text'].long()
59
+ loss_mask = data_b['loss_mask'].float()
60
+ attention_mask_totxt = data_b['attention_mask_totxt'].float()
61
+
62
+ labels = tokens_[:, 1:].clone().contiguous()
63
+ loss_mask = loss_mask[:, 1:].contiguous()
64
+ tokens = tokens_[:, :-1].clone().contiguous()
65
+
66
+ for idx in range(args.layout[0], args.layout[2], 400):
67
+ tokens[:, idx] = tokenizer['<start_of_image>']
68
+ # Get the masks and postition ids.
69
+ position_ids = get_masks_and_position_ids_video(
70
+ tokens,
71
+ attention_mask_totxt=attention_mask_totxt,
72
+ args=args
73
+ )
74
+ attention_mask_totxt = attention_mask_totxt.unsqueeze(1).unsqueeze(1)
75
+ # Convert
76
+ if args.fp16:
77
+ attention_mask_totxt = attention_mask_totxt.half()
78
+ return tokens, labels, loss_mask, attention_mask_totxt, position_ids
79
+
80
+
81
+ def forward_step(data_iterator, model, args, timers):
82
+ """Forward step."""
83
+
84
+ # Get the batch.
85
+ timers('batch generator').start()
86
+ tokens, labels, loss_mask, attention_mask_totxt, position_ids = get_batch(
87
+ data_iterator, args, timers)
88
+ timers('batch generator').stop()
89
+
90
+ # Forward model.
91
+ logits, *mems = model(tokens, position_ids, attention_mask_totxt)
92
+ # ======= hyper params =======#
93
+ perframe_len = 400
94
+ text_len=64
95
+ frame_num = 5
96
+ logits_img_tokens = logits[:, text_len:, :tokenizer.num_image_tokens].float().contiguous()
97
+ losses = mpu.vocab_parallel_cross_entropy(logits_img_tokens, labels[:, text_len:])
98
+ # scaling loss mask
99
+ loss_mask = loss_mask[:, text_len:].reshape(-1)
100
+
101
+ losses_1d = losses.reshape(-1) * loss_mask
102
+ loss = torch.sum(losses_1d) / loss_mask.sum()
103
+ # ===================== Log partial losses ======================== #
104
+ log_loss_dict = {}
105
+ bs = losses.shape[0]
106
+
107
+ if args.cogvideo_stage == 1:
108
+ for i in range(frame_num):
109
+ log_loss_dict[f'AR_f{i}_loss'] = losses[:, i*perframe_len:(i+1)*perframe_len].contiguous().reshape(-1).detach().sum() / max((perframe_len*bs), 1)
110
+ else:
111
+ for i in range(1, frame_num-1):
112
+ log_loss_dict[f'ITP_f{i}_loss'] = losses[:, i*perframe_len:(i+1)*perframe_len].contiguous().reshape(-1).detach().sum() / max((perframe_len*bs), 1)
113
+
114
+ # ===================== END OF BLOCK ======================= #
115
+ return loss, log_loss_dict
116
+
117
+
118
+ def create_dataset_function(path, args):
119
+ dataset_layout = [64, 464, 2064]
120
+ input_layout = [64, 464, 2064]
121
+ # frame_num = 6
122
+ # frame_interval = 2 # DEBUG!!!
123
+ def process_fn(row):
124
+ row = row.astype(np.int64)
125
+ text = row[:dataset_layout[0]]
126
+ frames = row[dataset_layout[0]:]
127
+
128
+ if text[0] == tokenizer['<pad>']:
129
+ text = text[1:] # due to our way of data processing
130
+ if args.cogvideo_stage == 1:
131
+ text, loss_mask, frames = make_text_video_generation(text, frames)
132
+ else:
133
+ text, loss_mask, frames = mask_video_frame_interpolation(text, frames)
134
+
135
+ n_pad = input_layout[0] - len(text)
136
+ parts = [
137
+ np.array([tokenizer['<pad>']] * n_pad, dtype=np.int64),
138
+ text,
139
+ np.array([tokenizer['<start_of_image>']], dtype=np.int64),
140
+ frames,
141
+ ]
142
+ ret = np.concatenate(parts, axis=0)
143
+
144
+ attention_mask_totxt = np.array([0] * n_pad + [1] * (input_layout[0]-n_pad))
145
+ return {'text': ret,
146
+ 'loss_mask': loss_mask,
147
+ 'attention_mask_totxt': attention_mask_totxt,
148
+ }
149
+ return BinaryDataset(path, process_fn, length_per_sample=dataset_layout[-1])
150
+
151
+ def make_text_video_generation(text, frames):
152
+ input_layout = [64, 464, 2064]
153
+ text = text[text!= tokenizer['<pad>']][:input_layout[0]] # dataset format: 1.0秒<n>{text}<pad><pad> ...
154
+ loss_mask = np.array([0] * (input_layout[1]+1) + [1] * (input_layout[2] - input_layout[1])) # 按照input的,之后loss_mask会左移一位
155
+ return text, loss_mask, frames
156
+
157
+ def mask_video_frame_interpolation(text, frames):
158
+ input_layout = [64, 464, 2064]
159
+ frame_len = input_layout[1]-input_layout[0]
160
+ # text format: <pad> 1.0秒 <n> {text} <pad> <pad>
161
+ text = text[text!= tokenizer['<pad>']][:input_layout[0]]
162
+ loss_mask = np.array([0] * (input_layout[1]+1)
163
+ + [1] * (input_layout[1]-input_layout[0])
164
+ + [0] * (input_layout[1]-input_layout[0])
165
+ + [1] * (input_layout[1]-input_layout[0])
166
+ + [0] * (input_layout[1]-input_layout[0]) )# 按照input的,之后loss_mask会左移一位
167
+
168
+ return text, loss_mask, frames
169
+
170
+
171
+
172
+ if __name__ == '__main__':
173
+ py_parser = argparse.ArgumentParser(add_help=False)
174
+ py_parser.add_argument('--txt-loss-scale', type=float, default=1)
175
+ CogVideoModel.add_model_specific_args(py_parser)
176
+
177
+ known, args_list = py_parser.parse_known_args()
178
+
179
+ args = get_args(args_list)
180
+ args = argparse.Namespace(**vars(args), **vars(known))
181
+
182
+ args.layout = [int(x) for x in args.layout.split(',')]
183
+
184
+ training_main(args, model_cls=CogVideoModel, forward_step_function=forward_step, create_dataset_function=create_dataset_function)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ SwissArmyTransformer>=0.2.9
2
+ icetk
3
+ gifmaker
4
+ torchvision
scripts/ds_brain_pretrain_cogvideo_stage1.sh ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /bin/bash
2
+
3
+ # Change for multinode config
4
+
5
+ NUM_WORKERS=1
6
+ NUM_GPUS_PER_WORKER=8
7
+ MP_SIZE=1
8
+
9
+ script_path=$(realpath $0)
10
+ script_dir=$(dirname $script_path)
11
+ main_dir=$(dirname $script_dir)
12
+
13
+ OPTIONS_NCCL="NCCL_DEBUG=warning NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2"
14
+ HOST_FILE_PATH="hostfile"
15
+ # HOST_FILE_PATH="hostfile_single"
16
+
17
+ video_data_test="" # TODO
18
+ CHECKPOINT_PATH="" # TODO: CogView2 ckpt
19
+
20
+ config_json="$script_dir/ds_config_zero.json"
21
+ gpt_options=" \
22
+ --experiment-name pretrain-cogvideo-stage1 \
23
+ --tokenizer-type fake \
24
+ --vocab-size 150010 \
25
+ --model-parallel-size ${MP_SIZE} \
26
+ --mode finetune \
27
+ --num-workers 0 \
28
+ --num-layers 48 \
29
+ --hidden-size 3072 \
30
+ --num-attention-heads 48 \
31
+ --layout 64,464,2064 \
32
+ --window-size -1 \
33
+ --cogvideo-stage 1 \
34
+ --additional-seqlen 2000 \
35
+ --train-iters 500000 \
36
+ --resume-dataloader \
37
+ --train-data ${video_data_test} \
38
+ --train-data-weights 1 \
39
+ --split 949,50,1 \
40
+ --distributed-backend nccl \
41
+ --lr-decay-style cosine \
42
+ --warmup .001 \
43
+ --checkpoint-activations \
44
+ --max-sequence-length 1024 \
45
+ --fp16 \
46
+ --save-interval 2000 \
47
+ --eval-interval 500 \
48
+ --eval-iters 15 \
49
+ --log-interval 50 \
50
+ --save $main_dir/checkpoints \
51
+ --sandwich-ln \
52
+ --load $CHECKPOINT_PATH \
53
+ "
54
+ # --load $CHECKPOINT_PATH \
55
+ # \ --sandwich-ln
56
+
57
+
58
+ gpt_options="${gpt_options}
59
+ --deepspeed \
60
+ --deepspeed_config ${config_json} \
61
+ "
62
+
63
+ #!/bin/bash
64
+
65
+ # Distribute Example
66
+ #export NCCL_SOCKET_IFNAME=eth0
67
+ export NCCL_IB_DISABLE=0
68
+ export NCCL_NET_GDR_LEVEL=2
69
+ #export NCCL_IB_CUDA_SUPPORT=1
70
+ #export NCCL_IB_GID_INDEX=3
71
+ #export NCCL_IB_HCA=$(pushd /sys/class/infiniband/ > /dev/null; for i in mlx5_*; do cat $i/ports/1/gid_attrs/types/* 2>/dev/null | grep v >/dev/null && echo $i ; done; popd > /dev/null)
72
+ export NCCL_DEBUG=info
73
+ export OMP_NUM_THREADS=4
74
+
75
+ if [ $RLAUNCH_REPLICA == "0" ]; then
76
+ ifconfig eth0 | grep inet | grep -v inet6 | awk '{print $2}' > master_ip
77
+ fi
78
+
79
+ function finish {
80
+ rm -rf master_ip
81
+ }
82
+
83
+ trap finish EXIT INT TERM
84
+
85
+ while [ ! -f master_ip ]; do
86
+ echo "wait master_ip..."
87
+ ls > /dev/null && sleep 1;
88
+ done
89
+
90
+ export MASTER_ADDR=$(cat master_ip)
91
+ echo "master_ip: $MASTER_ADDR"
92
+
93
+ MP_SIZE=1
94
+ task_set=$2
95
+ source $1
96
+ DATESTR=$(date +"%m-%d-%H-%M")
97
+
98
+ mkdir logs
99
+ run_cmd="sudo /opt/conda/bin/python -m torch.distributed.launch --nproc_per_node=8 \
100
+ --nnodes=$RLAUNCH_REPLICA_TOTAL --node_rank=$RLAUNCH_REPLICA \
101
+ --master_addr=$MASTER_ADDR --master_port=12355 pretrain_cogvideo.py $@ ${gpt_options} 2>&1 | tee logs/log-${DATESTR}-${RLAUNCH_REPLICA}.txt"
102
+
103
+
104
+ # run_cmd="${OPTIONS_NCCL} deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} --hostfile ${HOST_FILE_PATH} pretrain_video_swin_cond_glm_interp.py $@ ${gpt_options}"
105
+ echo ${run_cmd}
106
+ eval ${run_cmd}
107
+
108
+ set +x
scripts/ds_brain_pretrain_cogvideo_stage2.sh ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /bin/bash
2
+
3
+ # Change for multinode config
4
+
5
+ NUM_WORKERS=1
6
+ NUM_GPUS_PER_WORKER=8
7
+ MP_SIZE=1
8
+
9
+ script_path=$(realpath $0)
10
+ script_dir=$(dirname $script_path)
11
+ main_dir=$(dirname $script_dir)
12
+
13
+ OPTIONS_NCCL="NCCL_DEBUG=warning NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2"
14
+ HOST_FILE_PATH="hostfile"
15
+ # HOST_FILE_PATH="hostfile_single"
16
+
17
+ video_data_test="" # TODO
18
+ CHECKPOINT_PATH="" # TODO: CogView2 ckpt
19
+
20
+ config_json="$script_dir/ds_config_zero.json"
21
+ gpt_options=" \
22
+ --experiment-name pretrain-cogvideo-stage2 \
23
+ --tokenizer-type fake \
24
+ --vocab-size 150010 \
25
+ --model-parallel-size ${MP_SIZE} \
26
+ --mode finetune \
27
+ --num-workers 0 \
28
+ --num-layers 48 \
29
+ --hidden-size 3072 \
30
+ --num-attention-heads 48 \
31
+ --layout 64,464,2064 \
32
+ --window-size 10 \
33
+ --cogvideo-stage 2 \
34
+ --additional-seqlen 2000 \
35
+ --train-iters 500000 \
36
+ --resume-dataloader \
37
+ --train-data ${video_data_test} \
38
+ --train-data-weights 1 \
39
+ --split 949,50,1 \
40
+ --distributed-backend nccl \
41
+ --lr-decay-style cosine \
42
+ --warmup .001 \
43
+ --checkpoint-activations \
44
+ --max-sequence-length 1024 \
45
+ --fp16 \
46
+ --save-interval 2000 \
47
+ --eval-interval 500 \
48
+ --eval-iters 15 \
49
+ --log-interval 50 \
50
+ --save $main_dir/checkpoints \
51
+ --sandwich-ln \
52
+ --load $CHECKPOINT_PATH \
53
+ "
54
+ # --load $CHECKPOINT_PATH \
55
+ # \ --sandwich-ln
56
+
57
+
58
+ gpt_options="${gpt_options}
59
+ --deepspeed \
60
+ --deepspeed_config ${config_json} \
61
+ "
62
+
63
+ #!/bin/bash
64
+
65
+ # Distribute Example
66
+ #export NCCL_SOCKET_IFNAME=eth0
67
+ export NCCL_IB_DISABLE=0
68
+ export NCCL_NET_GDR_LEVEL=2
69
+ #export NCCL_IB_CUDA_SUPPORT=1
70
+ #export NCCL_IB_GID_INDEX=3
71
+ #export NCCL_IB_HCA=$(pushd /sys/class/infiniband/ > /dev/null; for i in mlx5_*; do cat $i/ports/1/gid_attrs/types/* 2>/dev/null | grep v >/dev/null && echo $i ; done; popd > /dev/null)
72
+ export NCCL_DEBUG=info
73
+ export OMP_NUM_THREADS=4
74
+
75
+ if [ $RLAUNCH_REPLICA == "0" ]; then
76
+ ifconfig eth0 | grep inet | grep -v inet6 | awk '{print $2}' > master_ip
77
+ fi
78
+
79
+ function finish {
80
+ rm -rf master_ip
81
+ }
82
+
83
+ trap finish EXIT INT TERM
84
+
85
+ while [ ! -f master_ip ]; do
86
+ echo "wait master_ip..."
87
+ ls > /dev/null && sleep 1;
88
+ done
89
+
90
+ export MASTER_ADDR=$(cat master_ip)
91
+ echo "master_ip: $MASTER_ADDR"
92
+
93
+ MP_SIZE=1
94
+ task_set=$2
95
+ source $1
96
+ DATESTR=$(date +"%m-%d-%H-%M")
97
+
98
+ mkdir logs
99
+ run_cmd="sudo /opt/conda/bin/python -m torch.distributed.launch --nproc_per_node=8 \
100
+ --nnodes=$RLAUNCH_REPLICA_TOTAL --node_rank=$RLAUNCH_REPLICA \
101
+ --master_addr=$MASTER_ADDR --master_port=12355 pretrain_cogvideo.py $@ ${gpt_options} 2>&1 | tee logs/log-${DATESTR}-${RLAUNCH_REPLICA}.txt"
102
+
103
+
104
+ # run_cmd="${OPTIONS_NCCL} deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} --hostfile ${HOST_FILE_PATH} pretrain_video_swin_cond_glm_interp.py $@ ${gpt_options}"
105
+ echo ${run_cmd}
106
+ eval ${run_cmd}
107
+
108
+ set +x
scripts/ds_config_zero.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train_micro_batch_size_per_gpu": 4,
3
+ "gradient_accumulation_steps": 1,
4
+ "steps_per_print": 1,
5
+ "gradient_clipping": 0.1,
6
+ "zero_optimization": {
7
+ "stage": 2,
8
+ "cpu_offload": true,
9
+ "contiguous_gradients": false,
10
+ "overlap_comm": true,
11
+ "reduce_scatter": false,
12
+ "reduce_bucket_size": 100000000,
13
+ "allgather_bucket_size": 1000000000,
14
+ "load_from_fp32_weights": false
15
+ },
16
+ "zero_allow_untested_optimizer": true,
17
+ "fp16": {
18
+ "enabled": true,
19
+ "loss_scale": 0,
20
+ "loss_scale_window": 400,
21
+ "hysteresis": 2,
22
+ "min_loss_scale": 1
23
+ },
24
+ "optimizer": {
25
+ "type": "Adam",
26
+ "params": {
27
+ "lr": 0.0002,
28
+ "betas": [
29
+ 0.9,
30
+ 0.95
31
+ ],
32
+ "eps": 1e-8,
33
+ "weight_decay": 1e-4
34
+ }
35
+ },
36
+ "activation_checkpointing": {
37
+ "partition_activations": false,
38
+ "contiguous_memory_optimization": false
39
+ },
40
+ "wall_clock_breakdown": false
41
+ }
42
+
scripts/inference_cogvideo_pipeline.sh ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ NLAYERS=48
4
+ NHIDDEN=3072
5
+ NATT=48
6
+ MAXSEQLEN=1024
7
+ MASTER_PORT=$(shuf -n 1 -i 10000-65535)
8
+ MPSIZE=1
9
+
10
+ #SAMPLING ARGS
11
+ TEMP=1.05
12
+ TOPK=12
13
+
14
+ script_path=$(realpath $0)
15
+ script_dir=$(dirname $script_path)
16
+
17
+ MASTER_PORT=${MASTER_PORT} SAT_HOME=/sharefs/cogview-new python cogvideo_pipeline.py \
18
+ --input-source interactive \
19
+ --output-path ./output \
20
+ --parallel-size 1 \
21
+ --both-stages \
22
+ --use-guidance-stage1 \
23
+ --guidance-alpha 3.0 \
24
+ --generate-frame-num 5 \
25
+ --tokenizer-type fake \
26
+ --mode inference \
27
+ --distributed-backend nccl \
28
+ --fp16 \
29
+ --model-parallel-size $MPSIZE \
30
+ --temperature $TEMP \
31
+ --coglm-temperature2 0.89 \
32
+ --top_k $TOPK \
33
+ --sandwich-ln \
34
+ --seed 1234 \
35
+ --num-workers 0 \
36
+ --batch-size 4 \
37
+ --max-inference-batch-size 8 \
38
+ $@
sr_pipeline/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ '''
3
+ @File : __init__.py
4
+ @Time : 2022/03/02 13:57:09
5
+ @Author : Ming Ding
6
+ @Contact : dm18@mails.tsinghua.edu.cn
7
+ '''
8
+
9
+ # here put the import lib
10
+ import os
11
+ import sys
12
+ import math
13
+ import random
14
+
15
+ from .direct_sr import DirectSuperResolution
16
+ from .iterative_sr import IterativeSuperResolution
17
+ from .sr_group import SRGroup
sr_pipeline/direct_sr.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ '''
3
+ @File : direct_sr.py
4
+ @Time : 2022/03/02 13:58:11
5
+ @Author : Ming Ding
6
+ @Contact : dm18@mails.tsinghua.edu.cn
7
+ '''
8
+
9
+ # here put the import lib
10
+ import os
11
+ import sys
12
+ import math
13
+ import random
14
+ import torch
15
+
16
+ # -*- encoding: utf-8 -*-
17
+ '''
18
+ @File : inference_cogview2.py
19
+ @Time : 2021/10/10 16:31:34
20
+ @Author : Ming Ding
21
+ @Contact : dm18@mails.tsinghua.edu.cn
22
+ '''
23
+
24
+ # here put the import lib
25
+ import os
26
+ import sys
27
+ import math
28
+ import random
29
+ from PIL import ImageEnhance, Image
30
+
31
+ import torch
32
+ import argparse
33
+ from torchvision import transforms
34
+
35
+ from SwissArmyTransformer import get_args
36
+ from SwissArmyTransformer.training.model_io import load_checkpoint
37
+ from .dsr_sampling import filling_sequence_dsr, IterativeEntfilterStrategy
38
+ from SwissArmyTransformer.generation.utils import timed_name, save_multiple_images, generate_continually
39
+
40
+ from .dsr_model import DsrModel
41
+
42
+ from icetk import icetk as tokenizer
43
+
44
+ class DirectSuperResolution:
45
+ def __init__(self, args, path, max_bz=4, topk=6, onCUDA=False):
46
+ args.load = path
47
+ args.kernel_size = 5
48
+ args.kernel_size2 = 5
49
+ args.new_sequence_length = 4624
50
+ args.layout = [96,496,4096]
51
+
52
+ model = DsrModel(args)
53
+ if args.fp16:
54
+ model = model.half()
55
+
56
+ load_checkpoint(model, args) # on cpu
57
+ model.eval()
58
+ self.model = model
59
+ self.onCUDA = onCUDA
60
+ if onCUDA:
61
+ self.model = self.model.cuda()
62
+
63
+ invalid_slices = [slice(tokenizer.num_image_tokens, None)]
64
+
65
+ self.strategy = IterativeEntfilterStrategy(invalid_slices,
66
+ temperature=1.0, topk=topk) # temperature not used # Temperature Freezed Here!!
67
+ self.max_bz = max_bz
68
+
69
+ def __call__(self, text_tokens, image_tokens, enhance=False):
70
+ if len(text_tokens.shape) == 1:
71
+ text_tokens.unsqueeze_(0)
72
+ if len(image_tokens.shape) == 1:
73
+ image_tokens.unsqueeze_(0)
74
+ # ===================== Debug ======================== #
75
+ # new_image_tokens = []
76
+ # for small_img in image_tokens:
77
+ # decoded = tokenizer.decode(image_ids=small_img)
78
+ # decoded = torch.nn.functional.interpolate(decoded, size=(480, 480)).squeeze(0)
79
+ # ndarr = decoded.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
80
+ # image_pil_raw = ImageEnhance.Sharpness(Image.fromarray(ndarr))
81
+ # small_img2 = tokenizer.encode(image_pil=image_pil_raw.enhance(1.5), image_size=480).view(-1)
82
+ # new_image_tokens.append(small_img2)
83
+ # image_tokens = torch.stack(new_image_tokens)
84
+ # return image_tokens
85
+ # ===================== END OF BLOCK ======================= #
86
+ if enhance:
87
+ new_image_tokens = []
88
+ for small_img in image_tokens:
89
+ decoded = tokenizer.decode(image_ids=small_img).squeeze(0)
90
+ ndarr = decoded.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
91
+ image_pil_raw = ImageEnhance.Sharpness(Image.fromarray(ndarr))
92
+ small_img2 = tokenizer.encode(image_pil=image_pil_raw.enhance(1.), image_size=160).view(-1)
93
+ new_image_tokens.append(small_img2)
94
+ image_tokens = torch.stack(new_image_tokens)
95
+
96
+ seq = torch.cat((text_tokens,image_tokens), dim=1)
97
+ seq1 = torch.tensor([tokenizer['<start_of_image>']]*3601, device=image_tokens.device).unsqueeze(0).expand(text_tokens.shape[0], -1)
98
+ if not self.onCUDA:
99
+ print('Converting Dsr model...')
100
+ model = self.model.cuda()
101
+ else:
102
+ model = self.model
103
+ print('Direct super-resolution...')
104
+ output_list = []
105
+ for tim in range(max((text_tokens.shape[0]+self.max_bz-1) // self.max_bz, 1)):
106
+ output1 = filling_sequence_dsr(model,
107
+ seq[tim*self.max_bz:(tim+1)*self.max_bz],
108
+ seq1[tim*self.max_bz:(tim+1)*self.max_bz],
109
+ warmup_steps=1, block_hw=(1, 0),
110
+ strategy=self.strategy
111
+ )
112
+ output_list.extend(output1[1:])
113
+ if not self.onCUDA:
114
+ print('Moving back Dsr to cpu...')
115
+ model = model.cpu()
116
+ torch.cuda.empty_cache()
117
+ return torch.cat(output_list, dim=0)
sr_pipeline/dsr_model.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ '''
3
+ @File : cuda2d_model.py
4
+ @Time : 2021/10/02 01:36:32
5
+ @Author : Ming Ding
6
+ @Contact : dm18@mails.tsinghua.edu.cn
7
+ '''
8
+
9
+ # here put the import lib
10
+ import os
11
+ import sys
12
+ import math
13
+ import random
14
+ import torch
15
+ import torch.nn.functional as F
16
+
17
+
18
+ from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin
19
+
20
+ from SwissArmyTransformer.model.transformer import split_tensor_along_last_dim, unscaled_init_method
21
+ from SwissArmyTransformer.mpu.utils import sqrt
22
+ from deepspeed.runtime.activation_checkpointing.checkpointing import get_cuda_rng_tracker
23
+ from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear
24
+
25
+ class PositionEmbeddingMixin(BaseMixin):
26
+ def __init__(self, additional_sequence_length, hidden_size,
27
+ init_method_std=0.02, reinit_slice=slice(512, 512+400)
28
+ ):
29
+ super(PositionEmbeddingMixin, self).__init__()
30
+ self.reinit_slice = reinit_slice
31
+ self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size)
32
+ torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
33
+
34
+ def reinit(self, parent_model=None):
35
+ old_weights = self.transformer.position_embeddings.weight.data[self.reinit_slice]
36
+ old_len, hidden_size = old_weights.shape
37
+ assert hidden_size == self.position_embeddings.weight.shape[-1]
38
+ old_edge, new_edge = sqrt(old_len), sqrt(self.position_embeddings.weight.shape[-2])
39
+ assert new_edge % old_edge == 0
40
+ self.position_embeddings.weight.data.view(new_edge // old_edge, old_edge, new_edge // old_edge, old_edge, hidden_size).copy_(old_weights.view(1, old_edge, 1, old_edge, hidden_size))
41
+ # self.position_embeddings.weight.data.view(-1, old_len, hidden_size).copy_(old_weights)
42
+
43
+
44
+ class AttentionMixin(BaseMixin):
45
+ def __init__(self, num_layers,
46
+ hidden_size,
47
+ init_method=unscaled_init_method(0.02),
48
+ output_layer_init_method=unscaled_init_method(0.02)
49
+ ):
50
+ super(AttentionMixin, self).__init__()
51
+ self.num_layers = num_layers # replace attention in the LAST n layers
52
+ self.query_key_value = torch.nn.ModuleList(
53
+ [ColumnParallelLinear(hidden_size, 3 * hidden_size, stride=3,
54
+ gather_output=False, init_method=init_method)
55
+ for layer_id in range(num_layers)
56
+ ])
57
+ self.dense = torch.nn.ModuleList(
58
+ [RowParallelLinear(hidden_size,
59
+ hidden_size,
60
+ input_is_parallel=True,
61
+ init_method=output_layer_init_method)
62
+ for layer_id in range(num_layers)
63
+ ])
64
+
65
+ def reinit(self, parent_model=None):
66
+ start_layer = len(self.transformer.layers) - self.num_layers
67
+ assert start_layer >= 0
68
+ for layer_id in range(self.num_layers):
69
+ old_attention = self.transformer.layers[start_layer + layer_id].attention
70
+ self.query_key_value[layer_id].weight.data.copy_(old_attention.query_key_value.weight.data)
71
+ self.query_key_value[layer_id].bias.data.copy_(old_attention.query_key_value.bias.data)
72
+ self.dense[layer_id].weight.data.copy_(old_attention.dense.weight.data)
73
+ self.dense[layer_id].bias.data.copy_(old_attention.dense.bias.data)
74
+
75
+ class DsrModel(BaseModel):
76
+ def __init__(self, args, transformer=None):
77
+ super().__init__(args, transformer=transformer)
78
+ self.original_sequence_length = args.max_sequence_length
79
+ additional_seqlen = args.new_sequence_length - args.max_sequence_length
80
+ self.add_mixin('extra_position_embedding', PositionEmbeddingMixin(
81
+ additional_seqlen, args.hidden_size
82
+ ))
83
+ self.add_mixin('attention_plus', AttentionMixin(
84
+ num_layers=args.num_layers,
85
+ hidden_size=args.hidden_size
86
+ ))
87
+ self.layout = args.layout
88
+ # [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1] 4095 {layout[2]}
89
+ self.kernel_size = args.kernel_size
90
+ self.kernel_size2 = args.kernel_size2
91
+ self.log_attention_weights = None
92
+
93
+ def position_embedding_forward(self, position_ids, **kw_args):
94
+ position = position_ids[..., :self.layout[1]]
95
+ position_plus = position_ids[..., self.layout[1]:] - self.original_sequence_length
96
+ position_embeddings = torch.cat(
97
+ (
98
+ self.transformer.position_embeddings(position),
99
+ self.get_mixin('extra_position_embedding').position_embeddings(position_plus)
100
+ ),
101
+ dim=-2
102
+ )
103
+ return position_embeddings
104
+
105
+ def attention_forward(self, hidden_states, mask,
106
+ layer_id=None, log_attention_weights=None, **kw_args):
107
+ attn_module = self.transformer.layers[layer_id].attention
108
+ # attention_plus on all layers
109
+ query_key_value_plus = self.get_mixin('attention_plus').query_key_value[layer_id]
110
+ dense_plus = self.get_mixin('attention_plus').dense[layer_id]
111
+ # split two parts
112
+ hidden_states_plus = hidden_states[:, self.layout[1]:]
113
+ hidden_states = hidden_states[:, :self.layout[1]]
114
+ # base model qkv
115
+ mixed_raw_layer = attn_module.query_key_value(hidden_states)
116
+ q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer, 3)
117
+ # cuda2d model qkv
118
+ mixed_raw_layer = query_key_value_plus(hidden_states_plus)
119
+ q1, k1, v1 = split_tensor_along_last_dim(mixed_raw_layer, 3)
120
+
121
+ dropout_fn = attn_module.attention_dropout if self.training else None
122
+
123
+ # cuda2d attention
124
+ context_layer0, context_layer1 = sparse_attention_2d_light(
125
+ q0, k0, v0,
126
+ q1, k1, v1,
127
+ mask,
128
+ n_head=attn_module.num_attention_heads_per_partition,
129
+ text_len=self.layout[0],
130
+ kernel_size=self.kernel_size,
131
+ kernel_size2=self.kernel_size2,
132
+ attention_dropout=dropout_fn,
133
+ log_attention_weights=log_attention_weights,
134
+ add_scalar=(kw_args['add_scalar'] if 'add_scalar' in kw_args else 0)
135
+ )
136
+
137
+ output_0 = attn_module.dense(context_layer0)
138
+ output_1 = dense_plus(context_layer1)
139
+ output = torch.cat((output_0, output_1), dim=1)
140
+
141
+ return output
142
+
143
+ def final_forward(self, logits, **kwargs):
144
+ logits_parallel = logits
145
+ logits_parallel = torch.nn.functional.linear(logits_parallel.float(), self.transformer.word_embeddings.weight[:20000].float())
146
+ # logits_parallel = torch.nn.functional.linear(logits_parallel, self.transformer.word_embeddings.weight[:20000])
147
+ return logits_parallel
148
+
149
+ def disable_untrainable_params(self):
150
+ self.transformer.requires_grad_(False)
151
+
152
+ @classmethod
153
+ def add_model_specific_args(cls, parser):
154
+ group = parser.add_argument_group('Cuda2dModel', 'cuda2d model configurations')
155
+ group.add_argument("--kernel-size", type=int, default=5)
156
+ group.add_argument("--kernel-size2", type=int, default=5)
157
+ group.add_argument("--layout", type=str, default='96,496,4096')
158
+ group.add_argument("--new-sequence-length", type=int, default=4096)
159
+ return parser
160
+
161
+ def sparse_attention_2d_light(q0, k0, v0, q1, k1, v1, attention_mask, n_head, text_len, kernel_size=9, kernel_size2=7, attention_dropout=None, log_attention_weights = None, add_scalar=0, **kwargs):
162
+ '''
163
+ q0, k0, v0: [batch_size, 1088, hidden_size]
164
+ q1, k1, v1: [batch_size, 4096, h2]
165
+ n_head: int
166
+ attention_mask: [batch_size, 1088, 1088]
167
+ '''
168
+ from SwissArmyTransformer.ops.local_attention_function import f_similar, f_weighting
169
+
170
+ b, s0, h0 = q0.shape
171
+ b, s1, h1 = q1.shape
172
+ h, l0, l1 = h0 // n_head, sqrt(s0-text_len), sqrt(s1)
173
+
174
+ q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
175
+ v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
176
+ k0T = k0.reshape(b, s0, n_head, h).permute(0, 2, 3, 1)
177
+
178
+ # standard attention for level 0
179
+ attention_scores = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T)
180
+
181
+ if log_attention_weights is not None:
182
+ attention_scores += log_attention_weights
183
+ attention_scores = torch.mul(attention_scores, attention_mask) - \
184
+ 10000.0 * (1.0 - attention_mask)
185
+
186
+ attention_probs0 = F.softmax(attention_scores, dim=-1)
187
+
188
+ # local attention for level 1
189
+ q1 = (q1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1) / math.sqrt(h1//n_head)).contiguous().view(b*n_head, h1//n_head, l1, l1)
190
+ k1 = k1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1)
191
+ v1 = v1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1)
192
+ # scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, True)
193
+ scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, False)
194
+
195
+ # cross attention
196
+ k0T = k0T[..., -l0**2:].reshape(b*n_head, h, l0, l0).contiguous()
197
+ scores_1_to_0 = f_similar(q1, k0T, kernel_size2, kernel_size2, False) # [b*n_head, l1, l1, field]
198
+ scores_1 = torch.cat(
199
+ (
200
+ scores_1_to_0.view(b*n_head, -1, scores_1_to_0.shape[3]) + add_scalar,
201
+ scores_1_to_1.view(b*n_head, -1, scores_1_to_1.shape[3])
202
+ ),
203
+ dim=-1)
204
+ attention_probs1 = F.softmax(scores_1, dim=-1)
205
+
206
+ if attention_dropout is not None:
207
+ # with get_cuda_rng_tracker().fork():
208
+ attention_probs0 = attention_dropout(attention_probs0)
209
+ attention_probs1 = attention_dropout(attention_probs1)
210
+
211
+ # weighting for level 0
212
+ context0 = torch.matmul(attention_probs0, v0) # [b, n_head, s0, h]
213
+ # weighting for level 1
214
+ probs_1_to_1 = attention_probs1[:, :, -scores_1_to_1.shape[3]:].view_as(scores_1_to_1)
215
+ # context1_to_1 = f_weighting(v1, probs_1_to_1.contiguous(), kernel_size*2-1, kernel_size, True)
216
+ context1_to_1 = f_weighting(v1, probs_1_to_1.contiguous(), kernel_size*2-1, kernel_size, False)
217
+
218
+ context1 = context1_to_1.view(b, n_head * h, l1**2)
219
+ # weighting for cross attention
220
+ probs_1_to_0 = attention_probs1[:, :, :scores_1_to_0.shape[3]].view_as(scores_1_to_0)
221
+ v0_part = v0[:, :, -l0**2:].transpose(-1, -2).contiguous().view(b*n_head, h, l0, l0)
222
+ context1_to_0 = f_weighting(v0_part, probs_1_to_0.contiguous(), kernel_size2, kernel_size2, False)
223
+ context1_to_0 = context1_to_0.view(b, n_head * h, l1**2)
224
+ context1 = context1 + context1_to_0
225
+ return context0.transpose(1, 2).reshape(b, s0, h0), context1.transpose(-1, -2)
sr_pipeline/dsr_sampling.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ '''
3
+ @File : cuda2d_sampling.py
4
+ @Time : 2021/10/09 00:46:04
5
+ @Author : Ming Ding
6
+ @Contact : dm18@mails.tsinghua.edu.cn
7
+ '''
8
+
9
+ # here put the import lib
10
+ import os
11
+ import sys
12
+ import math
13
+ import random
14
+ from cv2 import reduce
15
+ import torch
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ import numpy as np
20
+
21
+ def top_k_logits_(logits, top_k=0, filter_value=-float('Inf')):
22
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
23
+ logits[indices_to_remove] = filter_value
24
+ return logits
25
+
26
+ class IterativeEntfilterStrategy:
27
+ def __init__(self, invalid_slices=[], temperature=1., topk=6):
28
+ self.invalid_slices = invalid_slices
29
+ self.temperature = temperature
30
+ self.topk = topk
31
+ self.cluster_labels = torch.tensor(np.load('cluster_label2.npy'), device='cuda', dtype=torch.long)
32
+
33
+
34
+ def forward(self, logits_, tokens, temperature=None, entfilter=None, filter_topk=5, temperature2=None):
35
+ # In interative strategy, logits are of shape [batch_size, seq_length, hidden_size]
36
+ if temperature is None:
37
+ temperature = self.temperature
38
+
39
+ logits = logits_.float() / temperature
40
+ for invalid_slice in self.invalid_slices:
41
+ logits[..., invalid_slice] = -float('Inf')
42
+ logits = logits.view(-1, logits.shape[-1])
43
+
44
+ rprobs = F.softmax(logits.float(), dim=-1)
45
+ c = self.cluster_labels.expand(*rprobs.shape)
46
+ cprobs = torch.zeros(logits.shape[0], 500, device=logits.device).scatter_add_(1, c, rprobs)
47
+
48
+ best_scores, best_clusters = cprobs.topk(self.topk)
49
+ bz = logits.shape[0]
50
+ best_scores = best_scores / best_scores.sum(dim=-1, keepdim=True)
51
+ sampled_ids = torch.multinomial(best_scores, num_samples=1)
52
+ selected_clusters = torch.gather(best_clusters, dim=1, index=sampled_ids)
53
+ selected_mask = (self.cluster_labels.unsqueeze(0).expand(bz, -1) != selected_clusters) # cluster_labels [1, 20000] \in [0,500)
54
+ logits[selected_mask] = -65504
55
+ # for i in range(bz):
56
+ # selected_cluster = best_clusters[i][torch.multinomial(best_scores[i] / best_scores[i].sum(), num_samples=1)]
57
+ # logits[i, self.cluster_labels != selected_cluster] = -65504
58
+
59
+ # logits = top_k_logits(logits, self.topk, self.top_p)
60
+ probs = F.softmax(logits.float()/0.6, dim=-1) # float is essetial, due to a bug in Pytorch
61
+ pred = torch.multinomial(probs, num_samples=1).view(*logits_.shape[:2])
62
+
63
+ assert tokens.shape[1] == pred.shape[1] + 1
64
+ tokens = torch.cat((tokens[:, :1], pred), dim=1)
65
+ return tokens
66
+
67
+ def filling_sequence_dsr(
68
+ model,
69
+ seq0,
70
+ seq1,
71
+ warmup_steps=3,
72
+ block_hw=(4, 4),
73
+ strategy=IterativeEntfilterStrategy(topk=10),
74
+ ):
75
+ '''
76
+ seq: [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1]
77
+ 4095 {layout[2]} final_token.
78
+ Attention:
79
+ The sampling temperature are changing, temporally we hard code them here.
80
+ The temperature in the strategy is not used.
81
+ '''
82
+ assert hasattr(model, 'layout')
83
+ layout = model.layout
84
+ assert len(seq0.shape) == 2 and len(seq1.shape) == 2 \
85
+ and seq0.shape[0] == seq1.shape[0]
86
+ assert len(layout) == 3
87
+ assert seq1.shape[1] == layout[-1] - layout[-2] + 1
88
+ assert (seq1 >= 0).all() and (seq0 >= 0).all()
89
+ device = seq0.device
90
+ # concat and pad sequences
91
+ batch_size = seq0.shape[0]
92
+ n_pad = layout[1] - seq0.shape[1]
93
+ assert n_pad > 0, "You should truncate long input before filling."
94
+ seq = torch.cat((
95
+ torch.tensor([0]*n_pad, device=device, dtype=seq0.dtype)
96
+ .unsqueeze(0).expand(batch_size, n_pad),
97
+ seq0, seq1), dim=1) # [b, layout[-1]+1]
98
+ assert seq.shape[1] == layout[-1] + 1
99
+
100
+ # build initial tokens, attention_mask, and position_ids
101
+ tokens = seq.clone()
102
+ attention_mask = torch.ones(layout[1], layout[1]).to(device)
103
+ attention_mask[:layout[0], layout[0]:] = 0
104
+ attention_mask[n_pad:, :n_pad] = 0
105
+ attention_mask = attention_mask.type_as(next(model.parameters())) # if fp16
106
+ position_ids = torch.cat((
107
+ torch.zeros(n_pad, dtype=torch.long),
108
+ torch.arange(0, layout[0] - n_pad),
109
+ torch.arange(513, 513 + layout[1] - layout[0]),
110
+ torch.arange(1024, 1024+layout[2]-layout[1]))).to(device)
111
+ log_attention_weights = torch.zeros(layout[1], layout[1],
112
+ device=device).type_as(next(model.parameters()))
113
+ log_attention_weights[layout[0]:, n_pad:layout[0]] = 0.
114
+
115
+ # prepare for interation
116
+ unfixed = (tokens < 0) # just init an all-False tensor
117
+ unfixed[:, -layout[-1] + layout[-2]:] = True
118
+
119
+ ll, rr = block_hw
120
+ edge_len = int(math.sqrt(layout[-1] - layout[-2]) + 1e-4)
121
+ num_steps = warmup_steps + ll - 1 + rr
122
+ # interative refining
123
+
124
+ # unfixed[..., -(layout[-1] - layout[-2]):].view(
125
+ # batch_size, edge_len//ll, ll, edge_len//rr, rr)[:, :, :, :, -1] = False
126
+
127
+
128
+ ret = []
129
+ ret.append(tokens[:, layout[-2]+1:].clone())
130
+ for step_cnt in range(1, num_steps+1):
131
+ if step_cnt <= warmup_steps:
132
+ logits, *_dump = model(tokens[:,:-1], position_ids, attention_mask, log_attention_weights=log_attention_weights)
133
+ real_temp = 1.
134
+ new_tokens = strategy.forward(logits, tokens, real_temp)
135
+ tokens[unfixed] = new_tokens[unfixed]
136
+ else:
137
+ logits, *_dump = model(tokens[:,:-1], position_ids, attention_mask, log_attention_weights=log_attention_weights)
138
+ real_temp = 1.
139
+ new_tokens = strategy.forward(
140
+ logits, tokens, real_temp,
141
+ entfilter=1.3,
142
+ filter_topk=5,
143
+ temperature2=0.6
144
+ )
145
+ # tokens[unfixed] = new_tokens[unfixed]
146
+ # fixed tokens (update unfixed)
147
+ unfixed2 = (tokens > 10000000)
148
+ for x in range(min(ll, step_cnt - warmup_steps)):
149
+ y = step_cnt - warmup_steps - x - 1
150
+ if y < rr:
151
+ unfixed[..., -(layout[-1] - layout[-2]):].view(
152
+ batch_size, edge_len//ll, ll, edge_len//rr, rr)[:, :, x, :, y] = False
153
+ unfixed2[..., -(layout[-1] - layout[-2]):].view(
154
+ batch_size, edge_len//ll, ll, edge_len//rr, rr)[:, :, x, :, y] = True
155
+ tokens[unfixed2] = new_tokens[unfixed2]
156
+
157
+ ret.append(tokens[:, layout[-2]+1:].clone())
158
+
159
+ return ret
sr_pipeline/iterative_sr.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ '''
3
+ @File : iterative_sr.py
4
+ @Time : 2022/03/02 15:57:45
5
+ @Author : Ming Ding
6
+ @Contact : dm18@mails.tsinghua.edu.cn
7
+ '''
8
+
9
+ # here put the import lib
10
+ import os
11
+ import sys
12
+ import math
13
+ import random
14
+
15
+ # here put the import lib
16
+ import os
17
+ import sys
18
+ import math
19
+ import random
20
+ from PIL import ImageEnhance, Image
21
+
22
+ import torch
23
+ import argparse
24
+ from torchvision import transforms
25
+
26
+ from SwissArmyTransformer.training.model_io import load_checkpoint
27
+ from SwissArmyTransformer import get_args
28
+ from .itersr_sampling import filling_sequence_itersr, IterativeEntfilterStrategy
29
+ from SwissArmyTransformer.generation.utils import timed_name, save_multiple_images, generate_continually
30
+
31
+ from .itersr_model import ItersrModel
32
+
33
+ from icetk import icetk as tokenizer
34
+
35
+ class IterativeSuperResolution:
36
+ def __init__(self, args, path, max_bz=4, shared_transformer=None):
37
+ args.load = path
38
+ args.kernel_size = 5
39
+ args.kernel_size2 = 5
40
+ args.new_sequence_length = 4624
41
+ args.layout = [16,3616]
42
+
43
+ model = ItersrModel(args, transformer=shared_transformer)
44
+ if args.fp16:
45
+ model = model.half()
46
+
47
+ load_checkpoint(model, args) # on cpu
48
+ model.eval()
49
+ self.model = model.cuda()
50
+
51
+ # save cpu weights
52
+ self.saved_weights = dict((k,v.cpu())
53
+ for k, v in model.named_parameters()
54
+ if 'transformer' in k
55
+ )
56
+
57
+ invalid_slices = [slice(tokenizer.num_image_tokens, None)]
58
+
59
+ self.strategy = IterativeEntfilterStrategy(invalid_slices,
60
+ temperature=args.temp_all_itersr, topk=args.topk_itersr)
61
+ self.max_bz = max_bz
62
+
63
+ def _restore_transformer_from_cpu(self, non_blocking=False):
64
+ for k, v in self.model.named_parameters():
65
+ if k in self.saved_weights:
66
+ v.copy_(self.saved_weights[k])
67
+
68
+ def __call__(self, text_tokens, image_tokens, enhance=False, input_mask=None):
69
+ if len(text_tokens.shape) == 1:
70
+ text_tokens.unsqueeze_(0)
71
+ text_tokens = text_tokens.clone()[..., :16]
72
+ if len(image_tokens.shape) == 1:
73
+ image_tokens.unsqueeze_(0)
74
+ if enhance:
75
+ new_image_tokens = []
76
+ for big_img in image_tokens:
77
+ decoded = tokenizer.decode(image_ids=big_img).squeeze(0)
78
+ ndarr = decoded.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
79
+ image_pil_raw = ImageEnhance.Sharpness(Image.fromarray(ndarr))
80
+ big_img2 = tokenizer.encode(image_pil=image_pil_raw.enhance(1.5), image_size=480).view(-1)
81
+ new_image_tokens.append(big_img2)
82
+ image_tokens = torch.stack(new_image_tokens)
83
+ print('Converting Itersr model...')
84
+ self._restore_transformer_from_cpu()
85
+ model = self.model
86
+ print('iterative super-resolution...')
87
+ output_list = []
88
+ for tim in range(max(text_tokens.shape[0] // self.max_bz, 1)):
89
+ big_img = image_tokens[tim*self.max_bz:(tim+1)*self.max_bz]
90
+ text_seq = text_tokens[tim*self.max_bz:(tim+1)*self.max_bz]
91
+ mask_raw = torch.tensor(
92
+ [
93
+ -1, 0, 1, 2, 3, 4,
94
+ 0, -1, 2, -1, -2, 5,
95
+ 1, -2, 3, 4, 5, 6,
96
+ 2, 3, 4, 5, -1, 1,
97
+ 3, -1, -2, 0, -1, 2,
98
+ 4, 5, 6, 1, 3, -2
99
+ ]
100
+ ).view(1, 6, 1, 6).expand(10, 6, 10, 6).reshape(-1).contiguous()
101
+
102
+ topks = [60, 40, 40, 40, 20, 20, 10]
103
+
104
+ for mask_ratio in range(1, 7):
105
+ self.strategy.topk = topks[mask_ratio]
106
+ mask = (mask_raw.to(big_img.device) >= mask_ratio)
107
+ if input_mask is not None:
108
+ mask = mask & input_mask
109
+ big_img.masked_fill_(mask, tokenizer['<start_of_image>'])
110
+ seq1 = big_img
111
+ output1 = filling_sequence_itersr(model, text_seq, seq1,
112
+ warmup_steps=1, block_hw=(1, 0),
113
+ strategy=self.strategy
114
+ )
115
+ big_img = output1
116
+ print(f'Iter {mask_ratio} times.')
117
+ output_list.append(output1.clone())
118
+ return torch.cat(output_list, dim=0)
sr_pipeline/itersr_model.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ '''
3
+ @File : itersr_model.py
4
+ @Time : 2021/10/02 01:36:32
5
+ @Author : Ming Ding
6
+ @Contact : dm18@mails.tsinghua.edu.cn
7
+ '''
8
+
9
+ # here put the import lib
10
+ import os
11
+ import sys
12
+ import math
13
+ import random
14
+ import torch
15
+ import torch.nn.functional as F
16
+
17
+
18
+ from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin
19
+
20
+ from SwissArmyTransformer.mpu.utils import sqrt
21
+ from deepspeed.runtime.activation_checkpointing.checkpointing import get_cuda_rng_tracker
22
+ from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear
23
+ from SwissArmyTransformer.model.transformer import unscaled_init_method, split_tensor_along_last_dim
24
+
25
+ class PositionEmbeddingMixin(BaseMixin):
26
+ def __init__(self, additional_sequence_length, hidden_size,
27
+ init_method_std=0.02, reinit_slice=slice(512, 512+400)
28
+ ):
29
+ super(PositionEmbeddingMixin, self).__init__()
30
+ self.reinit_slice = reinit_slice
31
+ self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size)
32
+ torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
33
+
34
+ def reinit(self, parent_model=None):
35
+ old_weights = self.transformer.position_embeddings.weight.data[self.reinit_slice]
36
+ old_len, hidden_size = old_weights.shape
37
+ assert hidden_size == self.position_embeddings.weight.shape[-1]
38
+ old_edge, new_edge = sqrt(old_len), sqrt(self.position_embeddings.weight.shape[-2])
39
+ assert new_edge % old_edge == 0
40
+ self.position_embeddings.weight.data.view(new_edge // old_edge, old_edge, new_edge // old_edge, old_edge, hidden_size).copy_(old_weights.view(1, old_edge, 1, old_edge, hidden_size))
41
+
42
+ class ItersrModel(BaseModel):
43
+ def __init__(self, args, transformer=None):
44
+ super().__init__(args, transformer=transformer)
45
+ self.original_sequence_length = args.max_sequence_length
46
+ additional_seqlen = args.new_sequence_length - args.max_sequence_length
47
+ self.add_mixin('extra_position_embedding', PositionEmbeddingMixin(
48
+ additional_seqlen, args.hidden_size
49
+ ))
50
+ # self.add_mixin('attention_plus', AttentionMixin(
51
+ # num_layers=args.num_layers,
52
+ # hidden_size=args.hidden_size
53
+ # ))
54
+ self.layout = args.layout
55
+ # [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1] 4095 {layout[2]}
56
+ self.kernel_size = args.kernel_size
57
+ self.kernel_size2 = args.kernel_size2
58
+ self.log_attention_weights = None
59
+
60
+ def position_embedding_forward(self, position_ids, **kw_args):
61
+ position = position_ids[..., :self.layout[0]]
62
+ position_plus = position_ids[..., self.layout[0]:] - self.original_sequence_length
63
+ position_embeddings = torch.cat(
64
+ (
65
+ self.transformer.position_embeddings(position),
66
+ self.get_mixin('extra_position_embedding').position_embeddings(position_plus)
67
+ ),
68
+ dim=-2
69
+ )
70
+ return position_embeddings
71
+
72
+ def attention_forward(self, hidden_states, mask,
73
+ layer_id=None, log_attention_weights=None, **kw_args):
74
+ attn_module = self.transformer.layers[layer_id].attention
75
+ # base model qkv
76
+ mixed_raw_layer = attn_module.query_key_value(hidden_states)
77
+ q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer[:, :self.layout[0]], 3)
78
+ # cuda2d model qkv
79
+ q1, k1, v1 = split_tensor_along_last_dim(mixed_raw_layer[:, self.layout[0]:], 3)
80
+
81
+ dropout_fn = attn_module.attention_dropout if self.training else None
82
+
83
+ # cuda2d attention
84
+ context_layer = sparse_attention_2d_text(
85
+ q0, k0, v0,
86
+ q1, k1, v1,
87
+ mask,
88
+ n_head=attn_module.num_attention_heads_per_partition,
89
+ text_len=self.layout[0],
90
+ kernel_size=self.kernel_size,
91
+ attention_dropout=dropout_fn,
92
+ log_attention_weights=log_attention_weights,
93
+ )
94
+
95
+ output = attn_module.dense(context_layer)
96
+
97
+ return output
98
+
99
+ def final_forward(self, logits, **kwargs):
100
+ logits_parallel = logits
101
+ logits_parallel = torch.nn.functional.linear(logits_parallel, self.transformer.word_embeddings.weight[:20000]).float()
102
+ # logits_parallel = torch.nn.functional.linear(logits_parallel, self.transformer.word_embeddings.weight[:20000])
103
+ return logits_parallel
104
+
105
+ # def disable_untrainable_params(self):
106
+ # self.transformer.requires_grad_(False)
107
+
108
+ @classmethod
109
+ def add_model_specific_args(cls, parser):
110
+ group = parser.add_argument_group('Cuda2dModel', 'cuda2d model configurations')
111
+ group.add_argument("--kernel-size", type=int, default=5)
112
+ group.add_argument("--kernel-size2", type=int, default=5)
113
+ group.add_argument("--layout", type=str, default='16,3616')
114
+ group.add_argument("--new-sequence-length", type=int, default=4096)
115
+ return parser
116
+
117
+ def sparse_attention_2d_text(q0, k0, v0, q1, k1, v1, attention_mask, n_head, text_len, kernel_size=9, attention_dropout=None, log_attention_weights = None, **kwargs):
118
+ '''
119
+ q0, k0, v0: [batch_size, 16, hidden_size]
120
+ q1, k1, v1: [batch_size, 3600, hidden_size]
121
+ n_head: int
122
+ attention_mask: [batch_size, 16]
123
+ '''
124
+ from SwissArmyTransformer.ops.local_attention_function import f_similar, f_weighting
125
+ b, s0, h0 = q0.shape
126
+ b, s1, h1 = q1.shape
127
+ h, l1 = h0 // n_head, sqrt(s1)
128
+ assert attention_mask.shape[-1] == s0, f"Mask Shape: {attention_mask.shape}"
129
+
130
+ q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
131
+ v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
132
+ k0T = k0.reshape(b, s0, n_head, h).permute(0, 2, 3, 1)
133
+
134
+ # standard attention for level 0
135
+ attention_scores = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T)
136
+
137
+ attention_scores = torch.mul(attention_scores, attention_mask) - \
138
+ 10000.0 * (1.0 - attention_mask)
139
+
140
+ attention_probs0 = F.softmax(attention_scores, dim=-1)
141
+
142
+ # local attention for level 1
143
+ q1 = (q1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1) / math.sqrt(h1//n_head)).contiguous().view(b*n_head, h1//n_head, l1, l1)
144
+ k1 = k1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1)
145
+ v1 = v1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1)
146
+ scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, False)
147
+
148
+ # cross attention
149
+ scores_1_to_0 = torch.matmul(q1.view(b, n_head, h, s1).transpose(-1, -2), k0T)
150
+ if log_attention_weights is not None:
151
+ scores_1_to_0 += log_attention_weights
152
+ scores_1_to_0 = torch.mul(scores_1_to_0, attention_mask) - \
153
+ 10000.0 * (1.0 - attention_mask)
154
+ scores_1 = torch.cat(
155
+ (
156
+ scores_1_to_0.view(b*n_head, s1, s0),
157
+ scores_1_to_1.view(b*n_head, -1, scores_1_to_1.shape[3])
158
+ ),
159
+ dim=-1)
160
+ attention_probs1 = F.softmax(scores_1, dim=-1)
161
+
162
+ if attention_dropout is not None:
163
+ with get_cuda_rng_tracker().fork():
164
+ attention_probs1 = attention_dropout(attention_probs1)
165
+
166
+ # weighting for level 0
167
+ context0 = torch.matmul(attention_probs0, v0) # [b, n_head, s0, h]
168
+ # weighting for level 1
169
+ probs_1_to_1 = attention_probs1[:, :, -scores_1_to_1.shape[3]:].view_as(scores_1_to_1)
170
+ context1_to_1 = f_weighting(v1, probs_1_to_1.contiguous(), kernel_size*2-1, kernel_size, False)
171
+
172
+ context1 = context1_to_1.view(b, n_head, h, l1**2)
173
+ # weighting for cross attention
174
+ probs_1_to_0 = attention_probs1[:, :, :scores_1_to_0.shape[3]].view(b, n_head, -1, scores_1_to_0.shape[3])
175
+
176
+ context1_to_0 = torch.matmul(probs_1_to_0, v0)
177
+ context1 = context1.transpose(-1, -2) + context1_to_0
178
+
179
+ output = torch.cat((context0, context1), dim=2).transpose(1, 2).reshape(b, s0+s1, h0)
180
+
181
+ return output
182
+
183
+ def sparse_attention_2d_notext(q0, k0, v0, q1, k1, v1, attention_mask, n_head, text_len, kernel_size=9, attention_dropout=None, log_attention_weights = None, **kwargs):
184
+ '''
185
+ q0, k0, v0: [batch_size, 16, hidden_size]
186
+ q1, k1, v1: [batch_size, 3600, hidden_size]
187
+ n_head: int
188
+ attention_mask: [batch_size, 16]
189
+ '''
190
+ from SwissArmyTransformer.mpu.local_attention_function import f_similar, f_weighting
191
+ b, s0, h0 = q0.shape
192
+ b, s1, h1 = q1.shape
193
+ h, l1 = h0 // n_head, sqrt(s1)
194
+ assert len(attention_mask.shape) == 4 and attention_mask.shape[-1] == s0, f"Mask Shape: {attention_mask.shape}"
195
+
196
+ q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
197
+ v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
198
+ k0T = k0.reshape(b, s0, n_head, h).permute(0, 2, 3, 1)
199
+
200
+ # standard attention for level 0
201
+ attention_scores = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T)
202
+
203
+ attention_scores = torch.mul(attention_scores, attention_mask) - \
204
+ 10000.0 * (1.0 - attention_mask)
205
+
206
+ attention_probs0 = F.softmax(attention_scores, dim=-1)
207
+
208
+ # local attention for level 1
209
+ q1 = (q1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1) / math.sqrt(h1//n_head)).contiguous().view(b*n_head, h1//n_head, l1, l1)
210
+ k1 = k1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1)
211
+ v1 = v1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1)
212
+ scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, False)
213
+
214
+ attention_probs1 = F.softmax(scores_1_to_1, dim=-1)
215
+
216
+ if attention_dropout is not None:
217
+ with get_cuda_rng_tracker().fork():
218
+ attention_probs1 = attention_dropout(attention_probs1)
219
+
220
+ # weighting for level 0
221
+ context0 = torch.matmul(attention_probs0, v0) # [b, n_head, s0, h]
222
+ # weighting for level 1
223
+ probs_1_to_1 = attention_probs1
224
+ context1_to_1 = f_weighting(v1, probs_1_to_1.contiguous(), kernel_size*2-1, kernel_size, False)
225
+
226
+ context1 = context1_to_1.view(b, n_head, h, l1**2)
227
+ # weighting for cross attention
228
+ context1 = context1.transpose(-1, -2)
229
+
230
+ output = torch.cat((context0, context1), dim=2).transpose(1, 2).reshape(b, s0+s1, h0)
231
+
232
+ return output
sr_pipeline/itersr_sampling.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ '''
3
+ @File : itersr_sampling.py
4
+ @Time : 2022/03/03 14:24:28
5
+ @Author : Ming Ding
6
+ @Contact : dm18@mails.tsinghua.edu.cn
7
+ '''
8
+
9
+ # here put the import lib
10
+ import os
11
+ import sys
12
+ import math
13
+ import random
14
+ import numpy as np
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from icetk import icetk as tokenizer
19
+
20
+ def top_k_logits_(logits, top_k=0, filter_value=-float('Inf')):
21
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
22
+ logits[indices_to_remove] = filter_value
23
+ return logits
24
+
25
+ # class IterativeEntfilterStrategy:
26
+ # def __init__(self, invalid_slices=[], temperature=1., topk=10):
27
+ # self.invalid_slices = invalid_slices
28
+ # self.temperature = temperature
29
+ # self.topk = topk
30
+ # self.cluster_labels = torch.tensor(np.load('cluster_label.npy'), device='cuda', dtype=torch.long)
31
+
32
+
33
+ # def forward(self, logits_, tokens, temperature=None, entfilter=None, filter_topk=5, temperature2=None):
34
+ # # In interative strategy, logits are of shape [batch_size, seq_length, hidden_size]
35
+ # if temperature is None:
36
+ # temperature = self.temperature
37
+
38
+ # logits = logits_.float() / temperature
39
+ # for invalid_slice in self.invalid_slices:
40
+ # logits[..., invalid_slice] = -float('Inf')
41
+ # logits = logits.view(-1, logits.shape[-1])
42
+
43
+ # rprobs = F.softmax(logits.float(), dim=-1)
44
+ # c = self.cluster_labels.expand(*rprobs.shape)
45
+ # cprobs = torch.zeros(logits.shape[0], 500, device=logits.device).scatter_add_(1, c, rprobs)
46
+
47
+ # best_scores, best_clusters = cprobs.topk(self.topk)
48
+ # bz = logits.shape[0]
49
+ # best_scores = best_scores / best_scores.sum(dim=-1, keepdim=True)
50
+ # sampled_ids = torch.multinomial(best_scores, num_samples=1)
51
+ # selected_clusters = torch.gather(best_clusters, dim=1, index=sampled_ids)
52
+ # selected_mask = (self.cluster_labels.unsqueeze(0).expand(bz, -1) != selected_clusters) # cluster_labels [1, 20000] \in [0,500)
53
+ # logits[selected_mask] = -65504
54
+ # # for i in range(bz):
55
+ # # selected_cluster = best_clusters[i][torch.multinomial(best_scores[i] / best_scores[i].sum(), num_samples=1)]
56
+ # # logits[i, self.cluster_labels != selected_cluster] = -65504
57
+
58
+ # # logits = top_k_logits(logits, self.topk, self.top_p)
59
+ # probs = F.softmax(logits.float(), dim=-1) # float is essetial, due to a bug in Pytorch
60
+ # pred = torch.multinomial(probs, num_samples=1).view(*logits_.shape[:2])
61
+
62
+ # assert tokens.shape[1] == pred.shape[1]
63
+ # tokens = pred
64
+ # return tokens
65
+
66
+ class IterativeEntfilterStrategy:
67
+ def __init__(self, invalid_slices=[], temperature=1., topk=10):
68
+ self.invalid_slices = invalid_slices
69
+ self.temperature = temperature
70
+ self.topk = topk
71
+
72
+ def forward(self, logits, tokens, temperature=None, entfilter=None, filter_topk=5, temperature2=None):
73
+ # In interative strategy, logits are of shape [batch_size, seq_length, hidden_size]
74
+ if temperature is None:
75
+ temperature = self.temperature
76
+ # check entropy filter
77
+ # if entfilter is not None:
78
+ # assert temperature2 is not None
79
+ # topraw = (torch.topk(logits, filter_topk, dim=-1)[0]).softmax(dim=-1)
80
+ # ent = -(topraw * topraw.log()).sum(dim=-1) # [batch_size, seq_length]
81
+ # temperature = torch.tensor([[[temperature - temperature2]]], device=logits.device).expand(*logits.shape[:2], 1) * (ent > entfilter).unsqueeze(-1) + temperature2
82
+
83
+ logits = logits.float() / temperature
84
+ for invalid_slice in self.invalid_slices:
85
+ logits[..., invalid_slice] = -float('Inf')
86
+
87
+ # debiased topk
88
+ # probs = F.softmax(logits, dim=-1)
89
+ # tk_value, tk_idx = torch.topk(probs, self.topk, dim=-1)
90
+ # pred = torch.multinomial(probs.view(-1, logits.shape[-1]), num_samples=1).view(*logits.shape[:2], 1)
91
+ # edge_idx = tk_idx[:, :, -1:]
92
+ # edge_value = tk_value[:, :, -1:]
93
+ # edge_mask = probs.gather(dim=-1, index=pred) < edge_value
94
+ # pred[edge_mask] = edge_idx[edge_mask] # replace outliers as the "filter_topk"-th token
95
+ # pred.squeeze_(-1) # [batch_size, seq_length]
96
+
97
+ top_k_logits_(logits, self.topk)
98
+ probs = F.softmax(logits, dim=-1)
99
+ pred = torch.multinomial(probs.view(-1, logits.shape[-1]), num_samples=1).view(*logits.shape[:2], 1)
100
+ pred.squeeze_(-1)
101
+
102
+ assert tokens.shape[1] == pred.shape[1]
103
+ tokens = pred
104
+ return tokens
105
+
106
+ def filling_sequence_itersr(
107
+ model,
108
+ seq0,
109
+ seq1,
110
+ warmup_steps=3,
111
+ block_hw=(4, 4),
112
+ strategy=IterativeEntfilterStrategy(topk=10),
113
+ ):
114
+ '''
115
+ seq: [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1]
116
+ 4095 {layout[2]} final_token.
117
+ Attention:
118
+ The sampling temperature are changing, temporally we hard code them here.
119
+ The temperature in the strategy is not used.
120
+ '''
121
+ assert hasattr(model, 'layout')
122
+ layout = model.layout
123
+
124
+ device = seq0.device
125
+ # concat and pad sequences
126
+ batch_size = seq0.shape[0]
127
+ n_pad = layout[0] - seq0.shape[1]
128
+ assert n_pad >= 0, "You should truncate long input before filling."
129
+ seq = torch.cat((
130
+ torch.tensor([0]*n_pad, device=device, dtype=seq0.dtype)
131
+ .unsqueeze(0).expand(batch_size, n_pad),
132
+ seq0, seq1), dim=1) # [b, layout[-1]+1]
133
+ assert seq.shape[1] == layout[-1]
134
+
135
+ # build initial tokens, attention_mask, and position_ids
136
+ tokens = seq.clone()
137
+ attention_mask = torch.ones(layout[0]).to(device)
138
+ attention_mask[:n_pad] = 0
139
+ attention_mask = attention_mask.unsqueeze(0).type_as(next(model.parameters())) # if fp16
140
+ position_ids = torch.cat((
141
+ torch.zeros(n_pad, dtype=torch.long),
142
+ torch.arange(0, layout[0] - n_pad),
143
+ torch.arange(1024, 1024+layout[1]-layout[0]))).to(device)
144
+ log_attention_weights = torch.zeros(layout[0], device=device).type_as(next(model.parameters()))
145
+ log_attention_weights[n_pad:layout[0]] = 0.
146
+ log_attention_weights = log_attention_weights.unsqueeze(0)
147
+
148
+ # prepare for interation
149
+ unfixed = (tokens == tokenizer['<start_of_image>'])
150
+ ll, rr = block_hw
151
+ edge_len = int(math.sqrt(layout[-1] - layout[-2]) + 1e-4)
152
+ num_steps = 1
153
+ # interative refining
154
+
155
+ # unfixed[..., -(layout[-1] - layout[-2]):].view(
156
+ # batch_size, edge_len//ll, ll, edge_len//rr, rr)[:, :, :, :, -1] = False
157
+
158
+
159
+ ret = []
160
+ # ret.append(tokens[:, layout[-2]:-1].clone())
161
+ for step_cnt in range(1, num_steps+1):
162
+ logits, *_dump = model(tokens, position_ids, attention_mask, log_attention_weights=log_attention_weights)
163
+ real_temp = 1.
164
+ new_tokens = strategy.forward(logits, tokens, real_temp)
165
+ tokens[unfixed] = new_tokens[unfixed]
166
+
167
+ ret.append(tokens[:, layout[-2]:].clone())
168
+ return torch.cat(ret, dim=0)
sr_pipeline/sr_group.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ '''
3
+ @File : sr_group.py
4
+ @Time : 2022/04/02 01:17:21
5
+ @Author : Ming Ding
6
+ @Contact : dm18@mails.tsinghua.edu.cn
7
+ '''
8
+
9
+ # here put the import lib
10
+ import os
11
+ import sys
12
+ import math
13
+ import random
14
+
15
+ import numpy as np
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from SwissArmyTransformer.resources import auto_create
19
+ from .direct_sr import DirectSuperResolution
20
+ from .iterative_sr import IterativeSuperResolution
21
+
22
+ class SRGroup:
23
+ def __init__(self, args, home_path=None,):
24
+ dsr_path = auto_create('cogview2-dsr', path=home_path)
25
+ itersr_path = auto_create('cogview2-itersr', path=home_path)
26
+ dsr = DirectSuperResolution(args, dsr_path)
27
+ itersr = IterativeSuperResolution(args, itersr_path, shared_transformer=dsr.model.transformer)
28
+ self.dsr = dsr
29
+ self.itersr = itersr
30
+
31
+ def sr_base(self, img_tokens, txt_tokens):
32
+ assert img_tokens.shape[-1] == 400 and len(img_tokens.shape) == 2
33
+ batch_size = img_tokens.shape[0]
34
+ txt_len = txt_tokens.shape[-1]
35
+ if len(txt_tokens.shape) == 1:
36
+ txt_tokens = txt_tokens.unsqueeze(0).expand(batch_size, txt_len)
37
+ sred_tokens = self.dsr(txt_tokens, img_tokens)
38
+ iter_tokens = self.itersr(txt_tokens, sred_tokens[:, -3600:].clone())
39
+ return iter_tokens[-batch_size:]
40
+
41
+ # def sr_patch(self, img_tokens, txt_tokens):
42
+ # assert img_tokens.shape[-1] == 3600 and len(img_tokens.shape) == 2
43
+ # batch_size = img_tokens.shape[0] * 9
44
+ # txt_len = txt_tokens.shape[-1]
45
+ # if len(txt_tokens.shape) == 1:
46
+ # txt_tokens = txt_tokens.unsqueeze(0).expand(batch_size, txt_len)
47
+ # img_tokens = img_tokens.view(img_tokens.shape[0], 3, 20, 3, 20).permute(0, 1, 3, 2, 4).reshape(batch_size, 400)
48
+ # iter_tokens = self.sr_base(img_tokens, txt_tokens)
49
+ # return iter_tokens