Spaces:
Runtime error
Runtime error
import logging | |
import os | |
from pathlib import Path | |
from typing import Iterable, List, Tuple | |
from langchain import chains | |
from langchain.memory import ConversationBufferWindowMemory | |
from financial_bot import constants | |
from financial_bot.chains import ( | |
ContextExtractorChain, | |
FinancialBotQAChain, | |
StatelessMemorySequentialChain, | |
) | |
from financial_bot.embeddings import EmbeddingModelSingleton | |
from financial_bot.handlers import CometLLMMonitoringHandler | |
from financial_bot.models import build_huggingface_pipeline | |
from financial_bot.qdrant import build_qdrant_client | |
from financial_bot.template import get_llm_template | |
logger = logging.getLogger(__name__) | |
class FinancialBot: | |
""" | |
A language chain bot that uses a language model to generate responses to user inputs. | |
Args: | |
llm_model_id (str): The ID of the Hugging Face language model to use. | |
llm_qlora_model_id (str): The ID of the Hugging Face QLora model to use. | |
llm_template_name (str): The name of the LLM template to use. | |
llm_inference_max_new_tokens (int): The maximum number of new tokens to generate during inference. | |
llm_inference_temperature (float): The temperature to use during inference. | |
vector_collection_name (str): The name of the Qdrant vector collection to use. | |
vector_db_search_topk (int): The number of nearest neighbors to search for in the Qdrant vector database. | |
model_cache_dir (Path): The directory to use for caching the language model and embedding model. | |
streaming (bool): Whether to use the Hugging Face streaming API for inference. | |
embedding_model_device (str): The device to use for the embedding model. | |
debug (bool): Whether to enable debug mode. | |
Attributes: | |
finbot_chain (Chain): The language chain that generates responses to user inputs. | |
""" | |
def __init__( | |
self, | |
llm_model_id: str = constants.LLM_MODEL_ID, | |
llm_qlora_model_id: str = constants.LLM_QLORA_CHECKPOINT, | |
llm_template_name: str = constants.TEMPLATE_NAME, | |
llm_inference_max_new_tokens: int = constants.LLM_INFERNECE_MAX_NEW_TOKENS, | |
llm_inference_temperature: float = constants.LLM_INFERENCE_TEMPERATURE, | |
vector_collection_name: str = constants.VECTOR_DB_OUTPUT_COLLECTION_NAME, | |
vector_db_search_topk: int = constants.VECTOR_DB_SEARCH_TOPK, | |
model_cache_dir: Path = constants.CACHE_DIR, | |
streaming: bool = False, | |
embedding_model_device: str = "cuda:0", | |
debug: bool = False, | |
): | |
self._llm_model_id = llm_model_id | |
self._llm_qlora_model_id = llm_qlora_model_id | |
self._llm_template_name = llm_template_name | |
self._llm_template = get_llm_template(name=self._llm_template_name) | |
self._llm_inference_max_new_tokens = llm_inference_max_new_tokens | |
self._llm_inference_temperature = llm_inference_temperature | |
self._vector_collection_name = vector_collection_name | |
self._vector_db_search_topk = vector_db_search_topk | |
self._debug = debug | |
self._qdrant_client = build_qdrant_client() | |
self._embd_model = EmbeddingModelSingleton( | |
cache_dir=model_cache_dir, device=embedding_model_device | |
) | |
self._llm_agent, self._streamer = build_huggingface_pipeline( | |
llm_model_id=llm_model_id, | |
llm_lora_model_id=llm_qlora_model_id, | |
max_new_tokens=llm_inference_max_new_tokens, | |
temperature=llm_inference_temperature, | |
use_streamer=streaming, | |
cache_dir=model_cache_dir, | |
debug=debug, | |
) | |
self.finbot_chain = self.build_chain() | |
def is_streaming(self) -> bool: | |
return self._streamer is not None | |
def build_chain(self) -> chains.SequentialChain: | |
""" | |
Constructs and returns a financial bot chain. | |
This chain is designed to take as input the user description, `about_me` and a `question` and it will | |
connect to the VectorDB, searches the financial news that rely on the user's question and injects them into the | |
payload that is further passed as a prompt to a financial fine-tuned LLM that will provide answers. | |
The chain consists of two primary stages: | |
1. Context Extractor: This stage is responsible for embedding the user's question, | |
which means converting the textual question into a numerical representation. | |
This embedded question is then used to retrieve relevant context from the VectorDB. | |
The output of this chain will be a dict payload. | |
2. LLM Generator: Once the context is extracted, | |
this stage uses it to format a full prompt for the LLM and | |
then feed it to the model to get a response that is relevant to the user's question. | |
Returns | |
------- | |
chains.SequentialChain | |
The constructed financial bot chain. | |
Notes | |
----- | |
The actual processing flow within the chain can be visualized as: | |
[about: str][question: str] > ContextChain > | |
[about: str][question:str] + [context: str] > FinancialChain > | |
[answer: str] | |
""" | |
logger.info("Building 1/3 - ContextExtractorChain") | |
context_retrieval_chain = ContextExtractorChain( | |
embedding_model=self._embd_model, | |
vector_store=self._qdrant_client, | |
vector_collection=self._vector_collection_name, | |
top_k=self._vector_db_search_topk, | |
) | |
logger.info("Building 2/3 - FinancialBotQAChain") | |
if self._debug: | |
callabacks = [] | |
else: | |
try: | |
comet_project_name = os.environ["COMET_PROJECT_NAME"] | |
except KeyError: | |
raise RuntimeError( | |
"Please set the COMET_PROJECT_NAME environment variable." | |
) | |
callabacks = [ | |
CometLLMMonitoringHandler( | |
project_name=f"{comet_project_name}-monitor-prompts", | |
llm_model_id=self._llm_model_id, | |
llm_qlora_model_id=self._llm_qlora_model_id, | |
llm_inference_max_new_tokens=self._llm_inference_max_new_tokens, | |
llm_inference_temperature=self._llm_inference_temperature, | |
) | |
] | |
llm_generator_chain = FinancialBotQAChain( | |
hf_pipeline=self._llm_agent, | |
template=self._llm_template, | |
callbacks=callabacks, | |
) | |
logger.info("Building 3/3 - Connecting chains into SequentialChain") | |
seq_chain = StatelessMemorySequentialChain( | |
history_input_key="to_load_history", | |
memory=ConversationBufferWindowMemory( | |
memory_key="chat_history", | |
input_key="question", | |
output_key="answer", | |
k=3, | |
), | |
chains=[context_retrieval_chain, llm_generator_chain], | |
input_variables=["about_me", "question", "to_load_history"], | |
output_variables=["answer"], | |
verbose=True, | |
) | |
logger.info("Done building SequentialChain.") | |
logger.info("Workflow:") | |
logger.info( | |
""" | |
[about: str][question: str] > ContextChain > | |
[about: str][question:str] + [context: str] > FinancialChain > | |
[answer: str] | |
""" | |
) | |
return seq_chain | |
def answer( | |
self, | |
about_me: str, | |
question: str, | |
to_load_history: List[Tuple[str, str]] = None, | |
) -> str: | |
""" | |
Given a short description about the user and a question make the LLM | |
generate a response. | |
Parameters | |
---------- | |
about_me : str | |
Short user description. | |
question : str | |
User question. | |
Returns | |
------- | |
str | |
LLM generated response. | |
""" | |
inputs = { | |
"about_me": about_me, | |
"question": question, | |
"to_load_history": to_load_history if to_load_history else [], | |
} | |
response = self.finbot_chain.run(inputs) | |
return response | |
def stream_answer(self) -> Iterable[str]: | |
"""Stream the answer from the LLM after each token is generated after calling `answer()`.""" | |
assert ( | |
self.is_streaming | |
), "Stream answer not available. Build the bot with `use_streamer=True`." | |
partial_answer = "" | |
for new_token in self._streamer: | |
if new_token != self._llm_template.eos: | |
partial_answer += new_token | |
yield partial_answer | |