ScouterAI / tools /task_model_retriever.py
stevenbucaille's picture
Enhance app.py with improved user interface and instructions, update model ID in llm.py, and add image classification capabilities across various components. Introduce segment anything functionality and refine README for clarity on model capabilities.
518d841
import modal
from smolagents import Tool
from modal_apps.app import app
from modal_apps.task_model_retriever import TaskModelRetrieverModalApp
class TaskModelRetrieverTool(Tool):
name = "task_model_retriever"
description = """
For a given task, retrieve the models that can perform that task.
The supported tasks are:
- object-detection
- image-segmentation
- image-classification
The query is a string that describes the task the model needs to perform.
The output is a dictionary with the model id as the key and the labels that the model can detect as the value.
"""
inputs = {
"task": {
"type": "string",
"description": "The task the model needs to perform.",
},
"query": {
"type": "string",
"description": "The class of objects the model needs to detect.",
},
}
output_type = "object"
def __init__(self):
super().__init__()
self.tasks = ["object-detection", "image-segmentation", "image-classification"]
self.tool_class = modal.Cls.from_name(app.name, TaskModelRetrieverModalApp.__name__)
def setup(self):
self.tool: TaskModelRetrieverModalApp = self.tool_class()
def forward(self, task: str, query: str) -> str:
assert task in self.tasks, f"Task {task} is not supported, supported tasks are: {self.tasks}"
assert isinstance(query, str), "Your search query must be a string"
print(f"Retrieving models for task {task} with query {query}")
if task == "object-detection":
result = self.tool.object_detection_search.remote(query)
elif task == "image-segmentation":
result = self.tool.image_segmentation_search.remote(query)
elif task == "image-classification":
result = self.tool.image_classification_search.remote(query)
return result