# Import necessary libraries
library(shiny)
library(caret)
library(readr)
library(catboost)
library(ggplot2)
library(gridExtra)
source("calculate_shap.R")
source("plot_shap.R")
# Load the pre-trained model
model <- readRDS("goat_behavior_model_caret.rds")
# Define UI for application
ui <- fluidPage(
# App title ----
titlePanel("Detecting Goat Behaviors"),
# Sidebar layout with input and output definitions ----
sidebarLayout(
# Sidebar panel for inputs ----
sidebarPanel(
# Input: Select a file ----
fileInput("file1", "Choose TSV File",
accept = c(
"text/tsv",
"text/tab-separated-values,text/plain",
".tsv")
)
),
# Main panel for displaying outputs ----
mainPanel(
# Output: Tabset with data, confusion matrix, and download button
tabsetPanel(
id = "dataset",
tabPanel("About",
HTML("
The following model was part of the the research article:
Developing an Interpretable Machine Learning Model for the Detection of Mimosa Grazing in Goats
You can test the app using an example dataset available here
A dataset is already preloaded in the app for demostration purposes
In the last years, several machine learning approaches for detecting animal behaviors have been proposed.
However, despite their successful application, their complexity and lack of explainability have difficulty in their
application to real-world scenarios. The article presents a machine-learning model for differentiating between grazing mimosa and other activities
(resting, walking, and grazing ) in goats using sensor data. Boruta, an algorithm for selecting the most relevant features, and SHAP,
a technique for interpreting the decision of a machine learning model are two fundamental components of the methodology used for creating the model.
The resulting model, a gradient boost algorithm with 15 selected features proved to be extremely accurate in detecting Grazing activities.
The study demonstrates the fundamental role of model explainability in identifying model weaknesses and errors, thereby creating a path for future
improvements. In addition, the simplicity of the resulting model not only reduces computational complexity and processing time but also enhances
interpretability and facilitates the deployment of real-life scenarios.
This application allows users to test the pre-trained machine learning model that predicts goat behavior based on input sensor data.
The input data should be a tab-separated value (.tsv) file containing specific sensor data related to the goat's activity.
The application then generates predictions, provides a confusion matrix result, and offers the option to download the predictions. In addition you can explore the decisions of the model via SHAP analysis.
The key features expected in the dataset are:
No |
Feature |
Definition |
1 |
Steps |
Number of steps |
2 |
HeadDown |
% time with head down |
3 |
Standing |
% time Standing |
4 |
Active |
% time Active |
5 |
MeanXY |
Arithmetic mean between X and Y positions |
6 |
Distance |
Distance in meters |
7 |
prev_steps1 |
Number of steps one step backward |
8 |
X_Act |
X position actuator |
9 |
prev_Active1 |
% time Active one step backward |
10 |
prev_Standing1 |
% time Standing one step backward |
11 |
DFA123 |
Accumulative Euclidean distance from actual position to three positions forward |
12 |
prev_headdown1 |
% time with head down one step backward |
13 |
Lying |
% time Lying |
14 |
Y_Act |
Y position actuator |
15 |
DBA123 |
Accumulative Euclidean distance from actual position to three positions backward |
Experiments, datasets and source code and more here
"
)
),
tabPanel("Results",
tableOutput("contents"),
verbatimTextOutput("confusionMatText"),
plotOutput("confusionMatPlot"),
downloadButton("downloadData", "Download Predictions")),
tabPanel("SHAP Summary",
plotOutput("SHAPSummary")),
tabPanel("SHAP Summary per class",
plotOutput("SHAPSummaryperclass")),
tabPanel("SHAP Dependency",
plotOutput("SHAPDependency"))
)
)
)
)
# Define server logic
server <- function(input, output) {
# For the predictions dataset
# Path to the default file
default_file_path <- "https://raw.githubusercontent.com/harpomaxx/goat-behavior-model/main/data/split/dataset_b.tsv"
predictions <- reactive({
# Use default file if no file is uploaded
file_path <- if (is.null(input$file1)) {
default_file_path
} else {
input$file1$datapath
}
dataset <- readr::read_delim(file_path, delim='\t')
predict(model, dataset)
})
# For the table
output$contents <- renderTable({
# a file, it will be a data frame with 'name', 'size', 'type', and 'datapath' variables.
# Use default file if no file is uploaded
file_path <- if (is.null(input$file1)) {
default_file_path
} else {
input$file1$datapath
}
dataset <- readr::read_delim(file_path, delim='\t')
head(dataset, n = 5)
})
# Download function for predictions
output$downloadData <- downloadHandler(
filename = function() {
paste("predictions-", Sys.Date(), ".csv", sep="")
},
content = function(file) {
write.csv(data.frame(Index = 1:length(predictions()), Prediction = predictions()), file, row.names = FALSE)
}
)
# Confusion Matrix
output$confusionMatText <- renderPrint({
file_path <- if (is.null(input$file1)) {
default_file_path
} else {
input$file1$datapath
}
dataset <- readr::read_delim(file_path,delim='\t',progress = FALSE)
predictions <- predict(model, dataset)
cm<-caret::confusionMatrix(reference=as.factor(dataset$Activity),predictions,mode="everything")
cm$overall
})
output$confusionMatPlot <- renderPlot({
file_path <- if (is.null(input$file1)) {
default_file_path
} else {
input$file1$datapath
}
dataset <- readr::read_delim(file_path,delim='\t')
predictions <- predict(model, dataset)
cm<-caret::confusionMatrix(reference=as.factor(dataset$Activity),predictions,mode="everything")
# Extract table data from confusion matrix
confusionMatrixTable <- as.table(cm$table)
# Plot the confusion matrix
ggplot(as.data.frame(confusionMatrixTable), aes(x=Reference, y=Prediction)) +
geom_tile(aes(fill = log(Freq)), colour = "white") +
geom_text(aes(label = sprintf("%1.0f", Freq)), vjust = 1) +
scale_fill_gradient(low = "white", high = "steelblue") +
theme_minimal() +
theme(axis.text.x = element_text(angle = 45, vjust = 1, hjust = 1))
})
output$SHAPSummary <- renderPlot({
file_path <- if (is.null(input$file1)) {
default_file_path
} else {
input$file1$datapath
}
dataset <- readr::read_delim(file_path,delim='\t')
predictions <- predict(model, dataset)
selected_variables <-
readr::read_delim(
"selected_features.tsv",
col_types = cols(),
delim = '\t'
)
new_dataset <-
dataset %>% select(selected_variables$variable, Anim, Activity)
new_dataset <- cbind(new_dataset, predictions)
shap_values <- calculate_shap(new_dataset, model, nsim = 30)
pall<-shap_summary_plot(shap_values %>% as.data.frame())
pall+xlim(0,0.35)
})
output$SHAPSummaryperclass <- renderPlot({
file_path <- if (is.null(input$file1)) {
default_file_path
} else {
input$file1$datapath
}
dataset <- readr::read_delim(file_path,delim='\t')
predictions <- predict(model, dataset)
selected_variables <-
readr::read_delim(
"selected_features.tsv",
col_types = cols(),
delim = '\t'
)
new_dataset <-
dataset %>% select(selected_variables$variable, Anim, Activity)
new_dataset <- cbind(new_dataset, predictions)
shap_values <- calculate_shap(new_dataset, model, nsim = 30)
pW<-shap_summary_plot_perclass(shap_values, class= "W",color="#C77CFF")+xlab("Activity W")+xlim(0,0.25)
pGM<-shap_summary_plot_perclass(shap_values, class= "GM",color="#7CAE00")+xlab("Activity GM")+xlim(0,0.25)
pG<-shap_summary_plot_perclass(shap_values, class= "G",color="#F8766D")+xlab("Activity G")+xlim(0,0.25)
pR<-shap_summary_plot_perclass(shap_values, class= "R",color="#00BFC4")+xlab("Activity R")+xlim(0,0.25)
grid.arrange(pW,pR,pG,pGM)
})
output$SHAPDependency <- renderPlot({
file_path <- if (is.null(input$file1)) {
default_file_path
} else {
input$file1$datapath
}
dataset <- readr::read_delim(file_path,delim='\t')
predictions <- predict(model, dataset)
selected_variables <-
readr::read_delim(
"selected_features.tsv",
col_types = cols(),
delim = '\t'
)
new_dataset <-
dataset %>% select(selected_variables$variable, Anim, Activity)
new_dataset <- cbind(new_dataset, predictions)
shap_values <- calculate_shap(new_dataset, model, nsim = 30)
li<-list()
li[[1]]<-dependency_plot("Steps",dataset = new_dataset,shap=shap_values)
#li[[2]]<-dependency_plot("prev_steps1",dataset = new_dataset,shap=shap_values)
li[[2]]<-dependency_plot("%HeadDown",dataset = new_dataset,shap=shap_values)
#li[[4]]<-dependency_plot("prev_headdown1",dataset = new_dataset,shap=shap_values)
li[[3]]<-dependency_plot("Active",dataset = new_dataset,shap=shap_values)
#li[[6]]<-dependency_plot("prev_Active1",dataset = new_dataset,shap=shap_values)
li[[4]]<-dependency_plot("Standing",dataset = new_dataset,shap=shap_values)
#li[[8]]<-dependency_plot("prev_Standing1",dataset = new_dataset,shap=shap_values)
#li[[9]]<-dependency_plot("X_Act",dataset = new_dataset, shap=shap_values)
#li[[10]]<-dependency_plot("Y_Act",dataset = new_dataset, shap=shap_values)
#li[[11]]<-dependency_plot("DBA123",dataset = new_dataset, shap=shap_values)
#li[[12]]<-dependency_plot("DFA123",dataset = new_dataset, shap=shap_values)
do.call(grid.arrange, c(li, ncol = 1))
})
}
# Create a Shiny app object
shinyApp(ui = ui, server = server)