moPPIt-v3 / flow_matching /utils /multi_guidance_cnp.py
AlienChen's picture
Upload 72 files
3527383 verified
import torch
from flow_matching.utils import categorical
import math
import inspect
import random
def generate_simplex_lattice_points(num_obj: int, num_div: int) -> torch.Tensor:
def rec(n, H):
if n == 1:
return [[H]]
points = []
for i in range(H + 1):
for tail in rec(n - 1, H - i):
points.append([i] + tail)
return points
points = rec(num_obj, num_div)
weight_vectors = torch.tensor(points, dtype=torch.float32) / num_div
return weight_vectors
def select_random_weight_vector(num_obj: int, num_div: int):
weight_vectors = generate_simplex_lattice_points(num_obj, num_div)
idx = torch.randint(0, weight_vectors.size(0), (1,)).item()
random_weight_vector = weight_vectors[idx]
return random_weight_vector, weight_vectors
def z_score_norm(tensor, eps=1e-8):
mean = tensor.mean(dim=-1, keepdim=True)
std = tensor.std(dim=-1, unbiased=False, keepdim=True).clamp(min=eps)
return (tensor - mean) / std
def guided_transition_scoring(x_t, u_t, w, s_models, t, importance, args):
B, L, vocab_size = u_t.shape
device = x_t.device
guided_u_t = u_t.clone()
# 1. Randomly select one position per sequence.
# pos_indices = torch.randint(low=1, high=L-2, size=(B,), device=device) # shape: (B,) # CHANGE!
pos_indices = torch.tensor([random.choice([i for i in range(1, L-2) if i != 6])]).to(x_t.device)
batch_idx = torch.arange(B, device=device)
current_tokens = x_t[batch_idx, pos_indices] # shape: (B,)
# 2. Build candidate tokens for each sequence and remove self-transition.
full_cand_tokens = torch.arange(vocab_size, device=device).unsqueeze(0).expand(B, vocab_size) # (B, vocab_size)
mask = (full_cand_tokens != current_tokens.unsqueeze(1)) & (full_cand_tokens != 23) # (B, vocab_size)
# Now, cand_tokens contains only candidate tokens that differ from the current token.
cand_tokens = torch.masked_select(full_cand_tokens, mask).view(B, vocab_size - 2) # (B, vocab_size-1)
# 3. Create candidate sequences by replacing the token at the selected position.
new_x = x_t.unsqueeze(1).expand(B, vocab_size, L).clone()
new_x = new_x[mask].view(B, vocab_size - 2, L) # (B, vocab_size-1, L)
new_x[batch_idx, :, pos_indices] = cand_tokens
new_x_flat = new_x.view(B * (vocab_size - 2), L)
improvements_list = []
with torch.no_grad():
count = 0
for i, s in enumerate(s_models):
sig = inspect.signature(s.forward) if hasattr(s, 'forward') else inspect.signature(s)
if 't' in sig.parameters:
candidate_scores = s(new_x_flat, t)
base_score = s(x_t, t)
else:
candidate_scores = s(new_x_flat)
base_score = s(x_t)
if isinstance(candidate_scores, tuple):
for k, score in enumerate(candidate_scores):
improvement = candidate_scores[k].view(B, vocab_size - 2) - base_score[k].unsqueeze(1)
improvement = improvement.float()
improvement *= importance[count]
improvements_list.append(improvement.unsqueeze(2))
count += 1
else:
improvement = candidate_scores.view(B, vocab_size - 2) - base_score.unsqueeze(1)
improvement = improvement.float()
improvement *= importance[count]
improvements_list.append(improvement.unsqueeze(2)) # (B, vocab_size-1, 1)
count += 1
improvement_values = torch.cat(improvements_list, dim=2) # (B, vocab_size-1, N)
if args.is_peptide:
improvement_values[:, :4, :] = -10 # Mask non-residue positions
# 5. Compute ranking scores I_n
ranks = torch.argsort(torch.argsort(improvement_values, dim=1), dim=1).float() + 1 # (B, vocab_size-1, N)
I_n = ranks / float(vocab_size - 2)
avg_I = I_n.mean(dim=2)
norm_avg_I = z_score_norm(avg_I) # (B, vocab_size-1)
# 6. Compute directional score D
D = (improvement_values * w.view(1, 1, -1)).sum(dim=2)
norm_D = z_score_norm(D) # (B, vocab_size-1)
# 7. Combine the scores
delta_S = norm_avg_I + args.lambda_ * norm_D # (B, vocab_size-1)
# 9. Update the guided velocities at the selected positions.
factor = torch.exp(args.beta * delta_S) # (B, vocab_size-1)
factor = torch.clamp(factor, min=-100, max=100)
guided_u_t[batch_idx.unsqueeze(1), pos_indices.unsqueeze(1), cand_tokens] = u_t[batch_idx.unsqueeze(1), pos_indices.unsqueeze(1), cand_tokens] * factor
# 10. For the self-transition (current token) at the selected position,
# set its guided velocity to be the negative sum of the updated off-diagonals.
updated_vals = guided_u_t[batch_idx, pos_indices, :] # (B, vocab_size)
sum_off_diag = updated_vals.sum(dim=1) - updated_vals[batch_idx, current_tokens]
guided_u_t[batch_idx, pos_indices, current_tokens] = -sum_off_diag
return guided_u_t, pos_indices, cand_tokens, improvement_values, delta_S
def adaptive_hypercone_filtering(improvement_values, cand_tokens, delta_S, w, Phi, args, ema_r_t=None):
B, num_candidates, N = improvement_values.shape
device = improvement_values.device
eps = 1e-8
# Compute norms and angles.
imp_norm = torch.norm(improvement_values.float(), dim=2) # (B, num_candidates)
dot_product = (improvement_values * w.view(1, 1, -1)).sum(dim=2)
w_norm = torch.norm(w) + eps
cos_angle = dot_product / (imp_norm * w_norm + eps)
cos_angle = cos_angle.clamp(-1.0, 1.0)
angles = torch.acos(cos_angle) # (B, num_candidates)
valid_mask = angles < math.pi / 2
accepted_mask = valid_mask & (angles <= Phi) # (B, num_candidates)
# Determine the best candidate for each sequence.
# We'll use a loop over batch items (batch size is typically moderate).
best_candidate = torch.empty(B, dtype=torch.long, device=device)
for i in range(B):
# For sequence i, consider only valid candidates.
if valid_mask[i].any():
# There is at least one candidate with α^i < π.
if accepted_mask[i].any():
# At least one candidate passes the hypercone: choose the one with max delta_S among accepted.
candidate_idx = torch.argmax(delta_S[i].masked_fill(~accepted_mask[i], float('-inf')))
else:
# No candidate was accepted, but some are valid. Select best candidate among valid ones.
candidate_idx = torch.argmax(delta_S[i].masked_fill(~valid_mask[i], float('-inf')))
best_candidate[i] = cand_tokens[i, candidate_idx]
else:
# No candidate is valid (all α^i >= π) → self-transition.
best_candidate[i] = -1
# Compute rejection rate only over valid candidates.
rejection_rates = []
for i in range(B):
valid_candidates = valid_mask[i]
total_valid = valid_candidates.sum().item()
if total_valid > 0:
# Among valid candidates, count how many are rejected.
num_rejected = (valid_candidates.sum() - accepted_mask[i].sum()).item()
rejection_rates.append(num_rejected / total_valid)
if len(rejection_rates) > 0:
r_t = sum(rejection_rates) / len(rejection_rates)
else:
# If no sequence has any valid candidate, set r_t to 0.
r_t = 0.0
if ema_r_t is None:
ema_r_t = args.tau
# Update hypercone angle and ema rejection rate only if there is at least one valid candidate in the batch.
if valid_mask.any():
new_ema_r_t = args.alpha_r * ema_r_t + (1 - args.alpha_r) * r_t
new_Phi = Phi * torch.exp(torch.tensor(args.eta * (new_ema_r_t - args.tau), device=device))
new_Phi = new_Phi.clamp(args.Phi_min, args.Phi_max).item()
else:
new_ema_r_t = ema_r_t
new_Phi = Phi # No update if no valid candidate exists.
return best_candidate, accepted_mask, valid_mask, new_Phi, new_ema_r_t
def get_best_candidate(improvement_values, cand_tokens, delta_S):
B, num_candidates, N = improvement_values.shape
device = improvement_values.device
best_candidate = torch.empty(B, dtype=torch.long, device=device)
for i in range(B):
candidate_idx = torch.argmax(delta_S[i])
best_candidate[i] = cand_tokens[i, candidate_idx]
return best_candidate
def euler_sample(x_t, pos_indices, best_candidate, guided_u_t, h):
B, L, V = guided_u_t.shape
device = x_t.device
u = torch.zeros_like(guided_u_t)
valid_mask = best_candidate != -1
if valid_mask.any():
valid_idx = torch.nonzero(valid_mask).squeeze(-1)
# For these sequences, update the velocity at the selected position and candidate token.
u[valid_idx, pos_indices[valid_idx], best_candidate[valid_idx]] = \
guided_u_t[valid_idx, pos_indices[valid_idx], best_candidate[valid_idx]]
# Compute intensity at the selected positions.
# For sequences with no valid candidate (i.e. self-transition), intensity remains zero.
intensity = torch.zeros(B, device=device)
if valid_mask.any():
intensity[valid_idx] = u[valid_idx, pos_indices[valid_idx]].sum(dim=-1)
# According to the Euler Sampling formula, `p_jump` should be `1 - torch.exp(-h * intensity)`
# However, since `h = 1 / T` is small, p_jump becomes tiny and slows down sampling.
# To compensate, we scale `intensity` by T. We can do this because this is equivalent to setting `args.beta` to `T * args.beta`.
# So for faster sampling, we just use `1 - torch.exp(-1 * intensity)`
p_jump = 1 - torch.exp(-1 * intensity)
rand_val = torch.rand(B, device=device)
jump_decision = (rand_val < p_jump) & valid_mask
# For sequences where a jump is decided, update the token at pos_indices to best_candidate.
x_t[jump_decision, pos_indices[jump_decision]] = best_candidate[jump_decision]
return x_t