File size: 2,144 Bytes
0833ce4
 
 
 
 
 
 
 
4c91492
 
 
3653851
 
 
 
 
4c91492
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3653851
 
 
 
 
4c91492
 
 
 
 
 
 
3653851
 
 
 
4c91492
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
"""
Game Search Tool

This module defines the GameSearchTool, a LangChain-compatible tool for searching soccer games and their related events
from the project's vector store. It includes a schema for specifying game queries, and provides both synchronous and
asynchronous search methods. Used by the agent workflow to retrieve structured game/event data for downstream processing.
"""

from langchain_core.tools import BaseTool
from pydantic import BaseModel, Field
from langchain_core.documents import Document
from typing import Type, List, Optional
from langchain.callbacks.manager import (
    AsyncCallbackManagerForToolRun,
    CallbackManagerForToolRun,
)
from data.vectorstore_singleton import get_vector_store

vector_store = get_vector_store()

# get list of games
games = vector_store.similarity_search(
    "",
    filter=lambda doc: doc.metadata.get("type") == "game",
)
games = [game.page_content for game in games]


class GameSearchSchema(BaseModel):
    query: str = Field(description=f"Name of the game to retrieve. Available options: {games}")


class GameSearchTool(BaseTool):
    name: str = "game_search"
    description: str = "Search for games in the vector store"
    args_schema: Type[BaseModel] = GameSearchSchema 

    def _run(self,
             query: str,
             run_manager: Optional[CallbackManagerForToolRun] = None,
             ) -> List[Document]:    
        result = vector_store.similarity_search(
            "",
            k=20,
            filter=lambda doc: doc.metadata.get("type") == "event" and doc.metadata.get("game_name") == query,
        )
        return sorted(result, key=lambda doc: int(doc.id.split("_")[-1]))
    
    async def _arun(self,
                    query: str,
                    run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
                    ) -> List[Document]:
        result = await vector_store.asimilarity_search(
            "",
            k=20,
            filter=lambda doc: doc.metadata.get("type") == "event" and doc.metadata.get("game_name") == query,
        )
        return sorted(result, key=lambda doc: int(doc.id.split("_")[-1]))