AlienChen commited on
Commit
f930dca
·
verified ·
1 Parent(s): 65a78ff

Update flow_matching/utils/multi_guidance.py

Browse files
Files changed (1) hide show
  1. flow_matching/utils/multi_guidance.py +118 -11
flow_matching/utils/multi_guidance.py CHANGED
@@ -1,7 +1,8 @@
1
  import torch
2
- from flow_matching.utils import categorical
3
  import math
4
  import inspect
 
5
 
6
  def generate_simplex_lattice_points(num_obj: int, num_div: int) -> torch.Tensor:
7
  def rec(n, H):
@@ -28,13 +29,17 @@ def z_score_norm(tensor, eps=1e-8):
28
  std = tensor.std(dim=-1, unbiased=False, keepdim=True).clamp(min=eps)
29
  return (tensor - mean) / std
30
 
31
- def guided_transition_scoring(x_t, u_t, w, s_models, t, importance, args):
32
  B, L, vocab_size = u_t.shape
33
  device = x_t.device
34
  guided_u_t = u_t.clone()
35
 
36
  # 1. Randomly select one position per sequence.
37
- pos_indices = torch.randint(low=1, high=L-2, size=(B,), device=device) # shape: (B,) # CHANGE!
 
 
 
 
38
  batch_idx = torch.arange(B, device=device)
39
  current_tokens = x_t[batch_idx, pos_indices] # shape: (B,)
40
 
@@ -53,32 +58,42 @@ def guided_transition_scoring(x_t, u_t, w, s_models, t, importance, args):
53
  improvements_list = []
54
  with torch.no_grad():
55
  count = 0
 
 
 
 
 
56
  for i, s in enumerate(s_models):
57
  sig = inspect.signature(s.forward) if hasattr(s, 'forward') else inspect.signature(s)
58
  if 't' in sig.parameters:
59
- candidate_scores = s(new_x_flat, t)
60
- base_score = s(x_t, t)
61
  else:
62
- candidate_scores = s(new_x_flat)
63
- base_score = s(x_t)
64
 
65
  if isinstance(candidate_scores, tuple):
66
  for k, score in enumerate(candidate_scores):
67
  improvement = candidate_scores[k].view(B, vocab_size - 1) - base_score[k].unsqueeze(1)
68
- improvement = improvement.float()
69
  improvement *= importance[count]
70
  improvements_list.append(improvement.unsqueeze(2))
71
  count += 1
72
  else:
73
  improvement = candidate_scores.view(B, vocab_size - 1) - base_score.unsqueeze(1)
74
- improvement = improvement.float()
75
  improvement *= importance[count]
76
  improvements_list.append(improvement.unsqueeze(2)) # (B, vocab_size-1, 1)
77
  count += 1
78
 
79
  improvement_values = torch.cat(improvements_list, dim=2) # (B, vocab_size-1, N)
80
- if args.is_peptide:
81
- improvement_values[:, :4, :] = -10 # Mask non-residue positions
 
 
 
 
 
82
 
83
  # 5. Compute ranking scores I_n
84
  ranks = torch.argsort(torch.argsort(improvement_values, dim=1), dim=1).float() + 1 # (B, vocab_size-1, N)
@@ -107,6 +122,98 @@ def guided_transition_scoring(x_t, u_t, w, s_models, t, importance, args):
107
 
108
  return guided_u_t, pos_indices, cand_tokens, improvement_values, delta_S
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  def adaptive_hypercone_filtering(improvement_values, cand_tokens, delta_S, w, Phi, args, ema_r_t=None):
111
  B, num_candidates, N = improvement_values.shape
112
  device = improvement_values.device
 
1
  import torch
2
+ import random
3
  import math
4
  import inspect
5
+ import pdb
6
 
7
  def generate_simplex_lattice_points(num_obj: int, num_div: int) -> torch.Tensor:
8
  def rec(n, H):
 
29
  std = tensor.std(dim=-1, unbiased=False, keepdim=True).clamp(min=eps)
30
  return (tensor - mean) / std
31
 
32
+ def guided_transition_scoring(x_t, u_t, w, s_models, t, importance, tokenizer, args, fixed_positions=None, invalid_tokens=None):
33
  B, L, vocab_size = u_t.shape
34
  device = x_t.device
35
  guided_u_t = u_t.clone()
36
 
37
  # 1. Randomly select one position per sequence.
