liesdillen's picture
Update app.py
4c41b67 verified
raw
history blame
7.71 kB
import csv
import numpy as np
import gradio as gr
import plotly.graph_objs as go
import datetime
from plotly.subplots import make_subplots
from testing_interface import model_defining
# Function to load data from a text file into a numpy array of the right format
def load_data(filename):
with open(filename, 'r') as file:
data = [int(line.strip()) for line in file]
data = np.array(data, dtype=int)
data[data == -100] = -1
return data.reshape((-1, 1))
# Convert a string with comma decimal separator to a float
def convert_float(value_str):
return float(value_str.replace(',', '.'))
# Find indicated time in the acc_data
def find_index(timestamps, date, begin_time, end_time):
if date.startswith('0'):
date = date[1:]
begin_string = date + " " + begin_time + ":00.000"
end_string = date + " " + end_time + ":00.000"
begin_index = None
end_index = None
# Iterate over timestamps to find the indices
for index, timestamp in enumerate(timestamps):
if begin_index is None and timestamp == begin_string:
begin_index = index
if end_index is None and timestamp == end_string:
end_index = index
if begin_index is not None and end_index is not None:
break
return begin_index, end_index
def process_csv(file, date, begin_time, end_time):
if file is None:
return "No file uploaded", None
try:
date_t = datetime.datetime.strptime(date, '%d/%m/%Y')
begin_time_t = datetime.datetime.strptime(begin_time, '%H:%M').time()
end_time_t = datetime.datetime.strptime(end_time, '%H:%M').time()
except ValueError:
return "Invalid date or time format. Please use DD/MM/YYYY for date and HH:MM for time.", None
acc_data = []
time_parts = []
# Read and prep accelerometer data
with open(file.name, 'r', newline='') as csvfile:
csv_reader = csv.reader(csvfile)
for _ in range(11):
next(csv_reader)
count = 0
for row in csv_reader:
timestamp_parts = row[0].split()
time_parts.append(row[0])
if len(timestamp_parts) >= 2:
time_part = timestamp_parts[1]
converted_values = [convert_float(val_str) for val_str in row[1:]]
if len(converted_values) >= 3:
row_data = [count, time_part] + converted_values
acc_data.append(row_data)
count += 1
# Write acc_data to a numpy array
acc_data = np.array(acc_data)[:, 2:].astype(float) # Convert strings to floats
begin_index, end_index = find_index(time_parts, date, begin_time, end_time)
# Check if indexes are in acc_data
if end_index is None:
return "End time not found in data. Please check the specified end time.", None
if begin_index is None:
return "Begin time not found in data. Please check the specified begin time.", None
acc_data = acc_data[begin_index-121:end_index+122, :]
time_parts = time_parts[begin_index-121:end_index+122]
# Call the model_defining function from testing.py
name_model = "S3_101_102_103_validation_epoch_10.pth"
output_file = "predicted_labels.txt"
model_defining(acc_data, name_model, output_file)
# Load the data
predicted_labels = load_data("predicted_labels.txt")
# Remove first and last 121 samples from acc_data
acc_data = acc_data[121:-121, :]
time_parts = time_parts[121:-121]
# Append the new columns
complete_array = np.hstack((acc_data, predicted_labels))
# Calculate the total number of predicted functional and non-functional activity
total_predicted_functional = np.sum(complete_array[:, 3] != 0)
total_predicted_non_functional = np.sum(complete_array[:, 3] == 0)
# Calculate percentages
predicted_functional_percentage = (total_predicted_functional / len(complete_array)) * 100
predicted_non_functional_percentage = (total_predicted_non_functional / len(complete_array)) * 100
# Calculate the total number of milliseconds for functional and non-functional activity
total_predicted_functional_ms = np.sum(complete_array[:, 3] != 0) * 33.333333
total_predicted_non_functional_ms = np.sum(complete_array[:, 3] == 0) * 33.333333
# Convert milliseconds to minutes
predicted_functional_minutes = total_predicted_functional_ms / (1000 * 60)
predicted_non_functional_minutes = total_predicted_non_functional_ms / (1000 * 60)
# Format the minutes into hours, minutes, and seconds
predicted_functional_time = "{:02}:{:02}:{:02}".format(int(predicted_functional_minutes // 60), int(predicted_functional_minutes % 60), int(predicted_functional_minutes % 1 * 60))
predicted_non_functional_time = "{:02}:{:02}:{:02}".format(int(predicted_non_functional_minutes // 60), int(predicted_non_functional_minutes % 60), int(predicted_non_functional_minutes % 1 * 60))
# Formulate return string
return_string = f"Percentage of predicted functional activity: {predicted_functional_percentage:.2f}%\nPercentage of predicted non-functional activity: {predicted_non_functional_percentage:.2f}%\n\nNumber of minutes of functional activity in predicted labels: {predicted_functional_time}\nNumber of minutes of non-functional activity in predicted labels: {predicted_non_functional_time}\n"
# Create subplots
fig = make_subplots(rows=2, cols=1, shared_xaxes=True, row_heights=[0.6, 0.4], specs=[[{"type": "scatter"}], [{"type": "scatter"}]])
# Add traces to the subplots
fig.add_trace(go.Scatter(x=time_parts, y=complete_array[:, 0], mode='lines', name='Acc X', line=dict(width=0.75)), row=1, col=1)
fig.add_trace(go.Scatter(x=time_parts, y=complete_array[:, 1], mode='lines', name='Acc Y', line=dict(width=0.75)), row=1, col=1)
fig.add_trace(go.Scatter(x=time_parts, y=complete_array[:, 2], mode='lines', name='Acc Z', line=dict(width=0.75)), row=1, col=1)
fig.add_trace(go.Scatter(x=time_parts, y=complete_array[:, 3], mode='lines', name='Predicted labels', line=dict(width=1)), row=2, col=1)
# Update layout
fig.update_layout(
title='Accelerometer Data with Annotated Labels',
xaxis=dict(title='Time (milliseconds)'),
yaxis=dict(title='Accelerometer Data'),
yaxis2=dict(title='Predicted'),
showlegend=True,
height=600
)
return return_string, fig
with gr.Blocks(theme=gr.themes.Base()) as demo:
gr.Markdown(
"""
# Functional Upper Limb Activity Recognition Model
Upload your csv file containing accelerometer data to obtain a prediction on the amount of functional activity of the upper limbs.
""")
with gr.Row(equal_height=True):
with gr.Column():
input_file = gr.File(label="Upload CSV file")
input_date = gr.Textbox(label="Date (DD/MM/YYYY)")
input_begin_time = gr.Textbox(label="Begin Time (HH:MM)")
input_end_time = gr.Textbox(label="End Time (HH:MM)")
with gr.Row():
submit_btn = gr.Button("Submit", variant='primary')
clear_btn = gr.Button("Clear", variant='secondary')
output_text = gr.Textbox(label="Prediction statistics")
output_plot = gr.Plot(label="CSV Plot")
submit_btn.click(fn=process_csv, inputs=[input_file, input_date, input_begin_time, input_end_time], outputs=[output_text, output_plot])
clear_btn.click(fn=lambda: (None, "", "", ""), outputs=[input_file, input_date, input_begin_time, input_end_time, output_text, output_plot])
demo.launch()