Add model inference code
Browse files- README.md +11 -1
- metavoice.py +234 -0
- metavoice.sh +28 -0
README.md
CHANGED
@@ -9,4 +9,14 @@ library_name: metavoice
|
|
9 |
inference: false
|
10 |
pipeline_tag: text-to-speech
|
11 |
---
|
12 |
-
Converted safetensors version of [metavoiceio/metavoice-1B-v0.1](https://huggingface.co/metavoiceio/metavoice-1B-v0.1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
inference: false
|
10 |
pipeline_tag: text-to-speech
|
11 |
---
|
12 |
+
Converted safetensors version of [metavoiceio/metavoice-1B-v0.1](https://huggingface.co/metavoiceio/metavoice-1B-v0.1)
|
13 |
+
|
14 |
+
# Usage
|
15 |
+
1. Install the original model code by running metavoice.sh
|
16 |
+
2. Run the following code:
|
17 |
+
|
18 |
+
```python
|
19 |
+
from metavoice import MetaVoiceModel
|
20 |
+
model = MetaVoiceModel("sheepymeh/metavoice-1B-v0.1")
|
21 |
+
model.generate("Hello world!")
|
22 |
+
```
|
metavoice.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import fam.llm.fast_inference_utils
|
2 |
+
from fam.llm.fast_inference import TTS as FAMTTS
|
3 |
+
from fam.llm.inference import Model as FAMModel
|
4 |
+
from fam.llm.inference import InferenceConfig
|
5 |
+
from fam.llm.adapters.tilted_encodec import TiltedEncodec
|
6 |
+
from fam.llm.adapters.flattened_encodec import FlattenedInterleavedEncodec2Codebook
|
7 |
+
from fam.llm.decoders import EncodecDecoder
|
8 |
+
from fam.llm.enhancers import get_enhancer
|
9 |
+
from fam.llm.utils import get_default_dtype, get_device
|
10 |
+
from fam.llm.fast_model import Transformer
|
11 |
+
from fam.llm.model import GPT, GPTConfig
|
12 |
+
from fam.quantiser.text.tokenise import TrainedBPETokeniser
|
13 |
+
from fam.quantiser.audio.speaker_encoder.model import SpeakerEncoder as FAMSpeakerEncoder
|
14 |
+
from fam.quantiser.audio.speaker_encoder.model import mel_n_channels, model_hidden_size, model_embedding_size, model_num_layers
|
15 |
+
|
16 |
+
import os
|
17 |
+
from pathlib import Path
|
18 |
+
from typing import Optional, Union
|
19 |
+
from json import load, dump
|
20 |
+
from base64 import b64encode, b64decode
|
21 |
+
|
22 |
+
import torch
|
23 |
+
from torch import nn
|
24 |
+
from huggingface_hub import snapshot_download, HfFileSystem
|
25 |
+
from safetensors.torch import load_model, save_model
|
26 |
+
|
27 |
+
def convert_to_safetensors(
|
28 |
+
stage1_path: str,
|
29 |
+
stage2_path: str,
|
30 |
+
spk_emb_ckpt_path: str,
|
31 |
+
precision: torch.dtype,
|
32 |
+
output_path: str
|
33 |
+
):
|
34 |
+
config_second_stage = InferenceConfig(
|
35 |
+
ckpt_path=stage2_path,
|
36 |
+
num_samples=1,
|
37 |
+
seed=0,
|
38 |
+
device='cpu',
|
39 |
+
dtype='float16' if precision == torch.float16 else 'bfloat16',
|
40 |
+
compile=False,
|
41 |
+
init_from='resume',
|
42 |
+
output_dir='.',
|
43 |
+
)
|
44 |
+
data_adapter_second_stage = TiltedEncodec(end_of_audio_token=512)
|
45 |
+
stage2_model = Model(config_second_stage, TrainedBPETokeniser, EncodecDecoder, data_adapter_fn=data_adapter_second_stage.decode)
|
46 |
+
|
47 |
+
stage2_checkpoint = torch.load(stage2_path, map_location='cpu')
|
48 |
+
stage2_state_dict = stage2_checkpoint['model']
|
49 |
+
unwanted_prefix = '_orig_mod.'
|
50 |
+
for k in stage2_state_dict.keys():
|
51 |
+
if k.startswith(unwanted_prefix):
|
52 |
+
stage2_state_dict[k[len(unwanted_prefix) :]] = stage2_state_dict.pop(k)
|
53 |
+
save_model(stage2_model.model, os.path.join(output_path, 'second_stage.safetensors'))
|
54 |
+
|
55 |
+
stage1_model, tokenizer, smodel = fam.llm.fast_inference_utils._load_model(stage1_path, spk_emb_ckpt_path, 'cpu', precision)
|
56 |
+
tokenizer_info = torch.load(stage1_path, map_location='cpu').get('meta', {}).get('tokenizer', {})
|
57 |
+
save_model(stage1_model, os.path.join(output_path, 'first_stage.safetensors'))
|
58 |
+
save_model(smodel, os.path.join(output_path, 'speaker_encoder.safetensors'))
|
59 |
+
|
60 |
+
with open(os.path.join(output_path, 'config.json'), 'w') as f:
|
61 |
+
tokenizer_info['mergeable_ranks'] = {b64encode(k).decode('ascii'): v for k, v in tokenizer_info['mergeable_ranks'].items()}
|
62 |
+
stage2_checkpoint['meta']['tokenizer']['mergeable_ranks'] = {b64encode(k).decode('ascii'): v for k, v in stage2_checkpoint['meta']['tokenizer']['mergeable_ranks'].items()}
|
63 |
+
dump({
|
64 |
+
'model_name': 'metavoice-1B-v0.1',
|
65 |
+
'stage1': {
|
66 |
+
'tokenizer_info': tokenizer_info
|
67 |
+
},
|
68 |
+
'stage2': {
|
69 |
+
'config': stage2_checkpoint['config'],
|
70 |
+
'meta': stage2_checkpoint['meta'],
|
71 |
+
'model_args': stage2_checkpoint['model_args']
|
72 |
+
}
|
73 |
+
}, f)
|
74 |
+
|
75 |
+
class SpeakerEncoder(FAMSpeakerEncoder):
|
76 |
+
def __init__(
|
77 |
+
self,
|
78 |
+
weights_fpath: str,
|
79 |
+
device: Optional[Union[str, torch.device]] = None,
|
80 |
+
verbose: bool = True,
|
81 |
+
eval: bool = False,
|
82 |
+
):
|
83 |
+
nn.Module.__init__(self)
|
84 |
+
|
85 |
+
# Define the network
|
86 |
+
self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
|
87 |
+
self.linear = nn.Linear(model_hidden_size, model_embedding_size)
|
88 |
+
self.relu = nn.ReLU()
|
89 |
+
|
90 |
+
# Get the target device
|
91 |
+
if device is None:
|
92 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
93 |
+
elif isinstance(device, str):
|
94 |
+
device = torch.device(device)
|
95 |
+
self.device = device
|
96 |
+
|
97 |
+
weights_fpath = str(weights_fpath)
|
98 |
+
if weights_fpath.endswith('.safetensors'):
|
99 |
+
load_model(self, weights_fpath)
|
100 |
+
else:
|
101 |
+
checkpoint = torch.load(weights_fpath, map_location='cpu')
|
102 |
+
self.load_state_dict(checkpoint['model_state'], strict=False)
|
103 |
+
self.to(device)
|
104 |
+
|
105 |
+
if eval:
|
106 |
+
self.eval()
|
107 |
+
|
108 |
+
def load_safetensors_model(checkpoint_path, spk_emb_ckpt_path, device, precision):
|
109 |
+
##### MODEL
|
110 |
+
with torch.device(device):
|
111 |
+
model = Transformer.from_name('metavoice-1B')
|
112 |
+
load_model(model, checkpoint_path)
|
113 |
+
model = model.to(device=device, dtype=precision)
|
114 |
+
|
115 |
+
###### TOKENIZER
|
116 |
+
with open(f'{os.path.dirname(checkpoint_path)}/config.json', 'r') as f:
|
117 |
+
config = load(f)['stage1']
|
118 |
+
config['tokenizer_info']['mergeable_ranks'] = {b64decode(k): v for k, v in config['tokenizer_info']['mergeable_ranks'].items()}
|
119 |
+
tokenizer_info = config['tokenizer_info']
|
120 |
+
tokenizer = TrainedBPETokeniser(**tokenizer_info)
|
121 |
+
|
122 |
+
###### SPEAKER EMBEDDER
|
123 |
+
smodel = SpeakerEncoder(
|
124 |
+
weights_fpath=spk_emb_ckpt_path,
|
125 |
+
device=device,
|
126 |
+
eval=True,
|
127 |
+
verbose=False,
|
128 |
+
)
|
129 |
+
return model.eval(), tokenizer, smodel
|
130 |
+
|
131 |
+
class Model(FAMModel):
|
132 |
+
def _init_model(self):
|
133 |
+
if self.config.init_from == 'safetensors':
|
134 |
+
with open(f'{os.path.dirname(self.config.ckpt_path)}/config.json', 'r') as f:
|
135 |
+
config = load(f)['stage2']
|
136 |
+
self.vocab_sizes = config['model_args']['vocab_sizes']
|
137 |
+
self.checkpoint_config = config['config']
|
138 |
+
config['meta']['tokenizer']['mergeable_ranks'] = {b64decode(k): v for k, v in config['meta']['tokenizer']['mergeable_ranks'].items()}
|
139 |
+
|
140 |
+
self.meta = config['meta']
|
141 |
+
self.load_meta = True
|
142 |
+
self.use_bpe_tokenizer = 'stoi' not in self.meta or 'itos' not in self.meta
|
143 |
+
self.speaker_cond = self.meta.get('speaker_cond')
|
144 |
+
|
145 |
+
speaker_emb_size = None
|
146 |
+
if self.speaker_cond:
|
147 |
+
speaker_emb_size = self.meta['speaker_emb_size']
|
148 |
+
|
149 |
+
model_args = config['model_args']
|
150 |
+
if 'causal' in self.checkpoint_config and self.checkpoint_config['causal'] is False:
|
151 |
+
self._encodec_ctx_window = model_args['block_size']
|
152 |
+
|
153 |
+
gptconf = GPTConfig(**model_args)
|
154 |
+
self.model = GPT(gptconf, speaker_emb_dim=speaker_emb_size)
|
155 |
+
load_model(self.model, self.config.ckpt_path)
|
156 |
+
|
157 |
+
super()._init_model()
|
158 |
+
|
159 |
+
class MetaVoiceModel(FAMTTS):
|
160 |
+
def __init__(self, model_name: str, *, seed: int = 1337, output_dir: str = 'outputs', enforce_safetensors: bool = True):
|
161 |
+
self._dtype = get_default_dtype()
|
162 |
+
self._device = get_device()
|
163 |
+
|
164 |
+
if os.path.exists(model_name):
|
165 |
+
if enforce_safetensors:
|
166 |
+
assert all(x in os.listdir(model_name) for x in ('config.json', 'second_stage.safetensors', 'first_stage.safetensors', 'speaker_encoder.safetensors')), 'Model is not compatible with safetensors'
|
167 |
+
self._model_dir = model_name
|
168 |
+
else:
|
169 |
+
print('WARNING: metavoice is allowing the use of non-safetensors models. Ensure you understand the risks of loading untrusted models at https://pytorch.org/docs/stable/generated/torch.load.html')
|
170 |
+
self._model_dir = model_name
|
171 |
+
else:
|
172 |
+
if enforce_safetensors:
|
173 |
+
fs = HfFileSystem()
|
174 |
+
files = [os.path.basename(x) for x in fs.ls(model_name, detail=False)]
|
175 |
+
assert all(x in files for x in ('config.json', 'second_stage.safetensors', 'first_stage.safetensors', 'speaker_encoder.safetensors')), 'Model is not compatible with safetensors'
|
176 |
+
self._model_dir = snapshot_download(repo_id=model_name, allow_patterns='second_stage.safetensors,first_stage.safetensors,speaker_encoder.safetensors,config.json')
|
177 |
+
else:
|
178 |
+
print('WARNING: metavoice is allowing the use of non-safetensors models. Ensure you understand the risks of loading untrusted models at https://pytorch.org/docs/stable/generated/torch.load.html')
|
179 |
+
self._model_dir = snapshot_download(repo_id=model_name)
|
180 |
+
|
181 |
+
self.first_stage_adapter = FlattenedInterleavedEncodec2Codebook(end_of_audio_token=self.END_OF_AUDIO_TOKEN)
|
182 |
+
self.output_dir = output_dir
|
183 |
+
os.makedirs(self.output_dir, exist_ok=True)
|
184 |
+
|
185 |
+
is_safetensors = os.path.exists(f'{self._model_dir}/second_stage.safetensors')
|
186 |
+
second_stage_ckpt_path = f'{self._model_dir}/{"second_stage.safetensors" if is_safetensors else "second_stage.pt"}'
|
187 |
+
config_second_stage = InferenceConfig(
|
188 |
+
ckpt_path=second_stage_ckpt_path,
|
189 |
+
num_samples=1,
|
190 |
+
seed=seed,
|
191 |
+
device=self._device,
|
192 |
+
dtype=self._dtype,
|
193 |
+
compile=False,
|
194 |
+
init_from='safetensors' if is_safetensors else 'resume',
|
195 |
+
output_dir=self.output_dir,
|
196 |
+
)
|
197 |
+
data_adapter_second_stage = TiltedEncodec(end_of_audio_token=self.END_OF_AUDIO_TOKEN)
|
198 |
+
self.llm_second_stage = Model(
|
199 |
+
config_second_stage, TrainedBPETokeniser, EncodecDecoder, data_adapter_fn=data_adapter_second_stage.decode
|
200 |
+
)
|
201 |
+
|
202 |
+
self.enhancer = get_enhancer('df')
|
203 |
+
self.precision = {'float16': torch.float16, 'bfloat16': torch.bfloat16}[self._dtype]
|
204 |
+
build_model_kwargs = {
|
205 |
+
'precision': self.precision,
|
206 |
+
'device': self._device,
|
207 |
+
'compile': False,
|
208 |
+
'compile_prefill': True,
|
209 |
+
}
|
210 |
+
if is_safetensors:
|
211 |
+
fam.llm.fast_inference_utils._load_model = load_safetensors_model
|
212 |
+
checkpoint_path, spk_emb_ckpt_path = Path(f'{self._model_dir}/first_stage.safetensors'), Path(f'{self._model_dir}/speaker_encoder.safetensors')
|
213 |
+
else:
|
214 |
+
checkpoint_path, spk_emb_ckpt_path= Path(f'{self._model_dir}/first_stage.pt'), Path(f'{self._model_dir}/speaker_encoder.pt')
|
215 |
+
|
216 |
+
self.model, self.tokenizer, self.smodel, self.model_size = fam.llm.fast_inference_utils.build_model(
|
217 |
+
checkpoint_path=checkpoint_path,
|
218 |
+
spk_emb_ckpt_path=spk_emb_ckpt_path,
|
219 |
+
**build_model_kwargs
|
220 |
+
)
|
221 |
+
|
222 |
+
@torch.inference_mode()
|
223 |
+
def generate(self, text: str, source: str = 'https://upload.wikimedia.org/wikipedia/commons/e/e1/King_Charles_Addresses_Scottish_Parliament_-_12_September_2022.flac'):
|
224 |
+
self.synthesise(text, source)
|
225 |
+
|
226 |
+
def save(self, path: str):
|
227 |
+
save_model(self.model, os.path.join(path, 'first_stage.safetensors'))
|
228 |
+
save_model(self.smodel, os.path.join(path, 'speaker_encoder.safetensors'))
|
229 |
+
save_model(self.llm_second_stage.model, os.path.join(path, 'second_stage.safetensors'))
|
230 |
+
|
231 |
+
@classmethod
|
232 |
+
def from_hub(cls, path: str):
|
233 |
+
# TODO: TEMPORARY OUTPUT DIR
|
234 |
+
return cls(path, enforce_safetensors=True)
|
metavoice.sh
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
set -euo pipefail
|
4 |
+
|
5 |
+
if ! command -v -- ffmpeg > /dev/null 2>&1; then
|
6 |
+
wget https://johnvansickle.com/ffmpeg/builds/ffmpeg-git-amd64-static.tar.xz
|
7 |
+
wget https://johnvansickle.com/ffmpeg/builds/ffmpeg-git-amd64-static.tar.xz.md5
|
8 |
+
md5sum -c ffmpeg-git-amd64-static.tar.xz.md5
|
9 |
+
tar xf ffmpeg-git-amd64-static.tar.xz
|
10 |
+
[ -z VIRTUAL_ENV ] \
|
11 |
+
&& sudo mv ffmpeg-git-*-static/ffprobe ffmpeg-git-*-static/ffmpeg /usr/local/bin/ \
|
12 |
+
|| mv ffmpeg-git-*-static/ffprobe ffmpeg-git-*-static/ffmpeg "$VIRTUAL_ENV/bin"
|
13 |
+
rm -rf ffmpeg-git-*
|
14 |
+
fi
|
15 |
+
|
16 |
+
(command -v -- rustup > /dev/null 2>&1) || (wget -qO - https://sh.rustup.rs | sh -s -- -y)
|
17 |
+
export PATH="$HOME/.cargo/bin:$PATH"
|
18 |
+
|
19 |
+
git clone https://github.com/metavoiceio/metavoice-src.git
|
20 |
+
cd metavoice-src
|
21 |
+
pip install torch
|
22 |
+
pip install -r requirements.txt
|
23 |
+
pip install --upgrade torch torchaudio
|
24 |
+
pip install .
|
25 |
+
cd -
|
26 |
+
rm -rf metavoice-src
|
27 |
+
|
28 |
+
python -c 'from audiocraft.models import MultiBandDiffusion; MultiBandDiffusion.get_mbd_24khz(bw=6)'
|