Spaces:
Build error
Build error
from transformers import SamModel, SamConfig, SamProcessor | |
import torch | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import app | |
import os | |
from patchify import patchify | |
from PIL import Image | |
def patchify(large_image): | |
all_img_patches = [] | |
patches_img = patchify(large_image, (patch_size, patch_size), step=step) #Step=256 for 256 patches means no overlap | |
for i in range(patches_img.shape[0]): | |
for j in range(patches_img.shape[1]): | |
single_patch_img = patches_img[i,j,:,:] | |
all_img_patches.append(single_patch_img) | |
images = np.array(all_img_patches) | |
def pred(src): | |
# os.environ['HUGGINGFACE_HUB_HOME'] = './.cache' | |
# Load the model configuration | |
cache_dir = "/code/cache" | |
model_config = SamConfig.from_pretrained("facebook/sam-vit-base", cache_dir=cache_dir) | |
processor = SamProcessor.from_pretrained("facebook/sam-vit-base", cache_dir=cache_dir) | |
# Create an instance of the model architecture with the loaded configuration | |
my_sam_model = SamModel(config=model_config) | |
#Update the model by loading the weights from saved file. | |
my_sam_model.load_state_dict(torch.load("sam_model.pth", map_location=torch.device('cpu'))) | |
new_image = np.array(Image.open(src)) | |
patches = patchify(new_image, (256, 256), step=256) | |
# Define the size of your array | |
array_size = 256 | |
# Define the size of your grid | |
grid_size = 10 | |
# Generate the grid points | |
x = np.linspace(0, array_size-1, grid_size) | |
y = np.linspace(0, array_size-1, grid_size) | |
# Generate a grid of coordinates | |
xv, yv = np.meshgrid(x, y) | |
# Convert the numpy arrays to lists | |
xv_list = xv.tolist() | |
yv_list = yv.tolist() | |
# Combine the x and y coordinates into a list of list of lists | |
input_points = [[[int(x), int(y)] for x, y in zip(x_row, y_row)] for x_row, y_row in zip(xv_list, yv_list)] | |
input_points = torch.tensor(input_points).view(1, 1, grid_size*grid_size, 2) | |
i, j = 1, 2 | |
# Selectelected patch for segmentation | |
random_array = patches[i, j] | |
single_patch = Image.fromarray(random_array) | |
inputs = processor(single_patch, input_points=input_points, return_tensors="pt") | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
my_mito_model.eval() | |
# forward pass | |
with torch.no_grad(): | |
outputs = my_mito_model(**inputs, multimask_output=False) | |
# apply sigmoid | |
single_patch_prob = torch.sigmoid(outputs.pred_masks.squeeze(1)) | |
# convert soft mask to hard mask | |
single_patch_prob = single_patch_prob.cpu().numpy().squeeze() | |
single_patch_prediction = (single_patch_prob > 0.5).astype(np.uint8) | |
x = 1 | |
# my_sam_model.eval() | |
# # forward pass | |
# with torch.no_grad(): | |
# outputs = my_sam_model(**inputs, multimask_output=False) | |
# # apply sigmoid | |
# single_patch_prob = torch.sigmoid(outputs.pred_masks.squeeze(1)) | |
# # convert soft mask to hard mask | |
# single_patch_prob = single_patch_prob.cpu().numpy().squeeze() | |
# single_patch_prediction = (single_patch_prob > 0.5).astype(np.uint8) | |
return x | |