YazawaSunrise commited on
Commit
d6266d7
1 Parent(s): 2308d0e

Upload infer_tool_grad.py

Browse files
Files changed (1) hide show
  1. inference/infer_tool_grad.py +160 -0
inference/infer_tool_grad.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import json
3
+ import logging
4
+ import os
5
+ import time
6
+ from pathlib import Path
7
+ import io
8
+ import librosa
9
+ import maad
10
+ import numpy as np
11
+ from inference import slicer
12
+ import parselmouth
13
+ import soundfile
14
+ import torch
15
+ import torchaudio
16
+
17
+ from hubert import hubert_model
18
+ import utils
19
+ from models import SynthesizerTrn
20
+ logging.getLogger('numba').setLevel(logging.WARNING)
21
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
22
+
23
+ def resize2d_f0(x, target_len):
24
+ source = np.array(x)
25
+ source[source < 0.001] = np.nan
26
+ target = np.interp(np.arange(0, len(source) * target_len, len(source)) / target_len, np.arange(0, len(source)),
27
+ source)
28
+ res = np.nan_to_num(target)
29
+ return res
30
+
31
+ def get_f0(x, p_len,f0_up_key=0):
32
+
33
+ time_step = 160 / 16000 * 1000
34
+ f0_min = 50
35
+ f0_max = 1100
36
+ f0_mel_min = 1127 * np.log(1 + f0_min / 700)
37
+ f0_mel_max = 1127 * np.log(1 + f0_max / 700)
38
+
39
+ f0 = parselmouth.Sound(x, 16000).to_pitch_ac(
40
+ time_step=time_step / 1000, voicing_threshold=0.6,
41
+ pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency']
42
+
43
+ pad_size=(p_len - len(f0) + 1) // 2
44
+ if(pad_size>0 or p_len - len(f0) - pad_size>0):
45
+ f0 = np.pad(f0,[[pad_size,p_len - len(f0) - pad_size]], mode='constant')
46
+
47
+ f0 *= pow(2, f0_up_key / 12)
48
+ f0_mel = 1127 * np.log(1 + f0 / 700)
49
+ f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * 254 / (f0_mel_max - f0_mel_min) + 1
50
+ f0_mel[f0_mel <= 1] = 1
51
+ f0_mel[f0_mel > 255] = 255
52
+ f0_coarse = np.rint(f0_mel).astype(np.int)
53
+ return f0_coarse, f0
54
+
55
+ def clean_pitch(input_pitch):
56
+ num_nan = np.sum(input_pitch == 1)
57
+ if num_nan / len(input_pitch) > 0.9:
58
+ input_pitch[input_pitch != 1] = 1
59
+ return input_pitch
60
+
61
+
62
+ def plt_pitch(input_pitch):
63
+ input_pitch = input_pitch.astype(float)
64
+ input_pitch[input_pitch == 1] = np.nan
65
+ return input_pitch
66
+
67
+
68
+ def f0_to_pitch(ff):
69
+ f0_pitch = 69 + 12 * np.log2(ff / 440)
70
+ return f0_pitch
71
+
72
+
73
+ def fill_a_to_b(a, b):
74
+ if len(a) < len(b):
75
+ for _ in range(0, len(b) - len(a)):
76
+ a.append(a[0])
77
+
78
+
79
+ def mkdir(paths: list):
80
+ for path in paths:
81
+ if not os.path.exists(path):
82
+ os.mkdir(path)
83
+
84
+
85
+ class VitsSvc(object):
86
+ def __init__(self):
87
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
88
+ self.SVCVITS = None
89
+ self.hps = None
90
+ self.speakers = None
91
+ self.hubert_soft = hubert_model.hubert_soft("hubert/model.pt")
92
+
93
+ def set_device(self, device):
94
+ self.device = torch.device(device)
95
+ self.hubert_soft.to(self.device)
96
+ if self.SVCVITS != None:
97
+ self.SVCVITS.to(self.device)
98
+
99
+ def loadCheckpoint(self, path):
100
+ self.hps = utils.get_hparams_from_file(f"checkpoints/{path}/config.json")
101
+ self.SVCVITS = SynthesizerTrn(
102
+ self.hps.data.filter_length // 2 + 1,
103
+ self.hps.train.segment_size // self.hps.data.hop_length,
104
+ **self.hps.model)
105
+ _ = utils.load_checkpoint(f"checkpoints/{path}/model.pth", self.SVCVITS, None)
106
+ _ = self.SVCVITS.eval().to(self.device)
107
+ self.speakers = self.hps.spk
108
+
109
+ def get_units(self, source, sr):
110
+ source = source.unsqueeze(0).to(self.device)
111
+ with torch.inference_mode():
112
+ units = self.hubert_soft.units(source)
113
+ return units
114
+
115
+
116
+ def get_unit_pitch(self, in_path, tran):
117
+ source, sr = torchaudio.load(in_path)
118
+ source = torchaudio.functional.resample(source, sr, 16000)
119
+ if len(source.shape) == 2 and source.shape[1] >= 2:
120
+ source = torch.mean(source, dim=0).unsqueeze(0)
121
+ soft = self.get_units(source, sr).squeeze(0).cpu().numpy()
122
+ f0_coarse, f0 = get_f0(source.cpu().numpy()[0], soft.shape[0]*2, tran)
123
+ return soft, f0
124
+
125
+ def infer(self, speaker_id, tran, raw_path):
126
+ speaker_id = self.speakers[speaker_id]
127
+ sid = torch.LongTensor([int(speaker_id)]).to(self.device).unsqueeze(0)
128
+ soft, pitch = self.get_unit_pitch(raw_path, tran)
129
+ f0 = torch.FloatTensor(clean_pitch(pitch)).unsqueeze(0).to(self.device)
130
+ stn_tst = torch.FloatTensor(soft)
131
+ with torch.no_grad():
132
+ x_tst = stn_tst.unsqueeze(0).to(self.device)
133
+ x_tst = torch.repeat_interleave(x_tst, repeats=2, dim=1).transpose(1, 2)
134
+ audio = self.SVCVITS.infer(x_tst, f0=f0, g=sid)[0,0].data.float()
135
+ return audio, audio.shape[-1]
136
+
137
+ def inference(self,srcaudio,chara,tran,slice_db):
138
+ sampling_rate, audio = srcaudio
139
+ audio = (audio / np.iinfo(audio.dtype).max).astype(np.float32)
140
+ if len(audio.shape) > 1:
141
+ audio = librosa.to_mono(audio.transpose(1, 0))
142
+ if sampling_rate != 16000:
143
+ audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000)
144
+ soundfile.write("tmpwav.wav", audio, 16000, format="wav")
145
+ chunks = slicer.cut("tmpwav.wav", db_thresh=slice_db)
146
+ audio_data, audio_sr = slicer.chunks2audio("tmpwav.wav", chunks)
147
+ audio = []
148
+ for (slice_tag, data) in audio_data:
149
+ length = int(np.ceil(len(data) / audio_sr * self.hps.data.sampling_rate))
150
+ raw_path = io.BytesIO()
151
+ soundfile.write(raw_path, data, audio_sr, format="wav")
152
+ raw_path.seek(0)
153
+ if slice_tag:
154
+ _audio = np.zeros(length)
155
+ else:
156
+ out_audio, out_sr = self.infer(chara, tran, raw_path)
157
+ _audio = out_audio.cpu().numpy()
158
+ audio.extend(list(_audio))
159
+ audio = (np.array(audio) * 32768.0).astype('int16')
160
+ return (self.hps.data.sampling_rate,audio)