File size: 4,562 Bytes
097caae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from huggingface_hub import InferenceClient
from dotenv import load_dotenv
import configparser
import os


class LLMManager:
    def __init__(self, settings):

        #Loading HF Token
        try:
            load_dotenv()
        except:
            print("No .env file")

        #Initing HuggingFace Inference client
        HF_TOKEN = os.environ.get('HF_TOKEN')
        self.client = InferenceClient(token=HF_TOKEN)

        #Creating and loading config file
        self.config=configparser.ConfigParser()
        self.config.read("config.ini")

        #getting setting
        self.set=settings

        #Loading default index for LLM
        self.defaultLLM=self.set.defaultLLM

        #Loading available LLM
        self.listLLM=self.get_llm()
        self.listLLMMap=self.get_llm_map()

        #Setting the model
        self.currentLLM=self.listLLM[self.defaultLLM]
    
    #Function used to select the LLM
    def selectLLM(self, llm):
        print("Selected {llm} LLM")
        llmIndex=self.listLLMMap.index(llm)
        self.currentLLM=self.listLLM[llmIndex]

    #Function used to get a list of available LLM
    def get_llm(self):
        llm_section = 'LLM'
        if llm_section in self.config:
            return [self.config.get(llm_section, llm) for llm in self.config[llm_section]]
        else:
            return []
        
    #Function used to get a list of available LLM
    def get_llm_prompts(self):
        prompt_section = 'Prompt_map'
        if prompt_section in self.config:
            return [self.config.get(prompt_section, llm) for llm in self.config[prompt_section]]
        else:
            return []
    
    #Function used to get the list of llm Map
    def get_llm_map(self):
        llm_map_section = 'LLM_Map'
        if llm_map_section in self.config:
            return [self.config.get(llm_map_section, llm) for llm in self.config[llm_map_section]]
        else:
            return []
    
    #This function is used to retrive the reply to a question
    def get_text(self, question):

        print("temp={temp}".format(temp=self.set.temperature))
        print("Repetition={rep}".format(rep=self.set.repetition_penalty))
        generate_kwargs = dict(
            temperature=self.set.temperature,
            max_new_tokens=self.set.max_new_token,
            top_p=self.set.top_p,
            repetition_penalty=self.set.repetition_penalty,
            do_sample=True,
            seed=42,
        )

        stream = self.client.text_generation(model=self.currentLLM, prompt=question, **generate_kwargs,stream=False, details=False, return_full_text=False)
        #output = ""
        return stream
        #for response in stream:
        #    output += response.token.text
        #    yield output
        #return output
    
    #this function is used to retrive the best search terms
    def get_query_terms(self, question):
        generate_kwargs = dict(
            temperature=self.set.RAG_temperature,
            max_new_tokens=self.set.RAG_max_new_token,
            top_p=self.set.RAG_top_p,
            repetition_penalty=self.set.RAG_repetition_penalty,
            do_sample=True,
        )
        stream = self.client.text_generation(model=self.currentLLM, prompt=question, **generate_kwargs,stream=False, details=False, return_full_text=False)
        return stream
            

    #This function is used to generate the prompt for the LLM
    def get_prompt(self,user_input,rag_contex,chat_history, system_prompt=None):
        """Returns the formatted prompt for a specific LLM"""

        prompts=self.get_llm_prompts()
        prompt=""

        if system_prompt==None:
            system_prompt=self.set.system_prompt
        else:
            print("System prompt set to : \n {sys_prompt}".format(sys_prompt=system_prompt))

        try:
            prompt= prompts[self.listLLM.index(self.currentLLM)].format(sys_prompt=system_prompt)
        except Exception:
            print("Warning prompt map for {llm} has not been defined".format(llm=self.currentLLM))
            prompt="{sys_prompt}".format(sys_prompt=system_prompt)

        print("Prompt={pro}".format(pro=prompt))
        return prompt.format(context=rag_contex,history=chat_history,question=user_input)
        

# Example Usage:
#if __name__ == "__main__":
#    llm_manager = LLMManager()
#    print(llm_manager.config.get('Prompt_map', 'prompt1').format(
#        system_prompt="Sei una brava IA",
#        history="",
#        context="",
#        question=""))
    #llm_manager.selectLLM("Mixtral 7B")