Financial_Bot / financial_bot /langchain_bot.py
PlantBasedTen's picture
Upload 22 files
bb59984 verified
import logging
import os
from pathlib import Path
from typing import Iterable, List, Tuple
from langchain import chains
from langchain.memory import ConversationBufferWindowMemory
from financial_bot import constants
from financial_bot.chains import (
ContextExtractorChain,
FinancialBotQAChain,
StatelessMemorySequentialChain,
)
from financial_bot.embeddings import EmbeddingModelSingleton
from financial_bot.handlers import CometLLMMonitoringHandler
from financial_bot.models import build_huggingface_pipeline
from financial_bot.qdrant import build_qdrant_client
from financial_bot.template import get_llm_template
logger = logging.getLogger(__name__)
class FinancialBot:
"""
A language chain bot that uses a language model to generate responses to user inputs.
Args:
llm_model_id (str): The ID of the Hugging Face language model to use.
llm_qlora_model_id (str): The ID of the Hugging Face QLora model to use.
llm_template_name (str): The name of the LLM template to use.
llm_inference_max_new_tokens (int): The maximum number of new tokens to generate during inference.
llm_inference_temperature (float): The temperature to use during inference.
vector_collection_name (str): The name of the Qdrant vector collection to use.
vector_db_search_topk (int): The number of nearest neighbors to search for in the Qdrant vector database.
model_cache_dir (Path): The directory to use for caching the language model and embedding model.
streaming (bool): Whether to use the Hugging Face streaming API for inference.
embedding_model_device (str): The device to use for the embedding model.
debug (bool): Whether to enable debug mode.
Attributes:
finbot_chain (Chain): The language chain that generates responses to user inputs.
"""
def __init__(
self,
llm_model_id: str = constants.LLM_MODEL_ID,
llm_qlora_model_id: str = constants.LLM_QLORA_CHECKPOINT,
llm_template_name: str = constants.TEMPLATE_NAME,
llm_inference_max_new_tokens: int = constants.LLM_INFERNECE_MAX_NEW_TOKENS,
llm_inference_temperature: float = constants.LLM_INFERENCE_TEMPERATURE,
vector_collection_name: str = constants.VECTOR_DB_OUTPUT_COLLECTION_NAME,
vector_db_search_topk: int = constants.VECTOR_DB_SEARCH_TOPK,
model_cache_dir: Path = constants.CACHE_DIR,
streaming: bool = False,
embedding_model_device: str = "cuda:0",
debug: bool = False,
):
self._llm_model_id = llm_model_id
self._llm_qlora_model_id = llm_qlora_model_id
self._llm_template_name = llm_template_name
self._llm_template = get_llm_template(name=self._llm_template_name)
self._llm_inference_max_new_tokens = llm_inference_max_new_tokens
self._llm_inference_temperature = llm_inference_temperature
self._vector_collection_name = vector_collection_name
self._vector_db_search_topk = vector_db_search_topk
self._debug = debug
self._qdrant_client = build_qdrant_client()
self._embd_model = EmbeddingModelSingleton(
cache_dir=model_cache_dir, device=embedding_model_device
)
self._llm_agent, self._streamer = build_huggingface_pipeline(
llm_model_id=llm_model_id,
llm_lora_model_id=llm_qlora_model_id,
max_new_tokens=llm_inference_max_new_tokens,
temperature=llm_inference_temperature,
use_streamer=streaming,
cache_dir=model_cache_dir,
debug=debug,
)
self.finbot_chain = self.build_chain()
@property
def is_streaming(self) -> bool:
return self._streamer is not None
def build_chain(self) -> chains.SequentialChain:
"""
Constructs and returns a financial bot chain.
This chain is designed to take as input the user description, `about_me` and a `question` and it will
connect to the VectorDB, searches the financial news that rely on the user's question and injects them into the
payload that is further passed as a prompt to a financial fine-tuned LLM that will provide answers.
The chain consists of two primary stages:
1. Context Extractor: This stage is responsible for embedding the user's question,
which means converting the textual question into a numerical representation.
This embedded question is then used to retrieve relevant context from the VectorDB.
The output of this chain will be a dict payload.
2. LLM Generator: Once the context is extracted,
this stage uses it to format a full prompt for the LLM and
then feed it to the model to get a response that is relevant to the user's question.
Returns
-------
chains.SequentialChain
The constructed financial bot chain.
Notes
-----
The actual processing flow within the chain can be visualized as:
[about: str][question: str] > ContextChain >
[about: str][question:str] + [context: str] > FinancialChain >
[answer: str]
"""
logger.info("Building 1/3 - ContextExtractorChain")
context_retrieval_chain = ContextExtractorChain(
embedding_model=self._embd_model,
vector_store=self._qdrant_client,
vector_collection=self._vector_collection_name,
top_k=self._vector_db_search_topk,
)
logger.info("Building 2/3 - FinancialBotQAChain")
if self._debug:
callabacks = []
else:
try:
comet_project_name = os.environ["COMET_PROJECT_NAME"]
except KeyError:
raise RuntimeError(
"Please set the COMET_PROJECT_NAME environment variable."
)
callabacks = [
CometLLMMonitoringHandler(
project_name=f"{comet_project_name}-monitor-prompts",
llm_model_id=self._llm_model_id,
llm_qlora_model_id=self._llm_qlora_model_id,
llm_inference_max_new_tokens=self._llm_inference_max_new_tokens,
llm_inference_temperature=self._llm_inference_temperature,
)
]
llm_generator_chain = FinancialBotQAChain(
hf_pipeline=self._llm_agent,
template=self._llm_template,
callbacks=callabacks,
)
logger.info("Building 3/3 - Connecting chains into SequentialChain")
seq_chain = StatelessMemorySequentialChain(
history_input_key="to_load_history",
memory=ConversationBufferWindowMemory(
memory_key="chat_history",
input_key="question",
output_key="answer",
k=3,
),
chains=[context_retrieval_chain, llm_generator_chain],
input_variables=["about_me", "question", "to_load_history"],
output_variables=["answer"],
verbose=True,
)
logger.info("Done building SequentialChain.")
logger.info("Workflow:")
logger.info(
"""
[about: str][question: str] > ContextChain >
[about: str][question:str] + [context: str] > FinancialChain >
[answer: str]
"""
)
return seq_chain
def answer(
self,
about_me: str,
question: str,
to_load_history: List[Tuple[str, str]] = None,
) -> str:
"""
Given a short description about the user and a question make the LLM
generate a response.
Parameters
----------
about_me : str
Short user description.
question : str
User question.
Returns
-------
str
LLM generated response.
"""
inputs = {
"about_me": about_me,
"question": question,
"to_load_history": to_load_history if to_load_history else [],
}
response = self.finbot_chain.run(inputs)
return response
def stream_answer(self) -> Iterable[str]:
"""Stream the answer from the LLM after each token is generated after calling `answer()`."""
assert (
self.is_streaming
), "Stream answer not available. Build the bot with `use_streamer=True`."
partial_answer = ""
for new_token in self._streamer:
if new_token != self._llm_template.eos:
partial_answer += new_token
yield partial_answer