devinschumacher's picture
Duplicate from sklearn-docs/sklearn_vector_quantization
682d1dc
import gradio as gr
import matplotlib.pyplot as plt
from sklearn.preprocessing import KBinsDiscretizer
import numpy as np
from typing import Tuple
def build_init_plot(img_array: np.ndarray) -> Tuple[str, plt.Figure]:
init_text = (f"The dimension of the image is {img_array.shape}\n"
f"The data used to encode the image is of type {img_array.dtype}\n"
f"The number of bytes taken in RAM is {img_array.nbytes}")
fig, ax = plt.subplots(ncols=2, figsize=(12, 4))
ax[0].imshow(img_array, cmap=plt.cm.gray)
ax[0].axis("off")
ax[0].set_title("Rendering of the image")
ax[1].hist(img_array.ravel(), bins=256)
ax[1].set_xlabel("Pixel value")
ax[1].set_ylabel("Count of pixels")
ax[1].set_title("Distribution of the pixel values")
_ = fig.suptitle("Original image")
return init_text, fig
def build_compressed_plot(compressed_image, img_array, sampling: str) -> plt.Figure:
compressed_text = (f"The number of bytes taken in RAM is {compressed_image.nbytes}\n"
f"Compression ratio: {compressed_image.nbytes / img_array.nbytes}\n"
f"Type of the compressed image: {compressed_image.dtype}")
sampling = sampling if sampling == "uniform" else "K-Means"
fig, ax = plt.subplots(ncols=2, figsize=(12, 4))
ax[0].imshow(compressed_image, cmap=plt.cm.gray)
ax[0].axis("off")
ax[0].set_title("Rendering of the image")
ax[1].hist(compressed_image.ravel(), bins=256)
ax[1].set_xlabel("Pixel value")
ax[1].set_ylabel("Count of pixels")
ax[1].set_title("Sub-sampled distribution of the pixel values")
_ = fig.suptitle(f"Original compressed using 3 bits and a {sampling} strategy")
return compressed_text, fig
def infer(img_array: np.ndarray, sampling: str, number_of_bins: int):
# greyscale_image = input_image.convert("L")
# img_array = np.array(greyscale_image)
#raccoon_face = face(gray=True)
init_text, init_fig = build_init_plot(img_array)
n_bins = number_of_bins
encoder = KBinsDiscretizer(
n_bins=n_bins, encode="ordinal", strategy=sampling, random_state=0
)
compressed_image = encoder.fit_transform(img_array.reshape(-1, 1)).reshape(
img_array.shape
)
compressed_image = compressed_image.astype(np.uint8)
compressed_text, compressed_fig = build_compressed_plot(compressed_image,
img_array,
sampling)
bin_edges = encoder.bin_edges_[0]
bin_center = bin_edges[:-1] + (bin_edges[1:] - bin_edges[:-1]) / 2
comparison_fig, ax = plt.subplots()
ax.hist(img_array.ravel(), bins=256)
color = "tab:orange"
for center in bin_center:
ax.axvline(center, color=color)
ax.text(center - 10, ax.get_ybound()[1] + 100, f"{center:.1f}", color=color)
return init_text, init_fig, compressed_text, compressed_fig, comparison_fig
article = """<center>
Demo by <a href='https://huggingface.co/johko' target='_blank'>Johannes (johko) Kolbe</a>"""
gr.Interface(
title="Vector Quantization with scikit-learn",
description="""<p style="text-align: center;">This is an interactive demo for the <a href="https://scikit-learn.org/stable/auto_examples/cluster/plot_face_compress.html">Vector Quantization Tutorial</a> from scikit-learn.
</br><b>Vector Quantization</b> is a compression technique to reduce the number of color values that are used in an image and with this save memory while trying to keep a good quality.
In this demo this can be done naively via <i>uniform</i> sampling, which just uses <i>N</i> color values (specified via slider) uniformly sampled from the whole spectrum or via <i>k-means</i> which pays closer attention
to the actual pixel distribution and potentially leads to a better quality of the compressed image.
In this demo we actually won't see a compression effect, because we cannot go smaller than <i>uint8</i> in datatype size here.
</br>
</br><b>Usage</b>: To run the demo you can simply upload an image and choose from two sampling methods - <i>uniform</i> and <i>kmeans</i>. Choose the number of bins and then click 'submit'.
You will get information about the histogram, pixels distribution and other image statistics for your orginial image as grayscale and the quantized version of it.
</p>""",
article=article,
fn=infer,
inputs=[gr.Image(image_mode="L", label="Input Image"),
gr.Dropdown(choices=["uniform", "kmeans"], label="Sampling Method"),
gr.Slider(minimum=2, maximum=50, value=8, step=1, label="Number of Bins")],
outputs=[gr.Text(label="Original Image Stats"),
gr.Plot(label="Original Image Histogram"),
gr.Text(label="Compressed Image Stats"),
gr.Plot(label="Compressed Image Histogram"),
gr.Plot(label="Pixel Distribution Comparison")],
examples=[["examples/hamster.jpeg", "uniform", 8],
["examples/racoon.png", "kmeans", 8]]).launch()