Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,849 Bytes
164603c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 |
import torch
from nemo.collections.tts.models import AudioCodecModel
from dataclasses import dataclass
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
@dataclass
class Config:
model_name: str = "nineninesix/lfm-nano-codec-tts-exp-4-large-61468-st"
audiocodec_name: str = "nvidia/nemo-nano-codec-22khz-0.6kbps-12.5fps"
device_map: str = "auto"
tokeniser_length: int = 64400
start_of_text: int = 1
end_of_text: int = 2
max_new_tokens: int = 2000
temperature: float = .6
top_p: float = .95
repetition_penalty: float = 1.1
class NemoAudioPlayer:
def __init__(self, config, text_tokenizer_name: str = None) -> None:
self.conf = config
print(f"Loading NeMo codec model: {self.conf.audiocodec_name}")
# Load NeMo codec model
self.nemo_codec_model = AudioCodecModel.from_pretrained(
self.conf.audiocodec_name
).eval()
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Moving NeMo codec to device: {self.device}")
self.nemo_codec_model.to(self.device)
self.text_tokenizer_name = text_tokenizer_name
if self.text_tokenizer_name:
self.tokenizer = AutoTokenizer.from_pretrained(self.text_tokenizer_name)
# Token configuration
self.tokeniser_length = self.conf.tokeniser_length
self.start_of_text = self.conf.start_of_text
self.end_of_text = self.conf.end_of_text
self.start_of_speech = self.tokeniser_length + 1
self.end_of_speech = self.tokeniser_length + 2
self.start_of_human = self.tokeniser_length + 3
self.end_of_human = self.tokeniser_length + 4
self.start_of_ai = self.tokeniser_length + 5
self.end_of_ai = self.tokeniser_length + 6
self.pad_token = self.tokeniser_length + 7
self.audio_tokens_start = self.tokeniser_length + 10
self.codebook_size = 4032
def output_validation(self, out_ids):
"""Validate that output contains required speech tokens"""
start_of_speech_flag = self.start_of_speech in out_ids
end_of_speech_flag = self.end_of_speech in out_ids
if not (start_of_speech_flag and end_of_speech_flag):
raise ValueError('Special speech tokens not found in output!')
print("Output validation passed - speech tokens found")
def get_nano_codes(self, out_ids):
"""Extract nano codec tokens from model output"""
try:
start_a_idx = (out_ids == self.start_of_speech).nonzero(as_tuple=True)[0].item()
end_a_idx = (out_ids == self.end_of_speech).nonzero(as_tuple=True)[0].item()
except IndexError:
raise ValueError('Speech start/end tokens not found!')
if start_a_idx >= end_a_idx:
raise ValueError('Invalid audio codes sequence!')
audio_codes = out_ids[start_a_idx + 1: end_a_idx]
if len(audio_codes) % 4:
raise ValueError('Audio codes length must be multiple of 4!')
audio_codes = audio_codes.reshape(-1, 4)
# Decode audio codes
audio_codes = audio_codes - torch.tensor([self.codebook_size * i for i in range(4)])
audio_codes = audio_codes - self.audio_tokens_start
if (audio_codes < 0).sum().item() > 0:
raise ValueError('Invalid audio tokens detected!')
audio_codes = audio_codes.T.unsqueeze(0)
len_ = torch.tensor([audio_codes.shape[-1]])
print(f"Extracted audio codes shape: {audio_codes.shape}")
return audio_codes, len_
def get_text(self, out_ids):
"""Extract text from model output"""
try:
start_t_idx = (out_ids == self.start_of_text).nonzero(as_tuple=True)[0].item()
end_t_idx = (out_ids == self.end_of_text).nonzero(as_tuple=True)[0].item()
except IndexError:
raise ValueError('Text start/end tokens not found!')
txt_tokens = out_ids[start_t_idx: end_t_idx + 1]
text = self.tokenizer.decode(txt_tokens, skip_special_tokens=True)
return text
def get_waveform(self, out_ids):
"""Convert model output to audio waveform"""
out_ids = out_ids.flatten()
print("Starting waveform generation...")
# Validate output
self.output_validation(out_ids)
# Extract audio codes
audio_codes, len_ = self.get_nano_codes(out_ids)
audio_codes, len_ = audio_codes.to(self.device), len_.to(self.device)
print("Decoding audio with NeMo codec...")
with torch.inference_mode():
reconstructed_audio, _ = self.nemo_codec_model.decode(
tokens=audio_codes,
tokens_len=len_
)
output_audio = reconstructed_audio.cpu().detach().numpy().squeeze()
print(f"Generated audio shape: {output_audio.shape}")
if self.text_tokenizer_name:
text = self.get_text(out_ids)
return output_audio, text
else:
return output_audio, None
class KaniModel:
def __init__(self, config, player: NemoAudioPlayer, token: str) -> None:
self.conf = config
self.player = player
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Loading model: {self.conf.model_name}")
print(f"Target device: {self.device}")
# Load model with proper configuration
self.model = AutoModelForCausalLM.from_pretrained(
self.conf.model_name,
torch_dtype=torch.bfloat16,
device_map=self.conf.device_map,
token=token,
trust_remote_code=True # May be needed for some models
)
self.tokenizer = AutoTokenizer.from_pretrained(
self.conf.model_name,
token=token,
trust_remote_code=True
)
print(f"Model loaded successfully on device: {next(self.model.parameters()).device}")
def get_input_ids(self, text_prompt: str) -> tuple[torch.tensor]:
"""Prepare input tokens for the model"""
START_OF_HUMAN = self.player.start_of_human
END_OF_TEXT = self.player.end_of_text
END_OF_HUMAN = self.player.end_of_human
# Tokenize input text
input_ids = self.tokenizer(text_prompt, return_tensors="pt").input_ids
# Add special tokens
start_token = torch.tensor([[START_OF_HUMAN]], dtype=torch.int64)
end_tokens = torch.tensor([[END_OF_TEXT, END_OF_HUMAN]], dtype=torch.int64)
# Concatenate tokens
modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
attention_mask = torch.ones(1, modified_input_ids.shape[1], dtype=torch.int64)
print(f"Input sequence length: {modified_input_ids.shape[1]}")
return modified_input_ids, attention_mask
def model_request(self, input_ids: torch.tensor, attention_mask: torch.tensor) -> torch.tensor:
"""Generate tokens using the model"""
input_ids = input_ids.to(self.device)
attention_mask = attention_mask.to(self.device)
print("Starting model generation...")
print(f"Generation parameters: max_tokens={self.conf.max_new_tokens}, "
f"temp={self.conf.temperature}, top_p={self.conf.top_p}")
with torch.no_grad():
generated_ids = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=self.conf.max_new_tokens,
do_sample=True,
temperature=self.conf.temperature,
top_p=self.conf.top_p,
repetition_penalty=self.conf.repetition_penalty,
num_return_sequences=1,
eos_token_id=self.player.end_of_speech,
pad_token_id=self.tokenizer.pad_token_id if self.tokenizer.pad_token_id else self.tokenizer.eos_token_id
)
print(f"Generated sequence length: {generated_ids.shape[1]}")
return generated_ids.to('cpu')
def run_model(self, text: str):
"""Complete pipeline: text -> tokens -> generation -> audio"""
print(f"Processing text: '{text[:50]}{'...' if len(text) > 50 else ''}'")
# Prepare input
input_ids, attention_mask = self.get_input_ids(text)
# Generate tokens
model_output = self.model_request(input_ids, attention_mask)
# Convert to audio
audio, _ = self.player.get_waveform(model_output)
print("Text-to-speech generation completed successfully!")
return audio, text |