DTW-CNN / app.py
NahuelCosta's picture
Update app.py
354059a
raw
history blame
2.7 kB
import numpy as np
import gradio as gr
import tensorflow as tf
import matplotlib.pyplot as plt
from matplotlib import cm
from PIL import Image
import pandas as pd
#from dtaidistance import dtw
'''
def getDTWImage(IC_reference, sample, size):
d, paths = dtw.warping_paths(IC_reference, sample, window=int(size/2), psi=2)
x = np.array(paths)
# mask values that are not filled
x = np.where(x == np.inf, -99, x)
# negative values are replaced by 0
x = np.where(x < 0, 0, x)
# normalise values
x = x/np.max(x)
# reshape the array
x = np.expand_dims(x, -1).astype("float32")
return x
'''
data = np.load('./data.npy')
data_DTW = np.load('./data/data_DTW.npy')
model = tf.keras.models.load_model('./models/model.h5',compile = False)
def predict(Cell_number, Duty_Cycle, Cycle_number):
# ------------------------ Prediction ------------------------
# select cell data
# data = x_test_1 #if Cell_number == '1' else x_test_2 if Cell_number == '2' else x_test_3
# data_DTW = x_test_DTW_1 #if Cell_number == '1' else x_test_DTW_2 if Cell_number == '2' else x_test_DTW_3
# select cycle number
cycle = 0 if Cycle_number == '10' else 1 if Cycle_number == '50' else 2 if Cycle_number == '100' else 3 if Cycle_number == '200'else 4 if Cycle_number == '400' else 5
IC_reference = data[0][0]
sample = data[Duty_Cycle-1][cycle]
sample_DTW = data_DTW[Duty_Cycle-1][cycle] #getDTWImage(IC_reference, sample, size)
prediction = model.predict(np.expand_dims(sample_DTW, axis=0))
pred = {"LLI ": str(prediction[0][0]), "LAMPE ": str(prediction[0][1]), "LAMNE ": str(prediction[0][2])}
# --------------------------- ICA + image----------------------------
d = {' ': np.linspace(1, len(ICA_reference), len(ICA_reference)), 'pristine': ICA_reference, 'degraded': sample}
df = pd.DataFrame(data=d)
image_array=sample_DTW.reshape(sample_DTW.shape[0], sample_DTW.shape[1])
image_array = normalise_data(image_array, np.min(image_array), np.max(image_array))
im = Image.fromarray(np.uint8(cm.inferno(image_array)*255))
return pred, df, im
iface = gr.Interface(
fn=predict,
inputs=[gr.inputs.Radio(["Cell #1", "Cell #2", "Cell #3"]), gr.inputs.Slider(1, 1000, step=1), gr.inputs.Radio(["10", "50", "100", "200", "400", "1000"]), "checkbox"],
title="LFP degradation diagnosis",
description="Enter cell number, duty cycle and cycle number to predict the percentage of LLI, LAMPE and LAMNE",
outputs=[gr.outputs.Label(label="Prediction"), gr.outputs.Timeseries(x=" ", y=["pristine", "degraded"]), gr.outputs.Image(type='pil', label="DTW image")],
allow_screenshot=False,
layout="unaligned")
iface.launch(share=True)