Hugo Flores commited on
Commit
183d21c
1 Parent(s): d6a029b

add opt for no prefix and no suffix

Browse files
Files changed (1) hide show
  1. vampnet/interface.py +8 -1
vampnet/interface.py CHANGED
@@ -189,7 +189,6 @@ class Interface:
189
  c_vamp = torch.cat([prefix_codes, suffix_codes], dim=-1)
190
  return c_vamp
191
 
192
-
193
  def coarse_vamp_v2(
194
  self,
195
  signal,
@@ -296,6 +295,14 @@ class Interface:
296
 
297
  c_vamp['suffix'].insert(0, cz_generated)
298
 
 
 
 
 
 
 
 
 
299
 
300
  n_to_insert = c_seq_len - (cz_new_prefix.shape[-1] + cz_new_suffix.shape[-1])
301
  to_insert = torch.zeros(cz_new_prefix.shape[0], cz_new_prefix.shape[1], n_to_insert).long().to(self.device)
 
189
  c_vamp = torch.cat([prefix_codes, suffix_codes], dim=-1)
190
  return c_vamp
191
 
 
192
  def coarse_vamp_v2(
193
  self,
194
  signal,
 
295
 
296
  c_vamp['suffix'].insert(0, cz_generated)
297
 
298
+ else:
299
+ # we have no prefix or suffix, so we'll just use the generated
300
+ # codes as the new prefix and suffix
301
+ cz_new_prefix = cz_generated.clone()
302
+ cz_new_suffix = _cz[:, :, :0].clone()
303
+
304
+ c_vamp['prefix'].append(cz_generated)
305
+
306
 
307
  n_to_insert = c_seq_len - (cz_new_prefix.shape[-1] + cz_new_suffix.shape[-1])
308
  to_insert = torch.zeros(cz_new_prefix.shape[0], cz_new_prefix.shape[1], n_to_insert).long().to(self.device)