Census income classification with Keras

To download a copy of this notebook visit github.

[1]:
from keras.layers import (
    Dense,
    Dropout,
    Flatten,
    Input,
    concatenate,
)
from keras.layers.embeddings import Embedding
from keras.models import Model
from sklearn.model_selection import train_test_split

import shap

# print the JS visualization code to the notebook
shap.initjs()
Using TensorFlow backend.

Load dataset

[2]:
X, y = shap.datasets.adult()
X_display, y_display = shap.datasets.adult(display=True)

# normalize data (this is important for model convergence)
dtypes = list(zip(X.dtypes.index, map(str, X.dtypes)))
for k, dtype in dtypes:
    if dtype == "float32":
        X[k] -= X[k].mean()
        X[k] /= X[k].std()

X_train, X_valid, y_train, y_valid = train_test_split(
    X, y, test_size=0.2, random_state=7
)

Train Keras model

[3]:
# build model
input_els = []
encoded_els = []
for k, dtype in dtypes:
    input_els.append(Input(shape=(1,)))
    if dtype == "int8":
        e = Flatten()(Embedding(X_train[k].max() + 1, 1)(input_els[-1]))
    else:
        e = input_els[-1]
    encoded_els.append(e)
encoded_els = concatenate(encoded_els)
layer1 = Dropout(0.5)(Dense(100, activation="relu")(encoded_els))
out = Dense(1)(layer1)

