Johannes commited on
Commit
f589fdc
1 Parent(s): 2f3a29d

add working code

Browse files
Files changed (3) hide show
  1. README.md +2 -2
  2. app.py +73 -0
  3. requirements.txt +3 -0
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
- title: Sklearn Spectral Clustering
3
- emoji: 👀
4
  colorFrom: blue
5
  colorTo: red
6
  sdk: gradio
 
1
  ---
2
+ title: sklearn Spectral Clustering
3
+ emoji: 🔴🔵🔴
4
  colorFrom: blue
5
  colorTo: red
6
  sdk: gradio
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from sklearn.feature_extraction import image
3
+ from sklearn.cluster import spectral_clustering
4
+ import matplotlib
5
+ matplotlib.use('Agg')
6
+ import matplotlib.pyplot as plt
7
+ import gradio as gr
8
+ from scipy.cluster.vq import kmeans
9
+
10
+
11
+ def get_coordinates_from_mask(mask_in, number_of_centroids):
12
+ x_y = np.where(mask_in != [255, 255, 255])[:2]
13
+ x_y = np.column_stack((x_y[0], x_y[1]))
14
+ x_y = np.float32(x_y)
15
+ centroids,_ = kmeans(x_y,number_of_centroids)
16
+ centroids = np.int64(centroids)
17
+
18
+ return centroids
19
+
20
+
21
+ def infer(input_image: np.ndarray, number_of_circles: int, radius: int):
22
+ centroids = get_coordinates_from_mask(input_image, number_of_circles)
23
+
24
+ img = np.zeros((input_image.shape[1], input_image.shape[0]))
25
+
26
+ x, y = np.indices((input_image.shape[1], input_image.shape[0]))
27
+
28
+ for centroid in centroids:
29
+ circle = (x - centroid[0]) ** 2 + (y - centroid[1]) ** 2 < radius**2
30
+ img += circle
31
+
32
+ mask = img.astype(bool)
33
+
34
+ img = img.astype(float)
35
+ img += 1 + 0.2 * np.random.randn(*img.shape)
36
+
37
+
38
+ graph = image.img_to_graph(img, mask=mask)
39
+ graph.data = np.exp(-graph.data / graph.data.std())
40
+
41
+ labels = spectral_clustering(graph, n_clusters=len(centroids), eigen_solver="arpack")
42
+ label_im = np.full(mask.shape, -1.0)
43
+ label_im[mask] = labels
44
+
45
+ fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
46
+ axs[0].matshow(img)
47
+ axs[1].matshow(label_im)
48
+
49
+ return fig
50
+
51
+
52
+ article = """<center>
53
+ Demo by <a href='https://huggingface.co/johko' target='_blank'>Johannes (johko) Kolbe</a>"""
54
+
55
+
56
+ description = """<p style="text-align: center;">This is an interactive demo for the <a href="https://scikit-learn.org/stable/auto_examples/cluster/plot_segmentation_toy.html#sphx-glr-auto-examples-cluster-plot-segmentation-toy-py">Spectral clustering for image segmentation tutorial</a> from scikit-learn.
57
+ </br></br>The demo lets you mark places in the input where the centers of circles should be. The circles should then be segmented from one another using Spectral Image Clustering.
58
+ </br>The circles should ideally be close together(connected), to let the algorithm work correctly.
59
+ </br>As the demo uses k-means to determine the centroids of the circles exactly, you also need to specify the number of circles you want to get.
60
+ </br></br><b>What is Spectral Image clustering?</b> From the tutorial:
61
+ </br><i>"The Spectral clustering approach solves the problem know as ‘normalized graph cuts’: the image is seen as a graph of connected voxels, and the spectral clustering algorithm amounts to choosing graph cuts defining regions while minimizing the ratio of the gradient along the cut, and the volume of the region."</i> .</p>"""
62
+
63
+
64
+ gr.Interface(
65
+ title="Spectral Clustering with scikit-learn",
66
+ description=description,
67
+ article=artsicle,
68
+ fn=infer,
69
+ inputs=[gr.Image(source="canvas", tool="sketch", label="Input Image", shape=[100, 100]),
70
+ gr.Number(label="Number of circles to draw", value=4, precision=0),
71
+ gr.Slider(label="Circle Radius", minimum=5, maximum=25, value=15, step=1)],
72
+ outputs=[gr.Plot(label="Original Image Histogram")]
73
+ ).launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ matplotlib==3.6.3
2
+ scikit-learn==1.2.1
3
+ scipy