import torch from typing import List, Tuple @torch.inference_mode() def chatglm_generate_stream( model, tokenizer, params, device, context_len=2048, stream_interval=2 ): """Generate text using model's chat api""" messages = params["prompt"] max_new_tokens = int(params.get("max_new_tokens", 256)) temperature = float(params.get("temperature", 1.0)) top_p = float(params.get("top_p", 0.7)) gen_kwargs = { "max_new_tokens": max_new_tokens, "do_sample": True, "top_p": top_p, "temperature": temperature, "logits_processor": None, } hist = [] for i in range(0, len(messages) - 2, 2): hist.append((messages[i][1], messages[i + 1][1])) query = messages[-2][1] for response, new_hist in model.stream_chat(tokenizer, query, hist): output = query + " " + response yield output