import gradio as gr import matplotlib.pyplot as plt from sklearn.preprocessing import KBinsDiscretizer import numpy as np from typing import Tuple def build_init_plot(img_array: np.ndarray) -> Tuple[str, plt.Figure]: init_text = (f"The dimension of the image is {img_array.shape}\n" f"The data used to encode the image is of type {img_array.dtype}\n" f"The number of bytes taken in RAM is {img_array.nbytes}") fig, ax = plt.subplots(ncols=2, figsize=(12, 4)) ax[0].imshow(img_array, cmap=plt.cm.gray) ax[0].axis("off") ax[0].set_title("Rendering of the image") ax[1].hist(img_array.ravel(), bins=256) ax[1].set_xlabel("Pixel value") ax[1].set_ylabel("Count of pixels") ax[1].set_title("Distribution of the pixel values") _ = fig.suptitle("Original image") return init_text, fig def build_compressed_plot(compressed_image, img_array, sampling: str) -> plt.Figure: compressed_text = (f"The number of bytes taken in RAM is {compressed_image.nbytes}\n" f"Compression ratio: {compressed_image.nbytes / img_array.nbytes}\n" f"Type of the compressed image: {compressed_image.dtype}") sampling = sampling if sampling == "uniform" else "K-Means" fig, ax = plt.subplots(ncols=2, figsize=(12, 4)) ax[0].imshow(compressed_image, cmap=plt.cm.gray) ax[0].axis("off") ax[0].set_title("Rendering of the image") ax[1].hist(compressed_image.ravel(), bins=256) ax[1].set_xlabel("Pixel value") ax[1].set_ylabel("Count of pixels") ax[1].set_title("Sub-sampled distribution of the pixel values") _ = fig.suptitle(f"Original compressed using 3 bits and a {sampling} strategy") return compressed_text, fig def infer(img_array: np.ndarray, sampling: str, number_of_bins: int): # greyscale_image = input_image.convert("L") # img_array = np.array(greyscale_image) #raccoon_face = face(gray=True) init_text, init_fig = build_init_plot(img_array) n_bins = number_of_bins encoder = KBinsDiscretizer( n_bins=n_bins, encode="ordinal", strategy=sampling, random_state=0 ) compressed_image = encoder.fit_transform(img_array.reshape(-1, 1)).reshape( img_array.shape ) compressed_image = compressed_image.astype(np.uint8) compressed_text, compressed_fig = build_compressed_plot(compressed_image, img_array, sampling) bin_edges = encoder.bin_edges_[0] bin_center = bin_edges[:-1] + (bin_edges[1:] - bin_edges[:-1]) / 2 comparison_fig, ax = plt.subplots() ax.hist(img_array.ravel(), bins=256) color = "tab:orange" for center in bin_center: ax.axvline(center, color=color) ax.text(center - 10, ax.get_ybound()[1] + 100, f"{center:.1f}", color=color) return init_text, init_fig, compressed_text, compressed_fig, comparison_fig article = """
This is an interactive demo for the Vector Quantization Tutorial from scikit-learn. Vector Quantization is a compression technique to reduce the number of color values that are used in an image and with this save memory while trying to keep a good quality. In this demo this can be done naively via uniform sampling, which just uses N color values (specified via slider) uniformly sampled from the whole spectrum or via k-means which pays closer attention to the actual pixel distribution and potentially leads to a better quality of the compressed image. In this demo we actually won't see a compression effect, because we cannot go smaller than uint8 in datatype size here. Usage: To run the demo you can simply upload an image and choose from two sampling methods - uniform and kmeans. Choose the number of bins and then click 'submit'. You will get information about the histogram, pixels distribution and other image statistics for your orginial image as grayscale and the quantized version of it.
""", article=article, fn=infer, inputs=[gr.Image(image_mode="L", label="Input Image"), gr.Dropdown(choices=["uniform", "kmeans"], label="Sampling Method"), gr.Slider(minimum=2, maximum=50, value=8, step=1, label="Number of Bins")], outputs=[gr.Text(label="Original Image Stats"), gr.Plot(label="Original Image Histogram"), gr.Text(label="Compressed Image Stats"), gr.Plot(label="Compressed Image Histogram"), gr.Plot(label="Pixel Distribution Comparison")], examples=[["examples/hamster.jpeg", "uniform", 8], ["examples/racoon.png", "kmeans", 8]]).launch()