Chapter 6 Model Interpretation

A general criticism of machine learning methods is that they are often black boxes: it is difficult to disentangle the influence of one variable from another. This is particularly true for complex models such as random forests and gradient boosting machines. However, there are several methods that can be used to interpret the results of these models.

There are several mechanistic insights that we may want to have given a set of data and a model:

  1. Variable importance: Which variables are most important in predicting the outcome?

  2. Variable interactions: Are there interactions between variables that are important in predicting the outcome?

  3. Local interpretability: Given a single observation, can we understand why the model made the prediction it did?

  4. Model structure: What does the model look like? What are the decision rules that the model is using?

In this chapter, we will explore several methods for interpreting the results of a model. We will use the DALEX package to do this. This package provides a unified interface for interpreting the results of a wide variety of models.

6.1 Variable importance

In a simple linear regression model, the coefficients of the model give us a direct measure of the importance of each variable. For example, consider the following linear regression model:

lm_model <- lm(mpg ~ ., data = mtcars)

summary(lm_model)
## 
## Call:
## lm(formula = mpg ~ ., data = mtcars)
## 
## Residuals:
##     Min      1Q  Median      3Q     Max 
## -3.4506 -1.6044 -0.1196  1.2193  4.6271 
## 
## Coefficients:
##             Estimate Std. Error t value Pr(>|t|)  
## (Intercept) 12.30337   18.71788   0.657   0.5181  
## cyl         -0.11144    1.04502  -0.107   0.9161  
## disp         0.01334    0.01786   0.747   0.4635  
## hp          -0.02148    0.02177  -0.987   0.3350  
## drat         0.78711    1.63537   0.481   0.6353  
## wt          -3.71530    1.89441  -1.961   0.0633 .
## qsec         0.82104    0.73084   1.123   0.2739  
## vs           0.31776    2.10451   0.151   0.8814  
## am           2.52023    2.05665   1.225   0.2340  
## gear         0.65541    1.49326   0.439   0.6652  
## carb        -0.19942    0.82875  -0.241   0.8122  
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 2.65 on 21 degrees of freedom
## Multiple R-squared:  0.869,  Adjusted R-squared:  0.8066 
## F-statistic: 13.93 on 10 and 21 DF,  p-value: 3.793e-07

The coefficients of the model give us a direct measure of the importance of each variable. For example, the coefficient of wt is -3.72, which means that for every one unit increase in wt, the predicted value of mpg decreases by 3.72.

It is also straight forward to determine important interactions between variables.

lm_model <- lm(mpg ~ . + cyl:disp, data = mtcars)

summary(lm_model)
## 
## Call:
## lm(formula = mpg ~ . + cyl:disp, data = mtcars)
## 
## Residuals:
##     Min      1Q  Median      3Q     Max 
## -3.1697 -1.6096 -0.1275  1.1873  3.8355 
## 
## Coefficients:
##              Estimate Std. Error t value Pr(>|t|)  
## (Intercept) 29.976395  18.535141   1.617   0.1215  
## cyl         -1.789619   1.183617  -1.512   0.1462  
## disp        -0.095947   0.049001  -1.958   0.0643 .
## hp          -0.033409   0.020359  -1.641   0.1164  
## drat        -0.541227   1.584761  -0.342   0.7363  
## wt          -3.552721   1.717760  -2.068   0.0518 .
## qsec         0.698111   0.664203   1.051   0.3058  
## vs           0.828745   1.918957   0.432   0.6705  
## am           0.819051   1.997640   0.410   0.6862  
## gear         1.554511   1.405425   1.106   0.2818  
## carb         0.144212   0.764824   0.189   0.8523  
## cyl:disp     0.013762   0.005825   2.363   0.0284 *
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 2.401 on 20 degrees of freedom
## Multiple R-squared:  0.8976, Adjusted R-squared:  0.8413 
## F-statistic: 15.94 on 11 and 20 DF,  p-value: 1.441e-07

