ai_agents / modules /knowledge_retrieval /destination_chain.py
jpfearnworks's picture
Resolve issue with key logic
2fbdd0c
raw
history blame
2.34 kB
from modules.base.chain import IChain
from modules.base.llm_chain_config import LLMChainConfig
from modules.knowledge_retrieval.base.knowledge_domain import KnowledgeDomain
from modules.settings.user_settings import UserSettings
from typing import Dict , Any, Callable
import os
class DestinationChain(IChain):
"""
DestinationChain Class
Design:
The DestinationChain class extends the IChain interface and provides an implementation for the
run method. It follows the Liskov Substitution Principle (LSP) as it can be used wherever IChain
is expected. The class also adheres to the Dependency Inversion Principle (DIP) as it depends on
the abstraction (KnowledgeDomain) rather than a concrete class.
Intended Implementation:
The DestinationChain class serves as a wrapper around a KnowledgeDomain instance. It implements
the run method from the IChain interface, which simply calls the generate_response method of the
KnowledgeDomain. As such, when the run method is called with a question as input, the
DestinationChain class will return a response generated by the KnowledgeDomain.
"""
knowledge_domain: KnowledgeDomain
api_key: str
llm: Any
display: Callable
usage: str
def run(self, input: str) -> str:
return self.knowledge_domain.generate_response(input)
class DestinationChainStrategy(DestinationChain):
"""Base class for Chain Strategies"""
def __init__(self, config: LLMChainConfig, display: Callable, knowledge_domain: KnowledgeDomain, usage: str):
settings = UserSettings.get_instance()
api_key = settings.get_api_key()
print("Api key")
print(api_key)
super().__init__(api_key=api_key, knowledge_domain=knowledge_domain, llm=config.llm_class, display=display, usage=usage)
self.llm = config.llm_class(temperature=config.temperature, max_tokens=config.max_tokens)
self.usage = config.usage
def run(self, question):
response = self.knowledge_domain.generate_response(question)
self.display(response)
return response
def get_chain_config(temperature: float = 0.7) -> LLMChainConfig:
usage = "This is the default chain model that should only be used as a last resort"
return LLMChainConfig(usage=usage)