Spaces:
Runtime error
Runtime error
File size: 2,878 Bytes
f4fac26 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import platform
import os
import time
from threading import Thread
from rich.text import Text
from rich.live import Live
from model.infer import ChatBot
from config import InferConfig
infer_config = InferConfig()
chat_bot = ChatBot(infer_config=infer_config)
clear_cmd = 'cls' if platform.system().lower() == 'windows' else 'clear'
welcome_txt = '欢迎使用ChatBot,输入`exit`退出,输入`cls`清屏。\n'
print(welcome_txt)
def build_prompt(history: list[list[str]]) -> str:
prompt = welcome_txt
for query, response in history:
prompt += '\n\033[0;33;40m用户:\033[0m{}'.format(query)
prompt += '\n\033[0;32;40mChatBot:\033[0m\n{}\n'.format(response)
return prompt
STOP_CIRCLE: bool=False
def circle_print(total_time: int=60) -> None:
global STOP_CIRCLE
'''非stream chat打印忙碌状态
'''
list_circle = ["\\", "|", "/", "—"]
for i in range(total_time * 4):
time.sleep(0.25)
print("\r{}".format(list_circle[i % 4]), end="", flush=True)
if STOP_CIRCLE: break
print("\r", end='', flush=True)
def chat(stream: bool=True) -> None:
global STOP_CIRCLE
history = []
turn_count = 0
while True:
print('\r\033[0;33;40m用户:\033[0m', end='', flush=True)
input_txt = input()
if len(input_txt) == 0:
print('请输入问题')
continue
# 退出
if input_txt.lower() == 'exit':
break
# 清屏
if input_txt.lower() == 'cls':
history = []
turn_count = 0
os.system(clear_cmd)
print(welcome_txt)
continue
if not stream:
STOP_CIRCLE = False
thread = Thread(target=circle_print)
thread.start()
outs = chat_bot.chat(input_txt)
STOP_CIRCLE = True
thread.join()
print("\r\033[0;32;40mChatBot:\033[0m\n{}\n\n".format(outs), end='')
continue
history.append([input_txt, ''])
stream_txt = []
streamer = chat_bot.stream_chat(input_txt)
rich_text = Text()
print("\r\033[0;32;40mChatBot:\033[0m\n", end='')
with Live(rich_text, refresh_per_second=15) as live:
for i, word in enumerate(streamer):
rich_text.append(word)
stream_txt.append(word)
stream_txt = ''.join(stream_txt)
if len(stream_txt) == 0:
stream_txt = "我是一个参数很少的AI模型🥺,知识库较少,无法直接回答您的问题,换个问题试试吧👋"
history[turn_count][1] = stream_txt
os.system(clear_cmd)
print(build_prompt(history), flush=True)
turn_count += 1
if __name__ == '__main__':
chat(stream=True) |