LucidDreamer / guidance /perpneg_utils.py
haodongli's picture
init
916b126
raw
history blame
2.06 kB
import torch
# Please refer to the https://perp-neg.github.io/ for details about the paper and algorithm
def get_perpendicular_component(x, y):
assert x.shape == y.shape
return x - ((torch.mul(x, y).sum())/max(torch.norm(y)**2, 1e-6)) * y
def batch_get_perpendicular_component(x, y):
assert x.shape == y.shape
result = []
for i in range(x.shape[0]):
result.append(get_perpendicular_component(x[i], y[i]))
return torch.stack(result)
def weighted_perpendicular_aggregator(delta_noise_preds, weights, batch_size):
"""
Notes:
- weights: an array with the weights for combining the noise predictions
- delta_noise_preds: [B x K, 4, 64, 64], K = max_prompts_per_dir
"""
delta_noise_preds = delta_noise_preds.split(batch_size, dim=0) # K x [B, 4, 64, 64]
weights = weights.split(batch_size, dim=0) # K x [B]
# print(f"{weights[0].shape = } {weights = }")
assert torch.all(weights[0] == 1.0)
main_positive = delta_noise_preds[0] # [B, 4, 64, 64]
accumulated_output = torch.zeros_like(main_positive)
for i, complementary_noise_pred in enumerate(delta_noise_preds[1:], start=1):
# print(f"\n{i = }, {weights[i] = }, {weights[i].shape = }\n")
idx_non_zero = torch.abs(weights[i]) > 1e-4
# print(f"{idx_non_zero.shape = }, {idx_non_zero = }")
# print(f"{weights[i][idx_non_zero].shape = }, {weights[i][idx_non_zero] = }")
# print(f"{complementary_noise_pred.shape = }, {complementary_noise_pred[idx_non_zero].shape = }")
# print(f"{main_positive.shape = }, {main_positive[idx_non_zero].shape = }")
if sum(idx_non_zero) == 0:
continue
accumulated_output[idx_non_zero] += weights[i][idx_non_zero].reshape(-1, 1, 1, 1) * batch_get_perpendicular_component(complementary_noise_pred[idx_non_zero], main_positive[idx_non_zero])
#assert accumulated_output.shape == main_positive.shape,# f"{accumulated_output.shape = }, {main_positive.shape = }"
return accumulated_output + main_positive