harpomaxx commited on
Commit
41592fb
1 Parent(s): 89e50b1

add SHAP values analysis

Browse files
Files changed (5) hide show
  1. Dockerfile +3 -0
  2. app.R +106 -1
  3. calculate_shap.R +328 -0
  4. plot_shap.R +299 -0
  5. selected_features.tsv +16 -0
Dockerfile CHANGED
@@ -7,7 +7,10 @@ RUN install2.r --error \
7
  ggExtra \
8
  readr \
9
  caret \
 
10
  ggplot2 \
 
 
11
  shiny
12
 
13
  RUN install2.r --error \
 
7
  ggExtra \
8
  readr \
9
  caret \
10
+ fastshap \
11
  ggplot2 \
12
+ ggExtra \
13
+ forcats \
14
  shiny
15
 
16
  RUN install2.r --error \
app.R CHANGED
@@ -5,6 +5,9 @@ library(readr)
5
  library(catboost)
6
  library(ggplot2)
7
 
 
 
 
8
  # Load the pre-trained model
9
  model <- readRDS("goat_behavior_model_caret.rds")
10
 
@@ -149,7 +152,18 @@ ui <- fluidPage(
149
  tableOutput("contents"),
150
  verbatimTextOutput("confusionMatText"),
151
  plotOutput("confusionMatPlot"),
152
- downloadButton("downloadData", "Download Predictions"))
 
 
 
 
 
 
 
 
 
 
 
153
  )
154
  )
155
  )
@@ -229,6 +243,97 @@ server <- function(input, output) {
229
  theme_minimal() +
230
  theme(axis.text.x = element_text(angle = 45, vjust = 1, hjust = 1))
231
  })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  }
233
 
234
  # Create a Shiny app object
 
5
  library(catboost)
6
  library(ggplot2)
7
 
8
+ source("calculate_shap.R")
9
+ source("plot_shap.R")
10
+
11
  # Load the pre-trained model
12
  model <- readRDS("goat_behavior_model_caret.rds")
13
 
 
152
  tableOutput("contents"),
153
  verbatimTextOutput("confusionMatText"),
154
  plotOutput("confusionMatPlot"),
155
+ downloadButton("downloadData", "Download Predictions")),
156
+
157
+ tabPanel("SHAP Summary",
158
+ plotOutput("SHAPSummary")),
159
+
160
+ tabPanel("SHAP Summary per class",
161
+ plotOutput("SHAPSummaryperclass")),
162
+
163
+ tabPanel("SHAP Dependency",
164
+ plotOutput("SHAPDependency"))
165
+
166
+
167
  )
168
  )
169
  )
 
243
  theme_minimal() +
244
  theme(axis.text.x = element_text(angle = 45, vjust = 1, hjust = 1))
245
  })
