import torch import tensorflow as tf device = torch.device("cpu") print(f"Torch device: {device}") # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # if device.type == "cuda": # torch.cuda.set_per_process_memory_fraction(0.3, device=device.index if device.index is not None else 0) # else: # device = "cpu" # print(f"Torch device: {device}") tf.config.set_visible_devices([], 'GPU') # gpu_devices = tf.config.experimental.list_physical_devices('GPU') # if gpu_devices: # tf.config.experimental.set_memory_growth(gpu_devices[0], True) # else: # print(f"TensorFlow device: {gpu_devices}") from segment_anything import SamPredictor, sam_model_registry import matplotlib.pyplot as plt import cv2 import numpy as np from math import ceil import os from huggingface_hub import snapshot_download if not os.path.exists('model'): REPO_ID='Serrelab/SAM_Leaves' token = os.environ.get('READ_TOKEN') print(f"Read token:{token}") if token is None: print("warning! A read token in env variables is needed for authentication.") snapshot_download(repo_id=REPO_ID, token=token,repo_type='model',local_dir='model') original_torch_load = torch.load def patched_torch_load(*args, **kwargs): kwargs['map_location'] = device return original_torch_load(*args, **kwargs) torch.load = patched_torch_load model_path = os.path.join('model', 'sam_02-06_dice_mse_0.pth') sam = sam_model_registry["default"](model_path) sam.to(device) #sam.cuda() predictor = SamPredictor(sam) torch.load = original_torch_load from torch.nn import functional as F def pad_gt(x): h, w = x.shape[-2:] padh = sam.image_encoder.img_size - h padw = sam.image_encoder.img_size - w x = F.pad(x, (0, padw, 0, padh)) return x def preprocess(img): img = np.array(img).astype(np.uint8) #assert img.max() > 127.0 img_preprocess = predictor.transform.apply_image(img) intermediate_shape = img_preprocess.shape img_preprocess = torch.as_tensor(img_preprocess).to(device) #torch.as_tensor(img_preprocess).cuda() img_preprocess = img_preprocess.permute(2, 0, 1).contiguous()[None, :, :, :] img_preprocess = sam.preprocess(img_preprocess) if len(intermediate_shape) == 3: intermediate_shape = intermediate_shape[:2] elif len(intermediate_shape) == 4: intermediate_shape = intermediate_shape[1:3] return img_preprocess, intermediate_shape def normalize(img): img = img - tf.math.reduce_min(img) img = img / tf.math.reduce_max(img) img = img * 2.0 - 1.0 return img def resize(img): # default resize function for all pi outputs return tf.image.resize(img, (SIZE, SIZE), method="bicubic") def smooth_mask(mask, ds=20): shape = tf.shape(mask) w, h = shape[0], shape[1] return tf.image.resize(tf.image.resize(mask, (ds, ds), method="bicubic"), (w, h), method="bicubic") def pi(img, mask): img = tf.cast(img, tf.float32) shape = tf.shape(img) w, h = tf.cast(shape[0], tf.int64), tf.cast(shape[1], tf.int64) mask = smooth_mask(mask.cpu().numpy().astype(float)) mask = tf.reduce_mean(mask, -1) img = img * tf.cast(mask > 0.01, tf.float32)[:, :, None] img_resize = tf.image.resize(img, (SIZE, SIZE), method="bicubic", antialias=True) img_pad = tf.image.resize_with_pad(img, SIZE, SIZE, method="bicubic", antialias=True) # building 2 anchors anchors = tf.where(mask > 0.15) anchor_xmin = tf.math.reduce_min(anchors[:, 0]) anchor_xmax = tf.math.reduce_max(anchors[:, 0]) anchor_ymin = tf.math.reduce_min(anchors[:, 1]) anchor_ymax = tf.math.reduce_max(anchors[:, 1]) if anchor_xmax - anchor_xmin > 50 and anchor_ymax - anchor_ymin > 50: img_anchor_1 = resize(img[anchor_xmin:anchor_xmax, anchor_ymin:anchor_ymax]) delta_x = (anchor_xmax - anchor_xmin) // 4 delta_y = (anchor_ymax - anchor_ymin) // 4 img_anchor_2 = img[anchor_xmin+delta_x:anchor_xmax-delta_x, anchor_ymin+delta_y:anchor_ymax-delta_y] img_anchor_2 = resize(img_anchor_2) else: img_anchor_1 = img_resize img_anchor_2 = img_pad # building the anchors max anchor_max = tf.where(mask == tf.math.reduce_max(mask))[0] anchor_max_x, anchor_max_y = anchor_max[0], anchor_max[1] img_max_zoom1 = img[tf.math.maximum(anchor_max_x-SIZE, 0): tf.math.minimum(anchor_max_x+SIZE, w), tf.math.maximum(anchor_max_y-SIZE, 0): tf.math.minimum(anchor_max_y+SIZE, h)] img_max_zoom1 = resize(img_max_zoom1) img_max_zoom2 = img[anchor_max_x-SIZE//2:anchor_max_x+SIZE//2, anchor_max_y-SIZE//2:anchor_max_y+SIZE//2] #img_max_zoom2 = img[tf.math.maximum(anchor_max_x-SIZE//2, 0): tf.math.minimum(anchor_max_x+SIZE//2, w), # tf.math.maximum(anchor_max_y-SIZE//2, 0): tf.math.minimum(anchor_max_y+SIZE//2, h)] #tf.print(img_max_zoom2.shape) #img_max_zoom2 = resize(img_max_zoom2) return tf.cast([ img_resize, #img_pad, img_anchor_1, img_anchor_2, img_max_zoom1, #img_max_zoom2, ], tf.float32) def one_step_inference(x): if len(x.shape) == 3: original_size = x.shape[:2] elif len(x.shape) == 4: original_size = x.shape[1:3] x, intermediate_shape = preprocess(x) with torch.no_grad(): image_embedding = sam.image_encoder(x) with torch.no_grad(): sparse_embeddings, dense_embeddings = sam.prompt_encoder(points = None, boxes = None,masks = None) low_res_masks, iou_predictions = sam.mask_decoder( image_embeddings=image_embedding, image_pe=sam.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=False, ) if len(x.shape) == 3: input_size = tuple(x.shape[:2]) elif len(x.shape) == 4: input_size = tuple(x.shape[-2:]) #upscaled_masks = sam.postprocess_masks(low_res_masks, input_size, original_size).cuda() mask = F.interpolate(low_res_masks, (1024, 1024))[:, :, :intermediate_shape[0], :intermediate_shape[1]] mask = F.interpolate(mask, (original_size[0], original_size[1])) return mask.to(device) #mask def segmentation_sam(x,SIZE=384): x = tf.image.resize_with_pad(x, SIZE, SIZE) predicted_mask = one_step_inference(x) fig, ax = plt.subplots() img = x.cpu().numpy() mask = predicted_mask.cpu().numpy()[0][0]>0.2 ax.imshow(img) ax.imshow(mask, cmap='jet', alpha=0.4) plt.savefig('test.png') ax.axis('off') fig.canvas.draw() # Now we can save it to a numpy array. data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) plt.close() return data