Emotion classification multiclass example

This notebook demonstrates how to use the Partition explainer for a multiclass text classification scenario. Once the SHAP values are computed for a set of sentences we then visualize feature attributions towards individual classes. The text classifcation model we use is BERT fine-tuned on an emotion dataset to classify a sentence among six classes: joy, sadness, anger, fear, love and surprise.

[1]:
import datasets
import pandas as pd
import transformers

import shap

# load the emotion dataset
dataset = datasets.load_dataset("emotion", split="train")
data = pd.DataFrame({"text": dataset["text"], "emotion": dataset["label"]})
Using custom data configuration default
Reusing dataset emotion (/home/slundberg/.cache/huggingface/datasets/emotion/default/0.0.0/aa34462255cd487d04be8387a2d572588f6ceee23f784f37365aa714afeb8fe6)

Build a transformers pipline

Note that we have set return_all_scores=True for the pipeline so we can observe the model’s behavior for all classes, not just the top output.

[2]:
# load the model and tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained(
    "nateraw/bert-base-uncased-emotion", use_fast=True
)
model = transformers.AutoModelForSequenceClassification.from_pretrained(
    "nateraw/bert-base-uncased-emotion"
).cuda()

# build a pipeline object to do predictions
pred = transformers.pipeline(
    "text-classification",
    model=model,
    tokenizer=tokenizer,
    device=0,
    return_all_scores=True,
)

Create an explainer for the pipeline

A transformers pipeline object can be passed directly to shap.Explainer, which will then wrap the pipeline model as a shap.models.TransformersPipeline model and the pipeline tokenizer as a shap.maskers.Text masker.

[3]:
explainer = shap.Explainer(pred)

Compute SHAP values

Explainers have the same method signature as the models they are explaining, so we just pass a list of strings for which to explain the classifications.

[4]:
shap_values = explainer(data["text"][:3])

Visualize the impact on all the output classes

In the plots below, when you hover your mouse over an output class you get the explanation for that output class. When you click an output class name then that class remains the focus of the explanation visualization until you click another class.

The base value is what the model outputs when the entire input text is masked, while \(f_{output class}(inputs)\) is the output of the model for the full original input. The SHAP values explain in an addive way how the impact of unmasking each word changes the model output from the base value (where the entire input is masked) to the final prediction value.

[5]:
shap.plots.text(shap_values)


[0]
outputs
sadness
joy
love
anger
fear
surprise


0.30.1-0.1-0.30.50.70.90.1316720.131672base value0.9964080.996408fsadness(inputs)0.855 humiliated 0.009 didn 0.003 i 0.001 t 0.0 -0.004 feel -0.0
inputs
-0.0
0.003
i
0.009
didn
0.001
t
-0.004
feel
0.855
humiliated
0.0


[1]
outputs
sadness
joy
love
anger
fear
surprise


0.30.1-0.1-0.30.50.70.90.1441950.144195base value0.9952920.995292fsadness(inputs)0.599 hopeless 0.28 feeling 0.039 so 0.004 from 0.004 damned 0.002 from 0.002 awake 0.002 i 0.001 to 0.001 who 0.0 go 0.0 0.0 -0.045 hopeful -0.011 cares -0.006 just -0.006 is -0.004 someone -0.003 can -0.002 and -0.002 being -0.002 around -0.001 so
inputs
0.0
0.002
i
-0.003
can
0.0
go
0.004
from
0.28
feeling
0.039
so
0.599
hopeless
0.001
to
-0.001
so
0.004
damned
-0.045
hopeful
-0.006
just
0.002
from
-0.002
being
-0.002
around
-0.004
someone
0.001
who
-0.011
cares
-0.002
and
-0.006
is
0.002
awake
0.0


[2]
outputs
sadness
joy
love
anger
fear
surprise


0.30.1-0.1-0.30.50.70.90.152610.15261base value0.002277240.00227724fsadness(inputs)0.0 i 0.0 0.0 -0.097 greedy -0.019 feel -0.013 grabbing -0.007 im -0.005 a -0.005 to -0.003 post -0.001 wrong -0.0 minute
inputs
0.0
-0.007
im
-0.013
grabbing
-0.005
a
-0.0
minute
-0.005
to
-0.003
post
0.0
i
-0.019
feel
-0.097
greedy
-0.001
wrong
0.0

Visualize the impact on a single class

Since Explanation objects are sliceable we can slice out just a single output class to visualize the model output towards that class.

[11]:
shap.plots.text(shap_values[:, :, "anger"])


[0]
0.50.30.10.70.90.2789150.278915base value0.001233210.00123321fanger(inputs)0.028 didn 0.015 i 0.008 t -0.199 humiliated -0.13 feel -0.0 -0.0
inputs
-0.0
0.015
i
0.028
didn
0.008
t
-0.13
feel
-0.199
humiliated
-0.0


