fffiloni commited on
Commit
83eb248
·
verified ·
1 Parent(s): d2aa136

Update text2vid_torch2.py

Browse files
Files changed (1) hide show
  1. text2vid_torch2.py +29 -39
text2vid_torch2.py CHANGED
@@ -224,67 +224,57 @@ class AttnProcessor2_0:
224
  return query, key, dynamic_lambda, key1
225
  '''
226
 
227
- import torch
228
-
229
  def get_qk(self, query, key):
230
  r"""
231
  Compute the attention scores.
 
232
  Args:
233
  query (`torch.Tensor`): The query tensor.
234
  key (`torch.Tensor`): The key tensor.
235
  attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
 
236
  Returns:
237
  `torch.Tensor`: The attention probabilities/scores.
238
  """
239
- q_old = query.clone()
240
- k_old = key.clone()
241
- dynamic_lambda = None
242
- key1 = None
243
-
244
  try:
 
 
 
 
 
245
  if self.use_last_attn_slice:
246
  if self.last_attn_slice is not None:
247
-
248
  query_list = self.last_attn_slice[0]
249
  key_list = self.last_attn_slice[1]
250
-
 
251
  if query.shape[1] == self.num_frames and query.shape == key.shape:
252
  key1 = key.clone()
253
 
254
- # Ensure the batch dimension of key1 and key_list match
255
- batch_size_key1 = key1.shape[0]
256
- batch_size_key_list = key_list.shape[0]
257
-
258
- if batch_size_key1 != batch_size_key_list:
259
- # Handle mismatch: either pad or slice to match sizes
260
- if batch_size_key1 > batch_size_key_list:
261
- # Pad key_list if key1 batch size is larger
262
- padding = (0, 0, 0, batch_size_key1 - batch_size_key_list) # (left, right, top, bottom)
263
- key_list = torch.nn.functional.pad(key_list, padding, "constant", 0)
264
- else:
265
- # Slice key1 if key_list batch size is larger
266
- key1 = key1[:batch_size_key_list]
267
-
268
- # Proceed with assignment after matching batch dimensions
269
- key1[:,:1,:key_list.shape[2]] = key_list[:,:1]
270
 
271
- dynamic_lambda = torch.tensor([1 + self.LAMBDA * (i/50) for i in range(self.num_frames)]).to(key.dtype).cuda()
 
272
 
273
  if q_old.shape == k_old.shape and q_old.shape[1] != self.num_frames:
 
274
  batch_dim = query_list.shape[0] // self.bs
275
  all_dim = query.shape[0] // self.bs
 
276
  for i in range(self.bs):
277
- target_size = min(query[i*all_dim:(i*all_dim) + batch_dim, :query_list.shape[1], :query_list.shape[2]].size(0),
278
- query_list[i*batch_dim:(i+1)*batch_dim].size(0))
279
- query_slice_shape = query[i*all_dim:(i*all_dim) + target_size, :query_list.shape[1], :query_list.shape[2]].shape
280
- query_list_slice_shape = query_list[i*batch_dim:i*batch_dim + target_size].shape
281
-
282
- if query_slice_shape[1] != query_list_slice_shape[1]:
283
- print(f"Warning: Dimension mismatch. query_slice_shape: {query_slice_shape}, query_list_slice_shape: {query_list_slice_shape}. Adjusting to compatible sizes.")
284
- target_size = min(query_slice_shape[1], query_list_slice_shape[1])
285
 
286
- query[i*all_dim:(i*all_dim) + target_size, :query_list.shape[1], :query_list.shape[2]] = \
287
- query_list[i*batch_dim:i*batch_dim + target_size]
 
 
 
288
 
289
  if self.save_last_attn_slice:
290
  self.last_attn_slice = [query, key]
@@ -293,9 +283,9 @@ class AttnProcessor2_0:
293
  except RuntimeError as e:
294
  # If a RuntimeError happens, catch it and clean CUDA memory
295
  print(f"RuntimeError occurred: {e}. Cleaning up CUDA memory...")
296
- torch.cuda.empty_cache()
297
- raise # Re-raise the error to let the caller handle it further if needed
298
-
299
  return query, key, dynamic_lambda, key1
300
 
301
 
 
224
  return query, key, dynamic_lambda, key1
225
  '''
226
 
 
 
227
  def get_qk(self, query, key):
228
  r"""
229
  Compute the attention scores.
230
+
231
  Args:
232
  query (`torch.Tensor`): The query tensor.
233
  key (`torch.Tensor`): The key tensor.
234
  attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
235
+
236
  Returns:
237
  `torch.Tensor`: The attention probabilities/scores.
238
  """
 
 
 
 
 
239
  try:
240
+ q_old = query.clone()
241
+ k_old = key.clone()
242
+ dynamic_lambda = None
243
+ key1 = None
244
+
245
  if self.use_last_attn_slice:
246
  if self.last_attn_slice is not None:
 
247
  query_list = self.last_attn_slice[0]
248
  key_list = self.last_attn_slice[1]
249
+
250
+ # Ensure that shapes are compatible before performing assignments
251
  if query.shape[1] == self.num_frames and query.shape == key.shape:
252
  key1 = key.clone()
253
 
254
+ # Safety check: ensure key1 can receive the value from key_list without causing size mismatch
255
+ if key1.shape[0] >= key_list.shape[0]:
256
+ key1[:, :1, :key_list.shape[2]] = key_list[:, :1]
257
+ else:
258
+ raise RuntimeError(f"Shape mismatch: key1 has {key1.shape[0]} batches, but key_list has {key_list.shape[0]} batches.")
 
 
 
 
 
 
 
 
 
 
 
259
 
260
+ # Dynamic lambda scaling
261
+ dynamic_lambda = torch.tensor([1 + self.LAMBDA * (i / 50) for i in range(self.num_frames)]).to(key.dtype).cuda()
262
 
263
  if q_old.shape == k_old.shape and q_old.shape[1] != self.num_frames:
264
+ # Ensure batch size division is valid
265
  batch_dim = query_list.shape[0] // self.bs
266
  all_dim = query.shape[0] // self.bs
267
+
268
  for i in range(self.bs):
269
+ # Safety check for slicing indices to avoid memory access errors
270
+ query_slice = query[i * all_dim:(i * all_dim) + batch_dim, :query_list.shape[1], :query_list.shape[2]]
271
+ target_slice = query_list[i * batch_dim:(i + 1) * batch_dim]
 
 
 
 
 
272
 
273
+ # Validate dimensions match before assignment
274
+ if query_slice.shape == target_slice.shape:
275
+ query_slice[:] = target_slice
276
+ else:
277
+ raise RuntimeError(f"Shape mismatch during slicing: query slice shape {query_slice.shape}, target slice shape {target_slice.shape}")
278
 
279
  if self.save_last_attn_slice:
280
  self.last_attn_slice = [query, key]
 
283
  except RuntimeError as e:
284
  # If a RuntimeError happens, catch it and clean CUDA memory
285
  print(f"RuntimeError occurred: {e}. Cleaning up CUDA memory...")
286
+ torch.cuda.empty_cache() # Free up CUDA memory to avoid further issues
287
+ raise # Re-raise the error to propagate it if needed
288
+
289
  return query, key, dynamic_lambda, key1
290
 
291