File size: 4,562 Bytes
097caae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
from huggingface_hub import InferenceClient
from dotenv import load_dotenv
import configparser
import os
class LLMManager:
def __init__(self, settings):
#Loading HF Token
try:
load_dotenv()
except:
print("No .env file")
#Initing HuggingFace Inference client
HF_TOKEN = os.environ.get('HF_TOKEN')
self.client = InferenceClient(token=HF_TOKEN)
#Creating and loading config file
self.config=configparser.ConfigParser()
self.config.read("config.ini")
#getting setting
self.set=settings
#Loading default index for LLM
self.defaultLLM=self.set.defaultLLM
#Loading available LLM
self.listLLM=self.get_llm()
self.listLLMMap=self.get_llm_map()
#Setting the model
self.currentLLM=self.listLLM[self.defaultLLM]
#Function used to select the LLM
def selectLLM(self, llm):
print("Selected {llm} LLM")
llmIndex=self.listLLMMap.index(llm)
self.currentLLM=self.listLLM[llmIndex]
#Function used to get a list of available LLM
def get_llm(self):
llm_section = 'LLM'
if llm_section in self.config:
return [self.config.get(llm_section, llm) for llm in self.config[llm_section]]
else:
return []
#Function used to get a list of available LLM
def get_llm_prompts(self):
prompt_section = 'Prompt_map'
if prompt_section in self.config:
return [self.config.get(prompt_section, llm) for llm in self.config[prompt_section]]
else:
return []
#Function used to get the list of llm Map
def get_llm_map(self):
llm_map_section = 'LLM_Map'
if llm_map_section in self.config:
return [self.config.get(llm_map_section, llm) for llm in self.config[llm_map_section]]
else:
return []
#This function is used to retrive the reply to a question
def get_text(self, question):
print("temp={temp}".format(temp=self.set.temperature))
print("Repetition={rep}".format(rep=self.set.repetition_penalty))
generate_kwargs = dict(
temperature=self.set.temperature,
max_new_tokens=self.set.max_new_token,
top_p=self.set.top_p,
repetition_penalty=self.set.repetition_penalty,
do_sample=True,
seed=42,
)
stream = self.client.text_generation(model=self.currentLLM, prompt=question, **generate_kwargs,stream=False, details=False, return_full_text=False)
#output = ""
return stream
#for response in stream:
# output += response.token.text
# yield output
#return output
#this function is used to retrive the best search terms
def get_query_terms(self, question):
generate_kwargs = dict(
temperature=self.set.RAG_temperature,
max_new_tokens=self.set.RAG_max_new_token,
top_p=self.set.RAG_top_p,
repetition_penalty=self.set.RAG_repetition_penalty,
do_sample=True,
)
stream = self.client.text_generation(model=self.currentLLM, prompt=question, **generate_kwargs,stream=False, details=False, return_full_text=False)
return stream
#This function is used to generate the prompt for the LLM
def get_prompt(self,user_input,rag_contex,chat_history, system_prompt=None):
"""Returns the formatted prompt for a specific LLM"""
prompts=self.get_llm_prompts()
prompt=""
if system_prompt==None:
system_prompt=self.set.system_prompt
else:
print("System prompt set to : \n {sys_prompt}".format(sys_prompt=system_prompt))
try:
prompt= prompts[self.listLLM.index(self.currentLLM)].format(sys_prompt=system_prompt)
except Exception:
print("Warning prompt map for {llm} has not been defined".format(llm=self.currentLLM))
prompt="{sys_prompt}".format(sys_prompt=system_prompt)
print("Prompt={pro}".format(pro=prompt))
return prompt.format(context=rag_contex,history=chat_history,question=user_input)
# Example Usage:
#if __name__ == "__main__":
# llm_manager = LLMManager()
# print(llm_manager.config.get('Prompt_map', 'prompt1').format(
# system_prompt="Sei una brava IA",
# history="",
# context="",
# question=""))
#llm_manager.selectLLM("Mixtral 7B") |