|
import numpy as np
|
|
import pandas as pd
|
|
import gradio as gr
|
|
import cv2
|
|
from tensorflow import keras as k
|
|
|
|
|
|
image_size = 256
|
|
num_classes = 3
|
|
|
|
in_channel_tool = 3
|
|
in_channel_spec = 9
|
|
img_rows, img_cols = image_size, image_size
|
|
|
|
|
|
model_class_path = f"Models/minape_base_multi_ts.h5"
|
|
|
|
|
|
csv_path = "Dataset/labels_sample.csv"
|
|
tool_path = "Dataset/tool"
|
|
spec_path = "Dataset/spec"
|
|
|
|
|
|
df = pd.read_csv(csv_path)
|
|
|
|
df["tool"] = df.id.map(lambda id: f"{tool_path}/{id}.jpg")
|
|
df["spec_x"] = df.id.map(lambda id: f"{spec_path}/x/{id}.jpg")
|
|
df["spec_y"] = df.id.map(lambda id: f"{spec_path}/y/{id}.jpg")
|
|
df["spec_z"] = df.id.map(lambda id: f"{spec_path}/z/{id}.jpg")
|
|
|
|
|
|
exs = []
|
|
for i in range(len(df)):
|
|
row = df.iloc[i,:]
|
|
tool_id = row.id
|
|
image_label = row.image_label
|
|
tool = row.tool
|
|
spec_x = row.spec_x
|
|
spec_y = row.spec_y
|
|
spec_z = row.spec_z
|
|
example = [tool_id, image_label, tool, spec_x, spec_y, spec_z]
|
|
exs.append(example)
|
|
|
|
|
|
def process_img(img, img_rows, img_cols, channels):
|
|
"""
|
|
Reads the spectogram files from disk and normalizes the pixel values
|
|
@params:
|
|
img - Data of the image
|
|
img_rows - The image height.
|
|
img_cols - The image width.
|
|
as_grey - Read the image as Greyscale or RGB.
|
|
channels - Number of channels.
|
|
@returns:
|
|
The created and compiled model (Model)
|
|
"""
|
|
img = cv2.imread(img)
|
|
img = cv2.resize(img, dsize=(img_rows, img_cols), interpolation=cv2.INTER_CUBIC)
|
|
img = np.asarray(img, dtype=np.float32)
|
|
|
|
|
|
|
|
img = img / 255.0
|
|
|
|
|
|
img = img.reshape(img_rows, img_cols, channels)
|
|
|
|
return img
|
|
|
|
def process_specs(img_x, img_y, img_z, img_rows, img_cols, channels):
|
|
img_x = cv2.imread(img_x)
|
|
img_y = cv2.imread(img_y)
|
|
img_z = cv2.imread(img_z)
|
|
|
|
img_x = cv2.resize(img_x, dsize=(img_rows, img_cols), interpolation=cv2.INTER_CUBIC)
|
|
img_y = cv2.resize(img_y, dsize=(img_rows, img_cols), interpolation=cv2.INTER_CUBIC)
|
|
img_z = cv2.resize(img_z, dsize=(img_rows, img_cols), interpolation=cv2.INTER_CUBIC)
|
|
img = np.concatenate([img_x, img_y, img_z], axis=2)
|
|
img = np.asarray(img, dtype=np.float32)
|
|
|
|
|
|
img = img / 255.0
|
|
return img
|
|
|
|
|
|
|
|
model_class = k.models.load_model(model_class_path, compile=False)
|
|
|
|
|
|
|
|
def predict(tool_id, label, tool, spec_x, spec_y, spec_z):
|
|
labels = ['sharp', 'used', 'dulled']
|
|
tool = process_img(tool, img_rows, img_cols, in_channel_tool)
|
|
spec = process_specs(spec_x, spec_y, spec_z, img_rows, img_cols, in_channel_spec)
|
|
|
|
|
|
inputs = [np.array([tool,]), np.array([spec,])]
|
|
y_score = model_class.predict(inputs)
|
|
y_pred = {label:float(score) for label, score in zip(labels, y_score[0])}
|
|
return [
|
|
gr.Label(value=label, label="Actual Label", visible=True), gr.Label(value=y_pred, label="Predicted Label", visible=True),
|
|
]
|
|
|
|
|
|
title = r"""
|
|
<h1 align="center">Minape</h1>
|
|
"""
|
|
description = r"""
|
|
<b>Official π€ Gradio demo</b> for <a href='https://github.com/hubtru/Minape'
|
|
target='_blank'><b>Multimodal, Isotropic Neural Architecture with Patch Embedding for Recognition of Device State</b></a>.<br>
|
|
"""
|
|
|
|
with gr.Blocks() as demo:
|
|
gr.Markdown(value=title)
|
|
gr.Markdown(description)
|
|
with gr.Row():
|
|
with gr.Column():
|
|
with gr.Row():
|
|
tool_id = gr.Textbox("T1R2B1", label="Tool")
|
|
label_input = gr.Textbox("Sharp", label="Label")
|
|
with gr.Row():
|
|
tool = gr.Image(label="Tool", type="filepath")
|
|
with gr.Row():
|
|
spec_x = gr.Image(label="Spec_x", type="filepath")
|
|
spec_y = gr.Image(label="Spec_y", type="filepath")
|
|
spec_z = gr.Image(label="Spec_z", type="filepath")
|
|
submit_btn = gr.Button("Submit", variant="primary")
|
|
|
|
with gr.Column():
|
|
output_labels = [
|
|
gr.Label("Sharp", label="Actual Label"),
|
|
gr.Label("Sharp", label="Predicted Label"),
|
|
]
|
|
|
|
examples = gr.Examples(examples=exs, inputs=[tool_id, label_input, tool, spec_x, spec_y, spec_z])
|
|
submit_btn.click(fn=predict, inputs=[tool_id, label_input, tool, spec_x, spec_y, spec_z], outputs=output_labels)
|
|
demo.launch() |