zyingt commited on
Commit
2023a17
1 Parent(s): b07c30e

support timbre confusion

Browse files
Files changed (1) hide show
  1. utils/util.py +14 -0
utils/util.py CHANGED
@@ -29,6 +29,20 @@ from utils.hparam import HParams
29
  import logging
30
  from logging import handlers
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  def str2bool(v):
34
  """Used in argparse.ArgumentParser.add_argument to indicate
 
29
  import logging
30
  from logging import handlers
31
 
32
+ def interpolate_embeddings(embedding1, embedding2, alpha):
33
+ """
34
+ Interpolate between two embeddings.
35
+
36
+ :param embedding1: First embedding vector.
37
+ :param embedding2: Second embedding vector.
38
+ :param alpha: Interpolation factor (0 to 1).
39
+ :return: Interpolated embedding.
40
+ """
41
+ if not 0 <= alpha <= 1:
42
+ raise ValueError("Alpha should be between 0 and 1")
43
+ fused_embedding = (1 - alpha) * np.array(embedding1.cpu()) + alpha * np.array(embedding2.cpu())
44
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
+ return torch.from_numpy(fused_embedding).to(device)
46
 
47
  def str2bool(v):
48
  """Used in argparse.ArgumentParser.add_argument to indicate