# backend.py import numpy as np from PIL import Image, ImageDraw import torch from transformers import SamModel, SamProcessor from torchvision.transforms import v2 from samgeo.text_sam import LangSAM import os import logging preproc = v2.Compose([ v2.PILToTensor(), v2.ToDtype(torch.float32, scale=True), # to float32 in [0, 1] ]) # Load the necessary models. device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') CHECKPOINT_FILE = os.getenv("SAM_FINETUNE_CHECKPOINT", "checkpoints/bbox_finetune.ckpt") processor = SamProcessor.from_pretrained("facebook/sam-vit-base") tuned_model = SamModel.from_pretrained("facebook/sam-vit-large").to(device) tuned_model.load_state_dict(torch.load(CHECKPOINT_FILE, map_location=device)) langsam_model = LangSAM("vit_l") def process_image(image: Image, bbox: list[int, int, int, int] = None) -> Image: logging.info("Logging image information.") if bbox is None: # No bbox information. Use default (filters out zeroes) logging.debug("Using default, null bounding box.") bbox = list(map(float, image.getbbox())) # List of floats. inputs = processor(preproc(image), input_boxes=[[bbox]], do_rescale=False, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} # Map objects to our device. mask = get_sidewalk_mask(tuned_model, inputs) # Get tree masks. # Union 'em?? return mask def get_sidewalk_mask(model, inputs) -> Image: logging.info("Calculating mask.") model.eval() with torch.no_grad(): outputs = model(**inputs, multimask_output=False) ## apply sigmoid mask_probabilities = torch.sigmoid(outputs.pred_masks.squeeze(1)) ## Convert to numpy for the rest of our stuff. mask_probabilities = mask_probabilities.cpu().numpy().squeeze() ## Filter out smaller probs. mask_probabilities[mask_probabilities < 0.5] = 0 ## Map probabilities to color intensity linearly. mask_probabilities *= 255 greyscale_img = Image.fromarray(mask_probabilities).convert('L') return greyscale_img def get_tree_masks(image: Image): langsam_model.predict(image, "tree", box_threshold=0.24, text_threshold=0.24) # masks, boxes, phrases, logits = tuned_model.predict(image_pil, bbox) # tree_data = langsam_model.predict(image_pil, text_prompt) # def draw_layer_on_image(model, im: Image, text_prompt: str='sidewalk') -> Image: