track-anything-annotate / sam_controller.py
lniki's picture
add model
0e83290 verified
import cv2
from XMem2.inference.interact.interactive_utils import overlay_davis
from segmenter import Segmenter
from tools.mask_display import visualize_unique_mask
from tools.mask_merge import merge_masks
import numpy as np
class SegmenterController:
def __init__(self):
"""
Инициализация контроллера для работы с Segmenter2.
:param device: Устройство для выполнения вычислений ('cuda' или 'cpu').
"""
self.segmenter = Segmenter()
self.image_set = False
def load_image(self, image: np.ndarray):
"""
Загружает изображение в Segmenter2.
:param image: Изображение в формате NumPy массива (H, W, C).
"""
if self.image_set:
print("Изображение уже загружено. Сбросьте его перед загрузкой нового.")
return
try:
self.segmenter.set_image(image)
self.image_set = True
print("Изображение успешно загружено.")
except Exception as e:
print(f"Ошибка при загрузке изображения: {e}")
def reset_image(self):
"""
Сбрасывает текущее изображение в Segmenter2.
"""
if not self.image_set:
print("Нет загруженного изображения для сброса.")
return
try:
self.segmenter.reset_image()
self.image_set = False
print("Изображение успешно сброшено.")
except Exception as e:
print(f"Ошибка при сбросе изображения: {e}")
def _process_point_prompt(
self,
point_coords: list[list[int] | list[list[int]]],
point_labels: list[list[int] | list[list[int]]],
) -> list[dict[str, np.ndarray]]:
"""
Обрабатывает промпт для точек.
:param point_coords: Координаты точек.
:param point_labels: Метки точек.
:return: Список словарей с подготовленными данными для predict.
"""
prompts = []
for coords, labels in zip(point_coords, point_labels):
# Определяем, является ли текущий элемент списком координат или одной координатой
if isinstance(coords[0], list) and isinstance(labels, list):
# Если несколько точек и меток, multimask=False
prompt = {
"point_coords": np.array(coords),
"point_labels": np.array(labels),
}
prompts.append((prompt, False))
else:
# Если одна точка, multimask=True
prompt = {
"point_coords": np.array([coords]),
"point_labels": np.array([labels]),
}
prompts.append((prompt, True))
return prompts
def _process_box_prompt(
self, boxes: list[list[int]]
) -> list[dict[str, np.ndarray]]:
"""
Обрабатывает промпт для рамок.
:param boxes: Рамки.
:return: Список словарей с подготовленными данными для predict.
"""
prompts = []
for box in boxes:
prompt = {"boxes": np.array([box])}
prompts.append((prompt, True)) # multimask=True для каждой рамки
return prompts
def _process_both_prompt(
self,
point_coords: list[list[int] | None],
point_labels: list[int | None],
boxes: list[list[int]],
) -> list[dict[str, np.ndarray]]:
"""
Обрабатывает промпт для комбинированного режима.
:param point_coords: Координаты точек.
:param point_labels: Метки точек.
:param boxes: Рамки.
:return: Список словарей с подготовленными данными для predict.
"""
prompts = []
for box, coords, labels in zip(boxes, point_coords, point_labels):
prompt = {"boxes": np.array([box])}
if coords is not None and labels is not None:
prompt["point_coords"] = np.array([coords])
prompt["point_labels"] = np.array([labels])
prompts.append((prompt, False)) # multimask=False, если есть точки
else:
prompts.append((prompt, True)) # multimask=True, если точек нет
return prompts
def predict_from_prompts(
self, prompts: dict[str, str | list]
) -> list[list[np.ndarray, np.ndarray, np.ndarray]]:
"""
Выполняет предсказание на основе заданного промпта.
:param prompts: Словарь с данными для предсказания.
:return: Список кортежей (маски, оценки, логиты).
"""
if not self.image_set:
raise RuntimeError("Изображение не загружено. Сначала вызовите load_image.")
mode = prompts.get("mode")
results = []
if mode == "point":
point_coords = prompts.get("point_coords", [])
point_labels = prompts.get("point_labels", [])
processed_prompts = self._process_point_prompt(point_coords, point_labels)
elif mode == "box":
boxes = prompts.get("boxes", [])
processed_prompts = self._process_box_prompt(boxes)
elif mode == "both":
point_coords = prompts.get(
"point_coords", [None] * len(prompts.get("boxes", []))
)
point_labels = prompts.get(
"point_labels", [None] * len(prompts.get("boxes", []))
)
boxes = prompts.get("boxes", [])
processed_prompts = self._process_both_prompt(
point_coords, point_labels, boxes
)
else:
raise ValueError("Режим должен быть 'point', 'box' или 'both'.")
# TODO: добавить вариант без цикла
for prompt, multimask in processed_prompts:
try:
masks, scores, logits = self.segmenter.predict(
prompt, mode=mode, multimask=multimask
)
results.append([masks, scores, logits])
except Exception as e:
print(f"Ошибка при выполнении предсказания: {e}")
raise
return results
if __name__ == '__main__':
# Создаем контроллер
controller = SegmenterController()
path = 'video-test/truck.jpg'
path = 'video-test/video.mp4'
video = cv2.VideoCapture(path)
ret, frame = video.read()
frame_cop = frame.copy()
video.release()
controller.load_image(frame)
import timeit
# Пример 1: Точки
prompts = {
'mode': 'point',
'point_coords': [[531, 230], [45, 321], [226, 360], [194, 313]],
'point_labels': [1, 1, 1, 1],
}
# prompts = {
# 'mode': 'point',
# 'point_coords': [[[531, 230], [45, 321]], [226, 360], [194, 313]],
# 'point_labels': [[1, 0], 1, 1],
# }
def run_segmentation():
prompts = {
'mode': 'point',
'point_coords': [[531, 230], [45, 321], [226, 360], [194, 313]],
'point_labels': [1, 0, 1, 1],
}
return controller.predict_from_prompts(prompts)
results = controller.predict_from_prompts(prompts)
execution_time_ms = timeit.timeit(run_segmentation, number=1) * 1000
print(f"Время выполнения: {execution_time_ms:.2f} мс")
# Пример 2: Рамки
# prompts = {
# 'mode': 'box',
# 'boxes': [
# [476, 166, 578, 320],
# [8, 252, 99, 401],
# [106, 335, 317, 425],
# [155, 283, 225, 339],
# ],
# }
# results = controller.predict_from_prompts(prompts)
# Пример 3: Комбинированный режим
# prompts = {
# 'mode': 'both',
# 'point_coords': [[575, 750]],
# 'point_labels': [0],
# 'boxes': [[425, 600, 700, 875]],
# }
# results = controller.predict_from_prompts(prompts)
print(len(results))
res = [result[np.argmax(scores)] for result, scores, logits in results]
mask, unique_mask = merge_masks(res)
f = overlay_davis(frame, unique_mask)
mask = visualize_unique_mask(unique_mask)
f = cv2.cvtColor(f, cv2.COLOR_BGR2RGB)
cv2.imshow('asd', mask)
cv2.imshow('asd', f)
cv2.waitKey(0)
cv2.destroyAllWindows()
# Сбрасываем изображение
controller.reset_image()