AlienChen commited on
Commit
d17476e
·
verified ·
1 Parent(s): e5c0506

Update flow_matching/solver/discrete_solver.py

Browse files
flow_matching/solver/discrete_solver.py CHANGED
@@ -10,7 +10,6 @@ from typing import Callable, Optional, Union
10
 
11
  import torch
12
  from torch import Tensor
13
- import gc
14
  from torch.nn import functional as F
15
 
16
  from flow_matching.path import MixtureDiscreteProbPath
@@ -21,7 +20,7 @@ from .utils import get_nearest_times
21
  from ..utils.multi_guidance import *
22
 
23
  try:
24
- from tqdm import tqdm
25
 
26
  TQDM_AVAILABLE = True
27
  except ImportError:
@@ -275,18 +274,12 @@ class MixtureDiscreteEulerSolver(Solver):
275
  score_models: list = None,
276
  num_objectives: int = 1,
277
  weights: list = None,
 
 
 
278
  **model_extras,
279
  ) -> Tensor:
280
 
281
- # score_list_0 = []
282
- # score_list_1 = []
283
- # score_list_2 = []
284
- # score_list_3 = []
285
- # score_list_4 = []
286
- # score_list_5 = []
287
-
288
- import pdb
289
-
290
  if not div_free == 0.0:
291
  raise NotImplementedError
292
 
@@ -331,7 +324,7 @@ class MixtureDiscreteEulerSolver(Solver):
331
  raise ImportError(
332
  "tqdm is required for verbose mode. Please install it."
333
  )
334
- ctx = tqdm(total=t_final, desc=f"NFE: {steps_counter}")
335
  else:
336
  ctx = nullcontext()
337
 
@@ -342,7 +335,7 @@ class MixtureDiscreteEulerSolver(Solver):
342
  w, _ = select_random_weight_vector(num_objectives, args.num_div)
343
  # w = torch.tensor([0.2, 0.7, 0.05, 0.05]).to(x_t.device)
344
  w = w.to(device=x_init.device)
345
- print(f"Weight Vector: {w}")
346
  Phi = args.Phi_init
347
  ema_r_t = None
348
 
@@ -362,14 +355,10 @@ class MixtureDiscreteEulerSolver(Solver):
362
  d_k_t = scheduler_output.d_alpha_t
363
  u_t = d_k_t / (1 - k_t) * p_1t
364
 
365
- guided_u_t, pos_indices, cand_tokens, improvement_values, delta_S = guided_transition_scoring(x_t, u_t, w, score_models, t, w, args)
366
 
367
  best_candidate, accepted_mask, valid_mask, Phi, ema_r_t = adaptive_hypercone_filtering(improvement_values, cand_tokens, delta_S, w, Phi, args, ema_r_t=ema_r_t)
368
 
369
- # best_candidate, accepted_mask, valid_mask, Phi, ema_r_t = hypercone_filtering(improvement_values, cand_tokens, delta_S, w, Phi, args, ema_r_t=ema_r_t)
370
-
371
- # best_candidate = get_best_candidate(improvement_values, cand_tokens, delta_S)
372
-
373
  x_t = euler_sample(x_t, pos_indices, best_candidate, guided_u_t, h)
374
 
375
 
@@ -377,37 +366,165 @@ class MixtureDiscreteEulerSolver(Solver):
377
  t = t + h
378
 
379
  scores = []
 
 
 
380
  for i, s in enumerate(score_models):
381
  sig = inspect.signature(s.forward) if hasattr(s, 'forward') else inspect.signature(s)
382
  if 't' in sig.parameters:
383
- candidate_scores = s(x_t, 1)
384
  else:
385
- candidate_scores = s(x_t)
386
 
387
  if isinstance(candidate_scores, tuple):
388
  for score in candidate_scores:
389
  scores.append(score.item())
390
  else:
391
  scores.append(candidate_scores.item())
392
- print(scores)
393
-
394
- # print(f"Score {i}: {[round(s.item(), 4) for s in candidate_scores]}")
395
- # if i == 0:
396
- # score_list_0.append(round(candidate_scores[0].item(), 2))
397
- # # score_list_0.append(round(1-candidate_scores.item(), 2))
398
- # # score_list_1.append(round(candidate_scores[1].item(), 2))
399
- # if i == 1:
400
- # score_list_1.append(round(candidate_scores.item(), 2))
401
- # # score_list_2.append(round(candidate_scores.item(), 2))
402
- # if i == 2:
403
- # score_list_2.append(round(candidate_scores.item(), 2))
404
- # if i == 3:
405
- # score_list_3.append(round(candidate_scores.item(), 2))
406
- # if i == 4:
407
- # score_list_4.append(round(candidate_scores.item(), 2))
408
- # if i == 5:
409
- # score_list_5.append(round(candidate_scores.item(), 2))
410
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
411
 