38
+ all_positions = set(range(1, L-1))
39
+ available_positions = list(all_positions - set(fixed_positions))
40
+ assert len(available_positions) > 0
41
+ pos_indices = torch.tensor(random.choices(available_positions, k=B), device=device)
42
+ # pos_indices = torch.randint(low=1, high=L-2, size=(B,), device=device) # shape: (B,) # CHANGE!
43
  batch_idx = torch.arange(B, device=device)
44
  current_tokens = x_t[batch_idx, pos_indices] # shape: (B,)
45
 
 
58
  improvements_list = []
59
  with torch.no_grad():
60
  count = 0
61
+ input_seqs_cand = tokenizer.batch_decode(new_x_flat)
62
+ input_seqs_orig = tokenizer.batch_decode(x_t)
63
+ input_seqs_cand = [seq.replace(' ', '')[5:-5] for seq in input_seqs_cand]
64
+ input_seqs_orig = [seq.replace(' ', '')[5:-5] for seq in input_seqs_orig]
65
+
66
  for i, s in enumerate(s_models):
67
  sig = inspect.signature(s.forward) if hasattr(s, 'forward') else inspect.signature(s)
68
  if 't' in sig.parameters:
69
+ candidate_scores = s(input_seqs_cand, t)
70
+ base_score = s(input_seqs_orig, t)
71
  else:
72
+ candidate_scores = s(input_seqs_cand)
73
+ base_score = s(input_seqs_orig)
74
 
75
  if isinstance(candidate_scores, tuple):
76
  for k, score in enumerate(candidate_scores):
77
  improvement = candidate_scores[k].view(B, vocab_size - 1) - base_score[k].unsqueeze(1)
78
+ improvement = improvement.float().to(device)
79
  improvement *= importance[count]
80
  improvements_list.append(improvement.unsqueeze(2))
81
  count += 1
82
  else:
83
  improvement = candidate_scores.view(B, vocab_size - 1) - base_score.unsqueeze(1)
84
+ improvement = improvement.float().to(device)
85
  improvement *= importance[count]
86
  improvements_list.append(improvement.unsqueeze(2)) # (B, vocab_size-1, 1)
87
  count += 1
88
 
89
  improvement_values = torch.cat(improvements_list, dim=2) # (B, vocab_size-1, N)
90
+
91
+ invalid_mask = cand_tokens.unsqueeze(-1) == invalid_tokens.view(1, 1, -1)
92
+ final_invalid_mask = invalid_mask.any(dim=-1)
93
+ improvement_values[final_invalid_mask] = -10.0
94
+
95
+ # if args.is_peptide:
96
+ # improvement_values[:, :4, :] = -10 # Mask non-residue positions
97
 
98
  # 5. Compute ranking scores I_n
99
  ranks = torch.argsort(torch.argsort(improvement_values, dim=1), dim=1).float() + 1 # (B, vocab_size-1, N)
 
122
 
123
  return guided_u_t, pos_indices, cand_tokens, improvement_values, delta_S
124
 
