jason-on-salt-a40 commited on
Commit
579d79b
1 Parent(s): 78774ba

better hf integration

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. models/voicecraft.py +37 -10
app.py CHANGED
@@ -93,7 +93,7 @@ def load_models(whisper_backend_name, whisper_model_name, alignment_model_name,
93
  transcribe_model = WhisperxModel(whisper_model_name, align_model)
94
 
95
  voicecraft_name = f"{voicecraft_model_name}.pth"
96
- model = voicecraft.VoiceCraftHF.from_pretrained(f"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}")
97
  phn2num = model.args.phn2num
98
  config = model.args
99
  model.to(device)
 
93
  transcribe_model = WhisperxModel(whisper_model_name, align_model)
94
 
95
  voicecraft_name = f"{voicecraft_model_name}.pth"
96
+ model = voicecraft.VoiceCraft.from_pretrained(f"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}")
97
  phn2num = model.args.phn2num
98
  config = model.args
99
  model.to(device)
models/voicecraft.py CHANGED
@@ -3,6 +3,7 @@ import random
3
  import numpy as np
4
  import logging
5
  import argparse, copy
 
6
  import torch
7
  import torch.nn as nn
8
  import torch.nn.functional as F
@@ -17,8 +18,11 @@ from .modules.transformer import (
17
  TransformerEncoderLayer,
18
  )
19
  from .codebooks_patterns import DelayedPatternProvider
20
- from huggingface_hub import PyTorchModelHubMixin
21
  from argparse import Namespace
 
 
 
22
  def top_k_top_p_filtering(
23
  logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
24
  ):
@@ -83,9 +87,31 @@ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
83
 
84
 
85
 
86
- class VoiceCraft(nn.Module):
87
- def __init__(self, args):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  super().__init__()
 
 
 
 
 
 
 
89
  self.args = copy.copy(args)
90
  self.pattern = DelayedPatternProvider(n_q=self.args.n_codebooks)
91
  if not getattr(self.args, "special_first", False):
@@ -97,7 +123,7 @@ class VoiceCraft(nn.Module):
97
  if self.args.eos > 0:
98
  assert self.args.eos != self.args.audio_pad_token and self.args.eos != self.args.empty_token, self.args.eos
99
  self.eos = nn.Parameter(torch.full((self.args.n_codebooks, 1), self.args.eos, dtype=torch.long), requires_grad=False) # [K 1]
100
- if type(self.args.audio_vocab_size) == str:
101
  self.args.audio_vocab_size = eval(self.args.audio_vocab_size)
102
 
103
  self.n_text_tokens = self.args.text_vocab_size + 1
@@ -410,6 +436,10 @@ class VoiceCraft(nn.Module):
410
  .expand(-1, self.args.nhead, -1, -1)
411
  .reshape(bsz * self.args.nhead, 1, src_len)
412
  )
 
 
 
 
413
  xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
414
 
415
  new_attn_mask = torch.zeros_like(xy_attn_mask)
@@ -455,8 +485,10 @@ class VoiceCraft(nn.Module):
455
  before padding.
