PlantBasedTen commited on
Commit
200cf42
1 Parent(s): 216850a

Upload bot.py

Browse files
Files changed (1) hide show
  1. bot.py +200 -0
bot.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from pathlib import Path
3
+ from typing import List, Tuple
4
+
5
+ import fire
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ # === Bot Loaders ===
10
+
11
+
12
+ def load_bot(
13
+ env_file_path: str = ".env",
14
+ logging_config_path: str = "logging.yaml",
15
+ model_cache_dir: str = "/model_cache",
16
+ embedding_model_device: str = "cuda:0",
17
+ debug: bool = False,
18
+ ):
19
+ """
20
+ Load the financial assistant bot in production or development mode based on the `debug` flag
21
+
22
+ In DEV mode the embedding model runs on CPU and the fine-tuned LLM is mocked.
23
+ Otherwise, the embedding model runs on GPU and the fine-tuned LLM is used.
24
+
25
+ Args:
26
+ env_file_path (str): Path to the environment file.
27
+ logging_config_path (str): Path to the logging configuration file.
28
+ model_cache_dir (str): Path to the directory where the model cache is stored.
29
+ embedding_model_device (str): Device to use for the embedding model.
30
+ debug (bool): Flag to indicate whether to run the bot in debug mode or not.
31
+
32
+ Returns:
33
+ FinancialBot: An instance of the FinancialBot class.
34
+ """
35
+
36
+ from financial_bot import initialize
37
+
38
+ # Be sure to initialize the environment variables before importing any other modules.
39
+ initialize(logging_config_path=logging_config_path, env_file_path=env_file_path)
40
+
41
+ from financial_bot import utils
42
+ from financial_bot.langchain_bot import FinancialBot
43
+
44
+ logger.info("#" * 100)
45
+ utils.log_available_gpu_memory()
46
+ utils.log_available_ram()
47
+ logger.info("#" * 100)
48
+
49
+ bot = FinancialBot(
50
+ model_cache_dir=Path(model_cache_dir) if model_cache_dir else None,
51
+ embedding_model_device=embedding_model_device,
52
+ debug=debug,
53
+ )
54
+
55
+ return bot
56
+
57
+
58
+ def load_bot_dev(
59
+ env_file_path: str = ".env",
60
+ logging_config_path: str = "logging.yaml",
61
+ model_cache_dir: str = "./model_cache",
62
+ ):
63
+ """
64
+ Load the Financial Assistant Bot in dev mode: the embedding model runs on CPU and the LLM is mocked.
65
+
66
+ Args:
67
+ env_file_path (str): Path to the environment file.
68
+ logging_config_path (str): Path to the logging configuration file.
69
+ model_cache_dir (str): Path to the directory where the model cache is stored.
70
+
71
+ Returns:
72
+ The loaded Financial Assistant Bot in dev mode.
73
+ """
74
+
75
+ return load_bot(
76
+ env_file_path=env_file_path,
77
+ logging_config_path=logging_config_path,
78
+ model_cache_dir=model_cache_dir,
79
+ embedding_model_device="cpu",
80
+ debug=True,
81
+ )
82
+
83
+
84
+ # === Bot Runners ===
85
+
86
+
87
+ @financial_bot.rest_api(keep_warm_seconds=300, loader=load_bot)
88
+ def run(**inputs):
89
+ """
90
+ Run the bot under the Beam RESTful API endpoint.
91
+
92
+ Args:
93
+ inputs (dict): A dictionary containing the following keys:
94
+ - context: The bot instance.
95
+ - about_me (str): Information about the user.
96
+ - question (str): The user's question.
97
+ - history (list): A list of previous conversations (optional).
98
+
99
+ Returns:
100
+ str: The bot's response to the user's question.
101
+ """
102
+
103
+ response = _run(**inputs)
104
+
105
+ return response
106
+
107
+
108
+ @financial_bot_dev.rest_api(keep_warm_seconds=300, loader=load_bot_dev)
109
+ def run_dev(**inputs):
110
+ """
111
+ Run the bot under the Beam RESTful API endpoint [Dev Mode].
112
+
113
+ Args:
114
+ inputs (dict): A dictionary containing the following keys:
115
+ - context: The bot instance.
116
+ - about_me (str): Information about the user.
117
+ - question (str): The user's question.
118
+ - history (list): A list of previous conversations (optional).
119
+
120
+ Returns:
121
+ str: The bot's response to the user's question.
122
+ """
123
+
124
+ response = _run(**inputs)
125
+
126
+ return response
127
+
128
+
129
+ def run_local(
130
+ about_me: str,
131
+ question: str,
132
+ history: List[Tuple[str, str]] = None,
133
+ debug: bool = False,
134
+ ):
135
+ """
136
+ Run the bot locally in production or dev mode.
137
+
138
+ Args:
139
+ about_me (str): A string containing information about the user.
140
+ question (str): A string containing the user's question.
141
+ history (List[Tuple[str, str]], optional): A list of tuples containing the user's previous questions
142
+ and the bot's responses. Defaults to None.
143
+ debug (bool, optional): A boolean indicating whether to run the bot in debug mode. Defaults to False.
144
+
145
+ Returns:
146
+ str: A string containing the bot's response to the user's question.
147
+ """
148
+
149
+ if debug is True:
150
+ bot = load_bot_dev(model_cache_dir=None)
151
+ else:
152
+ bot = load_bot(model_cache_dir=None)
153
+
154
+ inputs = {
155
+ "about_me": about_me,
156
+ "question": question,
157
+ "history": history,
158
+ "context": bot,
159
+ }
160
+
161
+ response = _run(**inputs)
162
+
163
+ return response
164
+
165
+
166
+ def _run(**inputs):
167
+ """
168
+ Central function that calls the bot and returns the response.
169
+
170
+ Args:
171
+ inputs (dict): A dictionary containing the following keys:
172
+ - context: The bot instance.
173
+ - about_me (str): Information about the user.
174
+ - question (str): The user's question.
175
+ - history (list): A list of previous conversations (optional).
176
+
177
+ Returns:
178
+ str: The bot's response to the user's question.
179
+ """
180
+
181
+ from financial_bot import utils
182
+
183
+ logger.info("#" * 100)
184
+ utils.log_available_gpu_memory()
185
+ utils.log_available_ram()
186
+ logger.info("#" * 100)
187
+
188
+ bot = inputs["context"]
189
+ input_payload = {
190
+ "about_me": inputs["about_me"],
191
+ "question": inputs["question"],
192
+ "to_load_history": inputs["history"] if "history" in inputs else [],
193
+ }
194
+ response = bot.answer(**input_payload)
195
+
196
+ return response
197
+
198
+
199
+ if __name__ == "__main__":
200
+ fire.Fire(run_local)