David Vaillant commited on
Commit
a073fdd
1 Parent(s): 5b4a37c

Basic func.

Browse files
Files changed (3) hide show
  1. baby_shiny.py +102 -0
  2. backend.py +72 -0
  3. checkpoints/bbox_finetune.ckpt +3 -0
baby_shiny.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from shiny import App, Inputs, Outputs, Session, reactive, render, ui
2
+ from shiny.types import FileInfo, ImgData
3
+ import asyncio
4
+ import concurrent.futures
5
+
6
+ import backend
7
+
8
+ import numpy as np
9
+ from PIL import Image, ImageDraw
10
+ from pathlib import Path
11
+ import tempfile
12
+
13
+
14
+ def draw_layer_on_image(im: Image) -> Image:
15
+ """Draws something on top of an image."""
16
+ # Attempting to use thresholds.
17
+ threshold: int = 1
18
+ output_im = np.array(im)
19
+ # return Image.fromarray(output_im)
20
+
21
+ # The image drawing code.
22
+ draw = ImageDraw.Draw(im)
23
+ draw.line((0, 0) + im.size, fill=128, width=5)
24
+ draw.line((0, im.size[1], im.size[0], 0), fill=128)
25
+
26
+ return im
27
+
28
+
29
+ # UI:
30
+ # TITLE ELEMENT, centered
31
+ # input, centered.
32
+ # table in middle. Upload, displays image on the left.
33
+ # arrow in the middle, mask on the right.
34
+ card_height = '700px'
35
+ app_ui = ui.page_fixed(
36
+ ui.input_file("file1", "Upload a sidewalk.", accept=[".jpg", ".png", ".jpeg"], multiple=False),
37
+ ui.layout_columns(
38
+ ui.card(
39
+ ui.card_header("Uploaded Image"),
40
+ ui.output_image("show_image"),
41
+ height=card_height
42
+ ),
43
+ ui.card(
44
+ ui.card_header("Image Mask"),
45
+ # ui.input_task_button("mask_btn", "Process mask"),
46
+ ui.output_image("samwalk"),
47
+ height=card_height
48
+ ),
49
+ )
50
+ )
51
+
52
+ def strip_alpha(image: Image) -> Image:
53
+ # Create a white background
54
+ background = Image.new('RGBA', image.size, (255, 255, 255, 255))
55
+ composite = Image.alpha_composite(background, image)
56
+ rgb_image = composite.convert('RGB')
57
+ return rgb_image
58
+
59
+ def server(input: Inputs, output: Outputs, session: Session):
60
+ uploaded_img = None
61
+
62
+ @reactive.calc
63
+ def parsed_file():
64
+ file: list[FileInfo] | None = input.file1()
65
+ if file is None:
66
+ return
67
+ return file[0]
68
+
69
+ @render.image
70
+ def show_image():
71
+ uploaded_img = parsed_file()
72
+ if uploaded_img is None:
73
+ return
74
+ uploaded_src = uploaded_img['datapath']
75
+ img: ImgData = {"src": str(uploaded_src), "width": "500px"}
76
+ return img
77
+
78
+ # @reactive.event(input.mask_btn)
79
+ @render.image
80
+ def samwalk():
81
+ uploaded_file = parsed_file()
82
+ if uploaded_file is None:
83
+ return
84
+ uploaded_src = uploaded_file['datapath']
85
+ uploaded_img = Image.open(uploaded_src)
86
+ if uploaded_img.mode == 'RGBA':
87
+ uploaded_img = strip_alpha(uploaded_img)
88
+ dirpath = tempfile.mkdtemp()
89
+
90
+ # output_img = async_process_image(uploaded_img)
91
+ # while output_img is None:
92
+ # pass
93
+ # output_img = output_img.result()
94
+ # # return {"src": str("waiting.gif"), "width": "500px"}
95
+ output_img = backend.process_image(uploaded_img)
96
+ output_path = dirpath / Path(uploaded_src)
97
+ output_img.save(output_path)
98
+ return {"src": str(output_path), "width": "500px"}
99
+
100
+
101
+
102
+ app = App(app_ui, server)
backend.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # backend.py
2
+ import numpy as np
3
+ from PIL import Image, ImageDraw
4
+ import torch
5
+ from transformers import SamModel, SamProcessor
6
+ from torchvision.transforms import v2
7
+ from samgeo.text_sam import LangSAM
8
+ import os
9
+ import logging
10
+
11
+
12
+ preproc = v2.Compose([
13
+ v2.PILToTensor(),
14
+ v2.ToDtype(torch.float32, scale=True), # to float32 in [0, 1]
15
+ ])
16
+
17
+
18
+ # Load the necessary models.
19
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
20
+ CHECKPOINT_FILE = os.getenv("SAM_FINETUNE_CHECKPOINT", "checkpoints/bbox_finetune.pth")
21
+
22
+ processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
23
+ tuned_model = SamModel.from_pretrained("facebook/sam-vit-large").to(device)
24
+ tuned_model.load_state_dict(torch.load(CHECKPOINT_FILE,
25
+ map_location=device))
26
+ langsam_model = LangSAM("vit_l")
27
+
28
+
29
+ def process_image(image: Image, bbox: list[int, int, int, int] = None) -> Image:
30
+ logging.info("Logging image information.")
31
+ if bbox is None:
32
+ # No bbox information. Use default (filters out zeroes)
33
+ logging.debug("Using default, null bounding box.")
34
+ bbox = list(map(float, image.getbbox())) # List of floats.
35
+ inputs = processor(preproc(image), input_boxes=[[bbox]],
36
+ do_rescale=False, return_tensors="pt")
37
+ inputs = {k: v.to(device) for k, v in inputs.items()} # Map objects to our device.
38
+
39
+ mask = get_sidewalk_mask(tuned_model, inputs)
40
+ # Get tree masks.
41
+ # Union 'em??
42
+ return mask
43
+
44
+
45
+ def get_sidewalk_mask(model, inputs) -> Image:
46
+ logging.info("Calculating mask.")
47
+ model.eval()
48
+ with torch.no_grad():
49
+ outputs = model(**inputs, multimask_output=False)
50
+ ## apply sigmoid
51
+ mask_probabilities = torch.sigmoid(outputs.pred_masks.squeeze(1))
52
+ ## Convert to numpy for the rest of our stuff.
53
+ mask_probabilities = mask_probabilities.cpu().numpy().squeeze()
54
+
55
+ ## Filter out smaller probs.
56
+ mask_probabilities[mask_probabilities < 0.5] = 0
57
+
58
+ ## Map probabilities to color intensity linearly.
59
+ mask_probabilities *= 255
60
+
61
+ greyscale_img = Image.fromarray(mask_probabilities).convert('L')
62
+ return greyscale_img
63
+
64
+
65
+ def get_tree_masks(image: Image):
66
+ langsam_model.predict(image, "tree", box_threshold=0.24, text_threshold=0.24)
67
+
68
+
69
+ # masks, boxes, phrases, logits = tuned_model.predict(image_pil, bbox)
70
+ # tree_data = langsam_model.predict(image_pil, text_prompt)
71
+
72
+ # def draw_layer_on_image(model, im: Image, text_prompt: str='sidewalk') -> Image:
checkpoints/bbox_finetune.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c72e371f7cd4644c9d9550649db4a5473ad63c21472b9d0973670d0dff1ff69
3
+ size 1249561500