[1]
0.50.30.10.70.90.2716290.271629base value0.000462820.00046282fanger(inputs)0.015 damned 0.005 from 0.005 to 0.004 so 0.004 around 0.002 i 0.002 being 0.001 is 0.0 -0.097 hopeful -0.08 hopeless -0.045 feeling -0.028 awake -0.021 cares -0.016 so -0.008 someone -0.004 just -0.004 who -0.003 and -0.003 can go -0.001 from -0.0
inputs
-0.0
0.002
i
-0.003 / 2
can go
0.005
from
-0.045
feeling
-0.016
so
-0.08
hopeless
0.005
to
0.004
so
0.015
damned
-0.097
hopeful
-0.004
just
-0.001
from
0.002
being
0.004
around
-0.008
someone
-0.004
who
-0.021
cares
-0.003
and
0.001
is
-0.028
awake
0.0


[2]
0.50.30.10.70.90.2303730.230373base value0.9914620.991462fanger(inputs)0.545 greedy 0.118 wrong 0.07 grabbing 0.023 post 0.015 im 0.006 feel 0.005 minute 0.0 -0.016 to -0.004 i -0.001 a -0.0
inputs
-0.0
0.015
im
0.07
grabbing
-0.001
a
0.005
minute
-0.016
to
0.023
post
-0.004
i
0.006
feel
0.545
greedy
0.118
wrong
0.0

Plotting the top words impacting a specific class

In addition to slicing, Explanation objects also support a set of reducing methods. Here we use the .mean(0) to take the average impact of all words towards the “joy” class. Note that here we are also averaging over three examples, to get a better summary you would want to use a larger portion of the dataset.

[12]:
shap.plots.bar(shap_values[:, :, "joy"].mean(0))
../../../_images/example_notebooks_text_examples_sentiment_analysis_Emotion_classification_multiclass_example_14_0.png
[13]:
# we can sort the bar chart in decending order
shap.plots.bar(shap_values[:, :, "joy"].mean(0), order=shap.Explanation.argsort)
../../../_images/example_notebooks_text_examples_sentiment_analysis_Emotion_classification_multiclass_example_15_0.png
[14]:
# ...or acending order
shap.plots.bar(shap_values[:, :, "joy"].mean(0), order=shap.Explanation.argsort.flip)
../../../_images/example_notebooks_text_examples_sentiment_analysis_Emotion_classification_multiclass_example_16_0.png

Explain the log odds instead of the probabilities

In the examples above we explained the direct output of the pipline object, which are class probabilities. Sometimes it makes more sense to work in a log odds space where it is natural to add and subtract effects (addition and subtraction correspond to the addition or subtraction of bits of evidence information). To work with logits we can use a parameter of the shap.models.TransformersPipeline object:

[15]:
logit_explainer = shap.Explainer(
    shap.models.TransformersPipeline(pred, rescale_to_logits=True)
)

logit_shap_values = logit_explainer(data["text"][:3])
shap.plots.text(logit_shap_values)


[0]
outputs
sadness
joy
love
anger
fear
surprise


-1-4-725-1.88626-1.88626base value5.625445.62544fsadness(inputs)6.901 humiliated 0.201 feel 0.173 didn 0.16 i 0.076 t 0.0 -0.0
inputs
-0.0
0.16
i
0.173
didn
0.076
t
0.201
feel
6.901
humiliated
0.0


[1]
outputs
sadness
joy
love
anger
fear
surprise


-1-4-725-1.78088-1.78088base value5.353885.35388fsadness(inputs)5.914 hopeless 2.741 feeling 0.248 so 0.079 to so 0.063 can go 0.053 damned 0.029 from -1.3 hopeful -0.172 just from -0.135 awake -0.119 cares -0.11 someone who -0.071 being around -0.054 is -0.025 i -0.006 and -0.0
inputs
-0.025 / 2
i
0.063 / 2
can go
0.029
from
2.741
feeling
0.248
so
5.914
hopeless
0.079 / 2
to so
0.053
damned
-1.3
hopeful
-0.172 / 2
just from
-0.071 / 2
being around
-0.11 / 2
someone who
-0.119
cares
-0.006
and
-0.054
is
-0.135
awake
-0.0


[2]
outputs
sadness
joy
love
anger
fear
surprise


-1-4-725-1.71428-1.71428base value-6.08251-6.08251fsadness(inputs)0.212 wrong 0.009 post 0.0 0.0 -3.174 greedy -0.528 feel -0.518 grabbing -0.152 im -0.131 a -0.067 to -0.02 i -0.0 minute
inputs
0.0
-0.152
im
-0.518
grabbing
-0.131
a
-0.0
minute
-0.067
to
0.009
post
-0.02
i
-0.528
feel
-3.174
greedy
0.212
wrong
0.0

Have an idea for more helpful examples? Pull requests that add to this documentation notebook are encouraged!