Spaces:
Runtime error
Runtime error
from typing import Any, Dict | |
import comet_llm | |
from langchain.callbacks.base import BaseCallbackHandler | |
from financial_bot import constants | |
class CometLLMMonitoringHandler(BaseCallbackHandler): | |
""" | |
A callback handler for monitoring LLM models using Comet.ml. | |
Args: | |
project_name (str): The name of the Comet.ml project to log to. | |
llm_model_id (str): The ID of the LLM model to use for inference. | |
llm_qlora_model_id (str): The ID of the PEFT model to use for inference. | |
llm_inference_max_new_tokens (int): The maximum number of new tokens to generate during inference. | |
llm_inference_temperature (float): The temperature to use during inference. | |
""" | |
def __init__( | |
self, | |
project_name: str = None, | |
llm_model_id: str = constants.LLM_MODEL_ID, | |
llm_qlora_model_id: str = constants.LLM_QLORA_CHECKPOINT, | |
llm_inference_max_new_tokens: int = constants.LLM_INFERNECE_MAX_NEW_TOKENS, | |
llm_inference_temperature: float = constants.LLM_INFERENCE_TEMPERATURE, | |
): | |
self._project_name = project_name | |
self._llm_model_id = llm_model_id | |
self._llm_qlora_model_id = llm_qlora_model_id | |
self._llm_inference_max_new_tokens = llm_inference_max_new_tokens | |
self._llm_inference_temperature = llm_inference_temperature | |
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: | |
""" | |
A callback function that logs the prompt and output to Comet.ml. | |
Args: | |
outputs (Dict[str, Any]): The output of the LLM model. | |
**kwargs (Any): Additional arguments passed to the function. | |
""" | |
should_log_prompt = "metadata" in kwargs | |
if should_log_prompt: | |
metadata = kwargs["metadata"] | |
comet_llm.log_prompt( | |
project=self._project_name, | |
prompt=metadata["prompt"], | |
output=outputs["answer"], | |
prompt_template=metadata["prompt_template"], | |
prompt_template_variables=metadata["prompt_template_variables"], | |
metadata={ | |
"usage.prompt_tokens": metadata["usage.prompt_tokens"], | |
"usage.total_tokens": metadata["usage.total_tokens"], | |
"usage.max_new_tokens": self._llm_inference_max_new_tokens, | |
"usage.temperature": self._llm_inference_temperature, | |
"usage.actual_new_tokens": metadata["usage.actual_new_tokens"], | |
"model": self._llm_model_id, | |
"peft_model": self._llm_qlora_model_id, | |
}, | |
duration=metadata["duration_milliseconds"], | |
) | |