Spaces:
Running
Running
from typing import Optional, Tuple, Union, List | |
import numpy as np | |
import PIL | |
from PIL.Image import Image | |
import supervision as sv | |
import torch | |
from torch import nn | |
from transformers import OwlViTProcessor, OwlViTForObjectDetection, OwlViTVisionModel | |
from transformers.models.owlvit.modeling_owlvit import center_to_corners_format, box_iou, generalized_box_iou, OwlViTObjectDetectionOutput | |
from sam_extension.pipeline.base import Pipeline, Output | |
class OwlViTVisionEncoderPipeline(Pipeline): | |
def __init__(self, | |
vision_model, | |
layer_norm, | |
processor, | |
device='cuda', | |
*args, | |
**kwargs): | |
super().__init__(*args, **kwargs) | |
self.vision_model = vision_model | |
self.layer_norm = layer_norm | |
self.processor = processor | |
self.device = device | |
torch.cuda.empty_cache() | |
def from_pretrained(cls, model_type, device='cuda', *args, **kwargs): | |
owlvit_for_object_detection = OwlViTForObjectDetection.from_pretrained(model_type).to(device) | |
processor = OwlViTProcessor.from_pretrained(model_type) | |
return cls(owlvit_for_object_detection.owlvit.vision_model, | |
owlvit_for_object_detection.layer_norm, | |
processor, | |
device, | |
*args, | |
**kwargs) | |
def process_image(self, image:Image): | |
image = self.processor(images=image, return_tensors="pt").pixel_values.to(self.device) | |
return image | |
def forward( | |
self, | |
pixel_values: Union[torch.FloatTensor, Image] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
) -> torch.FloatTensor: | |
if isinstance(pixel_values, Image): | |
pixel_values = self.process_image(pixel_values) | |
pixel_values = pixel_values.to(self.device) | |
vision_outputs = self.vision_model( | |
pixel_values=pixel_values, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
# Get image embeddings | |
last_hidden_state = vision_outputs[0] | |
image_embeds = self.vision_model.post_layernorm(last_hidden_state) | |
new_size = tuple(np.array(image_embeds.shape) - np.array((0, 1, 0))) | |
class_token_out = torch.broadcast_to(image_embeds[:, :1, :], new_size) | |
# Merge image embedding with class tokens | |
image_embeds = image_embeds[:, 1:, :] * class_token_out | |
image_embeds = self.layer_norm(image_embeds) | |
# Resize to [batch_size, num_patches, num_patches, hidden_size] | |
new_size = ( | |
image_embeds.shape[0], | |
int(np.sqrt(image_embeds.shape[1])), | |
int(np.sqrt(image_embeds.shape[1])), | |
image_embeds.shape[-1], | |
) | |
image_embeds = image_embeds.reshape(new_size) | |
return image_embeds | |
class OwlViTDecoderPipeline(Pipeline): | |
prompt_template: str = 'a photo of a ' | |
def __init__(self, | |
owlvit_text, | |
text_projection, | |
class_head, | |
box_head, | |
processor, | |
device='cuda', | |
*args, | |
**kwargs): | |
super().__init__(*args, **kwargs) | |
self.owlvit_text = owlvit_text | |
self.text_projection = text_projection | |
self.class_head = class_head | |
self.box_head = box_head | |
self.sigmoid = nn.Sigmoid() | |
self.processor = processor | |
self.device = device | |
torch.cuda.empty_cache() | |
def from_pretrained(cls, model_type, device='cuda', *args, **kwargs): | |
owlvit_for_object_detection = OwlViTForObjectDetection.from_pretrained(model_type).to(device) | |
processor = OwlViTProcessor.from_pretrained(model_type) | |
return cls(owlvit_for_object_detection.owlvit.text_model, | |
owlvit_for_object_detection.owlvit.text_projection, | |
owlvit_for_object_detection.class_head, | |
owlvit_for_object_detection.box_head, | |
processor, | |
device, | |
*args, | |
**kwargs) | |
def set_template(self, template: str): | |
self.prompt_template = template | |
def process_text(self, text:List, use_template:bool = True): | |
if use_template: | |
text = [[self.prompt_template+i for i in text[0]]] | |
inputs = self.processor(text=text, return_tensors="pt") | |
return inputs | |
def normalize_grid_corner_coordinates(self, feature_map: torch.FloatTensor): | |
# Computes normalized xy corner coordinates from feature_map. | |
if not feature_map.ndim == 4: | |
raise ValueError("Expected input shape is [batch_size, num_patches, num_patches, hidden_dim]") | |
device = feature_map.device | |
num_patches = feature_map.shape[1] | |
box_coordinates = np.stack( | |
np.meshgrid(np.arange(1, num_patches + 1), np.arange(1, num_patches + 1)), axis=-1 | |
).astype(np.float32) | |
box_coordinates /= np.array([num_patches, num_patches], np.float32) | |
# Flatten (h, w, 2) -> (h*w, 2) | |
box_coordinates = box_coordinates.reshape( | |
box_coordinates.shape[0] * box_coordinates.shape[1], box_coordinates.shape[2] | |
) | |
box_coordinates = torch.from_numpy(box_coordinates).to(device) | |
return box_coordinates | |
def compute_box_bias(self, feature_map: torch.FloatTensor) -> torch.FloatTensor: | |
# The box center is biased to its position on the feature grid | |
box_coordinates = self.normalize_grid_corner_coordinates(feature_map) | |
box_coordinates = torch.clip(box_coordinates, 0.0, 1.0) | |
# Unnormalize xy | |
box_coord_bias = torch.log(box_coordinates + 1e-4) - torch.log1p(-box_coordinates + 1e-4) | |
# The box size is biased to the patch size | |
box_size = torch.full_like(box_coord_bias, 1.0 / feature_map.shape[-2]) | |
box_size_bias = torch.log(box_size + 1e-4) - torch.log1p(-box_size + 1e-4) | |
# Compute box bias | |
box_bias = torch.cat([box_coord_bias, box_size_bias], dim=-1) | |
return box_bias | |
def box_predictor( | |
self, | |
image_feats: torch.FloatTensor, | |
feature_map: torch.FloatTensor, | |
) -> torch.FloatTensor: | |
""" | |
Args: | |
image_feats: | |
Features extracted from the image, returned by the `image_text_embedder` method. | |
feature_map: | |
A spatial re-arrangement of image_features, also returned by the `image_text_embedder` method. | |
Returns: | |
pred_boxes: | |
List of predicted boxes (cxcywh normalized to 0, 1) nested within a dictionary. | |
""" | |
# Bounding box detection head [batch_size, num_boxes, 4]. | |
pred_boxes = self.box_head(image_feats) | |
# Compute the location of each token on the grid and use it to compute a bias for the bbox prediction | |
pred_boxes += self.compute_box_bias(feature_map) | |
pred_boxes = self.sigmoid(pred_boxes) | |
return pred_boxes | |
def class_predictor( | |
self, | |
image_feats: torch.FloatTensor, | |
query_embeds: Optional[torch.FloatTensor] = None, | |
query_mask: Optional[torch.Tensor] = None, | |
) -> Tuple[torch.FloatTensor]: | |
""" | |
Args: | |
image_feats: | |
Features extracted from the `image_text_embedder`. | |
query_embeds: | |
Text query embeddings. | |
query_mask: | |
Must be provided with query_embeddings. A mask indicating which query embeddings are valid. | |
""" | |
(pred_logits, image_class_embeds) = self.class_head(image_feats, query_embeds, query_mask) | |
return (pred_logits, image_class_embeds) | |
def image_text_embedder( | |
self, | |
input_ids: torch.Tensor, | |
image_embeds: torch.FloatTensor, | |
attention_mask: torch.Tensor, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
) -> Tuple[torch.FloatTensor]: | |
# Encode text and image | |
text_outputs = self.owlvit_text( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=True, | |
) | |
text_embeds = text_outputs[1] | |
text_embeds = self.text_projection(text_embeds) | |
text_embeds = text_embeds / torch.linalg.norm(text_embeds, ord=2, dim=-1, keepdim=True) | |
return (text_embeds, image_embeds, text_outputs) | |
def embed_image_query( | |
self, query_image_features: torch.FloatTensor, query_feature_map: torch.FloatTensor | |
) -> torch.FloatTensor: | |
_, class_embeds = self.class_predictor(query_image_features) | |
pred_boxes = self.box_predictor(query_image_features, query_feature_map) | |
pred_boxes_as_corners = center_to_corners_format(pred_boxes) | |
# Loop over query images | |
best_class_embeds = [] | |
best_box_indices = [] | |
pred_boxes_device = pred_boxes_as_corners.device | |
for i in range(query_image_features.shape[0]): | |
each_query_box = torch.tensor([[0, 0, 1, 1]], device=pred_boxes_device) | |
each_query_pred_boxes = pred_boxes_as_corners[i] | |
ious, _ = box_iou(each_query_box, each_query_pred_boxes) | |
# If there are no overlapping boxes, fall back to generalized IoU | |
if torch.all(ious[0] == 0.0): | |
ious = generalized_box_iou(each_query_box, each_query_pred_boxes) | |
# Use an adaptive threshold to include all boxes within 80% of the best IoU | |
iou_threshold = torch.max(ious) * 0.8 | |
selected_inds = (ious[0] >= iou_threshold).nonzero() | |
if selected_inds.numel(): | |
selected_embeddings = class_embeds[i][selected_inds[0]] | |
mean_embeds = torch.mean(class_embeds[i], axis=0) | |
mean_sim = torch.einsum("d,id->i", mean_embeds, selected_embeddings) | |
best_box_ind = selected_inds[torch.argmin(mean_sim)] | |
best_class_embeds.append(class_embeds[i][best_box_ind]) | |
best_box_indices.append(best_box_ind) | |
if best_class_embeds: | |
query_embeds = torch.stack(best_class_embeds) | |
box_indices = torch.stack(best_box_indices) | |
else: | |
query_embeds, box_indices = None, None | |
return query_embeds, box_indices, pred_boxes | |
def forward( | |
self, | |
image_embeds: torch.FloatTensor, | |
input_ids: Optional[torch.Tensor] = None, | |
text: Optional[List] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
) -> OwlViTObjectDetectionOutput: | |
if text is not None: | |
inputs = self.process_text(text) | |
input_ids = inputs.input_ids.to(self.device) | |
attention_mask = inputs.attention_mask.to(self.device) | |
input_ids = input_ids.to(self.device) | |
image_embeds = image_embeds.to(self.device) | |
attention_mask = attention_mask.to(self.device) | |
output_attentions = output_attentions if output_attentions is not None else False | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else False | |
) | |
return_dict = return_dict if return_dict is not None else True | |
# Embed images and text queries | |
query_embeds, feature_map, text_outputs = self.image_text_embedder( | |
input_ids=input_ids, | |
image_embeds=image_embeds, | |
attention_mask=attention_mask, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
) | |
# Text and vision model outputs | |
batch_size, num_patches, num_patches, hidden_dim = feature_map.shape | |
image_feats = torch.reshape(feature_map, (batch_size, num_patches * num_patches, hidden_dim)) | |
# Reshape from [batch_size * max_text_queries, hidden_dim] -> [batch_size, max_text_queries, hidden_dim] | |
max_text_queries = input_ids.shape[0] // batch_size | |
query_embeds = query_embeds.reshape(batch_size, max_text_queries, query_embeds.shape[-1]) | |
# If first token is 0, then this is a padded query [batch_size, num_queries]. | |
input_ids = input_ids.reshape(batch_size, max_text_queries, input_ids.shape[-1]) | |
query_mask = input_ids[..., 0] > 0 | |
# Predict object classes [batch_size, num_patches, num_queries+1] | |
(pred_logits, class_embeds) = self.class_predictor(image_feats, query_embeds, query_mask) | |
# Predict object boxes | |
pred_boxes = self.box_predictor(image_feats, feature_map) | |
if not return_dict: | |
output = ( | |
pred_logits, | |
pred_boxes, | |
query_embeds, | |
feature_map, | |
class_embeds, | |
text_outputs.to_tuple(), | |
None, | |
) | |
output = tuple(x for x in output if x is not None) | |
return output | |
return OwlViTObjectDetectionOutput( | |
image_embeds=feature_map, | |
text_embeds=query_embeds, | |
pred_boxes=pred_boxes.cpu(), | |
logits=pred_logits.cpu(), | |
class_embeds=class_embeds, | |
text_model_output=text_outputs, | |
vision_model_output=None, | |
) | |
def owlvit_visualize(self, | |
image: Image, | |
texts: List, | |
owlvit_objectdetection_output: OwlViTObjectDetectionOutput, | |
score_threshold: float = 0.1, | |
pil=True): | |
target_sizes = torch.Tensor([image.size[::-1]]) | |
# Convert outputs (bounding boxes and class logits) to COCO API | |
results = self.processor.post_process(outputs=owlvit_objectdetection_output, target_sizes=target_sizes) | |
text = texts[0] | |
boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"] | |
boxes_np = [] | |
labels_list = [] | |
# Print detected objects and rescaled box coordinates | |
for box, score, label in zip(boxes, scores, labels): | |
box = [int(i) for i in box.tolist()] | |
if score >= score_threshold: | |
labels_list.append(f"{text[label]} {round(score.item(), 3)}") | |
boxes_np.append(box) | |
print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}") | |
boxes_np = np.array(boxes_np) | |
detections = sv.Detections(xyxy=boxes_np) | |
image_np = np.uint8(image)[:, :, ::-1] | |
box_annotator = sv.BoxAnnotator() | |
annotated_frame = box_annotator.annotate(scene=image_np.copy(), detections=detections, labels=labels_list) | |
if pil: | |
return PIL.Image.fromarray(annotated_frame[:, :, ::-1]) | |
else: | |
return annotated_frame[:, :, ::-1] | |