246
+
247
+ output$SHAPSummary <- renderPlot({
248
+
249
+ if (is.null(input$file1))
250
+ return(NULL)
251
+
252
+ inFile <- input$file1
253
+ dataset <- readr::read_delim(inFile$datapath,delim='\t')
254
+ predictions <- predict(model, dataset)
255
+ selected_variables <-
256
+ readr::read_delim(
257
+ "selected_features.tsv",
258
+ col_types = cols(),
259
+ delim = '\t'
260
+ )
261
+ new_dataset <-
262
+ dataset %>% select(selected_variables$variable, Anim, Activity)
263
+ new_dataset <- cbind(new_dataset, predictions)
264
+
265
+ shap_values <- calculate_shap(new_dataset, model, nsim = 30)
266
+ pall<-shap_summary_plot(shap_values)
267
+ pall+xlim(0,0.35)
268
+ })
269
+
270
+ output$SHAPSummaryperclass <- renderPlot({
271
+
272
+ if (is.null(input$file1))
273
+ return(NULL)
274
+
275
+ inFile <- input$file1
276
+ dataset <- readr::read_delim(inFile$datapath,delim='\t')
277
+ predictions <- predict(model, dataset)
278
+ selected_variables <-
279
+ readr::read_delim(
280
+ "selected_features.tsv",
281
+ col_types = cols(),
282
+ delim = '\t'
283
+ )
284
+ new_dataset <-
285
+ dataset %>% select(selected_variables$variable, Anim, Activity)
286
+ new_dataset <- cbind(new_dataset, predictions)
287
+
288
+ shap_values <- calculate_shap(new_dataset, model, nsim = 30)
289
+
290
+ pW<-shap_summary_plot_perclass(shap_values, class= "W",color="#C77CFF")+xlab("Activity W")+xlim(0,0.25)
291
+ pGM<-shap_summary_plot_perclass(shap_values, class= "GM",color="#7CAE00")+xlab("Activity GM")+xlim(0,0.25)
292
+ pG<-shap_summary_plot_perclass(shap_values, class= "G",color="#F8766D")+xlab("Activity G")+xlim(0,0.25)
293
+ pR<-shap_summary_plot_perclass(shap_values, class= "R",color="#00BFC4")+xlab("Activity R")+xlim(0,0.25)
294
+
295
+ grid.arrange(pW,pR,pG,pGM)
296
+
297
+
298
+
299
+ })
300
+ output$SHAPDependency <- renderPlot({
301
+
302
+ if (is.null(input$file1))
303
+ return(NULL)
304
+
305
+ inFile <- input$file1
306
+ dataset <- readr::read_delim(inFile$datapath,delim='\t')
307
+ predictions <- predict(model, dataset)
308
+ selected_variables <-
309
+ readr::read_delim(
310
+ "selected_features.tsv",
311
+ col_types = cols(),
312
+ delim = '\t'
313
+ )
314
+ new_dataset <-
315
+ dataset %>% select(selected_variables$variable, Anim, Activity)
316
+ new_dataset <- cbind(new_dataset, predictions)
317
+
318
+ shap_values <- calculate_shap(new_dataset, model, nsim = 30)
319
+
320
+ li<-list()
321
+ li[[1]]<-dependency_plot("Steps",dataset = new_dataset,shap=shap_values)
322
+ #li[[2]]<-dependency_plot("prev_steps1",dataset = new_dataset,shap=shap_values)
323
+ li[[2]]<-dependency_plot("%HeadDown",dataset = new_dataset,shap=shap_values)
324
+ #li[[4]]<-dependency_plot("prev_headdown1",dataset = new_dataset,shap=shap_values)
325
+ li[[3]]<-dependency_plot("Active",dataset = new_dataset,shap=shap_values)
326
+ #li[[6]]<-dependency_plot("prev_Active1",dataset = new_dataset,shap=shap_values)
327
+ li[[4]]<-dependency_plot("Standing",dataset = new_dataset,shap=shap_values)
328
+ #li[[8]]<-dependency_plot("prev_Standing1",dataset = new_dataset,shap=shap_values)
329
+ #li[[9]]<-dependency_plot("X_Act",dataset = new_dataset, shap=shap_values)
330
+ #li[[10]]<-dependency_plot("Y_Act",dataset = new_dataset, shap=shap_values)
331
+ #li[[11]]<-dependency_plot("DBA123",dataset = new_dataset, shap=shap_values)
332
+ #li[[12]]<-dependency_plot("DFA123",dataset = new_dataset, shap=shap_values)
333
+ do.call(grid.arrange, c(li, ncol = 1))
334
+
335
+
336
+ })
337
  }
338
 
339
  # Create a Shiny app object
