Keras LSTM for IMDB Sentiment Classification

This is simple example of how to explain a Keras LSTM model using DeepExplainer.

[1]:
# This model training code is directly from:
# https://github.com/keras-team/keras/blob/master/examples/imdb_lstm.py

"""Trains an LSTM model on the IMDB sentiment classification task.
The dataset is actually too small for LSTM to be of any advantage
compared to simpler, much faster methods such as TF-IDF + LogReg.
# Notes
- RNNs are tricky. Choice of batch size is important,
choice of loss and optimizer is critical, etc.
Some configurations won't converge.
- LSTM loss decrease patterns during training can be quite different
from what you see with CNNs/MLPs/etc.
"""

from keras.datasets import imdb
from keras.layers import LSTM, Dense, Embedding
from keras.models import Sequential
from keras.preprocessing import sequence

max_features = 20000
maxlen = 80  # cut texts after this number of words (among top max_features most common words)
batch_size = 32

print("Loading data...")
(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)
print(len(x_train), "train sequences")
print(len(x_test), "test sequences")

print("Pad sequences (samples x time)")
x_train = sequence.pad_sequences(x_train, maxlen=maxlen)
x_test = sequence.pad_sequences(x_test, maxlen=maxlen)
print("x_train shape:", x_train.shape)
print("x_test shape:", x_test.shape)

print("Build model...")
model = Sequential()
model.add(Embedding(max_features, 128))
model.add(LSTM(128, dropout=0.2, recurrent_dropout=0.2))
model.add(Dense(1, activation="sigmoid"))

# try using different optimizers and different optimizer configs
model.compile(loss="binary_crossentropy", optimizer="adam", metrics=["accuracy"])

print("Train...")
model.fit(
    x_train, y_train, batch_size=batch_size, epochs=15, validation_data=(x_test, y_test)
)
score, acc = model.evaluate(x_test, y_test, batch_size=batch_size)
print("Test score:", score)
print("Test accuracy:", acc)
Using TensorFlow backend.
Loading data...
25000 train sequences
25000 test sequences
Pad sequences (samples x time)
x_train shape: (25000, 80)
x_test shape: (25000, 80)
Build model...
Train...
Train on 25000 samples, validate on 25000 samples
Epoch 1/15
25000/25000 [==============================] - 113s 5ms/step - loss: 0.4577 - acc: 0.7825 - val_loss: 0.3970 - val_acc: 0.8246
Epoch 2/15
25000/25000 [==============================] - 110s 4ms/step - loss: 0.3048 - acc: 0.8752 - val_loss: 0.3794 - val_acc: 0.8330
Epoch 3/15
25000/25000 [==============================] - 109s 4ms/step - loss: 0.2210 - acc: 0.9129 - val_loss: 0.4197 - val_acc: 0.8300
Epoch 4/15
25000/25000 [==============================] - 113s 5ms/step - loss: 0.1557 - acc: 0.9433 - val_loss: 0.4687 - val_acc: 0.8279
Epoch 5/15
25000/25000 [==============================] - 114s 5ms/step - loss: 0.1057 - acc: 0.9615 - val_loss: 0.6095 - val_acc: 0.8240
Epoch 6/15
25000/25000 [==============================] - 136s 5ms/step - loss: 0.0790 - acc: 0.9720 - val_loss: 0.7360 - val_acc: 0.8177
Epoch 7/15
25000/25000 [==============================] - 127s 5ms/step - loss: 0.0755 - acc: 0.9746 - val_loss: 0.6201 - val_acc: 0.8180
Epoch 8/15
25000/25000 [==============================] - 121s 5ms/step - loss: 0.0436 - acc: 0.9854 - val_loss: 0.8128 - val_acc: 0.8169
Epoch 9/15
25000/25000 [==============================] - 110s 4ms/step - loss: 0.0312 - acc: 0.9895 - val_loss: 0.9553 - val_acc: 0.8145
Epoch 10/15
25000/25000 [==============================] - 114s 5ms/step - loss: 0.0283 - acc: 0.9909 - val_loss: 0.9576 - val_acc: 0.8126
Epoch 11/15
25000/25000 [==============================] - 108s 4ms/step - loss: 0.0172 - acc: 0.9949 - val_loss: 0.9107 - val_acc: 0.8117
Epoch 12/15
25000/25000 [==============================] - 108s 4ms/step - loss: 0.0156 - acc: 0.9954 - val_loss: 0.9634 - val_acc: 0.8096
Epoch 13/15
25000/25000 [==============================] - 110s 4ms/step - loss: 0.0119 - acc: 0.9962 - val_loss: 1.0733 - val_acc: 0.8123
Epoch 14/15
25000/25000 [==============================] - 113s 5ms/step - loss: 0.0117 - acc: 0.9964 - val_loss: 1.1165 - val_acc: 0.8106
Epoch 15/15
25000/25000 [==============================] - 111s 4ms/step - loss: 0.0107 - acc: 0.9970 - val_loss: 1.0867 - val_acc: 0.8091
25000/25000 [==============================] - 17s 688us/step
Test score: 1.0867270610725879
Test accuracy: 0.80912

Explain the model with DeepExplainer and visualize the first prediction

[3]:
import shap

# we use the first 100 training examples as our background dataset to integrate over
explainer = shap.DeepExplainer(model, x_train[:100])

# explain the first 10 predictions
# explaining each prediction requires 2 * background dataset size runs
shap_values = explainer.shap_values(x_test[:10])
[4]:
# init the JS visualization code
shap.initjs()

# transform the indexes to words
import numpy as np

words = imdb.get_word_index()
num2word = {}
for w in words.keys():
    num2word[words[w]] = w
x_test_words = np.stack(
    [
        np.array(list(map(lambda x: num2word.get(x, "NONE"), x_test[i])))
        for i in range(10)
    ]
)

# plot the explanation of the first prediction
# Note the model is "multi-output" because it is rank-2 but only has one column
shap.force_plot(explainer.expected_value[0], shap_values[0][0], x_test_words[0])
[4]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.

Note that each sample is an IMDB review text document, represented as a sequence of words. This means “feature 0” is the first word in the review, which will be different for difference reviews. This means calling summary_plot will combine the importance of all the words by their position in the text. This is likely not what you want for a global measure of feature importance (which is why we have not called summary_plot here). If you do want a global summary of a word’s importance you could pull apart the feature attribution values and group them by words.