# Simple Kernel SHAP

This notebook provides a simple brute force version of Kernel SHAP that enumerates the entire $$2^M$$ sample space. We also compare to the full KernelExplainer implementation. Note that KernelExplainer does a sampling approximation for large values of $$M$$, but for small values it is exact.

## Brute Force Kernel SHAP

[1]:

import scipy.special
import numpy as np
import itertools

def powerset(iterable):
s = list(iterable)
return itertools.chain.from_iterable(itertools.combinations(s, r) for r in range(len(s)+1))

def shapley_kernel(M,s):
if s == 0 or s == M:
return 10000
return (M-1)/(scipy.special.binom(M,s)*s*(M-s))

def f(X):
np.random.seed(0)
beta = np.random.rand(X.shape[-1])
return np.dot(X,beta) + 10

def kernel_shap(f, x, reference, M):
X = np.zeros((2**M,M+1))
X[:,-1] = 1
weights = np.zeros(2**M)
V = np.zeros((2**M,M))
for i in range(2**M):
V[i,:] = reference

ws = {}
for i,s in enumerate(powerset(range(M))):
s = list(s)
V[i,s] = x[s]
X[i,s] = 1
ws[len(s)] = ws.get(len(s), 0) + shapley_kernel(M,len(s))
weights[i] = shapley_kernel(M,len(s))
y = f(V)
tmp = np.linalg.inv(np.dot(np.dot(X.T, np.diag(weights)), X))
return np.dot(tmp, np.dot(np.dot(X.T, np.diag(weights)), y))

M = 4
np.random.seed(1)
x = np.random.randn(M)
reference = np.zeros(M)
phi = kernel_shap(f, x, reference, M)
base_value = phi[-1]
shap_values = phi[:-1]

print("  reference =", reference)
print("          x =", x)
print("shap_values =", shap_values)
print(" base_value =", base_value)
print("   sum(phi) =", np.sum(phi))
print("       f(x) =", f(x))

  reference = [0. 0. 0. 0.]
x = [ 1.62434536 -0.61175641 -0.52817175 -1.07296862]
shap_values = [ 0.89146267 -0.43752168 -0.31836259 -0.58464256]
base_value = 9.999999999999996
sum(phi) = 9.55093584211863
f(x) = 9.55093584213122


## Using KernelExplainer

[2]:

import shap

explainer = shap.KernelExplainer(f, np.reshape(reference, (1, len(reference))))
shap_values = explainer.shap_values(x)
print("shap_values =", shap_values)
print("base value =", explainer.expected_value)

shap_values = [ 0.89146267 -0.43752168 -0.31836259 -0.58464256]
base value = 10.0