Johannes commited on
Commit
48db972
β€’
1 Parent(s): facf38a

add demo code

Browse files
Files changed (5) hide show
  1. README.md +2 -2
  2. app.py +92 -0
  3. examples/hamster.jpeg +0 -0
  4. examples/racoon.png +0 -0
  5. requirements.txt +3 -0
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
- title: Sklearn Vector Quantization
3
- emoji: πŸ“š
4
  colorFrom: red
5
  colorTo: green
6
  sdk: gradio
 
1
  ---
2
+ title: sklearn Vector Quantization
3
+ emoji: πŸ“Š
4
  colorFrom: red
5
  colorTo: green
6
  sdk: gradio
app.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import matplotlib.pyplot as plt
3
+ from sklearn.preprocessing import KBinsDiscretizer
4
+ from PIL import Image
5
+ import numpy as np
6
+
7
+
8
+ def build_init_plot(img_array: np.ndarray) -> tuple[str, plt.Figure]:
9
+ init_text = (f"The dimension of the image is {img_array.shape}\n"
10
+ f"The data used to encode the image is of type {img_array.dtype}\n"
11
+ f"The number of bytes taken in RAM is {img_array.nbytes}")
12
+
13
+ fig, ax = plt.subplots(ncols=2, figsize=(12, 4))
14
+
15
+ ax[0].imshow(img_array, cmap=plt.cm.gray)
16
+ ax[0].axis("off")
17
+ ax[0].set_title("Rendering of the image")
18
+ ax[1].hist(img_array.ravel(), bins=256)
19
+ ax[1].set_xlabel("Pixel value")
20
+ ax[1].set_ylabel("Count of pixels")
21
+ ax[1].set_title("Distribution of the pixel values")
22
+ _ = fig.suptitle("Original image")
23
+
24
+ return init_text, fig
25
+
26
+
27
+ def build_compressed_plot(compressed_image, img_array, sampling: str) -> plt.Figure:
28
+ compressed_text = (f"The number of bytes taken in RAM is {compressed_image.nbytes}\n"
29
+ f"Compression ratio: {compressed_image.nbytes / img_array.nbytes}\n"
30
+ f"Type of the compressed image: {compressed_image.dtype}")
31
+
32
+ sampling = sampling if sampling == "uniform" else "K-Means"
33
+
34
+ fig, ax = plt.subplots(ncols=2, figsize=(12, 4))
35
+ ax[0].imshow(compressed_image, cmap=plt.cm.gray)
36
+ ax[0].axis("off")
37
+ ax[0].set_title("Rendering of the image")
38
+ ax[1].hist(compressed_image.ravel(), bins=256)
39
+ ax[1].set_xlabel("Pixel value")
40
+ ax[1].set_ylabel("Count of pixels")
41
+ ax[1].set_title("Sub-sampled distribution of the pixel values")
42
+ _ = fig.suptitle(f"Original compressed using 3 bits and a {sampling} strategy")
43
+
44
+ return compressed_text, fig
45
+
46
+
47
+ def infer(img_array: np.ndarray, sampling: str):
48
+ # greyscale_image = input_image.convert("L")
49
+ # img_array = np.array(greyscale_image)
50
+
51
+ #raccoon_face = face(gray=True)
52
+ init_text, init_fig = build_init_plot(img_array)
53
+
54
+ n_bins = 8
55
+ encoder = KBinsDiscretizer(
56
+ n_bins=n_bins, encode="ordinal", strategy=sampling, random_state=0
57
+ )
58
+ compressed_image = encoder.fit_transform(img_array.reshape(-1, 1)).reshape(
59
+ img_array.shape
60
+ )
61
+
62
+ compressed_text, compressed_fig = build_compressed_plot(compressed_image,
63
+ img_array,
64
+ sampling)
65
+
66
+ bin_edges = encoder.bin_edges_[0]
67
+ bin_center = bin_edges[:-1] + (bin_edges[1:] - bin_edges[:-1]) / 2
68
+
69
+ comparison_fig, ax = plt.subplots()
70
+ ax.hist(img_array.ravel(), bins=256)
71
+ color = "tab:orange"
72
+ for center in bin_center:
73
+ ax.axvline(center, color=color)
74
+ ax.text(center - 10, ax.get_ybound()[1] + 100, f"{center:.1f}", color=color)
75
+
76
+ return init_text, init_fig, compressed_text, compressed_fig, comparison_fig
77
+
78
+
79
+ gr.Interface(
80
+ title="Vector Quantization with scikit-learn",
81
+ description="""<p style="text-align: center;">This is an interactive demo for the <a href="https://scikit-learn.org/stable/auto_examples/cluster/plot_face_compress.html">Vector Quantization Tutorial</a> from scikit-learn.
82
+ </br>You can upload an image and choose from two sampling methods - *uniform* and *kmeans*.</p>""",
83
+ fn=infer,
84
+ inputs=[gr.Image(image_mode="L", label="Input Image"),
85
+ gr.Dropdown(choices=["uniform", "kmeans"], label="Sampling Method")],
86
+ outputs=[gr.Text(label="Original Image Stats"),
87
+ gr.Plot(label="Original Image Histogram"),
88
+ gr.Text(label="Compressed Image Stats"),
89
+ gr.Plot(label="Compressed Image Histogram"),
90
+ gr.Plot(label="Pixel Distribution Comparison")],
91
+ examples=[["examples/hamster.jpeg", "uniform"],
92
+ ["examples/racoon.png", "kmeans"]]).launch()
examples/hamster.jpeg ADDED
examples/racoon.png ADDED
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ matplotlib==3.6.3
2
+ scikit-learn==1.2.1
3
+ scipy