sheepymeh commited on
Commit
e9541e9
1 Parent(s): 47b94d1

Add model inference code

Browse files
Files changed (3) hide show
  1. README.md +11 -1
  2. metavoice.py +234 -0
  3. metavoice.sh +28 -0
README.md CHANGED
@@ -9,4 +9,14 @@ library_name: metavoice
9
  inference: false
10
  pipeline_tag: text-to-speech
11
  ---
12
- Converted safetensors version of [metavoiceio/metavoice-1B-v0.1](https://huggingface.co/metavoiceio/metavoice-1B-v0.1)
 
 
 
 
 
 
 
 
 
 
 
9
  inference: false
10
  pipeline_tag: text-to-speech
11
  ---
12
+ Converted safetensors version of [metavoiceio/metavoice-1B-v0.1](https://huggingface.co/metavoiceio/metavoice-1B-v0.1)
13
+
14
+ # Usage
15
+ 1. Install the original model code by running metavoice.sh
16
+ 2. Run the following code:
17
+
18
+ ```python
19
+ from metavoice import MetaVoiceModel
20
+ model = MetaVoiceModel("sheepymeh/metavoice-1B-v0.1")
21
+ model.generate("Hello world!")
22
+ ```
metavoice.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fam.llm.fast_inference_utils
2
+ from fam.llm.fast_inference import TTS as FAMTTS
3
+ from fam.llm.inference import Model as FAMModel
4
+ from fam.llm.inference import InferenceConfig
5
+ from fam.llm.adapters.tilted_encodec import TiltedEncodec
6
+ from fam.llm.adapters.flattened_encodec import FlattenedInterleavedEncodec2Codebook
7
+ from fam.llm.decoders import EncodecDecoder
8
+ from fam.llm.enhancers import get_enhancer
9
+ from fam.llm.utils import get_default_dtype, get_device
10
+ from fam.llm.fast_model import Transformer
11
+ from fam.llm.model import GPT, GPTConfig
12
+ from fam.quantiser.text.tokenise import TrainedBPETokeniser
13
+ from fam.quantiser.audio.speaker_encoder.model import SpeakerEncoder as FAMSpeakerEncoder
14
+ from fam.quantiser.audio.speaker_encoder.model import mel_n_channels, model_hidden_size, model_embedding_size, model_num_layers
15
+
16
+ import os
17
+ from pathlib import Path
18
+ from typing import Optional, Union
19
+ from json import load, dump
20
+ from base64 import b64encode, b64decode
21
+
22
+ import torch
23
+ from torch import nn
24
+ from huggingface_hub import snapshot_download, HfFileSystem
25
+ from safetensors.torch import load_model, save_model
26
+
27
+ def convert_to_safetensors(
28
+ stage1_path: str,
29
+ stage2_path: str,
30
+ spk_emb_ckpt_path: str,
31
+ precision: torch.dtype,
32
+ output_path: str
33
+ ):
34
+ config_second_stage = InferenceConfig(
35
+ ckpt_path=stage2_path,
36
+ num_samples=1,
37
+ seed=0,
38
+ device='cpu',
39
+ dtype='float16' if precision == torch.float16 else 'bfloat16',
40
+ compile=False,
41
+ init_from='resume',
42
+ output_dir='.',
43
+ )
44
+ data_adapter_second_stage = TiltedEncodec(end_of_audio_token=512)
45
+ stage2_model = Model(config_second_stage, TrainedBPETokeniser, EncodecDecoder, data_adapter_fn=data_adapter_second_stage.decode)
46
+
47
+ stage2_checkpoint = torch.load(stage2_path, map_location='cpu')
48
+ stage2_state_dict = stage2_checkpoint['model']
49
+ unwanted_prefix = '_orig_mod.'
50
+ for k in stage2_state_dict.keys():
51
+ if k.startswith(unwanted_prefix):
52
+ stage2_state_dict[k[len(unwanted_prefix) :]] = stage2_state_dict.pop(k)
53
+ save_model(stage2_model.model, os.path.join(output_path, 'second_stage.safetensors'))
54
+
55
+ stage1_model, tokenizer, smodel = fam.llm.fast_inference_utils._load_model(stage1_path, spk_emb_ckpt_path, 'cpu', precision)
56
+ tokenizer_info = torch.load(stage1_path, map_location='cpu').get('meta', {}).get('tokenizer', {})
57
+ save_model(stage1_model, os.path.join(output_path, 'first_stage.safetensors'))
58
+ save_model(smodel, os.path.join(output_path, 'speaker_encoder.safetensors'))
59
+
60
+ with open(os.path.join(output_path, 'config.json'), 'w') as f:
61
+ tokenizer_info['mergeable_ranks'] = {b64encode(k).decode('ascii'): v for k, v in tokenizer_info['mergeable_ranks'].items()}
62
+ stage2_checkpoint['meta']['tokenizer']['mergeable_ranks'] = {b64encode(k).decode('ascii'): v for k, v in stage2_checkpoint['meta']['tokenizer']['mergeable_ranks'].items()}
63
+ dump({
64
+ 'model_name': 'metavoice-1B-v0.1',
65
+ 'stage1': {
66
+ 'tokenizer_info': tokenizer_info
67
+ },
68
+ 'stage2': {
69
+ 'config': stage2_checkpoint['config'],
70
+ 'meta': stage2_checkpoint['meta'],
71
+ 'model_args': stage2_checkpoint['model_args']
72
+ }
73
+ }, f)
74
+
75
+ class SpeakerEncoder(FAMSpeakerEncoder):
76
+ def __init__(
77
+ self,
78
+ weights_fpath: str,
79
+ device: Optional[Union[str, torch.device]] = None,
80
+ verbose: bool = True,
81
+ eval: bool = False,
82
+ ):
83
+ nn.Module.__init__(self)
84
+
85
+ # Define the network
86
+ self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
87
+ self.linear = nn.Linear(model_hidden_size, model_embedding_size)
88
+ self.relu = nn.ReLU()
89
+
90
+ # Get the target device
91
+ if device is None:
92
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
93
+ elif isinstance(device, str):
94
+ device = torch.device(device)
95
+ self.device = device
96
+
97
+ weights_fpath = str(weights_fpath)
98
+ if weights_fpath.endswith('.safetensors'):
99
+ load_model(self, weights_fpath)
100
+ else:
101
+ checkpoint = torch.load(weights_fpath, map_location='cpu')
102
+ self.load_state_dict(checkpoint['model_state'], strict=False)
103
+ self.to(device)
104
+
105
+ if eval:
106
+ self.eval()
107
+
108
+ def load_safetensors_model(checkpoint_path, spk_emb_ckpt_path, device, precision):
109
+ ##### MODEL
110
+ with torch.device(device):
111
+ model = Transformer.from_name('metavoice-1B')
112
+ load_model(model, checkpoint_path)
113
+ model = model.to(device=device, dtype=precision)
114
+
115
+ ###### TOKENIZER
116
+ with open(f'{os.path.dirname(checkpoint_path)}/config.json', 'r') as f:
117
+ config = load(f)['stage1']
118
+ config['tokenizer_info']['mergeable_ranks'] = {b64decode(k): v for k, v in config['tokenizer_info']['mergeable_ranks'].items()}
119
+ tokenizer_info = config['tokenizer_info']
120
+ tokenizer = TrainedBPETokeniser(**tokenizer_info)
121
+
122
+ ###### SPEAKER EMBEDDER
123
+ smodel = SpeakerEncoder(
124
+ weights_fpath=spk_emb_ckpt_path,
125
+ device=device,
126
+ eval=True,
127
+ verbose=False,
128
+ )
129
+ return model.eval(), tokenizer, smodel
130
+
131
+ class Model(FAMModel):
132
+ def _init_model(self):
133
+ if self.config.init_from == 'safetensors':
134
+ with open(f'{os.path.dirname(self.config.ckpt_path)}/config.json', 'r') as f:
135
+ config = load(f)['stage2']
136
+ self.vocab_sizes = config['model_args']['vocab_sizes']
137
+ self.checkpoint_config = config['config']
138
+ config['meta']['tokenizer']['mergeable_ranks'] = {b64decode(k): v for k, v in config['meta']['tokenizer']['mergeable_ranks'].items()}
139
+
140
+ self.meta = config['meta']
141
+ self.load_meta = True
142
+ self.use_bpe_tokenizer = 'stoi' not in self.meta or 'itos' not in self.meta
143
+ self.speaker_cond = self.meta.get('speaker_cond')
144
+
145
+ speaker_emb_size = None
146
+ if self.speaker_cond:
147
+ speaker_emb_size = self.meta['speaker_emb_size']
148
+
149
+ model_args = config['model_args']
150
+ if 'causal' in self.checkpoint_config and self.checkpoint_config['causal'] is False:
151
+ self._encodec_ctx_window = model_args['block_size']
152
+
153
+ gptconf = GPTConfig(**model_args)
154
+ self.model = GPT(gptconf, speaker_emb_dim=speaker_emb_size)
155
+ load_model(self.model, self.config.ckpt_path)
156
+
157
+ super()._init_model()
158
+
159
+ class MetaVoiceModel(FAMTTS):
160
+ def __init__(self, model_name: str, *, seed: int = 1337, output_dir: str = 'outputs', enforce_safetensors: bool = True):
161
+ self._dtype = get_default_dtype()
162
+ self._device = get_device()
163
+
164
+ if os.path.exists(model_name):
165
+ if enforce_safetensors:
166
+ assert all(x in os.listdir(model_name) for x in ('config.json', 'second_stage.safetensors', 'first_stage.safetensors', 'speaker_encoder.safetensors')), 'Model is not compatible with safetensors'
167
+ self._model_dir = model_name
168
+ else:
169
+ print('WARNING: metavoice is allowing the use of non-safetensors models. Ensure you understand the risks of loading untrusted models at https://pytorch.org/docs/stable/generated/torch.load.html')
170
+ self._model_dir = model_name
171
+ else:
172
+ if enforce_safetensors:
173
+ fs = HfFileSystem()
174
+ files = [os.path.basename(x) for x in fs.ls(model_name, detail=False)]
175
+ assert all(x in files for x in ('config.json', 'second_stage.safetensors', 'first_stage.safetensors', 'speaker_encoder.safetensors')), 'Model is not compatible with safetensors'
176
+ self._model_dir = snapshot_download(repo_id=model_name, allow_patterns='second_stage.safetensors,first_stage.safetensors,speaker_encoder.safetensors,config.json')
177
+ else:
178
+ print('WARNING: metavoice is allowing the use of non-safetensors models. Ensure you understand the risks of loading untrusted models at https://pytorch.org/docs/stable/generated/torch.load.html')
179
+ self._model_dir = snapshot_download(repo_id=model_name)
180
+
181
+ self.first_stage_adapter = FlattenedInterleavedEncodec2Codebook(end_of_audio_token=self.END_OF_AUDIO_TOKEN)
182
+ self.output_dir = output_dir
183
+ os.makedirs(self.output_dir, exist_ok=True)
184
+
185
+ is_safetensors = os.path.exists(f'{self._model_dir}/second_stage.safetensors')
186
+ second_stage_ckpt_path = f'{self._model_dir}/{"second_stage.safetensors" if is_safetensors else "second_stage.pt"}'
187
+ config_second_stage = InferenceConfig(
188
+ ckpt_path=second_stage_ckpt_path,
189
+ num_samples=1,
190
+ seed=seed,
191
+ device=self._device,
192
+ dtype=self._dtype,
193
+ compile=False,
194
+ init_from='safetensors' if is_safetensors else 'resume',
195
+ output_dir=self.output_dir,
196
+ )
197
+ data_adapter_second_stage = TiltedEncodec(end_of_audio_token=self.END_OF_AUDIO_TOKEN)
198
+ self.llm_second_stage = Model(
199
+ config_second_stage, TrainedBPETokeniser, EncodecDecoder, data_adapter_fn=data_adapter_second_stage.decode
200
+ )
201
+
202
+ self.enhancer = get_enhancer('df')
203
+ self.precision = {'float16': torch.float16, 'bfloat16': torch.bfloat16}[self._dtype]
204
+ build_model_kwargs = {
205
+ 'precision': self.precision,
206
+ 'device': self._device,
207
+ 'compile': False,
208
+ 'compile_prefill': True,
209
+ }
210
+ if is_safetensors:
211
+ fam.llm.fast_inference_utils._load_model = load_safetensors_model
212
+ checkpoint_path, spk_emb_ckpt_path = Path(f'{self._model_dir}/first_stage.safetensors'), Path(f'{self._model_dir}/speaker_encoder.safetensors')
213
+ else:
214
+ checkpoint_path, spk_emb_ckpt_path= Path(f'{self._model_dir}/first_stage.pt'), Path(f'{self._model_dir}/speaker_encoder.pt')
215
+
216
+ self.model, self.tokenizer, self.smodel, self.model_size = fam.llm.fast_inference_utils.build_model(
217
+ checkpoint_path=checkpoint_path,
218
+ spk_emb_ckpt_path=spk_emb_ckpt_path,
219
+ **build_model_kwargs
220
+ )
221
+
222
+ @torch.inference_mode()
223
+ def generate(self, text: str, source: str = 'https://upload.wikimedia.org/wikipedia/commons/e/e1/King_Charles_Addresses_Scottish_Parliament_-_12_September_2022.flac'):
224
+ self.synthesise(text, source)
225
+
226
+ def save(self, path: str):
227
+ save_model(self.model, os.path.join(path, 'first_stage.safetensors'))
228
+ save_model(self.smodel, os.path.join(path, 'speaker_encoder.safetensors'))
229
+ save_model(self.llm_second_stage.model, os.path.join(path, 'second_stage.safetensors'))
230
+
231
+ @classmethod
232
+ def from_hub(cls, path: str):
233
+ # TODO: TEMPORARY OUTPUT DIR
234
+ return cls(path, enforce_safetensors=True)
metavoice.sh ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ set -euo pipefail
4
+
5
+ if ! command -v -- ffmpeg > /dev/null 2>&1; then
6
+ wget https://johnvansickle.com/ffmpeg/builds/ffmpeg-git-amd64-static.tar.xz
7
+ wget https://johnvansickle.com/ffmpeg/builds/ffmpeg-git-amd64-static.tar.xz.md5
8
+ md5sum -c ffmpeg-git-amd64-static.tar.xz.md5
9
+ tar xf ffmpeg-git-amd64-static.tar.xz
10
+ [ -z VIRTUAL_ENV ] \
11
+ && sudo mv ffmpeg-git-*-static/ffprobe ffmpeg-git-*-static/ffmpeg /usr/local/bin/ \
12
+ || mv ffmpeg-git-*-static/ffprobe ffmpeg-git-*-static/ffmpeg "$VIRTUAL_ENV/bin"
13
+ rm -rf ffmpeg-git-*
14
+ fi
15
+
16
+ (command -v -- rustup > /dev/null 2>&1) || (wget -qO - https://sh.rustup.rs | sh -s -- -y)
17
+ export PATH="$HOME/.cargo/bin:$PATH"
18
+
19
+ git clone https://github.com/metavoiceio/metavoice-src.git
20
+ cd metavoice-src
21
+ pip install torch
22
+ pip install -r requirements.txt
23
+ pip install --upgrade torch torchaudio
24
+ pip install .
25
+ cd -
26
+ rm -rf metavoice-src
27
+
28
+ python -c 'from audiocraft.models import MultiBandDiffusion; MultiBandDiffusion.get_mbd_24khz(bw=6)'