File size: 6,001 Bytes
32a6937
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01f4bd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32a6937
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01f4bd7
 
 
 
79eed96
01f4bd7
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import abc
import json
import os
import re
import time
import urllib
from queue import Queue
from threading import Thread
from typing import List, Optional
from urllib.parse import quote, urlparse, urlunparse

from langchain.chains.base import Chain

from app_modules.llm_loader import LLMLoader, TextIteratorStreamer
from app_modules.utils import remove_extra_spaces

chat_history_enabled = os.getenv("CHAT_HISTORY_ENABLED", "false").lower() == "true"


def get_system_prompt_and_user_message(orca=False):
    # system_prompt = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."
    system_prompt = (
        "You are Orca, an AI language model created by Microsoft. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior."
        if orca
        else "You are a chatbot having a conversation with a human."
    )

    user_message = "{input}"

    if chat_history_enabled:
        user_message = "Chat History:\n\n{history} \n\n" + user_message
        system_prompt += " Read the chat history to get context."

    return system_prompt, user_message


class LLMInference(metaclass=abc.ABCMeta):
    def __init__(self, llm_loader):
        self.llm_loader = llm_loader
        self.chain = None
        self.pattern = re.compile(r"\s*<.+>$")

    @abc.abstractmethod
    def create_chain(self) -> Chain:
        pass

    def get_chain(self) -> Chain:
        if self.chain is None:
            self.chain = self.create_chain()

        return self.chain

    def reset(self) -> None:
        self.chain = None

    def _process_inputs(self, inputs):
        return inputs

    def _normalize_result(self, result):
        # print(f"_normalize_result: {result}")
        if isinstance(result, list):
            result = result[0]

        key = "text" if "text" in result else "generated_text"
        if key in result:
            result["answer"] = result[key]
            del result[key]

        result["answer"] = self.pattern.sub("", result["answer"])
        return result

    def _process_results(self, results):
        if isinstance(results, list):
            return [self._normalize_result(result) for result in results]

        return self._normalize_result(results)

    def _run_batch(self, chain, inputs):
        if self.llm_loader.llm_model_type == "huggingface":
            results = self.llm_loader.llm.pipeline(inputs)
        else:
            results = chain.batch(inputs)

        return results

    def run_chain(self, chain, inputs, callbacks: Optional[List] = []):
        inputs = self._process_inputs(inputs)

        # check if inputs is an array
        if isinstance(inputs, list):
            results = self._run_batch(chain, inputs)
        else:
            results = chain.invoke(inputs, {"callbacks": callbacks})

        return self._process_results(results)

    def call_chain(
        self,
        inputs,
        streaming_handler,
        q: Queue = None,
        testing: bool = False,
    ):
        print(json.dumps(inputs, indent=4))
        if self.llm_loader.huggingfaceStreamingEnabled():
            self.llm_loader.lock.acquire()

        try:
            if self.llm_loader.huggingfaceStreamingEnabled():
                self.llm_loader.streamer.reset(q)

            chain = self.get_chain()
            result = (
                self._run_chain_with_streaming_handler(
                    chain, inputs, streaming_handler, testing
                )
                if streaming_handler is not None
                else self.run_chain(chain, inputs)
            )

            if "answer" in result:
                result["answer"] = remove_extra_spaces(result["answer"])

            return result
        finally:
            if self.llm_loader.huggingfaceStreamingEnabled():
                self.llm_loader.lock.release()

    def _execute_chain(self, chain, inputs, q, sh):
        q.put(self.run_chain(chain, inputs, callbacks=[sh]))

    def _run_chain_with_streaming_handler(
        self, chain, inputs, streaming_handler, testing
    ):
        que = Queue()

        t = Thread(
            target=self._execute_chain,
            args=(chain, inputs, que, streaming_handler),
        )
        t.start()

        if self.llm_loader.huggingfaceStreamingEnabled():
            count = (
                2
                if "chat_history" in inputs and len(inputs.get("chat_history")) > 0
                else 1
            )

            while count > 0:
                try:
                    for token in self.llm_loader.streamer:
                        if not testing:
                            streaming_handler.on_llm_new_token(token)

                    self.llm_loader.streamer.reset()
                    count -= 1
                except Exception:
                    if not testing:
                        print("nothing generated yet - retry in 0.5s")
                    time.sleep(0.5)

        t.join()
        return que.get()

    def apply_chat_template(self, user_message):
        result = (
            []
            if re.search(r"gemma|mistral", self.llm_loader.model_name, re.IGNORECASE)
            else [
                {
                    "role": "system",
                    "content": get_system_prompt_and_user_message()[0],
                }
            ]
        )
        result.append(
            {
                "role": "user",
                "content": user_message,
            }
        )
        return result