File size: 3,158 Bytes
9974436
2a3b1c1
9974436
 
 
 
 
 
585d613
e9da765
8dead6f
585d613
 
 
 
 
2a3b1c1
 
9974436
585d613
680c919
 
9974436
 
 
8dead6f
25c684b
a8366cc
 
 
97b080a
585d613
 
 
9974436
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b6e051
585d613
deca358
 
 
7b6e051
9974436
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import numpy as np
import gradio as gr
import segment_anything
import base64
import torch
import typing
import os
import subprocess
import requests
import PIL.Image
import urllib.parse


def download_image(url) -> PIL.Image.Image:
    """Download an image from a URL and return it as a PIL image."""
    return PIL.Image.open(requests.get(url, stream=True).raw)


def image_to_sam_image_embedding(
    image_url: str,
    # model_size: typing.Literal["base", "large", "huge"] = "base",
    model_size: str = "base",
) -> str:
    """Generate an image embedding."""

    image_url = urllib.parse.unquote(image_url)

    try:
        image = download_image(image_url)
    except:
        raise gr.Error(f"Could not find image with URL: {image_url}.")
    image = image.convert("RGB")
    image = np.asarray(image)

    # Select model size
    if model_size == "base":
        predictor = base_predictor
    elif model_size == "large":
        predictor = large_predictor
    elif model_size == "huge":
        predictor = huge_predictor

    # Run model
    predictor.set_image(image)
    # Output shape is (1, 256, 64, 64)
    image_embedding = predictor.get_image_embedding().cpu().numpy()

    # Flatten the array to a 1D array
    flat_arr = image_embedding.flatten()
    # Convert the 1D array to bytes
    bytes_arr = flat_arr.astype(np.float32).tobytes()
    # Encode the bytes to base64
    base64_str = base64.b64encode(bytes_arr).decode("utf-8")

    return base64_str


if __name__ == "__main__":

    # Load the model into memory to make running multiple predictions efficient
    device = "cuda" if torch.cuda.is_available() else "cpu"

    base_sam_checkpoint = "sam_vit_b_01ec64.pth"  # 375 MB
    large_sam_checkpoint = "sam_vit_l_0b3195.pth"  # 1.25 GB
    huge_sam_checkpoint = "sam_vit_h_4b8939.pth"  # 2.56 GB

    # Download the model checkpoints
    for model in [base_sam_checkpoint, large_sam_checkpoint, huge_sam_checkpoint]:
        if not os.path.exists(f"./{model}"):
            result = subprocess.run(
                ["wget", f"https://dl.fbaipublicfiles.com/segment_anything/{model}"],
                check=True,
            )
            print(f"wget {model} result = {result}")

    base_sam = segment_anything.sam_model_registry["vit_b"](
        checkpoint=base_sam_checkpoint
    )
    large_sam = segment_anything.sam_model_registry["vit_l"](
        checkpoint=large_sam_checkpoint
    )
    huge_sam = segment_anything.sam_model_registry["vit_h"](
        checkpoint=huge_sam_checkpoint
    )

    base_sam.to(device=device)
    large_sam.to(device=device)
    huge_sam.to(device=device)

    base_predictor = segment_anything.SamPredictor(base_sam)
    large_predictor = segment_anything.SamPredictor(large_sam)
    huge_predictor = segment_anything.SamPredictor(huge_sam)

    # Gradio app
    app = gr.Interface(
        fn=image_to_sam_image_embedding,
        inputs=[
            gr.components.Textbox(label="Image URL"),
            gr.components.Radio(
                choices=["base", "large", "huge"], label="Model Size", value="base"
            ),
        ],
        outputs="text",
    )
    app.launch()