nouamanetazi's picture
nouamanetazi HF staff
Update app.py (#2)
bbe1d3e
import os
import gradio as gr
import numpy as np
import glob
import warnings
import pandas as pd
import matplotlib.pyplot as plt
from utils import OrthogonalRegularizer
from huggingface_hub.keras_mixin import from_pretrained_keras
# load model
model = from_pretrained_keras(
"keras-io/pointnet_segmentation", custom_objects={"OrthogonalRegularizer": OrthogonalRegularizer}
)
# Examples
samples = []
input_images = glob.glob("asset/source/*.csv")
examples = [[im] for im in input_images]
LABELS = ["wing", "body", "tail", "engine"]
COLORS = ["blue", "green", "red", "pink"]
def visualize_data(point_cloud, labels, output_path=None):
df = pd.DataFrame(
data={
"x": point_cloud[:, 0],
"y": point_cloud[:, 1],
"z": point_cloud[:, 2],
"label": labels,
}
)
fig = plt.figure(figsize=(15, 10))
ax = plt.axes(projection="3d")
for index, label in enumerate(LABELS):
c_df = df[df["label"] == label]
try:
ax.scatter(c_df["x"], c_df["y"], c_df["z"], label=label, alpha=0.5, c=COLORS[index])
except IndexError:
pass
ax.legend()
if output_path:
os.makedirs(os.path.dirname(output_path), exist_ok=True)
plt.savefig(output_path)
def inference(
csv_file,
output_path="asset/output",
cpu=False,
):
csv_path = csv_file.name
im_name = csv_path.split("/")[-1].split(".")[0]
if os.path.exists(csv_path):
df = pd.read_csv(csv_path, index_col=None)
inputs = df[["x", "y", "z"]].values
y_test = df.iloc[:, 3:].values # TODO: show ground truth image if y_test is not None
else:
warnings.warn(f"{csv_path} not found for {im_path}")
return
preds = model.predict(np.expand_dims(inputs, 0))[0]
label_map = LABELS + ["none"]
visualize_data(inputs, [label_map[np.argmax(label)] for label in preds], f"{output_path}/{im_name}.png")
return f"{output_path}/{im_name}.png"
article = "<div style='text-align: center;'><a href='https://nouamanetazi.me/' target='_blank'>Space by Nouamane Tazi</a><br><a href='https://keras.io/examples/vision/pointnet_segmentation' target='_blank'>Keras example by Soumik Rakshit, Sayak Paul</a></div>"
iface = gr.Interface(
inference, # main function
inputs=[
"file",
],
outputs=[
gr.outputs.Image(label="result"), # generated image
],
title="Point cloud segmentation with PointNet",
article=article,
examples=examples, cache_examples=True
).launch(enable_queue=True)