Spaces:
Runtime error
Runtime error
File size: 4,006 Bytes
e01d462 509fb31 e01d462 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
# Import libraries
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_olivetti_faces
from sklearn.utils.validation import check_random_state
from sklearn.ensemble import ExtraTreesRegressor
from sklearn.neighbors import KNeighborsRegressor
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import RidgeCV
import gradio as gr
# Load the faces datasets
data, targets = fetch_olivetti_faces(return_X_y=True)
train = data[targets < 30]
n_pixels = data.shape[1]
# Training data
# Upper half of the faces
X_train = train[:, : (n_pixels + 1) // 2]
# Lower half of the faces
y_train = train[:, n_pixels // 2 :]
# Fit estimators -> The problem (given half the image/features extrapolate the rest of the image/features)
ESTIMATORS = {
"Extra trees": ExtraTreesRegressor(
n_estimators=10, max_features=32, random_state=0
),
"K-nn": KNeighborsRegressor(),
"Linear regression": LinearRegression(),
"Ridge": RidgeCV(),
}
for name, estimator in ESTIMATORS.items():
estimator.fit(X_train, y_train)
test = data[targets >= 30]
n_faces = 15
rng = check_random_state(4)
face_ids = rng.randint(test.shape[0], size=(n_faces,))
test = test[face_ids, :]
# Function for returning 64*64 image, given the image index
def imageFromIndex(index):
return test[int(index)].reshape(1,-1).reshape(64, 64)
# Function for extrapolating face
def extrapolateFace(index, ESTIMATORS=ESTIMATORS):
image = test[int(index)].reshape(1,-1)
image_shape = (64, 64)
n_cols = 1 + len(ESTIMATORS)
n_faces = 1
n_pixels = image.shape[1]
# Upper half of the face
X_upper = image[:, : (n_pixels + 1) // 2]
# Lower half of the face
y_ground_truth = image[:, n_pixels // 2 :]
# y_predict: Dictionary of predicted lower-faces
y_predict = dict()
for name, estimator in ESTIMATORS.items():
y_predict[name] = estimator.predict(X_upper)
plt.figure(figsize=(2.0 * n_cols, 2.5 * n_faces))
# plt.suptitle("Face completion with multi-output estimators", size=16)
true_face = np.hstack((X_upper, y_ground_truth))
sub = plt.subplot(n_faces, n_cols, 1, title="true face")
sub.axis("off")
sub.imshow(
true_face.reshape(image_shape), cmap=plt.cm.gray, interpolation="nearest"
)
for j, est in enumerate(sorted(ESTIMATORS)):
completed_face = np.hstack((X_upper[0], y_predict[est][0]))
sub = plt.subplot(n_faces, n_cols, 2 + j, title=est)
sub.axis("off")
sub.imshow(
completed_face.reshape(image_shape),
cmap=plt.cm.gray,
interpolation="nearest",
)
return plt
with gr.Blocks() as demo:
link = "https://scikit-learn.org/stable/auto_examples/miscellaneous/plot_multioutput_face_completion.html#sphx-glr-auto-examples-miscellaneous-plot-multioutput-face-completion-py"
title = "Face completion with a multi-output estimators"
gr.Markdown(f"# {title}")
gr.Markdown(f"### This demo is based on this [scikit-learn example]({link}).")
gr.Markdown("### In this demo, we compare 4 multi-output estimators to complete images. \
The goal is to predict the lower half of a face given its upper half.")
gr.Markdown("#### Use the below slider to choose a face's image. \
Consequently, observe how the four estimators complete the lower half of that face.")
with gr.Row():
with gr.Column(scale=1):
image_index = gr.Slider(1,15,1,step=1, label="Image Index", info="Choose an image")
face_image = gr.Image()
with gr.Column(scale=2):
plot = gr.Plot(label=f"Face completion with multi-output estimators")
image_index.change(imageFromIndex, inputs=[image_index], outputs=[face_image])
image_index.change(extrapolateFace, inputs=[image_index], outputs=[plot])
demo.load(imageFromIndex, inputs=[image_index], outputs=[face_image])
demo.load(extrapolateFace, inputs=[image_index], outputs=[plot])
if __name__ == "__main__":
demo.launch(debug=True) |