File size: 9,741 Bytes
e9541e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
import fam.llm.fast_inference_utils
from fam.llm.fast_inference import TTS as FAMTTS
from fam.llm.inference import Model as FAMModel
from fam.llm.inference import InferenceConfig
from fam.llm.adapters.tilted_encodec import TiltedEncodec
from fam.llm.adapters.flattened_encodec import FlattenedInterleavedEncodec2Codebook
from fam.llm.decoders import EncodecDecoder
from fam.llm.enhancers import get_enhancer
from fam.llm.utils import get_default_dtype, get_device
from fam.llm.fast_model import Transformer
from fam.llm.model import GPT, GPTConfig
from fam.quantiser.text.tokenise import TrainedBPETokeniser
from fam.quantiser.audio.speaker_encoder.model import SpeakerEncoder as FAMSpeakerEncoder
from fam.quantiser.audio.speaker_encoder.model import mel_n_channels, model_hidden_size, model_embedding_size, model_num_layers

import os
from pathlib import Path
from typing import Optional, Union
from json import load, dump
from base64 import b64encode, b64decode

import torch
from torch import nn
from huggingface_hub import snapshot_download, HfFileSystem
from safetensors.torch import load_model, save_model

def convert_to_safetensors(
	stage1_path: str,
	stage2_path: str,
	spk_emb_ckpt_path: str,
	precision: torch.dtype,
	output_path: str
):
	config_second_stage = InferenceConfig(
		ckpt_path=stage2_path,
		num_samples=1,
		seed=0,
		device='cpu',
		dtype='float16' if precision == torch.float16 else 'bfloat16',
		compile=False,
		init_from='resume',
		output_dir='.',
	)
	data_adapter_second_stage = TiltedEncodec(end_of_audio_token=512)
	stage2_model = Model(config_second_stage, TrainedBPETokeniser, EncodecDecoder, data_adapter_fn=data_adapter_second_stage.decode)

	stage2_checkpoint = torch.load(stage2_path, map_location='cpu')
	stage2_state_dict = stage2_checkpoint['model']
	unwanted_prefix = '_orig_mod.'
	for k in stage2_state_dict.keys():
		if k.startswith(unwanted_prefix):
			stage2_state_dict[k[len(unwanted_prefix) :]] = stage2_state_dict.pop(k)
	save_model(stage2_model.model, os.path.join(output_path, 'second_stage.safetensors'))

	stage1_model, tokenizer, smodel = fam.llm.fast_inference_utils._load_model(stage1_path, spk_emb_ckpt_path, 'cpu', precision)
	tokenizer_info = torch.load(stage1_path, map_location='cpu').get('meta', {}).get('tokenizer', {})
	save_model(stage1_model, os.path.join(output_path, 'first_stage.safetensors'))
	save_model(smodel, os.path.join(output_path, 'speaker_encoder.safetensors'))

	with open(os.path.join(output_path, 'config.json'), 'w') as f:
		tokenizer_info['mergeable_ranks'] = {b64encode(k).decode('ascii'): v for k, v in tokenizer_info['mergeable_ranks'].items()}
		stage2_checkpoint['meta']['tokenizer']['mergeable_ranks'] = {b64encode(k).decode('ascii'): v for k, v in stage2_checkpoint['meta']['tokenizer']['mergeable_ranks'].items()}
		dump({
			'model_name': 'metavoice-1B-v0.1',
			'stage1': {
				'tokenizer_info': tokenizer_info
			},
			'stage2': {
				'config': stage2_checkpoint['config'],
				'meta': stage2_checkpoint['meta'],
				'model_args': stage2_checkpoint['model_args']
			}
		}, f)

class SpeakerEncoder(FAMSpeakerEncoder):
	def __init__(
		self,
		weights_fpath: str,
		device: Optional[Union[str, torch.device]] = None,
		verbose: bool = True,
		eval: bool = False,
	):
		nn.Module.__init__(self)

		# Define the network
		self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
		self.linear = nn.Linear(model_hidden_size, model_embedding_size)
		self.relu = nn.ReLU()

		# Get the target device
		if device is None:
			device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
		elif isinstance(device, str):
			device = torch.device(device)
		self.device = device

		weights_fpath = str(weights_fpath)
		if weights_fpath.endswith('.safetensors'):
			load_model(self, weights_fpath)
		else:
			checkpoint = torch.load(weights_fpath, map_location='cpu')
			self.load_state_dict(checkpoint['model_state'], strict=False)
		self.to(device)

		if eval:
			self.eval()

def load_safetensors_model(checkpoint_path, spk_emb_ckpt_path, device, precision):
	##### MODEL
	with torch.device(device):
		model = Transformer.from_name('metavoice-1B')
		load_model(model, checkpoint_path)
	model = model.to(device=device, dtype=precision)

	###### TOKENIZER
	with open(f'{os.path.dirname(checkpoint_path)}/config.json', 'r') as f:
		config = load(f)['stage1']
	config['tokenizer_info']['mergeable_ranks'] = {b64decode(k): v for k, v in config['tokenizer_info']['mergeable_ranks'].items()}
	tokenizer_info = config['tokenizer_info']
	tokenizer = TrainedBPETokeniser(**tokenizer_info)

	###### SPEAKER EMBEDDER
	smodel = SpeakerEncoder(
		weights_fpath=spk_emb_ckpt_path,
		device=device,
		eval=True,
		verbose=False,
	)
	return model.eval(), tokenizer, smodel

