Spaces:
Sleeping
Sleeping
import torch | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
if device.type == "cuda": | |
torch.cuda.set_per_process_memory_fraction(0.3, device=device.index if device.index is not None else 0) | |
else: | |
device = "cpu" | |
print(f"Torch device: {device}") | |
import tensorflow as tf | |
gpu_devices = tf.config.experimental.list_physical_devices('GPU') | |
if gpu_devices: | |
tf.config.experimental.set_memory_growth(gpu_devices[0], True) | |
else: | |
print(f"TensorFlow device: {gpu_devices}") | |
from segment_anything import SamPredictor, sam_model_registry | |
import matplotlib.pyplot as plt | |
import cv2 | |
import numpy as np | |
from math import ceil | |
import os | |
from huggingface_hub import snapshot_download | |
if not os.path.exists('model'): | |
REPO_ID='Serrelab/SAM_Leaves' | |
token = os.environ.get('READ_TOKEN') | |
print(f"Read token:{token}") | |
if token is None: | |
print("warning! A read token in env variables is needed for authentication.") | |
snapshot_download(repo_id=REPO_ID, token=token,repo_type='model',local_dir='model') | |
model_path = os.path.join('model', 'sam_02-06_dice_mse_0.pth') | |
sam = sam_model_registry["default"](model_path) | |
sam.to(device) #sam.cuda() | |
predictor = SamPredictor(sam) | |
from torch.nn import functional as F | |
def pad_gt(x): | |
h, w = x.shape[-2:] | |
padh = sam.image_encoder.img_size - h | |
padw = sam.image_encoder.img_size - w | |
x = F.pad(x, (0, padw, 0, padh)) | |
return x | |
def preprocess(img): | |
img = np.array(img).astype(np.uint8) | |
#assert img.max() > 127.0 | |
img_preprocess = predictor.transform.apply_image(img) | |
intermediate_shape = img_preprocess.shape | |
img_preprocess = torch.as_tensor(img_preprocess).to(device) #torch.as_tensor(img_preprocess).cuda() | |
img_preprocess = img_preprocess.permute(2, 0, 1).contiguous()[None, :, :, :] | |
img_preprocess = sam.preprocess(img_preprocess) | |
if len(intermediate_shape) == 3: | |
intermediate_shape = intermediate_shape[:2] | |
elif len(intermediate_shape) == 4: | |
intermediate_shape = intermediate_shape[1:3] | |
return img_preprocess, intermediate_shape | |
def normalize(img): | |
img = img - tf.math.reduce_min(img) | |
img = img / tf.math.reduce_max(img) | |
img = img * 2.0 - 1.0 | |
return img | |
def resize(img): | |
# default resize function for all pi outputs | |
return tf.image.resize(img, (SIZE, SIZE), method="bicubic") | |
def smooth_mask(mask, ds=20): | |
shape = tf.shape(mask) | |
w, h = shape[0], shape[1] | |
return tf.image.resize(tf.image.resize(mask, (ds, ds), method="bicubic"), (w, h), method="bicubic") | |
def pi(img, mask): | |
img = tf.cast(img, tf.float32) | |
shape = tf.shape(img) | |
w, h = tf.cast(shape[0], tf.int64), tf.cast(shape[1], tf.int64) | |
mask = smooth_mask(mask.cpu().numpy().astype(float)) | |
mask = tf.reduce_mean(mask, -1) | |
img = img * tf.cast(mask > 0.01, tf.float32)[:, :, None] | |
img_resize = tf.image.resize(img, (SIZE, SIZE), method="bicubic", antialias=True) | |
img_pad = tf.image.resize_with_pad(img, SIZE, SIZE, method="bicubic", antialias=True) | |
# building 2 anchors | |
anchors = tf.where(mask > 0.15) | |
anchor_xmin = tf.math.reduce_min(anchors[:, 0]) | |
anchor_xmax = tf.math.reduce_max(anchors[:, 0]) | |
anchor_ymin = tf.math.reduce_min(anchors[:, 1]) | |
anchor_ymax = tf.math.reduce_max(anchors[:, 1]) | |
if anchor_xmax - anchor_xmin > 50 and anchor_ymax - anchor_ymin > 50: | |
img_anchor_1 = resize(img[anchor_xmin:anchor_xmax, anchor_ymin:anchor_ymax]) | |
delta_x = (anchor_xmax - anchor_xmin) // 4 | |
delta_y = (anchor_ymax - anchor_ymin) // 4 | |
img_anchor_2 = img[anchor_xmin+delta_x:anchor_xmax-delta_x, | |
anchor_ymin+delta_y:anchor_ymax-delta_y] | |
img_anchor_2 = resize(img_anchor_2) | |
else: | |
img_anchor_1 = img_resize | |
img_anchor_2 = img_pad | |
# building the anchors max | |
anchor_max = tf.where(mask == tf.math.reduce_max(mask))[0] | |
anchor_max_x, anchor_max_y = anchor_max[0], anchor_max[1] | |
img_max_zoom1 = img[tf.math.maximum(anchor_max_x-SIZE, 0): tf.math.minimum(anchor_max_x+SIZE, w), | |
tf.math.maximum(anchor_max_y-SIZE, 0): tf.math.minimum(anchor_max_y+SIZE, h)] | |
img_max_zoom1 = resize(img_max_zoom1) | |
img_max_zoom2 = img[anchor_max_x-SIZE//2:anchor_max_x+SIZE//2, | |
anchor_max_y-SIZE//2:anchor_max_y+SIZE//2] | |
#img_max_zoom2 = img[tf.math.maximum(anchor_max_x-SIZE//2, 0): tf.math.minimum(anchor_max_x+SIZE//2, w), | |
# tf.math.maximum(anchor_max_y-SIZE//2, 0): tf.math.minimum(anchor_max_y+SIZE//2, h)] | |
#tf.print(img_max_zoom2.shape) | |
#img_max_zoom2 = resize(img_max_zoom2) | |
return tf.cast([ | |
img_resize, | |
#img_pad, | |
img_anchor_1, | |
img_anchor_2, | |
img_max_zoom1, | |
#img_max_zoom2, | |
], tf.float32) | |
def one_step_inference(x): | |
if len(x.shape) == 3: | |
original_size = x.shape[:2] | |
elif len(x.shape) == 4: | |
original_size = x.shape[1:3] | |
x, intermediate_shape = preprocess(x) | |
with torch.no_grad(): | |
image_embedding = sam.image_encoder(x) | |
with torch.no_grad(): | |
sparse_embeddings, dense_embeddings = sam.prompt_encoder(points = None, boxes = None,masks = None) | |
low_res_masks, iou_predictions = sam.mask_decoder( | |
image_embeddings=image_embedding, | |
image_pe=sam.prompt_encoder.get_dense_pe(), | |
sparse_prompt_embeddings=sparse_embeddings, | |
dense_prompt_embeddings=dense_embeddings, | |
multimask_output=False, | |
) | |
if len(x.shape) == 3: | |
input_size = tuple(x.shape[:2]) | |
elif len(x.shape) == 4: | |
input_size = tuple(x.shape[-2:]) | |
#upscaled_masks = sam.postprocess_masks(low_res_masks, input_size, original_size).cuda() | |
mask = F.interpolate(low_res_masks, (1024, 1024))[:, :, :intermediate_shape[0], :intermediate_shape[1]] | |
mask = F.interpolate(mask, (original_size[0], original_size[1])) | |
return mask.to(device) #mask | |
def segmentation_sam(x,SIZE=384): | |
x = tf.image.resize_with_pad(x, SIZE, SIZE) | |
predicted_mask = one_step_inference(x) | |
fig, ax = plt.subplots() | |
img = x.cpu().numpy() | |
mask = predicted_mask.cpu().numpy()[0][0]>0.2 | |
ax.imshow(img) | |
ax.imshow(mask, cmap='jet', alpha=0.4) | |
plt.savefig('test.png') | |
ax.axis('off') | |
fig.canvas.draw() | |
# Now we can save it to a numpy array. | |
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) | |
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
plt.close() | |
return data | |