File size: 4,355 Bytes
03f1c64
 
 
 
 
 
 
 
 
 
 
2fbdd0c
03f1c64
 
 
 
2fbdd0c
 
 
03f1c64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2fbdd0c
 
 
03f1c64
 
2fbdd0c
03f1c64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from loguru import logger
from modules.knowledge_retrieval.domains.business_domain import BusinessChain, get_business_chain_config
from modules.knowledge_retrieval.domains.family_domain import FamilyChain, get_family_chain_config
from modules.knowledge_retrieval.domains.gaming_domain import GamingChain, get_gaming_chain_config
from modules.knowledge_retrieval.domains.finance_domain import FinanceChain, get_finance_chain_config
from modules.base.llm_chain_config import LLMChainConfig
from modules.knowledge_retrieval.base.router_chain import RouterChain
from modules.knowledge_retrieval.destination_chain import DestinationChainStrategy
from langchain import PromptTemplate, LLMChain

import pprint
from modules.settings.user_settings import UserSettings
from typing import Callable, Dict, Optional, Tuple
import re

class KnowledgeDomainRouter(RouterChain):
    def __init__(self, config: LLMChainConfig, question: str, display: Callable):
        settings = UserSettings.get_instance()
        
        chains : Dict[int, DestinationChainStrategy] = {
        1: BusinessChain(config=get_business_chain_config(), display=display),
        2: FamilyChain(config=get_family_chain_config(), display=display),
        3: GamingChain(config=get_gaming_chain_config(), display=display),
        4: FinanceChain(config=get_finance_chain_config(), display=display)
    }
        
        usage_block = f"""
        1. {chains[1].usage} [1].
        2. {chains[2].usage} [2].
        3. {chains[3].usage} [3].
        4. {chains[4].usage} [4].
        """
        template = """
            Consider the following problem : {question}. Based on the characteristics of the problem,
            identify the most suitable knowledge domain to query from the items provided. Consider each carefully 
            in the context of the question, write out the likelihood of relevant of each, and then select the most 
            appropriate knowledge domain:""" + usage_block + """
                Based on the characteristics of the given problem , select the domain that aligns most closely with the nature of the problem. It is important to first provide the number of the technique that best solves the problem, followed by a period. Then you may provide your reason why you have chosen this technique. 

            The number and name of the selected strategy is...
            """
        api_key = settings.get_api_key()
        super().__init__(api_key=api_key, template = template, destination_chains=chains, usage=config.usage, llm=config.llm_class, question=question)
        self.api_key = settings.get_api_key()
        print("Creating Knowledge Domain Router with config: ")
        # pprint.pprint(config)
        self.llm = config.llm_class(temperature=config.temperature, max_tokens=config.max_tokens, api_key=self.api_key)
        self.question: str = question

    def run(self, question: str) -> str:


        for index, chain in self.destination_chains.items():
            print(f'Running Chain #{index}')
            self.display(f"Running Chain #{index}")
            response = chain.run(question)
            print(response)
            self.display(response)
        return response


    @staticmethod
    def find_first_integer(text: str) -> Optional[int]:
        match = re.search(r'\d+', text)
        if match:
            return int(match.group())
        else:
            return None

    def determine_and_execute(self, question) -> Tuple[str, str]:
        prompt = PromptTemplate(template=self.template, input_variables=["question"])
        llm_chain = LLMChain(prompt=prompt, llm=self.llm)

        response = llm_chain.run(self.question)
        print(response)
        self.display(response)
        n = self.find_first_integer(response)

        if n in self.destination_chains:
            chain_resp = self.destination_chains[n].run(self.question)
        else:
            chain_resp = (f"Chain number {n} is not recognized.")
            print(chain_resp)
            
        return response, chain_resp

def get_knowledge_domain_router_config(temperature: float = 0.6) -> LLMChainConfig:
    usage="This router should be used when determining the most effective strategy for a query requiring domain-specific knowledge to derive"
    return LLMChainConfig(temperature=temperature, max_tokens=3000, usage=usage)