Plachta commited on
Commit
cf8f15a
1 Parent(s): a55de4e

Added safety measure

Browse files
Files changed (1) hide show
  1. models/vallex.py +5 -0
models/vallex.py CHANGED
@@ -588,6 +588,11 @@ class VALLE(VALLF):
588
  print(f"Current memory used: {memory_used:.2f} MB")
589
  break
590
 
 
 
 
 
 
591
  y = torch.concat([y, samples], dim=1)
592
 
593
  codes = [y[:, prefix_len + int(self.ar_audio_prepend_bos) :]]
 
588
  print(f"Current memory used: {memory_used:.2f} MB")
589
  break
590
 
591
+ # safety measure, break if token sequence too long
592
+ if y.shape[1] > 2250:
593
+ print(f"VALL-E EOS [{prompts.shape[1]} -> {y.shape[1]}]")
594
+ break
595
+
596
  y = torch.concat([y, samples], dim=1)
597
 
598
  codes = [y[:, prefix_len + int(self.ar_audio_prepend_bos) :]]