Chapter 6 Model Interpretation
A general criticism of machine learning methods is that they are often black boxes: it is difficult to disentangle the influence of one variable from another. This is particularly true for complex models such as random forests and gradient boosting machines. However, there are several methods that can be used to interpret the results of these models.
There are several mechanistic insights that we may want to have given a set of data and a model:
Variable importance: Which variables are most important in predicting the outcome?
Variable interactions: Are there interactions between variables that are important in predicting the outcome?
Local interpretability: Given a single observation, can we understand why the model made the prediction it did?
Model structure: What does the model look like? What are the decision rules that the model is using?
In this chapter, we will explore several methods for interpreting the results of a model. We will use the DALEX
package to do this. This package provides a unified interface for interpreting the results of a wide variety of models.
6.1 Variable importance
In a simple linear regression model, the coefficients of the model give us a direct measure of the importance of each variable. For example, consider the following linear regression model:
##
## Call:
## lm(formula = mpg ~ ., data = mtcars)
##
## Residuals:
## Min 1Q Median 3Q Max
## -3.4506 -1.6044 -0.1196 1.2193 4.6271
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 12.30337 18.71788 0.657 0.5181
## cyl -0.11144 1.04502 -0.107 0.9161
## disp 0.01334 0.01786 0.747 0.4635
## hp -0.02148 0.02177 -0.987 0.3350
## drat 0.78711 1.63537 0.481 0.6353
## wt -3.71530 1.89441 -1.961 0.0633 .
## qsec 0.82104 0.73084 1.123 0.2739
## vs 0.31776 2.10451 0.151 0.8814
## am 2.52023 2.05665 1.225 0.2340
## gear 0.65541 1.49326 0.439 0.6652
## carb -0.19942 0.82875 -0.241 0.8122
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 2.65 on 21 degrees of freedom
## Multiple R-squared: 0.869, Adjusted R-squared: 0.8066
## F-statistic: 13.93 on 10 and 21 DF, p-value: 3.793e-07
The coefficients of the model give us a direct measure of the importance of each variable. For example, the coefficient of wt
is -3.72, which means that for every one unit increase in wt
, the predicted value of mpg
decreases by 3.72.
It is also straight forward to determine important interactions between variables.
##
## Call:
## lm(formula = mpg ~ . + cyl:disp, data = mtcars)
##
## Residuals:
## Min 1Q Median 3Q Max
## -3.1697 -1.6096 -0.1275 1.1873 3.8355
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 29.976395 18.535141 1.617 0.1215
## cyl -1.789619 1.183617 -1.512 0.1462
## disp -0.095947 0.049001 -1.958 0.0643 .
## hp -0.033409 0.020359 -1.641 0.1164
## drat -0.541227 1.584761 -0.342 0.7363
## wt -3.552721 1.717760 -2.068 0.0518 .
## qsec 0.698111 0.664203 1.051 0.3058
## vs 0.828745 1.918957 0.432 0.6705
## am 0.819051 1.997640 0.410 0.6862
## gear 1.554511 1.405425 1.106 0.2818
## carb 0.144212 0.764824 0.189 0.8523
## cyl:disp 0.013762 0.005825 2.363 0.0284 *
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 2.401 on 20 degrees of freedom
## Multiple R-squared: 0.8976, Adjusted R-squared: 0.8413
## F-statistic: 15.94 on 11 and 20 DF, p-value: 1.441e-07
For example, the coefficient of cyl:disp
is 0.01, which means that for every one unit increase in cyl
and disp
, the predicted value of mpg
increases by 0.01.
However, for more complex models such as random forests and gradient boosting machines, it is not as straight forward to determine the importance of each variable.
6.1.1 Decision tree structure
Another method for interpreting the results of a model is to examine the decision tree structure of the model. Of course, decision trees are far more accessible than other ML methods. There are several packages in R that can be used to visualise the decision tree structure of a model. One such package is rpart
.
In a random forest setting, the decision tree structure of the model can be used to determine the importance of each variable. The importance of each variable is determined by the number of times the variable is used in the decision tree structure.
Example
Here we will use the rpart
package to visualise the decision tree structure of a random forest model.
library(randomForest)
library(rpart)
# Fit the random forest model
rf_model <- randomForest(mpg ~ .,
data = mtcars)
# Extract the decision tree structure of the model
tree <- getTree(rf_model, k = 1, labelVar = TRUE)
# Examine the decision tree structure
tree
## left daughter right daughter split var split point status prediction
## 1 2 3 vs 0.5000 -3 18.84062
## 2 4 5 drat 4.3250 -3 16.39565
## 3 6 7 wt 2.1275 -3 25.08889
## 4 8 9 hp 192.5000 -3 14.95500
## 5 0 0 <NA> 0.0000 -1 26.00000
## 6 0 0 <NA> 0.0000 -1 29.50000
## 7 10 11 hp 77.5000 -3 22.88333
## 8 12 13 carb 4.5000 -3 17.27000
## 9 14 15 hp 230.0000 -3 12.64000
## 10 0 0 <NA> 0.0000 -1 24.40000
## 11 0 0 <NA> 0.0000 -1 22.12500
## 12 16 17 qsec 17.1600 -3 17.00000
## 13 0 0 <NA> 0.0000 -1 19.70000
## 14 0 0 <NA> 0.0000 -1 10.40000
## 15 18 19 drat 3.9750 -3 14.13333
## 16 0 0 <NA> 0.0000 -1 17.90000
## 17 0 0 <NA> 0.0000 -1 16.28000
## 18 0 0 <NA> 0.0000 -1 13.30000
## 19 0 0 <NA> 0.0000 -1 15.80000
## IncNodePurity
## cyl 173.98820
## disp 247.11280
## hp 167.31635
## drat 81.85837
## wt 267.73722
## qsec 37.52913
## vs 32.33025
## am 13.40816
## gear 18.20693
## carb 32.83072
For a random forest model, we can also count how many times variables are used in the decision tree structure.
## [1] "cyl" "disp" "hp" "drat" "wt" "qsec" "vs" "am" "gear" "carb"
## [1] 298 819 792 505 817 564 128 115 159 321
This might suggest that disp
is the most important variable in the model, but it could be that am
is the most important but its influence is captured early in the decision tree structure.
Instead, we might count the number of times each variable is used for the first split in a decision tree within the random forest:
# Count how many times each variable is used for the first split in a decision tree
first.split <- sapply(1:rf_model$ntree, function(i) {
rfTree <- getTree(rf_model, k = i, labelVar = TRUE)[1,3]
variable.names(mtcars)[-1][rfTree]
})
table(first.split)
## first.split
## am cyl disp drat hp qsec vs wt
## 5 91 112 69 87 50 16 70
Now, as the first division tends to be the most important in terms of predicting the outcome, we can see that disp
and cyl
seem to be the most important variables in the model.
6.1.2 Permutation-based feature importance
One way to determine the importance of each variable in a random forest or gradient boosting machine is to use a permutation-based feature importance method. This method works as follows:
Fit the model to the data.
For each variable in turn,
- randomly permute the values of that variable in the data;
- calculate the change in the model’s performance.
The importance of each variable is the average change in the model’s performance across all permutations.
If a variable is important in predicting the outcome, then permuting the values of that variable should result in a large change in the model’s performance. Conversely, if a variable is not important in predicting the outcome, then permuting the values of that variable should result in a small change in the model’s performance.
The algorithm briefly outlined here is a one-at-a-time method. It is possible that such a method may miss important interactions between variables. There are other methods that can be used to determine the importance of interactions between variables.
Example
Here we will look at this strategy for a linear regression model and consider its utility in contrast to the more traditional hypothesis tests. Again, instead of relying on a package, we will implement this method ourselves.
# Fit the model
lm_model <- lm(mpg ~ ., data = mtcars)
# Hypothesis test results for each variable
summary(lm_model)
##
## Call:
## lm(formula = mpg ~ ., data = mtcars)
##
## Residuals:
## Min 1Q Median 3Q Max
## -3.4506 -1.6044 -0.1196 1.2193 4.6271
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 12.30337 18.71788 0.657 0.5181
## cyl -0.11144 1.04502 -0.107 0.9161
## disp 0.01334 0.01786 0.747 0.4635
## hp -0.02148 0.02177 -0.987 0.3350
## drat 0.78711 1.63537 0.481 0.6353
## wt -3.71530 1.89441 -1.961 0.0633 .
## qsec 0.82104 0.73084 1.123 0.2739
## vs 0.31776 2.10451 0.151 0.8814
## am 2.52023 2.05665 1.225 0.2340
## gear 0.65541 1.49326 0.439 0.6652
## carb -0.19942 0.82875 -0.241 0.8122
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 2.65 on 21 degrees of freedom
## Multiple R-squared: 0.869, Adjusted R-squared: 0.8066
## F-statistic: 13.93 on 10 and 21 DF, p-value: 3.793e-07
# Calculate the performance of the model
performance <- summary(lm_model)$r.squared
# Permute the values of each variable in turn
permuted_performance <- sapply(names(mtcars), function(var) {
# Permute the values of the variable
permuted_data <- mtcars
permuted_data[[var]] <- sample(permuted_data[[var]])
# Fit the model to the permuted data
permuted_lm_model <- lm(mpg ~ ., data = permuted_data)
# Calculate the performance of the model
permuted_performance <- summary(permuted_lm_model)$r.squared
# Return the change in performance
return(performance - permuted_performance)
})
# Plot the results
barplot(permuted_performance, names.arg = names(mtcars), las = 2)
There are, of course, packages available in R that will implement this method for you. One such package is DALEX
.
Example
Now, we will utilise the DALEX
package to determine the importance of each variable in a random forest model for the Boston housing data.
## Registered S3 method overwritten by 'DALEX':
## method from
## print.description questionr
## Welcome to DALEX (version: 2.4.3).
## Find examples and detailed introduction at: http://ema.drwhy.ai/
# Fit the random forest model
rf_model <- randomForest(medv ~ ., data = Boston)
# Create an explainer object
explainer <- explain(rf_model, data = Boston[-14], y = Boston$medv)
## Preparation of a new explainer is initiated
## -> model label : randomForest ( default )
## -> data : 506 rows 13 cols
## -> target variable : 506 values
## -> predict function : yhat.randomForest will be used ( default )
## -> predicted values : No value for predict function target column. ( default )
## -> model_info : package randomForest , ver. 4.7.1.1 , task regression ( default )
## -> predicted values : numerical, min = 6.763717 , mean = 22.53069 , max = 49.019
## -> residual function : difference between y and yhat ( default )
## -> residuals : numerical, min = -5.218475 , mean = 0.00211808 , max = 7.493544
## A new explainer has been created!
# Calculate the importance of each variable
variable_importance <- variable_importance(explainer)
variable_importance
## variable mean_dropout_loss label
## 1 _full_model_ 1.365472 randomForest
## 2 chas 1.418405 randomForest
## 3 zn 1.433858 randomForest
## 4 rad 1.501006 randomForest
## 5 black 1.703902 randomForest
## 6 age 1.839409 randomForest
## 7 tax 1.900797 randomForest
## 8 indus 2.034476 randomForest
## 9 ptratio 2.284982 randomForest
## 10 crim 2.307187 randomForest
## 11 nox 2.525805 randomForest
## 12 dis 2.540718 randomForest
## 13 rm 5.093851 randomForest
## 14 lstat 6.428002 randomForest
## 15 _baseline_ 12.510559 randomForest
Here, mean_dropout_loss
is the average change in the model’s performance across all permutations. The larger the value, the more important the variable is in predicting the outcome.
6.2 Main effect visualisations
6.2.1 Main effects
A main effect is the effect of a variable on the outcome, while holding all other variables constant. To help us to understand the utility of these let’s consider a situation where we know the true model generating the data.
Example
Here, we have one outcome variable and two explanatory variables. The true model is:
\[ Y = X_1 + X_2^2. \] We will have strongly correlated explanatory variables:
\[\begin{align*} X_1 &\sim \text{Uni}(0, 1), \\ X_2 &= X_1 + \epsilon, \quad \epsilon \sim \text{N}(0, 0.02). \end{align*}\]
Here’s some data for us to work with:
It is trivial to derive the main effect of the model for fixed values of \(X_1\).
Similarly for fixed values of \(X_2\):
[Add in derived effects]
Another powerful method for interpreting the results of a model is to use the iml
package. This package provides a unified interface for interpreting the results of a wide variety of models. This package stems from a 2020 paper by Apley (2020) and is a powerful tool for interpreting the results of a wide variety of models. [This is incorrect, but both need to be included in the book].
6.2.2 Partial dependence plots
One method for interpreting the results of a model is to use partial dependence plots. A partial dependence plot shows the relationship between a variable and the outcome, while holding all other variables constant. This allows us to see the effect of a variable on the outcome, while controlling for the effects of other variables.
6.2.3 Individual conditional expectation plots
Another method for interpreting the results of a model is to use individual conditional expectation plots. An individual conditional expectation plot shows the relationship between a variable and the outcome for a single observation, while holding all other variables constant. This allows us to see the effect of a variable on the outcome for a single observation, while controlling for the effects of other variables.
6.2.4 Accumulated local effects
Another method for interpreting the results of a model is to use accumulated local effects. Accumulated local effects show the effect of a variable on the outcome, while taking into account the effects of other variables. This allows us to see the effect of a variable on the outcome, while controlling for the effects of other variables.