import fam.llm.fast_inference_utils from fam.llm.fast_inference import TTS as FAMTTS from fam.llm.inference import Model as FAMModel from fam.llm.inference import InferenceConfig from fam.llm.adapters.tilted_encodec import TiltedEncodec from fam.llm.adapters.flattened_encodec import FlattenedInterleavedEncodec2Codebook from fam.llm.decoders import EncodecDecoder from fam.llm.enhancers import get_enhancer from fam.llm.utils import get_default_dtype, get_device from fam.llm.fast_model import Transformer from fam.llm.model import GPT, GPTConfig from fam.quantiser.text.tokenise import TrainedBPETokeniser from fam.quantiser.audio.speaker_encoder.model import SpeakerEncoder as FAMSpeakerEncoder from fam.quantiser.audio.speaker_encoder.model import mel_n_channels, model_hidden_size, model_embedding_size, model_num_layers import os from pathlib import Path from typing import Optional, Union from json import load, dump from base64 import b64encode, b64decode import torch from torch import nn from huggingface_hub import snapshot_download, HfFileSystem from safetensors.torch import load_model, save_model def convert_to_safetensors( stage1_path: str, stage2_path: str, spk_emb_ckpt_path: str, precision: torch.dtype, output_path: str ): config_second_stage = InferenceConfig( ckpt_path=stage2_path, num_samples=1, seed=0, device='cpu', dtype='float16' if precision == torch.float16 else 'bfloat16', compile=False, init_from='resume', output_dir='.', ) data_adapter_second_stage = TiltedEncodec(end_of_audio_token=512) stage2_model = Model(config_second_stage, TrainedBPETokeniser, EncodecDecoder, data_adapter_fn=data_adapter_second_stage.decode) stage2_checkpoint = torch.load(stage2_path, map_location='cpu') stage2_state_dict = stage2_checkpoint['model'] unwanted_prefix = '_orig_mod.' for k in stage2_state_dict.keys(): if k.startswith(unwanted_prefix): stage2_state_dict[k[len(unwanted_prefix) :]] = stage2_state_dict.pop(k) save_model(stage2_model.model, os.path.join(output_path, 'second_stage.safetensors')) stage1_model, tokenizer, smodel = fam.llm.fast_inference_utils._load_model(stage1_path, spk_emb_ckpt_path, 'cpu', precision) tokenizer_info = torch.load(stage1_path, map_location='cpu').get('meta', {}).get('tokenizer', {}) save_model(stage1_model, os.path.join(output_path, 'first_stage.safetensors')) save_model(smodel, os.path.join(output_path, 'speaker_encoder.safetensors')) with open(os.path.join(output_path, 'config.json'), 'w') as f: tokenizer_info['mergeable_ranks'] = {b64encode(k).decode('ascii'): v for k, v in tokenizer_info['mergeable_ranks'].items()} stage2_checkpoint['meta']['tokenizer']['mergeable_ranks'] = {b64encode(k).decode('ascii'): v for k, v in stage2_checkpoint['meta']['tokenizer']['mergeable_ranks'].items()} dump({ 'model_name': 'metavoice-1B-v0.1', 'stage1': { 'tokenizer_info': tokenizer_info }, 'stage2': { 'config': stage2_checkpoint['config'], 'meta': stage2_checkpoint['meta'], 'model_args': stage2_checkpoint['model_args'] } }, f) class SpeakerEncoder(FAMSpeakerEncoder): def __init__( self, weights_fpath: str, device: Optional[Union[str, torch.device]] = None, verbose: bool = True, eval: bool = False, ): nn.Module.__init__(self) # Define the network self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True) self.linear = nn.Linear(model_hidden_size, model_embedding_size) self.relu = nn.ReLU() # Get the target device if device is None: device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') elif isinstance(device, str): device = torch.device(device) self.device = device weights_fpath = str(weights_fpath) if weights_fpath.endswith('.safetensors'): load_model(self, weights_fpath) else: checkpoint = torch.load(weights_fpath, map_location='cpu') self.load_state_dict(checkpoint['model_state'], strict=False) self.to(device) if eval: self.eval() def load_safetensors_model(checkpoint_path, spk_emb_ckpt_path, device, precision): ##### MODEL with torch.device(device): model = Transformer.from_name('metavoice-1B') load_model(model, checkpoint_path) model = model.to(device=device, dtype=precision) ###### TOKENIZER with open(f'{os.path.dirname(checkpoint_path)}/config.json', 'r') as f: config = load(f)['stage1'] config['tokenizer_info']['mergeable_ranks'] = {b64decode(k): v for k, v in config['tokenizer_info']['mergeable_ranks'].items()} tokenizer_info = config['tokenizer_info'] tokenizer = TrainedBPETokeniser(**tokenizer_info) ###### SPEAKER EMBEDDER smodel = SpeakerEncoder( weights_fpath=spk_emb_ckpt_path, device=device, eval=True, verbose=False, ) return model.eval(), tokenizer, smodel class Model(FAMModel): def _init_model(self): if self.config.init_from == 'safetensors': with open(f'{os.path.dirname(self.config.ckpt_path)}/config.json', 'r') as f: config = load(f)['stage2'] self.vocab_sizes = config['model_args']['vocab_sizes'] self.checkpoint_config = config['config'] config['meta']['tokenizer']['mergeable_ranks'] = {b64decode(k): v for k, v in config['meta']['tokenizer']['mergeable_ranks'].items()} self.meta = config['meta'] self.load_meta = True self.use_bpe_tokenizer = 'stoi' not in self.meta or 'itos' not in self.meta self.speaker_cond = self.meta.get('speaker_cond') speaker_emb_size = None if self.speaker_cond: speaker_emb_size = self.meta['speaker_emb_size'] model_args = config['model_args'] if 'causal' in self.checkpoint_config and self.checkpoint_config['causal'] is False: self._encodec_ctx_window = model_args['block_size'] gptconf = GPTConfig(**model_args) self.model = GPT(gptconf, speaker_emb_dim=speaker_emb_size) load_model(self.model, self.config.ckpt_path) super()._init_model() class MetaVoiceModel(FAMTTS): def __init__(self, model_name: str, *, seed: int = 1337, output_dir: str = 'outputs', enforce_safetensors: bool = True): self._dtype = get_default_dtype() self._device = get_device() if os.path.exists(model_name): if enforce_safetensors: assert all(x in os.listdir(model_name) for x in ('config.json', 'second_stage.safetensors', 'first_stage.safetensors', 'speaker_encoder.safetensors')), 'Model is not compatible with safetensors' self._model_dir = model_name else: print('WARNING: metavoice is allowing the use of non-safetensors models. Ensure you understand the risks of loading untrusted models at https://pytorch.org/docs/stable/generated/torch.load.html') self._model_dir = model_name else: if enforce_safetensors: fs = HfFileSystem() files = [os.path.basename(x) for x in fs.ls(model_name, detail=False)] assert all(x in files for x in ('config.json', 'second_stage.safetensors', 'first_stage.safetensors', 'speaker_encoder.safetensors')), 'Model is not compatible with safetensors' self._model_dir = snapshot_download(repo_id=model_name, allow_patterns='second_stage.safetensors,first_stage.safetensors,speaker_encoder.safetensors,config.json') else: print('WARNING: metavoice is allowing the use of non-safetensors models. Ensure you understand the risks of loading untrusted models at https://pytorch.org/docs/stable/generated/torch.load.html') self._model_dir = snapshot_download(repo_id=model_name) self.first_stage_adapter = FlattenedInterleavedEncodec2Codebook(end_of_audio_token=self.END_OF_AUDIO_TOKEN) self.output_dir = output_dir os.makedirs(self.output_dir, exist_ok=True) is_safetensors = os.path.exists(f'{self._model_dir}/second_stage.safetensors') second_stage_ckpt_path = f'{self._model_dir}/{"second_stage.safetensors" if is_safetensors else "second_stage.pt"}' config_second_stage = InferenceConfig( ckpt_path=second_stage_ckpt_path, num_samples=1, seed=seed, device=self._device, dtype=self._dtype, compile=False, init_from='safetensors' if is_safetensors else 'resume', output_dir=self.output_dir, ) data_adapter_second_stage = TiltedEncodec(end_of_audio_token=self.END_OF_AUDIO_TOKEN) self.llm_second_stage = Model( config_second_stage, TrainedBPETokeniser, EncodecDecoder, data_adapter_fn=data_adapter_second_stage.decode ) self.enhancer = get_enhancer('df') self.precision = {'float16': torch.float16, 'bfloat16': torch.bfloat16}[self._dtype] build_model_kwargs = { 'precision': self.precision, 'device': self._device, 'compile': False, 'compile_prefill': True, } if is_safetensors: fam.llm.fast_inference_utils._load_model = load_safetensors_model checkpoint_path, spk_emb_ckpt_path = Path(f'{self._model_dir}/first_stage.safetensors'), Path(f'{self._model_dir}/speaker_encoder.safetensors') else: checkpoint_path, spk_emb_ckpt_path= Path(f'{self._model_dir}/first_stage.pt'), Path(f'{self._model_dir}/speaker_encoder.pt') self.model, self.tokenizer, self.smodel, self.model_size = fam.llm.fast_inference_utils.build_model( checkpoint_path=checkpoint_path, spk_emb_ckpt_path=spk_emb_ckpt_path, **build_model_kwargs ) @torch.inference_mode() def generate(self, text: str, source: str = 'https://upload.wikimedia.org/wikipedia/commons/e/e1/King_Charles_Addresses_Scottish_Parliament_-_12_September_2022.flac'): self.synthesise(text, source) def save(self, path: str): save_model(self.model, os.path.join(path, 'first_stage.safetensors')) save_model(self.smodel, os.path.join(path, 'speaker_encoder.safetensors')) save_model(self.llm_second_stage.model, os.path.join(path, 'second_stage.safetensors')) @classmethod def from_hub(cls, path: str): # TODO: TEMPORARY OUTPUT DIR return cls(path, enforce_safetensors=True)