Spaces:
Sleeping
Sleeping
| import logging | |
| from typing import Sequence | |
| import numpy as np | |
| import torch | |
| from transformers import Owlv2ForObjectDetection, Owlv2Processor | |
| from models.detectors.base import DetectionResult, ObjectDetector | |
| class Owlv2Detector(ObjectDetector): | |
| MODEL_NAME = "google/owlv2-large-patch14" | |
| def __init__(self) -> None: | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logging.info("Loading %s onto %s", self.MODEL_NAME, self.device) | |
| self.processor = Owlv2Processor.from_pretrained(self.MODEL_NAME) | |
| torch_dtype = torch.float16 if self.device.type == "cuda" else torch.float32 | |
| self.model = Owlv2ForObjectDetection.from_pretrained( | |
| self.MODEL_NAME, torch_dtype=torch_dtype | |
| ) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| self.name = "owlv2" | |
| def predict(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult: | |
| inputs = self.processor(text=queries, images=frame, return_tensors="pt") | |
| if hasattr(inputs, "to"): | |
| inputs = inputs.to(self.device) | |
| else: | |
| inputs = { | |
| key: value.to(self.device) if hasattr(value, "to") else value | |
| for key, value in inputs.items() | |
| } | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| processed = self.processor.post_process_object_detection( | |
| outputs, threshold=0.3, target_sizes=[frame.shape[:2]] | |
| )[0] | |
| boxes = processed["boxes"] | |
| scores = processed.get("scores", []) | |
| labels = processed.get("labels", []) | |
| boxes_np = boxes.cpu().numpy() if hasattr(boxes, "cpu") else np.asarray(boxes) | |
| if hasattr(scores, "cpu"): | |
| scores_seq = scores.cpu().numpy().tolist() | |
| elif isinstance(scores, np.ndarray): | |
| scores_seq = scores.tolist() | |
| else: | |
| scores_seq = list(scores) | |
| if hasattr(labels, "cpu"): | |
| labels_seq = labels.cpu().numpy().tolist() | |
| elif isinstance(labels, np.ndarray): | |
| labels_seq = labels.tolist() | |
| else: | |
| labels_seq = list(labels) | |
| return DetectionResult(boxes=boxes_np, scores=scores_seq, labels=labels_seq) | |