456
  """
457
  x, x_lens, y, y_lens = batch["x"], batch["x_lens"], batch["y"], batch["y_lens"]
 
 
458
  x = x[:, :x_lens.max()] # this deal with gradient accumulation, where x_lens.max() might not be longer than the length of the current slice of x
459
- y = y[:, :y_lens.max()]
460
  assert x.ndim == 2, x.shape
461
  assert x_lens.ndim == 1, x_lens.shape
462
  assert y.ndim == 3 and y.shape[1] == self.args.n_codebooks, y.shape
@@ -1405,8 +1437,3 @@ class VoiceCraft(nn.Module):
1405
  flatten_gen = flatten_gen - int(self.args.n_special)
1406
 
1407
  return res, flatten_gen[0].unsqueeze(0)
1408
-
1409
- class VoiceCraftHF(VoiceCraft, PyTorchModelHubMixin, repo_url="https://github.com/jasonppy/VoiceCraft", tags=["Text-to-Speech", "VoiceCraft"]):
1410
- def __init__(self, config: dict):
1411
- args = Namespace(**config)
1412
- super().__init__(args)
 
3
  import numpy as np
4
  import logging
5
  import argparse, copy
6
+ from typing import Dict, Optional
7
  import torch
8
  import torch.nn as nn
9
  import torch.nn.functional as F
 
18
  TransformerEncoderLayer,
19
  )
20
  from .codebooks_patterns import DelayedPatternProvider
21
+
22
  from argparse import Namespace
23
+ from huggingface_hub import PyTorchModelHubMixin
24
+
25
+
26
  def top_k_top_p_filtering(
27
  logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
28
  ):
 
87
 
88
 
89
 
90
+ class VoiceCraft(
91
+ nn.Module,
92
+ PyTorchModelHubMixin,
93
+ library_name="voicecraft",
94
+ repo_url="https://github.com/jasonppy/VoiceCraft",
95
+ tags=["text-to-speech"],
96
+ ):
97
+ def __new__(cls, args: Optional[Namespace] = None, config: Optional[Dict] = None, **kwargs) -> "VoiceCraft":
98
+ # If initialized from Namespace args => convert to dict config for 'PyTorchModelHubMixin' to serialize it as config.json
99
+ # Won't affect instance initialization
100
+ if args is not None:
101
+ if config is not None:
102
+ raise ValueError("Cannot provide both `args` and `config`.")
103
+ config = vars(args)
104
+ return super().__new__(cls, args=args, config=config, **kwargs)
105
+
106
+ def __init__(self, args: Optional[Namespace] = None, config: Optional[Dict] = None):
107
  super().__init__()
108
+
109
+ # If loaded from HF Hub => convert config.json to Namespace args before initializing
110
+ if args is None:
111
+ if config is None:
112
+ raise ValueError("Either `args` or `config` must be provided.")
113
+ args = Namespace(**config)
114
+
115
  self.args = copy.copy(args)
116
  self.pattern = DelayedPatternProvider(n_q=self.args.n_codebooks)
117
  if not getattr(self.args, "special_first", False):
 
123
  if self.args.eos > 0:
124
  assert self.args.eos != self.args.audio_pad_token and self.args.eos != self.args.empty_token, self.args.eos
125
  self.eos = nn.Parameter(torch.full((self.args.n_codebooks, 1), self.args.eos, dtype=torch.long), requires_grad=False) # [K 1]
126
+ if isinstance(self.args.audio_vocab_size, str):
127
  self.args.audio_vocab_size = eval(self.args.audio_vocab_size)
128
 
129
  self.n_text_tokens = self.args.text_vocab_size + 1
 
436
  .expand(-1, self.args.nhead, -1, -1)
437
  .reshape(bsz * self.args.nhead, 1, src_len)
438
  )
439
+ # Check shapes and resize+broadcast as necessary
440
+ if xy_attn_mask.shape != _xy_padding_mask.shape:
441
+ assert xy_attn_mask.ndim + 1 == _xy_padding_mask.ndim, f"xy_attn_mask.shape: {xy_attn_mask.shape}, _xy_padding_mask: {_xy_padding_mask.shape}"
442
+ xy_attn_mask = xy_attn_mask.unsqueeze(0).repeat(_xy_padding_mask.shape[0], 1, 1) # Example approach
443
  xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
444
 
445
  new_attn_mask = torch.zeros_like(xy_attn_mask)
 
485
  before padding.
486
  """
487
  x, x_lens, y, y_lens = batch["x"], batch["x_lens"], batch["y"], batch["y_lens"]
488
+ if len(x) == 0:
489
+ return None
490
  x = x[:, :x_lens.max()] # this deal with gradient accumulation, where x_lens.max() might not be longer than the length of the current slice of x
491
+ y = y[:, :, :y_lens.max()]
492
  assert x.ndim == 2, x.shape
493
  assert x_lens.ndim == 1, x_lens.shape
494
  assert y.ndim == 3 and y.shape[1] == self.args.n_codebooks, y.shape
 
1437
  flatten_gen = flatten_gen - int(self.args.n_special)
1438
 
1439
  return res, flatten_gen[0].unsqueeze(0)