PlantBasedTen commited on
Commit
bb59984
1 Parent(s): d432e43

Upload 22 files

Browse files
financial_bot/__init__.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import logging.config
3
+ from pathlib import Path
4
+
5
+ import yaml
6
+ from dotenv import find_dotenv, load_dotenv
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ def initialize(logging_config_path: str = "logging.yaml", env_file_path: str = ".env"):
12
+ """
13
+ Initializes the logger and environment variables.
14
+
15
+ Args:
16
+ logging_config_path (str): The path to the logging configuration file. Defaults to "logging.yaml".
17
+ env_file_path (str): The path to the environment variables file. Defaults to ".env".
18
+ """
19
+
20
+ logger.info("Initializing logger...")
21
+ try:
22
+ initialize_logger(config_path=logging_config_path)
23
+ except FileNotFoundError:
24
+ logger.warning(
25
+ f"No logging configuration file found at: {logging_config_path}. Setting logging level to INFO."
26
+ )
27
+ logging.basicConfig(level=logging.INFO)
28
+
29
+ logger.info("Initializing env vars...")
30
+ if env_file_path is None:
31
+ env_file_path = find_dotenv(raise_error_if_not_found=True, usecwd=False)
32
+
33
+ logger.info(f"Loading environment variables from: {env_file_path}")
34
+ found_env_file = load_dotenv(env_file_path, verbose=True, override=True)
35
+ if found_env_file is False:
36
+ raise RuntimeError(f"Could not find environment file at: {env_file_path}")
37
+
38
+
39
+ def initialize_logger(
40
+ config_path: str = "logging.yaml", logs_dir_name: str = "logs"
41
+ ) -> logging.Logger:
42
+ """Initialize logger from a YAML config file."""
43
+
44
+ # Create logs directory.
45
+ config_path_parent = Path(config_path).parent
46
+ logs_dir = config_path_parent / logs_dir_name
47
+ logs_dir.mkdir(parents=True, exist_ok=True)
48
+
49
+ with open(config_path, "rt") as f:
50
+ config = yaml.safe_load(f.read())
51
+
52
+ # Make sure that existing logger will still work.
53
+ config["disable_existing_loggers"] = False
54
+
55
+ logging.config.dictConfig(config)
financial_bot/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.96 kB). View file
 
financial_bot/__pycache__/base.cpython-310.pyc ADDED
Binary file (936 Bytes). View file
 
financial_bot/__pycache__/chains.cpython-310.pyc ADDED
Binary file (6.98 kB). View file
 
financial_bot/__pycache__/constants.cpython-310.pyc ADDED
Binary file (720 Bytes). View file
 
financial_bot/__pycache__/embeddings.cpython-310.pyc ADDED
Binary file (4.37 kB). View file
 
financial_bot/__pycache__/handlers.cpython-310.pyc ADDED
Binary file (2.59 kB). View file
 
financial_bot/__pycache__/langchain_bot.cpython-310.pyc ADDED
Binary file (7.71 kB). View file
 
financial_bot/__pycache__/models.cpython-310.pyc ADDED
Binary file (8.25 kB). View file
 
financial_bot/__pycache__/qdrant.cpython-310.pyc ADDED
Binary file (1.56 kB). View file
 
financial_bot/__pycache__/template.cpython-310.pyc ADDED
Binary file (3.84 kB). View file
 
financial_bot/__pycache__/utils.cpython-310.pyc ADDED
Binary file (3.34 kB). View file
 
