|
import torch |
|
import gc |
|
|
|
import os |
|
import time |
|
import random |
|
from typing import Dict, Optional, Sequence, List, Tuple |
|
from transformers.cache_utils import Cache, DynamicCache |
|
from transformers import ( |
|
LlamaModel, |
|
LlamaForCausalLM, |
|
GenerationConfig, |
|
StoppingCriteria, |
|
StoppingCriteriaList, |
|
TextIteratorStreamer, |
|
) |
|
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask |
|
import torch.nn.functional as F |
|
|
|
|
|
def get_jacobian_trajectory( |
|
model, tokenizer, input_ids, attention_mask, max_new_tokens |
|
): |
|
bsz = input_ids.shape[0] |
|
prompt_len = [torch.sum(t) for t in attention_mask] |
|
max_prompt_len = max(prompt_len) |
|
total_len = max_prompt_len + max_new_tokens |
|
|
|
|
|
tokens = torch.full( |
|
(bsz, total_len), tokenizer.pad_token_id, dtype=torch.long, device=model.device |
|
) |
|
for i in range(bsz): |
|
tokens[i, :] = torch.tensor( |
|
random.choices(input_ids[i][attention_mask[i] == 1], k=total_len), |
|
dtype=torch.long, |
|
device=model.device, |
|
) |
|
tokens[i, : prompt_len[i]] = input_ids[i][: prompt_len[i]].to( |
|
dtype=torch.long, device=model.device |
|
) |
|
itr = 0 |
|
next_generation = tokens |
|
generate_attention_mask = torch.full_like(next_generation, 1).to(model.device) |
|
accurate_lengths = torch.tensor([prompt_len[i].item()] * bsz, device=model.device) |
|
prev_len = 0 |
|
while True: |
|
current_generation = next_generation |
|
with torch.no_grad(): |
|
logits = model(current_generation, generate_attention_mask).logits |
|
next_generation = torch.argmax( |
|
torch.nn.functional.softmax(logits, dim=-1) / 0.001, dim=-1 |
|
) |
|
|
|
|
|
for i in range(bsz): |
|
next_generation[i, :] = torch.cat( |
|
( |
|
tokens[i, : prompt_len[i]], |
|
next_generation[i, prompt_len[i] - 1 : total_len - 1], |
|
), |
|
dim=0, |
|
) |
|
|
|
if ( |
|
torch.all(torch.eq(next_generation, current_generation)).item() |
|
and itr == max_new_tokens |
|
or len( |
|
torch.where( |
|
current_generation[0, : accurate_lengths[0]] |
|
== tokenizer.eos_token_id |
|
)[0] |
|
) |
|
> 0 |
|
): |
|
|
|
return next_generation, itr |
|
|
|
|
|
if itr != 0: |
|
if torch.all(torch.eq(next_generation, current_generation)).item(): |
|
matched_position = total_len |
|
else: |
|
matched_position = ( |
|
torch.eq(current_generation, next_generation).squeeze(0) == False |
|
).nonzero(as_tuple=True)[0][0] |
|
fast_forward_cnt = matched_position - accurate_lengths[0] |
|
|
|
for i in range(bsz): |
|
accurate_lengths[i] = matched_position.item() |
|
|
|
|
|
generated_str = tokenizer.decode( |
|
next_generation[0, prompt_len[0] : accurate_lengths[0]], |
|
skip_special_tokens=True, |
|
spaces_between_special_tokens=False, |
|
clean_up_tokenization_spaces=True, |
|
) |
|
print(generated_str[prev_len:], flush=True, end="") |
|
prev_len = len(generated_str) |
|
|
|
if torch.all(torch.eq(next_generation, current_generation)).item(): |
|
|
|
return next_generation, itr |
|
|
|
itr += 1 |
|
|
|
|
|
def generate_stream_cllm( |
|
model, |
|
tokenizer, |
|
params, |
|
device, |
|
context_len, |
|
stream_interval=2, |
|
judge_sent_end=False, |
|
): |
|
|
|
prompt = params["prompt"] |
|
inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
max_new_tokens = int(params.get("n_token_seq_length", 32)) |
|
max_new_seq_len = int(params.get("max_new_tokens", 1024)) |
|
|
|
prompt_len = torch.sum(inputs["attention_mask"], dim=-1) |
|
generation = inputs["input_ids"] |
|
input_echo_len = len(generation) |
|
|
|
|
|
itr = 0 |
|
eos_reached = False |
|
while True: |
|
if itr == 0: |
|
input_ids = inputs["input_ids"] |
|
input_masks = inputs["attention_mask"] |
|
else: |
|
input_masks = torch.ones_like(input_ids).to(device) |
|
for j in range(bsz): |
|
input_masks[j][ |
|
torch.sum(inputs["attention_mask"], dim=-1)[j] |
|
+ itr * max_new_tokens : |
|
] = 0 |
|
|
|
bsz = input_ids.shape[0] |
|
eos_reached = torch.tensor([False] * bsz, device=device) |
|
|
|
generation, iter_steps = get_jacobian_trajectory( |
|
model=model, |
|
tokenizer=tokenizer, |
|
input_ids=input_ids, |
|
attention_mask=input_masks, |
|
max_new_tokens=max_new_tokens, |
|
) |
|
|
|
|
|
for j in range(bsz): |
|
prompt_len = torch.sum(input_masks, dim=-1) |
|
eos_positions = torch.where(generation[j] == tokenizer.eos_token_id)[0] |
|
|
|
if len(eos_positions) == 0: |
|
|
|
generation[j][prompt_len[j] + max_new_tokens :] = tokenizer.pad_token_id |
|
continue |
|
|
|
else: |
|
if len(eos_positions) != 0: |
|
eos_reached[j] = True |
|
generation[j, int(eos_positions[0]) + 1 :] = tokenizer.pad_token_id |
|
|
|
itr += 1 |
|
|
|
if all(eos_reached) or itr * max_new_tokens >= max_new_seq_len: |
|
break |
|
input_ids = generation[ |
|
torch.where(eos_reached == False)[0].tolist(), ... |
|
] |
|
|
|
if all(eos_reached): |
|
finish_reason = "eos" |
|
elif itr * max_new_tokens > max_new_seq_len: |
|
finish_reason = "length" |
|
else: |
|
finish_reason = "stop" |
|
|
|
output = tokenizer.decode(input_ids[0], skip_special_tokens=False) |
|
|
|
yield { |
|
"text": "", |
|
"usage": { |
|
"prompt_tokens": input_echo_len, |
|
"completion_tokens": itr * max_new_tokens, |
|
"total_tokens": input_echo_len + itr * max_new_tokens, |
|
}, |
|
"finish_reason": finish_reason, |
|
} |
|
|
|
|
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
if device == "xpu": |
|
torch.xpu.empty_cache() |
|
if device == "npu": |
|
torch.npu.empty_cache() |
|
|