sidewalk-sam / backend.py
David Vaillant
Basic func.
a073fdd
raw
history blame
2.5 kB
# backend.py
import numpy as np
from PIL import Image, ImageDraw
import torch
from transformers import SamModel, SamProcessor
from torchvision.transforms import v2
from samgeo.text_sam import LangSAM
import os
import logging
preproc = v2.Compose([
v2.PILToTensor(),
v2.ToDtype(torch.float32, scale=True), # to float32 in [0, 1]
])
# Load the necessary models.
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
CHECKPOINT_FILE = os.getenv("SAM_FINETUNE_CHECKPOINT", "checkpoints/bbox_finetune.pth")
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
tuned_model = SamModel.from_pretrained("facebook/sam-vit-large").to(device)
tuned_model.load_state_dict(torch.load(CHECKPOINT_FILE,
map_location=device))
langsam_model = LangSAM("vit_l")
def process_image(image: Image, bbox: list[int, int, int, int] = None) -> Image:
logging.info("Logging image information.")
if bbox is None:
# No bbox information. Use default (filters out zeroes)
logging.debug("Using default, null bounding box.")
bbox = list(map(float, image.getbbox())) # List of floats.
inputs = processor(preproc(image), input_boxes=[[bbox]],
do_rescale=False, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()} # Map objects to our device.
mask = get_sidewalk_mask(tuned_model, inputs)
# Get tree masks.
# Union 'em??
return mask
def get_sidewalk_mask(model, inputs) -> Image:
logging.info("Calculating mask.")
model.eval()
with torch.no_grad():
outputs = model(**inputs, multimask_output=False)
## apply sigmoid
mask_probabilities = torch.sigmoid(outputs.pred_masks.squeeze(1))
## Convert to numpy for the rest of our stuff.
mask_probabilities = mask_probabilities.cpu().numpy().squeeze()
## Filter out smaller probs.
mask_probabilities[mask_probabilities < 0.5] = 0
## Map probabilities to color intensity linearly.
mask_probabilities *= 255
greyscale_img = Image.fromarray(mask_probabilities).convert('L')
return greyscale_img
def get_tree_masks(image: Image):
langsam_model.predict(image, "tree", box_threshold=0.24, text_threshold=0.24)
# masks, boxes, phrases, logits = tuned_model.predict(image_pil, bbox)
# tree_data = langsam_model.predict(image_pil, text_prompt)
# def draw_layer_on_image(model, im: Image, text_prompt: str='sidewalk') -> Image: