fffiloni commited on
Commit
415bea3
·
verified ·
1 Parent(s): 8063899

Update text2vid_torch2.py

Browse files
Files changed (1) hide show
  1. text2vid_torch2.py +54 -32
text2vid_torch2.py CHANGED
@@ -224,8 +224,9 @@ class AttnProcessor2_0:
224
  return query, key, dynamic_lambda, key1
225
  '''
226
 
227
- def get_qk(
228
- self, query, key):
 
229
  r"""
230
  Compute the attention scores.
231
  Args:
@@ -240,45 +241,66 @@ class AttnProcessor2_0:
240
  dynamic_lambda = None
241
  key1 = None
242
 
243
- if self.use_last_attn_slice:
244
- if self.last_attn_slice is not None:
245
-
246
- query_list = self.last_attn_slice[0]
247
- key_list = self.last_attn_slice[1]
248
-
249
- if query.shape[1] == self.num_frames and query.shape == key.shape:
250
- key1 = key.clone()
251
- key1[:,:1,:key_list.shape[2]] = key_list[:,:1]
252
- dynamic_lambda = torch.tensor([1 + self.LAMBDA * (i/50) for i in range(self.num_frames)]).to(key.dtype).cuda()
253
 
254
- if q_old.shape == k_old.shape and q_old.shape[1] != self.num_frames:
255
- batch_dim = query_list.shape[0] // self.bs
256
- all_dim = query.shape[0] // self.bs
257
- for i in range(self.bs):
258
- # Ensure slice dimensions match
259
- target_size = min(query[i*all_dim:(i*all_dim) + batch_dim, :query_list.shape[1], :query_list.shape[2]].size(0),
260
- query_list[i*batch_dim:(i+1)*batch_dim].size(0))
 
 
 
 
 
 
 
 
 
261
 
262
- # Check if the target size is compatible with the query slice dimensions
263
- query_slice_shape = query[i*all_dim:(i*all_dim) + target_size, :query_list.shape[1], :query_list.shape[2]].shape
264
- query_list_slice_shape = query_list[i*batch_dim:i*batch_dim + target_size].shape
265
 
266
- if query_slice_shape[1] != query_list_slice_shape[1]: # Dimension mismatch check
267
- print(f"Warning: Dimension mismatch. query_slice_shape: {query_slice_shape}, query_list_slice_shape: {query_list_slice_shape}. Adjusting to compatible sizes.")
268
- # Adjust to the smaller dimension
269
- target_size = min(query_slice_shape[1], query_list_slice_shape[1])
 
 
 
 
270
 
271
- # Assign values from query_list to query
272
- query[i*all_dim:(i*all_dim) + target_size, :query_list.shape[1], :query_list.shape[2]] = \
273
- query_list[i*batch_dim:i*batch_dim + target_size]
274
 
275
- if self.save_last_attn_slice:
276
- self.last_attn_slice = [query, key]
277
- self.save_last_attn_slice = False
 
 
 
278
 
 
 
 
 
 
 
279
  return query, key, dynamic_lambda, key1
280
 
281
 
 
 
282
  def init_attention_func(unet):
283
 
284
  for name, module in unet.named_modules():
 
224
  return query, key, dynamic_lambda, key1
225
  '''
226
 
227
+ import torch
228
+
229
+ def get_qk(self, query, key):
230
  r"""
231
  Compute the attention scores.
232
  Args:
 
241
  dynamic_lambda = None
242
  key1 = None
243
 
244
+ try:
245
+ if self.use_last_attn_slice:
246
+ if self.last_attn_slice is not None:
247
+
248
+ query_list = self.last_attn_slice[0]
249
+ key_list = self.last_attn_slice[1]
250
+
251
+ if query.shape[1] == self.num_frames and query.shape == key.shape:
252
+ key1 = key.clone()
 
253
 
254
+ # Ensure the batch dimension of key1 and key_list match
255
+ batch_size_key1 = key1.shape[0]
256
+ batch_size_key_list = key_list.shape[0]
257
+
258
+ if batch_size_key1 != batch_size_key_list:
259
+ # Handle mismatch: either pad or slice to match sizes
260
+ if batch_size_key1 > batch_size_key_list:
261
+ # Pad key_list if key1 batch size is larger
262
+ padding = (0, 0, 0, batch_size_key1 - batch_size_key_list) # (left, right, top, bottom)
263
+ key_list = torch.nn.functional.pad(key_list, padding, "constant", 0)
264
+ else:
265
+ # Slice key1 if key_list batch size is larger
266
+ key1 = key1[:batch_size_key_list]
267
+
268
+ # Proceed with assignment after matching batch dimensions
269
+ key1[:,:1,:key_list.shape[2]] = key_list[:,:1]
270
 
271
+ dynamic_lambda = torch.tensor([1 + self.LAMBDA * (i/50) for i in range(self.num_frames)]).to(key.dtype).cuda()
 
 
272
 
273
+ if q_old.shape == k_old.shape and q_old.shape[1] != self.num_frames:
274
+ batch_dim = query_list.shape[0] // self.bs
275
+ all_dim = query.shape[0] // self.bs
276
+ for i in range(self.bs):
277
+ target_size = min(query[i*all_dim:(i*all_dim) + batch_dim, :query_list.shape[1], :query_list.shape[2]].size(0),
278
+ query_list[i*batch_dim:(i+1)*batch_dim].size(0))
279
+ query_slice_shape = query[i*all_dim:(i*all_dim) + target_size, :query_list.shape[1], :query_list.shape[2]].shape
280
+ query_list_slice_shape = query_list[i*batch_dim:i*batch_dim + target_size].shape
281
 
282
+ if query_slice_shape[1] != query_list_slice_shape[1]:
283
+ print(f"Warning: Dimension mismatch. query_slice_shape: {query_slice_shape}, query_list_slice_shape: {query_list_slice_shape}. Adjusting to compatible sizes.")
284
+ target_size = min(query_slice_shape[1], query_list_slice_shape[1])
285
 
286
+ query[i*all_dim:(i*all_dim) + target_size, :query_list.shape[1], :query_list.shape[2]] = \
287
+ query_list[i*batch_dim:i*batch_dim + target_size]
288
+
289
+ if self.save_last_attn_slice:
290
+ self.last_attn_slice = [query, key]
291
+ self.save_last_attn_slice = False
292
 
293
+ except RuntimeError as e:
294
+ # If a RuntimeError happens, catch it and clean CUDA memory
295
+ print(f"RuntimeError occurred: {e}. Cleaning up CUDA memory...")
296
+ torch.cuda.empty_cache()
297
+ raise # Re-raise the error to let the caller handle it further if needed
298
+
299
  return query, key, dynamic_lambda, key1
300
 
301
 
302
+
303
+
304
  def init_attention_func(unet):
305
 
306
  for name, module in unet.named_modules():