# ------------------------------------------------------------------------ # Copyright (c) 2023-present, BAAI. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ------------------------------------------------------------------------ """Gradio application.""" import argparse import multiprocessing as mp import os import time import numpy as np import torch from tokenize_anything import engine from tokenize_anything.utils.image import im_rescale from tokenize_anything.utils.image import im_vstack def parse_args(): """Parse arguments.""" parser = argparse.ArgumentParser(description="Launch gradio application") parser.add_argument("--model-type", type=str, default="tap_vit_h") parser.add_argument("--checkpoint", type=str, default="models/tap_vit_h_v1_1.pkl") parser.add_argument("--concept", type=str, default="concepts/merged_2560.pkl") parser.add_argument("--device", nargs="+", type=int, default=[0], help="Index of devices") return parser.parse_args() class Predictor(object): """Predictor.""" def __init__(self, model, kwargs): self.model = model self.kwargs = kwargs self.prompt_size = kwargs.get("prompt_size", 256) self.model.concept_projector.reset_weights(kwargs["concept_weights"]) self.model.text_decoder.reset_cache(max_batch_size=self.prompt_size) def preprocess_images(self, imgs): """Preprocess the inference images.""" im_batch, im_shapes, im_scales = [], [], [] for img in imgs: scaled_imgs, scales = im_rescale(img, scales=[1024]) im_batch += scaled_imgs im_scales += scales im_shapes += [x.shape[:2] for x in scaled_imgs] im_batch = im_vstack(im_batch, self.model.pixel_mean_value, size=(1024, 1024)) im_shapes = np.array(im_shapes) im_scales = np.array(im_scales).reshape((len(im_batch), -1)) im_info = np.hstack([im_shapes, im_scales]).astype("float32") return im_batch, im_info @torch.inference_mode() def get_results(self, examples): """Return the results.""" # Preprocess images and prompts. imgs = [example["img"] for example in examples] points = np.concatenate([example["points"] for example in examples]) im_batch, im_info = self.preprocess_images(imgs) num_prompts = points.shape[0] if len(points.shape) > 2 else 1 batch_shape = im_batch.shape[0], num_prompts // im_batch.shape[0] batch_points = points.reshape(batch_shape + (-1, 3)) batch_points[:, :, :, :2] *= im_info[:, None, None, 2:4] batch_points = batch_points.reshape(points.shape) # Predict tokens and masks. inputs = self.model.get_inputs({"img": im_batch}) inputs.update(self.model.get_features(inputs)) outputs = self.model.get_outputs(dict(**inputs, **{"points": batch_points})) # Select final mask. iou_pred = outputs["iou_pred"].cpu().numpy() point_score = batch_points[:, 0, 2].__eq__(2).__sub__(0.5)[:, None] rank_scores = iou_pred + point_score * ([1000] + [0] * (iou_pred.shape[1] - 1)) mask_index = np.arange(rank_scores.shape[0]), rank_scores.argmax(1) iou_scores = outputs["iou_pred"][mask_index].cpu().numpy().reshape(batch_shape) # Upscale masks to the original image resolution. mask_pred = outputs["mask_pred"][mask_index].unsqueeze_(1) mask_pred = self.model.upscale_masks(mask_pred, im_batch.shape[1:-1]) mask_pred = mask_pred.view(batch_shape + mask_pred.shape[2:]) # Predict concepts. concepts, scores = self.model.predict_concept(outputs["sem_embeds"][mask_index]) concepts, scores = [x.reshape(batch_shape) for x in (concepts, scores)] # Generate captions. sem_tokens = outputs["sem_tokens"][mask_index] captions = self.model.generate_text(sem_tokens).reshape(batch_shape) # Postprocess results. results = [] for i in range(batch_shape[0]): pred_h, pred_w = im_info[i, :2].astype("int") masks = mask_pred[i : i + 1, :, :pred_h, :pred_w] masks = self.model.upscale_masks(masks, imgs[i].shape[:2]).flatten(0, 1) results.append( { "scores": np.stack([iou_scores[i], scores[i]], axis=-1), "masks": masks.gt(0).cpu().numpy().astype("uint8"), "concepts": concepts[i], "captions": captions[i], } ) return results class ServingCommand(object): """Command to run serving.""" def __init__(self, output_queue): self.output_queue = output_queue self.output_dict = mp.Manager().dict() self.output_index = mp.Value("i", 0) def postprocess_outputs(self, outputs): """Main the detection objects.""" scores, masks = outputs["scores"], outputs["masks"] concepts, captions = outputs["concepts"], outputs["captions"] text_template = "{} ({:.2f}, {:.2f}): {}" text_contents = concepts, scores[:, 0], scores[:, 1], captions texts = np.array([text_template.format(*vals) for vals in zip(*text_contents)]) return masks, texts def run(self): """Main loop to make the serving outputs.""" while True: img_id, outputs = self.output_queue.get() self.output_dict[img_id] = self.postprocess_outputs(outputs) def build_gradio_app(queues, command): """Build the gradio application.""" import gradio as gr import gradio_image_prompter as gr_ext title = "Tokenize Anything" header = ( "