import os import gradio as gr import numpy as np import glob import warnings import pandas as pd import matplotlib.pyplot as plt from utils import OrthogonalRegularizer from huggingface_hub.keras_mixin import from_pretrained_keras # load model model = from_pretrained_keras( "keras-io/pointnet_segmentation", custom_objects={"OrthogonalRegularizer": OrthogonalRegularizer} ) # Examples samples = [] input_images = glob.glob("asset/source/*.csv") examples = [[im] for im in input_images] LABELS = ["wing", "body", "tail", "engine"] COLORS = ["blue", "green", "red", "pink"] def visualize_data(point_cloud, labels, output_path=None): df = pd.DataFrame( data={ "x": point_cloud[:, 0], "y": point_cloud[:, 1], "z": point_cloud[:, 2], "label": labels, } ) fig = plt.figure(figsize=(15, 10)) ax = plt.axes(projection="3d") for index, label in enumerate(LABELS): c_df = df[df["label"] == label] try: ax.scatter(c_df["x"], c_df["y"], c_df["z"], label=label, alpha=0.5, c=COLORS[index]) except IndexError: pass ax.legend() if output_path: os.makedirs(os.path.dirname(output_path), exist_ok=True) plt.savefig(output_path) def inference( csv_file, output_path="asset/output", cpu=False, ): csv_path = csv_file.name im_name = csv_path.split("/")[-1].split(".")[0] if os.path.exists(csv_path): df = pd.read_csv(csv_path, index_col=None) inputs = df[["x", "y", "z"]].values y_test = df.iloc[:, 3:].values # TODO: show ground truth image if y_test is not None else: warnings.warn(f"{csv_path} not found for {im_path}") return preds = model.predict(np.expand_dims(inputs, 0))[0] label_map = LABELS + ["none"] visualize_data(inputs, [label_map[np.argmax(label)] for label in preds], f"{output_path}/{im_name}.png") return f"{output_path}/{im_name}.png" article = "
Space by Nouamane Tazi
Keras example by Soumik Rakshit, Sayak Paul
" iface = gr.Interface( inference, # main function inputs=[ "file", ], outputs=[ gr.outputs.Image(label="result"), # generated image ], title="Point cloud segmentation with PointNet", article=article, examples=examples, cache_examples=True ).launch(enable_queue=True)