class Model(FAMModel):
	def _init_model(self):
		if self.config.init_from == 'safetensors':
			with open(f'{os.path.dirname(self.config.ckpt_path)}/config.json', 'r') as f:
				config = load(f)['stage2']
			self.vocab_sizes = config['model_args']['vocab_sizes']
			self.checkpoint_config = config['config']
			config['meta']['tokenizer']['mergeable_ranks'] = {b64decode(k): v for k, v in config['meta']['tokenizer']['mergeable_ranks'].items()}

			self.meta = config['meta']
			self.load_meta = True
			self.use_bpe_tokenizer = 'stoi' not in self.meta or 'itos' not in self.meta
			self.speaker_cond = self.meta.get('speaker_cond')

			speaker_emb_size = None
			if self.speaker_cond:
				speaker_emb_size = self.meta['speaker_emb_size']

			model_args = config['model_args']
			if 'causal' in self.checkpoint_config and self.checkpoint_config['causal'] is False:
				self._encodec_ctx_window = model_args['block_size']

			gptconf = GPTConfig(**model_args)
			self.model = GPT(gptconf, speaker_emb_dim=speaker_emb_size)
			load_model(self.model, self.config.ckpt_path)

		super()._init_model()

class MetaVoiceModel(FAMTTS):
	def __init__(self, model_name: str, *, seed: int = 1337, output_dir: str = 'outputs', enforce_safetensors: bool = True):
		self._dtype = get_default_dtype()
		self._device = get_device()

		if os.path.exists(model_name):
			if enforce_safetensors:
				assert all(x in os.listdir(model_name) for x in ('config.json', 'second_stage.safetensors', 'first_stage.safetensors', 'speaker_encoder.safetensors')), 'Model is not compatible with safetensors'
				self._model_dir = model_name
			else:
				print('WARNING: metavoice is allowing the use of non-safetensors models. Ensure you understand the risks of loading untrusted models at https://pytorch.org/docs/stable/generated/torch.load.html')
				self._model_dir = model_name
		else:
			if enforce_safetensors:
				fs = HfFileSystem()
				files = [os.path.basename(x) for x in fs.ls(model_name, detail=False)]
				assert all(x in files for x in ('config.json', 'second_stage.safetensors', 'first_stage.safetensors', 'speaker_encoder.safetensors')), 'Model is not compatible with safetensors'
				self._model_dir = snapshot_download(repo_id=model_name, allow_patterns='second_stage.safetensors,first_stage.safetensors,speaker_encoder.safetensors,config.json')
			else:
				print('WARNING: metavoice is allowing the use of non-safetensors models. Ensure you understand the risks of loading untrusted models at https://pytorch.org/docs/stable/generated/torch.load.html')
				self._model_dir = snapshot_download(repo_id=model_name)

		self.first_stage_adapter = FlattenedInterleavedEncodec2Codebook(end_of_audio_token=self.END_OF_AUDIO_TOKEN)
		self.output_dir = output_dir
		os.makedirs(self.output_dir, exist_ok=True)

		is_safetensors = os.path.exists(f'{self._model_dir}/second_stage.safetensors')
		second_stage_ckpt_path = f'{self._model_dir}/{"second_stage.safetensors" if is_safetensors else "second_stage.pt"}'
		config_second_stage = InferenceConfig(
			ckpt_path=second_stage_ckpt_path,
			num_samples=1,
			seed=seed,
			device=self._device,
			dtype=self._dtype,
			compile=False,
			init_from='safetensors' if is_safetensors else 'resume',
			output_dir=self.output_dir,
		)
		data_adapter_second_stage = TiltedEncodec(end_of_audio_token=self.END_OF_AUDIO_TOKEN)
		self.llm_second_stage = Model(
			config_second_stage, TrainedBPETokeniser, EncodecDecoder, data_adapter_fn=data_adapter_second_stage.decode
		)

		self.enhancer = get_enhancer('df')
		self.precision = {'float16': torch.float16, 'bfloat16': torch.bfloat16}[self._dtype]
		build_model_kwargs = {
			'precision': self.precision,
			'device': self._device,
			'compile': False,
			'compile_prefill': True,
		}
		if is_safetensors:
			fam.llm.fast_inference_utils._load_model = load_safetensors_model
			checkpoint_path, spk_emb_ckpt_path = Path(f'{self._model_dir}/first_stage.safetensors'), Path(f'{self._model_dir}/speaker_encoder.safetensors')
		else:
			checkpoint_path, spk_emb_ckpt_path= Path(f'{self._model_dir}/first_stage.pt'), Path(f'{self._model_dir}/speaker_encoder.pt')

		self.model, self.tokenizer, self.smodel, self.model_size = fam.llm.fast_inference_utils.build_model(
			checkpoint_path=checkpoint_path,
			spk_emb_ckpt_path=spk_emb_ckpt_path,
			**build_model_kwargs
		)

	@torch.inference_mode()
	def generate(self, text: str, source: str = 'https://upload.wikimedia.org/wikipedia/commons/e/e1/King_Charles_Addresses_Scottish_Parliament_-_12_September_2022.flac'):
		self.synthesise(text, source)

	def save(self, path: str):
		save_model(self.model, os.path.join(path, 'first_stage.safetensors'))
		save_model(self.smodel, os.path.join(path, 'speaker_encoder.safetensors'))
		save_model(self.llm_second_stage.model, os.path.join(path, 'second_stage.safetensors'))

	@classmethod
	def from_hub(cls, path: str):
		# TODO: TEMPORARY OUTPUT DIR
		return cls(path, enforce_safetensors=True)