Fitting a Linear Simulation with XGBoost

This notebook uses shap to demonstrate how XGBoost behaves when we fit it to simulated data where the label has a linear relationship to the features.

[1]:
import numpy as np
import sklearn
import xgboost
from sklearn.model_selection import train_test_split

import shap
/Users/slundberg/anaconda3/lib/python3.7/site-packages/sklearn/ensemble/weight_boosting.py:29: DeprecationWarning: numpy.core.umath_tests is an internal NumPy module and should not be imported. It will be removed in a future NumPy release.
  from numpy.core.umath_tests import inner1d

Build a simulated dataset with binary labels

[2]:
N = 10000
M = 10
np.random.seed(0)
X = np.random.randn(N, M)
beta = np.random.randn(M)
y_margin = X @ beta
y = y_margin

X_train, X_test, y_train, y_test = train_test_split(X, y)
X_strain, X_valid, y_strain, y_valid = train_test_split(X_train, y_train)

Build an XGBoost regressor

Train a depth 1 model

[3]:
model_depth1 = xgboost.XGBRegressor(
    max_depth=1,
    learning_rate=0.01,
    subsample=0.5,
    n_estimators=10000,
    base_score=y_strain.mean(),
)
model_depth1.fit(
    X_strain,
    y_strain,
    eval_set=[(X_valid, y_valid)],
    eval_metric="logloss",
    verbose=1000,
    early_stopping_rounds=20,
)
[0]     validation_0-logloss:-1.6359
Will train until validation_0-logloss hasn't improved in 20 rounds.
[1000]  validation_0-logloss:-45.6925
Stopping. Best iteration:
[1395]  validation_0-logloss:-47.0788

[3]:
XGBRegressor(base_score=0.013271975120564434, booster='gbtree',
       colsample_bylevel=1, colsample_bytree=1, gamma=0,
       learning_rate=0.01, max_delta_step=0, max_depth=1,
       min_child_weight=1, missing=None, n_estimators=10000, n_jobs=1,
       nthread=None, objective='reg:linear', random_state=0, reg_alpha=0,
       reg_lambda=1, scale_pos_weight=1, seed=None, silent=True,
       subsample=0.5)

Train a depth 3 model

[4]:
model_depth3 = xgboost.XGBRegressor(
    learning_rate=0.02,
    subsample=0.2,
    colsample_bytree=0.5,
    n_estimators=5000,
    base_score=y_strain.mean(),
)
model_depth3.fit(
    X_strain,
    y_strain,
    eval_set=[(X_valid, y_valid)],
    eval_metric="logloss",
    verbose=500,
    early_stopping_rounds=20,
)
[0]     validation_0-logloss:-13.3182
Will train until validation_0-logloss hasn't improved in 20 rounds.
[500]   validation_0-logloss:-48.8915
Stopping. Best iteration:
[612]   validation_0-logloss:-49.0554

[4]:
XGBRegressor(base_score=0.013271975120564434, booster='gbtree',
       colsample_bylevel=1, colsample_bytree=0.5, gamma=0,
       learning_rate=0.02, max_delta_step=0, max_depth=3,
       min_child_weight=1, missing=None, n_estimators=5000, n_jobs=1,
       nthread=None, objective='reg:linear', random_state=0, reg_alpha=0,
       reg_lambda=1, scale_pos_weight=1, seed=None, silent=True,
       subsample=0.2)

Explain the depth 1 model

[6]:
shap_values = shap.TreeExplainer(model_depth1).shap_values(X_test)
shap_interaction_values = shap.TreeExplainer(model_depth1).shap_interaction_values(
    X_test
)

Summary bar plot shows the global importance of each feature

[7]:
shap.summary_plot(shap_values, X_test, plot_type="bar")
../../../_images/example_notebooks_tabular_examples_tree_based_models_Fitting_a_Linear_Simulation_with_XGBoost_12_0.png

Summary bee-swarm plot shows the global importance of each feature and the distribution of effect sizes

[8]:
shap.summary_plot(shap_values, X_test)
../../../_images/example_notebooks_tabular_examples_tree_based_models_Fitting_a_Linear_Simulation_with_XGBoost_14_0.png

