Plachta commited on
Commit
1d0192f
1 Parent(s): 06e7a0f

updated requirements

Browse files
Files changed (1) hide show
  1. models/vallex.py +8 -3
models/vallex.py CHANGED
@@ -22,7 +22,6 @@ import torch.nn.functional as F
22
  # from icefall.utils import make_pad_mask
23
  # from torchmetrics.classification import MulticlassAccuracy
24
 
25
-
26
  from modules.embedding import SinePositionalEmbedding, TokenEmbedding
27
  from modules.transformer import (
28
  AdaptiveLayerNorm,
@@ -493,7 +492,10 @@ class VALLE(VALLF):
493
  x = self.ar_text_embedding(text)
494
  # Add language embedding
495
  prompt_language_id = torch.LongTensor(np.array([self.language_ID[prompt_language]])).to(x.device)
496
- text_language_id = torch.LongTensor(np.array([self.language_ID[text_language]])).to(x.device)
 
 
 
497
  x[:, :enroll_x_lens, :] += self.ar_language_embedding(prompt_language_id)
498
  x[:, enroll_x_lens:, :] += self.ar_language_embedding(text_language_id)
499
  x = self.ar_text_prenet(x)
@@ -599,7 +601,10 @@ class VALLE(VALLF):
599
  x = self.nar_text_embedding(text)
600
  # Add language embedding
601
  prompt_language_id = torch.LongTensor(np.array([self.language_ID[prompt_language]])).to(x.device)
602
- text_language_id = torch.LongTensor(np.array([self.language_ID[text_language]])).to(x.device)
 
 
 
603
  x[:, :enroll_x_lens, :] += self.nar_language_embedding(prompt_language_id)
604
  x[:, enroll_x_lens:, :] += self.nar_language_embedding(text_language_id)
605
  x = self.nar_text_prenet(x)
 
22
  # from icefall.utils import make_pad_mask
23
  # from torchmetrics.classification import MulticlassAccuracy
24
 
 
25
  from modules.embedding import SinePositionalEmbedding, TokenEmbedding
26
  from modules.transformer import (
27
  AdaptiveLayerNorm,
 
492
  x = self.ar_text_embedding(text)
493
  # Add language embedding
494
  prompt_language_id = torch.LongTensor(np.array([self.language_ID[prompt_language]])).to(x.device)
495
+ if isinstance(text_language, str):
496
+ text_language_id = torch.LongTensor(np.array([self.language_ID[text_language]])).to(x.device)
497
+ elif isinstance(text_language, List):
498
+ text_language_id = torch.LongTensor(np.array([self.language_ID[tl] for tl in text_language])).to(x.device)
499
  x[:, :enroll_x_lens, :] += self.ar_language_embedding(prompt_language_id)
500
  x[:, enroll_x_lens:, :] += self.ar_language_embedding(text_language_id)
501
  x = self.ar_text_prenet(x)
 
601
  x = self.nar_text_embedding(text)
602
  # Add language embedding
603
  prompt_language_id = torch.LongTensor(np.array([self.language_ID[prompt_language]])).to(x.device)
604
+ if isinstance(text_language, str):
605
+ text_language_id = torch.LongTensor(np.array([self.language_ID[text_language]])).to(x.device)
606
+ elif isinstance(text_language, List):
607
+ text_language_id = torch.LongTensor(np.array([self.language_ID[tl] for tl in text_language])).to(x.device)
608
  x[:, :enroll_x_lens, :] += self.nar_language_embedding(prompt_language_id)
609
  x[:, enroll_x_lens:, :] += self.nar_language_embedding(text_language_id)
610
  x = self.nar_text_prenet(x)