API Reference

This page contains the API reference for public objects and functions in SHAP. There are also example notebooks available that demonstrate how to use the API of each object/function.

Explanation

shap.Explanation(values[, base_values, ...])

A sliceable set of parallel arrays representing a SHAP explanation.

explainers

shap.Explainer(model[, masker, link, ...])

Uses Shapley values to explain any machine learning model or python function.

shap.TreeExplainer(model[, data, ...])

Uses Tree SHAP algorithms to explain the output of ensemble tree models.

shap.GPUTreeExplainer(model[, data, ...])

Experimental GPU accelerated version of TreeExplainer.

shap.LinearExplainer(model, masker[, link, ...])

Computes SHAP values for a linear model, optionally accounting for inter-feature correlations.

shap.PermutationExplainer(model, masker[, ...])

This method approximates the Shapley values by iterating through permutations of the inputs.

shap.PartitionExplainer(model, masker, *[, ...])

Uses the Partition SHAP method to explain the output of any function.

shap.SamplingExplainer(model, data, **kwargs)

Computes SHAP values using an extension of the Shapley sampling values explanation method (also known as IME).

shap.AdditiveExplainer(model, masker[, ...])

Computes SHAP values for generalized additive models.

shap.DeepExplainer(model, data[, session, ...])

Meant to approximate SHAP values for deep learning models.

shap.KernelExplainer(model, data[, ...])

Uses the Kernel SHAP method to explain the output of any function.

shap.GradientExplainer(model, data[, ...])

Explains a model using expected gradients (an extension of integrated gradients).

shap.ExactExplainer(model, masker[, link, ...])

Computes SHAP values via an optimized exact enumeration.

shap.explainers.other.Coefficient(model)

Simply returns the model coefficients as the feature attributions.

shap.explainers.other.Random(model, masker)

Simply returns random (normally distributed) feature attributions.

shap.explainers.other.LimeTabular(model, data)

Simply wrap of lime.lime_tabular.LimeTabularExplainer into the common shap interface.

shap.explainers.other.Maple(model, data)

Simply wraps MAPLE into the common SHAP interface.

shap.explainers.other.TreeMaple(model, data)

Simply tree MAPLE into the common SHAP interface.

shap.explainers.other.TreeGain(model)

Simply returns the global gain/gini feature importances for tree models.

plots

shap.plots.bar(shap_values[, max_display, ...])

Create a bar plot of a set of SHAP values.

shap.plots.waterfall(shap_values[, ...])

Plots an explanation of a single prediction as a waterfall plot.

shap.plots.scatter(shap_values[, color, ...])

Create a SHAP dependence scatter plot, colored by an interaction feature.

shap.plots.heatmap(shap_values[, ...])

Create a heatmap plot of a set of SHAP values.

shap.plots.force(base_value[, shap_values, ...])

Visualize the given SHAP values with an additive force layout.

shap.plots.text(shap_values[, ...])

Plots an explanation of a string of text using coloring and interactive labels.

shap.plots.image(shap_values[, ...])

Plots SHAP values for image inputs.

shap.plots.partial_dependence(ind, model, data)

A basic partial dependence plot function.

shap.plots.decision(base_value, shap_values)

Visualize model decisions using cumulative SHAP values.

shap.plots.embedding(ind, shap_values[, ...])

Use the SHAP values as an embedding which we project to 2D for visualization.

shap.plots.initjs()

Initialize the necessary javascript libraries for interactive force plots.

shap.plots.group_difference(shap_values, ...)

This plots the difference in mean SHAP values between two groups.

shap.plots.image_to_text(shap_values)

Plots SHAP values for image inputs with test outputs.

shap.plots.monitoring(ind, shap_values, features)

Create a SHAP monitoring plot.

shap.plots.beeswarm(shap_values[, ...])

Create a SHAP beeswarm plot, colored by feature values when they are provided.

shap.plots.violin(shap_values[, features, ...])

Create a SHAP violin plot, colored by feature values when they are provided.

maskers

shap.maskers.Masker()

