chatmlTest / cli_demo.py
fangshengren's picture
Upload 59 files
f4fac26 verified
raw
history blame
2.88 kB
import platform
import os
import time
from threading import Thread
from rich.text import Text
from rich.live import Live
from model.infer import ChatBot
from config import InferConfig
infer_config = InferConfig()
chat_bot = ChatBot(infer_config=infer_config)
clear_cmd = 'cls' if platform.system().lower() == 'windows' else 'clear'
welcome_txt = '欢迎使用ChatBot,输入`exit`退出,输入`cls`清屏。\n'
print(welcome_txt)
def build_prompt(history: list[list[str]]) -> str:
prompt = welcome_txt
for query, response in history:
prompt += '\n\033[0;33;40m用户:\033[0m{}'.format(query)
prompt += '\n\033[0;32;40mChatBot:\033[0m\n{}\n'.format(response)
return prompt
STOP_CIRCLE: bool=False
def circle_print(total_time: int=60) -> None:
global STOP_CIRCLE
'''非stream chat打印忙碌状态
'''
list_circle = ["\\", "|", "/", "—"]
for i in range(total_time * 4):
time.sleep(0.25)
print("\r{}".format(list_circle[i % 4]), end="", flush=True)
if STOP_CIRCLE: break
print("\r", end='', flush=True)
def chat(stream: bool=True) -> None:
global STOP_CIRCLE
history = []
turn_count = 0
while True:
print('\r\033[0;33;40m用户:\033[0m', end='', flush=True)
input_txt = input()
if len(input_txt) == 0:
print('请输入问题')
continue
# 退出
if input_txt.lower() == 'exit':
break
# 清屏
if input_txt.lower() == 'cls':
history = []
turn_count = 0
os.system(clear_cmd)
print(welcome_txt)
continue
if not stream:
STOP_CIRCLE = False
thread = Thread(target=circle_print)
thread.start()
outs = chat_bot.chat(input_txt)
STOP_CIRCLE = True
thread.join()
print("\r\033[0;32;40mChatBot:\033[0m\n{}\n\n".format(outs), end='')
continue
history.append([input_txt, ''])
stream_txt = []
streamer = chat_bot.stream_chat(input_txt)
rich_text = Text()
print("\r\033[0;32;40mChatBot:\033[0m\n", end='')
with Live(rich_text, refresh_per_second=15) as live:
for i, word in enumerate(streamer):
rich_text.append(word)
stream_txt.append(word)
stream_txt = ''.join(stream_txt)
if len(stream_txt) == 0:
stream_txt = "我是一个参数很少的AI模型🥺,知识库较少,无法直接回答您的问题,换个问题试试吧👋"
history[turn_count][1] = stream_txt
os.system(clear_cmd)
print(build_prompt(history), flush=True)
turn_count += 1
if __name__ == '__main__':
chat(stream=True)