gordonchan's picture
Upload 41 files
ca56e6a verified
import gc
import time
import uuid
from threading import Thread
from types import MethodType
from typing import Iterable, Dict, Any
import torch
from transformers import (
TextIteratorStreamer,
PreTrainedModel,
PreTrainedTokenizer,
)
from api.generation.qwen import check_is_qwen
from api.generation.utils import (
prepare_logits_processor,
is_partial_stop,
apply_stopping_strings,
)
@torch.inference_mode()
def generate_stream(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
params: Dict[str, Any],
):
# Read parameters
input_ids = params.get("inputs")
prompt = params.get("prompt")
model_name = params.get("model", "llm")
temperature = float(params.get("temperature", 1.0))
repetition_penalty = float(params.get("repetition_penalty", 1.0))
top_p = float(params.get("top_p", 1.0))
top_k = int(params.get("top_k", -1)) # -1 means disable
max_new_tokens = int(params.get("max_tokens", 256))
logprobs = params.get("logprobs")
echo = bool(params.get("echo", True))
stop_str = params.get("stop")
stop_token_ids = params.get("stop_token_ids") or []
if tokenizer.eos_token_id not in stop_token_ids:
stop_token_ids.append(tokenizer.eos_token_id)
logits_processor = prepare_logits_processor(
temperature, repetition_penalty, top_p, top_k
)
output_ids = list(input_ids)
input_echo_len = len(input_ids)
device = model.device
if model.config.is_encoder_decoder:
encoder_output = model.encoder(
input_ids=torch.as_tensor([input_ids], device=device)
)[0]
start_ids = torch.as_tensor(
[[model.generation_config.decoder_start_token_id]],
dtype=torch.int64,
device=device,
)
else:
start_ids = torch.as_tensor([input_ids], device=device)
past_key_values, sent_interrupt = None, False
token_logprobs = [None] # The first token has no logprobs.
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
created: int = int(time.time())
previous_text = ""
for i in range(max_new_tokens):
if i == 0: # prefill
if model.config.is_encoder_decoder:
out = model.decoder(
input_ids=start_ids,
encoder_hidden_states=encoder_output,
use_cache=True,
)
logits = model.lm_head(out[0])
else:
out = model(torch.as_tensor([input_ids], device=device), use_cache=True)
logits = out.logits
past_key_values = out.past_key_values
if logprobs is not None:
# Prefull logprobs for the prompt.
shift_input_ids = start_ids[..., 1:].contiguous()
shift_logits = logits[..., :-1, :].contiguous()
shift_logits = torch.log_softmax(shift_logits, dim=-1).tolist()
for label_id, logit in zip(
shift_input_ids[0].tolist(), shift_logits[0]
):
token_logprobs.append(logit[label_id])
else: # decoding
if model.config.is_encoder_decoder:
out = model.decoder(
input_ids=torch.as_tensor(
[output_ids if sent_interrupt else [token]], device=device
),
encoder_hidden_states=encoder_output,
use_cache=True,
past_key_values=None if sent_interrupt else past_key_values,
)
sent_interrupt = False
logits = model.lm_head(out[0])
else:
out = model(
input_ids=torch.as_tensor(
[output_ids if sent_interrupt else [token]], device=device
),
use_cache=True,
past_key_values=None if sent_interrupt else past_key_values,
)
sent_interrupt = False
logits = out.logits
past_key_values = out.past_key_values
if logits_processor:
if repetition_penalty > 1.0:
tmp_output_ids = torch.as_tensor([output_ids], device=logits.device)
else:
tmp_output_ids = None
last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0]
else:
last_token_logits = logits[0, -1, :]
if device == "mps":
# Switch to CPU by avoiding some bugs in mps backend.
last_token_logits = last_token_logits.float().to("cpu")
if temperature < 1e-5 or top_p < 1e-8: # greedy
_, indices = torch.topk(last_token_logits, 2)
tokens = [int(index) for index in indices.tolist()]
else:
probs = torch.softmax(last_token_logits, dim=-1)
indices = torch.multinomial(probs, num_samples=2)
tokens = [int(token) for token in indices.tolist()]
token = tokens[0]
output_ids.append(token)
if logprobs is not None:
# Cannot use last_token_logits because logprobs is based on raw logits.
token_logprobs.append(
torch.log_softmax(logits[0, -1, :], dim=-1)[token].tolist()
)
if token in stop_token_ids:
stopped = True
else:
stopped = False
# Yield the output tokens
if i % 2 == 0 or i == max_new_tokens - 1 or stopped:
if echo:
tmp_output_ids = output_ids
rfind_start = len(prompt)
else:
tmp_output_ids = output_ids[input_echo_len:]
rfind_start = 0
output = tokenizer.decode(
tmp_output_ids,
skip_special_tokens=False if check_is_qwen(model) else True, # fix for qwen react
spaces_between_special_tokens=False,
clean_up_tokenization_spaces=True,
)
ret_logprobs = None
if logprobs is not None:
ret_logprobs = {
"text_offset": [],
"tokens": [
tokenizer.decode(token)
for token in (
output_ids if echo else output_ids[input_echo_len:]
)
],
"token_logprobs": token_logprobs if echo else token_logprobs[input_echo_len:],
"top_logprobs": [{}] * len(token_logprobs if echo else token_logprobs[input_echo_len:]),
}
# Compute text_offset
curr_pos = 0
for text in ret_logprobs["tokens"]:
ret_logprobs["text_offset"].append(curr_pos)
curr_pos += len(text)
partially_stopped, finish_reason = False, None
if stop_str:
if isinstance(stop_str, str):
pos = output.rfind(stop_str, rfind_start)
if pos != -1:
output = output[:pos]
stopped = True
else:
partially_stopped = is_partial_stop(output, stop_str)
elif isinstance(stop_str, Iterable):
for each_stop in stop_str:
pos = output.rfind(each_stop, rfind_start)
if pos != -1:
output = output[:pos]
stopped = True
if each_stop == "Observation:":
finish_reason = "function_call"
break
else:
partially_stopped = is_partial_stop(output, each_stop)
if partially_stopped:
break
else:
raise ValueError("Invalid stop field type.")
# Prevent yielding partial stop sequence
if (not partially_stopped) and output and output[-1] != "�":
delta_text = output[len(previous_text):]
previous_text = output
yield {
"id": completion_id,
"object": "text_completion",
"created": created,
"model": model_name,
"delta": delta_text,
"text": output,
"logprobs": ret_logprobs,
"finish_reason": finish_reason,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": i,
"total_tokens": input_echo_len + i,
},
}
if stopped:
break
yield {
"id": completion_id,
"object": "text_completion",
"created": created,
"model": model_name,
"delta": "",
"text": output,
"logprobs": ret_logprobs,
"finish_reason": "stop",
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": i,
"total_tokens": input_echo_len + i,
},
}
# Clean
del past_key_values, out
gc.collect()
torch.cuda.empty_cache()
@torch.inference_mode()
def generate_stream_v2(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
params: Dict[str, Any],
):
input_ids = params.get("inputs")
functions = params.get("functions")
model_name = params.get("model", "llm")
temperature = float(params.get("temperature", 1.0))
repetition_penalty = float(params.get("repetition_penalty", 1.0))
top_p = float(params.get("top_p", 1.0))
top_k = int(params.get("top_k", 40))
max_new_tokens = int(params.get("max_tokens", 256))
stop_token_ids = params.get("stop_token_ids") or []
if tokenizer.eos_token_id not in stop_token_ids:
stop_token_ids.append(tokenizer.eos_token_id)
stop_strings = params.get("stop", [])
input_echo_len = len(input_ids)
device = model.device
generation_kwargs = dict(
input_ids=torch.tensor([input_ids], device=device),
do_sample=True,
temperature=temperature,
top_p=top_p,
top_k=top_k,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
pad_token_id=tokenizer.pad_token_id,
)
if temperature <= 1e-5:
generation_kwargs["do_sample"] = False
generation_kwargs.pop("top_k")
streamer = TextIteratorStreamer(
tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True
)
generation_kwargs["streamer"] = streamer
if "GenerationMixin" not in str(model.generate.__func__):
model.generate = MethodType(PreTrainedModel.generate, model)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
generated_text, func_call_found = "", False
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
created: int = int(time.time())
previous_text = ""
for i, new_text in enumerate(streamer):
generated_text += new_text
if functions:
_, func_call_found = apply_stopping_strings(generated_text, ["Observation:"])
generated_text, stop_found = apply_stopping_strings(generated_text, stop_strings)
if generated_text and generated_text[-1] != "�":
delta_text = generated_text[len(previous_text):]
previous_text = generated_text
yield {
"id": completion_id,
"object": "text_completion",
"created": created,
"model": model_name,
"delta": delta_text,
"text": generated_text,
"logprobs": None,
"finish_reason": "function_call" if func_call_found else None,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": i,
"total_tokens": input_echo_len + i,
},
}
if stop_found:
break
yield {
"id": completion_id,
"object": "text_completion",
"created": created,
"model": model_name,
"delta": "",
"text": generated_text,
"logprobs": None,
"finish_reason": "stop",
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": i,
"total_tokens": input_echo_len + i,
},
}