PlantBasedTen commited on
Commit
3484e05
1 Parent(s): 297c21b

Initial commit

Browse files
Files changed (1) hide show
  1. app.py +194 -0
app.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ from pathlib import Path
4
+ from threading import Thread
5
+ from typing import List
6
+
7
+ import gradio as gr
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ def parseargs() -> argparse.Namespace:
13
+ """
14
+ Parses command line arguments for the Financial Assistant Bot.
15
+
16
+ Returns:
17
+ argparse.Namespace: An object containing the parsed arguments.
18
+ """
19
+
20
+ parser = argparse.ArgumentParser(description="Financial Assistant Bot")
21
+
22
+ parser.add_argument(
23
+ "--env-file-path",
24
+ type=str,
25
+ default=".env",
26
+ help="Path to the environment file",
27
+ )
28
+
29
+ parser.add_argument(
30
+ "--logging-config-path",
31
+ type=str,
32
+ default="logging.yaml",
33
+ help="Path to the logging configuration file",
34
+ )
35
+
36
+ parser.add_argument(
37
+ "--model-cache-dir",
38
+ type=str,
39
+ default="./model_cache",
40
+ help="Path to the directory where the model cache will be stored",
41
+ )
42
+
43
+ parser.add_argument(
44
+ "--embedding-model-device",
45
+ type=str,
46
+ default="cuda:0",
47
+ help="Device to use for the embedding model (e.g. 'cpu', 'cuda:0', etc.)",
48
+ )
49
+
50
+ parser.add_argument(
51
+ "--debug",
52
+ action="store_true",
53
+ default=False,
54
+ help="Enable debug mode",
55
+ )
56
+
57
+ return parser.parse_args()
58
+
59
+
60
+ args = parseargs()
61
+
62
+
63
+ # === Load Bot ===
64
+
65
+
66
+ def load_bot(
67
+ env_file_path: str = ".env",
68
+ logging_config_path: str = "logging.yaml",
69
+ model_cache_dir: str = "/model_cache",
70
+ embedding_model_device: str = "cuda:0",
71
+ debug: bool = False,
72
+ ):
73
+ """
74
+ Load the financial assistant bot in production or development mode based on the `debug` flag
75
+
76
+ In DEV mode the embedding model runs on CPU and the fine-tuned LLM is mocked.
77
+ Otherwise, the embedding model runs on GPU and the fine-tuned LLM is used.
78
+
79
+ Args:
80
+ env_file_path (str): Path to the environment file.
81
+ logging_config_path (str): Path to the logging configuration file.
82
+ model_cache_dir (str): Path to the directory where the model cache is stored.
83
+ embedding_model_device (str): Device to use for the embedding model.
84
+ debug (bool): Flag to indicate whether to run the bot in debug mode or not.
85
+
86
+ Returns:
87
+ FinancialBot: An instance of the FinancialBot class.
88
+ """
89
+
90
+ from financial_bot import initialize
91
+
92
+ # Be sure to initialize the environment variables before importing any other modules.
93
+ initialize(logging_config_path=logging_config_path, env_file_path=env_file_path)
94
+
95
+ from financial_bot import utils
96
+ from financial_bot.langchain_bot import FinancialBot
97
+
98
+ logger.info("#" * 100)
99
+ utils.log_available_gpu_memory()
100
+ utils.log_available_ram()
101
+ logger.info("#" * 100)
102
+
103
+ bot = FinancialBot(
104
+ model_cache_dir=Path(model_cache_dir) if model_cache_dir else None,
105
+ embedding_model_device=embedding_model_device,
106
+ streaming=True,
107
+ debug=debug,
108
+ )
109
+
110
+ return bot
111
+
112
+
113
+ bot = load_bot(
114
+ env_file_path=args.env_file_path,
115
+ logging_config_path=args.logging_config_path,
116
+ model_cache_dir=args.model_cache_dir,
117
+ embedding_model_device=args.embedding_model_device,
118
+ debug=args.debug,
119
+ )
120
+
121
+
122
+ # === Gradio Interface ===
123
+
124
+
125
+ def predict(message: str, history: List[List[str]], about_me: str) -> str:
126
+ """
127
+ Predicts a response to a given message using the financial_bot Gradio UI.
128
+
129
+ Args:
130
+ message (str): The message to generate a response for.
131
+ history (List[List[str]]): A list of previous conversations.
132
+ about_me (str): A string describing the user.
133
+
134
+ Returns:
135
+ str: The generated response.
136
+ """
137
+
138
+ generate_kwargs = {
139
+ "about_me": about_me,
140
+ "question": message,
141
+ "to_load_history": history,
142
+ }
143
+
144
+ if bot.is_streaming:
145
+ t = Thread(target=bot.answer, kwargs=generate_kwargs)
146
+ t.start()
147
+
148
+ for partial_answer in bot.stream_answer():
149
+ yield partial_answer
150
+ else:
151
+ yield bot.answer(**generate_kwargs)
152
+
153
+
154
+ demo = gr.ChatInterface(
155
+ predict,
156
+ textbox=gr.Textbox(
157
+ placeholder="Ask me a financial question",
158
+ label="Financial question",
159
+ container=False,
160
+ scale=7,
161
+ ),
162
+ additional_inputs=[
163
+ gr.Textbox(
164
+ "I am a student and I have some money that I want to invest.",
165
+ label="About me",
166
+ )
167
+ ],
168
+ title="Your Personal Financial Assistant",
169
+ description="Ask me any financial or crypto market questions, and I will do my best to answer them.",
170
+ theme="soft",
171
+ examples=[
172
+ [
173
+ "What's your opinion on investing in startup companies?",
174
+ "I am a 30 year old graphic designer. I want to invest in something with potential for high returns.",
175
+ ],
176
+ [
177
+ "What's your opinion on investing in AI-related companies?",
178
+ "I'm a 25 year old entrepreneur interested in emerging technologies. \
179
+ I'm willing to take calculated risks for potential high returns.",
180
+ ],
181
+ [
182
+ "Do you think advancements in gene therapy are impacting biotech company valuations?",
183
+ "I'm a 31 year old scientist. I'm curious about the potential of biotech investments.",
184
+ ],
185
+ ],
186
+ cache_examples=False,
187
+ retry_btn=None,
188
+ undo_btn=None,
189
+ clear_btn="Clear",
190
+ )
191
+
192
+
193
+ if __name__ == "__main__":
194
+ demo.queue().launch()