Python Version of Tree SHAP
This is a sample implementation of Tree SHAP written in Python for easy reading.
[1]:
import time
import numba
import numpy as np
import pandas as pd
import sklearn.ensemble
import xgboost
import shap
c:\programming\shap\.venv\Lib\site-packages\tqdm\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
Load California dataset
[2]:
X, y = shap.datasets.california(n_points=1000)
X.shape
[2]:
(1000, 8)
Train sklearn random forest
[3]:
model = sklearn.ensemble.RandomForestRegressor(n_estimators=1000, max_depth=4)
model.fit(X, y)
[3]:
RandomForestRegressor(max_depth=4, n_estimators=1000)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Parameters
Train XGBoost model
[4]:
bst = xgboost.train({"learning_rate": 0.01, "max_depth": 4}, xgboost.DMatrix(X, label=y), 1000)
Python TreeExplainer
This uses numba to speed things up.
[5]:
class TreeExplainer:
def __init__(self, model, **kwargs):
if str(type(model)).endswith("sklearn.ensemble._forest.RandomForestRegressor'>"):
# self.trees = [Tree(e.tree_) for e in model.estimators_]
self.trees = [
Tree(
children_left=e.tree_.children_left,
children_right=e.tree_.children_right,
children_default=e.tree_.children_right,
feature=e.tree_.feature,
threshold=e.tree_.threshold,
value=e.tree_.value[:, 0, 0],
node_sample_weight=e.tree_.weighted_n_node_samples,
)
for e in model.estimators_
]
# Preallocate space for the unique path data
maxd = np.max([t.max_depth for t in self.trees]) + 2
s = (maxd * (maxd + 1)) // 2
self.feature_indexes = np.zeros(s, dtype=np.int32)
self.zero_fractions = np.zeros(s, dtype=np.float64)
self.one_fractions = np.zeros(s, dtype=np.float64)
self.pweights = np.zeros(s, dtype=np.float64)
def shap_values(self, X, **kwargs):
# convert dataframes
if isinstance(X, pd.Series):
X = X.values
elif isinstance(X, pd.DataFrame):
X = X.values
X = np.array(X, dtype=np.float64, copy=True, order="C")
assert isinstance(X, np.ndarray), "Unknown instance type: " + str(type(X))
assert len(X.shape) == 1 or len(X.shape) == 2, "Instance must have 1 or 2 dimensions!"
# single instance
if len(X.shape) == 1:
phi = np.zeros(X.shape[0] + 1)
x_missing = np.zeros(X.shape[0], dtype=bool)
for t in self.trees:
self.tree_shap(t, X, x_missing, phi)
phi /= len(self.trees)
elif len(X.shape) == 2:
phi = np.zeros((X.shape[0], X.shape[1] + 1))
x_missing = np.zeros(X.shape[1], dtype=bool)
for i in range(X.shape[0]):
for t in self.trees:
self.tree_shap(t, X[i, :], x_missing, phi[i, :])
phi /= len(self.trees)
return phi
def tree_shap(self, tree, x, x_missing, phi, condition=0, condition_feature=0):
# update the bias term, which is the last index in phi
# (note the paper has this as phi_0 instead of phi_M)
if condition == 0:
phi[-1] += tree.values[0]
# start the recursive algorithm
tree_shap_recursive(
tree.children_left,
tree.children_right,
tree.children_default,
tree.features,
tree.thresholds,
tree.values,
tree.node_sample_weight,
x,
x_missing,
phi,
0,
0,
self.feature_indexes,
self.zero_fractions,
self.one_fractions,
self.pweights,
1.0,
1.0,
-1,
condition,
condition_feature,
1.0,
)
[6]:
# extend our decision path with a fraction of one and zero extensions
@numba.jit(
numba.types.void(
numba.types.int32[:],
numba.types.float64[:],
numba.types.float64[:],
numba.types.float64[:],
numba.types.int32,
numba.types.float64,
numba.types.float64,
numba.types.int32,
),
nopython=True,
nogil=True,
)
def extend_path(
feature_indexes,
zero_fractions,
one_fractions,
pweights,
unique_depth,
zero_fraction,
one_fraction,
feature_index,
):
feature_indexes[unique_depth] = feature_index
zero_fractions[unique_depth] = zero_fraction
one_fractions[unique_depth] = one_fraction
if unique_depth == 0:
pweights[unique_depth] = 1
else:
pweights[unique_depth] = 0
for i in range(unique_depth - 1, -1, -1):
pweights[i + 1] += one_fraction * pweights[i] * (i + 1) / (unique_depth + 1)
pweights[i] = zero_fraction * pweights[i] * (unique_depth - i) / (unique_depth + 1)
# undo a previous extension of the decision path
@numba.jit(
numba.types.void(
numba.types.int32[:],
numba.types.float64[:],
numba.types.float64[:],
numba.types.float64[:],
numba.types.int32,
numba.types.int32,
),
nopython=True,
nogil=True,
)
def unwind_path(feature_indexes, zero_fractions, one_fractions, pweights, unique_depth, path_index):
one_fraction = one_fractions[path_index]
zero_fraction = zero_fractions[path_index]
next_one_portion = pweights[unique_depth]
for i in range(unique_depth - 1, -1, -1):
if one_fraction != 0:
tmp = pweights[i]
pweights[i] = next_one_portion * (unique_depth + 1) / ((i + 1) * one_fraction)
next_one_portion = tmp - pweights[i] * zero_fraction * (unique_depth - i) / (unique_depth + 1)
else:
pweights[i] = (pweights[i] * (unique_depth + 1)) / (zero_fraction * (unique_depth - i))
for i in range(path_index, unique_depth):
feature_indexes[i] = feature_indexes[i + 1]
zero_fractions[i] = zero_fractions[i + 1]
one_fractions[i] = one_fractions[i + 1]
# determine what the total permuation weight would be if
# we unwound a previous extension in the decision path
@numba.jit(
numba.types.float64(
numba.types.int32[:],
numba.types.float64[:],
numba.types.float64[:],
numba.types.float64[:],
numba.types.int32,
numba.types.int32,
),
nopython=True,
nogil=True,
)
def unwound_path_sum(feature_indexes, zero_fractions, one_fractions, pweights, unique_depth, path_index):
one_fraction = one_fractions[path_index]
zero_fraction = zero_fractions[path_index]
next_one_portion = pweights[unique_depth]
total = 0
for i in range(unique_depth - 1, -1, -1):
if one_fraction != 0:
tmp = next_one_portion * (unique_depth + 1) / ((i + 1) * one_fraction)
total += tmp
next_one_portion = pweights[i] - tmp * zero_fraction * ((unique_depth - i) / (unique_depth + 1))
else:
total += (pweights[i] / zero_fraction) / ((unique_depth - i) / (unique_depth + 1))
return total
class Tree:
def __init__(
self,
children_left,
children_right,
children_default,
feature,
threshold,
value,
node_sample_weight,
):
self.children_left = children_left.astype(np.int32)
self.children_right = children_right.astype(np.int32)
self.children_default = children_default.astype(np.int32)
self.features = feature.astype(np.int32)
self.thresholds = threshold
self.values = value
self.node_sample_weight = node_sample_weight
self.max_depth = compute_expectations(
self.children_left,
self.children_right,
self.node_sample_weight,
self.values,
0,
)
@numba.jit(nopython=True)
def compute_expectations(children_left, children_right, node_sample_weight, values, i, depth=0):
if children_right[i] == -1:
values[i] = values[i]
return 0
else:
li = children_left[i]
ri = children_right[i]
depth_left = compute_expectations(children_left, children_right, node_sample_weight, values, li, depth + 1)
depth_right = compute_expectations(children_left, children_right, node_sample_weight, values, ri, depth + 1)
left_weight = node_sample_weight[li]
right_weight = node_sample_weight[ri]
v = (left_weight * values[li] + right_weight * values[ri]) / (left_weight + right_weight)
values[i] = v
return max(depth_left, depth_right) + 1
# recursive computation of SHAP values for a decision tree
@numba.jit(
numba.types.void(
numba.types.int32[:],
numba.types.int32[:],
numba.types.int32[:],
numba.types.int32[:],
numba.types.float64[:],
numba.types.float64[:],
numba.types.float64[:],
numba.types.float64[:],
numba.types.boolean[:],
numba.types.float64[:],
numba.types.int64,
numba.types.int64,
numba.types.int32[:],
numba.types.float64[:],
numba.types.float64[:],
numba.types.float64[:],
numba.types.float64,
numba.types.float64,
numba.types.int64,
numba.types.int64,
numba.types.int64,
numba.types.float64,
),
nopython=True,
nogil=True,
)
def tree_shap_recursive(
children_left,
children_right,
children_default,
features,
thresholds,
values,
node_sample_weight,
x,
x_missing,
phi,
node_index,
unique_depth,
parent_feature_indexes,
parent_zero_fractions,
parent_one_fractions,
parent_pweights,
parent_zero_fraction,
parent_one_fraction,
parent_feature_index,
condition,
condition_feature,
condition_fraction,
):
# stop if we have no weight coming down to us
if condition_fraction == 0:
return
# extend the unique path
feature_indexes = parent_feature_indexes[unique_depth + 1 :]
feature_indexes[: unique_depth + 1] = parent_feature_indexes[: unique_depth + 1]
zero_fractions = parent_zero_fractions[unique_depth + 1 :]
zero_fractions[: unique_depth + 1] = parent_zero_fractions[: unique_depth + 1]
one_fractions = parent_one_fractions[unique_depth + 1 :]
one_fractions[: unique_depth + 1] = parent_one_fractions[: unique_depth + 1]
pweights = parent_pweights[unique_depth + 1 :]
pweights[: unique_depth + 1] = parent_pweights[: unique_depth + 1]
if condition == 0 or condition_feature != parent_feature_index:
extend_path(
feature_indexes,
zero_fractions,
one_fractions,
pweights,
unique_depth,
parent_zero_fraction,
parent_one_fraction,
parent_feature_index,
)
split_index = features[node_index]
# leaf node
if children_right[node_index] == -1:
for i in range(1, unique_depth + 1):
w = unwound_path_sum(
feature_indexes,
zero_fractions,
one_fractions,
pweights,
unique_depth,
i,
)
phi[feature_indexes[i]] += (
w * (one_fractions[i] - zero_fractions[i]) * values[node_index] * condition_fraction
)
# internal node
else:
# find which branch is "hot" (meaning x would follow it)
hot_index = 0
cleft = children_left[node_index]
cright = children_right[node_index]
if x_missing[split_index] == 1:
hot_index = children_default[node_index]
elif x[split_index] < thresholds[node_index]:
hot_index = cleft
else:
hot_index = cright
cold_index = cright if hot_index == cleft else cleft
w = node_sample_weight[node_index]
hot_zero_fraction = node_sample_weight[hot_index] / w
cold_zero_fraction = node_sample_weight[cold_index] / w
incoming_zero_fraction = 1
incoming_one_fraction = 1
# see if we have already split on this feature,
# if so we undo that split so we can redo it for this node
path_index = 0
while path_index <= unique_depth:
if feature_indexes[path_index] == split_index:
break
path_index += 1
if path_index != unique_depth + 1:
incoming_zero_fraction = zero_fractions[path_index]
incoming_one_fraction = one_fractions[path_index]
unwind_path(
feature_indexes,
zero_fractions,
one_fractions,
pweights,
unique_depth,
path_index,
)
unique_depth -= 1
# divide up the condition_fraction among the recursive calls
hot_condition_fraction = condition_fraction
cold_condition_fraction = condition_fraction
if condition > 0 and split_index == condition_feature:
cold_condition_fraction = 0
unique_depth -= 1
elif condition < 0 and split_index == condition_feature:
hot_condition_fraction *= hot_zero_fraction
cold_condition_fraction *= cold_zero_fraction
unique_depth -= 1
tree_shap_recursive(
children_left,
children_right,
children_default,
features,
thresholds,
values,
node_sample_weight,
x,
x_missing,
phi,
hot_index,
unique_depth + 1,
feature_indexes,
zero_fractions,
one_fractions,
pweights,
hot_zero_fraction * incoming_zero_fraction,
incoming_one_fraction,
split_index,
condition,
condition_feature,
hot_condition_fraction,
)
tree_shap_recursive(
children_left,
children_right,
children_default,
features,
thresholds,
values,
node_sample_weight,
x,
x_missing,
phi,
cold_index,
unique_depth + 1,
feature_indexes,
zero_fractions,
one_fractions,
pweights,
cold_zero_fraction * incoming_zero_fraction,
0,
split_index,
condition,
condition_feature,
cold_condition_fraction,
)
Compare runtime of XGBoost Tree SHAP…
[7]:
start = time.time()
shap_values = bst.predict(xgboost.DMatrix(X), pred_contribs=True)
print(time.time() - start)
0.4652073383331299
Versus the Python (numba) Tree SHAP…
[8]:
x = np.ones(X.shape[1])
TreeExplainer(model).shap_values(x)
[8]:
array([-0.52625503, 0.07485998, 0.10129917, -0.00780694, 0.18581504,
0.35883265, 0.14973339, -0.03848465, 2.04406351])
[9]:
start = time.time()
ex = TreeExplainer(model)
print(time.time() - start)
start = time.time()
ex.shap_values(X.iloc[:, :])
print(time.time() - start)
0.013829231262207031
18.91119885444641