File size: 1,347 Bytes
915d664
 
 
 
 
c8a3732
e952bfc
915d664
f2d8c06
78054c9
e8807d2
 
 
e952bfc
915d664
83cc330
915d664
 
1d5816e
915d664
 
 
 
 
 
 
 
1d5816e
915d664
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from transformers import SamModel, SamConfig, SamProcessor
import torch
import numpy as np
import matplotlib.pyplot as plt
import app
import os

def pred(src):
    # os.environ['HUGGINGFACE_HUB_HOME'] = './.cache'
    # Load the model configuration
    cache_dir = "/code/cache"
    model_config = SamConfig.from_pretrained("facebook/sam-vit-base", cache_dir=cache_dir)
    processor = SamProcessor.from_pretrained("facebook/sam-vit-base", cache_dir=cache_dir)

    # Create an instance of the model architecture with the loaded configuration
    my_sam_model = SamModel(config=model_config)
    #Update the model by loading the weights from saved file.
    my_sam_model.load_state_dict(torch.load("sam_model.pth", map_location=torch.device('cpu')))

    new_image = np.array(Image.open(src))
    inputs = processor(new_image, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    x = 1
    # my_sam_model.eval()
    # # forward pass
    # with torch.no_grad():
    #     outputs = my_sam_model(**inputs, multimask_output=False)

    # # apply sigmoid
    # single_patch_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
    # # convert soft mask to hard mask
    # single_patch_prob = single_patch_prob.cpu().numpy().squeeze()
    # single_patch_prediction = (single_patch_prob > 0.5).astype(np.uint8)
    return x