|
import fam.llm.fast_inference_utils |
|
from fam.llm.fast_inference import TTS as FAMTTS |
|
from fam.llm.inference import Model as FAMModel |
|
from fam.llm.inference import InferenceConfig |
|
from fam.llm.adapters.tilted_encodec import TiltedEncodec |
|
from fam.llm.adapters.flattened_encodec import FlattenedInterleavedEncodec2Codebook |
|
from fam.llm.decoders import EncodecDecoder |
|
from fam.llm.enhancers import get_enhancer |
|
from fam.llm.utils import get_default_dtype, get_device |
|
from fam.llm.fast_model import Transformer |
|
from fam.llm.model import GPT, GPTConfig |
|
from fam.quantiser.text.tokenise import TrainedBPETokeniser |
|
from fam.quantiser.audio.speaker_encoder.model import SpeakerEncoder as FAMSpeakerEncoder |
|
from fam.quantiser.audio.speaker_encoder.model import mel_n_channels, model_hidden_size, model_embedding_size, model_num_layers |
|
|
|
import os |
|
from pathlib import Path |
|
from typing import Optional, Union |
|
from json import load, dump |
|
from base64 import b64encode, b64decode |
|
|
|
import torch |
|
from torch import nn |
|
from huggingface_hub import snapshot_download, HfFileSystem |
|
from safetensors.torch import load_model, save_model |
|
|
|
def convert_to_safetensors( |
|
stage1_path: str, |
|
stage2_path: str, |
|
spk_emb_ckpt_path: str, |
|
precision: torch.dtype, |
|
output_path: str |
|
): |
|
config_second_stage = InferenceConfig( |
|
ckpt_path=stage2_path, |
|
num_samples=1, |
|
seed=0, |
|
device='cpu', |
|
dtype='float16' if precision == torch.float16 else 'bfloat16', |
|
compile=False, |
|
init_from='resume', |
|
output_dir='.', |
|
) |
|
data_adapter_second_stage = TiltedEncodec(end_of_audio_token=512) |
|
stage2_model = Model(config_second_stage, TrainedBPETokeniser, EncodecDecoder, data_adapter_fn=data_adapter_second_stage.decode) |
|
|
|
stage2_checkpoint = torch.load(stage2_path, map_location='cpu') |
|
stage2_state_dict = stage2_checkpoint['model'] |
|
unwanted_prefix = '_orig_mod.' |
|
for k in stage2_state_dict.keys(): |
|
if k.startswith(unwanted_prefix): |
|
stage2_state_dict[k[len(unwanted_prefix) :]] = stage2_state_dict.pop(k) |
|
save_model(stage2_model.model, os.path.join(output_path, 'second_stage.safetensors')) |
|
|
|
stage1_model, tokenizer, smodel = fam.llm.fast_inference_utils._load_model(stage1_path, spk_emb_ckpt_path, 'cpu', precision) |
|
tokenizer_info = torch.load(stage1_path, map_location='cpu').get('meta', {}).get('tokenizer', {}) |
|
save_model(stage1_model, os.path.join(output_path, 'first_stage.safetensors')) |
|
save_model(smodel, os.path.join(output_path, 'speaker_encoder.safetensors')) |
|
|
|
with open(os.path.join(output_path, 'config.json'), 'w') as f: |
|
tokenizer_info['mergeable_ranks'] = {b64encode(k).decode('ascii'): v for k, v in tokenizer_info['mergeable_ranks'].items()} |
|
stage2_checkpoint['meta']['tokenizer']['mergeable_ranks'] = {b64encode(k).decode('ascii'): v for k, v in stage2_checkpoint['meta']['tokenizer']['mergeable_ranks'].items()} |
|
dump({ |
|
'model_name': 'metavoice-1B-v0.1', |
|
'stage1': { |
|
'tokenizer_info': tokenizer_info |
|
}, |
|
'stage2': { |
|
'config': stage2_checkpoint['config'], |
|
'meta': stage2_checkpoint['meta'], |
|
'model_args': stage2_checkpoint['model_args'] |
|
} |
|
}, f) |
|
|
|
class SpeakerEncoder(FAMSpeakerEncoder): |
|
def __init__( |
|
self, |
|
weights_fpath: str, |
|
device: Optional[Union[str, torch.device]] = None, |
|
verbose: bool = True, |
|
eval: bool = False, |
|
): |
|
nn.Module.__init__(self) |
|
|
|
|
|
self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True) |
|
self.linear = nn.Linear(model_hidden_size, model_embedding_size) |
|
self.relu = nn.ReLU() |
|
|
|
|
|
if device is None: |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
elif isinstance(device, str): |
|
device = torch.device(device) |
|
self.device = device |
|
|
|
weights_fpath = str(weights_fpath) |
|
if weights_fpath.endswith('.safetensors'): |
|
load_model(self, weights_fpath) |
|
else: |
|
checkpoint = torch.load(weights_fpath, map_location='cpu') |
|
self.load_state_dict(checkpoint['model_state'], strict=False) |
|
self.to(device) |
|
|
|
if eval: |
|
self.eval() |
|
|
|
def load_safetensors_model(checkpoint_path, spk_emb_ckpt_path, device, precision): |
|
|
|
with torch.device(device): |
|
model = Transformer.from_name('metavoice-1B') |
|
load_model(model, checkpoint_path) |
|
model = model.to(device=device, dtype=precision) |
|
|
|
|
|
with open(f'{os.path.dirname(checkpoint_path)}/config.json', 'r') as f: |
|
config = load(f)['stage1'] |
|
config['tokenizer_info']['mergeable_ranks'] = {b64decode(k): v for k, v in config['tokenizer_info']['mergeable_ranks'].items()} |
|
tokenizer_info = config['tokenizer_info'] |
|
tokenizer = TrainedBPETokeniser(**tokenizer_info) |
|
|
|
|
|
smodel = SpeakerEncoder( |
|
weights_fpath=spk_emb_ckpt_path, |
|
device=device, |
|
eval=True, |
|
verbose=False, |
|
) |
|
return model.eval(), tokenizer, smodel |
|
|
|
class Model(FAMModel): |
|
def _init_model(self): |
|
if self.config.init_from == 'safetensors': |
|
with open(f'{os.path.dirname(self.config.ckpt_path)}/config.json', 'r') as f: |
|
config = load(f)['stage2'] |
|
self.vocab_sizes = config['model_args']['vocab_sizes'] |
|
self.checkpoint_config = config['config'] |
|
config['meta']['tokenizer']['mergeable_ranks'] = {b64decode(k): v for k, v in config['meta']['tokenizer']['mergeable_ranks'].items()} |
|
|
|
self.meta = config['meta'] |
|
self.load_meta = True |
|
self.use_bpe_tokenizer = 'stoi' not in self.meta or 'itos' not in self.meta |
|
self.speaker_cond = self.meta.get('speaker_cond') |
|
|
|
speaker_emb_size = None |
|
if self.speaker_cond: |
|
speaker_emb_size = self.meta['speaker_emb_size'] |
|
|
|
model_args = config['model_args'] |
|
if 'causal' in self.checkpoint_config and self.checkpoint_config['causal'] is False: |
|
self._encodec_ctx_window = model_args['block_size'] |
|
|
|
gptconf = GPTConfig(**model_args) |
|
self.model = GPT(gptconf, speaker_emb_dim=speaker_emb_size) |
|
load_model(self.model, self.config.ckpt_path) |
|
|
|
super()._init_model() |
|
|
|
class MetaVoiceModel(FAMTTS): |
|
def __init__(self, model_name: str, *, seed: int = 1337, output_dir: str = 'outputs', enforce_safetensors: bool = True): |
|
self._dtype = get_default_dtype() |
|
self._device = get_device() |
|
|
|
if os.path.exists(model_name): |
|
if enforce_safetensors: |
|
assert all(x in os.listdir(model_name) for x in ('config.json', 'second_stage.safetensors', 'first_stage.safetensors', 'speaker_encoder.safetensors')), 'Model is not compatible with safetensors' |
|
self._model_dir = model_name |
|
else: |
|
print('WARNING: metavoice is allowing the use of non-safetensors models. Ensure you understand the risks of loading untrusted models at https://pytorch.org/docs/stable/generated/torch.load.html') |
|
self._model_dir = model_name |
|
else: |
|
if enforce_safetensors: |
|
fs = HfFileSystem() |
|
files = [os.path.basename(x) for x in fs.ls(model_name, detail=False)] |
|
assert all(x in files for x in ('config.json', 'second_stage.safetensors', 'first_stage.safetensors', 'speaker_encoder.safetensors')), 'Model is not compatible with safetensors' |
|
self._model_dir = snapshot_download(repo_id=model_name, allow_patterns='second_stage.safetensors,first_stage.safetensors,speaker_encoder.safetensors,config.json') |
|
else: |
|
print('WARNING: metavoice is allowing the use of non-safetensors models. Ensure you understand the risks of loading untrusted models at https://pytorch.org/docs/stable/generated/torch.load.html') |
|
self._model_dir = snapshot_download(repo_id=model_name) |
|
|
|
self.first_stage_adapter = FlattenedInterleavedEncodec2Codebook(end_of_audio_token=self.END_OF_AUDIO_TOKEN) |
|
self.output_dir = output_dir |
|
os.makedirs(self.output_dir, exist_ok=True) |
|
|
|
is_safetensors = os.path.exists(f'{self._model_dir}/second_stage.safetensors') |
|
second_stage_ckpt_path = f'{self._model_dir}/{"second_stage.safetensors" if is_safetensors else "second_stage.pt"}' |
|
config_second_stage = InferenceConfig( |
|
ckpt_path=second_stage_ckpt_path, |
|
num_samples=1, |
|
seed=seed, |
|
device=self._device, |
|
dtype=self._dtype, |
|
compile=False, |
|
init_from='safetensors' if is_safetensors else 'resume', |
|
output_dir=self.output_dir, |
|
) |
|
data_adapter_second_stage = TiltedEncodec(end_of_audio_token=self.END_OF_AUDIO_TOKEN) |
|
self.llm_second_stage = Model( |
|
config_second_stage, TrainedBPETokeniser, EncodecDecoder, data_adapter_fn=data_adapter_second_stage.decode |
|
) |
|
|
|
self.enhancer = get_enhancer('df') |
|
self.precision = {'float16': torch.float16, 'bfloat16': torch.bfloat16}[self._dtype] |
|
build_model_kwargs = { |
|
'precision': self.precision, |
|
'device': self._device, |
|
'compile': False, |
|
'compile_prefill': True, |
|
} |
|
if is_safetensors: |
|
fam.llm.fast_inference_utils._load_model = load_safetensors_model |
|
checkpoint_path, spk_emb_ckpt_path = Path(f'{self._model_dir}/first_stage.safetensors'), Path(f'{self._model_dir}/speaker_encoder.safetensors') |
|
else: |
|
checkpoint_path, spk_emb_ckpt_path= Path(f'{self._model_dir}/first_stage.pt'), Path(f'{self._model_dir}/speaker_encoder.pt') |
|
|
|
self.model, self.tokenizer, self.smodel, self.model_size = fam.llm.fast_inference_utils.build_model( |
|
checkpoint_path=checkpoint_path, |
|
spk_emb_ckpt_path=spk_emb_ckpt_path, |
|
**build_model_kwargs |
|
) |
|
|
|
@torch.inference_mode() |
|
def generate(self, text: str, source: str = 'https://upload.wikimedia.org/wikipedia/commons/e/e1/King_Charles_Addresses_Scottish_Parliament_-_12_September_2022.flac'): |
|
self.synthesise(text, source) |
|
|
|
def save(self, path: str): |
|
save_model(self.model, os.path.join(path, 'first_stage.safetensors')) |
|
save_model(self.smodel, os.path.join(path, 'speaker_encoder.safetensors')) |
|
save_model(self.llm_second_stage.model, os.path.join(path, 'second_stage.safetensors')) |
|
|
|
@classmethod |
|
def from_hub(cls, path: str): |
|
|
|
return cls(path, enforce_safetensors=True) |