File size: 1,250 Bytes
915d664
 
 
 
218ce85
 
915d664
78054c9
5e03559
 
 
e952bfc
915d664
5e03559
29f785c
5e03559
d152f7f
 
8dbb7a8
 
67507fb
 
8dbb7a8
 
67507fb
 
8dbb7a8
67507fb
8dbb7a8
 
218ce85
d152f7f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from transformers import SamModel, SamConfig, SamProcessor
import torch
import numpy as np
import app
from PIL import Image

def pred(src):
    # Load the model configuration
    cache_dir = "/code/cache"
    model_config = SamConfig.from_pretrained("facebook/sam-vit-base", cache_dir=cache_dir)
    processor = SamProcessor.from_pretrained("facebook/sam-vit-base", cache_dir=cache_dir)

    # Create an instance of the model architecture with the loaded configuration
    my_sam_model = SamModel(config=model_config)
    #Update the model by loading the weights from saved file
    my_sam_model.load_state_dict(torch.load("sam_model.pth", map_location=torch.device('cpu')))

    new_image = np.array(Image.open(src).convert("RGB"))
    inputs = processor(new_image, return_tensors="pt")
    my_sam_model.eval()

    # # forward pass
    with torch.no_grad():
        outputs = my_sam_model(**inputs, multimask_output=False)

    # # apply sigmoid
    single_patch_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
    # # convert soft mask to hard mask
    single_patch_prob = single_patch_prob.cpu().numpy().squeeze()
    single_patch_prediction = (single_patch_prob > 0.5).astype(np.uint8)

    return single_patch_prob, single_patch_prediction