Trang Dang
use sam model without config
78054c9
raw
history blame
No virus
1.26 kB
from transformers import SamModel, SamConfig, SamProcessor
import torch
import numpy as np
import matplotlib.pyplot as plt
import app
import os
def pred(src):
# os.environ['HUGGINGFACE_HUB_HOME'] = './.cache'
# Load the model configuration
# model_config = SamConfig.from_pretrained("facebook/sam-vit-base")
# processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
# Create an instance of the model architecture with the loaded configuration
my_sam_model = SamModel()
#Update the model by loading the weights from saved file.
my_sam_model.load_state_dict(torch.load("sam_model.pth", map_location=torch.device('cpu')))
new_image = np.array(Image.open(src))
inputs = processor(new_image, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
x = 1
# my_sam_model.eval()
# # forward pass
# with torch.no_grad():
# outputs = my_sam_model(**inputs, multimask_output=False)
# # apply sigmoid
# single_patch_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
# # convert soft mask to hard mask
# single_patch_prob = single_patch_prob.cpu().numpy().squeeze()
# single_patch_prediction = (single_patch_prob > 0.5).astype(np.uint8)
return x