File size: 16,929 Bytes
b3c4c5d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 |
import numpy as np
import torch
import time
def generate_embeddings(model, tokenizer, text, bicodec, prompt_text=None, prompt_audio=None):
"""
为 Spark LLM 生成预测所需的输入嵌入
Args:
model: Spark LLM 模型
tokenizer: 文本分词器
text: 要生成语音的文本
bicodec: BiCodecTokenizer 实例
prompt_text: 提示文本(可选)
prompt_audio: 提示音频数组(可选)
Returns:
dict: 包含 input_embs 的字典,用于模型预测
"""
device = next(model.parameters()).device
with torch.no_grad():
# 1. 处理提示音频,提取 global_tokens 和 semantic_tokens
if prompt_audio is not None:
# 确保音频数据是 float32 类型
audio_data = np.array(prompt_audio, dtype=np.float32)
target_sample_rate = bicodec.config['sample_rate']
# 检查是否需要重采样
# 注意:这里假设 prompt_audio 已经是从 soundfile 加载的,采样率信息在外部处理
# BiCodecTokenizer 期望 16kHz 采样率的音频
print(f"BiCodecTokenizer 期望的采样率: {target_sample_rate}Hz")
print(f"音频数据形状: {audio_data.shape}")
# 使用 BiCodec 提取 tokens (返回顺序: global_tokens, semantic_tokens)
global_tokens, semantic_tokens = bicodec.tokenize(audio_data)
global_tokens = global_tokens.squeeze(0).squeeze(0).detach().cpu().tolist()
semantic_tokens = semantic_tokens.squeeze(0).squeeze(0).detach().cpu().tolist()
else:
global_tokens = []
semantic_tokens = []
# 2. 处理文本
if prompt_text is not None:
# 连接提示文本和目标文本
full_text = prompt_text + text
# 初始的 semantic tokens 等于 prompt_audio 提取的 semantic tokens
initial_semantic_tokens = semantic_tokens.copy()
else:
full_text = text
initial_semantic_tokens = []
# 3. 获取文本 tokens
text_tokens = tokenizer.encode(full_text, add_special_tokens=False)
# 4. 转换为张量
text_tokens_tensor = torch.tensor(text_tokens, dtype=torch.long, device=device)
global_tokens_tensor = torch.tensor(global_tokens, dtype=torch.long, device=device)
semantic_tokens_tensor = torch.tensor(initial_semantic_tokens, dtype=torch.long, device=device)
# 5. 获取嵌入
text_embs = model.text_embedder(text_tokens_tensor)
global_embs = model.global_embedder(global_tokens_tensor)
semantic_embs = model.model.embeddings(semantic_tokens_tensor)
# 6. 获取特殊标记嵌入
tag_0_emb = model.tts_tag_embedder(torch.tensor([0], dtype=torch.long, device=device))
tag_1_emb = model.tts_tag_embedder(torch.tensor([1], dtype=torch.long, device=device))
tag_2_emb = model.tts_tag_embedder(torch.tensor([2], dtype=torch.long, device=device))
# 7. 连接嵌入
input_embs = torch.cat([
tag_2_emb,
text_embs,
tag_0_emb,
global_embs,
tag_1_emb,
semantic_embs
], dim=0)
# 8. 添加批次维度
input_embs = input_embs.unsqueeze(0) # [1, seq_len, hidden_size]
return {
"input_embs": input_embs,
"global_tokens": global_tokens_tensor,
}
def generate_embeddings_batch(model, tokenizer, texts, bicodec, prompt_text=None, prompt_audio=None):
"""
为 Spark LLM 批量生成预测所需的输入嵌入,支持多个文本的并行处理
Args:
model: Spark LLM 模型
tokenizer: 文本分词器
texts: 要生成语音的文本列表
bicodec: BiCodecTokenizer 实例
prompt_text: 提示文本(可选)
prompt_audio: 提示音频数组(可选)
Returns:
tuple: (embeddings_dict, attention_mask) 包含批量 input_embs 的字典和注意力掩码
"""
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
batch_size = len(texts)
with torch.no_grad():
# 1. 处理提示音频,提取 global_tokens 和 semantic_tokens
if prompt_audio is not None:
# 确保音频数据是 float32 类型
audio_data = np.array(prompt_audio, dtype=np.float32)
target_sample_rate = bicodec.config['sample_rate']
print(f"BiCodecTokenizer 期望的采样率: {target_sample_rate}Hz")
print(f"音频数据形状: {audio_data.shape}")
# 使用 BiCodec 提取 tokens (返回顺序: global_tokens, semantic_tokens)
global_tokens, semantic_tokens = bicodec.tokenize(audio_data)
global_tokens = global_tokens.squeeze(0).squeeze(0).detach().cpu().tolist()
semantic_tokens = semantic_tokens.squeeze(0).squeeze(0).detach().cpu().tolist()
else:
global_tokens = []
semantic_tokens = []
# 2. 处理所有文本,获取每个样本的嵌入组件
all_input_embs = []
all_attention_masks = []
for text in texts:
# 处理单个文本
if prompt_text is not None:
full_text = prompt_text + text
initial_semantic_tokens = semantic_tokens.copy()
else:
full_text = text
initial_semantic_tokens = []
# 获取文本 tokens
text_tokens = tokenizer.encode(full_text, add_special_tokens=False)
# 转换为张量
text_tokens_tensor = torch.tensor(text_tokens, dtype=torch.long, device=device)
global_tokens_tensor = torch.tensor(global_tokens, dtype=torch.long, device=device)
semantic_tokens_tensor = torch.tensor(initial_semantic_tokens, dtype=torch.long, device=device)
# 获取嵌入
text_embs = model.text_embedder(text_tokens_tensor)
global_embs = model.global_embedder(global_tokens_tensor)
semantic_embs = model.model.embeddings(semantic_tokens_tensor)
# 获取特殊标记嵌入
tag_0_emb = model.tts_tag_embedder(torch.tensor([0], dtype=torch.long, device=device))
tag_1_emb = model.tts_tag_embedder(torch.tensor([1], dtype=torch.long, device=device))
tag_2_emb = model.tts_tag_embedder(torch.tensor([2], dtype=torch.long, device=device))
# 连接嵌入
input_embs = torch.cat([
tag_2_emb,
text_embs,
tag_0_emb,
global_embs,
tag_1_emb,
semantic_embs
], dim=0) # [seq_len, hidden_size]
all_input_embs.append(input_embs)
all_attention_masks.append(torch.ones(input_embs.shape[0], dtype=torch.long, device=device))
# 3. 找到最大序列长度
max_seq_len = max(emb.shape[0] for emb in all_input_embs)
hidden_size = all_input_embs[0].shape[1]
# 4. 创建批量张量,使用 left padding 和零填充
batch_input_embs = torch.zeros(batch_size, max_seq_len, hidden_size, device=device, dtype=dtype)
batch_attention_mask = torch.zeros(batch_size, max_seq_len, dtype=torch.long, device=device)
for i, (input_embs, attention_mask) in enumerate(zip(all_input_embs, all_attention_masks)):
seq_len = input_embs.shape[0]
# Left padding: 将序列放在右侧,左侧填充零
batch_input_embs[i, -seq_len:, :] = input_embs
batch_attention_mask[i, -seq_len:] = attention_mask
# 5. 创建 global_tokens 的批量版本
global_tokens_tensor = torch.tensor(global_tokens, dtype=torch.long, device=device, requires_grad=False)
batch_global_tokens = global_tokens_tensor.unsqueeze(0).expand(batch_size, -1)
return {
"input_embs": batch_input_embs,
"global_tokens": batch_global_tokens,
}, batch_attention_mask
# Repetition Aware Sampling in VALL-E 2
def ras_sampling(weighted_scores, decoded_tokens, top_p=0.8, top_k=25, win_size=10, tau_r=0.1):
top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids).sum().item()
if rep_num >= win_size * tau_r:
top_ids = random_sampling(weighted_scores)
return top_ids
def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25):
prob, indices = [], []
cum_prob = 0.0
sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True)
for i in range(len(sorted_idx)):
# sampling both top-p and numbers.
if cum_prob < top_p and len(prob) < top_k:
cum_prob += sorted_value[i]
prob.append(sorted_value[i])
indices.append(sorted_idx[i])
else:
break
prob = torch.tensor(prob).to(weighted_scores)
indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device)
top_ids = indices[prob.multinomial(1, replacement=True)]
return top_ids
def random_sampling(weighted_scores):
top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True)
return top_ids
def generate(model,
inputs_embeds,
attention_mask,
new_max_tokens,
top_k,
top_p,
temperate,
eos_token_id,
pad_token_id,
past_key_values
):
"""
seperate two stages of generation:
1. prefill
2. decode
we will measure the time of each stage and the total time
"""
start_time = time.time()
model.eval()
batch_size = inputs_embeds.shape[0]
decoded_tokens = [[] for _ in range(batch_size)]
is_decoding = [True for _ in range(batch_size)]
with torch.no_grad():
# 1. prefill
outputs = model.model.forward(
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=True,
output_attentions=False,
output_hidden_states=True,
return_dict=False
)
hidden_states = outputs[0]
past_key_values = outputs[1]
prefill_time = time.time() - start_time
tokens = attention_mask.shape[0]*attention_mask.shape[1]
print(f"Prefill time: {prefill_time} seconds, all tokens is {tokens}, speed is {tokens/prefill_time} tokens/s ")
# 2. decode
start_time = time.time()
#sampling the logits using top_k, top_p, temperature
decoded_tokens_size = 0
while True:
logits = model.lm_head(hidden_states)
last_time_decoded = []
logits = logits[:, -1, :]
continue_decoding = False
for i in range(batch_size):
if is_decoding[i]:
logits_i = logits[i, :]
top_ids = ras_sampling(logits_i, decoded_tokens[i], top_p=top_p, top_k=top_k).item()
decoded_tokens[i].append(top_ids)
last_time_decoded.append([top_ids])
if top_ids == eos_token_id:
is_decoding[i] = False
else:
continue_decoding = True
decoded_tokens_size += 1
else:
decoded_tokens[i].append(pad_token_id)
last_time_decoded.append([pad_token_id])
if not continue_decoding:
break
last_time_decoded = torch.tensor(last_time_decoded, dtype=torch.long, device=device)
lm_input = model.get_input_embeddings()(last_time_decoded)
outputs = model.model.forward(
inputs_embeds=lm_input,
past_key_values=past_key_values,
use_cache=True,
output_attentions=False,
output_hidden_states=True,
return_dict=False
)
hidden_states = outputs[0]
past_key_values = outputs[1]
decode_time = time.time() - start_time
print(f"Decode time: {decode_time} seconds, all tokens is {decoded_tokens_size}, speed is {decoded_tokens_size/decode_time} tokens/s ")
print(f"decoded_tokens: {decoded_tokens}")
return decoded_tokens, past_key_values
if __name__ == "__main__":
import os
import sys
current_dir = os.path.dirname(os.path.abspath(__file__))
print('add current dir to sys.path', current_dir)
sys.path.append(current_dir)
device = 'cuda:2'
from sparktts.models.audio_tokenizer import BiCodecTokenizer
from transformers import AutoTokenizer, AutoModelForCausalLM
import soundfile as sf
audio_tokenizer = BiCodecTokenizer(model_dir=current_dir, device=device)
print(audio_tokenizer)
tokenizer = AutoTokenizer.from_pretrained(current_dir, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(current_dir, trust_remote_code=True)
print(tokenizer)
print(model)
model = model.bfloat16().to(device)
model.eval()
prompt_text = "我们并不是通过物理移动手段找到星河的。"
prompt_audio_file = os.path.join(current_dir, 'kafka.wav')
prompt_audio, sampling_rate = sf.read(prompt_audio_file)
print(f"Loaded prompt audio from {prompt_audio_file}")
print(f"Original sampling rate: {sampling_rate}Hz")
print(f"Audio shape: {prompt_audio.shape}")
target_sample_rate = audio_tokenizer.config['sample_rate']
if sampling_rate != target_sample_rate:
print(f"Resampling from {sampling_rate}Hz to {target_sample_rate}Hz...")
from librosa import resample
prompt_audio = resample(prompt_audio, orig_sr=sampling_rate, target_sr=target_sample_rate)
prompt_audio = np.array(prompt_audio, dtype=np.float32)
print(f"Resampled audio shape: {prompt_audio.shape}")
else:
print(f"Audio sampling rate already matches target ({target_sample_rate}Hz)")
texts = ["为了点燃青少年对科技的热情,培养他们的创新思维与动手能力,杏花岭区巨轮街道社区教育学校携手中车社区教育分校,与太原市科学技术协会联手,于暑期精心策划了一场别开生面的青少年数智技术服务港探索之旅,吸引了众多社区青少年的积极参与。"]
eos_token_id = model.config.vocab_size - 1
print(f"EOS token ID: {eos_token_id}")
# 生成输入嵌入
embeddings,attention_mask = generate_embeddings_batch(
model=model,
tokenizer=tokenizer,
texts=texts,
bicodec=audio_tokenizer,
prompt_text=prompt_text,
prompt_audio=prompt_audio
)
input_embs = embeddings['input_embs']
print(f"input_embs shape: {input_embs.shape}")
print(f"attention_mask shape: {attention_mask.shape}")
print(f"input_embs dtype: {input_embs.dtype}")
print(f"attention_mask dtype: {attention_mask.dtype}")
print(f"input_embs: {input_embs}")
print(f"attention_mask: {attention_mask}")
print(f"input_embs: {input_embs}")
with torch.no_grad():
generate(model,
input_embs,
attention_mask,
new_max_tokens=1000,
top_k=25,
top_p=0.95,
temperate=1.0,
eos_token_id=eos_token_id,
pad_token_id=tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else tokenizer.eos_token_id,
past_key_values=None)
with torch.no_grad():
audio_tokens,past_key_values = generate(model,
input_embs,
attention_mask,
new_max_tokens=1000,
top_k=50,
top_p=0.8,
temperate=1.0,
eos_token_id=eos_token_id,
pad_token_id=tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else tokenizer.eos_token_id,
past_key_values=None)
audio_tokens = torch.tensor(audio_tokens, dtype=torch.long, device=device)
audio_tokens = audio_tokens[:,:-1]
print(f"audio_tokens: {audio_tokens}")
print(f"past_key_values: {past_key_values}")
global_tokens = embeddings['global_tokens']
print(f"global_tokens shape: {global_tokens.shape}")
print(f"audio_tokens shape: {audio_tokens.shape}")
with torch.no_grad():
wav = audio_tokenizer.detokenize(global_tokens, audio_tokens)
print(f"wav shape: {wav.shape}")
sf.write('test.wav', wav, audio_tokenizer.config['sample_rate']) |