DTW-CNN / app.py
NahuelCosta's picture
Update app.py
cc85a94
raw
history blame contribute delete
No virus
2.78 kB
import numpy as np
import gradio as gr
import tensorflow as tf
import matplotlib.pyplot as plt
from matplotlib import cm
from PIL import Image
import pandas as pd
def normalise_data(data, min_val, max_val, low=0, high=1):
'''
Normalises the data to the range [low, high]
Parameters
----------
data: numpy array, data to normalise
low: float, minimum value of the range
high: float, maximum value of the range
Returns
-------
normalised_data: float, normalised data
'''
normalised_data = (data - min_val)/(max_val - min_val)
normalised_data = (high - low)*normalised_data + low
return normalised_data
data = np.load('./data.npy')
data_DTW = np.load('./data_DTW.npy')
model = tf.keras.models.load_model('./model.h5',compile = False)
def predict(Cell_number, Duty_Cycle, Cycle_number):
# ------------------------ Prediction ------------------------
# select cell data
# data = x_test_1 #if Cell_number == '1' else x_test_2 if Cell_number == '2' else x_test_3
# data_DTW = x_test_DTW_1 #if Cell_number == '1' else x_test_DTW_2 if Cell_number == '2' else x_test_DTW_3
# select cycle number
cycle = 0 if Cycle_number == '10' else 1 if Cycle_number == '50' else 2 if Cycle_number == '100' else 3 if Cycle_number == '200'else 4 if Cycle_number == '400' else 5
IC_reference = data[0][0]
sample = data[Duty_Cycle-1][cycle]
sample_DTW = data_DTW[Duty_Cycle-1][cycle] #getDTWImage(IC_reference, sample, size)
prediction = model.predict(np.expand_dims(sample_DTW, axis=0))
pred = {"LLI ": str(prediction[0][0]), "LAMPE ": str(prediction[0][1]), "LAMNE ": str(prediction[0][2])}
# --------------------------- IC + image----------------------------
d = {' ': np.linspace(1, len(IC_reference), len(IC_reference)), 'pristine': IC_reference, 'degraded': sample}
df = pd.DataFrame(data=d)
image_array=sample_DTW.reshape(sample_DTW.shape[0], sample_DTW.shape[1])
image_array = normalise_data(image_array, np.min(image_array), np.max(image_array))
im = Image.fromarray(np.uint8(cm.inferno(image_array)*255))
return pred, df, im
iface = gr.Interface(
fn=predict,
inputs=[gr.inputs.Radio(["Cell #1", "Cell #2", "Cell #3"]), gr.inputs.Slider(1, 1000, step=1), gr.inputs.Radio(["10", "50", "100", "200", "400", "1000"])],
title="LFP degradation diagnosis",
description="Enter cell number, duty cycle and cycle number to predict the percentage of LLI, LAMPE and LAMNE",
outputs=[gr.outputs.Label(label="Prediction"), gr.outputs.Timeseries(x=" ", y=["pristine", "degraded"], label="IC curves"), gr.outputs.Image(type='pil', label="DTW image")],
allow_screenshot=False,
theme="darkpeach",
layout="unaligned")
iface.launch()