fffiloni commited on
Commit
8063899
·
verified ·
1 Parent(s): cfc7d51

Update text2vid_torch2.py

Browse files
Files changed (1) hide show
  1. text2vid_torch2.py +10 -1
text2vid_torch2.py CHANGED
@@ -258,7 +258,16 @@ class AttnProcessor2_0:
258
  # Ensure slice dimensions match
259
  target_size = min(query[i*all_dim:(i*all_dim) + batch_dim, :query_list.shape[1], :query_list.shape[2]].size(0),
260
  query_list[i*batch_dim:(i+1)*batch_dim].size(0))
261
-
 
 
 
 
 
 
 
 
 
262
  # Assign values from query_list to query
263
  query[i*all_dim:(i*all_dim) + target_size, :query_list.shape[1], :query_list.shape[2]] = \
264
  query_list[i*batch_dim:i*batch_dim + target_size]
 
258
  # Ensure slice dimensions match
259
  target_size = min(query[i*all_dim:(i*all_dim) + batch_dim, :query_list.shape[1], :query_list.shape[2]].size(0),
260
  query_list[i*batch_dim:(i+1)*batch_dim].size(0))
261
+
262
+ # Check if the target size is compatible with the query slice dimensions
263
+ query_slice_shape = query[i*all_dim:(i*all_dim) + target_size, :query_list.shape[1], :query_list.shape[2]].shape
264
+ query_list_slice_shape = query_list[i*batch_dim:i*batch_dim + target_size].shape
265
+
266
+ if query_slice_shape[1] != query_list_slice_shape[1]: # Dimension mismatch check
267
+ print(f"Warning: Dimension mismatch. query_slice_shape: {query_slice_shape}, query_list_slice_shape: {query_list_slice_shape}. Adjusting to compatible sizes.")
268
+ # Adjust to the smaller dimension
269
+ target_size = min(query_slice_shape[1], query_list_slice_shape[1])
270
+
271
  # Assign values from query_list to query
272
  query[i*all_dim:(i*all_dim) + target_size, :query_list.shape[1], :query_list.shape[2]] = \
273
  query_list[i*batch_dim:i*batch_dim + target_size]