This is the superclass of all maskers.

shap.maskers.Independent(data[, max_samples])

This masks out tabular features by integrating over the given background dataset.

shap.maskers.Partition(data[, max_samples, ...])

This masks out tabular features by integrating over the given background dataset.

shap.maskers.Impute(data[, method])

This imputes the values of missing features using the values of the observed features.

shap.maskers.Fixed()

This leaves the input unchanged during masking, and is used for things like scoring labels.

shap.maskers.Composite(*maskers)

This merges several maskers for different inputs together into a single composite masker.

shap.maskers.FixedComposite(masker)

A masker that outputs both the masked data and the original data as a pair.

shap.maskers.OutputComposite(masker, model)

A masker that is a combination of a masker and a model and outputs both masked args and the model's output.

shap.maskers.Text([tokenizer, mask_token, ...])

This masks out tokens according to the given tokenizer.

shap.maskers.Image(mask_value[, shape])

Masks out image regions with blurring or inpainting.

models

shap.models.Model([model])

This is the superclass of all models.

shap.models.TeacherForcing(model[, ...])

Generates scores (log odds) for output text explanation algorithms using Teacher Forcing technique.

shap.models.TextGeneration([model, ...])

Generates target sentence/ids using a base model.

shap.models.TopKLM(model, tokenizer[, k, ...])

Generates scores (log odds) for the top-k tokens for Causal/Masked LM.

shap.models.TransformersPipeline(pipeline[, ...])

This wraps a transformers pipeline object for easy explanations.

utils

shap.utils.hclust(X[, y, linkage, metric, ...])

Fit a hierarcical clustering model for features X relative to target variable y.

shap.utils.hclust_ordering(X[, metric, ...])

A leaf ordering is under-defined, this picks the ordering that keeps nearby samples similar.

shap.utils.partition_tree(X[, metric])

shap.utils.partition_tree_shuffle(indexes, ...)

Randomly shuffle the indexes in a way that is consistent with the given partition tree.

shap.utils.delta_minimization_order(all_masks)

shap.utils.approximate_interactions(index, ...)

Order other features by how much interaction they seem to have with the feature at the given index.

shap.utils.potential_interactions(...)

Order other features by how much interaction they seem to have with the feature at the given index.

shap.utils.sample(X[, nsamples, random_state])

Performs sampling without replacement of the input data X.

shap.utils.shapley_coefficients(n)

shap.utils.convert_name(ind, shap_values, ...)

shap.utils.OpChain([root_name])

A way to represent a set of dot chained operations on an object without actually running them.

shap.utils.show_progress(iterable[, total, ...])

shap.utils.MaskedModel(model, masker, link, ...)

This is a utility class that combines a model, a masker object, and a current input.

shap.utils.make_masks(cluster_matrix)

Builds a sparse CSR mask matrix from the given clustering.

datasets

shap.datasets.a1a([n_points])

A sparse dataset in scipy csr matrix format.

shap.datasets.adult([display, n_points])

Return the Adult census data in a nice package.

shap.datasets.california([display, n_points])

Return the california housing data in a nice package.

shap.datasets.communitiesandcrime([display, ...])

Predict total number of non-violent crimes per 100K popuation.

shap.datasets.corrgroups60([display, n_points])

Correlated Groups 60

shap.datasets.diabetes([display, n_points])

Return the diabetes data in a nice package.

shap.datasets.imagenet50([display, ...])

This is a set of 50 images representative of ImageNet images.

shap.datasets.imdb([display, n_points])

Return the classic IMDB sentiment analysis training data in a nice package.

shap.datasets.independentlinear60([display, ...])

A simulated dataset with tight correlations among distinct groups of features.

shap.datasets.iris([display, n_points])

Return the classic iris data in a nice package.

shap.datasets.linnerud([display, n_points])

Return the linnerud data in a nice package (multi-target regression).

shap.datasets.nhanesi([display, n_points])

A nicely packaged version of NHANES I data with surivival times as labels.

shap.datasets.rank()

Ranking datasets from lightgbm repository.