Math behind LinearExplainer with correlation feature perturbation

When we use LinearExplainer(model, prior, feature_perturbation="correlation_dependent") we do not use \(E[f(x) \mid do(X_S = x_S)]\) to measure the impact of a set \(S\) of features, but rather use \(E[f(x) \mid X_S = x_s]\) under the assumption that the random variable \(X\) (representing the input features) follows a multivariate guassian distribution. To compute SHAP values this way we need to compute conditional expectations under the multivariate guassian distribution for all subset of features. This would be a lot of matrix match for an exponential number of terms, and it hence intractable for models with more than just a few features.

This document briefly outlines the math we have used to precompute all of the required linear algebra using a sampling procedure that can be done just once, and then applied to as many samples as we like. This drastically speed up the computation compared to a brute force approach. Note that all these calculations depend on the fact that we are explaining a linear model \(f(x) = \beta x\).

The permutation definition of SHAP values in the interventional form used by most explainers is

\[\phi_i = \frac{1}{M!} \sum_R E[f(X) \mid do(X_{S_i^R \cup i} = x_{S_i^R \cup i})] - E[f(X) \mid do(X_{S_i^R} = x_{S_i^R})]\]

but here we will use the non-interventional conditional expectation form (where we have simplified the notation by dropping the explicit reference to the random variable \(X\)).

\[\phi_i = \frac{1}{M!} \sum_R E[f(x) \mid x_{S_i^R \cup i}] - E[f(x) \mid x_{S_i^R}]\]

where \(f(x) = \beta x + b\) with \(\beta\) a row vector and \(b\) a scalar.

If we replace f(x) with the linear function definition we get:

:nbsphinx-math:`begin{align} phi_i = frac{1}{M!} sum_R E[beta x + b mid x_{S_i^R cup i}] - E[beta x + b mid x_{S_i^R}] \

= beta frac{1}{M!} sum_R E[x mid x_{S_i^R cup i}] - E[x mid x_{S_i^R}]

end{align}`

Assume the inputs \(x\) follow a multivariate normal distribution with mean \(\mu\) and covariance \(\Sigma\). Denote the projection matrix that selects a set \(S\) as \(P_S\), then we get:

\begin{align} E[x \mid x_S] = [P_{\bar S} \mu + P_{\bar S} \Sigma P_S^T (P_S \Sigma P_S^T)^{-1} ( P_S x - P_S \mu)] P_{\bar S} + x P_S^T P_S \\ = [P_{\bar S} \mu + P_{\bar S} \Sigma P_S^T (P_S \Sigma P_S^T)^{-1} P_S (x - \mu)] P_{\bar S} + x P_S^T P_S \\ = [\mu + \Sigma P_S^T (P_S \Sigma P_S^T)^{-1} P_S (x - \mu)] P_{\bar S}^T P_{\bar S} + x P_S^T P_S \\ = P_{\bar S}^T P_{\bar S} [\mu + \Sigma P_S^T (P_S \Sigma P_S^T)^{-1} P_S (x - \mu)] + P_S^T P_S x \\ = P_{\bar S}^T P_{\bar S} \mu + P_{\bar S}^T P_{\bar S} \Sigma P_S^T (P_S \Sigma P_S^T)^{-1} P_S x - P_{\bar S}^T P_{\bar S} \Sigma P_S^T (P_S \Sigma P_S^T)^{-1} P_S \mu + P_S^T P_S x \\ = [P_{\bar S}^T P_{\bar S} - P_{\bar S}^T P_{\bar S} \Sigma P_S^T (P_S \Sigma P_S^T)^{-1} P_S] \mu + [P_S^T P_S + P_{\bar S}^T P_{\bar S} \Sigma P_S^T (P_S \Sigma P_S^T)^{-1} P_S] x \end{align}

if we let \(R_S = P_{\bar S}^T P_{\bar S} \Sigma P_S^T (P_S \Sigma P_S^T)^{-1} P_S\) and \(Q_S = P_S^T P_S\) then we can write

\begin{align} E[x \mid x_S] = [Q_{\bar S} - R_S] \mu + [Q_S + R_S] x \end{align}

or

\begin{align} E[x \mid x_{S_i^R \cup i}] = [Q_{\bar{S_i^R \cup i}} - R_{S_i^R \cup i}] \mu + [Q_{S_i^R \cup i} + R_{S_i^R \cup i}] x \end{align}

leading to the Shapley equation of

\begin{align} \phi_i = \beta \frac{1}{M!} \sum_R [Q_{\bar{S_i^R \cup i}} - R_{S_i^R \cup i}] \mu + [Q_{S_i^R \cup i} + R_{S_i^R \cup i}] x - [Q_{\bar{S_i^R}} - R_{S_i^R}] \mu - [Q_{S_i^R} + R_{S_i^R}] x \\ = \beta \frac{1}{M!} \sum_R ([Q_{\bar{S_i^R \cup i}} - R_{S_i^R \cup i}] - [Q_{\bar{S_i^R}} - R_{S_i^R}]) \mu + ([Q_{S_i^R \cup i} + R_{S_i^R \cup i}] - [Q_{S_i^R} + R_{S_i^R}]) x \\ = \beta \left [ \frac{1}{M!} \sum_R ([Q_{\bar{S_i^R \cup i}} - R_{S_i^R \cup i}] - [Q_{\bar{S_i^R}} - R_{S_i^R}]) \right ] \mu + \beta \left [ \frac{1}{M!} \sum_R ([Q_{S_i^R \cup i} + R_{S_i^R \cup i}] - [Q_{S_i^R} + R_{S_i^R}]) \right ] x \end{align}

\[\phi = \beta T x\]

This means that we can precompute the transform matrix \(T\) by drawing random permutations \(R\) many times and averaging our results. Once we have computed \(T\) we can explain any number of samples (or models for that matter) by just using matrix multiplication.