Spaces:
No application file
No application file
File size: 2,144 Bytes
0833ce4 4c91492 3653851 4c91492 3653851 4c91492 3653851 4c91492 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
"""
Game Search Tool
This module defines the GameSearchTool, a LangChain-compatible tool for searching soccer games and their related events
from the project's vector store. It includes a schema for specifying game queries, and provides both synchronous and
asynchronous search methods. Used by the agent workflow to retrieve structured game/event data for downstream processing.
"""
from langchain_core.tools import BaseTool
from pydantic import BaseModel, Field
from langchain_core.documents import Document
from typing import Type, List, Optional
from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from data.vectorstore_singleton import get_vector_store
vector_store = get_vector_store()
# get list of games
games = vector_store.similarity_search(
"",
filter=lambda doc: doc.metadata.get("type") == "game",
)
games = [game.page_content for game in games]
class GameSearchSchema(BaseModel):
query: str = Field(description=f"Name of the game to retrieve. Available options: {games}")
class GameSearchTool(BaseTool):
name: str = "game_search"
description: str = "Search for games in the vector store"
args_schema: Type[BaseModel] = GameSearchSchema
def _run(self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> List[Document]:
result = vector_store.similarity_search(
"",
k=20,
filter=lambda doc: doc.metadata.get("type") == "event" and doc.metadata.get("game_name") == query,
)
return sorted(result, key=lambda doc: int(doc.id.split("_")[-1]))
async def _arun(self,
query: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> List[Document]:
result = await vector_store.asimilarity_search(
"",
k=20,
filter=lambda doc: doc.metadata.get("type") == "event" and doc.metadata.get("game_name") == query,
)
return sorted(result, key=lambda doc: int(doc.id.split("_")[-1]))
|