from PIL import Image import torch import numpy as np import gradio as gr from pathlib import Path from busam import Busam resize_to = 512 checkpoint = "weights.pth" device = "cpu" print("Loading model...") busam = Busam(checkpoint=checkpoint, device=device, side=resize_to) minmaxnorm = lambda x: (x - x.min()) / (x.max() - x.min()) def edge_inference(img, algorithm, th_low=None, th_high=None): algorithm = algorithm.lower() print("Loading image...") img = np.array(img[:, :, :3]) print("Getting features...") pred, size = busam.process_image(img, do_activate=True) print("Computing sobel...") if algorithm == "sobel": edge = busam.sobel_from_pred(pred, size) elif algorithm == "canny": th_low, th_high = th_low or 5000, th_high or 10000 edge = busam.canny_from_pred(pred, size, th_low=th_low, th_high=th_high) else: raise ValueError("algorithm should be sobel or canny") edge = edge.cpu().numpy() if isinstance(edge, torch.Tensor) else edge print("Done") return Image.fromarray( (minmaxnorm(edge) * 255).astype(np.uint8) ).resize(size[::-1]) def dimred_inference( img, algorithm, resample_pct, ): algorithm = algorithm.lower() img = np.array(img[:, :, :3]) print("Getting features...") pred, size = busam.process_image(img, do_activate=True) # pred is 1, F, S, S assert pred.shape[1] >= 3, "should have at least 3 channels" if algorithm == 'pca': from sklearn.decomposition import PCA reducer = PCA(n_components=3) elif algorithm == 'tsne': from sklearn.manifold import TSNE reducer = TSNE(n_components=3) elif algorithm == 'umap': from umap import UMAP reducer = UMAP(n_components=3) else: raise ValueError('algorithm should be pca, tsne or umap') np_y_hat = pred.detach().cpu().permute(1, 0, 2, 3).numpy() # F, B, H, W np_y_hat = np_y_hat.reshape(np_y_hat.shape[0], -1) # F, BHW np_y_hat = np_y_hat.T # BHW, F resample_pct = 10**resample_pct resample_size = int(resample_pct * np_y_hat.shape[0]) sampled_pixels = np_y_hat[:: np_y_hat.shape[0] // resample_size] print("dim reduction fit..." + " " * 30, end="\r") reducer = reducer.fit(sampled_pixels) print("dim reduction transform..." + " " * 30, end="\r") reducer.transform(np_y_hat[:10]) # to numba compile the function np_y_hat = reducer.transform(np_y_hat) # BHW, 3 print() print('Done. Saving...') # revert back to original shape colors = np_y_hat.reshape(pred.shape[2], pred.shape[3], 3) return Image.fromarray((minmaxnorm(colors) * 255).astype(np.uint8)).resize( size[::-1] ) def segmentation_inference(img, algorithm, scale): algorithm = algorithm.lower() img = np.array(img[:, :, :3]) print("Getting features...") pred, size = busam.process_image(img, do_activate=True) print("Computing segmentation...") if algorithm == "kmeans": from sklearn.cluster import KMeans n_clusters = int(100 / 100**scale) kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit( pred.view(pred.shape[1], -1).T ) labels = kmeans.labels_ labels = labels.reshape(pred.shape[2], pred.shape[3]) elif algorithm == "felzenszwalb": from skimage.segmentation import felzenszwalb labels = felzenszwalb( (minmaxnorm(pred[0].cpu().numpy()) * 255).astype(np.uint8).transpose(1, 2, 0), scale=10**(8*scale-3), sigma=0, min_size=50, ) elif algorithm == "slic": from skimage.segmentation import slic labels = slic( (minmaxnorm(pred[0].cpu().numpy()) * 255).astype(np.uint8).transpose(1, 2, 0), n_segments = int(100 / 100**scale), compactness=0.00001, sigma=1, ) else: raise ValueError("algorithm should be kmeans, felzenszwalb or slic") print("Done") # the labels have values that are usually close to each other in the image and in magnitude, which complicates visualization # shuffle the labels to make them more visually distinct out = labels.copy() out[labels % 4 == 0] = labels[labels % 4 == 0] * 1 / 4 out[labels % 4 == 1] = labels[labels % 4 == 1] * 4 // 4 + 1 out[labels % 4 == 2] = labels[labels % 4 == 2] * 2 // 4 + 2 out[labels % 4 == 3] = labels[labels % 4 == 3] * 3 // 4 + 3 return Image.fromarray( (minmaxnorm(out) * 255).astype(np.uint8) ).resize(size[::-1]) def one_click_segmentation(img, row, col, threshold): row, col = int(row), int(col) img = np.array(img[:, :, :3]) click_map = np.zeros(img.shape[:2], dtype=bool) click_map[max(0, row-5):min(img.shape[0], row+5), col] = True click_map[row, max(0, col-5):min(img.shape[1], col+5)] = True print("Getting features...") pred, size = busam.process_image(img, do_activate=True) print("Getting mask...") mask = busam.get_mask((pred, size), (row, col)) print("Done") print('shapes=', img.shape, mask.shape, click_map.shape) return (img, [(mask, 'Prediction'), (click_map, 'Click')]) with gr.Blocks() as demo: with gr.Tab('Edge detection'): algorithm = "canny" with gr.Row(): def enable_sliders(algorithm): algorithm = algorithm.lower() return gr.Slider(visible=algorithm == "canny"), gr.Slider(visible=algorithm == "canny") with gr.Column(): image_input = gr.Image(label="Input Image") run_button = gr.Button("Run") algorithm = gr.Radio(["Sobel", "Canny"], label="Algorithm", value="Sobel") # add sliders for th_low, th_high th_low_slider = gr.Slider(0, 32768, 10000, label="Canny's low threshold", visible=False) th_high_slider = gr.Slider(0, 32768, 20000, label="Canny's high threshold", visible=False) algorithm.change(enable_sliders, inputs=[algorithm], outputs=[th_low_slider, th_high_slider]) with gr.Column(): output_image = gr.Image(label="Output Image") run_button.click(edge_inference, inputs=[image_input, algorithm, th_low_slider, th_high_slider], outputs=output_image) gr.Examples([str(p) for p in Path('demoimgs').glob('*')], inputs=image_input) with gr.Tab('Reduction to 3D'): with gr.Row(): with gr.Column(): image_input = gr.Image(label="Input Image") algorithm = gr.Radio(["PCA", "TSNE", "UMAP"], label="Algorithm") run_button = gr.Button("Run") gr.Markdown("⚠️ UMAP is slow, TSNE is ultra-slow, use resample x<-3 ⚠️") resample_pct = gr.Slider(-5, 0, -3, label="Resample (10^x)*100%") with gr.Column(): output_image = gr.Image(label="Output Image") run_button.click(dimred_inference, inputs=[image_input, algorithm, resample_pct], outputs=output_image) gr.Examples([str(p) for p in Path('demoimgs').glob('*')], inputs=image_input) with gr.Tab('Classical Segmentation'): with gr.Row(): with gr.Column(): image_input = gr.Image(label="Input Image") algorithm = gr.Radio(['KMeans', 'Felzenszwalb', 'SLIC'], label="Algorithm", value="SLIC") scale = gr.Slider(0.1, 1.0, 0.5, label="Scale") run_button = gr.Button("Run") with gr.Column(): output_image = gr.Image(label="Output Image") run_button.click(segmentation_inference, inputs=[image_input, algorithm, scale], outputs=output_image) gr.Examples([str(p) for p in Path('demoimgs').glob('*')], inputs=image_input) with gr.Tab('One-click segmentation'): with gr.Row(): with gr.Column(): image_input = gr.Image(label="Input Image") threshold = gr.Slider(0, 1, 0.5, label="Threshold") with gr.Row(): row = gr.Textbox(10, label="Click's row") col = gr.Textbox(10, label="Click's column") run_button = gr.Button("Run") with gr.Column(): output_image = gr.AnnotatedImage(label="Output") run_button.click(one_click_segmentation, inputs=[image_input, row, col, threshold], outputs=output_image) gr.Examples([str(p) for p in Path('demoimgs').glob('*')], inputs=image_input) demo.launch(share=False)