Spaces:
Running
Running
Enhance app.py with improved user interface and instructions, update model ID in llm.py, and add image classification capabilities across various components. Introduce segment anything functionality and refine README for clarity on model capabilities.
518d841
import modal | |
import numpy as np | |
import supervision as sv | |
from smolagents import Tool | |
from modal_apps.app import app | |
from modal_apps.segment_anything import SegmentAnythingModalApp | |
def get_detections_from_segment_anything(detections, list_of_masks, iou_scores): | |
bounding_boxes = detections.xyxy.tolist() | |
detections = sv.Detections( | |
xyxy=np.array(bounding_boxes), | |
mask=np.array(list_of_masks), | |
class_id=np.array(list(range(len(bounding_boxes)))), | |
confidence=np.array(iou_scores), | |
) | |
return detections | |
class SegmentAnythingTool(Tool): | |
name = "segment_anything" | |
description = """ | |
Given an image and an already detected object (a sv.Detections object), segment the image and return masks for each bounding box. | |
The image is a PIL image. | |
The detections are an object of type sv.Detections, obtainable from the usage of the object_detection tool with task_inference_output_converter. | |
The output is the same as the input, but with the masks added. | |
""" | |
inputs = { | |
"image": { | |
"type": "image", | |
"description": "The image to segment", | |
}, | |
"detections": { | |
"type": "object", | |
"description": """ | |
The detections to segment the image with. | |
The detections are an object of type supervision.Detections. | |
""", | |
}, | |
} | |
output_type = "object" | |
def __init__(self): | |
super().__init__() | |
self.modal_app = modal.Cls.from_name(app.name, SegmentAnythingModalApp.__name__)() | |
def forward( | |
self, | |
image, | |
detections: sv.Detections, | |
): | |
bounding_boxes = detections.xyxy.tolist() | |
masks, iou_scores = self.modal_app.forward.remote(image=image, bounding_boxes=bounding_boxes) | |
detections = get_detections_from_segment_anything(detections, masks, iou_scores) | |
return detections | |