# coding=utf-8 # Implement stream chat in command line for ChatGLM fine-tuned with PEFT. # This code is largely borrowed from https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py import os import signal import platform from utils import ModelArguments, load_pretrained from transformers import HfArgumentParser os_name = platform.system() clear_command = "cls" if os_name == "Windows" else "clear" stop_stream = False welcome = "欢迎使用 ChatGLM-6B 模型,输入内容即可对话,clear清空对话历史,stop终止程序" def build_prompt(history): prompt = welcome for query, response in history: prompt += f"\n\nUser: {query}" prompt += f"\n\nChatGLM-6B: {response}" return prompt def signal_handler(signal, frame): global stop_stream stop_stream = True def main(): global stop_stream parser = HfArgumentParser(ModelArguments) model_args, = parser.parse_args_into_dataclasses() model, tokenizer = load_pretrained(model_args) model = model.cuda() model.eval() history = [] print(welcome) while True: try: query = input("\nInput: ") except UnicodeDecodeError: print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.") continue except Exception: raise if query.strip() == "stop": break if query.strip() == "clear": history = [] os.system(clear_command) print(welcome) continue count = 0 for _, history in model.stream_chat(tokenizer, query, history=history): if stop_stream: stop_stream = False break else: count += 1 if count % 8 == 0: os.system(clear_command) print(build_prompt(history), flush=True) signal.signal(signal.SIGINT, signal_handler) os.system(clear_command) print(build_prompt(history), flush=True) if __name__ == "__main__": main()