sam-vit-h-encoder-torchscript / convert_torchscript.py
khsyee's picture
Change using inheritance
65dd0ae
import os
import urllib
import torch
from segment_anything.modeling import Sam
from custom_encoder import build_sam_vit_h_torchscript
CHECKPOINT_PATH = os.path.join(os.path.expanduser("~"), ".cache", "SAM")
CHECKPOINT_NAME = "sam_vit_h_4b8939.pth"
CHECKPOINT_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
MODEL_TYPE = "default"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_model(
checkpoint_path: str = CHECKPOINT_PATH,
checkpoint_name: str = CHECKPOINT_NAME,
checkpoint_url: str = CHECKPOINT_URL,
model_type: str = MODEL_TYPE,
) -> Sam:
if not os.path.exists(checkpoint_path):
os.makedirs(checkpoint_path)
checkpoint = os.path.join(checkpoint_path, checkpoint_name)
if not os.path.exists(checkpoint):
print("Downloading the model weights...")
urllib.request.urlretrieve(checkpoint_url, checkpoint)
print(f"The model weights saved as {checkpoint}")
print(f"Load the model weights from {checkpoint}")
return build_sam_vit_h_torchscript(checkpoint=checkpoint)
if __name__ == "__main__":
model = load_model().image_encoder.eval().to(device)
with torch.jit.optimized_execution(True):
script_model = torch.jit.script(model)
script_model.save("model_repository/sam_torchscript_fp32/model.pt")