File size: 2,878 Bytes
f4fac26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import platform
import os
import time
from threading import Thread

from rich.text import Text
from rich.live import Live

from model.infer import ChatBot
from config import InferConfig

infer_config = InferConfig()
chat_bot = ChatBot(infer_config=infer_config)

clear_cmd = 'cls' if platform.system().lower() == 'windows' else 'clear'

welcome_txt = '欢迎使用ChatBot,输入`exit`退出,输入`cls`清屏。\n'
print(welcome_txt)

def build_prompt(history: list[list[str]]) -> str:
    prompt = welcome_txt
    for query, response in history:
        prompt += '\n\033[0;33;40m用户:\033[0m{}'.format(query)
        prompt += '\n\033[0;32;40mChatBot:\033[0m\n{}\n'.format(response)
    return prompt

STOP_CIRCLE: bool=False
def circle_print(total_time: int=60) -> None:
    global STOP_CIRCLE
    '''非stream chat打印忙碌状态
    '''
    list_circle = ["\\", "|", "/", "—"]
    for i in range(total_time * 4):
        time.sleep(0.25)
        print("\r{}".format(list_circle[i % 4]), end="", flush=True)

        if STOP_CIRCLE: break

    print("\r", end='', flush=True)


def chat(stream: bool=True) -> None:
    global  STOP_CIRCLE
    history = []
    turn_count = 0

    while True:
        print('\r\033[0;33;40m用户:\033[0m', end='', flush=True)
        input_txt = input()

        if len(input_txt) == 0:
            print('请输入问题')
            continue
        
        # 退出
        if input_txt.lower() == 'exit':
            break
        
        # 清屏
        if input_txt.lower() == 'cls':
            history = []
            turn_count = 0
            os.system(clear_cmd)
            print(welcome_txt)
            continue
        
        if not stream:
            STOP_CIRCLE = False
            thread = Thread(target=circle_print)
            thread.start()

            outs = chat_bot.chat(input_txt)

            STOP_CIRCLE = True
            thread.join()
            
            print("\r\033[0;32;40mChatBot:\033[0m\n{}\n\n".format(outs), end='')
           
            continue

        history.append([input_txt, ''])
        stream_txt = []
        streamer = chat_bot.stream_chat(input_txt)
        rich_text = Text()

        print("\r\033[0;32;40mChatBot:\033[0m\n", end='')

        with Live(rich_text, refresh_per_second=15) as live: 
            for i, word in enumerate(streamer):
                rich_text.append(word)
                stream_txt.append(word)

        stream_txt = ''.join(stream_txt)

        if len(stream_txt) == 0:
            stream_txt = "我是一个参数很少的AI模型🥺,知识库较少,无法直接回答您的问题,换个问题试试吧👋"

        history[turn_count][1] = stream_txt
        
        os.system(clear_cmd)
        print(build_prompt(history), flush=True)
        turn_count += 1

if __name__ == '__main__':
    chat(stream=True)