financial_bot/base.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from threading import Lock
2
+
3
+
4
+ class SingletonMeta(type):
5
+ """
6
+ This is a thread-safe implementation of Singleton.
7
+ """
8
+
9
+ _instances = {}
10
+
11
+ _lock: Lock = Lock()
12
+
13
+ """
14
+ We now have a lock object that will be used to synchronize threads during
15
+ first access to the Singleton.
16
+ """
17
+
18
+ def __call__(cls, *args, **kwargs):
19
+ """
20
+ Possible changes to the value of the `__init__` argument do not affect
21
+ the returned instance.
22
+ """
23
+ # Now, imagine that the program has just been launched. Since there's no
24
+ # Singleton instance yet, multiple threads can simultaneously pass the
25
+ # previous conditional and reach this point almost at the same time. The
26
+ # first of them will acquire lock and will proceed further, while the
27
+ # rest will wait here.
28
+ with cls._lock:
29
+ # The first thread to acquire the lock, reaches this conditional,
30
+ # goes inside and creates the Singleton instance. Once it leaves the
31
+ # lock block, a thread that might have been waiting for the lock
32
+ # release may then enter this section. But since the Singleton field
33
+ # is already initialized, the thread won't create a new object.
34
+ if cls not in cls._instances:
35
+ instance = super().__call__(*args, **kwargs)
36
+ cls._instances[cls] = instance
37
+
38
+ return cls._instances[cls]
financial_bot/chains.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ import qdrant_client
5
+ from langchain import chains
6
+ from langchain.callbacks.manager import CallbackManagerForChainRun
7
+ from langchain.chains.base import Chain
8
+ from langchain.llms import HuggingFacePipeline
9
+ from unstructured.cleaners.core import (
10
+ clean,
11
+ clean_extra_whitespace,
12
+ clean_non_ascii_chars,
13
+ group_broken_paragraphs,
14
+ replace_unicode_quotes,
15
+ )
16
+
17
+ from financial_bot.embeddings import EmbeddingModelSingleton
18
+ from financial_bot.template import PromptTemplate
19
+
20
+
21
+ class StatelessMemorySequentialChain(chains.SequentialChain):
22
+ """
23
+ A sequential chain that uses a stateless memory to store context between calls.
24
+
25
+ This chain overrides the _call and prep_outputs methods to load and clear the memory
26
+ before and after each call, respectively.
27
+ """
28
+
29
+ history_input_key: str = "to_load_history"
30
+
31
+ def _call(self, inputs: Dict[str, str], **kwargs) -> Dict[str, str]:
32
+ """
33
+ Override _call to load history before calling the chain.
34
+
35
+ This method loads the history from the input dictionary and saves it to the
36
+ stateless memory. It then updates the inputs dictionary with the memory values
37
+ and removes the history input key. Finally, it calls the parent _call method
38
+ with the updated inputs and returns the results.
39
+ """
40
+
41
+ to_load_history = inputs[self.history_input_key]
42
+ for (
43
+ human,
44
+ ai,
45
+ ) in to_load_history:
46
+ self.memory.save_context(
47
+ inputs={self.memory.input_key: human},
48
+ outputs={self.memory.output_key: ai},
49
+ )
50
+ memory_values = self.memory.load_memory_variables({})
51
+ inputs.update(memory_values)
52
+
53
+ del inputs[self.history_input_key]
54
+
55
+ return super()._call(inputs, **kwargs)
56
+
57
+ def prep_outputs(
58
+ self,
59
+ inputs: Dict[str, str],
60
+ outputs: Dict[str, str],
61
+ return_only_outputs: bool = False,
62
+ ) -> Dict[str, str]:
63
+ """
64
+ Override prep_outputs to clear the internal memory after each call.
65
+
66
+ This method calls the parent prep_outputs method to get the results, then
67
+ clears the stateless memory and removes the memory key from the results
68
+ dictionary. It then returns the updated results.
69
+ """
70
+
71
+ results = super().prep_outputs(inputs, outputs, return_only_outputs)
72
+
73
+ # Clear the internal memory.
74
+ self.memory.clear()
75
+ if self.memory.memory_key in results:
76
+ results[self.memory.memory_key] = ""
77
+
78
+ return results
79
+
80
+
81
+ class ContextExtractorChain(Chain):
82
+ """
83
+ Encode the question, search the vector store for top-k articles and return
84
+ context news from documents collection of Alpaca news.
85
+
86
+ Attributes:
87
+ -----------
88
+ top_k : int
89
+ The number of top matches to retrieve from the vector store.
90
+ embedding_model : EmbeddingModelSingleton
91
+ The embedding model to use for encoding the question.
92
+ vector_store : qdrant_client.QdrantClient
93
+ The vector store to search for matches.
94
+ vector_collection : str
95
+ The name of the collection to search in the vector store.
96
+ """
97
+
98
+ top_k: int = 1
99
+ embedding_model: EmbeddingModelSingleton
100
+ vector_store: qdrant_client.QdrantClient
101
+ vector_collection: str
102
+
103
+ @property
104
+ def input_keys(self) -> List[str]:
105
+ return ["about_me", "question"]
106
+
107
+ @property
108
+ def output_keys(self) -> List[str]:
109
+ return ["context"]
110
+
111
+ def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
112
+ _, quest_key = self.input_keys
113
+ question_str = inputs[quest_key]
114
+
115
+ cleaned_question = self.clean(question_str)
116
+ # TODO: Instead of cutting the question at 'max_input_length', chunk the question in 'max_input_length' chunks,
117
+ # pass them through the model and average the embeddings.
118
+ cleaned_question = cleaned_question[: self.embedding_model.max_input_length]
119
+ embeddings = self.embedding_model(cleaned_question)
120
+
121
+ # TODO: Using the metadata, use the filter to take into consideration only the news from the last 24 hours
122
+ # (or other time frame).
123
+ matches = self.vector_store.search(
124
+ query_vector=embeddings,
125
+ k=self.top_k,
126
+ collection_name=self.vector_collection,
127
+ )
128
+
129
+ context = ""
130
+ for match in matches:
131
+ context += match.payload["summary"] + "\n"
132
+
133
+ return {
134
+ "context": context,
135
+ }
136
+
137
+ def clean(self, question: str) -> str:
138
+ """
139
+ Clean the input question by removing unwanted characters.
140
+
141
+ Parameters:
142
+ -----------
143
+ question : str
144
+ The input question to clean.
145
+
146
+ Returns:
147
+ --------
148
+ str
149
+ The cleaned question.
150
+ """
151
+ question = clean(question)
152
+ question = replace_unicode_quotes(question)
153
+ question = clean_non_ascii_chars(question)
154
+
155
+ return question
156
+
157
+
158
+ class FinancialBotQAChain(Chain):
159
+ """This custom chain handles LLM generation upon given prompt"""
160
+
161
+ hf_pipeline: HuggingFacePipeline
162
+ template: PromptTemplate
163
+
164
+ @property
165
+ def input_keys(self) -> List[str]:
166
+ """Returns a list of input keys for the chain"""
167
+
168
+ return ["context"]
169
+
170
+ @property
171
+ def output_keys(self) -> List[str]:
172
+ """Returns a list of output keys for the chain"""
173
+
174
+ return ["answer"]
175
+
176
+ def _call(
177
+ self,
178
+ inputs: Dict[str, Any],
179
+ run_manager: Optional[CallbackManagerForChainRun] = None,
180
+ ) -> Dict[str, Any]:
181
+ """Calls the chain with the given inputs and returns the output"""
182
+
183
+ inputs = self.clean(inputs)
184
+ prompt = self.template.format_infer(
185
+ {
186
+ "user_context": inputs["about_me"],
187
+ "news_context": inputs["context"],
188
+ "chat_history": inputs["chat_history"],
189
+ "question": inputs["question"],
190
+ }
191
+ )
192
+
193
+ start_time = time.time()
194
+ response = self.hf_pipeline(prompt["prompt"])
195
+ end_time = time.time()
196
+ duration_milliseconds = (end_time - start_time) * 1000
197
+
198
+ if run_manager:
199
+ run_manager.on_chain_end(
200
+ outputs={
201
+ "answer": response,
202
+ },
203
+ # TODO: Count tokens instead of using len().
204
+ metadata={
205
+ "prompt": prompt["prompt"],
206
+ "prompt_template_variables": prompt["payload"],
207
+ "prompt_template": self.template.infer_raw_template,
208
+ "usage.prompt_tokens": len(prompt["prompt"]),
209
+ "usage.total_tokens": len(prompt["prompt"]) + len(response),
210
+ "usage.actual_new_tokens": len(response),
211
+ "duration_milliseconds": duration_milliseconds,
212
+ },
213
+ )
214
+
215
+ return {"answer": response}
216
+
217
+ def clean(self, inputs: Dict[str, str]) -> Dict[str, str]:
218
+ """Cleans the inputs by removing extra whitespace and grouping broken paragraphs"""
219
+
220
+ for key, input in inputs.items():
221
+ cleaned_input = clean_extra_whitespace(input)
222
+ cleaned_input = group_broken_paragraphs(cleaned_input)
223
+
224
+ inputs[key] = cleaned_input
225
+
226
+ return inputs
financial_bot/constants.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ # == Embeddings model ==
4
+ EMBEDDING_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2"
5
+ EMBEDDING_MODEL_MAX_INPUT_LENGTH = 384
6
+
7
+ # == VECTOR Database ==
8
+ VECTOR_DB_OUTPUT_COLLECTION_NAME = "alpaca_financial_news"
9
+ VECTOR_DB_SEARCH_TOPK = 1
10
+
11
+ # == LLM Model ==
12
+ LLM_MODEL_ID = "unsloth/mistral-7b-instruct-v0.2-bnb-4bit"
13
+ LLM_QLORA_CHECKPOINT = "plantbased/mistral-7b-instruct-v0.2-4bit"
14
+
15
+ LLM_INFERNECE_MAX_NEW_TOKENS = 500
16
+ LLM_INFERENCE_TEMPERATURE = 1.0
17
+
18
+
19
+ # == Prompt Template ==
20
+ TEMPLATE_NAME = "mistral"
21
+
22
+ # === Misc ===
23
+ CACHE_DIR = Path.home() / ".cache" / "hands-on-llms"
financial_bot/embeddings.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import traceback
3
+ from typing import Optional, Union
4
+
5
+ import numpy as np
6
+ from transformers import AutoModel, AutoTokenizer
7
+
8
+ from financial_bot import constants
9
+ from financial_bot.base import SingletonMeta
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class EmbeddingModelSingleton(metaclass=SingletonMeta):
15
+ """
16
+ A singleton class that provides a pre-trained transformer model for generating embeddings of input text.
17
+
18
+ Args:
19
+ model_id (str): The identifier of the pre-trained transformer model to use.
20
+ max_input_length (int): The maximum length of input text to tokenize.
21
+ device (str): The device to use for running the model (e.g. "cpu", "cuda").
22
+ cache_dir (Optional[Path]): The directory to cache the pre-trained model files.
23
+ If None, the default cache directory is used.
24
+
25
+ Attributes:
26
+ max_input_length (int): The maximum length of input text to tokenize.
27
+ tokenizer (AutoTokenizer): The tokenizer used to tokenize input text.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ model_id: str = constants.EMBEDDING_MODEL_ID,
33
+ max_input_length: int = constants.EMBEDDING_MODEL_MAX_INPUT_LENGTH,
34
+ device: str = "cuda:0",
35
+ cache_dir: Optional[str] = None,
36
+ ):
37
+ """
38
+ Initializes the EmbeddingModelSingleton instance.
39
+
40
+ Args:
41
+ model_id (str): The identifier of the pre-trained transformer model to use.
42
+ max_input_length (int): The maximum length of input text to tokenize.
43
+ device (str): The device to use for running the model (e.g. "cpu", "cuda").
44
+ cache_dir (Optional[Path]): The directory to cache the pre-trained model files.
45
+ If None, the default cache directory is used.
46
+ """
47
+
48
+ self._model_id = model_id
49
+ self._device = device
50
+ self._max_input_length = max_input_length
51
+
52
+ self._tokenizer = AutoTokenizer.from_pretrained(model_id)
53
+ self._model = AutoModel.from_pretrained(
54
+ model_id,
55
+ cache_dir=str(cache_dir) if cache_dir else None,
56
+ ).to(self._device)
57
+ self._model.eval()
58
+
59
+ @property
60
+ def max_input_length(self) -> int:
61
+ """
62
+ Returns the maximum length of input text to tokenize.
63
+
64
+ Returns:
65
+ int: The maximum length of input text to tokenize.
66
+ """
67
+
68
+ return self._max_input_length
69
+
70
+ @property
71
+ def tokenizer(self) -> AutoTokenizer:
72
+ """
73
+ Returns the tokenizer used to tokenize input text.
74
+
75
+ Returns:
76
+ AutoTokenizer: The tokenizer used to tokenize input text.
77
+ """
78
+
79
+ return self._tokenizer
80
+
81
+ def __call__(
82
+ self, input_text: str, to_list: bool = True
83
+ ) -> Union[np.ndarray, list]:
84
+ """
85
+ Generates embeddings for the input text using the pre-trained transformer model.
86
+
87
+ Args:
88
+ input_text (str): The input text to generate embeddings for.
89
+ to_list (bool): Whether to return the embeddings as a list or numpy array. Defaults to True.
90
+
91
+ Returns:
92
+ Union[np.ndarray, list]: The embeddings generated for the input text.
93
+ """
94
+
95
+ try:
96
+ tokenized_text = self._tokenizer(
97
+ input_text,
98
+ padding=True,
99
+ truncation=True,
100
+ return_tensors="pt",
101
+ max_length=self._max_input_length,
102
+ ).to(self._device)
103
+ except Exception:
104
+ logger.error(traceback.format_exc())
105
+ logger.error(f"Error tokenizing the following input text: {input_text}")
106
+
107
+ return [] if to_list else np.array([])
108
+
109
+ try:
110
+ result = self._model(**tokenized_text)
111
+ except Exception:
112
+ logger.error(traceback.format_exc())
113
+ logger.error(
114
+ f"Error generating embeddings for the following model_id: {self._model_id} and input text: {input_text}"
115
+ )
116
+
117
+ return [] if to_list else np.array([])
118
+
119
+ embeddings = result.last_hidden_state[:, 0, :].cpu().detach().numpy()
120
+ if to_list:
121
+ embeddings = embeddings.flatten().tolist()
122
+
123
+ return embeddings
financial_bot/handlers.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict
2
+
3
+ import comet_llm
4
+ from langchain.callbacks.base import BaseCallbackHandler
5
+
6
+ from financial_bot import constants
7
+
8
+
9
+ class CometLLMMonitoringHandler(BaseCallbackHandler):
10
+ """
11
+ A callback handler for monitoring LLM models using Comet.ml.
12
+
13
+ Args:
14
+ project_name (str): The name of the Comet.ml project to log to.
15
+ llm_model_id (str): The ID of the LLM model to use for inference.
16
+ llm_qlora_model_id (str): The ID of the PEFT model to use for inference.
17
+ llm_inference_max_new_tokens (int): The maximum number of new tokens to generate during inference.
18
+ llm_inference_temperature (float): The temperature to use during inference.
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ project_name: str = None,
24
+ llm_model_id: str = constants.LLM_MODEL_ID,
25
+ llm_qlora_model_id: str = constants.LLM_QLORA_CHECKPOINT,
26
+ llm_inference_max_new_tokens: int = constants.LLM_INFERNECE_MAX_NEW_TOKENS,
27
+ llm_inference_temperature: float = constants.LLM_INFERENCE_TEMPERATURE,
28
+ ):
29
+ self._project_name = project_name
30
+ self._llm_model_id = llm_model_id
31
+ self._llm_qlora_model_id = llm_qlora_model_id
32
+ self._llm_inference_max_new_tokens = llm_inference_max_new_tokens
33
+ self._llm_inference_temperature = llm_inference_temperature
34
+
35
+ def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
36
+ """
37
+ A callback function that logs the prompt and output to Comet.ml.
38
+
39
+ Args:
40
+ outputs (Dict[str, Any]): The output of the LLM model.
41
+ **kwargs (Any): Additional arguments passed to the function.
42
+ """
43
+
44
+ should_log_prompt = "metadata" in kwargs
45
+ if should_log_prompt:
46
+ metadata = kwargs["metadata"]
47
+
48
+ comet_llm.log_prompt(
49
+ project=self._project_name,
50
+ prompt=metadata["prompt"],
51
+ output=outputs["answer"],
52
+ prompt_template=metadata["prompt_template"],
53
+ prompt_template_variables=metadata["prompt_template_variables"],
54
+ metadata={
55
+ "usage.prompt_tokens": metadata["usage.prompt_tokens"],
56
+ "usage.total_tokens": metadata["usage.total_tokens"],
57
+ "usage.max_new_tokens": self._llm_inference_max_new_tokens,
58
+ "usage.temperature": self._llm_inference_temperature,
59
+ "usage.actual_new_tokens": metadata["usage.actual_new_tokens"],
60
+ "model": self._llm_model_id,
61
+ "peft_model": self._llm_qlora_model_id,
62
+ },
63
+ duration=metadata["duration_milliseconds"],
64
+ )
financial_bot/langchain_bot.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Iterable, List, Tuple
5
+
6
+ from langchain import chains
7
+ from langchain.memory import ConversationBufferWindowMemory
8
+
9
+ from financial_bot import constants
10
+ from financial_bot.chains import (
11
+ ContextExtractorChain,
12
+ FinancialBotQAChain,
13
+ StatelessMemorySequentialChain,
14
+ )
15
+ from financial_bot.embeddings import EmbeddingModelSingleton
16
+ from financial_bot.handlers import CometLLMMonitoringHandler
17
+ from financial_bot.models import build_huggingface_pipeline
18
+ from financial_bot.qdrant import build_qdrant_client
19
+ from financial_bot.template import get_llm_template
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class FinancialBot:
25
+ """
26
+ A language chain bot that uses a language model to generate responses to user inputs.
27
+
28
+ Args:
29
+ llm_model_id (str): The ID of the Hugging Face language model to use.
30
+ llm_qlora_model_id (str): The ID of the Hugging Face QLora model to use.
31
+ llm_template_name (str): The name of the LLM template to use.
32
+ llm_inference_max_new_tokens (int): The maximum number of new tokens to generate during inference.
33
+ llm_inference_temperature (float): The temperature to use during inference.
34
+ vector_collection_name (str): The name of the Qdrant vector collection to use.
35
+ vector_db_search_topk (int): The number of nearest neighbors to search for in the Qdrant vector database.
36
+ model_cache_dir (Path): The directory to use for caching the language model and embedding model.
37
+ streaming (bool): Whether to use the Hugging Face streaming API for inference.
38
+ embedding_model_device (str): The device to use for the embedding model.
39
+ debug (bool): Whether to enable debug mode.
40
+
41
+ Attributes:
42
+ finbot_chain (Chain): The language chain that generates responses to user inputs.
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ llm_model_id: str = constants.LLM_MODEL_ID,
48
+ llm_qlora_model_id: str = constants.LLM_QLORA_CHECKPOINT,
49
+ llm_template_name: str = constants.TEMPLATE_NAME,
50
+ llm_inference_max_new_tokens: int = constants.LLM_INFERNECE_MAX_NEW_TOKENS,
51
+ llm_inference_temperature: float = constants.LLM_INFERENCE_TEMPERATURE,
52
+ vector_collection_name: str = constants.VECTOR_DB_OUTPUT_COLLECTION_NAME,
53
+ vector_db_search_topk: int = constants.VECTOR_DB_SEARCH_TOPK,
54
+ model_cache_dir: Path = constants.CACHE_DIR,
55
+ streaming: bool = False,
56
+ embedding_model_device: str = "cuda:0",
57
+ debug: bool = False,
58
+ ):
59
+ self._llm_model_id = llm_model_id
60
+ self._llm_qlora_model_id = llm_qlora_model_id
61
+ self._llm_template_name = llm_template_name
62
+ self._llm_template = get_llm_template(name=self._llm_template_name)
63
+ self._llm_inference_max_new_tokens = llm_inference_max_new_tokens
64
+ self._llm_inference_temperature = llm_inference_temperature
65
+ self._vector_collection_name = vector_collection_name
66
+ self._vector_db_search_topk = vector_db_search_topk
67
+ self._debug = debug
68
+
69
+ self._qdrant_client = build_qdrant_client()
70
+
71
+ self._embd_model = EmbeddingModelSingleton(
72
+ cache_dir=model_cache_dir, device=embedding_model_device
73
+ )
74
+ self._llm_agent, self._streamer = build_huggingface_pipeline(
75
+ llm_model_id=llm_model_id,
76
+ llm_lora_model_id=llm_qlora_model_id,
77
+ max_new_tokens=llm_inference_max_new_tokens,
78
+ temperature=llm_inference_temperature,
79
+ use_streamer=streaming,
80
+ cache_dir=model_cache_dir,
81
+ debug=debug,
82
+ )
83
+ self.finbot_chain = self.build_chain()
84
+
85
+ @property
86
+ def is_streaming(self) -> bool:
87
+ return self._streamer is not None
88
+
89
+ def build_chain(self) -> chains.SequentialChain:
90
+ """
91
+ Constructs and returns a financial bot chain.
92
+ This chain is designed to take as input the user description, `about_me` and a `question` and it will
93
+ connect to the VectorDB, searches the financial news that rely on the user's question and injects them into the
94
+ payload that is further passed as a prompt to a financial fine-tuned LLM that will provide answers.
95
+
96
+ The chain consists of two primary stages:
97
+ 1. Context Extractor: This stage is responsible for embedding the user's question,
98
+ which means converting the textual question into a numerical representation.
99
+ This embedded question is then used to retrieve relevant context from the VectorDB.
100
+ The output of this chain will be a dict payload.
101
+
102
+ 2. LLM Generator: Once the context is extracted,
103
+ this stage uses it to format a full prompt for the LLM and
104
+ then feed it to the model to get a response that is relevant to the user's question.
105
+
106
+ Returns
107
+ -------
108
+ chains.SequentialChain
109
+ The constructed financial bot chain.
110
+
111
+ Notes
112
+ -----
113
+ The actual processing flow within the chain can be visualized as:
114
+ [about: str][question: str] > ContextChain >
115
+ [about: str][question:str] + [context: str] > FinancialChain >
116
+ [answer: str]
117
+ """
118
+
119
+ logger.info("Building 1/3 - ContextExtractorChain")
120
+ context_retrieval_chain = ContextExtractorChain(
121
+ embedding_model=self._embd_model,
122
+ vector_store=self._qdrant_client,
123
+ vector_collection=self._vector_collection_name,
124
+ top_k=self._vector_db_search_topk,
125
+ )
126
+
127
+ logger.info("Building 2/3 - FinancialBotQAChain")
128
+ if self._debug:
129
+ callabacks = []
130
+ else:
131
+ try:
132
+ comet_project_name = os.environ["COMET_PROJECT_NAME"]
133
+ except KeyError:
134
+ raise RuntimeError(
135
+ "Please set the COMET_PROJECT_NAME environment variable."
136
+ )
137
+ callabacks = [
138
+ CometLLMMonitoringHandler(
139
+ project_name=f"{comet_project_name}-monitor-prompts",
140
+ llm_model_id=self._llm_model_id,
141
+ llm_qlora_model_id=self._llm_qlora_model_id,
142
+ llm_inference_max_new_tokens=self._llm_inference_max_new_tokens,
143
+ llm_inference_temperature=self._llm_inference_temperature,
144
+ )
145
+ ]
146
+ llm_generator_chain = FinancialBotQAChain(
147
+ hf_pipeline=self._llm_agent,
148
+ template=self._llm_template,
149
+ callbacks=callabacks,
150
+ )
151
+
152
+ logger.info("Building 3/3 - Connecting chains into SequentialChain")
153
+ seq_chain = StatelessMemorySequentialChain(
154
+ history_input_key="to_load_history",
155
+ memory=ConversationBufferWindowMemory(
156
+ memory_key="chat_history",
157
+ input_key="question",
158
+ output_key="answer",
159
+ k=3,
160
+ ),
161
+ chains=[context_retrieval_chain, llm_generator_chain],
162
+ input_variables=["about_me", "question", "to_load_history"],
163
+ output_variables=["answer"],
164
+ verbose=True,
165
+ )
166
+
167
+ logger.info("Done building SequentialChain.")
168
+ logger.info("Workflow:")
169
+ logger.info(
170
+ """
171
+ [about: str][question: str] > ContextChain >
172
+ [about: str][question:str] + [context: str] > FinancialChain >
173
+ [answer: str]
174
+ """
175
+ )
176
+
177
+ return seq_chain
178
+
179
+ def answer(
180
+ self,
181
+ about_me: str,
182
+ question: str,
183
+ to_load_history: List[Tuple[str, str]] = None,
184
+ ) -> str:
185
+ """
186
+ Given a short description about the user and a question make the LLM
187
+ generate a response.
188
+
189
+ Parameters
190
+ ----------
191
+ about_me : str
192
+ Short user description.
193
+ question : str
194
+ User question.
195
+
196
+ Returns
197
+ -------
198
+ str
199
+ LLM generated response.
200
+ """
201
+
202
+ inputs = {
203
+ "about_me": about_me,
204
+ "question": question,
205
+ "to_load_history": to_load_history if to_load_history else [],
206
+ }
207
+ response = self.finbot_chain.run(inputs)
208
+
209
+ return response
210
+
211
+ def stream_answer(self) -> Iterable[str]:
212
+ """Stream the answer from the LLM after each token is generated after calling `answer()`."""
213
+
214
+ assert (
215
+ self.is_streaming
216
+ ), "Stream answer not available. Build the bot with `use_streamer=True`."
217
+
218
+ partial_answer = ""
219
+ for new_token in self._streamer:
220
+ if new_token != self._llm_template.eos:
221
+ partial_answer += new_token
222
+
223
+ yield partial_answer
financial_bot/models.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from pathlib import Path
4
+ from typing import List, Optional, Tuple
5
+
6
+ import torch
7
+ from comet_ml import API
8
+ from langchain.llms import HuggingFacePipeline
9
+ from peft import LoraConfig, PeftConfig, PeftModel
10
+ from transformers import (
11
+ AutoModelForCausalLM,
12
+ AutoTokenizer,
13
+ BitsAndBytesConfig,
14
+ StoppingCriteria,
15
+ StoppingCriteriaList,
16
+ TextIteratorStreamer,
17
+ pipeline,
18
+ )
19
+
20
+ from financial_bot import constants
21
+ from financial_bot.utils import MockedPipeline
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ def download_from_model_registry(
27
+ model_id: str, cache_dir: Optional[Path] = None
28
+ ) -> Path:
29
+ """
30
+ Downloads a model from the Comet ML Learning model registry.
31
+
32
+ Args:
33
+ model_id (str): The ID of the model to download, in the format "workspace/model_name:version".
34
+ cache_dir (Optional[Path]): The directory to cache the downloaded model in. Defaults to the value of
35
+ `constants.CACHE_DIR`.
36
+
37
+ Returns:
38
+ Path: The path to the downloaded model directory.
39
+ """
40
+
41
+ if cache_dir is None:
42
+ cache_dir = constants.CACHE_DIR
43
+ output_folder = cache_dir / "models" / model_id
44
+
45
+ already_downloaded = output_folder.exists()
46
+ if not already_downloaded:
47
+ workspace, model_id = model_id.split("/")
48
+ model_name, version = model_id.split(":")
49
+
50
+ api = API()
51
+ model = api.get_model(workspace=workspace, model_name=model_name)
52
+ model.download(version=version, output_folder=output_folder, expand=True)
53
+ else:
54
+ logger.info(f"Model {model_id=} already downloaded to: {output_folder}")
55
+
56
+ subdirs = [d for d in output_folder.iterdir() if d.is_dir()]
57
+ if len(subdirs) == 1:
58
+ model_dir = subdirs[0]
59
+ else:
60
+ raise RuntimeError(
61
+ f"There should be only one directory inside the model folder. \
62
+ Check the downloaded model at: {output_folder}"
63
+ )
64
+
65
+ logger.info(f"Model {model_id=} downloaded from the registry to: {model_dir}")
66
+
67
+ return model_dir
68
+
69
+
70
+ class StopOnTokens(StoppingCriteria):
71
+ """
72
+ A stopping criteria that stops generation when a specific token is generated.
73
+
74
+ Args:
75
+ stop_ids (List[int]): A list of token ids that will trigger the stopping criteria.
76
+ """
77
+
78
+ def __init__(self, stop_ids: List[int]):
79
+ super().__init__()
80
+
81
+ self._stop_ids = stop_ids
82
+
83
+ def __call__(
84
+ self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
85
+ ) -> bool:
86
+ """
87
+ Check if the last generated token is in the stop_ids list.
88
+
89
+ Args:
90
+ input_ids (torch.LongTensor): The input token ids.
91
+ scores (torch.FloatTensor): The scores of the generated tokens.
92
+
93
+ Returns:
94
+ bool: True if the last generated token is in the stop_ids list, False otherwise.
95
+ """
96
+
97
+ for stop_id in self._stop_ids:
98
+ if input_ids[0][-1] == stop_id:
99
+ return True
100
+
101
+ return False
102
+
103
+
104
+ def build_huggingface_pipeline(
105
+ llm_model_id: str,
106
+ llm_lora_model_id: str,
107
+ max_new_tokens: int = constants.LLM_INFERNECE_MAX_NEW_TOKENS,
108
+ temperature: float = constants.LLM_INFERENCE_TEMPERATURE,
109
+ gradient_checkpointing: bool = False,
110
+ use_streamer: bool = False,
111
+ cache_dir: Optional[Path] = None,
112
+ debug: bool = False,
113
+ ) -> Tuple[HuggingFacePipeline, Optional[TextIteratorStreamer]]:
114
+ """
115
+ Builds a HuggingFace pipeline for text generation using a custom LLM + Finetuned checkpoint.
116
+
117
+ Args:
118
+ llm_model_id (str): The ID or path of the LLM model.
119
+ llm_lora_model_id (str): The ID or path of the LLM LoRA model.
120
+ max_new_tokens (int, optional): The maximum number of new tokens to generate. Defaults to 128.
121
+ temperature (float, optional): The temperature to use for sampling. Defaults to 0.7.
122
+ gradient_checkpointing (bool, optional): Whether to use gradient checkpointing. Defaults to False.
123
+ use_streamer (bool, optional): Whether to use a text iterator streamer. Defaults to False.
124
+ cache_dir (Optional[Path], optional): The directory to use for caching. Defaults to None.
125
+ debug (bool, optional): Whether to use a mocked pipeline for debugging. Defaults to False.
126
+
127
+ Returns:
128
+ Tuple[HuggingFacePipeline, Optional[TextIteratorStreamer]]: A tuple containing the HuggingFace pipeline
129
+ and the text iterator streamer (if used).
130
+ """
131
+
132
+ if debug is True:
133
+ return (
134
+ HuggingFacePipeline(
135
+ pipeline=MockedPipeline(f=lambda _: "You are doing great!")
136
+ ),
137
+ None,
138
+ )
139
+
140
+ model, tokenizer, _ = build_qlora_model(
141
+ pretrained_model_name_or_path=llm_model_id,
142
+ peft_pretrained_model_name_or_path=llm_lora_model_id,
143
+ gradient_checkpointing=gradient_checkpointing,
144
+ cache_dir=cache_dir,
145
+ )
146
+ model.eval()
147
+
148
+ if use_streamer:
149
+ streamer = TextIteratorStreamer(
150
+ tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
151
+ )
152
+ stop_on_tokens = StopOnTokens(stop_ids=[tokenizer.eos_token_id])
153
+ stopping_criteria = StoppingCriteriaList([stop_on_tokens])
154
+ else:
155
+ streamer = None
156
+ stopping_criteria = StoppingCriteriaList([])
157
+
158
+ pipe = pipeline(
159
+ "text-generation",
160
+ model=model,
161
+ tokenizer=tokenizer,
162
+ max_new_tokens=max_new_tokens,
163
+ temperature=temperature,
164
+ streamer=streamer,
165
+ stopping_criteria=stopping_criteria,
166
+ )
167
+ hf = HuggingFacePipeline(pipeline=pipe)
168
+
169
+ return hf, streamer
170
+
171
+
172
+ def build_qlora_model(
173
+ pretrained_model_name_or_path: str = "tiiuae/falcon-7b-instruct",
174
+ peft_pretrained_model_name_or_path: Optional[str] = None,
175
+ gradient_checkpointing: bool = True,
176
+ cache_dir: Optional[Path] = None,
177
+ ) -> Tuple[AutoModelForCausalLM, AutoTokenizer, PeftConfig]:
178
+ """
179
+ Function that builds a QLoRA LLM model based on the given HuggingFace name:
180
+ 1. Create and prepare the bitsandbytes configuration for QLoRa's quantization
181
+ 2. Download, load, and quantize on-the-fly Falcon-7b
182
+ 3. Create and prepare the LoRa configuration
183
+ 4. Load and configuration Falcon-7B's tokenizer
184
+
185
+ Args:
186
+ pretrained_model_name_or_path (str): The name or path of the pretrained model to use.
187
+ peft_pretrained_model_name_or_path (Optional[str]): The name or path of the PEFT pretrained model to use.
188
+ gradient_checkpointing (bool): Whether to use gradient checkpointing or not.
189
+ cache_dir (Optional[Path]): The directory to cache the downloaded models.
190
+
191
+ Returns:
192
+ Tuple[AutoModelForCausalLM, AutoTokenizer, PeftConfig]:
193
+ A tuple containing the QLoRA LLM model, tokenizer, and PEFT config.
194
+ """
195
+
196
+ bnb_config = BitsAndBytesConfig(
197
+ load_in_4bit=True,
198
+ bnb_4bit_use_double_quant=True,
199
+ bnb_4bit_quant_type="nf4",
200
+ bnb_4bit_compute_dtype=torch.bfloat16,
201
+ )
202
+
203
+ model = AutoModelForCausalLM.from_pretrained(
204
+ pretrained_model_name_or_path,
205
+ revision="main",
206
+ quantization_config=bnb_config,
207
+ load_in_4bit=True,
208
+ device_map="auto",
209
+ trust_remote_code=False,
210
+ cache_dir=str(cache_dir) if cache_dir else None,
211
+ )
212
+
213
+ tokenizer = AutoTokenizer.from_pretrained(
214
+ pretrained_model_name_or_path,
215
+ trust_remote_code=False,
216
+ truncation=True,
217
+ cache_dir=str(cache_dir) if cache_dir else None,
218
+ )
219
+ if tokenizer.pad_token_id is None:
220
+ tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
221
+ with torch.no_grad():
222
+ model.resize_token_embeddings(len(tokenizer))
223
+ model.config.pad_token_id = tokenizer.pad_token_id
224
+
225
+ if peft_pretrained_model_name_or_path:
226
+ is_model_name = not os.path.isdir(peft_pretrained_model_name_or_path)
227
+ if is_model_name:
228
+ logger.info(
229
+ f"Downloading {peft_pretrained_model_name_or_path} from CometML Model Registry:"
230
+ )
231
+ peft_pretrained_model_name_or_path = download_from_model_registry(
232
+ model_id=peft_pretrained_model_name_or_path,
233
+ cache_dir=cache_dir,
234
+ )
235
+
236
+ logger.info(f"Loading Lora Confing from: {peft_pretrained_model_name_or_path}")
237
+ lora_config = LoraConfig.from_pretrained(peft_pretrained_model_name_or_path)
238
+ assert (
239
+ lora_config.base_model_name_or_path == pretrained_model_name_or_path
240
+ ), f"Lora Model trained on different base model than the one requested: \
241
+ {lora_config.base_model_name_or_path} != {pretrained_model_name_or_path}"
242
+
243
+ logger.info(f"Loading Peft Model from: {peft_pretrained_model_name_or_path}")
244
+ model = PeftModel.from_pretrained(model, peft_pretrained_model_name_or_path)
245
+ else:
246
+ lora_config = LoraConfig(
247
+ lora_alpha=16,
248
+ lora_dropout=0.1,
249
+ r=64,
250
+ bias="none",
251
+ task_type="CAUSAL_LM",
252
+ target_modules=["query_key_value"],
253
+ )
254
+
255
+ if gradient_checkpointing:
256
+ model.gradient_checkpointing_enable()
257
+ model.config.use_cache = (
258
+ False # Gradient checkpointing is not compatible with caching.
259
+ )
260
+ else:
261
+ model.gradient_checkpointing_disable()
262
+ model.config.use_cache = True # It is good practice to enable caching when using the model for inference.
263
+
264
+ return model, tokenizer, lora_config
financial_bot/qdrant.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from typing import Optional
4
+
5
+ import qdrant_client
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ def build_qdrant_client(
11
+ url: Optional[str] = None,
12
+ api_key: Optional[str] = None,
13
+ ):
14
+ """
15
+ Builds a Qdrant client object using the provided URL and API key.
16
+
17
+ Args:
18
+ url (Optional[str]): The URL of the Qdrant server. If not provided, the function will attempt
19
+ to read it from the QDRANT_URL environment variable.
20
+ api_key (Optional[str]): The API key to use for authentication. If not provided, the function will attempt
21
+ to read it from the QDRANT_API_KEY environment variable.
22
+
23
+ Raises:
24
+ KeyError: If the URL or API key is not provided and cannot be read from the environment variables.
25
+
26
+ Returns:
27
+ qdrant_client.QdrantClient: A Qdrant client object.
28
+ """
29
+
30
+ logger.info("Building QDrant Client")
31
+ if url is None:
32
+ try:
33
+ url = os.environ["QDRANT_URL"]
34
+ except KeyError:
35
+ raise KeyError(
36
+ "QDRANT_URL must be set as environment variable or manually passed as an argument."
37
+ )
38
+
39
+ if api_key is None:
40
+ try:
41
+ api_key = os.environ["QDRANT_API_KEY"]
42
+ except KeyError:
43
+ raise KeyError(
44
+ "QDRANT_API_KEY must be set as environment variable or manually passed as an argument."
45
+ )
46
+
47
+ client = qdrant_client.QdrantClient(url, api_key=api_key)
48
+
49
+ return client
financial_bot/template.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script defines a PromptTemplate class that assists in generating
3
+ conversation/prompt templates. The script facilitates formatting prompts
4
+ for inference and training by combining various context elements and user inputs.
5
+ """
6
+
7
+
8
+ import dataclasses
9
+ from typing import Dict, List, Union
10
+
11
+
12
+ @dataclasses.dataclass
13
+ class PromptTemplate:
14
+ """A class that manages prompt templates"""
15
+
16
+ # The name of this template
17
+ name: str
18
+ # The template of the system prompt
19
+ system_template: str = "{system_message}"
20
+ # The template for the system context
21
+ context_template: str = "{user_context}\n{news_context}"
22
+ # The template for the conversation history
23
+ chat_history_template: str = "{chat_history}"
24
+ # The template of the user question
25
+ question_template: str = "{question}"
26
+ # The template of the system answer
27
+ answer_template: str = "{answer}"
28
+ # The system message
29
+ system_message: str = ""
30
+ # Separator
31
+ sep: str = "\n"
32
+ eos: str = "</s>"
33
+
34
+ @property
35
+ def input_variables(self) -> List[str]:
36
+ """Returns a list of input variables for the prompt template"""
37
+
38
+ return ["user_context", "news_context", "chat_history", "question", "answer"]
39
+
40
+ @property
41
+ def train_raw_template(self):
42
+ """Returns the training prompt template format"""
43
+
44
+ system = self.system_template.format(system_message=self.system_message)
45
+ context = f"{self.sep}{self.context_template}"
46
+ chat_history = f"{self.sep}{self.chat_history_template}"
47
+ question = f"{self.sep}{self.question_template}"
48
+ answer = f"{self.sep}{self.answer_template}"
49
+
50
+ return f"{system}{context}{chat_history}{question}{answer}{self.eos}"
51
+
52
+ @property
53
+ def infer_raw_template(self):
54
+ """Returns the inference prompt template format"""
55
+
56
+ system = self.system_template.format(system_message=self.system_message)
57
+ context = f"{self.sep}{self.context_template}"
58
+ chat_history = f"{self.sep}{self.chat_history_template}"
59
+ question = f"{self.sep}{self.question_template}"
60
+
61
+ return f"{system}{context}{chat_history}{question}{self.eos}"
62
+
63
+ def format_train(self, sample: Dict[str, str]) -> Dict[str, Union[str, Dict]]:
64
+ """Formats the data sample to a training sample"""
65
+
66
+ prompt = self.train_raw_template.format(
67
+ user_context=sample["user_context"],
68
+ news_context=sample["news_context"],
69
+ chat_history=sample.get("chat_history", ""),
70
+ question=sample["question"],
71
+ answer=sample["answer"],
72
+ )
73
+ return {"prompt": prompt, "payload": sample}
74
+
75
+ def format_infer(self, sample: Dict[str, str]) -> Dict[str, Union[str, Dict]]:
76
+ """Formats the data sample to a testing sample"""
77
+
78
+ prompt = self.infer_raw_template.format(
79
+ user_context=sample["user_context"],
80
+ news_context=sample["news_context"],
81
+ chat_history=sample.get("chat_history", ""),
82
+ question=sample["question"],
83
+ )
84
+ return {"prompt": prompt, "payload": sample}
85
+
86
+
87
+ # Global Templates registry
88
+ templates: Dict[str, PromptTemplate] = {}
89
+
90
+
91
+ def register_llm_template(template: PromptTemplate):
92
+ """Register a new template to the global templates registry"""
93
+
94
+ templates[template.name] = template
95
+
96
+
97
+ def get_llm_template(name: str) -> PromptTemplate:
98
+ """Returns the template assigned to the given name"""
99
+
100
+ return templates[name]
101
+
102
+
103
+ ##### Register Templates #####
104
+ # - Mistral 7B Instruct v0.2 Template
105
+ register_llm_template(
106
+ PromptTemplate(
107
+ name="mistral",
108
+ system_template="<s>{system_message}",
109
+ system_message="You are a helpful assistant, with financial expertise.",
110
+ context_template="{user_context}\n{news_context}",
111
+ chat_history_template="Summary: {chat_history}",
112
+ question_template="[INST] {question} [/INST]",
113
+ answer_template="{answer}",
114
+ sep="\n",
115
+ eos=" </s>",
116
+ )
117
+ )
118
+
119
+ # - FALCON (spec: https://huggingface.co/tiiuae/falcon-7b/blob/main/tokenizer.json)
120
+ register_llm_template(
121
+ PromptTemplate(
122
+ name="falcon",
123
+ system_template=">>INTRODUCTION<< {system_message}",
124
+ system_message="You are a helpful assistant, with financial expertise.",
125
+ context_template=">>DOMAIN<< {user_context}\n{news_context}",
126
+ chat_history_template=">>SUMMARY<< {chat_history}",
127
+ question_template=">>QUESTION<< {question}",
128
+ answer_template=">>ANSWER<< {answer}",
129
+ sep="\n",
130
+ eos="<|endoftext|>",
131
+ )
132
+ )
financial_bot/utils.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import subprocess
4
+ from typing import Callable, Dict, List
5
+
6
+ import psutil
7
+ import torch
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ def log_available_gpu_memory():
13
+ """
14
+ Logs the available GPU memory for each available GPU device.
15
+
16
+ If no GPUs are available, logs "No GPUs available".
17
+
18
+ Returns:
19
+ None
20
+ """
21
+
22
+ if torch.cuda.is_available():
23
+ for i in range(torch.cuda.device_count()):
24
+ memory_info = subprocess.check_output(
25
+ f"nvidia-smi -i {i} --query-gpu=memory.free --format=csv,nounits,noheader",
26
+ shell=True,
27
+ )
28
+ memory_info = str(memory_info).split("\\")[0][2:]
29
+
30
+ logger.info(f"GPU {i} memory available: {memory_info} MiB")
31
+ else:
32
+ logger.info("No GPUs available")
33
+
34
+
35
+ def log_available_ram():
36
+ """
37
+ Logs the amount of available RAM in gigabytes.
38
+
39
+ Returns:
40
+ None
41
+ """
42
+
43
+ memory_info = psutil.virtual_memory()
44
+
45
+ # convert bytes to GB
46
+ logger.info(f"Available RAM: {memory_info.available / (1024.0 ** 3):.2f} GB")
47
+
48
+
49
+ def log_files_and_subdirs(directory_path: str):
50
+ """
51
+ Logs all files and subdirectories in the specified directory.
52
+
53
+ Args:
54
+ directory_path (str): The path to the directory to log.
55
+
56
+ Returns:
57
+ None
58
+ """
59
+
60
+ # Check if the directory exists
61
+ if os.path.exists(directory_path) and os.path.isdir(directory_path):
62
+ for dirpath, dirnames, filenames in os.walk(directory_path):
63
+ logger.info(f"Directory: {dirpath}")
64
+ for filename in filenames:
65
+ logger.info(f"File: {os.path.join(dirpath, filename)}")
66
+ for dirname in dirnames:
67
+ logger.info(f"Sub-directory: {os.path.join(dirpath, dirname)}")
68
+ else:
69
+ logger.info(f"The directory '{directory_path}' does not exist")
70
+
71
+
72
+ class MockedPipeline:
73
+ """
74
+ A mocked pipeline class that is used as a replacement to the HF pipeline class.
75
+
76
+ Attributes:
77
+ -----------
78
+ task : str
79
+ The task of the pipeline, which is text-generation.
80
+ f : Callable[[str], str]
81
+ A function that takes a prompt string as input and returns a generated text string.
82
+ """
83
+
84
+ task: str = "text-generation"
85
+
86
+ def __init__(self, f: Callable[[str], str]):
87
+ self.f = f
88
+
89
+ def __call__(self, prompt: str) -> List[Dict[str, str]]:
90
+ """
91
+ Calls the pipeline with a given prompt and returns a list of generated text.
92
+
93
+ Parameters:
94
+ -----------
95
+ prompt : str
96
+ The prompt string to generate text from.
97
+
98
+ Returns:
99
+ --------
100
+ List[Dict[str, str]]
101
+ A list of dictionaries, where each dictionary contains a generated_text key with the generated text string.
102
+ """
103
+
104
+ result = self.f(prompt)
105
+
106
+ return [{"generated_text": f"{prompt}{result}"}]