calculate_shap.R ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ suppressPackageStartupMessages(library(dplyr))
2
+ suppressPackageStartupMessages(library(fastshap)) # for fast (approximate) Shapley values
3
+ suppressPackageStartupMessages(library(caret))
4
+ suppressPackageStartupMessages(library(doMC))
5
+
6
+ registerDoMC(cores = 10)
7
+
8
+
9
+ p_function_G <-
10
+ function(object, newdata)
11
+ caret::predict.train(object, newdata = newdata, type = "prob")[, "G"]
12
+ p_function_GM <-
13
+ function(object, newdata)
14
+ caret::predict.train(object, newdata = newdata, type = "prob")[, "GM"]
15
+ p_function_R <-
16
+ function(object, newdata)
17
+ caret::predict.train(object, newdata = newdata, type = "prob")[, "R"]
18
+ p_function_W <-
19
+ function(object, newdata)
20
+ caret::predict.train(object, newdata = newdata, type = "prob")[, "W"]
21
+
22
+ # DEPRECATED
23
+ calculate_shap_deprecated <- function(dataset,model,nsim=10) {
24
+ # library(doParallel)
25
+ # registerDoParallel(8)
26
+
27
+ trainset <- dataset %>% na.omit() %>%
28
+ as.data.frame()
29
+ trainset_y <- dataset %>%
30
+ select(Activity) %>%
31
+ na.omit() %>%
32
+ unlist() %>%
33
+ unname()
34
+ trainset <- trainset %>% select(-Activity)
35
+ trainset_G <- trainset[which(trainset_y == "G"), ]
36
+ trainset_GM <- trainset[which(trainset_y == "GM"), ]
37
+ trainset_R <- trainset[which(trainset_y == "R"), ]
38
+ trainset_W <- trainset[which(trainset_y == "W"), ]
39
+
40
+
41
+ # Compute fast (approximate) Shapley values using 50 Monte Carlo repetitions
42
+ message(" - Calculating SHAP values for class G")
43
+ shap_values_G <-
44
+ fastshap::explain(
45
+ model,
46
+ X = trainset,
47
+ pred_wrapper = p_function_G,
48
+ nsim = nsim,
49
+ newdata = trainset_G,
50
+ .parallel = TRUE
51
+ )
52
+ message(" - Calculating SHAP values for class GM")
53
+ shap_values_GM <-
54
+ fastshap::explain(
55
+ model,
56
+ X = trainset,
57
+ pred_wrapper = p_function_GM,
58
+ nsim = nsim,
59
+ newdata = trainset_GM,
60
+ .parallel = TRUE
61
+ )
62
+ message(" - Calculating SHAP values for class R")
63
+ shap_values_R <-
64
+ fastshap::explain(
65
+ model,
66
+ X = trainset,
67
+ pred_wrapper = p_function_R,
68
+ nsim = nsim,
69
+ newdata = trainset_R,
70
+ .parallel = TRUE
71
+ )
72
+ message(" - Calculating SHAP values for class W")
73
+ shap_values_W <-
74
+ fastshap::explain(
75
+ model,
76
+ X = trainset,
77
+ pred_wrapper = p_function_W,
78
+ nsim = nsim,
79
+ newdata = trainset_W,
80
+ .parallel = TRUE
81
+ # adjust = TRUE
82
+ )
83
+
84
+ shap_values_GM$class<-"GM"
85
+ shap_values_G$class<-"G"
86
+ shap_values_R$class<-"R"
87
+ shap_values_W$class<-"W"
88
+
89
+ shap_values<-rbind(shap_values_G,
90
+ shap_values_GM,
91
+ shap_values_R,
92
+ shap_values_W)
93
+ shap_values
94
+ }
95
+
96
+
97
+ #' A new function for calcualting SHAP values
98
+ #' the function returns a dataframe with SHAP values in the same
99
+ #' order of the original dataset.
100
+ #'
101
+ #' SHAP value dataframe also contains information about Animal and
102
+ #' the prediction of the model. Notice that SHAP are calculated considering
103
+ #' the class (ground truth) and not the prediction. The prediction column is only
104
+ #' used for filtering ana analysis. The function `calculate_shapp_class()` can be
105
+ #' used for calculating SHAP values on prediction
106
+ #'
107
+ #' @param dataset a dataset used for calcuating SHAP. The dataset is used for
108
+ #' permutation during SHAP calculation and also each class is filtered and SHAP
109
+ #' value for each class is calculated.
110
+ #' @param model a model
111
+ #' @param nsim number of monte carlo simulation
112
+ #'
113
+ #' @return
114
+ #' @export
115
+ #'
116
+ #' @examples
117
+ calculate_shap <- function(dataset,model,nsim=10) {
118
+ trainset <- dataset %>% na.omit() %>%
119
+ as.data.frame()
120
+ trainset_y <- dataset %>%
121
+ select(Activity) %>%
122
+ na.omit() %>%
123
+ unlist() %>%
124
+ unname()
125
+ ## Create an ID for maintaining the order
126
+ trainset <- cbind(id=seq(1:nrow(trainset)), trainset)
127
+ trainset <- trainset %>% select(-Activity)
128
+
129
+ trainset_G <- trainset[which(trainset_y == "G"), ]
130
+ trainset_GM <- trainset[which(trainset_y == "GM"), ]
131
+ trainset_R <- trainset[which(trainset_y == "R"), ]
132
+ trainset_W <- trainset[which(trainset_y == "W"), ]
133
+
134
+ id <- c(trainset_G$id,
135
+ trainset_GM$id,
136
+ trainset_R$id,
137
+ trainset_W$id)
138
+ trainset <- trainset %>% select(-id)
139
+ trainset_G <- trainset_G %>% select(-id)
140
+ trainset_GM <- trainset_GM %>% select(-id)
141
+ trainset_R <- trainset_R %>% select(-id)
142
+ trainset_W <- trainset_W %>% select(-id)
143
+
144
+ Anim <- c(trainset_G$Anim,
145
+ trainset_GM$Anim,
146
+ trainset_R$Anim,
147
+ trainset_W$Anim)
148
+ trainset <- trainset %>% select(-Anim)
149
+ trainset_G <- trainset_G %>% select(-Anim)
150
+ trainset_GM <- trainset_GM %>% select(-Anim)
151
+ trainset_R <- trainset_R %>% select(-Anim)
152
+ trainset_W <- trainset_W %>% select(-Anim)
153
+
154
+ predictions <- c(trainset_G$predictions,
155
+ trainset_GM$predictions,
156
+ trainset_R$predictions,
157
+ trainset_W$predictions)
158
+ trainset <- trainset %>% select(-predictions)
159
+ trainset_G <- trainset_G %>% select(-predictions)
160
+ trainset_GM <- trainset_GM %>% select(-predictions)
161
+ trainset_R <- trainset_R %>% select(-predictions)
162
+ trainset_W <- trainset_W %>% select(-predictions)
163
+
164
+ # Compute fast (approximate) Shapley values using 50 Monte Carlo repetitions
165
+ message(" - Calculating SHAP values for class G")
166
+ shap_values_G <-
167
+ fastshap::explain(
168
+ model,
169
+ X = trainset,
170
+ pred_wrapper = p_function_G,
171
+ nsim = nsim,
172
+ newdata = trainset_G,
173
+ .parallel = TRUE
174
+ )
175
+ message(" - Calculating SHAP values for class GM")
176
+ shap_values_GM <-
177
+ fastshap::explain(
178
+ model,
179
+ X = trainset,
180
+ pred_wrapper = p_function_GM,
181
+ nsim = nsim,
182
+ newdata = trainset_GM,
183
+ .parallel = TRUE
184
+ )
185
+ message(" - Calculating SHAP values for class R")
186
+ shap_values_R <-
187
+ fastshap::explain(
188
+ model,
189
+ X = trainset,
190
+ pred_wrapper = p_function_R,
191
+ nsim = nsim,
192
+ newdata = trainset_R,
193
+ .parallel = TRUE
194
+ )
195
+ message(" - Calculating SHAP values for class W")
196
+ shap_values_W <-
197
+ fastshap::explain(
198
+ model,
199
+ X = trainset,
200
+ pred_wrapper = p_function_W,
201
+ nsim = nsim,
202
+ newdata = trainset_W,
203
+ .parallel = TRUE
204
+ # adjust = TRUE
205
+ )
206
+
207
+ shap_values_G$class<-"G"
208
+ shap_values_GM$class<-"GM"
209
+ shap_values_R$class<-"R"
210
+ shap_values_W$class<-"W"
211
+
212
+ shap_values<-rbind(shap_values_G,
213
+ shap_values_GM,
214
+ shap_values_R,
215
+ shap_values_W)
216
+
217
+ shap_values <- shap_values %>% tibble::add_column(Anim)
218
+ shap_values <- shap_values %>% tibble::add_column(predictions)
219
+ #shap_values <-shap_values %>% tibble::add_column(id)
220
+ shap_values[order(id),]
221
+ }
222
+
223
+ #' Calculate SHAP values for a given PREDICTED class
224
+ #'
225
+ #' @param dataset the dataset used for permutation during SHAP calculation
226
+ #' @param new_data the new data we want to calculate SHAP
227
+ #' @param model the model used for explanation
228
+ #' @param nsim the number of Monte Carlos Simulations
229
+ #' @param function_class a wrapper function to obtain only a particular class
230
+ #' @param class_name the name of the class
231
+ #'
232
+ #' @return
233
+ #' @export
234
+ #'
235
+ #' @examples
236
+ #'
237
+ #' # Calculate the SHAP values for class G on new data
238
+ #' shap_values_G <- calculate_shap_class(
239
+ #' dataset,
240
+ #' new_data = newdata,
241
+ #' model = goat_model
242
+ #' nsim = 100,
243
+ #' function_class = p_function_G,
244
+ #' class_name = "G")
245
+ #'
246
+ #'
247
+ calculate_shap_class <- function(dataset, new_data, model,nsim=10,
248
+ function_class, class_name = "G") {
249
+ trainset <- dataset %>% na.omit() %>%
250
+ as.data.frame()
251
+ trainset_y <- dataset %>%
252
+ select(predictions) %>%
253
+ na.omit() %>%
254
+ unlist() %>%
255
+ unname()
256
+
257
+ trainset<- trainset %>%select (-Activity,-predictions,-Anim)
258
+ new_data_class <- new_data
259
+
260
+ Anim <- new_data_class$Anim
261
+ new_data_class <- new_data_class %>% select(-Anim)
262
+
263
+ Activity <- new_data_class$Activity
264
+ new_data_class <- new_data_class %>% select(-Activity)
265
+
266
+ predictions <- new_data_class$predictions
267
+ new_data_class <- new_data_class %>% select(-predictions)
268
+
269
+ # Compute fast (approximate) Shapley values using 50 Monte Carlo repetitions
270
+ message(" - Calculating SHAP values for class ",class_name)
271
+ shap_values_class <-
272
+ fastshap::explain(
273
+ model,
274
+ X = trainset,
275
+ pred_wrapper = function_class,
276
+ nsim = nsim,
277
+ newdata = new_data_class,
278
+ .parallel = TRUE
279
+ )
280
+
281
+ shap_values_class$class<-Activity
282
+ shap_values<-shap_values_class
283
+
284
+ shap_values <- shap_values %>% tibble::add_column(Anim)
285
+ shap_values <- shap_values %>% tibble::add_column(predictions)
286
+ shap_values
287
+ }
288
+
289
+ shap_summary_plot<-function(shap_values){
290
+ summary_plot <-
291
+ shap_values %>% reshape2::melt() %>% group_by(class, variable) %>%
292
+ summarise(mean = mean(abs(value))) %>%
293
+ arrange(desc(mean)) %>%
294
+ ggplot() +
295
+ ggdark::dark_theme_classic() +
296
+ geom_col(aes(
297
+ y = variable,
298
+ x = mean,
299
+ group = class,
300
+ fill = class
301
+ ), position = "stack") +
302
+ xlab("Mean(|Shap Value|) Average impact on model output magnitude")
303
+ summary_plot
304
+
305
+ }
306
+
307
+ shap_beeswarm_plot<-function(shap_values,dataset){
308
+
309
+ shap_values <- shap_values %>% reshape2::melt()
310
+ dataset<-dataset %>% mutate(class=Activity) %>% select(-Activity) %>%
311
+ reshape2::melt() %>% group_by(variable) %>%
312
+ mutate(value_scale=range01(value))
313
+
314
+ beeswarm_plot<-cbind(shap_values, feature_value=dataset$value_scale) %>%
315
+ # filter(class=="GM") %>%
316
+ ggplot()+
317
+ facet_wrap(~class)+
318
+ #ggdark::dark_theme_bw()+
319
+ theme_classic()+
320
+ geom_hline(yintercept=0,
321
+ color = "red", size=0.5)+
322
+ ggforce::geom_sina(aes(x=variable,y=value,color=feature_value),size=0.5,bins=4,alpha=0.9,shape=15)+
323
+ scale_colour_gradient(low = "yellow", high = "red", na.value = NA)+
324
+ scale_colour_gradient(low = "skyblue", high = "orange", na.value = NA)+
325
+ xlab("Feature")+ylab("SHAP value")+
326
+ theme(axis.text.x = element_text(angle = 45, vjust = 0.5, hjust=1))
327
+ beeswarm_plot
328
+ }
plot_shap.R ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ suppressPackageStartupMessages(library(dplyr))
2
+ suppressPackageStartupMessages(library(ggplot2))
3
+ suppressPackageStartupMessages(library(ggExtra))
4
+ suppressPackageStartupMessages(library(forcats))
5
+
6
+
7
+
8
+ range01 <- function(x){(x-min(x))/(max(x)-min(x))}
9
+
10
+ shap_summary_plot<-function(shap_values){
11
+ summary_plot <-
12
+ shap_values %>% reshape2::melt() %>% group_by(class, variable) %>%
13
+ summarise(mean = mean(abs(value))) %>%
14
+ arrange(desc(mean)) %>%
15
+ ggplot() +
16
+ # ggdark::dark_theme_classic() +
17
+ theme_classic()+
18
+ geom_col(aes(
19
+ y = variable,
20
+ x = mean,
21
+ group = class,
22
+ fill = class
23
+ ), position = "stack") +
24
+ ylab("Feature")+
25
+ xlab("Mean(|Shap Value|) Average impact on model output magnitude per activity.")+
26
+ guides(fill=guide_legend(title="Activity"))
27
+ summary_plot
28
+
29
+ }
30
+
31
+
32
+ shap_summary_plot_perclass<-function(shap_values, class="G",color="#F8766D"){
33
+ shap_values <-shap_values %>% as.data.frame() %>% filter(class == {{class}} )
34
+ summary_plot <-
35
+ shap_values %>% reshape2::melt() %>% group_by(variable) %>%
36
+ summarise(mean = mean(abs(value))) %>%
37
+ ggplot() +
38
+ theme_classic()+
39
+ geom_col(aes(
40
+ x = mean,
41
+ y = fct_reorder(variable,mean)
42
+ ),
43
+ fill = color
44
+ ) +
45
+ ylab("Feature")+
46
+ xlab(paste0("Mean(|Shap Value|) Average impact on model output magnitude for activity ", class))+
47
+ guides(fill=guide_legend(title="Activity"))
48
+ summary_plot
49
+
50
+ }
51
+
52
+
53
+ shap_beeswarm_plot<-function(shap_values,dataset){
54
+
55
+ shap_values <- shap_values %>% reshape2::melt()
56
+ dataset<-dataset %>% mutate(class=Activity) %>% select(-Activity) %>%
57
+ reshape2::melt() %>% group_by(variable) %>%
58
+ mutate(value_scale=range01(value))
59
+
60
+ beeswarm_plot<-cbind(shap_values, feature_value=dataset$value_scale) %>% # filter(class=="GM") %>%
61
+ ggplot()+
62
+ facet_wrap(~class)+
63
+ #ggdark::dark_theme_bw()+
64
+ theme_classic()+
65
+ geom_hline(yintercept=0,
66
+ color = "red", size=0.5)+
67
+ ggforce::geom_sina(aes(x=variable,y=value,fill=feature_value),color="black", size=2.4,bins=4,alpha=0.9,shape=22)+
68
+ scale_fill_gradient(low = "yellow", high = "red", na.value = NA)+
69
+ scale_fill_gradient(low = "skyblue", high = "orange", na.value = NA)+
70
+ xlab("Feature")+ylab("SHAP value")+
71
+ theme(axis.text.x = element_text(angle = 45, vjust = 0.5, hjust=1))
72
+ beeswarm_plot
73
+
74
+
75
+ }
76
+
77
+
78
+ #' Dependency plot for a particular feature. The plot considers
79
+ #' activities and FP/TP
80
+ #'
81
+ #' @param feature a particular feature to calculate
82
+ #' @param dataset a dataset with goat information
83
+ #' @param shap a shap value dataset for each feature.
84
+ #'
85
+ #' @return a dependency plot for each activity considering the selected feature
86
+ #' @export ggplot object
87
+ #'
88
+ #' @examples
89
+ #'
90
+ #' dataset <-
91
+ #' readr::read_delim("data/split/seba-caprino_loocv.tsv",
92
+ #' delim = '\t')
93
+ #' selected_variables <-
94
+ #' readr::read_delim(
95
+ #' "data/topnfeatures/seba-caprino_selected_features.tsv",
96
+ #' col_types = cols(),
97
+ #' delim = '\t'
98
+ #' )
99
+ #' dataset <-
100
+ #' dataset %>% select(selected_variables$variable,
101
+ #' Anim,
102
+ #' Activity)
103
+ #' goat_model <- readRDS("models/boost/seba-caprino_model.rds")
104
+ #' shap_values <- calculate_shap(dataset,
105
+ #' model = goat_model,
106
+ #' nsim = 30)
107
+ #' dependency_plot_full(feature = "Steps",
108
+ #' dataset = dataset,
109
+ #' shap = shap_values)
110
+
111
+ dependency_plot <- function(feature, dataset, shap) {
112
+ newdata <- dataset %>% mutate({{ feature }} := range01(!!sym(feature)))
113
+ #activities <- c("G", "GM", "W", "R")
114
+ activities<-dataset %>% pull(Activity) %>% unique()
115
+ plots <- list()
116
+ for (activity in activities) {
117
+ s <- shap[which(shap$class == activity), 1:18]
118
+ x <- newdata[which(newdata$Activity == activity), ]
119
+ data <- cbind(
120
+ shap = (s %>% as.data.frame %>% select(feature)),
121
+ feature = (x %>% select(feature)),
122
+ tp = x %>% mutate(tp = ifelse(Activity == predictions, "TP", "FP")) %>%
123
+ pull(tp)
124
+ )
125
+ names(data) <- c("shap", "feature", "tp")
126
+ p <- ggplot(data, aes(x = feature)) +
127
+ geom_point(aes(y = shap, color = tp), alpha = 0.3, size = 0.8) +
128
+ geom_smooth(aes(y = shap),
129
+ se = FALSE,
130
+ size = 0.5,
131
+ linetype = "dashed") +
132
+ geom_hline(
133
+ yintercept = 0,
134
+ color = 'red',
135
+ size = 0.5,
136
+ alpha = 0.5
137
+ ) +
138
+ xlab(feature) +
139
+ labs(title = paste0("Activity ", activity)) +
140
+ ylab("SHAP Value") +
141
+ ylim(-0.1, 0.4) +
142
+ xlim(0, 1) +
143
+ theme_light() +
144
+ theme(legend.position = 'none')
145
+
146
+ p1 <-
147
+ ggMarginal(
148
+ p,
149
+ type = "histogram",
150
+ fill = 'gray',
151
+ color = 'white',
152
+ size = 10,
153
+ xparams = list(bins = 25),
154
+ yparams = list(bins = 15)
155
+ ) #,margins='x')
156
+ plots[[activity]] <- p1
157
+ }
158
+ #plots
159
+ do.call(grid.arrange, c(plots, ncol = 4))
160
+ }
161
+
162
+
163
+ #' Dependency plot for a particular feature on a particular animal.
164
+ #' The plot considers activities and FP/TP
165
+ #'
166
+ #' @param feature a particular feature to calculate
167
+ #' @param dataset a dataset with goat information
168
+ #' @param shap a shap value dataset for each feature.
169
+ #' @param anim the id of the animal
170
+ #' @return a dependency plot for each activity considering the selected feature
171
+ #' @export ggplot object
172
+ #'
173
+ #' @examples
174
+ #'
175
+ #' dataset <-
176
+ #' readr::read_delim("data/split/seba-caprino_loocv.tsv",
177
+ #' delim = '\t')
178
+ #' selected_variables <-
179
+ #' readr::read_delim(
180
+ #' "data/topnfeatures/seba-caprino_selected_features.tsv",
181
+ #' col_types = cols(),
182
+ #' delim = '\t'
183
+ #' )
184
+ #' dataset <-
185
+ #' dataset %>% select(selected_variables$variable,
186
+ #' Anim,
187
+ #' Activity)
188
+ #' goat_model <- readRDS("models/boost/seba-caprino_model.rds")
189
+ #' shap_values <- calculate_shap(dataset,
190
+ #' model = goat_model,
191
+ #' nsim = 30)
192
+ #' dependency_plot_anim(feature = "Steps",
193
+ #' dataset = dataset,
194
+ #' shap = shap_values,
195
+ #' anim = 'a13')
196
+ dependency_plot_anim<- function(feature,dataset,shap,anim){
197
+
198
+ newdata <- dataset %>% mutate({{feature}} := range01(!!sym(feature)))
199
+ plots<-list()
200
+ activities<-newdata %>% filter(Anim == anim) %>% pull(Activity) %>% unique()
201
+ for (activity in activities) {
202
+ s <- shap[which(shap$class == activity &
203
+ shap$Anim == anim
204
+ ), 1:18]
205
+ x <- newdata[which(newdata$Activity == activity &
206
+ newdata$Anim == anim
207
+ ),]
208
+ data <- cbind(shap=(s %>% as.data.frame %>% select(feature)),
209
+ feature = (x %>% select(feature)),
210
+ tp = x %>% mutate(tp=ifelse(Activity == predictions,"TP","FP")) %>% pull(tp) )
211
+ names(data)<-c("shap","feature","tp")
212
+
213
+ p <- ggplot(data, aes(x = feature)) +
214
+ geom_point(aes(y = shap, color = tp), alpha = 0.3, size = 1.8) +
215
+ geom_smooth(aes(y = shap),
216
+ se = FALSE,
217
+ size = 0.5,
218
+ linetype = "dashed") +
219
+ geom_hline(
220
+ yintercept = 0,
221
+ color = 'red',
222
+ size = 0.5,
223
+ alpha = 0.5
224
+ ) +
225
+ xlab(feature) +
226
+ labs(title = paste0("Activity ", activity)) +
227
+ ylab("SHAP Value") +
228
+ ylim(-0.1, 0.4) +
229
+ xlim(0, 1) +
230
+ theme_light() +
231
+ theme(legend.position = 'none')
232
+
233
+ p1 <-
234
+ ggMarginal(
235
+ p,
236
+ type = "histogram",
237
+ fill = 'gray',
238
+ color = 'white',
239
+ size = 15,
240
+ xparams = list(bins = 25),
241
+ yparams = list(bins = 15)
242
+ ) #,margins='x')
243
+ plots[[activity]] <- p1
244
+ }
245
+ do.call(grid.arrange, c(plots, ncol = length(activities)))
246
+ }
247
+
248
+ #' contribution plot for SHAP values
249
+ #'
250
+ #' @param shap shap values for a particular class, animal, etc.
251
+ #' @param num_row the row number of the observation to show
252
+ #'
253
+ #' @return ggplot object
254
+ #' @export
255
+ #'
256
+ #' @examples
257
+ #'
258
+ #' shap_values_G <- calculate_shap_class(
259
+ #' dataset = dataset,
260
+ #' new_data = newdata,
261
+ #' model= model,
262
+ #' nsim = 100,
263
+ #' function_class = p_function_G,
264
+ #' class_name ="G")
265
+ #' p1 <- contribution_plot(shap_values_G,num_row = 1) +
266
+ #' labs(title="Anim a13: class G (FN)", subtitle = "SHAP analysis for class G")
267
+ #'
268
+ contribution_plot <-function(s, num_row = 1){
269
+ s<-s[num_row,]
270
+ s <- data.frame(
271
+ Variable = names(s[,1:15]),
272
+ Importance = apply(s[,1:15], MARGIN = 2, FUN = function(x) sum(x))
273
+ )
274
+ ggplot(s, aes(Variable, Importance, Importance,fill=Importance) )+
275
+ geom_col() +
276
+ coord_flip() +
277
+ xlab("") +
278
+ ylab("Shapley value")+
279
+ theme_classic()+
280
+ theme(legend.position = 'none')
281
+ }
282
+
283
+
284
+ contribution_plot_w_feature <-function(s, f, num_row = 1){
285
+ d <- data.frame(
286
+ variable = names(s[num_row,1:15]),
287
+ importance = apply(s[num_row,1:15], MARGIN = 2, FUN = function(x) sum(x)),
288
+ value = apply(f[num_row,1:15], MARGIN = 2, FUN = function(x) sum(x))
289
+ )
290
+ ggplot(d, aes(variable, importance, value,fill=value) )+
291
+ geom_col() +
292
+ geom_text(aes(label=round(value,digits = 2),hjust = 1.0),size=2)+
293
+ coord_flip() +
294
+ xlab("") +
295
+ ylab("Shapley value")+
296
+ scale_fill_gradient(low = 'lightgray', high = 'skyblue')+
297
+ theme_classic()+
298
+ theme(legend.position = 'none')
299
+ }
selected_features.tsv ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ variable
2
+ Steps
3
+ %HeadDown
4
+ Standing
5
+ Active
6
+ MeanXY
7
+ distance(m)
8
+ prev_steps1
9
+ X_Act
10
+ prev_Active1
11
+ prev_Standing1
12
+ DFA123
13
+ prev_headdown1
14
+ Lying
15
+ Y_Act
16
+ DBA123