arnavmehta7
commited on
Update inference.py (#4)
Browse files- Update inference.py (5840a7df2257c4c60e02973caec0dbfd97217f5b)
- Update inference.py (f4a06e5d8cb52e048267289fad984e6ae81b04d2)
- inference.py +79 -16
inference.py
CHANGED
@@ -4,7 +4,7 @@ import torch.nn as nn
|
|
4 |
import torch.nn.functional as F
|
5 |
import logging
|
6 |
import json
|
7 |
-
from typing import Optional
|
8 |
from pathlib import Path
|
9 |
from dataclasses import dataclass
|
10 |
import os
|
@@ -18,6 +18,8 @@ from mars5.minbpe.codebook import CodebookTokenizer
|
|
18 |
from mars5.ar_generate import ar_generate
|
19 |
from mars5.utils import nuke_weight_norm
|
20 |
from mars5.trim import trim
|
|
|
|
|
21 |
import tempfile
|
22 |
import logging
|
23 |
|
@@ -64,9 +66,7 @@ class InferenceConfig():
|
|
64 |
beam_width: int = 1 # only beam width of 1 is currently supported
|
65 |
ref_audio_pad: float = 0
|
66 |
|
67 |
-
|
68 |
-
class Mars5TTS(nn.Module):
|
69 |
-
|
70 |
def __init__(self, ar_ckpt, nar_ckpt, device: str = None) -> None:
|
71 |
super().__init__()
|
72 |
|
@@ -79,16 +79,16 @@ class Mars5TTS(nn.Module):
|
|
79 |
|
80 |
# save and load text tokenize
|
81 |
self.texttok = RegexTokenizer(GPT4_SPLIT_PATTERN)
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
# save and load speech tokenizer
|
87 |
-
sfn = tempfile.mkstemp(suffix='speechtok.model')[1]
|
88 |
self.speechtok = CodebookTokenizer(GPT4_SPLIT_PATTERN)
|
89 |
-
|
90 |
-
self.speechtok.load(
|
91 |
-
|
|
|
92 |
# keep track of tokenization things.
|
93 |
self.n_vocab = len(self.texttok.vocab) + len(self.speechtok.vocab)
|
94 |
self.n_text_vocab = len(self.texttok.vocab) + 1
|
@@ -111,7 +111,44 @@ class Mars5TTS(nn.Module):
|
|
111 |
self.vocos = Vocos.from_pretrained("charactr/vocos-encodec-24khz").to(self.device).eval()
|
112 |
nuke_weight_norm(self.codec)
|
113 |
nuke_weight_norm(self.vocos)
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
@torch.inference_mode
|
116 |
def vocode(self, tokens: Tensor) -> Tensor:
|
117 |
""" Vocodes tokens of shape (seq_len, n_q) """
|
@@ -126,6 +163,33 @@ class Mars5TTS(nn.Module):
|
|
126 |
wav_diffusion = self.vocos.decode(features, bandwidth_id=bandwidth_id)
|
127 |
return wav_diffusion.cpu().squeeze()[None]
|
128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
@torch.inference_mode
|
130 |
def tts(self, text: str, ref_audio: Tensor, ref_transcript: Optional[str] = None,
|
131 |
cfg: Optional[InferenceConfig] = InferenceConfig()) -> Tensor:
|
@@ -183,12 +247,12 @@ class Mars5TTS(nn.Module):
|
|
183 |
first_codec_idx = prompt.shape[-1] - n_speech_inp + 1
|
184 |
|
185 |
# ---> perform AR code generation
|
186 |
-
|
187 |
logging.debug(f"Raw acoustic prompt length: {raw_prompt_acoustic_len}")
|
188 |
|
189 |
ar_codes = ar_generate(self.texttok, self.speechtok, self.codeclm,
|
190 |
prompt, spk_ref_codec, first_codec_idx,
|
191 |
-
max_len=cfg.generate_max_len_override if cfg.generate_max_len_override > 1 else 2000,
|
|
|
192 |
temperature=cfg.temperature, topk=cfg.top_k, top_p=cfg.top_p, typical_p=cfg.typical_p,
|
193 |
alpha_frequency=cfg.freq_penalty, alpha_presence=cfg.presence_penalty, penalty_window=cfg.rep_penalty_window,
|
194 |
eos_penalty_decay=cfg.eos_penalty_decay, eos_penalty_factor=cfg.eos_penalty_factor,
|
@@ -211,7 +275,6 @@ class Mars5TTS(nn.Module):
|
|
211 |
x_padding_mask = torch.zeros((1, _x.shape[1]), dtype=torch.bool, device=_x.device)
|
212 |
|
213 |
# ---> perform DDPM NAR inference
|
214 |
-
|
215 |
T = self.default_T
|
216 |
diff = MultinomialDiffusion(self.diffusion_n_classes, timesteps=T, device=self.device)
|
217 |
|
|
|
4 |
import torch.nn.functional as F
|
5 |
import logging
|
6 |
import json
|
7 |
+
from typing import Optional, Dict, Type, Union, List, Tuple
|
8 |
from pathlib import Path
|
9 |
from dataclasses import dataclass
|
10 |
import os
|
|
|
18 |
from mars5.ar_generate import ar_generate
|
19 |
from mars5.utils import nuke_weight_norm
|
20 |
from mars5.trim import trim
|
21 |
+
from huggingface_hub import ModelHubMixin, hf_hub_download
|
22 |
+
from safetensors import safe_open
|
23 |
import tempfile
|
24 |
import logging
|
25 |
|
|
|
66 |
beam_width: int = 1 # only beam width of 1 is currently supported
|
67 |
ref_audio_pad: float = 0
|
68 |
|
69 |
+
class Mars5TTS(nn.Module, ModelHubMixin):
|
|
|
|
|
70 |
def __init__(self, ar_ckpt, nar_ckpt, device: str = None) -> None:
|
71 |
super().__init__()
|
72 |
|
|
|
79 |
|
80 |
# save and load text tokenize
|
81 |
self.texttok = RegexTokenizer(GPT4_SPLIT_PATTERN)
|
82 |
+
texttok_data = io.BytesIO(ar_ckpt['vocab']['texttok.model'].encode('utf-8'))
|
83 |
+
self.texttok.load(texttok_data)
|
84 |
+
texttok_data.close()
|
85 |
+
|
86 |
# save and load speech tokenizer
|
|
|
87 |
self.speechtok = CodebookTokenizer(GPT4_SPLIT_PATTERN)
|
88 |
+
speechtok_data = io.BytesIO(ar_ckpt['vocab']['speechtok.model'].encode('utf-8'))
|
89 |
+
self.speechtok.load(speechtok_data)
|
90 |
+
speechtok_data.close()
|
91 |
+
|
92 |
# keep track of tokenization things.
|
93 |
self.n_vocab = len(self.texttok.vocab) + len(self.speechtok.vocab)
|
94 |
self.n_text_vocab = len(self.texttok.vocab) + 1
|
|
|
111 |
self.vocos = Vocos.from_pretrained("charactr/vocos-encodec-24khz").to(self.device).eval()
|
112 |
nuke_weight_norm(self.codec)
|
113 |
nuke_weight_norm(self.vocos)
|
114 |
+
|
115 |
+
@classmethod
|
116 |
+
def _from_pretrained(
|
117 |
+
cls: Type["Mars5TTS"],
|
118 |
+
*,
|
119 |
+
model_id: str,
|
120 |
+
revision: Optional[str],
|
121 |
+
cache_dir: Optional[Union[str, Path]],
|
122 |
+
force_download: bool,
|
123 |
+
proxies: Optional[Dict],
|
124 |
+
local_files_only: bool,
|
125 |
+
token: Optional[Union[str, bool]],
|
126 |
+
device: str = None,
|
127 |
+
**model_kwargs,
|
128 |
+
) -> "Mars5TTS":
|
129 |
+
# Download files from Hub
|
130 |
+
print(f">>>>> Downloading AR model")
|
131 |
+
ar_ckpt_path = hf_hub_download(repo_id=model_id, filename="mars5_ar.safetensors", revision=revision, cache_dir=cache_dir, force_download=force_download, proxies=proxies, local_files_only=local_files_only, token=token)
|
132 |
+
print(f">>>>> Downloading NAR model")
|
133 |
+
nar_ckpt_path = hf_hub_download(repo_id=model_id, filename="mars5_nar.safetensors", revision=revision, cache_dir=cache_dir, force_download=force_download, proxies=proxies, local_files_only=local_files_only, token=token)
|
134 |
+
|
135 |
+
ar_ckpt = {}
|
136 |
+
with safe_open(ar_ckpt_path, framework='pt', device='cpu') as f:
|
137 |
+
metadata = f.metadata()
|
138 |
+
ar_ckpt['vocab'] = {'texttok.model': metadata['texttok.model'], 'speechtok.model': metadata['speechtok.model']}
|
139 |
+
ar_ckpt['model'] = {}
|
140 |
+
for k in f.keys(): ar_ckpt['model'][k] = f.get_tensor(k)
|
141 |
+
nar_ckpt = {}
|
142 |
+
with safe_open(nar_ckpt_path, framework='pt', device='cpu') as f:
|
143 |
+
metadata = f.metadata()
|
144 |
+
nar_ckpt['vocab'] = {'texttok.model': metadata['texttok.model'], 'speechtok.model': metadata['speechtok.model']}
|
145 |
+
nar_ckpt['model'] = {}
|
146 |
+
for k in f.keys(): nar_ckpt['model'][k] = f.get_tensor(k)
|
147 |
+
|
148 |
+
|
149 |
+
# Init
|
150 |
+
return cls(ar_ckpt=ar_ckpt, nar_ckpt=nar_ckpt, device=device)
|
151 |
+
|
152 |
@torch.inference_mode
|
153 |
def vocode(self, tokens: Tensor) -> Tensor:
|
154 |
""" Vocodes tokens of shape (seq_len, n_q) """
|
|
|
163 |
wav_diffusion = self.vocos.decode(features, bandwidth_id=bandwidth_id)
|
164 |
return wav_diffusion.cpu().squeeze()[None]
|
165 |
|
166 |
+
@torch.inference_mode
|
167 |
+
def get_speaker_embedding(self, ref_audio: Tensor) -> Tensor:
|
168 |
+
""" Given `ref_audio` (bs, T) audio tensor, compute the implicit speakre embedding of shape (bs, dim). """
|
169 |
+
if ref_audio.dim() == 1: ref_audio = ref_audio[None]
|
170 |
+
spk_reference = self.codec.encode(ref_audio[None].to(self.device))[0][0]
|
171 |
+
spk_reference = spk_reference.permute(0, 2, 1)
|
172 |
+
bs = spk_reference.shape[0]
|
173 |
+
if bs != 1:
|
174 |
+
raise AssertionError(f"Speaker embedding extraction only implemented using for bs=1 currently.")
|
175 |
+
spk_seq = self.codeclm.ref_chunked_emb(spk_reference) # (bs, sl, dim)
|
176 |
+
spk_ref_emb = self.codeclm.spk_identity_emb.weight[None].expand(bs, -1, -1) # (bs, 1, dim)
|
177 |
+
|
178 |
+
spk_seq = torch.cat([spk_ref_emb, spk_seq], dim=1) # (bs, 1+sl, dim)
|
179 |
+
# add pos encoding
|
180 |
+
spk_seq = self.codeclm.pos_embedding(spk_seq)
|
181 |
+
# codebook goes from indices 0->1023, padding is idx 1024 (the 1025th entry)
|
182 |
+
src_key_padding_mask = construct_padding_mask(spk_reference[:, :, 0], 1024)
|
183 |
+
src_key_padding_mask = torch.cat((
|
184 |
+
# append a zero here since we DO want to attend to initial position.
|
185 |
+
torch.zeros(src_key_padding_mask.shape[0], 1, dtype=bool, device=src_key_padding_mask.device),
|
186 |
+
src_key_padding_mask
|
187 |
+
),
|
188 |
+
dim=1)
|
189 |
+
# pass through transformer
|
190 |
+
res = self.codeclm.spk_encoder(spk_seq, is_causal=False, src_key_padding_mask=src_key_padding_mask)[:, :1] # select first element -> now (bs, 1, dim).
|
191 |
+
return res.squeeze(1)
|
192 |
+
|
193 |
@torch.inference_mode
|
194 |
def tts(self, text: str, ref_audio: Tensor, ref_transcript: Optional[str] = None,
|
195 |
cfg: Optional[InferenceConfig] = InferenceConfig()) -> Tensor:
|
|
|
247 |
first_codec_idx = prompt.shape[-1] - n_speech_inp + 1
|
248 |
|
249 |
# ---> perform AR code generation
|
|
|
250 |
logging.debug(f"Raw acoustic prompt length: {raw_prompt_acoustic_len}")
|
251 |
|
252 |
ar_codes = ar_generate(self.texttok, self.speechtok, self.codeclm,
|
253 |
prompt, spk_ref_codec, first_codec_idx,
|
254 |
+
max_len=cfg.generate_max_len_override if cfg.generate_max_len_override > 1 else 2000,
|
255 |
+
fp16=True if torch.cuda.is_available() else False,
|
256 |
temperature=cfg.temperature, topk=cfg.top_k, top_p=cfg.top_p, typical_p=cfg.typical_p,
|
257 |
alpha_frequency=cfg.freq_penalty, alpha_presence=cfg.presence_penalty, penalty_window=cfg.rep_penalty_window,
|
258 |
eos_penalty_decay=cfg.eos_penalty_decay, eos_penalty_factor=cfg.eos_penalty_factor,
|
|
|
275 |
x_padding_mask = torch.zeros((1, _x.shape[1]), dtype=torch.bool, device=_x.device)
|
276 |
|
277 |
# ---> perform DDPM NAR inference
|
|
|
278 |
T = self.default_T
|
279 |
diff = MultinomialDiffusion(self.diffusion_n_classes, timesteps=T, device=self.device)
|
280 |
|