The dependence plot for the top feature shows that XGBoost captured most the linear relationship

It is important to note that XGBoost (and other gradient boosted tree models) is biased towards flat regions, which can be seen below by the flattened tails of the linear function.

[9]:
shap.dependence_plot(8, shap_values, X_test)
../../../_images/example_notebooks_tabular_examples_tree_based_models_Fitting_a_Linear_Simulation_with_XGBoost_16_0.png

As expected there are no interactions for the depth-1 model:

We can also see this from the lack of any vertical dispresion in the dependence plot above.

[10]:
shap.dependence_plot((8, 1), shap_interaction_values, X_test)
../../../_images/example_notebooks_tabular_examples_tree_based_models_Fitting_a_Linear_Simulation_with_XGBoost_18_0.png

The tail flattening behavior is consistent across all other features.

[11]:
shap.dependence_plot(1, shap_values, X_test)
/Users/slundberg/anaconda3/lib/python3.7/site-packages/numpy/lib/function_base.py:2400: RuntimeWarning: invalid value encountered in true_divide
  c /= stddev[:, None]
/Users/slundberg/anaconda3/lib/python3.7/site-packages/numpy/lib/function_base.py:2401: RuntimeWarning: invalid value encountered in true_divide
  c /= stddev[None, :]
../../../_images/example_notebooks_tabular_examples_tree_based_models_Fitting_a_Linear_Simulation_with_XGBoost_20_1.png

Note that weaker signal lead to more variability in the fits:

Since XGBoost like flat regions the variability will often look like step functions. Remember in the plot below the SHAP values are correctly telling you what the model learned, but the model did not learn a smooth line.

[12]:
shap.dependence_plot(4, shap_values, X_test)
/Users/slundberg/anaconda3/lib/python3.7/site-packages/numpy/lib/function_base.py:2400: RuntimeWarning: invalid value encountered in true_divide
  c /= stddev[:, None]
/Users/slundberg/anaconda3/lib/python3.7/site-packages/numpy/lib/function_base.py:2401: RuntimeWarning: invalid value encountered in true_divide
  c /= stddev[None, :]
../../../_images/example_notebooks_tabular_examples_tree_based_models_Fitting_a_Linear_Simulation_with_XGBoost_22_1.png

Explain the depth 3 model

In this simulation we know that the true relationships are linear without any interactions. However when we fit trees with depth greater than 1, we are telling the model to look for interactions. When we explain our depth 3 model we see that it did learn some weak (incorrect) interactions.

[13]:
e3 = shap.TreeExplainer(model_depth3)
shap_values3 = e3.shap_values(X_test)
shap_interaction_values3 = shap.TreeExplainer(model_depth3).shap_interaction_values(
    X_test
)

The bar chart of global importance is basically the same as depth 1.

[14]:
shap.summary_plot(shap_values3, X_test, plot_type="bar")
../../../_images/example_notebooks_tabular_examples_tree_based_models_Fitting_a_Linear_Simulation_with_XGBoost_26_0.png

The bee-swarm summary plots are smoother than with the depth 1 model (see to dependency plots for why).

[15]:
shap.summary_plot(shap_values3, X_test)
../../../_images/example_notebooks_tabular_examples_tree_based_models_Fitting_a_Linear_Simulation_with_XGBoost_28_0.png

The vertical interaction dispersion from the depth 3 tree smoothes over any small steps.

This is what made the bee-swarm summary plot look more even. Note also though there when we color by feature 1 we seem to see a consistent interation effect.

[16]:
shap.dependence_plot(8, shap_values3, X_test)
../../../_images/example_notebooks_tabular_examples_tree_based_models_Fitting_a_Linear_Simulation_with_XGBoost_30_0.png

If we look more closely and plot the interaction value between 6 and 1 we see a seemingly clear pattern:

[17]:
shap.dependence_plot((8, 1), shap_interaction_values3, X_test)
../../../_images/example_notebooks_tabular_examples_tree_based_models_Fitting_a_Linear_Simulation_with_XGBoost_32_0.png

