|
import os |
|
import requests |
|
import yaml |
|
import torch |
|
import librosa |
|
import numpy as np |
|
import soundfile as sf |
|
from pathlib import Path |
|
from transformers import T5Tokenizer, T5EncoderModel |
|
from tqdm import tqdm |
|
from .src.vc_wrapper import ReDiffVC, DreamVC |
|
from .src.plugin_wrapper import DreamVG |
|
from .src.modules.speaker_encoder.encoder import inference as spk_encoder |
|
from .src.modules.BigVGAN.inference import load_model as load_vocoder |
|
from .src.feats.contentvec_hf import get_content_model, get_content |
|
|
|
|
|
class DreamVoice: |
|
def __init__(self, config='dreamvc.yaml', mode='plugin', device='cuda', chunk_size=16): |
|
|
|
script_dir = Path(__file__).resolve().parent |
|
config_path = script_dir / config |
|
|
|
|
|
with open(config_path, 'r') as fp: |
|
self.config = yaml.safe_load(fp) |
|
|
|
self.script_dir = script_dir |
|
|
|
|
|
self._ensure_checkpoints_exist() |
|
|
|
|
|
self.device = device |
|
self.sr = self.config['sample_rate'] |
|
|
|
|
|
vocoder_path = script_dir / self.config['vocoder_path'] |
|
self.hifigan, _ = load_vocoder(vocoder_path, device) |
|
self.hifigan.eval() |
|
|
|
|
|
self.content_model = get_content_model().to(device) |
|
|
|
|
|
lm_path = self.config['lm_path'] |
|
self.tokenizer = T5Tokenizer.from_pretrained(lm_path) |
|
self.text_encoder = T5EncoderModel.from_pretrained(lm_path).to(device).eval() |
|
|
|
|
|
self.mode = mode |
|
if mode == 'plugin': |
|
self._init_plugin_mode() |
|
elif mode == 'end2end': |
|
self._init_end2end_mode() |
|
else: |
|
raise NotImplementedError("Select mode from 'plugin' and 'end2end'") |
|
|
|
|
|
self.chunk_size = chunk_size * 50 |
|
|
|
def _ensure_checkpoints_exist(self): |
|
checkpoints = [ |
|
('vocoder_path', self.config.get('vocoder_url')), |
|
('vocoder_config_path', self.config.get('vocoder_config_url')), |
|
('speaker_path', self.config.get('speaker_url')), |
|
('dreamvc.ckpt_path', self.config.get('dreamvc', {}).get('ckpt_url')), |
|
('rediffvc.ckpt_path', self.config.get('rediffvc', {}).get('ckpt_url')), |
|
('dreamvg.ckpt_path', self.config.get('dreamvg', {}).get('ckpt_url')) |
|
] |
|
|
|
for path_key, url in checkpoints: |
|
local_path = self._get_local_path(path_key) |
|
if not local_path.exists() and url: |
|
print(f"Downloading {path_key} from {url}") |
|
self._download_file(url, local_path) |
|
|
|
def _get_local_path(self, path_key): |
|
keys = path_key.split('.') |
|
local_path = self.config |
|
for key in keys: |
|
local_path = local_path.get(key, {}) |
|
return self.script_dir / local_path |
|
|
|
def _download_file(self, url, local_path): |
|
try: |
|
|
|
response = requests.get(url, stream=True) |
|
response.raise_for_status() |
|
except requests.exceptions.RequestException as e: |
|
|
|
print(f"Error encountered: {e}") |
|
|
|
|
|
user_input = input("Private checkpoint, please request authorization and enter your Hugging Face API key.") |
|
self.hf_key = user_input if user_input else None |
|
|
|
|
|
headers = {'Authorization': f'Bearer {self.hf_key}'} if self.hf_key else {} |
|
|
|
try: |
|
|
|
response = requests.get(url, stream=True, headers=headers) |
|
response.raise_for_status() |
|
except requests.exceptions.RequestException as e: |
|
|
|
print(f"Error encountered in dev mode: {e}") |
|
response = None |
|
|
|
local_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
total_size = int(response.headers.get('content-length', 0)) |
|
block_size = 8192 |
|
t = tqdm(total=total_size, unit='iB', unit_scale=True) |
|
|
|
with open(local_path, 'wb') as f: |
|
for chunk in response.iter_content(chunk_size=block_size): |
|
t.update(len(chunk)) |
|
f.write(chunk) |
|
t.close() |
|
|
|
def _init_plugin_mode(self): |
|
|
|
self.dreamvc = ReDiffVC( |
|
config_path=self.script_dir / self.config['rediffvc']['config_path'], |
|
ckpt_path=self.script_dir / self.config['rediffvc']['ckpt_path'], |
|
device=self.device |
|
) |
|
|
|
|
|
self.dreamvg = DreamVG( |
|
config_path=self.script_dir / self.config['dreamvg']['config_path'], |
|
ckpt_path=self.script_dir / self.config['dreamvg']['ckpt_path'], |
|
device=self.device |
|
) |
|
|
|
|
|
spk_encoder.load_model(self.script_dir / self.config['speaker_path'], self.device) |
|
self.spk_encoder = spk_encoder |
|
self.spk_embed_cache = None |
|
|
|
def _init_end2end_mode(self): |
|
|
|
self.dreamvc = DreamVC( |
|
config_path=self.script_dir / self.config['dreamvc']['config_path'], |
|
ckpt_path=self.script_dir / self.config['dreamvc']['ckpt_path'], |
|
device=self.device |
|
) |
|
|
|
def _load_content(self, audio_path): |
|
content_audio, _ = librosa.load(audio_path, sr=16000) |
|
|
|
target_length = ((len(content_audio) + 16*160 - 1) // (16*160)) * (16*160) |
|
|
|
if len(content_audio) < target_length: |
|
content_audio = np.pad(content_audio, (0, target_length - len(content_audio)), mode='constant') |
|
content_audio = torch.tensor(content_audio).unsqueeze(0).to(self.device) |
|
content_clip = get_content(self.content_model, content_audio) |
|
return content_clip |
|
|
|
def load_spk_embed(self, emb_path): |
|
self.spk_embed_cache = torch.load(emb_path, map_location=self.device) |
|
|
|
def save_spk_embed(self, emb_path): |
|
assert self.spk_embed_cache is not None |
|
torch.save(self.spk_embed_cache.cpu(), emb_path) |
|
|
|
def save_audio(self, output_path, audio, sr): |
|
sf.write(output_path, audio, samplerate=sr) |
|
|
|
@torch.no_grad() |
|
def genvc(self, content_audio, prompt, |
|
prompt_guidance_scale=3, prompt_guidance_rescale=0.0, |
|
prompt_ddim_steps=100, prompt_eta=1, prompt_random_seed=None, |
|
vc_guidance_scale=3, vc_guidance_rescale=0.0, |
|
vc_ddim_steps=50, vc_eta=1, vc_random_seed=None, |
|
): |
|
|
|
content_clip = self._load_content(content_audio) |
|
|
|
text_batch = self.tokenizer(prompt, max_length=32, |
|
padding='max_length', truncation=True, return_tensors="pt") |
|
text, text_mask = text_batch.input_ids.to(self.device), \ |
|
text_batch.attention_mask.to(self.device) |
|
text = self.text_encoder(input_ids=text, attention_mask=text_mask)[0] |
|
|
|
if self.mode == 'plugin': |
|
spk_embed = self.dreamvg.inference([text, text_mask], |
|
guidance_scale=prompt_guidance_scale, |
|
guidance_rescale=prompt_guidance_rescale, |
|
ddim_steps=prompt_ddim_steps, eta=prompt_eta, |
|
random_seed=prompt_random_seed) |
|
|
|
B, L, D = content_clip.shape |
|
gen_audio_chunks = [] |
|
num_chunks = (L + self.chunk_size - 1) // self.chunk_size |
|
for i in range(num_chunks): |
|
start_idx = i * self.chunk_size |
|
end_idx = min((i + 1) * self.chunk_size, L) |
|
content_clip_chunk = content_clip[:, start_idx:end_idx, :] |
|
|
|
gen_audio_chunk = self.dreamvc.inference( |
|
spk_embed, content_clip_chunk, None, |
|
guidance_scale=vc_guidance_scale, |
|
guidance_rescale=vc_guidance_rescale, |
|
ddim_steps=vc_ddim_steps, |
|
eta=vc_eta, |
|
random_seed=vc_random_seed) |
|
|
|
gen_audio_chunks.append(gen_audio_chunk) |
|
|
|
gen_audio = torch.cat(gen_audio_chunks, dim=-1) |
|
|
|
self.spk_embed_cache = spk_embed |
|
|
|
elif self.mode == 'end2end': |
|
B, L, D = content_clip.shape |
|
gen_audio_chunks = [] |
|
num_chunks = (L + self.chunk_size - 1) // self.chunk_size |
|
|
|
for i in range(num_chunks): |
|
start_idx = i * self.chunk_size |
|
end_idx = min((i + 1) * self.chunk_size, L) |
|
content_clip_chunk = content_clip[:, start_idx:end_idx, :] |
|
|
|
gen_audio_chunk = self.dreamvc.inference([text, text_mask], content_clip, |
|
guidance_scale=prompt_guidance_scale, |
|
guidance_rescale=prompt_guidance_rescale, |
|
ddim_steps=prompt_ddim_steps, |
|
eta=prompt_eta, random_seed=prompt_random_seed) |
|
gen_audio_chunks.append(gen_audio_chunk) |
|
|
|
gen_audio = torch.cat(gen_audio_chunks, dim=-1) |
|
|
|
else: |
|
raise NotImplementedError("Select mode from 'plugin' and 'end2end'") |
|
|
|
gen_audio = self.hifigan(gen_audio.squeeze(1)) |
|
gen_audio = gen_audio.cpu().numpy().squeeze(0).squeeze(0) |
|
|
|
return gen_audio, self.sr |
|
|
|
@torch.no_grad() |
|
def simplevc(self, content_audio, speaker_audio=None, use_spk_cache=False, |
|
vc_guidance_scale=3, vc_guidance_rescale=0.0, |
|
vc_ddim_steps=50, vc_eta=1, vc_random_seed=None, |
|
): |
|
|
|
assert self.mode == 'plugin' |
|
if speaker_audio is not None: |
|
speaker_audio, _ = librosa.load(speaker_audio, sr=16000) |
|
speaker_audio = torch.tensor(speaker_audio).unsqueeze(0).to(self.device) |
|
spk_embed = spk_encoder.embed_utterance_batch(speaker_audio) |
|
self.spk_embed_cache = spk_embed |
|
elif use_spk_cache: |
|
assert self.spk_embed_cache is not None |
|
spk_embed = self.spk_embed_cache |
|
else: |
|
raise NotImplementedError |
|
|
|
content_clip = self._load_content(content_audio) |
|
|
|
B, L, D = content_clip.shape |
|
gen_audio_chunks = [] |
|
num_chunks = (L + self.chunk_size - 1) // self.chunk_size |
|
for i in range(num_chunks): |
|
start_idx = i * self.chunk_size |
|
end_idx = min((i + 1) * self.chunk_size, L) |
|
content_clip_chunk = content_clip[:, start_idx:end_idx, :] |
|
|
|
gen_audio_chunk = self.dreamvc.inference( |
|
spk_embed, content_clip_chunk, None, |
|
guidance_scale=vc_guidance_scale, |
|
guidance_rescale=vc_guidance_rescale, |
|
ddim_steps=vc_ddim_steps, |
|
eta=vc_eta, |
|
random_seed=vc_random_seed) |
|
|
|
gen_audio_chunks.append(gen_audio_chunk) |
|
|
|
gen_audio = torch.cat(gen_audio_chunks, dim=-1) |
|
|
|
gen_audio = self.hifigan(gen_audio.squeeze(1)) |
|
gen_audio = gen_audio.cpu().numpy().squeeze(0).squeeze(0) |
|
|
|
return gen_audio, self.sr |
|
|
|
|
|
if __name__ == '__main__': |
|
dreamvoice = DreamVoice(config='dreamvc.yaml', mode='plugin', device='cuda') |
|
content_audio = 'test.wav' |
|
speaker_audio = 'speaker.wav' |
|
prompt = 'young female voice, sounds young and cute' |
|
gen_audio, sr = dreamvoice.genvc('test.wav', prompt) |
|
dreamvoice.save_audio('debug.wav', gen_audio, sr) |