from transformers import SamModel, SamConfig, SamProcessor import torch import numpy as np import matplotlib.pyplot as plt import app import os from patchify import patchify from PIL import Image def patchify(large_image): all_img_patches = [] patches_img = patchify(large_image, (patch_size, patch_size), step=step) #Step=256 for 256 patches means no overlap for i in range(patches_img.shape[0]): for j in range(patches_img.shape[1]): single_patch_img = patches_img[i,j,:,:] all_img_patches.append(single_patch_img) images = np.array(all_img_patches) def pred(src): # os.environ['HUGGINGFACE_HUB_HOME'] = './.cache' # Load the model configuration cache_dir = "/code/cache" model_config = SamConfig.from_pretrained("facebook/sam-vit-base", cache_dir=cache_dir) processor = SamProcessor.from_pretrained("facebook/sam-vit-base", cache_dir=cache_dir) # Create an instance of the model architecture with the loaded configuration my_sam_model = SamModel(config=model_config) #Update the model by loading the weights from saved file. my_sam_model.load_state_dict(torch.load("sam_model.pth", map_location=torch.device('cpu'))) new_image = np.array(Image.open(src)) patches = patchify(new_image, (256, 256), step=256) # Define the size of your array array_size = 256 # Define the size of your grid grid_size = 10 # Generate the grid points x = np.linspace(0, array_size-1, grid_size) y = np.linspace(0, array_size-1, grid_size) # Generate a grid of coordinates xv, yv = np.meshgrid(x, y) # Convert the numpy arrays to lists xv_list = xv.tolist() yv_list = yv.tolist() # Combine the x and y coordinates into a list of list of lists input_points = [[[int(x), int(y)] for x, y in zip(x_row, y_row)] for x_row, y_row in zip(xv_list, yv_list)] input_points = torch.tensor(input_points).view(1, 1, grid_size*grid_size, 2) i, j = 1, 2 # Selectelected patch for segmentation random_array = patches[i, j] single_patch = Image.fromarray(random_array) inputs = processor(single_patch, input_points=input_points, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} my_mito_model.eval() # forward pass with torch.no_grad(): outputs = my_mito_model(**inputs, multimask_output=False) # apply sigmoid single_patch_prob = torch.sigmoid(outputs.pred_masks.squeeze(1)) # convert soft mask to hard mask single_patch_prob = single_patch_prob.cpu().numpy().squeeze() single_patch_prediction = (single_patch_prob > 0.5).astype(np.uint8) x = 1 # my_sam_model.eval() # # forward pass # with torch.no_grad(): # outputs = my_sam_model(**inputs, multimask_output=False) # # apply sigmoid # single_patch_prob = torch.sigmoid(outputs.pred_masks.squeeze(1)) # # convert soft mask to hard mask # single_patch_prob = single_patch_prob.cpu().numpy().squeeze() # single_patch_prediction = (single_patch_prob > 0.5).astype(np.uint8) return x