125
+
126
+ def guided_transition_scoring_uaa(x_t, u_t, w, s_models, t, importance, tokenizer, args, fixed_positions=None, invalid_tokens=None):
127
+ B, L, vocab_size = u_t.shape
128
+ device = x_t.device
129
+ guided_u_t = u_t.clone()
130
+
131
+ # 1. Randomly select one position per sequence.
132
+ all_positions = set(range(1, L-1))
133
+ available_positions = list(all_positions - set(fixed_positions))
134
+ assert len(available_positions) > 0
135
+ pos_indices = torch.tensor(random.choices(available_positions, k=B), device=device)
136
+ # pos_indices = torch.randint(low=1, high=L-2, size=(B,), device=device) # shape: (B,) # CHANGE!
137
+ batch_idx = torch.arange(B, device=device)
138
+ current_tokens = x_t[batch_idx, pos_indices] # shape: (B,)
139
+
140
+ # 2. Build candidate tokens for each sequence and remove self-transition.
141
+ full_cand_tokens = torch.arange(vocab_size, device=device).unsqueeze(0).expand(B, vocab_size) # (B, vocab_size)
142
+ mask = (full_cand_tokens != current_tokens.unsqueeze(1)) # (B, vocab_size)
143
+ # Now, cand_tokens contains only candidate tokens that differ from the current token.
144
+ cand_tokens = torch.masked_select(full_cand_tokens, mask).view(B, vocab_size - 1) # (B, vocab_size-1)
145
+
146
+ # 3. Create candidate sequences by replacing the token at the selected position.
147
+ new_x = x_t.unsqueeze(1).expand(B, vocab_size, L).clone()
148
+ new_x = new_x[mask].view(B, vocab_size - 1, L) # (B, vocab_size-1, L)
149
+ new_x[batch_idx, :, pos_indices] = cand_tokens
150
+ new_x_flat = new_x.view(B * (vocab_size - 1), L)
151
+ improvements_list = []
152
+ with torch.no_grad():
153
+ count = 0
154
+ input_seqs_cand_smiles, valid_mask_cand = tokenizer.batch_decode(new_x_flat, convert_to_smiles=True, cyclic=args.cyclic)
155
+ input_seqs_cand_aa = tokenizer.batch_decode(new_x_flat, convert_to_smiles=False)
156
+
157
+ input_seqs_orig_smiles, valid_mask_orig = tokenizer.batch_decode(x_t, convert_to_smiles=True, cyclic=args.cyclic)
158
+ input_seqs_orig_aa = tokenizer.batch_decode(x_t, convert_to_smiles=False)
159
+
160
+ for i, s in enumerate(s_models):
161
+ if i == 0:
162
+ candidate_scores = s(input_seqs_cand_aa) * valid_mask_cand
163
+ base_score = s(input_seqs_orig_aa) * valid_mask_orig
164
+ else:
165
+ candidate_scores = s(input_seqs_cand_smiles) * valid_mask_cand
166
+ base_score = s(input_seqs_orig_smiles) * valid_mask_orig
167
+
168
+ if isinstance(candidate_scores, tuple):
169
+ for k, score in enumerate(candidate_scores):
170
+ improvement = candidate_scores[k].view(B, vocab_size - 1) - base_score[k].unsqueeze(1)
171
+ improvement = improvement.float().to(device)
172
+ improvement *= importance[count]
173
+ improvements_list.append(improvement.unsqueeze(2))
174
+ count += 1
175
+ else:
176
+ improvement = candidate_scores.view(B, vocab_size - 1) - base_score.unsqueeze(1)
177
+ improvement = improvement.float().to(device)
178
+ improvement *= importance[count]
179
+ improvements_list.append(improvement.unsqueeze(2)) # (B, vocab_size-1, 1)
180
+ count += 1
181
+
182
+ improvement_values = torch.cat(improvements_list, dim=2) # (B, vocab_size-1, N)
183
+
184
+ invalid_mask = cand_tokens.unsqueeze(-1) == invalid_tokens.view(1, 1, -1)
185
+ final_invalid_mask = invalid_mask.any(dim=-1)
186
+ improvement_values[final_invalid_mask] = -10.0
187
+
188
+
189
+ # 5. Compute ranking scores I_n
190
+ ranks = torch.argsort(torch.argsort(improvement_values, dim=1), dim=1).float() + 1 # (B, vocab_size-1, N)
191
+ I_n = ranks / float(vocab_size - 1)
192
+ avg_I = I_n.mean(dim=2)
193
+ norm_avg_I = z_score_norm(avg_I) # (B, vocab_size-1)
194
+
195
+ # 6. Compute directional score D
196
+ D = (improvement_values * w.view(1, 1, -1)).sum(dim=2)
197
+ norm_D = z_score_norm(D) # (B, vocab_size-1)
198
+
199
+ # 7. Combine the scores
200
+ delta_S = norm_avg_I + args.lambda_ * norm_D # (B, vocab_size-1)
201
+
202
+ # 9. Update the guided velocities at the selected positions.
203
+ factor = torch.exp(args.beta * delta_S) # (B, vocab_size-1)
204
+ factor = torch.clamp(factor, min=-100, max=100)
205
+
206
+ guided_u_t[batch_idx.unsqueeze(1), pos_indices.unsqueeze(1), cand_tokens] = u_t[batch_idx.unsqueeze(1), pos_indices.unsqueeze(1), cand_tokens] * factor
207
+
208
+ # 10. For the self-transition (current token) at the selected position,
209
+ # set its guided velocity to be the negative sum of the updated off-diagonals.
210
+ updated_vals = guided_u_t[batch_idx, pos_indices, :] # (B, vocab_size)
211
+ sum_off_diag = updated_vals.sum(dim=1) - updated_vals[batch_idx, current_tokens]
212
+ guided_u_t[batch_idx, pos_indices, current_tokens] = -sum_off_diag
213
+
214
+ return guided_u_t, pos_indices, cand_tokens, improvement_values, delta_S
215
+
216
+
217
  def adaptive_hypercone_filtering(improvement_values, cand_tokens, delta_S, w, Phi, args, ema_r_t=None):
218
  B, num_candidates, N = improvement_values.shape
219
  device = improvement_values.device