Spaces:
Paused
Paused
Update text2vid_torch2.py
Browse files- text2vid_torch2.py +10 -1
text2vid_torch2.py
CHANGED
@@ -258,7 +258,16 @@ class AttnProcessor2_0:
|
|
258 |
# Ensure slice dimensions match
|
259 |
target_size = min(query[i*all_dim:(i*all_dim) + batch_dim, :query_list.shape[1], :query_list.shape[2]].size(0),
|
260 |
query_list[i*batch_dim:(i+1)*batch_dim].size(0))
|
261 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
262 |
# Assign values from query_list to query
|
263 |
query[i*all_dim:(i*all_dim) + target_size, :query_list.shape[1], :query_list.shape[2]] = \
|
264 |
query_list[i*batch_dim:i*batch_dim + target_size]
|
|
|
258 |
# Ensure slice dimensions match
|
259 |
target_size = min(query[i*all_dim:(i*all_dim) + batch_dim, :query_list.shape[1], :query_list.shape[2]].size(0),
|
260 |
query_list[i*batch_dim:(i+1)*batch_dim].size(0))
|
261 |
+
|
262 |
+
# Check if the target size is compatible with the query slice dimensions
|
263 |
+
query_slice_shape = query[i*all_dim:(i*all_dim) + target_size, :query_list.shape[1], :query_list.shape[2]].shape
|
264 |
+
query_list_slice_shape = query_list[i*batch_dim:i*batch_dim + target_size].shape
|
265 |
+
|
266 |
+
if query_slice_shape[1] != query_list_slice_shape[1]: # Dimension mismatch check
|
267 |
+
print(f"Warning: Dimension mismatch. query_slice_shape: {query_slice_shape}, query_list_slice_shape: {query_list_slice_shape}. Adjusting to compatible sizes.")
|
268 |
+
# Adjust to the smaller dimension
|
269 |
+
target_size = min(query_slice_shape[1], query_list_slice_shape[1])
|
270 |
+
|
271 |
# Assign values from query_list to query
|
272 |
query[i*all_dim:(i*all_dim) + target_size, :query_list.shape[1], :query_list.shape[2]] = \
|
273 |
query_list[i*batch_dim:i*batch_dim + target_size]
|