The same interaction effect is observed on a new set set:

This means the model really did learn this interaction, even though there was no interaction there to learn. How can protect ourselves from jumping to the conclusion that this is a real interaction? (without using the fact that we simluated this data)

[18]:
X_tmp = np.random.randn(*X_test.shape)
tmp_values = shap.TreeExplainer(model_depth3).shap_interaction_values(X_tmp)
shap.dependence_plot((8, 1), tmp_values, X_tmp)
../../../_images/example_notebooks_tabular_examples_tree_based_models_Fitting_a_Linear_Simulation_with_XGBoost_34_0.png

The same interaction for retrained models.

The structure (what little there was) seems to go away when we retrain the model on bootstrap resamples.

[19]:
for i in range(5):
    print(i)
    X_strain_tmp, y_strain_tmp = sklearn.utils.resample(X_strain, y_strain)
    X_valid_tmp, y_valid_tmp = sklearn.utils.resample(X_valid, y_valid)
    X_test_tmp, y_test_tmp = sklearn.utils.resample(X_test, y_test)

    model_tmp = xgboost.XGBRegressor(
        learning_rate=0.01, subsample=0.5, n_estimators=5000, base_score=y_strain.mean()
    )
    model_tmp.fit(
        X_strain_tmp,
        y_strain_tmp,
        eval_set=[(X_valid_tmp, y_valid_tmp)],
        eval_metric="logloss",
        verbose=500,
        early_stopping_rounds=20,
    )

    tmp_values = shap.TreeExplainer(model_tmp).shap_interaction_values(X_test_tmp)
    shap.dependence_plot((8, 1), tmp_values, X_test_tmp)
0
[0]     validation_0-logloss:-15.2753
Will train until validation_0-logloss hasn't improved in 20 rounds.
[500]   validation_0-logloss:-46.3412
Stopping. Best iteration:
[532]   validation_0-logloss:-46.6178

../../../_images/example_notebooks_tabular_examples_tree_based_models_Fitting_a_Linear_Simulation_with_XGBoost_36_1.png
1
[0]     validation_0-logloss:-10.6659
Will train until validation_0-logloss hasn't improved in 20 rounds.
[500]   validation_0-logloss:-48.2467
[1000]  validation_0-logloss:-49.8021
Stopping. Best iteration:
[984]   validation_0-logloss:-49.8074

../../../_images/example_notebooks_tabular_examples_tree_based_models_Fitting_a_Linear_Simulation_with_XGBoost_36_3.png
2
[0]     validation_0-logloss:-14.5348
Will train until validation_0-logloss hasn't improved in 20 rounds.
[500]   validation_0-logloss:-46.558
[1000]  validation_0-logloss:-48.0845
Stopping. Best iteration:
[1310]  validation_0-logloss:-48.3214

../../../_images/example_notebooks_tabular_examples_tree_based_models_Fitting_a_Linear_Simulation_with_XGBoost_36_5.png
3
[0]     validation_0-logloss:-15.4186
Will train until validation_0-logloss hasn't improved in 20 rounds.
[500]   validation_0-logloss:-47.0986
Stopping. Best iteration:
[946]   validation_0-logloss:-48.5674

../../../_images/example_notebooks_tabular_examples_tree_based_models_Fitting_a_Linear_Simulation_with_XGBoost_36_7.png
4
[0]     validation_0-logloss:-13.7396
Will train until validation_0-logloss hasn't improved in 20 rounds.
[500]   validation_0-logloss:-47.5456
Stopping. Best iteration:
[755]   validation_0-logloss:-48.7873

../../../_images/example_notebooks_tabular_examples_tree_based_models_Fitting_a_Linear_Simulation_with_XGBoost_36_9.png

Visualize a single prediction

[31]:
shap.initjs()
[32]:
e3 = shap.TreeExplainer(model_depth3)
t = e3.shap_values(X_test)
[33]:
shap.force_plot(e3.expected_value, shap_values[0, :], X_test[0, :])
[33]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
[34]:
shap.force_plot(e3.expected_value, shap_values[0:500, :], X_test[0:500, :])
[34]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
[ ]: