|
|
import torch |
|
|
from flow_matching.utils import categorical |
|
|
import math |
|
|
import inspect |
|
|
import random |
|
|
|
|
|
def generate_simplex_lattice_points(num_obj: int, num_div: int) -> torch.Tensor: |
|
|
def rec(n, H): |
|
|
if n == 1: |
|
|
return [[H]] |
|
|
points = [] |
|
|
for i in range(H + 1): |
|
|
for tail in rec(n - 1, H - i): |
|
|
points.append([i] + tail) |
|
|
return points |
|
|
|
|
|
points = rec(num_obj, num_div) |
|
|
weight_vectors = torch.tensor(points, dtype=torch.float32) / num_div |
|
|
return weight_vectors |
|
|
|
|
|
def select_random_weight_vector(num_obj: int, num_div: int): |
|
|
weight_vectors = generate_simplex_lattice_points(num_obj, num_div) |
|
|
idx = torch.randint(0, weight_vectors.size(0), (1,)).item() |
|
|
random_weight_vector = weight_vectors[idx] |
|
|
return random_weight_vector, weight_vectors |
|
|
|
|
|
def z_score_norm(tensor, eps=1e-8): |
|
|
mean = tensor.mean(dim=-1, keepdim=True) |
|
|
std = tensor.std(dim=-1, unbiased=False, keepdim=True).clamp(min=eps) |
|
|
return (tensor - mean) / std |
|
|
|
|
|
def guided_transition_scoring(x_t, u_t, w, s_models, t, importance, args): |
|
|
B, L, vocab_size = u_t.shape |
|
|
device = x_t.device |
|
|
guided_u_t = u_t.clone() |
|
|
|
|
|
|
|
|
|
|
|
pos_indices = torch.tensor([random.choice([i for i in range(1, L-2) if i != 6])]).to(x_t.device) |
|
|
batch_idx = torch.arange(B, device=device) |
|
|
current_tokens = x_t[batch_idx, pos_indices] |
|
|
|
|
|
|
|
|
full_cand_tokens = torch.arange(vocab_size, device=device).unsqueeze(0).expand(B, vocab_size) |
|
|
mask = (full_cand_tokens != current_tokens.unsqueeze(1)) & (full_cand_tokens != 23) |
|
|
|
|
|
cand_tokens = torch.masked_select(full_cand_tokens, mask).view(B, vocab_size - 2) |
|
|
|
|
|
|
|
|
new_x = x_t.unsqueeze(1).expand(B, vocab_size, L).clone() |
|
|
new_x = new_x[mask].view(B, vocab_size - 2, L) |
|
|
new_x[batch_idx, :, pos_indices] = cand_tokens |
|
|
|
|
|
new_x_flat = new_x.view(B * (vocab_size - 2), L) |
|
|
improvements_list = [] |
|
|
with torch.no_grad(): |
|
|
count = 0 |
|
|
for i, s in enumerate(s_models): |
|
|
sig = inspect.signature(s.forward) if hasattr(s, 'forward') else inspect.signature(s) |
|
|
if 't' in sig.parameters: |
|
|
candidate_scores = s(new_x_flat, t) |
|
|
base_score = s(x_t, t) |
|
|
else: |
|
|
candidate_scores = s(new_x_flat) |
|
|
base_score = s(x_t) |
|
|
|
|
|
if isinstance(candidate_scores, tuple): |
|
|
for k, score in enumerate(candidate_scores): |
|
|
improvement = candidate_scores[k].view(B, vocab_size - 2) - base_score[k].unsqueeze(1) |
|
|
improvement = improvement.float() |
|
|
improvement *= importance[count] |
|
|
improvements_list.append(improvement.unsqueeze(2)) |
|
|
count += 1 |
|
|
else: |
|
|
improvement = candidate_scores.view(B, vocab_size - 2) - base_score.unsqueeze(1) |
|
|
improvement = improvement.float() |
|
|
improvement *= importance[count] |
|
|
improvements_list.append(improvement.unsqueeze(2)) |
|
|
count += 1 |
|
|
|
|
|
improvement_values = torch.cat(improvements_list, dim=2) |
|
|
if args.is_peptide: |
|
|
improvement_values[:, :4, :] = -10 |
|
|
|
|
|
|
|
|
ranks = torch.argsort(torch.argsort(improvement_values, dim=1), dim=1).float() + 1 |
|
|
I_n = ranks / float(vocab_size - 2) |
|
|
avg_I = I_n.mean(dim=2) |
|
|
norm_avg_I = z_score_norm(avg_I) |
|
|
|
|
|
|
|
|
D = (improvement_values * w.view(1, 1, -1)).sum(dim=2) |
|
|
norm_D = z_score_norm(D) |
|
|
|
|
|
|
|
|
delta_S = norm_avg_I + args.lambda_ * norm_D |
|
|
|
|
|
|
|
|
factor = torch.exp(args.beta * delta_S) |
|
|
factor = torch.clamp(factor, min=-100, max=100) |
|
|
|
|
|
guided_u_t[batch_idx.unsqueeze(1), pos_indices.unsqueeze(1), cand_tokens] = u_t[batch_idx.unsqueeze(1), pos_indices.unsqueeze(1), cand_tokens] * factor |
|
|
|
|
|
|
|
|
|
|
|
updated_vals = guided_u_t[batch_idx, pos_indices, :] |
|
|
sum_off_diag = updated_vals.sum(dim=1) - updated_vals[batch_idx, current_tokens] |
|
|
guided_u_t[batch_idx, pos_indices, current_tokens] = -sum_off_diag |
|
|
|
|
|
return guided_u_t, pos_indices, cand_tokens, improvement_values, delta_S |
|
|
|
|
|
def adaptive_hypercone_filtering(improvement_values, cand_tokens, delta_S, w, Phi, args, ema_r_t=None): |
|
|
B, num_candidates, N = improvement_values.shape |
|
|
device = improvement_values.device |
|
|
eps = 1e-8 |
|
|
|
|
|
|
|
|
imp_norm = torch.norm(improvement_values.float(), dim=2) |
|
|
dot_product = (improvement_values * w.view(1, 1, -1)).sum(dim=2) |
|
|
w_norm = torch.norm(w) + eps |
|
|
cos_angle = dot_product / (imp_norm * w_norm + eps) |
|
|
cos_angle = cos_angle.clamp(-1.0, 1.0) |
|
|
angles = torch.acos(cos_angle) |
|
|
|
|
|
valid_mask = angles < math.pi / 2 |
|
|
accepted_mask = valid_mask & (angles <= Phi) |
|
|
|
|
|
|
|
|
|
|
|
best_candidate = torch.empty(B, dtype=torch.long, device=device) |
|
|
for i in range(B): |
|
|
|
|
|
if valid_mask[i].any(): |
|
|
|
|
|
if accepted_mask[i].any(): |
|
|
|
|
|
candidate_idx = torch.argmax(delta_S[i].masked_fill(~accepted_mask[i], float('-inf'))) |
|
|
else: |
|
|
|
|
|
candidate_idx = torch.argmax(delta_S[i].masked_fill(~valid_mask[i], float('-inf'))) |
|
|
best_candidate[i] = cand_tokens[i, candidate_idx] |
|
|
else: |
|
|
|
|
|
best_candidate[i] = -1 |
|
|
|
|
|
|
|
|
rejection_rates = [] |
|
|
for i in range(B): |
|
|
valid_candidates = valid_mask[i] |
|
|
total_valid = valid_candidates.sum().item() |
|
|
if total_valid > 0: |
|
|
|
|
|
num_rejected = (valid_candidates.sum() - accepted_mask[i].sum()).item() |
|
|
rejection_rates.append(num_rejected / total_valid) |
|
|
if len(rejection_rates) > 0: |
|
|
r_t = sum(rejection_rates) / len(rejection_rates) |
|
|
else: |
|
|
|
|
|
r_t = 0.0 |
|
|
|
|
|
if ema_r_t is None: |
|
|
ema_r_t = args.tau |
|
|
|
|
|
|
|
|
if valid_mask.any(): |
|
|
new_ema_r_t = args.alpha_r * ema_r_t + (1 - args.alpha_r) * r_t |
|
|
new_Phi = Phi * torch.exp(torch.tensor(args.eta * (new_ema_r_t - args.tau), device=device)) |
|
|
new_Phi = new_Phi.clamp(args.Phi_min, args.Phi_max).item() |
|
|
else: |
|
|
new_ema_r_t = ema_r_t |
|
|
new_Phi = Phi |
|
|
|
|
|
return best_candidate, accepted_mask, valid_mask, new_Phi, new_ema_r_t |
|
|
|
|
|
def get_best_candidate(improvement_values, cand_tokens, delta_S): |
|
|
B, num_candidates, N = improvement_values.shape |
|
|
device = improvement_values.device |
|
|
best_candidate = torch.empty(B, dtype=torch.long, device=device) |
|
|
|
|
|
for i in range(B): |
|
|
candidate_idx = torch.argmax(delta_S[i]) |
|
|
best_candidate[i] = cand_tokens[i, candidate_idx] |
|
|
|
|
|
return best_candidate |
|
|
|
|
|
def euler_sample(x_t, pos_indices, best_candidate, guided_u_t, h): |
|
|
B, L, V = guided_u_t.shape |
|
|
device = x_t.device |
|
|
u = torch.zeros_like(guided_u_t) |
|
|
|
|
|
valid_mask = best_candidate != -1 |
|
|
if valid_mask.any(): |
|
|
valid_idx = torch.nonzero(valid_mask).squeeze(-1) |
|
|
|
|
|
u[valid_idx, pos_indices[valid_idx], best_candidate[valid_idx]] = \ |
|
|
guided_u_t[valid_idx, pos_indices[valid_idx], best_candidate[valid_idx]] |
|
|
|
|
|
|
|
|
|
|
|
intensity = torch.zeros(B, device=device) |
|
|
if valid_mask.any(): |
|
|
intensity[valid_idx] = u[valid_idx, pos_indices[valid_idx]].sum(dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
p_jump = 1 - torch.exp(-1 * intensity) |
|
|
|
|
|
rand_val = torch.rand(B, device=device) |
|
|
|
|
|
jump_decision = (rand_val < p_jump) & valid_mask |
|
|
|
|
|
|
|
|
x_t[jump_decision, pos_indices[jump_decision]] = best_candidate[jump_decision] |
|
|
|
|
|
return x_t |
|
|
|