import logging from dataclasses import dataclass, field from functools import lru_cache import numpy as np import pandas as pd from openai.embeddings_utils import cosine_similarity, get_embedding from buster.completers import completer_factory from buster.completers.base import Completion from buster.formatters.prompts import SystemPromptFormatter, prompt_formatter_factory from buster.retriever import Retriever logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) @dataclass(slots=True) class Response: completion: Completion is_relevant: bool matched_documents: pd.DataFrame | None = None @dataclass class BusterConfig: """Configuration object for a chatbot.""" embedding_model: str = "text-embedding-ada-002" unknown_threshold: float = 0.9 unknown_prompt: str = "I Don't know how to answer your question." document_source: str = "" retriever_cfg: dict = field( default_factory=lambda: { "top_k": 3, "thresh": 0.7, } ) prompt_cfg: dict = field( default_factory=lambda: { "max_words": 3000, "text_before_documents": "You are a chatbot answering questions.\n", "text_before_prompt": "Answer the following question:\n", } ) completion_cfg: dict = field( default_factory=lambda: { "name": "ChatGPT", "completion_kwargs": { "engine": "gpt-3.5-turbo", "max_tokens": 200, "temperature": None, "top_p": None, "frequency_penalty": 1, "presence_penalty": 1, }, } ) class Buster: def __init__(self, cfg: BusterConfig, retriever: Retriever): self._unk_embedding = None self.update_cfg(cfg) self.retriever = retriever @property def unk_embedding(self): return self._unk_embedding @unk_embedding.setter def unk_embedding(self, embedding): logger.info("Setting new UNK embedding...") self._unk_embedding = embedding return self._unk_embedding def update_cfg(self, cfg: BusterConfig): """Every time we set a new config, we update the things that need to be updated.""" logger.info(f"Updating config to {cfg.document_source}:\n{cfg}") self._cfg = cfg self.embedding_model = cfg.embedding_model self.unknown_threshold = cfg.unknown_threshold self.unknown_prompt = cfg.unknown_prompt self.document_source = cfg.document_source self.retriever_cfg = cfg.retriever_cfg self.completion_cfg = cfg.completion_cfg self.prompt_cfg = cfg.prompt_cfg # set the unk. embedding self.unk_embedding = self.get_embedding(self.unknown_prompt, engine=self.embedding_model) # update completer and formatter cfg self.completer = completer_factory(self.completion_cfg) self.prompt_formatter = prompt_formatter_factory(self.prompt_cfg) logger.info(f"Config Updated.") @lru_cache def get_embedding(self, query: str, engine: str): logger.info("generating embedding") return get_embedding(query, engine=engine) def rank_documents( self, query: str, top_k: float, thresh: float, engine: str, source: str, ) -> pd.DataFrame: """ Compare the question to the series of documents and return the best matching documents. """ query_embedding = self.get_embedding( query, engine=engine, ) matched_documents = self.retriever.retrieve(query_embedding, top_k=top_k, source=source) # log matched_documents to the console logger.info(f"matched documents before thresh: {matched_documents}") # filter out matched_documents using a threshold matched_documents = matched_documents[matched_documents.similarity > thresh] logger.info(f"matched documents after thresh: {matched_documents}") return matched_documents def check_response_relevance( self, completion_text: str, engine: str, unk_embedding: np.array, unk_threshold: float ) -> bool: """Check to see if a response is relevant to the chatbot's knowledge or not. We assume we've prompt-engineered our bot to say a response is unrelated to the context if it isn't relevant. Here, we compare the embedding of the response to the embedding of the prompt-engineered "I don't know" embedding. set the unk_threshold to 0 to essentially turn off this feature. """ response_embedding = self.get_embedding( completion_text, engine=engine, ) score = cosine_similarity(response_embedding, unk_embedding) logger.info(f"UNK score: {score}") # Likely that the answer is meaningful, add the top sources return score < unk_threshold def process_input(self, user_input: str) -> Response: """ Main function to process the input question and generate a formatted output. """ logger.info(f"User Input:\n{user_input}") # We make sure there is always a newline at the end of the question to avoid completing the question. if not user_input.endswith("\n"): user_input += "\n" matched_documents = self.rank_documents( query=user_input, top_k=self.retriever_cfg["top_k"], thresh=self.retriever_cfg["thresh"], engine=self.embedding_model, source=self.document_source, ) if len(matched_documents) == 0: logger.warning("No documents found...") completion = Completion(text="No documents found.") matched_documents = pd.DataFrame(columns=matched_documents.columns) response = Response(completion=completion, matched_documents=matched_documents, is_relevant=False) return response # prepare the prompt system_prompt = self.prompt_formatter.format(matched_documents) completion: Completion = self.completer.generate_response(user_input=user_input, system_prompt=system_prompt) logger.info(f"GPT Response:\n{completion.text}") # check for relevance is_relevant = self.check_response_relevance( completion_text=completion.text, engine=self.embedding_model, unk_embedding=self.unk_embedding, unk_threshold=self.unknown_threshold, ) if not is_relevant: matched_documents = pd.DataFrame(columns=matched_documents.columns) # answer generated was the chatbot saying it doesn't know how to answer # uncomment override completion with unknown prompt # completion = Completion(text=self.unknown_prompt) response = Response(completion=completion, matched_documents=matched_documents, is_relevant=is_relevant) return response