Trang Dang
update
d152f7f
raw
history blame
No virus
1.26 kB
from transformers import SamModel, SamConfig, SamProcessor
import torch
import numpy as np
import app
import os
from PIL import Image
def pred(src):
# Load the model configuration
cache_dir = "/code/cache"
model_config = SamConfig.from_pretrained("facebook/sam-vit-base", cache_dir=cache_dir)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base", cache_dir=cache_dir)
# Create an instance of the model architecture with the loaded configuration
my_sam_model = SamModel(config=model_config)
# #Update the model by loading the weights from saved file.
my_sam_model.load_state_dict(torch.load("sam_model.pth", map_location=torch.device('cpu')))
new_image = np.array(Image.open(src).convert("RGB"))
inputs = processor(new_image, return_tensors="pt")
my_sam_model.eval()
# # forward pass
with torch.no_grad():
outputs = my_sam_model(**inputs, multimask_output=False)
# # apply sigmoid
single_patch_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
# # convert soft mask to hard mask
single_patch_prob = single_patch_prob.cpu().numpy().squeeze()
single_patch_prediction = (single_patch_prob > 0.5).astype(np.uint8)
return single_patch_prob, single_patch_prediction