How a squashing function can effect feature importance

The importance of a feature in a machine learning model can change significantly when you use a non-linear function to transform the model’s output. The most common type of transformation where this matters is the use of a “squashing” function. Squashing functions such as the logistic transform are often used to convert an unbounded “margin” space to a bounded probability space. The value of the margin space is then in the units of information, while the values in the probability space is in the units of probability. Which space you care about can be different in different situations. The margin space is better for adding and subtracting, and directly corresponds to “evidence” in an information-theoretic sense. However, if you only care about changes in % probability, not evidence, then you would be better off using the probability space. By choosing probability space you are saying that getting lots of powerful evidence that takes you from 98% probability to 99.99% probability is not nearly as important as a smaller amount of evidence that takes you from 50% probability to 60% probability. Why does it take more evidence to go from 98% probability to 99.99% than from 50% probability to 60%? It is because in an information theoretic sense, it takes more information to go from 98% certainty to 99.99%, than it does to go from 50% certainty to 60%.

Note that even though the logistic function is a monotonic transformation is can still change the ordering of which features are most important in a model. The ordering of features can change because some features may be very important for getting to 99.9% probability, while others are usually helpful in getting to 60% probability. The simple example below shows how you can change the importance of a feature using a squahing function:

[3]:
import numpy as np
import pandas as pd
import scipy

import shap
[4]:
shap.initjs()
[5]:
# build a simple dataset
N = 500
M = 4
X = np.random.randn(N, M)
X[0, 0] = 0
X[0, 1] = 0
X = pd.DataFrame(X, columns=["A", "B", "C", "D"])


# a function (a made up ML model) with an output in "margin" space...
def f(X):
    return (X[:, 0] > 0) * 1 + (X[:, 1] > 1.5) * 100


# ...and then also change its output to probability space
def f_logistic(X):
    return scipy.special.expit(f(X))
[7]:
# explain both functions
explainer = shap.KernelExplainer(f, X)
shap_values_f = explainer.shap_values(X.values[0:2, :])

explainer_logistic = shap.KernelExplainer(f_logistic, X)
shap_values_f_logistic = explainer_logistic.shap_values(X.values[0:2, :])
Using 500 background data samples could cause slower run times. Consider using shap.kmeans(data, K) to summarize the background as K weighted samples.
Using 500 background data samples could cause slower run times. Consider using shap.kmeans(data, K) to summarize the background as K weighted samples.


Margin space explaination

When thinking about margin space, feature B is very important because by being 0 it means we don’t hit the +100 effect that happens when B is greater than 2. Even though B being greater than 2 is rare, it is also very important because of the large impact it has.

[8]:
shap_values_f[0, :]
[8]:
array([-0.506, -6.   ,  0.   ,  0.   ])
[9]:
shap.force_plot(float(explainer.expected_value), shap_values_f[0, :], X.iloc[0, :])
[9]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.

Probability space explaination

When thinking about probability space, feature B is no longer very important because the logistic function squashes the effect of +100 in the margin space to just +1 at the most. So now feature B being larger than 2 is both rare and less important.

[10]:
shap_values_f_logistic[0, :]
[10]:
array([-0.11344976, -0.02653412,  0.        ,  0.        ])
[11]:
shap.force_plot(
    float(explainer_logistic.expected_value), shap_values_f_logistic[0, :], X.iloc[0, :]
)
[11]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
[ ]: