Python Version of Tree SHAP
This is a sample implementation of Tree SHAP written in Python for easy reading.
[1]:
import time
import numba
import numpy as np
import sklearn.ensemble
import xgboost
import shap
Load California dataset
[2]:
X, y = shap.datasets.california(n_points=1000)
X.shape
[2]:
(1000, 8)
Train sklearn random forest
[3]:
model = sklearn.ensemble.RandomForestRegressor(n_estimators=1000, max_depth=4)
model.fit(X, y)
[3]:
RandomForestRegressor(max_depth=4, n_estimators=1000)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
RandomForestRegressor(max_depth=4, n_estimators=1000)
Train XGBoost model
[4]:
bst = xgboost.train({"learning_rate": 0.01, "max_depth": 4}, xgboost.DMatrix(X, label=y), 1000)
is_sparse is deprecated and will be removed in a future version. Check `isinstance(dtype, pd.SparseDtype)` instead.
is_categorical_dtype is deprecated and will be removed in a future version. Use isinstance(dtype, CategoricalDtype) instead
is_categorical_dtype is deprecated and will be removed in a future version. Use isinstance(dtype, CategoricalDtype) instead
is_categorical_dtype is deprecated and will be removed in a future version. Use isinstance(dtype, CategoricalDtype) instead
Python TreeExplainer
This uses numba to speed things up.
[5]:
class TreeExplainer:
def __init__(self, model, **kwargs):
if str(type(model)).endswith("sklearn.ensemble._forest.RandomForestRegressor'>"):
# self.trees = [Tree(e.tree_) for e in model.estimators_]
self.trees = [
Tree(
children_left=e.tree_.children_left,
children_right=e.tree_.children_right,
children_default=e.tree_.children_right,
feature=e.tree_.feature,
threshold=e.tree_.threshold,
value=e.tree_.value[:, 0, 0],
node_sample_weight=e.tree_.weighted_n_node_samples,
)
for e in model.estimators_
]
# Preallocate space for the unique path data
maxd = np.max([t.max_depth for t in self.trees]) + 2
s = (maxd * (maxd + 1)) // 2
self.feature_indexes = np.zeros(s, dtype=np.int32)
self.zero_fractions = np.zeros(s, dtype=np.float64)
self.one_fractions = np.zeros(s, dtype=np.float64)
self.pweights = np.zeros(s, dtype=np.float64)
def shap_values(self, X, **kwargs):
# convert dataframes
if str(type(X)).endswith("pandas.core.series.Series'>"):
X = X.values
elif str(type(X)).endswith("'pandas.core.frame.DataFrame'>"):
X = X.values
assert str(type(X)).endswith("'numpy.ndarray'>"), "Unknown instance type: " + str(type(X))
assert len(X.shape) == 1 or len(X.shape) == 2, "Instance must have 1 or 2 dimensions!"
# single instance
if len(X.shape) == 1:
phi = np.zeros(X.shape[0] + 1)
x_missing = np.zeros(X.shape[0], dtype=bool)
for t in self.trees:
self.tree_shap(t, X, x_missing, phi)
phi /= len(self.trees)
elif len(X.shape) == 2:
phi = np.zeros((X.shape[0], X.shape[1] + 1))
x_missing = np.zeros(X.shape[1], dtype=bool)
for i in range(X.shape[0]):
for t in self.trees:
self.tree_shap(t, X[i, :], x_missing, phi[i, :])
phi /= len(self.trees)
return phi
def tree_shap(self, tree, x, x_missing, phi, condition=0, condition_feature=0):
# update the bias term, which is the last index in phi
# (note the paper has this as phi_0 instead of phi_M)
if condition == 0:
phi[-1] += tree.values[0]
# start the recursive algorithm
tree_shap_recursive(
tree.children_left,
tree.children_right,
tree.children_default,
tree.features,
tree.thresholds,
tree.values,
tree.node_sample_weight,
x,
x_missing,
phi,
0,
0,
self.feature_indexes,
self.zero_fractions,
self.one_fractions,
self.pweights,
1,
1,
-1,
condition,
condition_feature,
1,
)
[6]:
# extend our decision path with a fraction of one and zero extensions
@numba.jit(
numba.types.void(
numba.types.int32[:],
numba.types.float64[:],
numba.types.float64[:],
numba.types.float64[:],
numba.types.int32,
numba.types.float64,
numba.types.float64,
numba.types.int32,
),
nopython=True,
nogil=True,
)
def extend_path(
feature_indexes,
zero_fractions,
one_fractions,
pweights,
unique_depth,
zero_fraction,
one_fraction,
feature_index,
):
feature_indexes[unique_depth] = feature_index
zero_fractions[unique_depth] = zero_fraction
one_fractions[unique_depth] = one_fraction
if unique_depth == 0:
pweights[unique_depth] = 1
else:
pweights[unique_depth] = 0
for i in range(unique_depth - 1, -1, -1):
pweights[i + 1] += one_fraction * pweights[i] * (i + 1) / (unique_depth + 1)
pweights[i] = zero_fraction * pweights[i] * (unique_depth - i) / (unique_depth + 1)
# undo a previous extension of the decision path
@numba.jit(
numba.types.void(
numba.types.int32[:],
numba.types.float64[:],
numba.types.float64[:],
numba.types.float64[:],
numba.types.int32,
numba.types.int32,
),
nopython=True,
nogil=True,
)
def unwind_path(feature_indexes, zero_fractions, one_fractions, pweights, unique_depth, path_index):
one_fraction = one_fractions[path_index]
zero_fraction = zero_fractions[path_index]
next_one_portion = pweights[unique_depth]
for i in range(unique_depth - 1, -1, -1):
if one_fraction != 0:
tmp = pweights[i]
pweights[i] = next_one_portion * (unique_depth + 1) / ((i + 1) * one_fraction)
next_one_portion = tmp - pweights[i] * zero_fraction * (unique_depth - i) / (unique_depth + 1)
else:
pweights[i] = (pweights[i] * (unique_depth + 1)) / (zero_fraction * (unique_depth - i))
for i in range(path_index, unique_depth):
feature_indexes[i] = feature_indexes[i + 1]
zero_fractions[i] = zero_fractions[i + 1]
one_fractions[i] = one_fractions[i + 1]
# determine what the total permuation weight would be if
# we unwound a previous extension in the decision path
@numba.jit(
numba.types.float64(
numba.types.int32[:],
numba.types.float64[:],
numba.types.float64[:],
numba.types.float64[:],
numba.types.int32,
numba.types.int32,
),
nopython=True,
nogil=True,
)
def unwound_path_sum(feature_indexes, zero_fractions, one_fractions, pweights, unique_depth, path_index):
one_fraction = one_fractions[path_index]
zero_fraction = zero_fractions[path_index]
next_one_portion = pweights[unique_depth]
total = 0
for i in range(unique_depth - 1, -1, -1):
if one_fraction != 0:
tmp = next_one_portion * (unique_depth + 1) / ((i + 1) * one_fraction)
total += tmp
next_one_portion = pweights[i] - tmp * zero_fraction * ((unique_depth - i) / (unique_depth + 1))
else:
total += (pweights[i] / zero_fraction) / ((unique_depth - i) / (unique_depth + 1))
return total
class Tree:
def __init__(
self,
children_left,
children_right,
children_default,
feature,
threshold,
value,
node_sample_weight,
):
self.children_left = children_left.astype(np.int32)
self.children_right = children_right.astype(np.int32)
self.children_default = children_default.astype(np.int32)
self.features = feature.astype(np.int32)
self.thresholds = threshold
self.values = value
self.node_sample_weight = node_sample_weight
self.max_depth = compute_expectations(
self.children_left,
self.children_right,
self.node_sample_weight,
self.values,
0,
)
@numba.jit(nopython=True)
def compute_expectations(children_left, children_right, node_sample_weight, values, i, depth=0):
if children_right[i] == -1:
values[i] = values[i]
return 0
else:
li = children_left[i]
ri = children_right[i]
depth_left = compute_expectations(children_left, children_right, node_sample_weight, values, li, depth + 1)
depth_right = compute_expectations(children_left, children_right, node_sample_weight, values, ri, depth + 1)
left_weight = node_sample_weight[li]
right_weight = node_sample_weight[ri]
v = (left_weight * values[li] + right_weight * values[ri]) / (left_weight + right_weight)
values[i] = v
return max(depth_left, depth_right) + 1
# recursive computation of SHAP values for a decision tree
@numba.jit(
numba.types.void(
numba.types.int32[:],
numba.types.int32[:],
numba.types.int32[:],
numba.types.int32[:],
numba.types.float64[:],
numba.types.float64[:],
numba.types.float64[:],
numba.types.float64[:],
numba.types.boolean[:],
numba.types.float64[:],
numba.types.int64,
numba.types.int64,
numba.types.int32[:],
numba.types.float64[:],
numba.types.float64[:],
numba.types.float64[:],
numba.types.float64,
numba.types.float64,
numba.types.int64,
numba.types.int64,
numba.types.int64,
numba.types.float64,
),
nopython=True,
nogil=True,
)
def tree_shap_recursive(
children_left,
children_right,
children_default,
features,
thresholds,
values,
node_sample_weight,
x,
x_missing,
phi,
node_index,
unique_depth,
parent_feature_indexes,
parent_zero_fractions,
parent_one_fractions,
parent_pweights,
parent_zero_fraction,
parent_one_fraction,
parent_feature_index,
condition,
condition_feature,
condition_fraction,
):
# stop if we have no weight coming down to us
if condition_fraction == 0:
return
# extend the unique path
feature_indexes = parent_feature_indexes[unique_depth + 1 :]
feature_indexes[: unique_depth + 1] = parent_feature_indexes[: unique_depth + 1]
zero_fractions = parent_zero_fractions[unique_depth + 1 :]
zero_fractions[: unique_depth + 1] = parent_zero_fractions[: unique_depth + 1]
one_fractions = parent_one_fractions[unique_depth + 1 :]
one_fractions[: unique_depth + 1] = parent_one_fractions[: unique_depth + 1]
pweights = parent_pweights[unique_depth + 1 :]
pweights[: unique_depth + 1] = parent_pweights[: unique_depth + 1]
if condition == 0 or condition_feature != parent_feature_index:
extend_path(
feature_indexes,
zero_fractions,
one_fractions,
pweights,
unique_depth,
parent_zero_fraction,
parent_one_fraction,
parent_feature_index,
)
split_index = features[node_index]
# leaf node
if children_right[node_index] == -1:
for i in range(1, unique_depth + 1):
w = unwound_path_sum(
feature_indexes,
zero_fractions,
one_fractions,
pweights,
unique_depth,
i,
)
phi[feature_indexes[i]] += (
w * (one_fractions[i] - zero_fractions[i]) * values[node_index] * condition_fraction
)
# internal node
else:
# find which branch is "hot" (meaning x would follow it)
hot_index = 0
cleft = children_left[node_index]
cright = children_right[node_index]
if x_missing[split_index] == 1:
hot_index = children_default[node_index]
elif x[split_index] < thresholds[node_index]:
hot_index = cleft
else:
hot_index = cright
cold_index = cright if hot_index == cleft else cleft
w = node_sample_weight[node_index]
hot_zero_fraction = node_sample_weight[hot_index] / w
cold_zero_fraction = node_sample_weight[cold_index] / w
incoming_zero_fraction = 1
incoming_one_fraction = 1
# see if we have already split on this feature,
# if so we undo that split so we can redo it for this node
path_index = 0
while path_index <= unique_depth:
if feature_indexes[path_index] == split_index:
break
path_index += 1
if path_index != unique_depth + 1:
incoming_zero_fraction = zero_fractions[path_index]
incoming_one_fraction = one_fractions[path_index]
unwind_path(
feature_indexes,
zero_fractions,
one_fractions,
pweights,
unique_depth,
path_index,
)
unique_depth -= 1
# divide up the condition_fraction among the recursive calls
hot_condition_fraction = condition_fraction
cold_condition_fraction = condition_fraction
if condition > 0 and split_index == condition_feature:
cold_condition_fraction = 0
unique_depth -= 1
elif condition < 0 and split_index == condition_feature:
hot_condition_fraction *= hot_zero_fraction
cold_condition_fraction *= cold_zero_fraction
unique_depth -= 1
tree_shap_recursive(
children_left,
children_right,
children_default,
features,
thresholds,
values,
node_sample_weight,
x,
x_missing,
phi,
hot_index,
unique_depth + 1,
feature_indexes,
zero_fractions,
one_fractions,
pweights,
hot_zero_fraction * incoming_zero_fraction,
incoming_one_fraction,
split_index,
condition,
condition_feature,
hot_condition_fraction,
)
tree_shap_recursive(
children_left,
children_right,
children_default,
features,
thresholds,
values,
node_sample_weight,
x,
x_missing,
phi,
cold_index,
unique_depth + 1,
feature_indexes,
zero_fractions,
one_fractions,
pweights,
cold_zero_fraction * incoming_zero_fraction,
0,
split_index,
condition,
condition_feature,
cold_condition_fraction,
)
Compare runtime of XGBoost Tree SHAP…
[7]:
start = time.time()
shap_values = bst.predict(xgboost.DMatrix(X), pred_contribs=True)
print(time.time() - start)
0.1391134262084961
is_sparse is deprecated and will be removed in a future version. Check `isinstance(dtype, pd.SparseDtype)` instead.
is_categorical_dtype is deprecated and will be removed in a future version. Use isinstance(dtype, CategoricalDtype) instead
is_categorical_dtype is deprecated and will be removed in a future version. Use isinstance(dtype, CategoricalDtype) instead
is_categorical_dtype is deprecated and will be removed in a future version. Use isinstance(dtype, CategoricalDtype) instead
Versus the Python (numba) Tree SHAP…
[8]:
x = np.ones(X.shape[1])
TreeExplainer(model).shap_values(x)
[8]:
array([-0.5140118 , 0.07819144, 0.09588052, -0.01004643, 0.21487504,
0.37846563, 0.15740923, -0.04002657, 2.04526993])
[9]:
start = time.time()
ex = TreeExplainer(model)
print(time.time() - start)
start = time.time()
ex.shap_values(X.iloc[:, :])
print(time.time() - start)
0.0075643062591552734
10.171732664108276