scatter plot

This notebook is designed to demonstrate (and so document) how to use the shap.plots.scatter function. It uses an XGBoost model trained on the classic UCI adult income dataset (which is a classification task to predict if people made over \$50k in the 90s).

import xgboost
import shap

# train XGBoost model
X,y =
model = xgboost.XGBClassifier().fit(X, y)

# compute SHAP values
explainer = shap.Explainer(model, X)
shap_values = explainer(X[:1000])

Simple dependence scatter plot

A dependence scatter plot shows the effect a single feature has on the predictions made by the model. In this example the log-odds of making over 50k increases significantly between age 20 and 40.

  • Each dot is a single prediction (row) from the dataset.

  • The x-axis is the value of the feature (from the X matrix, stored in

  • The y-axis is the SHAP value for that feature (stored in shap_values.values), which represents how much knowing that feature’s value changes the output of the model for that sample’s prediction. For this model the units are log-odds of making over 50k annually.

  • The light grey area at the bottom of the plot is a histogram showing the distribution of data values.

# Note that we are slicing off the column of the shap_values Explanation corresponding to the "Age" feature

Using color to highlight interaction effects

The vertical dispersion in the plot above shows that the same value for the Age feature can have a different impact on the model’s output for different people. This means there are non-linear interaction effects in the model between Age and other features (otherwise the scatter plot would perfectly follow the line given by shap.plots.partial_dependence).

To show which feature may be driving these interaction effects we can color our Age dependence scatter plot by another feature. If we pass the entire Explanation object to the color parameter then the scatter plot attempts to pick out the feature column with the strongest interaction with Age. If an interaction effect is present between this other feature and the feature we are plotting it will show up as a distinct vertical pattern of coloring. For the example below, 20-year-olds with a high level of education are less likely make over \$50k than 20-year-olds with a low level of education. This suggests an interaction effect between Education-Num and Age.

shap.plots.scatter(shap_values[:, "Age"], color=shap_values)

To explicitly control which feature is used for coloring you can pass a specific feature column to the color parameter.

shap.plots.scatter(shap_values[:, "Age"], color=shap_values[:,"Workclass"])

In the plot above we see that the Workclass feature is encoded with a number for the sake of the XGBoost model. When plotting though we often would rather use the original string values before they were categorically encoded. To do this we can set the .display_data property of the Explanation object to a parallel version of the data we would like displayed in plots.

X_display,y =
shap_values.display_data = X_display.values

shap.plots.scatter(shap_values[:, "Age"], color=shap_values[:,"Workclass"])

Using global feature importance orderings

Sometimes we don’t know the name or index of the feature we want to plot, we just want to plot the most important features. To do that we can use the dot-chaining capability of the Explanation object to compute a measure of global feature importance, sort by that measure (descending), and then pick out the top feature (which in this case is Age):

shap.plots.scatter(shap_values[:, shap_values.abs.mean(0).argsort[-1]])

Note that how we chose to measure the global importance of a feature will impact the ranking we get. In this example Age is the feature with the largest mean absolute value of the whole dataset, but Capital gain is the feature with the largest absolute impact for any sample.

shap.plots.scatter(shap_values[:, shap_values.abs.max(0).argsort[-1]])

The max function is potentially sensitive to outliers. A more robust option would be to use the percentile function. Here we sort the features by their 95th percentile absolute value and find the Capital gain has the largest 95th percentile value:

shap.plots.scatter(shap_values[:, shap_values.abs.percentile(95, 0).argsort[-1]])

Exploring different interaction colorings

# we can use shap.approximate_interactions to guess which features
# may interact with age
inds = shap.utils.potential_interactions(shap_values[:, "Age"], shap_values)

# make plots colored by each of the top three possible interacting features
for i in range(3):
    shap.plots.scatter(shap_values[:,"Age"], color=shap_values[:,inds[i]])

Customizing the figure properties

# by passing show=False you can prevent shap.dependence_plot from calling
# the matplotlib show() function, and so you can keep customizing the plot
# before eventually calling show yourself
import matplotlib.pyplot as plt
shap.plots.scatter(shap_values[:,"Age"], show=False)
plt.title("Age dependence plot")
plt.ylabel("SHAP value for the 'Age' feature")
# plt.savefig("my_dependence_plot.pdf") # we can save a PDF of the figure if we want
# you can use xmax and xmin with a percentile notation to hide outliers.
# note that the .percentile method applies to both the .values and .data properties
# of the Explanation object, and the scatter plots knows to use the .data propoerty
# when passed to the xmin or xmax arguments.
age = shap_values[:,"Age"]
shap.plots.scatter(age, xmin=age.percentile(1), xmax=age.percentile(99))
# you can use ymax and ymin with a percentile notation to hide vertical outliers.
# note that now the scatter plot uses the .value property for ymin and ymax if
# an explanation object is passed in those parameters.
age = shap_values[:,"Age"]
shap.plots.scatter(age, ymin=age.percentile(1), ymax=age.percentile(99))
# transparency can help reveal dense vs. sparse areas of the scatter plot
shap.plots.scatter(shap_values[:,"Age"], alpha=0.1)
# transparency can help reveal dense vs. sparse areas of the scatter plot
shap.plots.scatter(shap_values[:,"Age"], dot_size=2, color=shap_values)
# for categorical (or binned) data adding a small amount of x-jitter makes
# thin columns of dots more readable
shap.plots.scatter(shap_values[:,"Age"], dot_size=2, x_jitter=1, color=shap_values)
shap.plots.scatter(shap_values[:,"Age"], dot_size=4, x_jitter=1, color=shap_values, xmin=20, xmax=60, ymin=-1, ymax=2)
# for categorical (or binned) data adding a small amount of x-jitter makes
# thin columns of dots more readable
shap.plots.scatter(shap_values[:,"Relationship"], dot_size=2, x_jitter=0.5, color=shap_values)
import matplotlib.pyplot as plt

# you can use the cmap parameter to provide your own custom color map
shap.plots.scatter(shap_values[:,"Age"], color=shap_values, cmap=plt.get_cmap("cool"))

Have an idea for more helpful examples? Pull requests that add to this documentation notebook are encouraged!