For example, the coefficient of cyl:disp is 0.01, which means that for every one unit increase in cyl and disp, the predicted value of mpg increases by 0.01.

However, for more complex models such as random forests and gradient boosting machines, it is not as straight forward to determine the importance of each variable.

6.1.1 Decision tree structure

Another method for interpreting the results of a model is to examine the decision tree structure of the model. Of course, decision trees are far more accessible than other ML methods. There are several packages in R that can be used to visualise the decision tree structure of a model. One such package is rpart.

In a random forest setting, the decision tree structure of the model can be used to determine the importance of each variable. The importance of each variable is determined by the number of times the variable is used in the decision tree structure.

Example

Here we will use the rpart package to visualise the decision tree structure of a random forest model.

library(randomForest)
library(rpart)

# Fit the random forest model
rf_model <- randomForest(mpg ~ .,
                         data = mtcars)

# Extract the decision tree structure of the model
tree <- getTree(rf_model, k = 1, labelVar = TRUE)

# Examine the decision tree structure
tree
##    left daughter right daughter split var split point status prediction
## 1              2              3        vs      0.5000     -3   18.84062
## 2              4              5      drat      4.3250     -3   16.39565
## 3              6              7        wt      2.1275     -3   25.08889
## 4              8              9        hp    192.5000     -3   14.95500
## 5              0              0      <NA>      0.0000     -1   26.00000
## 6              0              0      <NA>      0.0000     -1   29.50000
## 7             10             11        hp     77.5000     -3   22.88333
## 8             12             13      carb      4.5000     -3   17.27000
## 9             14             15        hp    230.0000     -3   12.64000
## 10             0              0      <NA>      0.0000     -1   24.40000
## 11             0              0      <NA>      0.0000     -1   22.12500
## 12            16             17      qsec     17.1600     -3   17.00000
## 13             0              0      <NA>      0.0000     -1   19.70000
## 14             0              0      <NA>      0.0000     -1   10.40000
## 15            18             19      drat      3.9750     -3   14.13333
## 16             0              0      <NA>      0.0000     -1   17.90000
## 17             0              0      <NA>      0.0000     -1   16.28000
## 18             0              0      <NA>      0.0000     -1   13.30000
## 19             0              0      <NA>      0.0000     -1   15.80000
# Extract the importance of each variable
importance <- importance(rf_model)
importance
##      IncNodePurity
## cyl      173.98820
## disp     247.11280
## hp       167.31635
## drat      81.85837
## wt       267.73722
## qsec      37.52913
## vs        32.33025
## am        13.40816
## gear      18.20693
## carb      32.83072

For a random forest model, we can also count how many times variables are used in the decision tree structure.

# Count how many times variables are used in the decision tree structure
variable.names(mtcars)[-1]
##  [1] "cyl"  "disp" "hp"   "drat" "wt"   "qsec" "vs"   "am"   "gear" "carb"
varUsed(rf_model)
##  [1] 298 819 792 505 817 564 128 115 159 321

This might suggest that disp is the most important variable in the model, but it could be that am is the most important but its influence is captured early in the decision tree structure.

Instead, we might count the number of times each variable is used for the first split in a decision tree within the random forest:

# Count how many times each variable is used for the first split in a decision tree
first.split <- sapply(1:rf_model$ntree, function(i) {
  rfTree <- getTree(rf_model, k = i, labelVar = TRUE)[1,3]
  variable.names(mtcars)[-1][rfTree]
})

table(first.split)
## first.split
##   am  cyl disp drat   hp qsec   vs   wt 
##    5   91  112   69   87   50   16   70

Now, as the first division tends to be the most important in terms of predicting the outcome, we can see that disp and cyl seem to be the most important variables in the model.

6.1.2 Permutation-based feature importance

