metavoice-1B-v0.1 / metavoice.py
sheepymeh's picture
Add model inference code
e9541e9 verified
raw history blame
No virus
9.74 kB
import fam.llm.fast_inference_utils
from fam.llm.fast_inference import TTS as FAMTTS
from fam.llm.inference import Model as FAMModel
from fam.llm.inference import InferenceConfig
from fam.llm.adapters.tilted_encodec import TiltedEncodec
from fam.llm.adapters.flattened_encodec import FlattenedInterleavedEncodec2Codebook
from fam.llm.decoders import EncodecDecoder
from fam.llm.enhancers import get_enhancer
from fam.llm.utils import get_default_dtype, get_device
from fam.llm.fast_model import Transformer
from fam.llm.model import GPT, GPTConfig
from fam.quantiser.text.tokenise import TrainedBPETokeniser
from fam.quantiser.audio.speaker_encoder.model import SpeakerEncoder as FAMSpeakerEncoder
from fam.quantiser.audio.speaker_encoder.model import mel_n_channels, model_hidden_size, model_embedding_size, model_num_layers
import os
from pathlib import Path
from typing import Optional, Union
from json import load, dump
from base64 import b64encode, b64decode
import torch
from torch import nn
from huggingface_hub import snapshot_download, HfFileSystem
from safetensors.torch import load_model, save_model
def convert_to_safetensors(
stage1_path: str,
stage2_path: str,
spk_emb_ckpt_path: str,
precision: torch.dtype,
output_path: str
):
config_second_stage = InferenceConfig(
ckpt_path=stage2_path,
num_samples=1,
seed=0,
device='cpu',
dtype='float16' if precision == torch.float16 else 'bfloat16',
compile=False,
init_from='resume',
output_dir='.',
)
data_adapter_second_stage = TiltedEncodec(end_of_audio_token=512)
stage2_model = Model(config_second_stage, TrainedBPETokeniser, EncodecDecoder, data_adapter_fn=data_adapter_second_stage.decode)
stage2_checkpoint = torch.load(stage2_path, map_location='cpu')
stage2_state_dict = stage2_checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k in stage2_state_dict.keys():
if k.startswith(unwanted_prefix):
stage2_state_dict[k[len(unwanted_prefix) :]] = stage2_state_dict.pop(k)
save_model(stage2_model.model, os.path.join(output_path, 'second_stage.safetensors'))
stage1_model, tokenizer, smodel = fam.llm.fast_inference_utils._load_model(stage1_path, spk_emb_ckpt_path, 'cpu', precision)
tokenizer_info = torch.load(stage1_path, map_location='cpu').get('meta', {}).get('tokenizer', {})
save_model(stage1_model, os.path.join(output_path, 'first_stage.safetensors'))
save_model(smodel, os.path.join(output_path, 'speaker_encoder.safetensors'))
with open(os.path.join(output_path, 'config.json'), 'w') as f:
tokenizer_info['mergeable_ranks'] = {b64encode(k).decode('ascii'): v for k, v in tokenizer_info['mergeable_ranks'].items()}
stage2_checkpoint['meta']['tokenizer']['mergeable_ranks'] = {b64encode(k).decode('ascii'): v for k, v in stage2_checkpoint['meta']['tokenizer']['mergeable_ranks'].items()}
dump({
'model_name': 'metavoice-1B-v0.1',
'stage1': {
'tokenizer_info': tokenizer_info
},
'stage2': {
'config': stage2_checkpoint['config'],
'meta': stage2_checkpoint['meta'],
'model_args': stage2_checkpoint['model_args']
}
}, f)
class SpeakerEncoder(FAMSpeakerEncoder):
def __init__(
self,
weights_fpath: str,
device: Optional[Union[str, torch.device]] = None,
verbose: bool = True,
eval: bool = False,
):
nn.Module.__init__(self)
# Define the network
self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
self.linear = nn.Linear(model_hidden_size, model_embedding_size)
self.relu = nn.ReLU()
# Get the target device
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
elif isinstance(device, str):
device = torch.device(device)
self.device = device
weights_fpath = str(weights_fpath)
if weights_fpath.endswith('.safetensors'):
load_model(self, weights_fpath)
else:
checkpoint = torch.load(weights_fpath, map_location='cpu')
self.load_state_dict(checkpoint['model_state'], strict=False)
self.to(device)
if eval:
self.eval()
def load_safetensors_model(checkpoint_path, spk_emb_ckpt_path, device, precision):
##### MODEL
with torch.device(device):
model = Transformer.from_name('metavoice-1B')
load_model(model, checkpoint_path)
model = model.to(device=device, dtype=precision)
###### TOKENIZER
with open(f'{os.path.dirname(checkpoint_path)}/config.json', 'r') as f:
config = load(f)['stage1']
config['tokenizer_info']['mergeable_ranks'] = {b64decode(k): v for k, v in config['tokenizer_info']['mergeable_ranks'].items()}
tokenizer_info = config['tokenizer_info']
tokenizer = TrainedBPETokeniser(**tokenizer_info)
###### SPEAKER EMBEDDER
smodel = SpeakerEncoder(
weights_fpath=spk_emb_ckpt_path,
device=device,
eval=True,
verbose=False,
)
return model.eval(), tokenizer, smodel
class Model(FAMModel):
def _init_model(self):
if self.config.init_from == 'safetensors':
with open(f'{os.path.dirname(self.config.ckpt_path)}/config.json', 'r') as f:
config = load(f)['stage2']
self.vocab_sizes = config['model_args']['vocab_sizes']
self.checkpoint_config = config['config']
config['meta']['tokenizer']['mergeable_ranks'] = {b64decode(k): v for k, v in config['meta']['tokenizer']['mergeable_ranks'].items()}
self.meta = config['meta']
self.load_meta = True
self.use_bpe_tokenizer = 'stoi' not in self.meta or 'itos' not in self.meta
self.speaker_cond = self.meta.get('speaker_cond')
speaker_emb_size = None
if self.speaker_cond:
speaker_emb_size = self.meta['speaker_emb_size']
model_args = config['model_args']
if 'causal' in self.checkpoint_config and self.checkpoint_config['causal'] is False:
self._encodec_ctx_window = model_args['block_size']
gptconf = GPTConfig(**model_args)
self.model = GPT(gptconf, speaker_emb_dim=speaker_emb_size)
load_model(self.model, self.config.ckpt_path)
super()._init_model()
class MetaVoiceModel(FAMTTS):
def __init__(self, model_name: str, *, seed: int = 1337, output_dir: str = 'outputs', enforce_safetensors: bool = True):
self._dtype = get_default_dtype()
self._device = get_device()
if os.path.exists(model_name):
if enforce_safetensors:
assert all(x in os.listdir(model_name) for x in ('config.json', 'second_stage.safetensors', 'first_stage.safetensors', 'speaker_encoder.safetensors')), 'Model is not compatible with safetensors'
self._model_dir = model_name
else:
print('WARNING: metavoice is allowing the use of non-safetensors models. Ensure you understand the risks of loading untrusted models at https://pytorch.org/docs/stable/generated/torch.load.html')
self._model_dir = model_name
else:
if enforce_safetensors:
fs = HfFileSystem()
files = [os.path.basename(x) for x in fs.ls(model_name, detail=False)]
assert all(x in files for x in ('config.json', 'second_stage.safetensors', 'first_stage.safetensors', 'speaker_encoder.safetensors')), 'Model is not compatible with safetensors'
self._model_dir = snapshot_download(repo_id=model_name, allow_patterns='second_stage.safetensors,first_stage.safetensors,speaker_encoder.safetensors,config.json')
else:
print('WARNING: metavoice is allowing the use of non-safetensors models. Ensure you understand the risks of loading untrusted models at https://pytorch.org/docs/stable/generated/torch.load.html')
self._model_dir = snapshot_download(repo_id=model_name)
self.first_stage_adapter = FlattenedInterleavedEncodec2Codebook(end_of_audio_token=self.END_OF_AUDIO_TOKEN)
self.output_dir = output_dir
os.makedirs(self.output_dir, exist_ok=True)
is_safetensors = os.path.exists(f'{self._model_dir}/second_stage.safetensors')
second_stage_ckpt_path = f'{self._model_dir}/{"second_stage.safetensors" if is_safetensors else "second_stage.pt"}'
config_second_stage = InferenceConfig(
ckpt_path=second_stage_ckpt_path,
num_samples=1,
seed=seed,
device=self._device,
dtype=self._dtype,
compile=False,
init_from='safetensors' if is_safetensors else 'resume',
output_dir=self.output_dir,
)
data_adapter_second_stage = TiltedEncodec(end_of_audio_token=self.END_OF_AUDIO_TOKEN)
self.llm_second_stage = Model(
config_second_stage, TrainedBPETokeniser, EncodecDecoder, data_adapter_fn=data_adapter_second_stage.decode
)
self.enhancer = get_enhancer('df')
self.precision = {'float16': torch.float16, 'bfloat16': torch.bfloat16}[self._dtype]
build_model_kwargs = {
'precision': self.precision,
'device': self._device,
'compile': False,
'compile_prefill': True,
}
if is_safetensors:
fam.llm.fast_inference_utils._load_model = load_safetensors_model
checkpoint_path, spk_emb_ckpt_path = Path(f'{self._model_dir}/first_stage.safetensors'), Path(f'{self._model_dir}/speaker_encoder.safetensors')
else:
checkpoint_path, spk_emb_ckpt_path= Path(f'{self._model_dir}/first_stage.pt'), Path(f'{self._model_dir}/speaker_encoder.pt')
self.model, self.tokenizer, self.smodel, self.model_size = fam.llm.fast_inference_utils.build_model(
checkpoint_path=checkpoint_path,
spk_emb_ckpt_path=spk_emb_ckpt_path,
**build_model_kwargs
)
@torch.inference_mode()
def generate(self, text: str, source: str = 'https://upload.wikimedia.org/wikipedia/commons/e/e1/King_Charles_Addresses_Scottish_Parliament_-_12_September_2022.flac'):
self.synthesise(text, source)
def save(self, path: str):
save_model(self.model, os.path.join(path, 'first_stage.safetensors'))
save_model(self.smodel, os.path.join(path, 'speaker_encoder.safetensors'))
save_model(self.llm_second_stage.model, os.path.join(path, 'second_stage.safetensors'))
@classmethod
def from_hub(cls, path: str):
# TODO: TEMPORARY OUTPUT DIR
return cls(path, enforce_safetensors=True)