Johannes
minimal adpations
4b6052a
import numpy as np
from sklearn.feature_extraction import image
from sklearn.cluster import spectral_clustering
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import gradio as gr
from scipy.cluster.vq import kmeans
def get_coordinates_from_mask(mask_in, number_of_centroids):
x_y = np.where(mask_in != [255, 255, 255])[:2]
x_y = np.column_stack((x_y[0], x_y[1]))
x_y = np.float32(x_y)
centroids,_ = kmeans(x_y,number_of_centroids)
centroids = np.int64(centroids)
return centroids
def infer(input_image: np.ndarray, number_of_circles: int, radius: int):
centroids = get_coordinates_from_mask(input_image, number_of_circles)
img = np.zeros((input_image.shape[1], input_image.shape[0]))
x, y = np.indices((input_image.shape[1], input_image.shape[0]))
for centroid in centroids:
circle = (x - centroid[0]) ** 2 + (y - centroid[1]) ** 2 < radius**2
img += circle
mask = img.astype(bool)
img = img.astype(float)
img += 1 + 0.2 * np.random.randn(*img.shape)
graph = image.img_to_graph(img, mask=mask)
graph.data = np.exp(-graph.data / graph.data.std())
labels = spectral_clustering(graph, n_clusters=len(centroids), eigen_solver="arpack")
label_im = np.full(mask.shape, -1.0)
label_im[mask] = labels
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
axs[0].matshow(img)
axs[1].matshow(label_im)
return fig
article = """<center>
Demo by <a href='https://huggingface.co/johko' target='_blank'>Johannes (johko) Kolbe</a>"""
description = """<p style="text-align: center;">This is an interactive demo for the <a href="https://scikit-learn.org/stable/auto_examples/cluster/plot_segmentation_toy.html#sphx-glr-auto-examples-cluster-plot-segmentation-toy-py">Spectral clustering for image segmentation tutorial</a> from scikit-learn.
</br></br><b>How To Use</b>
</br>The demo lets you mark places in the input where the centers of circles should be. The circles should then be segmented from one another using Spectral Image Clustering.
</br>The circles should ideally be close together(connected), to let the algorithm work correctly.
</br>As the demo uses k-means to determine the centroids of the circles exactly, you also need to specify the number of circles you want to get.
</br></br><b>What is Spectral Image clustering?</b> From the tutorial:
</br><i>"The Spectral clustering approach solves the problem know as ‘normalized graph cuts’: the image is seen as a graph of connected voxels, and the spectral clustering algorithm amounts to choosing graph cuts defining regions while minimizing the ratio of the gradient along the cut, and the volume of the region."</i> .</p>"""
gr.Interface(
title="Spectral Clustering with scikit-learn",
description=description,
article=article,
fn=infer,
inputs=[gr.Image(source="canvas", tool="sketch", label="Mark the Circle Centers", shape=[100, 100]),
gr.Number(label="Number of circles to draw", value=4, precision=0),
gr.Slider(label="Circle Radius", minimum=5, maximum=25, value=15, step=1)],
outputs=[gr.Plot(label="Output Plot")]
).launch()