File size: 3,451 Bytes
e30aaa7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""
Team Search Tool

This module defines the TeamSearchTool, a LangChain-compatible tool for searching soccer teams 
in the fictional Huge League using the project's vector store.
"""

from pydantic import BaseModel, Field
from langchain.tools import BaseTool
from langchain_core.documents import Document
from typing import Type, List, Optional
from langchain.callbacks.manager import (
    AsyncCallbackManagerForToolRun,
    CallbackManagerForToolRun,
)
from data.vectorstore_singleton import get_vector_store

vector_store = get_vector_store()


class TeamSearchInputSchema(BaseModel):
    team_query: str = Field(description=(
        "The search query to identify a soccer team in the fictional league. "
    ))

class TeamSearchTool(BaseTool):
    name: str = "team_search"
    description: str = (
        "Searches for a specific soccer team in the fictional league based on its name. "
        "Returns information about the team, which can be used to display a team card."
    )
    args_schema: Type[BaseModel] = TeamSearchInputSchema
    
    def _run(
        self,
        team_query: str,
        run_manager: Optional[CallbackManagerForToolRun] = None,
    ) -> List[Document]:
        """Search for a team using the vector store."""
        results = vector_store.similarity_search(
            query=team_query,
            k=1,
            filter=lambda doc: doc.metadata.get("type") == "team",
        )
        
        processed_results = []
        for doc in results:
            team_name_found = doc.metadata.get("name", team_query) 
            
            doc.metadata["show_team_card"] = True
            doc.metadata["team_name"] = team_name_found
            doc.metadata.pop("country", None)
            doc.metadata.pop("description", None)
            if "city" not in doc.metadata:
                doc.metadata["city"] = "N/A"

            processed_results.append(doc)
        
        return processed_results
    
    async def _arun(
        self,
        team_query: str,
        run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
    ) -> List[Document]:
        """Asynchronously searches for a team using the vector store."""
        found_docs = await vector_store.asimilarity_search(
            query=team_query, 
            k=3,
            metadata={"type": "team"}  # Use metadata filter instead of filter function
        )
        
        processed_results = []
        if found_docs:
            doc = found_docs[0] 
            if doc.metadata.get("type") == "team" and doc.metadata.get("name"):
                metadata = {
                    "show_team_card": True,
                    "team_name": doc.metadata.get("name", "Unknown Team"),
                    "team_id": doc.metadata.get("id", doc.metadata.get("name", "unknown-id").lower().replace(" ", "-")),
                    "city": doc.metadata.get("city", "N/A"),
                }
                page_content = f"Found: {metadata['team_name']}. Location: {metadata.get('city')}."
                processed_doc = Document(page_content=page_content, metadata=metadata)
                processed_results.append(processed_doc)
            else:
                print(f"Found document for query '{team_query}' but it's not a valid team entry or lacks name.")

        if not processed_results:
            print(f"No team found for query: {team_query} after vector search and filtering.")

        return processed_results