SWivid commited on
Commit
d15ef36
·
1 Parent(s): b899a35

fix address #191

Browse files
model/backbones/dit.py CHANGED
@@ -45,9 +45,9 @@ class TextEmbedding(nn.Module):
45
  self.extra_modeling = False
46
 
47
  def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
48
- batch, text_len = text.shape[0], text.shape[1]
49
  text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
50
  text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
 
51
  text = F.pad(text, (0, seq_len - text_len), value=0)
52
 
53
  if drop_text: # cfg for text
 
45
  self.extra_modeling = False
46
 
47
  def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
 
48
  text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
49
  text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
50
+ batch, text_len = text.shape[0], text.shape[1]
51
  text = F.pad(text, (0, seq_len - text_len), value=0)
52
 
53
  if drop_text: # cfg for text
model/backbones/unett.py CHANGED
@@ -48,9 +48,9 @@ class TextEmbedding(nn.Module):
48
  self.extra_modeling = False
49
 
50
  def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
51
- batch, text_len = text.shape[0], text.shape[1]
52
  text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
53
  text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
 
54
  text = F.pad(text, (0, seq_len - text_len), value=0)
55
 
56
  if drop_text: # cfg for text
 
48
  self.extra_modeling = False
49
 
50
  def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
 
51
  text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
52
  text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
53
+ batch, text_len = text.shape[0], text.shape[1]
54
  text = F.pad(text, (0, seq_len - text_len), value=0)
55
 
56
  if drop_text: # cfg for text