"""Inference for FastChat models.""" import abc import gc import math import sys import time from typing import Iterable, Optional, Dict import warnings import psutil import torch from transformers import ( AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer, LlamaForCausalLM, AutoModel, AutoModelForSeq2SeqLM, T5Tokenizer, AutoConfig, ) from transformers.generation.logits_process import ( LogitsProcessorList, RepetitionPenaltyLogitsProcessor, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, ) from fastchat.conversation import get_conv_template, SeparatorStyle from fastchat.model.model_adapter import ( load_model, get_conversation_template, get_generate_stream_function, ) from fastchat.modules.gptq import GptqConfig from fastchat.utils import is_partial_stop, is_sentence_complete, get_context_length def prepare_logits_processor( temperature: float, repetition_penalty: float, top_p: float, top_k: int ) -> LogitsProcessorList: processor_list = LogitsProcessorList() # TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases. if temperature >= 1e-5 and temperature != 1.0: processor_list.append(TemperatureLogitsWarper(temperature)) if repetition_penalty > 1.0: processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty)) if 1e-8 <= top_p < 1.0: processor_list.append(TopPLogitsWarper(top_p)) if top_k > 0: processor_list.append(TopKLogitsWarper(top_k)) return processor_list @torch.inference_mode() def generate_stream( model, tokenizer, params: Dict, device: str, context_len: int, stream_interval: int = 2, judge_sent_end: bool = False, ): # Read parameters prompt = params["prompt"] len_prompt = len(prompt) temperature = float(params.get("temperature", 1.0)) repetition_penalty = float(params.get("repetition_penalty", 1.0)) top_p = float(params.get("top_p", 1.0)) top_k = int(params.get("top_k", -1)) # -1 means disable max_new_tokens = int(params.get("max_new_tokens", 256)) echo = bool(params.get("echo", True)) stop_str = params.get("stop", None) stop_token_ids = params.get("stop_token_ids", None) or [] stop_token_ids.append(tokenizer.eos_token_id) logits_processor = prepare_logits_processor( temperature, repetition_penalty, top_p, top_k ) input_ids = tokenizer(prompt).input_ids output_ids = list(input_ids) if model.config.is_encoder_decoder: max_src_len = context_len else: # truncate max_src_len = context_len - max_new_tokens - 8 input_ids = input_ids[-max_src_len:] input_echo_len = len(input_ids) if model.config.is_encoder_decoder: encoder_output = model.encoder( input_ids=torch.as_tensor([input_ids], device=device) )[0] start_ids = torch.as_tensor( [[model.generation_config.decoder_start_token_id]], dtype=torch.int64, device=device, ) past_key_values = out = None sent_interrupt = False for i in range(max_new_tokens): if i == 0: # prefill if model.config.is_encoder_decoder: out = model.decoder( input_ids=start_ids, encoder_hidden_states=encoder_output, use_cache=True, ) logits = model.lm_head(out[0]) else: out = model(torch.as_tensor([input_ids], device=device), use_cache=True) logits = out.logits past_key_values = out.past_key_values else: # decoding if model.config.is_encoder_decoder: out = model.decoder( input_ids=torch.as_tensor( [[token] if not sent_interrupt else output_ids], device=device ), encoder_hidden_states=encoder_output, use_cache=True, past_key_values=past_key_values if not sent_interrupt else None, ) sent_interrupt = False logits = model.lm_head(out[0]) else: out = model( input_ids=torch.as_tensor( [[token] if not sent_interrupt else output_ids], device=device ), use_cache=True, past_key_values=past_key_values if not sent_interrupt else None, ) sent_interrupt = False logits = out.logits past_key_values = out.past_key_values if logits_processor: if repetition_penalty > 1.0: tmp_output_ids = torch.as_tensor([output_ids], device=logits.device) else: tmp_output_ids = None last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0] else: last_token_logits = logits[0, -1, :] if device == "mps": # Switch to CPU by avoiding some bugs in mps backend. last_token_logits = last_token_logits.float().to("cpu") if temperature < 1e-5 or top_p < 1e-8: # greedy _, indices = torch.topk(last_token_logits, 2) tokens = [int(index) for index in indices.tolist()] else: probs = torch.softmax(last_token_logits, dim=-1) indices = torch.multinomial(probs, num_samples=2) tokens = [int(token) for token in indices.tolist()] token = tokens[0] output_ids.append(token) if token in stop_token_ids: stopped = True else: stopped = False # Yield the output tokens if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped: if echo: tmp_output_ids = output_ids rfind_start = len_prompt else: tmp_output_ids = output_ids[input_echo_len:] rfind_start = 0 output = tokenizer.decode( tmp_output_ids, skip_special_tokens=True, spaces_between_special_tokens=False, clean_up_tokenization_spaces=True, ) # TODO: For the issue of incomplete sentences interrupting output, apply a patch and others can also modify it to a more elegant way if judge_sent_end and stopped and not is_sentence_complete(output): if len(tokens) > 1: token = tokens[1] output_ids[-1] = token else: output_ids.pop() stopped = False sent_interrupt = True partially_stopped = False if stop_str: if isinstance(stop_str, str): pos = output.rfind(stop_str, rfind_start) if pos != -1: output = output[:pos] stopped = True else: partially_stopped = is_partial_stop(output, stop_str) elif isinstance(stop_str, Iterable): for each_stop in stop_str: pos = output.rfind(each_stop, rfind_start) if pos != -1: output = output[:pos] stopped = True break else: partially_stopped = is_partial_stop(output, each_stop) if partially_stopped: break else: raise ValueError("Invalid stop field type.") # Prevent yielding partial stop sequence if not partially_stopped: yield { "text": output, "usage": { "prompt_tokens": input_echo_len, "completion_tokens": i, "total_tokens": input_echo_len + i, }, "finish_reason": None, } if stopped: break # Finish stream event, which contains finish reason if i == max_new_tokens - 1: finish_reason = "length" elif stopped: finish_reason = "stop" else: finish_reason = None yield { "text": output, "usage": { "prompt_tokens": input_echo_len, "completion_tokens": i, "total_tokens": input_echo_len + i, }, "finish_reason": finish_reason, } # Clean del past_key_values, out gc.collect() torch.cuda.empty_cache() class ChatIO(abc.ABC): @abc.abstractmethod def prompt_for_input(self, role: str) -> str: """Prompt for input from a role.""" @abc.abstractmethod def prompt_for_output(self, role: str): """Prompt for output from a role.""" @abc.abstractmethod def stream_output(self, output_stream): """Stream output.""" def chat_loop( model_path: str, device: str, num_gpus: int, max_gpu_memory: str, load_8bit: bool, cpu_offloading: bool, conv_template: Optional[str], temperature: float, repetition_penalty: float, max_new_tokens: int, chatio: ChatIO, gptq_config: GptqConfig, revision: str, judge_sent_end: bool, debug: bool, ): # Model model, tokenizer = load_model( model_path, device, num_gpus, max_gpu_memory, load_8bit, cpu_offloading, gptq_config, revision, debug, ) generate_stream_func = get_generate_stream_function(model, model_path) model_type = str(type(model)).lower() is_t5 = "t5" in model_type is_codet5p = "codet5p" in model_type # Hardcode T5's default repetition penalty to be 1.2 if is_t5 and repetition_penalty == 1.0: repetition_penalty = 1.2 # Set context length context_len = get_context_length(model.config) # Chat def new_chat(): if conv_template: conv = get_conv_template(conv_template) else: conv = get_conversation_template(model_path) return conv conv = new_chat() while True: try: inp = chatio.prompt_for_input(conv.roles[0]) except EOFError: inp = "" if inp == "!!exit" or not inp: print("exit...") break if inp == "!!reset": print("resetting...") conv = new_chat() continue conv.append_message(conv.roles[0], inp) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() if is_codet5p: # codet5p is a code completion model. prompt = inp gen_params = { "model": model_path, "prompt": prompt, "temperature": temperature, "repetition_penalty": repetition_penalty, "max_new_tokens": max_new_tokens, "stop": conv.stop_str, "stop_token_ids": conv.stop_token_ids, "echo": False, } chatio.prompt_for_output(conv.roles[1]) output_stream = generate_stream_func( model, tokenizer, gen_params, device, context_len=context_len, judge_sent_end=judge_sent_end, ) t = time.time() outputs = chatio.stream_output(output_stream) duration = time.time() - t conv.update_last_message(outputs.strip()) if debug: num_tokens = len(tokenizer.encode(outputs)) msg = { "conv_template": conv.name, "prompt": prompt, "outputs": outputs, "speed (token/s)": round(num_tokens / duration, 2), } print(f"\n{msg}\n")