Front Page DeepExplainer MNIST Example

A simple example showing how to explain an MNIST CNN trained using Keras with DeepExplainer.

[1]:
# this is the code from here --> https://github.com/keras-team/keras/blob/master/examples/demo_mnist_convnet.py
import keras
import numpy as np
from keras import layers
from keras.utils import to_categorical

import shap

# Model / data parameters
num_classes = 10
input_shape = (28, 28, 1)

# Load the data and split it between train and test sets
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Scale images to the [0, 1] range
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255
# Make sure images have shape (28, 28, 1)
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
print("x_train shape:", x_train.shape)
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")


# convert class vectors to binary class matrices
y_train = to_categorical(y_train, num_classes)
y_test = to_categorical(y_test, num_classes)

batch_size = 128
epochs = 3

model = keras.Sequential(
    [
        layers.Input(shape=input_shape),
        layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Flatten(),
        layers.Dropout(0.5),
        layers.Dense(num_classes, activation="softmax"),
    ]
)

model.summary()

model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])

model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)

score = model.evaluate(x_test, y_test, verbose=0)
print("Test loss:", score[0])
print("Test accuracy:", score[1])
x_train shape: (60000, 28, 28, 1)
60000 train samples
10000 test samples
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #
=================================================================
 conv2d (Conv2D)             (None, 26, 26, 32)        320

 max_pooling2d (MaxPooling2  (None, 13, 13, 32)        0
 D)

 conv2d_1 (Conv2D)           (None, 11, 11, 64)        18496

 max_pooling2d_1 (MaxPoolin  (None, 5, 5, 64)          0
 g2D)

 flatten (Flatten)           (None, 1600)              0

 dropout (Dropout)           (None, 1600)              0

 dense (Dense)               (None, 10)                16010

=================================================================
Total params: 34826 (136.04 KB)
Trainable params: 34826 (136.04 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
Epoch 1/3
422/422 [==============================] - 13s 27ms/step - loss: 0.3786 - accuracy: 0.8848 - val_loss: 0.0791 - val_accuracy: 0.9792
Epoch 2/3
422/422 [==============================] - 11s 27ms/step - loss: 0.1127 - accuracy: 0.9652 - val_loss: 0.0569 - val_accuracy: 0.9847
Epoch 3/3
422/422 [==============================] - 11s 27ms/step - loss: 0.0863 - accuracy: 0.9738 - val_loss: 0.0473 - val_accuracy: 0.9860
Test loss: 0.04649536311626434
Test accuracy: 0.9850000143051147
[2]:
# select a set of background examples to take an expectation over
background = x_train[np.random.choice(x_train.shape[0], 100, replace=False)]

# explain predictions of the model on three images
e = shap.DeepExplainer(model, background)
# ...or pass tensors directly
# e = shap.DeepExplainer((model.layers[0].input, model.layers[-1].output), background)
shap_values = e.shap_values(x_test[0:5])
[3]:
# plot the feature attributions
shap.image_plot(shap_values, -x_test[0:5])
../../../_images/example_notebooks_image_examples_image_classification_Front_Page_DeepExplainer_MNIST_Example_3_0.png

The plot above shows the explanations for each class on five predictions. Note that the explanations are ordered for the classes 0-9 going left to right along the rows.