Hugo Flores commited on
Commit
326b5bb
1 Parent(s): 04c5b94

fix: sample prefix suffix

Browse files
Files changed (1) hide show
  1. scripts/exp/train.py +4 -3
scripts/exp/train.py CHANGED
@@ -216,6 +216,7 @@ def accuracy(
216
  return accuracy
217
 
218
  def sample_prefix_suffix_amt(
 
219
  n_batch,
220
  prefix_amt,
221
  suffix_amt,
@@ -362,7 +363,7 @@ def train(
362
  n_batch = z.shape[0]
363
  r = rng.draw(n_batch)[:, 0].to(accel.device)
364
 
365
- n_prefix, n_suffix = sample_prefix_suffix_amt(
366
  n_batch=n_batch, prefix_amt=prefix_amt, suffix_amt=suffix_amt,
367
  prefix_dropout=prefix_dropout, suffix_dropout=suffix_dropout,
368
  rng=rng
@@ -448,7 +449,7 @@ def train(
448
  n_batch = z.shape[0]
449
  r = rng.draw(n_batch)[:, 0].to(accel.device)
450
 
451
- n_prefix, n_suffix = sample_prefix_suffix_amt(
452
  n_batch=n_batch, prefix_amt=prefix_amt, suffix_amt=suffix_amt,
453
  prefix_dropout=prefix_dropout, suffix_dropout=suffix_dropout,
454
  rng=rng
@@ -606,7 +607,7 @@ def train(
606
 
607
  n_batch = z.shape[0]
608
 
609
- n_prefix, n_suffix = sample_prefix_suffix_amt(
610
  n_batch=n_batch, prefix_amt=prefix_amt, suffix_amt=suffix_amt,
611
  prefix_dropout=prefix_dropout, suffix_dropout=suffix_dropout,
612
  rng=rng
 
216
  return accuracy
217
 
218
  def sample_prefix_suffix_amt(
219
+ z,
220
  n_batch,
221
  prefix_amt,
222
  suffix_amt,
 
363
  n_batch = z.shape[0]
364
  r = rng.draw(n_batch)[:, 0].to(accel.device)
365
 
366
+ n_prefix, n_suffix = sample_prefix_suffix_amt(z=z,
367
  n_batch=n_batch, prefix_amt=prefix_amt, suffix_amt=suffix_amt,
368
  prefix_dropout=prefix_dropout, suffix_dropout=suffix_dropout,
369
  rng=rng
 
449
  n_batch = z.shape[0]
450
  r = rng.draw(n_batch)[:, 0].to(accel.device)
451
 
452
+ n_prefix, n_suffix = sample_prefix_suffix_amt(z=z,
453
  n_batch=n_batch, prefix_amt=prefix_amt, suffix_amt=suffix_amt,
454
  prefix_dropout=prefix_dropout, suffix_dropout=suffix_dropout,
455
  rng=rng
 
607
 
608
  n_batch = z.shape[0]
609
 
610
+ n_prefix, n_suffix = sample_prefix_suffix_amt(z=z,
611
  n_batch=n_batch, prefix_amt=prefix_amt, suffix_amt=suffix_amt,
612
  prefix_dropout=prefix_dropout, suffix_dropout=suffix_dropout,
613
  rng=rng