Spaces:
Runtime error
Runtime error
import torch | |
# Please refer to the https://perp-neg.github.io/ for details about the paper and algorithm | |
def get_perpendicular_component(x, y): | |
assert x.shape == y.shape | |
return x - ((torch.mul(x, y).sum())/max(torch.norm(y)**2, 1e-6)) * y | |
def batch_get_perpendicular_component(x, y): | |
assert x.shape == y.shape | |
result = [] | |
for i in range(x.shape[0]): | |
result.append(get_perpendicular_component(x[i], y[i])) | |
return torch.stack(result) | |
def weighted_perpendicular_aggregator(delta_noise_preds, weights, batch_size): | |
""" | |
Notes: | |
- weights: an array with the weights for combining the noise predictions | |
- delta_noise_preds: [B x K, 4, 64, 64], K = max_prompts_per_dir | |
""" | |
delta_noise_preds = delta_noise_preds.split(batch_size, dim=0) # K x [B, 4, 64, 64] | |
weights = weights.split(batch_size, dim=0) # K x [B] | |
# print(f"{weights[0].shape = } {weights = }") | |
assert torch.all(weights[0] == 1.0) | |
main_positive = delta_noise_preds[0] # [B, 4, 64, 64] | |
accumulated_output = torch.zeros_like(main_positive) | |
for i, complementary_noise_pred in enumerate(delta_noise_preds[1:], start=1): | |
# print(f"\n{i = }, {weights[i] = }, {weights[i].shape = }\n") | |
idx_non_zero = torch.abs(weights[i]) > 1e-4 | |
# print(f"{idx_non_zero.shape = }, {idx_non_zero = }") | |
# print(f"{weights[i][idx_non_zero].shape = }, {weights[i][idx_non_zero] = }") | |
# print(f"{complementary_noise_pred.shape = }, {complementary_noise_pred[idx_non_zero].shape = }") | |
# print(f"{main_positive.shape = }, {main_positive[idx_non_zero].shape = }") | |
if sum(idx_non_zero) == 0: | |
continue | |
accumulated_output[idx_non_zero] += weights[i][idx_non_zero].reshape(-1, 1, 1, 1) * batch_get_perpendicular_component(complementary_noise_pred[idx_non_zero], main_positive[idx_non_zero]) | |
#assert accumulated_output.shape == main_positive.shape,# f"{accumulated_output.shape = }, {main_positive.shape = }" | |
return accumulated_output + main_positive |