Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. | |
| import numpy as np | |
| import torch | |
| def rope_precompute(x, grid_sizes, freqs, start=None): | |
| b, s, n, c = x.size(0), x.size(1), x.size(2), x.size(3) // 2 | |
| # split freqs | |
| if type(freqs) is list: | |
| trainable_freqs = freqs[1] | |
| freqs = freqs[0] | |
| freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) | |
| # loop over samples | |
| output = torch.view_as_complex(x.detach().reshape(b, s, n, -1, | |
| 2).to(torch.float64)) | |
| seq_bucket = [0] | |
| if not type(grid_sizes) is list: | |
| grid_sizes = [grid_sizes] | |
| for g in grid_sizes: | |
| if not type(g) is list: | |
| g = [torch.zeros_like(g), g] | |
| batch_size = g[0].shape[0] | |
| for i in range(batch_size): | |
| if start is None: | |
| f_o, h_o, w_o = g[0][i] | |
| else: | |
| f_o, h_o, w_o = start[i] | |
| f, h, w = g[1][i] | |
| t_f, t_h, t_w = g[2][i] | |
| seq_f, seq_h, seq_w = f - f_o, h - h_o, w - w_o | |
| seq_len = int(seq_f * seq_h * seq_w) | |
| if seq_len > 0: | |
| if t_f > 0: | |
| factor_f, factor_h, factor_w = (t_f / seq_f).item(), ( | |
| t_h / seq_h).item(), (t_w / seq_w).item() | |
| # Generate a list of seq_f integers starting from f_o and ending at math.ceil(factor_f * seq_f.item() + f_o.item()) | |
| if f_o >= 0: | |
| f_sam = np.linspace(f_o.item(), (t_f + f_o).item() - 1, | |
| seq_f).astype(int).tolist() | |
| else: | |
| f_sam = np.linspace(-f_o.item(), | |
| (-t_f - f_o).item() + 1, | |
| seq_f).astype(int).tolist() | |
| h_sam = np.linspace(h_o.item(), (t_h + h_o).item() - 1, | |
| seq_h).astype(int).tolist() | |
| w_sam = np.linspace(w_o.item(), (t_w + w_o).item() - 1, | |
| seq_w).astype(int).tolist() | |
| assert f_o * f >= 0 and h_o * h >= 0 and w_o * w >= 0 | |
| freqs_0 = freqs[0][f_sam] if f_o >= 0 else freqs[0][ | |
| f_sam].conj() | |
| freqs_0 = freqs_0.view(seq_f, 1, 1, -1) | |
| freqs_i = torch.cat([ | |
| freqs_0.expand(seq_f, seq_h, seq_w, -1), | |
| freqs[1][h_sam].view(1, seq_h, 1, -1).expand( | |
| seq_f, seq_h, seq_w, -1), | |
| freqs[2][w_sam].view(1, 1, seq_w, -1).expand( | |
| seq_f, seq_h, seq_w, -1), | |
| ], | |
| dim=-1).reshape(seq_len, 1, -1) | |
| elif t_f < 0: | |
| freqs_i = trainable_freqs.unsqueeze(1) | |
| # apply rotary embedding | |
| output[i, seq_bucket[-1]:seq_bucket[-1] + seq_len] = freqs_i | |
| seq_bucket.append(seq_bucket[-1] + seq_len) | |
| return output | |