Create QuestionRetrieverTool.py
Browse files- QuestionRetrieverTool.py +38 -0
    	
        QuestionRetrieverTool.py
    ADDED
    
    | @@ -0,0 +1,38 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from langchain.docstore.document import Document
         | 
| 2 | 
            +
            from langchain.text_splitter import RecursiveCharacterTextSplitter
         | 
| 3 | 
            +
            from smolagents import Tool
         | 
| 4 | 
            +
            from langchain_community.retrievers import BM25Retriever
         | 
| 5 | 
            +
            from smolagents import CodeAgent, InferenceClientModel
         | 
| 6 | 
            +
            from datasets import load_dataset
         | 
| 7 | 
            +
            import re
         | 
| 8 | 
            +
            import pandas as pd
         | 
| 9 | 
            +
            #%%
         | 
| 10 | 
            +
            class QuestionRetrieverTool(Tool):
         | 
| 11 | 
            +
                name = "Question_retriever"
         | 
| 12 | 
            +
                description = "Uses semantic search to retrieve relevant question given the class, difficulty, and topic inputs by the user."
         | 
| 13 | 
            +
                inputs = {
         | 
| 14 | 
            +
                    "query": {
         | 
| 15 | 
            +
                        "type": "string",
         | 
| 16 | 
            +
                        "description": "This tool returns relevant question and answer pairs based on the provided context.",
         | 
| 17 | 
            +
                    }
         | 
| 18 | 
            +
                }
         | 
| 19 | 
            +
                output_type = "string"
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                def __init__(self, docs, **kwargs):
         | 
| 22 | 
            +
                    super().__init__(**kwargs)
         | 
| 23 | 
            +
                    self.retriever = BM25Retriever.from_documents(
         | 
| 24 | 
            +
                        docs, k=5  # Retrieve the top 5 documents
         | 
| 25 | 
            +
                    )
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                def forward(self, query: str) -> str:
         | 
| 28 | 
            +
                    assert isinstance(query, str), "Your search query must be a string"
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                    docs = self.retriever.invoke(
         | 
| 31 | 
            +
                        query,
         | 
| 32 | 
            +
                    )
         | 
| 33 | 
            +
                    return "\nRetrieved example question and answer pairs:\n" + "".join(
         | 
| 34 | 
            +
                        [
         | 
| 35 | 
            +
                            f"\n\n===== Q and A pairs {str(i)} =====\n" + doc.page_content
         | 
| 36 | 
            +
                            for i, doc in enumerate(docs)
         | 
| 37 | 
            +
                        ]
         | 
| 38 | 
            +
                    )
         |