Matteo-CNPPS's picture
error_correction
1f7a3b4
import gradio as gr
from smolagents import HfApiModel
import sys
if './lib' not in sys.path :
sys.path.append('./lib')
from ingestion_chroma import retrieve_info_from_db
############################################################################################
################################### TOOLS ##################################################
############################################################################################
def find_key(data, target_key):
if isinstance(data, dict):
for key, value in data.items():
if key == target_key:
return value
else:
result = find_key(value, target_key)
if result is not None:
return result
return "Indicator not found"
############################################################################################
class Chroma_retrieverTool(Tool):
name = "request"
description = "Using semantic similarity, retrieve the text from the knowledge base that has the embedding closest to the query."
inputs = {
"query": {
"type": "string",
"description": "The query to execute must be semantically close to the text to search. Use the affirmative form rather than a question.",
},
}
output_type = "string"
def forward(self, query: str) -> str:
assert isinstance(query, str), "The request needs to be a string."
query_results = retrieve_info_from_db(query)
str_result = "\nRetrieval texts : \n" + "".join([f"===== Text {str(i)} =====\n" + query_results['documents'][0][i] for i in range(len(query_results['documents'][0]))])
return str_result
############################################################################################
class ESRS_info_tool(Tool):
name = "find_ESRS"
description = "Find ESRS description to help you to find what indicators the user want"
inputs = {
"indicator": {
"type": "string",
"description": "The indicator name. return the description of the indicator demanded.",
},
}
output_type = "string"
def forward(self, indicator: str) -> str:
assert isinstance(indicator, str), "The request needs to be a string."
with open('./data/dico_esrs.json') as json_data:
dico_esrs = json.load(json_data)
result = find_key(dico_esrs, indicator)
return result
############################################################################################
############################################################################################
############################################################################################
model = HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct")
retriever_tool = Chroma_retrieverTool()
get_ESRS_info_tool = ESRS_info_tool()
agent = CodeAgent(
tools=[
get_ESRS_info_tool,
retriever_tool,
],
model=model,
max_steps=10,
max_print_outputs_length=16000,
additional_authorized_imports=['pandas', 'matplotlib', 'datetime']
)
def respond(message):
system_prompt_added = """You are an expert in environmental and corporate social responsibility. You must respond to requests using the query function in the document database.
User's question : """
agent_output = agent.run(system_prompt_added+"""Find all informations about the ESRS E1–5: Energy consumption from fossil sources in Sartorius documents.""")
yield agent_output
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
respond,
)
if __name__ == "__main__":
demo.launch()