Trang Dang
updare requirements
4aaf6ae
raw
history blame
No virus
3.11 kB
from transformers import SamModel, SamConfig, SamProcessor
import torch
import numpy as np
import matplotlib.pyplot as plt
import app
import os
from patchify import patchify
from PIL import Image
def patchify(large_image):
all_img_patches = []
patches_img = patchify(large_image, (patch_size, patch_size), step=step) #Step=256 for 256 patches means no overlap
for i in range(patches_img.shape[0]):
for j in range(patches_img.shape[1]):
single_patch_img = patches_img[i,j,:,:]
all_img_patches.append(single_patch_img)
images = np.array(all_img_patches)
def pred(src):
# os.environ['HUGGINGFACE_HUB_HOME'] = './.cache'
# Load the model configuration
cache_dir = "/code/cache"
model_config = SamConfig.from_pretrained("facebook/sam-vit-base", cache_dir=cache_dir)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base", cache_dir=cache_dir)
# Create an instance of the model architecture with the loaded configuration
my_sam_model = SamModel(config=model_config)
#Update the model by loading the weights from saved file.
my_sam_model.load_state_dict(torch.load("sam_model.pth", map_location=torch.device('cpu')))
new_image = np.array(Image.open(src))
patches = patchify(new_image, (256, 256), step=256)
# Define the size of your array
array_size = 256
# Define the size of your grid
grid_size = 10
# Generate the grid points
x = np.linspace(0, array_size-1, grid_size)
y = np.linspace(0, array_size-1, grid_size)
# Generate a grid of coordinates
xv, yv = np.meshgrid(x, y)
# Convert the numpy arrays to lists
xv_list = xv.tolist()
yv_list = yv.tolist()
# Combine the x and y coordinates into a list of list of lists
input_points = [[[int(x), int(y)] for x, y in zip(x_row, y_row)] for x_row, y_row in zip(xv_list, yv_list)]
input_points = torch.tensor(input_points).view(1, 1, grid_size*grid_size, 2)
i, j = 1, 2
# Selectelected patch for segmentation
random_array = patches[i, j]
single_patch = Image.fromarray(random_array)
inputs = processor(single_patch, input_points=input_points, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
my_mito_model.eval()
# forward pass
with torch.no_grad():
outputs = my_mito_model(**inputs, multimask_output=False)
# apply sigmoid
single_patch_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
# convert soft mask to hard mask
single_patch_prob = single_patch_prob.cpu().numpy().squeeze()
single_patch_prediction = (single_patch_prob > 0.5).astype(np.uint8)
x = 1
# my_sam_model.eval()
# # forward pass
# with torch.no_grad():
# outputs = my_sam_model(**inputs, multimask_output=False)
# # apply sigmoid
# single_patch_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
# # convert soft mask to hard mask
# single_patch_prob = single_patch_prob.cpu().numpy().squeeze()
# single_patch_prediction = (single_patch_prob > 0.5).astype(np.uint8)
return x