One way to determine the importance of each variable in a random forest or gradient boosting machine is to use a permutation-based feature importance method. This method works as follows:

  1. Fit the model to the data.

  2. For each variable in turn,

    • randomly permute the values of that variable in the data;
    • calculate the change in the model’s performance.
  3. The importance of each variable is the average change in the model’s performance across all permutations.

If a variable is important in predicting the outcome, then permuting the values of that variable should result in a large change in the model’s performance. Conversely, if a variable is not important in predicting the outcome, then permuting the values of that variable should result in a small change in the model’s performance.

The algorithm briefly outlined here is a one-at-a-time method. It is possible that such a method may miss important interactions between variables. There are other methods that can be used to determine the importance of interactions between variables.

Example

Here we will look at this strategy for a linear regression model and consider its utility in contrast to the more traditional hypothesis tests. Again, instead of relying on a package, we will implement this method ourselves.

# Fit the model
lm_model <- lm(mpg ~ ., data = mtcars)

# Hypothesis test results for each variable
summary(lm_model)
## 
## Call:
## lm(formula = mpg ~ ., data = mtcars)
## 
## Residuals:
##     Min      1Q  Median      3Q     Max 
## -3.4506 -1.6044 -0.1196  1.2193  4.6271 
## 
## Coefficients:
##             Estimate Std. Error t value Pr(>|t|)  
## (Intercept) 12.30337   18.71788   0.657   0.5181  
## cyl         -0.11144    1.04502  -0.107   0.9161  
## disp         0.01334    0.01786   0.747   0.4635  
## hp          -0.02148    0.02177  -0.987   0.3350  
## drat         0.78711    1.63537   0.481   0.6353  
## wt          -3.71530    1.89441  -1.961   0.0633 .
## qsec         0.82104    0.73084   1.123   0.2739  
## vs           0.31776    2.10451   0.151   0.8814  
## am           2.52023    2.05665   1.225   0.2340  
## gear         0.65541    1.49326   0.439   0.6652  
## carb        -0.19942    0.82875  -0.241   0.8122  
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 2.65 on 21 degrees of freedom
## Multiple R-squared:  0.869,  Adjusted R-squared:  0.8066 
## F-statistic: 13.93 on 10 and 21 DF,  p-value: 3.793e-07
# Calculate the performance of the model
performance <- summary(lm_model)$r.squared

# Permute the values of each variable in turn
permuted_performance <- sapply(names(mtcars), function(var) {
  # Permute the values of the variable
  permuted_data <- mtcars
  permuted_data[[var]] <- sample(permuted_data[[var]])
  
  # Fit the model to the permuted data
  permuted_lm_model <- lm(mpg ~ ., data = permuted_data)
  
  # Calculate the performance of the model
  permuted_performance <- summary(permuted_lm_model)$r.squared
  
  # Return the change in performance
  return(performance - permuted_performance)
})

# Plot the results
barplot(permuted_performance, names.arg = names(mtcars), las = 2)
The change in performance of the model when the values of each variable are permuted.

Figure 6.1: The change in performance of the model when the values of each variable are permuted.

There are, of course, packages available in R that will implement this method for you. One such package is DALEX.

Example

Now, we will utilise the DALEX package to determine the importance of each variable in a random forest model for the Boston housing data.

library(DALEX)
## Registered S3 method overwritten by 'DALEX':
##   method            from     
##   print.description questionr
## Welcome to DALEX (version: 2.4.3).
## Find examples and detailed introduction at: http://ema.drwhy.ai/
# Fit the random forest model
rf_model <- randomForest(medv ~ ., data = Boston)

