dify / api /core /tools /tool /dataset_retriever_tool.py
kikuepi's picture
Upload 4913 files
4304c6d verified
raw
history blame contribute delete
No virus
4.28 kB
from typing import Any
from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import InvokeFrom
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
ToolDescription,
ToolIdentity,
ToolInvokeMessage,
ToolParameter,
ToolProviderType,
)
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from core.tools.tool.tool import Tool
class DatasetRetrieverTool(Tool):
retrival_tool: DatasetRetrieverBaseTool
@staticmethod
def get_dataset_tools(tenant_id: str,
dataset_ids: list[str],
retrieve_config: DatasetRetrieveConfigEntity,
return_resource: bool,
invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler
) -> list['DatasetRetrieverTool']:
"""
get dataset tool
"""
# check if retrieve_config is valid
if dataset_ids is None or len(dataset_ids) == 0:
return []
if retrieve_config is None:
return []
feature = DatasetRetrieval()
# save original retrieve strategy, and set retrieve strategy to SINGLE
# Agent only support SINGLE mode
original_retriever_mode = retrieve_config.retrieve_strategy
retrieve_config.retrieve_strategy = DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE
retrival_tools = feature.to_dataset_retriever_tool(
tenant_id=tenant_id,
dataset_ids=dataset_ids,
retrieve_config=retrieve_config,
return_resource=return_resource,
invoke_from=invoke_from,
hit_callback=hit_callback
)
# restore retrieve strategy
retrieve_config.retrieve_strategy = original_retriever_mode
# convert retrival tools to Tools
tools = []
for retrival_tool in retrival_tools:
tool = DatasetRetrieverTool(
retrival_tool=retrival_tool,
identity=ToolIdentity(provider='', author='', name=retrival_tool.name, label=I18nObject(en_US='', zh_Hans='')),
parameters=[],
is_team_authorization=True,
description=ToolDescription(
human=I18nObject(en_US='', zh_Hans=''),
llm=retrival_tool.description),
runtime=DatasetRetrieverTool.Runtime()
)
tools.append(tool)
return tools
def get_runtime_parameters(self) -> list[ToolParameter]:
return [
ToolParameter(name='query',
label=I18nObject(en_US='', zh_Hans=''),
human_description=I18nObject(en_US='', zh_Hans=''),
type=ToolParameter.ToolParameterType.STRING,
form=ToolParameter.ToolParameterForm.LLM,
llm_description='Query for the dataset to be used to retrieve the dataset.',
required=True,
default=''),
]
def tool_provider_type(self) -> ToolProviderType:
return ToolProviderType.DATASET_RETRIEVAL
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
"""
invoke dataset retriever tool
"""
query = tool_parameters.get('query', None)
if not query:
return self.create_text_message(text='please input query')
# invoke dataset retriever tool
result = self.retrival_tool._run(query=query)
return self.create_text_message(text=result)
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:
"""
validate the credentials for dataset retriever tool
"""
pass