412
  if return_intermediates and (t in time_grid):
413
  res.append(x_t.clone())
 
10
 
11
  import torch
12
  from torch import Tensor
 
13
  from torch.nn import functional as F
14
 
15
  from flow_matching.path import MixtureDiscreteProbPath
 
20
  from ..utils.multi_guidance import *
21
 
22
  try:
23
+ from tqdm.auto import tqdm
24
 
25
  TQDM_AVAILABLE = True
26
  except ImportError:
 
274
  score_models: list = None,
275
  num_objectives: int = 1,
276
  weights: list = None,
277
+ tokenizer = None,
278
+ fixed_positions=None,
279
+ invalid_tokens=None,
280
  **model_extras,
281
  ) -> Tensor:
282
 
 
 
 
 
 
 
 
 
 
283
  if not div_free == 0.0:
284
  raise NotImplementedError
285
 
 
324
  raise ImportError(
325
  "tqdm is required for verbose mode. Please install it."
326
  )
327
+ ctx = tqdm(total=n_steps, desc=f"NFE", dynamic_ncols=True, leave=True, bar_format="{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt}{postfix}")
328
  else:
329
  ctx = nullcontext()
330
 
 
335
  w, _ = select_random_weight_vector(num_objectives, args.num_div)
336
  # w = torch.tensor([0.2, 0.7, 0.05, 0.05]).to(x_t.device)
337
  w = w.to(device=x_init.device)
338
+ # print(f"Weight Vector: {w}")
339
  Phi = args.Phi_init
340
  ema_r_t = None
341
 
 
355
  d_k_t = scheduler_output.d_alpha_t
356
  u_t = d_k_t / (1 - k_t) * p_1t
357
 
358
+ guided_u_t, pos_indices, cand_tokens, improvement_values, delta_S = guided_transition_scoring(x_t, u_t, w, score_models, t, w, tokenizer, args, fixed_positions, invalid_tokens)
359
 
360
  best_candidate, accepted_mask, valid_mask, Phi, ema_r_t = adaptive_hypercone_filtering(improvement_values, cand_tokens, delta_S, w, Phi, args, ema_r_t=ema_r_t)
361
 
 
 
 
 
362
  x_t = euler_sample(x_t, pos_indices, best_candidate, guided_u_t, h)
363
 
364
 
 
366
  t = t + h
367
 
368
  scores = []
369
+ input_seqs = tokenizer.batch_decode(x_t)
370
+ input_seqs = [seq.replace(' ', '')[5:-5] for seq in input_seqs]
371
+
372
  for i, s in enumerate(score_models):
373
  sig = inspect.signature(s.forward) if hasattr(s, 'forward') else inspect.signature(s)
374
  if 't' in sig.parameters:
375
+ candidate_scores = s(input_seqs, 1)
376
  else:
377
+ candidate_scores = s(input_seqs)
378
 
379
  if isinstance(candidate_scores, tuple):
380
  for score in candidate_scores:
381
  scores.append(score.item())
382
  else:
383
  scores.append(candidate_scores.item())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
 