# train model
regression = Model(inputs=input_els, outputs=[out])
regression.compile(optimizer="adam", loss="binary_crossentropy")
regression.fit(
    [X_train[k].values for k, t in dtypes],
    y_train,
    epochs=50,
    batch_size=512,
    shuffle=True,
    validation_data=([X_valid[k].values for k, t in dtypes], y_valid),
)
Train on 26048 samples, validate on 6513 samples
Epoch 1/50
26048/26048 [==============================] - 1s 28us/step - loss: 2.3308 - val_loss: 0.4450
Epoch 2/50
26048/26048 [==============================] - 0s 8us/step - loss: 1.5018 - val_loss: 0.5353
Epoch 3/50
26048/26048 [==============================] - 0s 9us/step - loss: 1.3662 - val_loss: 0.5634
Epoch 4/50
26048/26048 [==============================] - 0s 8us/step - loss: 1.3522 - val_loss: 0.6502
Epoch 5/50
26048/26048 [==============================] - 0s 8us/step - loss: 1.3053 - val_loss: 0.5451
Epoch 6/50
26048/26048 [==============================] - 0s 8us/step - loss: 1.2348 - val_loss: 0.5146
Epoch 7/50
26048/26048 [==============================] - 0s 8us/step - loss: 1.2083 - val_loss: 0.4880
Epoch 8/50
26048/26048 [==============================] - 0s 9us/step - loss: 1.2280 - val_loss: 0.7679
Epoch 9/50
26048/26048 [==============================] - 0s 8us/step - loss: 1.1979 - val_loss: 0.4658
Epoch 10/50
26048/26048 [==============================] - 0s 7us/step - loss: 1.1313 - val_loss: 0.5112
Epoch 11/50
26048/26048 [==============================] - 0s 8us/step - loss: 1.1138 - val_loss: 0.5580
Epoch 12/50
26048/26048 [==============================] - 0s 9us/step - loss: 1.2020 - val_loss: 0.4981
Epoch 13/50
26048/26048 [==============================] - 0s 7us/step - loss: 1.0844 - val_loss: 0.4940
Epoch 14/50
26048/26048 [==============================] - 0s 10us/step - loss: 1.0802 - val_loss: 0.5090
Epoch 15/50
26048/26048 [==============================] - 0s 7us/step - loss: 1.0761 - val_loss: 0.5058
Epoch 16/50
26048/26048 [==============================] - 0s 7us/step - loss: 1.0470 - val_loss: 0.5143
Epoch 17/50
26048/26048 [==============================] - 0s 7us/step - loss: 1.0285 - val_loss: 0.5553
Epoch 18/50
26048/26048 [==============================] - 0s 7us/step - loss: 1.0215 - val_loss: 0.5479
Epoch 19/50
26048/26048 [==============================] - 0s 8us/step - loss: 1.0137 - val_loss: 0.5628
Epoch 20/50
26048/26048 [==============================] - 0s 8us/step - loss: 1.0022 - val_loss: 0.5426
Epoch 21/50
26048/26048 [==============================] - 0s 8us/step - loss: 0.9641 - val_loss: 0.5291
Epoch 22/50
26048/26048 [==============================] - 0s 8us/step - loss: 0.9765 - val_loss: 0.7090
Epoch 23/50
26048/26048 [==============================] - 0s 8us/step - loss: 1.0097 - val_loss: 0.4819
Epoch 24/50
26048/26048 [==============================] - 0s 7us/step - loss: 0.9100 - val_loss: 0.4874
Epoch 25/50
26048/26048 [==============================] - 0s 9us/step - loss: 0.8821 - val_loss: 0.4724
Epoch 26/50
26048/26048 [==============================] - 0s 9us/step - loss: 0.8653 - val_loss: 0.5671
Epoch 27/50
26048/26048 [==============================] - 0s 8us/step - loss: 1.0496 - val_loss: 0.6884
Epoch 28/50
26048/26048 [==============================] - 0s 8us/step - loss: 0.9529 - val_loss: 0.5993
Epoch 29/50
26048/26048 [==============================] - 0s 7us/step - loss: 0.9255 - val_loss: 0.5297
Epoch 30/50
26048/26048 [==============================] - 0s 7us/step - loss: 0.8726 - val_loss: 0.4880
Epoch 31/50
26048/26048 [==============================] - 0s 8us/step - loss: 0.8523 - val_loss: 0.4730
Epoch 32/50
26048/26048 [==============================] - 0s 7us/step - loss: 0.8526 - val_loss: 0.4683
Epoch 33/50
26048/26048 [==============================] - 0s 8us/step - loss: 0.7988 - val_loss: 0.4655
Epoch 34/50
26048/26048 [==============================] - 0s 7us/step - loss: 0.7920 - val_loss: 0.4560
Epoch 35/50
26048/26048 [==============================] - 0s 7us/step - loss: 0.7629 - val_loss: 0.4449
Epoch 36/50
26048/26048 [==============================] - 0s 8us/step - loss: 0.7506 - val_loss: 0.4388
Epoch 37/50
26048/26048 [==============================] - 0s 8us/step - loss: 0.7266 - val_loss: 0.4366
Epoch 38/50
26048/26048 [==============================] - 0s 7us/step - loss: 0.7460 - val_loss: 0.4239
Epoch 39/50
26048/26048 [==============================] - 0s 7us/step - loss: 0.7268 - val_loss: 0.4159
Epoch 40/50
26048/26048 [==============================] - 0s 10us/step - loss: 0.7199 - val_loss: 0.4025
Epoch 41/50
26048/26048 [==============================] - 0s 9us/step - loss: 0.6725 - val_loss: 0.4090
Epoch 42/50
26048/26048 [==============================] - 0s 7us/step - loss: 0.7740 - val_loss: 0.4576
Epoch 43/50
26048/26048 [==============================] - 0s 7us/step - loss: 0.7491 - val_loss: 0.4111
Epoch 44/50
26048/26048 [==============================] - 0s 7us/step - loss: 0.6639 - val_loss: 0.4068
Epoch 45/50
26048/26048 [==============================] - 0s 7us/step - loss: 0.6734 - val_loss: 0.4218
Epoch 46/50
26048/26048 [==============================] - 0s 7us/step - loss: 0.6580 - val_loss: 0.3993
Epoch 47/50
26048/26048 [==============================] - 0s 7us/step - loss: 0.6516 - val_loss: 0.4000
Epoch 48/50
26048/26048 [==============================] - 0s 7us/step - loss: 0.6464 - val_loss: 0.3989
Epoch 49/50
26048/26048 [==============================] - 0s 7us/step - loss: 0.6258 - val_loss: 0.4004
Epoch 50/50
26048/26048 [==============================] - 0s 7us/step - loss: 0.6157 - val_loss: 0.4005
[3]:
<keras.callbacks.History at 0x10d720390>

Explain predictions

Here we take the Keras model trained above and explain why it makes different predictions for different individuals. SHAP expects model functions to take a 2D numpy array as input, so we define a wrapper function around the original Keras predict function.

[4]:
def f(X):
    return regression.predict([X[:, i] for i in range(X.shape[1])]).flatten()

Explain a single prediction

Here we use a selection of 50 samples from the dataset to represent “typical” feature values, and then use 500 perterbation samples to estimate the SHAP values for a given prediction. Note that this requires 500 * 50 evaluations of the model.

[5]:
explainer = shap.KernelExplainer(f, X.iloc[:50, :])
shap_values = explainer.shap_values(X.iloc[299, :], nsamples=500)
shap.force_plot(explainer.expected_value, shap_values, X_display.iloc[299, :])
[5]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security.

Explain many predictions

Here we repeat the above explanation process for 50 individuals. Since we are using a sampling based approximation each explanation can take a couple seconds depending on your machine setup.

[6]:
shap_values50 = explainer.shap_values(X.iloc[280:330, :], nsamples=500)
100%|██████████| 50/50 [00:53<00:00,  1.08s/it]
[7]:
shap.force_plot(explainer.expected_value, shap_values50, X_display.iloc[280:330, :])
[7]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security.