bf16_vs_fp8 / fastchat /model /model_xfastertransformer.py
zjasper666's picture
Upload folder using huggingface_hub
8655a4b verified
raw
history blame
2.35 kB
import gc
from threading import Thread
import torch
from transformers import TextIteratorStreamer
@torch.inference_mode()
def generate_stream_xft(
model,
tokenizer,
params,
device,
context_len=8192,
stream_interval=2,
judge_sent_end=False,
):
prompt = params["prompt"]
repetition_penalty = float(params.get("repetition_penalty", 1.0))
# unused now, and placehold for future.
# temperature = float(params.get("temperature", 1.0))
# top_p = float(params.get("top_p", 1.0))
max_new_tokens = int(params.get("max_new_tokens", 4096))
echo = params.get("echo", True)
inputs = tokenizer(
prompt, return_tensors="pt", padding=model.config.padding
).input_ids
input_echo_len = len(inputs[0])
max_len = max_new_tokens + input_echo_len
decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, **decode_config)
generation_kwargs = {
"input_ids": inputs,
"streamer": streamer,
"max_length": max_len,
"num_beams": model.config.beam_width,
"length_penalty": repetition_penalty,
"num_return_sequences": model.config.num_return_sequences,
"early_stopping": model.config.early_stopping,
"eos_token_id": model.config.eos_token_id,
"pad_token_id": model.config.pad_token_id,
}
thread = Thread(target=model.model.generate, kwargs=generation_kwargs)
thread.start()
if echo:
# means keep the prompt
output = prompt
else:
output = ""
i = 0
for i, new_text in enumerate(streamer):
output += new_text
yield {
"text": output,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": i,
"total_tokens": input_echo_len + i,
},
"finish_reason": None,
}
output = output.strip()
if i == max_new_tokens - 1:
finish_reason = "length"
else:
finish_reason = "stop"
yield {
"text": output,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": i,
"total_tokens": input_echo_len + i,
},
"finish_reason": finish_reason,
}
gc.collect()