File size: 1,913 Bytes
111afa2
 
 
 
 
 
 
 
 
 
 
 
 
 
518d841
111afa2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
518d841
111afa2
 
 
 
 
 
 
 
 
 
 
 
 
 
518d841
 
111afa2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import modal
from smolagents import Tool

from modal_apps.app import app
from modal_apps.task_model_retriever import TaskModelRetrieverModalApp


class TaskModelRetrieverTool(Tool):
    name = "task_model_retriever"
    description = """
    For a given task, retrieve the models that can perform that task.
    The supported tasks are:
    - object-detection
    - image-segmentation
    - image-classification

    The query is a string that describes the task the model needs to perform.
    The output is a dictionary with the model id as the key and the labels that the model can detect as the value.
    """
    inputs = {
        "task": {
            "type": "string",
            "description": "The task the model needs to perform.",
        },
        "query": {
            "type": "string",
            "description": "The class of objects the model needs to detect.",
        },
    }
    output_type = "object"

    def __init__(self):
        super().__init__()
        self.tasks = ["object-detection", "image-segmentation", "image-classification"]
        self.tool_class = modal.Cls.from_name(app.name, TaskModelRetrieverModalApp.__name__)

    def setup(self):
        self.tool: TaskModelRetrieverModalApp = self.tool_class()

    def forward(self, task: str, query: str) -> str:
        assert task in self.tasks, f"Task {task} is not supported, supported tasks are: {self.tasks}"
        assert isinstance(query, str), "Your search query must be a string"

        print(f"Retrieving models for task {task} with query {query}")
        if task == "object-detection":
            result = self.tool.object_detection_search.remote(query)
        elif task == "image-segmentation":
            result = self.tool.image_segmentation_search.remote(query)
        elif task == "image-classification":
            result = self.tool.image_classification_search.remote(query)
        return result