import os import time import torch import numpy as np import gradio as gr from segment_anything import build_sam, SamAutomaticMaskGenerator from segment_anything.utils.amg import ( build_all_layer_point_grids ) os.system(r'python -m wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth') hourglass_args = { "baseline": { "use_hourglass": False, "hourglass_clustering_location": -1, }, "1.2x faster": { "use_hourglass": True, "hourglass_clustering_location": 16, "hourglass_num_cluster": 81, }, "1.5x faster": { "use_hourglass": True, "hourglass_clustering_location": 6, "hourglass_num_cluster": 81, }, } device = torch.device("cuda" if torch.cuda.is_available() else "cpu") mask_generator = SamAutomaticMaskGenerator( build_sam(checkpoint="sam_vit_h_4b8939.pth", use_hourglass=True), ) mask_generator.predictor.model.to(device=device) def predict(image, speed_mode, points_per_side): points_per_side = int(points_per_side) mask_generator.predictor.model.image_encoder.load_hourglass_args(**hourglass_args[speed_mode]) if points_per_side is not None: mask_generator.point_grids = build_all_layer_point_grids( points_per_side, mask_generator.crop_n_layers, mask_generator.crop_n_points_downscale_factor, ) mask_generator.points_per_batch = 64 if points_per_side > 12 else points_per_side * points_per_side start = time.perf_counter() with torch.no_grad(): masks = mask_generator.generate(image) eta = time.perf_counter() - start eta_text = f"Time of generation: {eta:.2f} seconds" if len(masks) == 0: return image sorted_masks = sorted(masks, key=(lambda x: x['area']), reverse=True) img = np.ones(image.shape) for mask in sorted_masks: m = mask['segmentation'] color_mask = np.random.random((1, 1, 3)) img = img * (1 - m[..., None]) + color_mask * m[..., None] image = (image * 0.65 + img * 255 * 0.35).astype(np.uint8) return image, eta_text description = """ #
For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.