Spaces:
Sleeping
Sleeping
import gradio as gr | |
from smolagents import HfApiModel | |
import sys | |
if './lib' not in sys.path : | |
sys.path.append('./lib') | |
from ingestion_chroma import retrieve_info_from_db | |
############################################################################################ | |
################################### TOOLS ################################################## | |
############################################################################################ | |
def find_key(data, target_key): | |
if isinstance(data, dict): | |
for key, value in data.items(): | |
if key == target_key: | |
return value | |
else: | |
result = find_key(value, target_key) | |
if result is not None: | |
return result | |
return "Indicator not found" | |
############################################################################################ | |
class Chroma_retrieverTool(Tool): | |
name = "request" | |
description = "Using semantic similarity, retrieve the text from the knowledge base that has the embedding closest to the query." | |
inputs = { | |
"query": { | |
"type": "string", | |
"description": "The query to execute must be semantically close to the text to search. Use the affirmative form rather than a question.", | |
}, | |
} | |
output_type = "string" | |
def forward(self, query: str) -> str: | |
assert isinstance(query, str), "The request needs to be a string." | |
query_results = retrieve_info_from_db(query) | |
str_result = "\nRetrieval texts : \n" + "".join([f"===== Text {str(i)} =====\n" + query_results['documents'][0][i] for i in range(len(query_results['documents'][0]))]) | |
return str_result | |
############################################################################################ | |
class ESRS_info_tool(Tool): | |
name = "find_ESRS" | |
description = "Find ESRS description to help you to find what indicators the user want" | |
inputs = { | |
"indicator": { | |
"type": "string", | |
"description": "The indicator name. return the description of the indicator demanded.", | |
}, | |
} | |
output_type = "string" | |
def forward(self, indicator: str) -> str: | |
assert isinstance(indicator, str), "The request needs to be a string." | |
with open('./data/dico_esrs.json') as json_data: | |
dico_esrs = json.load(json_data) | |
result = find_key(dico_esrs, indicator) | |
return result | |
############################################################################################ | |
############################################################################################ | |
############################################################################################ | |
model = HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct") | |
retriever_tool = Chroma_retrieverTool() | |
get_ESRS_info_tool = ESRS_info_tool() | |
agent = CodeAgent( | |
tools=[ | |
get_ESRS_info_tool, | |
retriever_tool, | |
], | |
model=model, | |
max_steps=10, | |
max_print_outputs_length=16000, | |
additional_authorized_imports=['pandas', 'matplotlib', 'datetime'] | |
) | |
def respond(message): | |
system_prompt_added = """You are an expert in environmental and corporate social responsibility. You must respond to requests using the query function in the document database. | |
User's question : """ | |
agent_output = agent.run(system_prompt_added+"""Find all informations about the ESRS E1–5: Energy consumption from fossil sources in Sartorius documents.""") | |
yield agent_output | |
""" | |
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface | |
""" | |
demo = gr.ChatInterface( | |
respond, | |
) | |
if __name__ == "__main__": | |
demo.launch() | |