rashmi's picture
Create app.py
2d51d39
raw
history blame
4.44 kB
import time
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.ndimage import gaussian_filter
from skimage.data import coins
from skimage.transform import rescale
from sklearn.cluster import AgglomerativeClustering
from sklearn.feature_extraction.image import grid_to_graph
plt.switch_backend('agg')
def cluster(gf_sigma, scale, anti_alias, mode, n_clusters,linkage):
orig_coins = coins()
smoothened_coins = gaussian_filter(orig_coins, sigma=gf_sigma)
# Resize it to 20% of the original size to speed up the processing Applying a Gaussian filter for smoothing
# prior to down-scaling reduces aliasing artifacts.
rescaled_coins = rescale(
smoothened_coins,
scale = scale,
mode="reflect",
anti_aliasing=False,
)
X = np.reshape(rescaled_coins, (-1, 1))
connectivity = grid_to_graph(*rescaled_coins.shape)
result = ""
result += "Compute structured hierarchical clustering...\n"
st = time.time()
ward = AgglomerativeClustering(
n_clusters=n_clusters, linkage="ward", connectivity=connectivity
)
ward.fit(X)
label = np.reshape(ward.labels_, rescaled_coins.shape)
result += f"Elapsed time: {time.time() - st:.3f}s \n"
result += f"Number of pixels: {label.size} \n"
result += f"Number of clusters: {np.unique(label).size} \n"
fig = plt.figure(figsize=(7, 7))
plt.imshow(rescaled_coins, cmap=plt.cm.gray)
for l in range(n_clusters):
plt.contour(
label == l,
colors=[
plt.cm.nipy_spectral(l / float(n_clusters)),
],
)
plt.axis("off")
return result, fig
## https://scikit-learn.org/stable/auto_examples/cluster/plot_coin_ward_segmentation.html
title = "A demo of structured Ward hierarchical clustering on an image of coins"
def do_submit(gf_sigma, scale, anti_alias, mode, n_clusters,linkage):
gf_sigma = float(gf_sigma)
scale = float(scale)
anti_alias = True if anti_alias == "True" else False
n_clusters = int(n_clusters)
result, fig = cluster(gf_sigma, scale, anti_alias, mode, n_clusters,linkage)
return result, fig
with gr.Blocks(title=title) as demo:
gr.Markdown(f"## {title}")
gr.Markdown("[Scikit-learn Example](https://scikit-learn.org/stable/auto_examples/cluster/plot_coin_ward_segmentation.html)")
gr.Markdown("Compute the segmentation of a 2D image with Ward hierarchical clustering. \
The clustering is spatially constrained in order for each segmented region to be in one piece.")
with gr.Row(variant="evenly-spaced"):
gf_sigma = gr.Slider(minimum=1, maximum=10, label="Gaussian Filter Sigma", value=2, \
info="Standard deviation for Gaussian filtering before down-scaling.", step=0.1)
scale = gr.Slider(minimum=0.1, maximum=0.7, label="Scale", value=0.2, \
info="Scale factor for the image.", step=0.1)
anti_alias = gr.Radio(["True","False"], label="Anti Aliasing", value="False", \
info="Whether to apply a Gaussian filter to smooth the image prior to down-scaling. \
It is crucial to filter when down-sampling the image to avoid aliasing artifacts.\
If input image data type is bool, no anti-aliasing is applied.")
mode = gr.Dropdown(
["constant", "edge", "symmetric", "reflect", "wrap"], value=["reflect"], multiselect=False, label="mode",\
info="Points outside the boundaries of the input are filled according to the given mode. Modes match the behaviour of numpy.pad."
)
with gr.Row():
## Agglomerative Clustering parameters
n_clusters = gr.Slider(minimum=2, maximum=70,label="Number of Clusters", value=27, \
info="The number of clusters to find.", step=1)
linkage = gr.Dropdown(["ward", "complete", "average", "single"], value=["ward"], multiselect=False, label="linkage",\
info="Which linkage criterion to use. The linkage criterion determines which distance to use between sets of observation. ")
output = gr.Textbox(label="Output Box")
plt_out = gr.Plot()
submit_btn = gr.Button("Submit")
submit_btn.click(fn=do_submit, inputs=[gf_sigma, scale, anti_alias, mode, n_clusters,linkage], outputs=[output, plt_out])
if __name__ == "__main__":
demo.launch()