# Create an explainer object
explainer <- explain(rf_model, data = Boston[-14], y = Boston$medv)
## Preparation of a new explainer is initiated
##   -> model label       :  randomForest  (  default  )
##   -> data              :  506  rows  13  cols 
##   -> target variable   :  506  values 
##   -> predict function  :  yhat.randomForest  will be used (  default  )
##   -> predicted values  :  No value for predict function target column. (  default  )
##   -> model_info        :  package randomForest , ver. 4.7.1.1 , task regression (  default  ) 
##   -> predicted values  :  numerical, min =  6.763717 , mean =  22.53069 , max =  49.019  
##   -> residual function :  difference between y and yhat (  default  )
##   -> residuals         :  numerical, min =  -5.218475 , mean =  0.00211808 , max =  7.493544  
##   A new explainer has been created!
# Calculate the importance of each variable
variable_importance <- variable_importance(explainer)
variable_importance
##        variable mean_dropout_loss        label
## 1  _full_model_          1.365472 randomForest
## 2          chas          1.418405 randomForest
## 3            zn          1.433858 randomForest
## 4           rad          1.501006 randomForest
## 5         black          1.703902 randomForest
## 6           age          1.839409 randomForest
## 7           tax          1.900797 randomForest
## 8         indus          2.034476 randomForest
## 9       ptratio          2.284982 randomForest
## 10         crim          2.307187 randomForest
## 11          nox          2.525805 randomForest
## 12          dis          2.540718 randomForest
## 13           rm          5.093851 randomForest
## 14        lstat          6.428002 randomForest
## 15   _baseline_         12.510559 randomForest

Here, mean_dropout_loss is the average change in the model’s performance across all permutations. The larger the value, the more important the variable is in predicting the outcome.

6.2 Main effect visualisations

6.2.1 Main effects

A main effect is the effect of a variable on the outcome, while holding all other variables constant. To help us to understand the utility of these let’s consider a situation where we know the true model generating the data.

Example

Here, we have one outcome variable and two explanatory variables. The true model is:

\[ Y = X_1 + X_2^2. \] We will have strongly correlated explanatory variables:

\[\begin{align*} X_1 &\sim \text{Uni}(0, 1), \\ X_2 &= X_1 + \epsilon, \quad \epsilon \sim \text{N}(0, 0.02). \end{align*}\]

Here’s some data for us to work with:

The relationship between the variables in the data.

Figure 6.2: The relationship between the variables in the data.

It is trivial to derive the main effect of the model for fixed values of \(X_1\).

Similarly for fixed values of \(X_2\):

[Add in derived effects]

Another powerful method for interpreting the results of a model is to use the iml package. This package provides a unified interface for interpreting the results of a wide variety of models. This package stems from a 2020 paper by Apley (2020) and is a powerful tool for interpreting the results of a wide variety of models. [This is incorrect, but both need to be included in the book].

6.2.2 Partial dependence plots

One method for interpreting the results of a model is to use partial dependence plots. A partial dependence plot shows the relationship between a variable and the outcome, while holding all other variables constant. This allows us to see the effect of a variable on the outcome, while controlling for the effects of other variables.

Example

Here, we will use the iml package to create a partial dependence plot for a decision tree model for our toy example.

[Add in example]

6.2.3 Individual conditional expectation plots

Another method for interpreting the results of a model is to use individual conditional expectation plots. An individual conditional expectation plot shows the relationship between a variable and the outcome for a single observation, while holding all other variables constant. This allows us to see the effect of a variable on the outcome for a single observation, while controlling for the effects of other variables.

Example

Here, we will use the iml package to create an individual conditional expectation plot for a decision tree model for our toy example.

[Add in example]

6.2.4 Accumulated local effects

Another method for interpreting the results of a model is to use accumulated local effects. Accumulated local effects show the effect of a variable on the outcome, while taking into account the effects of other variables. This allows us to see the effect of a variable on the outcome, while controlling for the effects of other variables.

Example

Here, we will use the iml package to create an ALE plot for a decision tree model for our toy example.

[Add in example]

6.2.5 Shapley plots

Another method for interpreting the results of a model is to use Shapley plots. Shapley plots again show the effect of a variable on the outcome, while taking into account the effects of other variables.

Example

Here, we will use the iml package to create a Shapley plot for a decision tree model for our toy example.

[Add in example]