nouamanetazi HF staff commited on
Commit
4ac4e3b
1 Parent(s): fafff42
Files changed (2) hide show
  1. app.py +46 -36
  2. utils.py +3 -3
app.py CHANGED
@@ -4,19 +4,31 @@ import numpy as np
4
  import glob
5
  import warnings
6
  import pandas as pd
7
- from .utils import OrthogonalRegularizer
8
 
 
9
  from huggingface_hub.keras_mixin import from_pretrained_keras
10
 
11
  # load model
12
- model = from_pretrained_keras("keras-io/pointnet_segmentation", custom_objects={"OrthogonalRegularizer": OrthogonalRegularizer})
 
 
13
 
14
  # Examples
15
  samples = []
16
- input_images = glob.glob('asset/source/*.png')
17
- examples = [[im, f"asset/source/{im.split('/')[-1].split('.')[0]}.csv", f'asset/ground_truth/{im_name}.png'] for im in input_images]
18
- LABELS = ['wing', 'body', 'tail', 'engine']
19
-
 
 
 
 
 
 
 
 
 
20
  def visualize_data(point_cloud, labels, output_path=None):
21
  df = pd.DataFrame(
22
  data={
@@ -31,9 +43,7 @@ def visualize_data(point_cloud, labels, output_path=None):
31
  for index, label in enumerate(LABELS):
32
  c_df = df[df["label"] == label]
33
  try:
34
- ax.scatter(
35
- c_df["x"], c_df["y"], c_df["z"], label=label, alpha=0.5, c=COLORS[index]
36
- )
37
  except IndexError:
38
  pass
39
  ax.legend()
@@ -41,45 +51,45 @@ def visualize_data(point_cloud, labels, output_path=None):
41
  os.makedirs(os.path.dirname(output_path), exist_ok=True)
42
  plt.savefig(output_path)
43
 
44
- def inference(im_path,
45
- truth_path,
46
- file_obj,
47
- output_path = 'asset/output',
48
- cpu = False,
49
- ):
 
 
50
 
51
  csv_path = file_obj.name
52
- im_name = csv_path.split('/')[-1].split('.')[0]
53
-
54
  if os.path.exists(csv_path):
55
  df = pd.read_csv(csv_path, index_col=None)
56
- inputs = df[['x', 'y', 'z']].values
57
  y_test = df.iloc[:, 3:].values
58
  else:
59
- warnings.warn(f'{csv_path} not found for {im_path}')
60
  return
61
 
62
-
63
  preds = model.predict(np.expand_dims(inputs, 0))[0]
64
  label_map = LABELS + ["none"]
65
- visualize_data(inputs, [label_map[np.argmax(label)] for label in preds], f'{output_path}/{im_name}.png')
66
- return f'{output_path}/{im_name}.png'
 
67
 
68
  article = "<div style='text-align: center;'><a href='https://nouamanetazi.me/' target='_blank'>Space by Nouamane Tazi</a><br><a href='https://keras.io/examples/vision/pointnet_segmentation' target='_blank'>Keras example by Soumik Rakshit, Sayak Paul</a></div>"
69
 
70
  iface = gr.Interface(
71
- inference, # main function
72
- inputs = [
73
- gr.inputs.Image(label='Image', type="filepath"),
74
- gr.inputs.Image(label='Ground Truth', type="filepath"),
75
- "file"
76
-
 
 
77
  ],
78
- outputs = [
79
- gr.outputs.Image(label='result'), # generated image
80
- ],
81
-
82
- title = 'Point cloud segmentation with PointNet',
83
- article = article,
84
- examples = examples,
85
- ).launch(enable_queue=True, cache_examples=True)
 
4
  import glob
5
  import warnings
6
  import pandas as pd
7
+ import matplotlib.pyplot as plt
8
 
9
+ from utils import OrthogonalRegularizer
10
  from huggingface_hub.keras_mixin import from_pretrained_keras
11
 
12
  # load model
13
+ model = from_pretrained_keras(
14
+ "keras-io/pointnet_segmentation", custom_objects={"OrthogonalRegularizer": OrthogonalRegularizer}
15
+ )
16
 
17
  # Examples
18
  samples = []
19
+ input_images = glob.glob("asset/source/*.png")
20
+ examples = [
21
+ [
22
+ im,
23
+ f"asset/ground_truth/{im.split('/')[-1].split('.')[0]}.png",
24
+ f"asset/source/{im.split('/')[-1].split('.')[0]}.csv",
25
+ ]
26
+ for im in input_images
27
+ ]
28
+ LABELS = ["wing", "body", "tail", "engine"]
29
+ COLORS = ["blue", "green", "red", "pink"]
30
+
31
+
32
  def visualize_data(point_cloud, labels, output_path=None):
33
  df = pd.DataFrame(
34
  data={
 
43
  for index, label in enumerate(LABELS):
44
  c_df = df[df["label"] == label]
45
  try:
46
+ ax.scatter(c_df["x"], c_df["y"], c_df["z"], label=label, alpha=0.5, c=COLORS[index])
 
 
47
  except IndexError:
48
  pass
49
  ax.legend()
 
51
  os.makedirs(os.path.dirname(output_path), exist_ok=True)
52
  plt.savefig(output_path)
53
 
54
+
55
+ def inference(
56
+ im_path,
57
+ truth_path,
58
+ file_obj,
59
+ output_path="asset/output",
60
+ cpu=False,
61
+ ):
62
 
63
  csv_path = file_obj.name
64
+ im_name = csv_path.split("/")[-1].split(".")[0]
65
+
66
  if os.path.exists(csv_path):
67
  df = pd.read_csv(csv_path, index_col=None)
68
+ inputs = df[["x", "y", "z"]].values
69
  y_test = df.iloc[:, 3:].values
70
  else:
71
+ warnings.warn(f"{csv_path} not found for {im_path}")
72
  return
73
 
 
74
  preds = model.predict(np.expand_dims(inputs, 0))[0]
75
  label_map = LABELS + ["none"]
76
+ visualize_data(inputs, [label_map[np.argmax(label)] for label in preds], f"{output_path}/{im_name}.png")
77
+ return f"{output_path}/{im_name}.png"
78
+
79
 
80
  article = "<div style='text-align: center;'><a href='https://nouamanetazi.me/' target='_blank'>Space by Nouamane Tazi</a><br><a href='https://keras.io/examples/vision/pointnet_segmentation' target='_blank'>Keras example by Soumik Rakshit, Sayak Paul</a></div>"
81
 
82
  iface = gr.Interface(
83
+ inference, # main function
84
+ inputs=[
85
+ gr.inputs.Image(label="Image", type="filepath"),
86
+ gr.inputs.Image(label="Ground Truth", type="filepath"),
87
+ "file",
88
+ ],
89
+ outputs=[
90
+ gr.outputs.Image(label="result"), # generated image
91
  ],
92
+ title="Point cloud segmentation with PointNet",
93
+ article=article,
94
+ examples=examples,
95
+ ).launch(enable_queue=True, cache_examples=True)
 
 
 
 
utils.py CHANGED
@@ -1,5 +1,7 @@
 
1
  from tensorflow import keras
2
 
 
3
  class OrthogonalRegularizer(keras.regularizers.Regularizer):
4
  """Reference: https://keras.io/examples/vision/pointnet/#build-a-model"""
5
 
@@ -12,9 +14,7 @@ class OrthogonalRegularizer(keras.regularizers.Regularizer):
12
  identity = tf.cast(self.identity, x.dtype)
13
  x = tf.reshape(x, (tf.shape(x)[0], self.num_features, self.num_features))
14
  xxt = tf.tensordot(x, x, axes=(2, 2))
15
- xxt = tf.reshape(
16
- xxt, (tf.shape(x)[0] * tf.shape(x)[0], self.num_features, self.num_features)
17
- )
18
  return tf.reduce_sum(self.l2reg * tf.square(xxt - identity))
19
 
20
  def get_config(self):
 
1
+ import tensorflow as tf
2
  from tensorflow import keras
3
 
4
+
5
  class OrthogonalRegularizer(keras.regularizers.Regularizer):
6
  """Reference: https://keras.io/examples/vision/pointnet/#build-a-model"""
7
 
 
14
  identity = tf.cast(self.identity, x.dtype)
15
  x = tf.reshape(x, (tf.shape(x)[0], self.num_features, self.num_features))
16
  xxt = tf.tensordot(x, x, axes=(2, 2))
17
+ xxt = tf.reshape(xxt, (tf.shape(x)[0] * tf.shape(x)[0], self.num_features, self.num_features))
 
 
18
  return tf.reduce_sum(self.l2reg * tf.square(xxt - identity))
19
 
20
  def get_config(self):