|
from typing import Any
|
|
|
|
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
|
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
|
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
|
from core.tools.entities.common_entities import I18nObject
|
|
from core.tools.entities.tool_entities import (
|
|
ToolDescription,
|
|
ToolIdentity,
|
|
ToolInvokeMessage,
|
|
ToolParameter,
|
|
ToolProviderType,
|
|
)
|
|
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
|
from core.tools.tool.tool import Tool
|
|
|
|
|
|
class DatasetRetrieverTool(Tool):
|
|
retrival_tool: DatasetRetrieverBaseTool
|
|
|
|
@staticmethod
|
|
def get_dataset_tools(tenant_id: str,
|
|
dataset_ids: list[str],
|
|
retrieve_config: DatasetRetrieveConfigEntity,
|
|
return_resource: bool,
|
|
invoke_from: InvokeFrom,
|
|
hit_callback: DatasetIndexToolCallbackHandler
|
|
) -> list['DatasetRetrieverTool']:
|
|
"""
|
|
get dataset tool
|
|
"""
|
|
|
|
if dataset_ids is None or len(dataset_ids) == 0:
|
|
return []
|
|
if retrieve_config is None:
|
|
return []
|
|
|
|
feature = DatasetRetrieval()
|
|
|
|
|
|
|
|
original_retriever_mode = retrieve_config.retrieve_strategy
|
|
retrieve_config.retrieve_strategy = DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE
|
|
retrival_tools = feature.to_dataset_retriever_tool(
|
|
tenant_id=tenant_id,
|
|
dataset_ids=dataset_ids,
|
|
retrieve_config=retrieve_config,
|
|
return_resource=return_resource,
|
|
invoke_from=invoke_from,
|
|
hit_callback=hit_callback
|
|
)
|
|
|
|
retrieve_config.retrieve_strategy = original_retriever_mode
|
|
|
|
|
|
tools = []
|
|
for retrival_tool in retrival_tools:
|
|
tool = DatasetRetrieverTool(
|
|
retrival_tool=retrival_tool,
|
|
identity=ToolIdentity(provider='', author='', name=retrival_tool.name, label=I18nObject(en_US='', zh_Hans='')),
|
|
parameters=[],
|
|
is_team_authorization=True,
|
|
description=ToolDescription(
|
|
human=I18nObject(en_US='', zh_Hans=''),
|
|
llm=retrival_tool.description),
|
|
runtime=DatasetRetrieverTool.Runtime()
|
|
)
|
|
|
|
tools.append(tool)
|
|
|
|
return tools
|
|
|
|
def get_runtime_parameters(self) -> list[ToolParameter]:
|
|
return [
|
|
ToolParameter(name='query',
|
|
label=I18nObject(en_US='', zh_Hans=''),
|
|
human_description=I18nObject(en_US='', zh_Hans=''),
|
|
type=ToolParameter.ToolParameterType.STRING,
|
|
form=ToolParameter.ToolParameterForm.LLM,
|
|
llm_description='Query for the dataset to be used to retrieve the dataset.',
|
|
required=True,
|
|
default=''),
|
|
]
|
|
|
|
def tool_provider_type(self) -> ToolProviderType:
|
|
return ToolProviderType.DATASET_RETRIEVAL
|
|
|
|
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
|
"""
|
|
invoke dataset retriever tool
|
|
"""
|
|
query = tool_parameters.get('query', None)
|
|
if not query:
|
|
return self.create_text_message(text='please input query')
|
|
|
|
|
|
result = self.retrival_tool._run(query=query)
|
|
|
|
return self.create_text_message(text=result)
|
|
|
|
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:
|
|
"""
|
|
validate the credentials for dataset retriever tool
|
|
"""
|
|
pass
|
|
|