Spaces:
Sleeping
Sleeping
add SHAP values analysis
Browse files- Dockerfile +3 -0
- app.R +106 -1
- calculate_shap.R +328 -0
- plot_shap.R +299 -0
- selected_features.tsv +16 -0
Dockerfile
CHANGED
@@ -7,7 +7,10 @@ RUN install2.r --error \
|
|
7 |
ggExtra \
|
8 |
readr \
|
9 |
caret \
|
|
|
10 |
ggplot2 \
|
|
|
|
|
11 |
shiny
|
12 |
|
13 |
RUN install2.r --error \
|
|
|
7 |
ggExtra \
|
8 |
readr \
|
9 |
caret \
|
10 |
+
fastshap \
|
11 |
ggplot2 \
|
12 |
+
ggExtra \
|
13 |
+
forcats \
|
14 |
shiny
|
15 |
|
16 |
RUN install2.r --error \
|
app.R
CHANGED
@@ -5,6 +5,9 @@ library(readr)
|
|
5 |
library(catboost)
|
6 |
library(ggplot2)
|
7 |
|
|
|
|
|
|
|
8 |
# Load the pre-trained model
|
9 |
model <- readRDS("goat_behavior_model_caret.rds")
|
10 |
|
@@ -149,7 +152,18 @@ ui <- fluidPage(
|
|
149 |
tableOutput("contents"),
|
150 |
verbatimTextOutput("confusionMatText"),
|
151 |
plotOutput("confusionMatPlot"),
|
152 |
-
downloadButton("downloadData", "Download Predictions"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
)
|
154 |
)
|
155 |
)
|
@@ -229,6 +243,97 @@ server <- function(input, output) {
|
|
229 |
theme_minimal() +
|
230 |
theme(axis.text.x = element_text(angle = 45, vjust = 1, hjust = 1))
|
231 |
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
}
|
233 |
|
234 |
# Create a Shiny app object
|
|
|
5 |
library(catboost)
|
6 |
library(ggplot2)
|
7 |
|
8 |
+
source("calculate_shap.R")
|
9 |
+
source("plot_shap.R")
|
10 |
+
|
11 |
# Load the pre-trained model
|
12 |
model <- readRDS("goat_behavior_model_caret.rds")
|
13 |
|
|
|
152 |
tableOutput("contents"),
|
153 |
verbatimTextOutput("confusionMatText"),
|
154 |
plotOutput("confusionMatPlot"),
|
155 |
+
downloadButton("downloadData", "Download Predictions")),
|
156 |
+
|
157 |
+
tabPanel("SHAP Summary",
|
158 |
+
plotOutput("SHAPSummary")),
|
159 |
+
|
160 |
+
tabPanel("SHAP Summary per class",
|
161 |
+
plotOutput("SHAPSummaryperclass")),
|
162 |
+
|
163 |
+
tabPanel("SHAP Dependency",
|
164 |
+
plotOutput("SHAPDependency"))
|
165 |
+
|
166 |
+
|
167 |
)
|
168 |
)
|
169 |
)
|
|
|
243 |
theme_minimal() +
|
244 |
theme(axis.text.x = element_text(angle = 45, vjust = 1, hjust = 1))
|
245 |
})
|
246 |
+
|
247 |
+
output$SHAPSummary <- renderPlot({
|
248 |
+
|
249 |
+
if (is.null(input$file1))
|
250 |
+
return(NULL)
|
251 |
+
|
252 |
+
inFile <- input$file1
|
253 |
+
dataset <- readr::read_delim(inFile$datapath,delim='\t')
|
254 |
+
predictions <- predict(model, dataset)
|
255 |
+
selected_variables <-
|
256 |
+
readr::read_delim(
|
257 |
+
"selected_features.tsv",
|
258 |
+
col_types = cols(),
|
259 |
+
delim = '\t'
|
260 |
+
)
|
261 |
+
new_dataset <-
|
262 |
+
dataset %>% select(selected_variables$variable, Anim, Activity)
|
263 |
+
new_dataset <- cbind(new_dataset, predictions)
|
264 |
+
|
265 |
+
shap_values <- calculate_shap(new_dataset, model, nsim = 30)
|
266 |
+
pall<-shap_summary_plot(shap_values)
|
267 |
+
pall+xlim(0,0.35)
|
268 |
+
})
|
269 |
+
|
270 |
+
output$SHAPSummaryperclass <- renderPlot({
|
271 |
+
|
272 |
+
if (is.null(input$file1))
|
273 |
+
return(NULL)
|
274 |
+
|
275 |
+
inFile <- input$file1
|
276 |
+
dataset <- readr::read_delim(inFile$datapath,delim='\t')
|
277 |
+
predictions <- predict(model, dataset)
|
278 |
+
selected_variables <-
|
279 |
+
readr::read_delim(
|
280 |
+
"selected_features.tsv",
|
281 |
+
col_types = cols(),
|
282 |
+
delim = '\t'
|
283 |
+
)
|
284 |
+
new_dataset <-
|
285 |
+
dataset %>% select(selected_variables$variable, Anim, Activity)
|
286 |
+
new_dataset <- cbind(new_dataset, predictions)
|
287 |
+
|
288 |
+
shap_values <- calculate_shap(new_dataset, model, nsim = 30)
|
289 |
+
|
290 |
+
pW<-shap_summary_plot_perclass(shap_values, class= "W",color="#C77CFF")+xlab("Activity W")+xlim(0,0.25)
|
291 |
+
pGM<-shap_summary_plot_perclass(shap_values, class= "GM",color="#7CAE00")+xlab("Activity GM")+xlim(0,0.25)
|
292 |
+
pG<-shap_summary_plot_perclass(shap_values, class= "G",color="#F8766D")+xlab("Activity G")+xlim(0,0.25)
|
293 |
+
pR<-shap_summary_plot_perclass(shap_values, class= "R",color="#00BFC4")+xlab("Activity R")+xlim(0,0.25)
|
294 |
+
|
295 |
+
grid.arrange(pW,pR,pG,pGM)
|
296 |
+
|
297 |
+
|
298 |
+
|
299 |
+
})
|
300 |
+
output$SHAPDependency <- renderPlot({
|
301 |
+
|
302 |
+
if (is.null(input$file1))
|
303 |
+
return(NULL)
|
304 |
+
|
305 |
+
inFile <- input$file1
|
306 |
+
dataset <- readr::read_delim(inFile$datapath,delim='\t')
|
307 |
+
predictions <- predict(model, dataset)
|
308 |
+
selected_variables <-
|
309 |
+
readr::read_delim(
|
310 |
+
"selected_features.tsv",
|
311 |
+
col_types = cols(),
|
312 |
+
delim = '\t'
|
313 |
+
)
|
314 |
+
new_dataset <-
|
315 |
+
dataset %>% select(selected_variables$variable, Anim, Activity)
|
316 |
+
new_dataset <- cbind(new_dataset, predictions)
|
317 |
+
|
318 |
+
shap_values <- calculate_shap(new_dataset, model, nsim = 30)
|
319 |
+
|
320 |
+
li<-list()
|
321 |
+
li[[1]]<-dependency_plot("Steps",dataset = new_dataset,shap=shap_values)
|
322 |
+
#li[[2]]<-dependency_plot("prev_steps1",dataset = new_dataset,shap=shap_values)
|
323 |
+
li[[2]]<-dependency_plot("%HeadDown",dataset = new_dataset,shap=shap_values)
|
324 |
+
#li[[4]]<-dependency_plot("prev_headdown1",dataset = new_dataset,shap=shap_values)
|
325 |
+
li[[3]]<-dependency_plot("Active",dataset = new_dataset,shap=shap_values)
|
326 |
+
#li[[6]]<-dependency_plot("prev_Active1",dataset = new_dataset,shap=shap_values)
|
327 |
+
li[[4]]<-dependency_plot("Standing",dataset = new_dataset,shap=shap_values)
|
328 |
+
#li[[8]]<-dependency_plot("prev_Standing1",dataset = new_dataset,shap=shap_values)
|
329 |
+
#li[[9]]<-dependency_plot("X_Act",dataset = new_dataset, shap=shap_values)
|
330 |
+
#li[[10]]<-dependency_plot("Y_Act",dataset = new_dataset, shap=shap_values)
|
331 |
+
#li[[11]]<-dependency_plot("DBA123",dataset = new_dataset, shap=shap_values)
|
332 |
+
#li[[12]]<-dependency_plot("DFA123",dataset = new_dataset, shap=shap_values)
|
333 |
+
do.call(grid.arrange, c(li, ncol = 1))
|
334 |
+
|
335 |
+
|
336 |
+
})
|
337 |
}
|
338 |
|
339 |
# Create a Shiny app object
|
calculate_shap.R
ADDED
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
suppressPackageStartupMessages(library(dplyr))
|
2 |
+
suppressPackageStartupMessages(library(fastshap)) # for fast (approximate) Shapley values
|
3 |
+
suppressPackageStartupMessages(library(caret))
|
4 |
+
suppressPackageStartupMessages(library(doMC))
|
5 |
+
|
6 |
+
registerDoMC(cores = 10)
|
7 |
+
|
8 |
+
|
9 |
+
p_function_G <-
|
10 |
+
function(object, newdata)
|
11 |
+
caret::predict.train(object, newdata = newdata, type = "prob")[, "G"]
|
12 |
+
p_function_GM <-
|
13 |
+
function(object, newdata)
|
14 |
+
caret::predict.train(object, newdata = newdata, type = "prob")[, "GM"]
|
15 |
+
p_function_R <-
|
16 |
+
function(object, newdata)
|
17 |
+
caret::predict.train(object, newdata = newdata, type = "prob")[, "R"]
|
18 |
+
p_function_W <-
|
19 |
+
function(object, newdata)
|
20 |
+
caret::predict.train(object, newdata = newdata, type = "prob")[, "W"]
|
21 |
+
|
22 |
+
# DEPRECATED
|
23 |
+
calculate_shap_deprecated <- function(dataset,model,nsim=10) {
|
24 |
+
# library(doParallel)
|
25 |
+
# registerDoParallel(8)
|
26 |
+
|
27 |
+
trainset <- dataset %>% na.omit() %>%
|
28 |
+
as.data.frame()
|
29 |
+
trainset_y <- dataset %>%
|
30 |
+
select(Activity) %>%
|
31 |
+
na.omit() %>%
|
32 |
+
unlist() %>%
|
33 |
+
unname()
|
34 |
+
trainset <- trainset %>% select(-Activity)
|
35 |
+
trainset_G <- trainset[which(trainset_y == "G"), ]
|
36 |
+
trainset_GM <- trainset[which(trainset_y == "GM"), ]
|
37 |
+
trainset_R <- trainset[which(trainset_y == "R"), ]
|
38 |
+
trainset_W <- trainset[which(trainset_y == "W"), ]
|
39 |
+
|
40 |
+
|
41 |
+
# Compute fast (approximate) Shapley values using 50 Monte Carlo repetitions
|
42 |
+
message(" - Calculating SHAP values for class G")
|
43 |
+
shap_values_G <-
|
44 |
+
fastshap::explain(
|
45 |
+
model,
|
46 |
+
X = trainset,
|
47 |
+
pred_wrapper = p_function_G,
|
48 |
+
nsim = nsim,
|
49 |
+
newdata = trainset_G,
|
50 |
+
.parallel = TRUE
|
51 |
+
)
|
52 |
+
message(" - Calculating SHAP values for class GM")
|
53 |
+
shap_values_GM <-
|
54 |
+
fastshap::explain(
|
55 |
+
model,
|
56 |
+
X = trainset,
|
57 |
+
pred_wrapper = p_function_GM,
|
58 |
+
nsim = nsim,
|
59 |
+
newdata = trainset_GM,
|
60 |
+
.parallel = TRUE
|
61 |
+
)
|
62 |
+
message(" - Calculating SHAP values for class R")
|
63 |
+
shap_values_R <-
|
64 |
+
fastshap::explain(
|
65 |
+
model,
|
66 |
+
X = trainset,
|
67 |
+
pred_wrapper = p_function_R,
|
68 |
+
nsim = nsim,
|
69 |
+
newdata = trainset_R,
|
70 |
+
.parallel = TRUE
|
71 |
+
)
|
72 |
+
message(" - Calculating SHAP values for class W")
|
73 |
+
shap_values_W <-
|
74 |
+
fastshap::explain(
|
75 |
+
model,
|
76 |
+
X = trainset,
|
77 |
+
pred_wrapper = p_function_W,
|
78 |
+
nsim = nsim,
|
79 |
+
newdata = trainset_W,
|
80 |
+
.parallel = TRUE
|
81 |
+
# adjust = TRUE
|
82 |
+
)
|
83 |
+
|
84 |
+
shap_values_GM$class<-"GM"
|
85 |
+
shap_values_G$class<-"G"
|
86 |
+
shap_values_R$class<-"R"
|
87 |
+
shap_values_W$class<-"W"
|
88 |
+
|
89 |
+
shap_values<-rbind(shap_values_G,
|
90 |
+
shap_values_GM,
|
91 |
+
shap_values_R,
|
92 |
+
shap_values_W)
|
93 |
+
shap_values
|
94 |
+
}
|
95 |
+
|
96 |
+
|
97 |
+
#' A new function for calcualting SHAP values
|
98 |
+
#' the function returns a dataframe with SHAP values in the same
|
99 |
+
#' order of the original dataset.
|
100 |
+
#'
|
101 |
+
#' SHAP value dataframe also contains information about Animal and
|
102 |
+
#' the prediction of the model. Notice that SHAP are calculated considering
|
103 |
+
#' the class (ground truth) and not the prediction. The prediction column is only
|
104 |
+
#' used for filtering ana analysis. The function `calculate_shapp_class()` can be
|
105 |
+
#' used for calculating SHAP values on prediction
|
106 |
+
#'
|
107 |
+
#' @param dataset a dataset used for calcuating SHAP. The dataset is used for
|
108 |
+
#' permutation during SHAP calculation and also each class is filtered and SHAP
|
109 |
+
#' value for each class is calculated.
|
110 |
+
#' @param model a model
|
111 |
+
#' @param nsim number of monte carlo simulation
|
112 |
+
#'
|
113 |
+
#' @return
|
114 |
+
#' @export
|
115 |
+
#'
|
116 |
+
#' @examples
|
117 |
+
calculate_shap <- function(dataset,model,nsim=10) {
|
118 |
+
trainset <- dataset %>% na.omit() %>%
|
119 |
+
as.data.frame()
|
120 |
+
trainset_y <- dataset %>%
|
121 |
+
select(Activity) %>%
|
122 |
+
na.omit() %>%
|
123 |
+
unlist() %>%
|
124 |
+
unname()
|
125 |
+
## Create an ID for maintaining the order
|
126 |
+
trainset <- cbind(id=seq(1:nrow(trainset)), trainset)
|
127 |
+
trainset <- trainset %>% select(-Activity)
|
128 |
+
|
129 |
+
trainset_G <- trainset[which(trainset_y == "G"), ]
|
130 |
+
trainset_GM <- trainset[which(trainset_y == "GM"), ]
|
131 |
+
trainset_R <- trainset[which(trainset_y == "R"), ]
|
132 |
+
trainset_W <- trainset[which(trainset_y == "W"), ]
|
133 |
+
|
134 |
+
id <- c(trainset_G$id,
|
135 |
+
trainset_GM$id,
|
136 |
+
trainset_R$id,
|
137 |
+
trainset_W$id)
|
138 |
+
trainset <- trainset %>% select(-id)
|
139 |
+
trainset_G <- trainset_G %>% select(-id)
|
140 |
+
trainset_GM <- trainset_GM %>% select(-id)
|
141 |
+
trainset_R <- trainset_R %>% select(-id)
|
142 |
+
trainset_W <- trainset_W %>% select(-id)
|
143 |
+
|
144 |
+
Anim <- c(trainset_G$Anim,
|
145 |
+
trainset_GM$Anim,
|
146 |
+
trainset_R$Anim,
|
147 |
+
trainset_W$Anim)
|
148 |
+
trainset <- trainset %>% select(-Anim)
|
149 |
+
trainset_G <- trainset_G %>% select(-Anim)
|
150 |
+
trainset_GM <- trainset_GM %>% select(-Anim)
|
151 |
+
trainset_R <- trainset_R %>% select(-Anim)
|
152 |
+
trainset_W <- trainset_W %>% select(-Anim)
|
153 |
+
|
154 |
+
predictions <- c(trainset_G$predictions,
|
155 |
+
trainset_GM$predictions,
|
156 |
+
trainset_R$predictions,
|
157 |
+
trainset_W$predictions)
|
158 |
+
trainset <- trainset %>% select(-predictions)
|
159 |
+
trainset_G <- trainset_G %>% select(-predictions)
|
160 |
+
trainset_GM <- trainset_GM %>% select(-predictions)
|
161 |
+
trainset_R <- trainset_R %>% select(-predictions)
|
162 |
+
trainset_W <- trainset_W %>% select(-predictions)
|
163 |
+
|
164 |
+
# Compute fast (approximate) Shapley values using 50 Monte Carlo repetitions
|
165 |
+
message(" - Calculating SHAP values for class G")
|
166 |
+
shap_values_G <-
|
167 |
+
fastshap::explain(
|
168 |
+
model,
|
169 |
+
X = trainset,
|
170 |
+
pred_wrapper = p_function_G,
|
171 |
+
nsim = nsim,
|
172 |
+
newdata = trainset_G,
|
173 |
+
.parallel = TRUE
|
174 |
+
)
|
175 |
+
message(" - Calculating SHAP values for class GM")
|
176 |
+
shap_values_GM <-
|
177 |
+
fastshap::explain(
|
178 |
+
model,
|
179 |
+
X = trainset,
|
180 |
+
pred_wrapper = p_function_GM,
|
181 |
+
nsim = nsim,
|
182 |
+
newdata = trainset_GM,
|
183 |
+
.parallel = TRUE
|
184 |
+
)
|
185 |
+
message(" - Calculating SHAP values for class R")
|
186 |
+
shap_values_R <-
|
187 |
+
fastshap::explain(
|
188 |
+
model,
|
189 |
+
X = trainset,
|
190 |
+
pred_wrapper = p_function_R,
|
191 |
+
nsim = nsim,
|
192 |
+
newdata = trainset_R,
|
193 |
+
.parallel = TRUE
|
194 |
+
)
|
195 |
+
message(" - Calculating SHAP values for class W")
|
196 |
+
shap_values_W <-
|
197 |
+
fastshap::explain(
|
198 |
+
model,
|
199 |
+
X = trainset,
|
200 |
+
pred_wrapper = p_function_W,
|
201 |
+
nsim = nsim,
|
202 |
+
newdata = trainset_W,
|
203 |
+
.parallel = TRUE
|
204 |
+
# adjust = TRUE
|
205 |
+
)
|
206 |
+
|
207 |
+
shap_values_G$class<-"G"
|
208 |
+
shap_values_GM$class<-"GM"
|
209 |
+
shap_values_R$class<-"R"
|
210 |
+
shap_values_W$class<-"W"
|
211 |
+
|
212 |
+
shap_values<-rbind(shap_values_G,
|
213 |
+
shap_values_GM,
|
214 |
+
shap_values_R,
|
215 |
+
shap_values_W)
|
216 |
+
|
217 |
+
shap_values <- shap_values %>% tibble::add_column(Anim)
|
218 |
+
shap_values <- shap_values %>% tibble::add_column(predictions)
|
219 |
+
#shap_values <-shap_values %>% tibble::add_column(id)
|
220 |
+
shap_values[order(id),]
|
221 |
+
}
|
222 |
+
|
223 |
+
#' Calculate SHAP values for a given PREDICTED class
|
224 |
+
#'
|
225 |
+
#' @param dataset the dataset used for permutation during SHAP calculation
|
226 |
+
#' @param new_data the new data we want to calculate SHAP
|
227 |
+
#' @param model the model used for explanation
|
228 |
+
#' @param nsim the number of Monte Carlos Simulations
|
229 |
+
#' @param function_class a wrapper function to obtain only a particular class
|
230 |
+
#' @param class_name the name of the class
|
231 |
+
#'
|
232 |
+
#' @return
|
233 |
+
#' @export
|
234 |
+
#'
|
235 |
+
#' @examples
|
236 |
+
#'
|
237 |
+
#' # Calculate the SHAP values for class G on new data
|
238 |
+
#' shap_values_G <- calculate_shap_class(
|
239 |
+
#' dataset,
|
240 |
+
#' new_data = newdata,
|
241 |
+
#' model = goat_model
|
242 |
+
#' nsim = 100,
|
243 |
+
#' function_class = p_function_G,
|
244 |
+
#' class_name = "G")
|
245 |
+
#'
|
246 |
+
#'
|
247 |
+
calculate_shap_class <- function(dataset, new_data, model,nsim=10,
|
248 |
+
function_class, class_name = "G") {
|
249 |
+
trainset <- dataset %>% na.omit() %>%
|
250 |
+
as.data.frame()
|
251 |
+
trainset_y <- dataset %>%
|
252 |
+
select(predictions) %>%
|
253 |
+
na.omit() %>%
|
254 |
+
unlist() %>%
|
255 |
+
unname()
|
256 |
+
|
257 |
+
trainset<- trainset %>%select (-Activity,-predictions,-Anim)
|
258 |
+
new_data_class <- new_data
|
259 |
+
|
260 |
+
Anim <- new_data_class$Anim
|
261 |
+
new_data_class <- new_data_class %>% select(-Anim)
|
262 |
+
|
263 |
+
Activity <- new_data_class$Activity
|
264 |
+
new_data_class <- new_data_class %>% select(-Activity)
|
265 |
+
|
266 |
+
predictions <- new_data_class$predictions
|
267 |
+
new_data_class <- new_data_class %>% select(-predictions)
|
268 |
+
|
269 |
+
# Compute fast (approximate) Shapley values using 50 Monte Carlo repetitions
|
270 |
+
message(" - Calculating SHAP values for class ",class_name)
|
271 |
+
shap_values_class <-
|
272 |
+
fastshap::explain(
|
273 |
+
model,
|
274 |
+
X = trainset,
|
275 |
+
pred_wrapper = function_class,
|
276 |
+
nsim = nsim,
|
277 |
+
newdata = new_data_class,
|
278 |
+
.parallel = TRUE
|
279 |
+
)
|
280 |
+
|
281 |
+
shap_values_class$class<-Activity
|
282 |
+
shap_values<-shap_values_class
|
283 |
+
|
284 |
+
shap_values <- shap_values %>% tibble::add_column(Anim)
|
285 |
+
shap_values <- shap_values %>% tibble::add_column(predictions)
|
286 |
+
shap_values
|
287 |
+
}
|
288 |
+
|
289 |
+
shap_summary_plot<-function(shap_values){
|
290 |
+
summary_plot <-
|
291 |
+
shap_values %>% reshape2::melt() %>% group_by(class, variable) %>%
|
292 |
+
summarise(mean = mean(abs(value))) %>%
|
293 |
+
arrange(desc(mean)) %>%
|
294 |
+
ggplot() +
|
295 |
+
ggdark::dark_theme_classic() +
|
296 |
+
geom_col(aes(
|
297 |
+
y = variable,
|
298 |
+
x = mean,
|
299 |
+
group = class,
|
300 |
+
fill = class
|
301 |
+
), position = "stack") +
|
302 |
+
xlab("Mean(|Shap Value|) Average impact on model output magnitude")
|
303 |
+
summary_plot
|
304 |
+
|
305 |
+
}
|
306 |
+
|
307 |
+
shap_beeswarm_plot<-function(shap_values,dataset){
|
308 |
+
|
309 |
+
shap_values <- shap_values %>% reshape2::melt()
|
310 |
+
dataset<-dataset %>% mutate(class=Activity) %>% select(-Activity) %>%
|
311 |
+
reshape2::melt() %>% group_by(variable) %>%
|
312 |
+
mutate(value_scale=range01(value))
|
313 |
+
|
314 |
+
beeswarm_plot<-cbind(shap_values, feature_value=dataset$value_scale) %>%
|
315 |
+
# filter(class=="GM") %>%
|
316 |
+
ggplot()+
|
317 |
+
facet_wrap(~class)+
|
318 |
+
#ggdark::dark_theme_bw()+
|
319 |
+
theme_classic()+
|
320 |
+
geom_hline(yintercept=0,
|
321 |
+
color = "red", size=0.5)+
|
322 |
+
ggforce::geom_sina(aes(x=variable,y=value,color=feature_value),size=0.5,bins=4,alpha=0.9,shape=15)+
|
323 |
+
scale_colour_gradient(low = "yellow", high = "red", na.value = NA)+
|
324 |
+
scale_colour_gradient(low = "skyblue", high = "orange", na.value = NA)+
|
325 |
+
xlab("Feature")+ylab("SHAP value")+
|
326 |
+
theme(axis.text.x = element_text(angle = 45, vjust = 0.5, hjust=1))
|
327 |
+
beeswarm_plot
|
328 |
+
}
|
plot_shap.R
ADDED
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
suppressPackageStartupMessages(library(dplyr))
|
2 |
+
suppressPackageStartupMessages(library(ggplot2))
|
3 |
+
suppressPackageStartupMessages(library(ggExtra))
|
4 |
+
suppressPackageStartupMessages(library(forcats))
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
range01 <- function(x){(x-min(x))/(max(x)-min(x))}
|
9 |
+
|
10 |
+
shap_summary_plot<-function(shap_values){
|
11 |
+
summary_plot <-
|
12 |
+
shap_values %>% reshape2::melt() %>% group_by(class, variable) %>%
|
13 |
+
summarise(mean = mean(abs(value))) %>%
|
14 |
+
arrange(desc(mean)) %>%
|
15 |
+
ggplot() +
|
16 |
+
# ggdark::dark_theme_classic() +
|
17 |
+
theme_classic()+
|
18 |
+
geom_col(aes(
|
19 |
+
y = variable,
|
20 |
+
x = mean,
|
21 |
+
group = class,
|
22 |
+
fill = class
|
23 |
+
), position = "stack") +
|
24 |
+
ylab("Feature")+
|
25 |
+
xlab("Mean(|Shap Value|) Average impact on model output magnitude per activity.")+
|
26 |
+
guides(fill=guide_legend(title="Activity"))
|
27 |
+
summary_plot
|
28 |
+
|
29 |
+
}
|
30 |
+
|
31 |
+
|
32 |
+
shap_summary_plot_perclass<-function(shap_values, class="G",color="#F8766D"){
|
33 |
+
shap_values <-shap_values %>% as.data.frame() %>% filter(class == {{class}} )
|
34 |
+
summary_plot <-
|
35 |
+
shap_values %>% reshape2::melt() %>% group_by(variable) %>%
|
36 |
+
summarise(mean = mean(abs(value))) %>%
|
37 |
+
ggplot() +
|
38 |
+
theme_classic()+
|
39 |
+
geom_col(aes(
|
40 |
+
x = mean,
|
41 |
+
y = fct_reorder(variable,mean)
|
42 |
+
),
|
43 |
+
fill = color
|
44 |
+
) +
|
45 |
+
ylab("Feature")+
|
46 |
+
xlab(paste0("Mean(|Shap Value|) Average impact on model output magnitude for activity ", class))+
|
47 |
+
guides(fill=guide_legend(title="Activity"))
|
48 |
+
summary_plot
|
49 |
+
|
50 |
+
}
|
51 |
+
|
52 |
+
|
53 |
+
shap_beeswarm_plot<-function(shap_values,dataset){
|
54 |
+
|
55 |
+
shap_values <- shap_values %>% reshape2::melt()
|
56 |
+
dataset<-dataset %>% mutate(class=Activity) %>% select(-Activity) %>%
|
57 |
+
reshape2::melt() %>% group_by(variable) %>%
|
58 |
+
mutate(value_scale=range01(value))
|
59 |
+
|
60 |
+
beeswarm_plot<-cbind(shap_values, feature_value=dataset$value_scale) %>% # filter(class=="GM") %>%
|
61 |
+
ggplot()+
|
62 |
+
facet_wrap(~class)+
|
63 |
+
#ggdark::dark_theme_bw()+
|
64 |
+
theme_classic()+
|
65 |
+
geom_hline(yintercept=0,
|
66 |
+
color = "red", size=0.5)+
|
67 |
+
ggforce::geom_sina(aes(x=variable,y=value,fill=feature_value),color="black", size=2.4,bins=4,alpha=0.9,shape=22)+
|
68 |
+
scale_fill_gradient(low = "yellow", high = "red", na.value = NA)+
|
69 |
+
scale_fill_gradient(low = "skyblue", high = "orange", na.value = NA)+
|
70 |
+
xlab("Feature")+ylab("SHAP value")+
|
71 |
+
theme(axis.text.x = element_text(angle = 45, vjust = 0.5, hjust=1))
|
72 |
+
beeswarm_plot
|
73 |
+
|
74 |
+
|
75 |
+
}
|
76 |
+
|
77 |
+
|
78 |
+
#' Dependency plot for a particular feature. The plot considers
|
79 |
+
#' activities and FP/TP
|
80 |
+
#'
|
81 |
+
#' @param feature a particular feature to calculate
|
82 |
+
#' @param dataset a dataset with goat information
|
83 |
+
#' @param shap a shap value dataset for each feature.
|
84 |
+
#'
|
85 |
+
#' @return a dependency plot for each activity considering the selected feature
|
86 |
+
#' @export ggplot object
|
87 |
+
#'
|
88 |
+
#' @examples
|
89 |
+
#'
|
90 |
+
#' dataset <-
|
91 |
+
#' readr::read_delim("data/split/seba-caprino_loocv.tsv",
|
92 |
+
#' delim = '\t')
|
93 |
+
#' selected_variables <-
|
94 |
+
#' readr::read_delim(
|
95 |
+
#' "data/topnfeatures/seba-caprino_selected_features.tsv",
|
96 |
+
#' col_types = cols(),
|
97 |
+
#' delim = '\t'
|
98 |
+
#' )
|
99 |
+
#' dataset <-
|
100 |
+
#' dataset %>% select(selected_variables$variable,
|
101 |
+
#' Anim,
|
102 |
+
#' Activity)
|
103 |
+
#' goat_model <- readRDS("models/boost/seba-caprino_model.rds")
|
104 |
+
#' shap_values <- calculate_shap(dataset,
|
105 |
+
#' model = goat_model,
|
106 |
+
#' nsim = 30)
|
107 |
+
#' dependency_plot_full(feature = "Steps",
|
108 |
+
#' dataset = dataset,
|
109 |
+
#' shap = shap_values)
|
110 |
+
|
111 |
+
dependency_plot <- function(feature, dataset, shap) {
|
112 |
+
newdata <- dataset %>% mutate({{ feature }} := range01(!!sym(feature)))
|
113 |
+
#activities <- c("G", "GM", "W", "R")
|
114 |
+
activities<-dataset %>% pull(Activity) %>% unique()
|
115 |
+
plots <- list()
|
116 |
+
for (activity in activities) {
|
117 |
+
s <- shap[which(shap$class == activity), 1:18]
|
118 |
+
x <- newdata[which(newdata$Activity == activity), ]
|
119 |
+
data <- cbind(
|
120 |
+
shap = (s %>% as.data.frame %>% select(feature)),
|
121 |
+
feature = (x %>% select(feature)),
|
122 |
+
tp = x %>% mutate(tp = ifelse(Activity == predictions, "TP", "FP")) %>%
|
123 |
+
pull(tp)
|
124 |
+
)
|
125 |
+
names(data) <- c("shap", "feature", "tp")
|
126 |
+
p <- ggplot(data, aes(x = feature)) +
|
127 |
+
geom_point(aes(y = shap, color = tp), alpha = 0.3, size = 0.8) +
|
128 |
+
geom_smooth(aes(y = shap),
|
129 |
+
se = FALSE,
|
130 |
+
size = 0.5,
|
131 |
+
linetype = "dashed") +
|
132 |
+
geom_hline(
|
133 |
+
yintercept = 0,
|
134 |
+
color = 'red',
|
135 |
+
size = 0.5,
|
136 |
+
alpha = 0.5
|
137 |
+
) +
|
138 |
+
xlab(feature) +
|
139 |
+
labs(title = paste0("Activity ", activity)) +
|
140 |
+
ylab("SHAP Value") +
|
141 |
+
ylim(-0.1, 0.4) +
|
142 |
+
xlim(0, 1) +
|
143 |
+
theme_light() +
|
144 |
+
theme(legend.position = 'none')
|
145 |
+
|
146 |
+
p1 <-
|
147 |
+
ggMarginal(
|
148 |
+
p,
|
149 |
+
type = "histogram",
|
150 |
+
fill = 'gray',
|
151 |
+
color = 'white',
|
152 |
+
size = 10,
|
153 |
+
xparams = list(bins = 25),
|
154 |
+
yparams = list(bins = 15)
|
155 |
+
) #,margins='x')
|
156 |
+
plots[[activity]] <- p1
|
157 |
+
}
|
158 |
+
#plots
|
159 |
+
do.call(grid.arrange, c(plots, ncol = 4))
|
160 |
+
}
|
161 |
+
|
162 |
+
|
163 |
+
#' Dependency plot for a particular feature on a particular animal.
|
164 |
+
#' The plot considers activities and FP/TP
|
165 |
+
#'
|
166 |
+
#' @param feature a particular feature to calculate
|
167 |
+
#' @param dataset a dataset with goat information
|
168 |
+
#' @param shap a shap value dataset for each feature.
|
169 |
+
#' @param anim the id of the animal
|
170 |
+
#' @return a dependency plot for each activity considering the selected feature
|
171 |
+
#' @export ggplot object
|
172 |
+
#'
|
173 |
+
#' @examples
|
174 |
+
#'
|
175 |
+
#' dataset <-
|
176 |
+
#' readr::read_delim("data/split/seba-caprino_loocv.tsv",
|
177 |
+
#' delim = '\t')
|
178 |
+
#' selected_variables <-
|
179 |
+
#' readr::read_delim(
|
180 |
+
#' "data/topnfeatures/seba-caprino_selected_features.tsv",
|
181 |
+
#' col_types = cols(),
|
182 |
+
#' delim = '\t'
|
183 |
+
#' )
|
184 |
+
#' dataset <-
|
185 |
+
#' dataset %>% select(selected_variables$variable,
|
186 |
+
#' Anim,
|
187 |
+
#' Activity)
|
188 |
+
#' goat_model <- readRDS("models/boost/seba-caprino_model.rds")
|
189 |
+
#' shap_values <- calculate_shap(dataset,
|
190 |
+
#' model = goat_model,
|
191 |
+
#' nsim = 30)
|
192 |
+
#' dependency_plot_anim(feature = "Steps",
|
193 |
+
#' dataset = dataset,
|
194 |
+
#' shap = shap_values,
|
195 |
+
#' anim = 'a13')
|
196 |
+
dependency_plot_anim<- function(feature,dataset,shap,anim){
|
197 |
+
|
198 |
+
newdata <- dataset %>% mutate({{feature}} := range01(!!sym(feature)))
|
199 |
+
plots<-list()
|
200 |
+
activities<-newdata %>% filter(Anim == anim) %>% pull(Activity) %>% unique()
|
201 |
+
for (activity in activities) {
|
202 |
+
s <- shap[which(shap$class == activity &
|
203 |
+
shap$Anim == anim
|
204 |
+
), 1:18]
|
205 |
+
x <- newdata[which(newdata$Activity == activity &
|
206 |
+
newdata$Anim == anim
|
207 |
+
),]
|
208 |
+
data <- cbind(shap=(s %>% as.data.frame %>% select(feature)),
|
209 |
+
feature = (x %>% select(feature)),
|
210 |
+
tp = x %>% mutate(tp=ifelse(Activity == predictions,"TP","FP")) %>% pull(tp) )
|
211 |
+
names(data)<-c("shap","feature","tp")
|
212 |
+
|
213 |
+
p <- ggplot(data, aes(x = feature)) +
|
214 |
+
geom_point(aes(y = shap, color = tp), alpha = 0.3, size = 1.8) +
|
215 |
+
geom_smooth(aes(y = shap),
|
216 |
+
se = FALSE,
|
217 |
+
size = 0.5,
|
218 |
+
linetype = "dashed") +
|
219 |
+
geom_hline(
|
220 |
+
yintercept = 0,
|
221 |
+
color = 'red',
|
222 |
+
size = 0.5,
|
223 |
+
alpha = 0.5
|
224 |
+
) +
|
225 |
+
xlab(feature) +
|
226 |
+
labs(title = paste0("Activity ", activity)) +
|
227 |
+
ylab("SHAP Value") +
|
228 |
+
ylim(-0.1, 0.4) +
|
229 |
+
xlim(0, 1) +
|
230 |
+
theme_light() +
|
231 |
+
theme(legend.position = 'none')
|
232 |
+
|
233 |
+
p1 <-
|
234 |
+
ggMarginal(
|
235 |
+
p,
|
236 |
+
type = "histogram",
|
237 |
+
fill = 'gray',
|
238 |
+
color = 'white',
|
239 |
+
size = 15,
|
240 |
+
xparams = list(bins = 25),
|
241 |
+
yparams = list(bins = 15)
|
242 |
+
) #,margins='x')
|
243 |
+
plots[[activity]] <- p1
|
244 |
+
}
|
245 |
+
do.call(grid.arrange, c(plots, ncol = length(activities)))
|
246 |
+
}
|
247 |
+
|
248 |
+
#' contribution plot for SHAP values
|
249 |
+
#'
|
250 |
+
#' @param shap shap values for a particular class, animal, etc.
|
251 |
+
#' @param num_row the row number of the observation to show
|
252 |
+
#'
|
253 |
+
#' @return ggplot object
|
254 |
+
#' @export
|
255 |
+
#'
|
256 |
+
#' @examples
|
257 |
+
#'
|
258 |
+
#' shap_values_G <- calculate_shap_class(
|
259 |
+
#' dataset = dataset,
|
260 |
+
#' new_data = newdata,
|
261 |
+
#' model= model,
|
262 |
+
#' nsim = 100,
|
263 |
+
#' function_class = p_function_G,
|
264 |
+
#' class_name ="G")
|
265 |
+
#' p1 <- contribution_plot(shap_values_G,num_row = 1) +
|
266 |
+
#' labs(title="Anim a13: class G (FN)", subtitle = "SHAP analysis for class G")
|
267 |
+
#'
|
268 |
+
contribution_plot <-function(s, num_row = 1){
|
269 |
+
s<-s[num_row,]
|
270 |
+
s <- data.frame(
|
271 |
+
Variable = names(s[,1:15]),
|
272 |
+
Importance = apply(s[,1:15], MARGIN = 2, FUN = function(x) sum(x))
|
273 |
+
)
|
274 |
+
ggplot(s, aes(Variable, Importance, Importance,fill=Importance) )+
|
275 |
+
geom_col() +
|
276 |
+
coord_flip() +
|
277 |
+
xlab("") +
|
278 |
+
ylab("Shapley value")+
|
279 |
+
theme_classic()+
|
280 |
+
theme(legend.position = 'none')
|
281 |
+
}
|
282 |
+
|
283 |
+
|
284 |
+
contribution_plot_w_feature <-function(s, f, num_row = 1){
|
285 |
+
d <- data.frame(
|
286 |
+
variable = names(s[num_row,1:15]),
|
287 |
+
importance = apply(s[num_row,1:15], MARGIN = 2, FUN = function(x) sum(x)),
|
288 |
+
value = apply(f[num_row,1:15], MARGIN = 2, FUN = function(x) sum(x))
|
289 |
+
)
|
290 |
+
ggplot(d, aes(variable, importance, value,fill=value) )+
|
291 |
+
geom_col() +
|
292 |
+
geom_text(aes(label=round(value,digits = 2),hjust = 1.0),size=2)+
|
293 |
+
coord_flip() +
|
294 |
+
xlab("") +
|
295 |
+
ylab("Shapley value")+
|
296 |
+
scale_fill_gradient(low = 'lightgray', high = 'skyblue')+
|
297 |
+
theme_classic()+
|
298 |
+
theme(legend.position = 'none')
|
299 |
+
}
|
selected_features.tsv
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
variable
|
2 |
+
Steps
|
3 |
+
%HeadDown
|
4 |
+
Standing
|
5 |
+
Active
|
6 |
+
MeanXY
|
7 |
+
distance(m)
|
8 |
+
prev_steps1
|
9 |
+
X_Act
|
10 |
+
prev_Active1
|
11 |
+
prev_Standing1
|
12 |
+
DFA123
|
13 |
+
prev_headdown1
|
14 |
+
Lying
|
15 |
+
Y_Act
|
16 |
+
DBA123
|