File size: 4,887 Bytes
6ef31de
f99efcc
 
6ef31de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f99efcc
 
6ef31de
f99efcc
 
6ef31de
 
f99efcc
6ef31de
f99efcc
6ef31de
f99efcc
6ef31de
 
 
 
 
 
 
 
 
f99efcc
6ef31de
 
 
 
f99efcc
 
6ef31de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f99efcc
6ef31de
 
 
 
 
 
f99efcc
6ef31de
f99efcc
6ef31de
 
f99efcc
6ef31de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from model.fastchat.conversation import (Conversation, SeparatorStyle,
                                         compute_skip_echo_len)
from model.fastchat.serve.inference import ChatIO, generate_stream, load_model


class SimpleChatIO(ChatIO):

    def prompt_for_input(self, role) -> str:
        return input(f"{role}: ")

    def prompt_for_output(self, role: str):
        print(f"{role}: ", end="", flush=True)

    def stream_output(self, output_stream, skip_echo_len: int):
        pre = 0
        for outputs in output_stream:
            outputs = outputs[skip_echo_len:].strip()
            outputs = outputs.split(" ")
            now = len(outputs) - 1
            if now > pre:
                print(" ".join(outputs[pre:now]), end=" ", flush=True)
                pre = now
        print(" ".join(outputs[pre:]), flush=True)
        return " ".join(outputs)


class VicunaChatBot:

    def __init__(
        self,
        model_path: str,
        device: str,
        num_gpus: str,
        max_gpu_memory: str,
        load_8bit: bool,
        ChatIO: ChatIO,
        debug: bool,
    ):
        self.model_path = model_path
        self.device = device
        self.chatio = ChatIO
        self.debug = debug

        self.model, self.tokenizer = load_model(self.model_path, device,
                                                num_gpus, max_gpu_memory,
                                                load_8bit, debug)

    def chat(self, inp: str, temperature: float, max_new_tokens: int,
             conv: Conversation):
        """ Vicuna as a chatbot. """
        conv.append_message(conv.roles[0], inp)
        conv.append_message(conv.roles[1], None)

        generate_stream_func = generate_stream
        prompt = conv.get_prompt()

        skip_echo_len = compute_skip_echo_len(self.model_path, conv, prompt)
        stop_str = (
            conv.sep if conv.sep_style
            in [SeparatorStyle.SINGLE, SeparatorStyle.BAIZE] else None)
        params = {
            "model": self.model_path,
            "prompt": prompt,
            "temperature": temperature,
            "max_new_tokens": max_new_tokens,
            "stop": stop_str,
        }
        print(prompt)
        self.chatio.prompt_for_output(conv.roles[1])
        output_stream = generate_stream_func(self.model, self.tokenizer,
                                             params, self.device)
        outputs = self.chatio.stream_output(output_stream, skip_echo_len)
        # NOTE: strip is important to align with the training data.
        conv.messages[-1][-1] = outputs.strip()
        return outputs, conv


class VicunaHandler:
    """ VicunaHandler is a class that handles the communication between the
    frontend and the backend. """

    def __init__(self, config):
        self.config = config
        self.chat_io = SimpleChatIO()
        self.chatbot = VicunaChatBot(
            self.config['model_path'],
            self.config['device'],
            self.config['num_gpus'],
            self.config['max_gpu_memory'],
            self.config['load_8bit'],
            self.chat_io,
            self.config['debug'],
        )

    def chat(self):
        """ Chat with the Vicuna. """
        pass

    def gr_chatbot_init(self, caption: str):
        """ Initialise the chatbot for gradio. """

        template = self._construct_conversation(caption)
        print("Chatbot initialised.")
        return template.copy(), template.copy()

    def gr_chat(self, inp, conv: Conversation):
        """ Chat using gradio as the frontend. """
        return self.chatbot.chat(inp, self.config['temperature'],
                                 self.config['max_new_tokens'], conv)

    def _construct_conversation(self, prompt):
        """ Construct a conversation template.
        Args:
            prompt: the prompt for the conversation.
        """

        user_message = "The following text described what you have " +\
            "seen, found, heard and notice from a consecutive video." +\
            " Some of the texts may not be accurate. " +\
            "Try to conclude what happens in the video, " +\
            "then answer my question based on your conclusion.\n" +\
            "<video begin>\n" + prompt + "<video end>\n" +\
            "Example: Is this a Video?"

        user_message = user_message.strip()

        print(user_message)

        return Conversation(
            system=
            "A chat between a curious user and an artificial intelligence assistant answering quetions on videos."
            "The assistant answers the questions based on the given video captions and speech in time order.",
            roles=("USER", "ASSISTANT"),
            messages=(("USER", user_message), ("ASSISTANT", "yes")),
            offset=0,
            sep_style=SeparatorStyle.TWO,
            sep=" ",
            sep2="</s>",
        )