385
+ postfix = {}
386
+ for i, objective in enumerate(args.objectives):
387
+ postfix[objective] = scores[i]
388
+
389
+ ctx.set_description(f"NFE: {steps_counter}", refresh=False)
390
+ ctx.set_postfix({k: f"{v:.3f}" for k, v in postfix.items()}, refresh=False)
391
+ ctx.update(1)
392
+
393
+ if return_intermediates and (t in time_grid):
394
+ res.append(x_t.clone())
395
+
396
+ if return_intermediates:
397
+ if step_size is None:
398
+ return torch.stack(res, dim=0)
399
+ else:
400
+ return torch.stack(res, dim=0)[order]
401
+ else:
402
+ return x_t
403
+
404
+
405
+ def multi_guidance_sample_uaa(
406
+ self,
407
+ args,
408
+ x_init: Tensor,
409
+ step_size: Optional[float],
410
+ div_free: Union[float, Callable[[float], float]] = 0.0,
411
+ dtype_categorical: torch.dtype = torch.float32,
412
+ time_grid: Tensor = torch.tensor([0.0, 1.0]),
413
+ return_intermediates: bool = False,
414
+ verbose: bool = False,
415
+ score_models: list = None,
416
+ num_objectives: int = 1,
417
+ weights: list = None,
418
+ tokenizer = None,
419
+ fixed_positions=None,
420
+ invalid_tokens=None,
421
+ **model_extras,
422
+ ) -> Tensor:
423
+
424
+ if not div_free == 0.0:
425
+ raise NotImplementedError
426
+
427
+ # Initialize the current state `x_t` with the initial state `X_0`.
428
+ time_grid = time_grid.to(device=x_init.device)
429
+
430
+ if step_size is None:
431
+ # If step_size is None then set the t discretization to time_grid.
432
+ t_discretization = time_grid
433
+ n_steps = len(time_grid) - 1
434
+ else:
435
+ # If step_size is float then t discretization is uniform with step size set by step_size.
436
+ t_init = time_grid[0].item()
437
+ t_final = time_grid[-1].item()
438
+ assert (
439
+ t_final - t_init
440
+ ) > step_size, f"Time interval [time_grid[0], time_grid[-1]] must be larger than step_size. Got a time interval [{t_init}, {t_final}] and step_size {step_size}."
441
+
442
+ n_steps = ceil((t_final - t_init) / step_size)
443
+ t_discretization = torch.tensor(
444
+ [t_init + step_size * i for i in range(n_steps)] + [t_final],
445
+ device=x_init.device,
446
+ )
447
+
448
+ if return_intermediates:
449
+ # get order of intermediate steps:
450
+ order = torch.argsort(time_grid)
451
+ # Compute intermediate steps to return via nearest points in t_discretization to time_grid.
452
+ time_grid = get_nearest_times(
453
+ time_grid=time_grid, t_discretization=t_discretization
454
+ )
455
+
456
+ x_t = x_init.clone()
457
+ steps_counter = 0
458
+ res = []
459
+
460
+ if return_intermediates:
461
+ res = [x_init.clone()]
462
+
463
+ if verbose:
464
+ if not TQDM_AVAILABLE:
465
+ raise ImportError(
466
+ "tqdm is required for verbose mode. Please install it."
467
+ )
468
+ ctx = tqdm(total=t_final, desc=f"NFE: {steps_counter}")
469
+ else:
470
+ ctx = nullcontext()
471
+
472
+ # Randomly sample a weight vector
473
+ if weights is not None:
474
+ w = torch.tensor(weights).to(device=x_init.device)
475
+ else:
476
+ w, _ = select_random_weight_vector(num_objectives, args.num_div)
477
+ # w = torch.tensor([0.2, 0.7, 0.05, 0.05]).to(x_t.device)
478
+ w = w.to(device=x_init.device)
479
+ # print(f"Weight Vector: {w}")
480
+ Phi = args.Phi_init
481
+ ema_r_t = None
482
+
483
+ with ctx:
484
+ for i in range(n_steps):
485
+ t = t_discretization[i : i + 1]
486
+ h = t_discretization[i + 1 : i + 2] - t_discretization[i : i + 1]
487
+
488
+ p_1t = self.model(x=x_t, t=t.repeat(x_t.shape[0]), **model_extras)
489
+ x_1 = categorical(p_1t.to(dtype=dtype_categorical))
490
+
491
+ # Checks if final step
492
+ if i != n_steps - 1:
493
+ # Compute u_t(y,x)
494
+ scheduler_output = self.path.scheduler(t=t)
495
+ k_t = scheduler_output.alpha_t
496
+ d_k_t = scheduler_output.d_alpha_t
497
+ u_t = d_k_t / (1 - k_t) * p_1t
498
+
499
+ guided_u_t, pos_indices, cand_tokens, improvement_values, delta_S = guided_transition_scoring_uaa(x_t, u_t, w, score_models, t, w, tokenizer, args, fixed_positions, invalid_tokens)
500
+
501
+ best_candidate, accepted_mask, valid_mask, Phi, ema_r_t = adaptive_hypercone_filtering(improvement_values, cand_tokens, delta_S, w, Phi, args, ema_r_t=ema_r_t)
502
+
503
+ # best_candidate, accepted_mask, valid_mask, Phi, ema_r_t = hypercone_filtering(improvement_values, cand_tokens, delta_S, w, Phi, args, ema_r_t=ema_r_t)
504
+
505
+ # best_candidate = get_best_candidate(improvement_values, cand_tokens, delta_S)
506
+
507
+ x_t = euler_sample(x_t, pos_indices, best_candidate, guided_u_t, h)
508
+
509
+ steps_counter += 1
510
+ t = t + h
511
+
512
+ scores = []
513
+ input_seqs_smiles, _ = tokenizer.batch_decode(x_t, convert_to_smiles=True, cyclic=args.cyclic)
514
+ input_seqs_aa = tokenizer.batch_decode(x_t, convert_to_smiles=False)
515
+
516
+ for i, s in enumerate(score_models):
517
+ if i == 0:
518
+ score = s(input_seqs_aa)
519
+ else:
520
+ score = s(input_seqs_smiles)
521
+
522
+ if isinstance(score, tuple):
523
+ for s in score:
524
+ scores.append(s.item())
525
+ else:
526
+ scores.append(score.item())
527
+ ctx.write(scores)
528
 
529
  if return_intermediates and (t in time_grid):
530
  res.append(x_t.clone())