from transformers.agents.tools import Tool | |
from data import Data | |
class SquadRetrieverTool(Tool): | |
name = "squad_retriever" | |
description = "Answers questions from the Stanford Question Answering Dataset (SQuAD)." | |
inputs = { | |
"query": { | |
"type": "string", | |
"description": "The question. This should be the literal question being asked, only modified to be informed by chat history. Be sure to pass this as a keyword argument and not a dictionary.", | |
}, | |
} | |
output_type = "string" | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
self.data = Data() | |
self.query_engine = self.data.index.as_query_engine() | |
def forward(self, query: str) -> str: | |
assert isinstance(query, str), "Your search query must be a string" | |
response = self.query_engine.query(query) | |
# docs = self.data.index.similarity_search(query, k=3) | |
if len(response.response) == 0: | |
return "No answer found for this query." | |
return "Retrieved answer:\n\n" + "\n===Answer===\n".join( | |
[response.response] | |
) |