ZeroRVC / infer /modules /train /extract_feature_print.py
JacobLinCool's picture
feat: infer
3a010aa
raw
history blame
3.11 kB
import os
import traceback
import fairseq
import numpy as np
import soundfile as sf
import torch
import torch.nn.functional as F
from model import hubert, hubert_cfg, device, fp16 as is_half
# wave must be 16k, hop_size=320
def readwave(wav_path, normalize=False):
wav, sr = sf.read(wav_path)
assert sr == 16000
feats = torch.from_numpy(wav).float()
if feats.dim() == 2: # double channels
feats = feats.mean(-1)
assert feats.dim() == 1, feats.dim()
if normalize:
with torch.no_grad():
feats = F.layer_norm(feats, feats.shape)
feats = feats.view(1, -1)
return feats
class HubertFeatureExtractor:
def __init__(self, exp_dir: str):
self.exp_dir = exp_dir
self.logfile = open("%s/extract_f0_feature.log" % exp_dir, "a+")
self.wavPath = "%s/1_16k_wavs" % exp_dir
self.outPath = "%s/3_feature768" % exp_dir
os.makedirs(self.outPath, exist_ok=True)
def println(self, strr):
print(strr)
self.logfile.write("%s\n" % strr)
self.logfile.flush()
def run(self):
todo = sorted(list(os.listdir(self.wavPath)))
n = max(1, len(todo) // 10) # ζœ€ε€šζ‰“ε°εζ‘
if len(todo) == 0:
self.println("no-feature-todo")
else:
self.println("all-feature-%s" % len(todo))
for idx, file in enumerate(todo):
try:
if file.endswith(".wav"):
wav_path = "%s/%s" % (self.wavPath, file)
out_path = "%s/%s" % (self.outPath, file.replace("wav", "npy"))
if os.path.exists(out_path):
continue
feats = readwave(wav_path, normalize=hubert_cfg.task.normalize)
padding_mask = torch.BoolTensor(feats.shape).fill_(False)
inputs = {
"source": (
feats.half().to(device) if is_half else feats.to(device)
),
"padding_mask": padding_mask.to(device),
"output_layer": 12,
}
with torch.no_grad():
logits = hubert.extract_features(**inputs)
feats = logits[0]
feats = feats.squeeze(0).float().cpu().numpy()
if np.isnan(feats).sum() == 0:
np.save(out_path, feats, allow_pickle=False)
else:
self.println("%s-contains nan" % file)
if idx % n == 0:
self.println(
"now-%s,all-%s,%s,%s"
% (len(todo), idx, file, feats.shape)
)
except:
self.println(traceback.format_exc())
self.println("all-feature-done")