Emotion classification multiclass example

This notebook demonstrates how to use the Partition explainer for a multiclass text classification scenario. Once the SHAP values are computed for a set of sentences we then visualize feature attributions towards individual classes. The text classifcation model we use is BERT fine-tuned on an emotion dataset to classify a sentence among six classes: joy, sadness, anger, fear, love and surprise.

[1]:
import datasets
import pandas as pd
import transformers

import shap

# load the emotion dataset
dataset = datasets.load_dataset("emotion", split="train")
data = pd.DataFrame({"text": dataset["text"], "emotion": dataset["label"]})
Using custom data configuration default
Reusing dataset emotion (/home/slundberg/.cache/huggingface/datasets/emotion/default/0.0.0/aa34462255cd487d04be8387a2d572588f6ceee23f784f37365aa714afeb8fe6)

Build a transformers pipline

Note that we have set return_all_scores=True for the pipeline so we can observe the model’s behavior for all classes, not just the top output.

[2]:
# load the model and tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained(
    "nateraw/bert-base-uncased-emotion", use_fast=True
)
model = transformers.AutoModelForSequenceClassification.from_pretrained(
    "nateraw/bert-base-uncased-emotion"
).cuda()

# build a pipeline object to do predictions
pred = transformers.pipeline(
    "text-classification",
    model=model,
    tokenizer=tokenizer,
    device=0,
    return_all_scores=True,
)

Create an explainer for the pipeline

A transformers pipeline object can be passed directly to shap.Explainer, which will then wrap the pipeline model as a shap.models.TransformersPipeline model and the pipeline tokenizer as a shap.maskers.Text masker.

[3]:
explainer = shap.Explainer(pred)

Compute SHAP values

Explainers have the same method signature as the models they are explaining, so we just pass a list of strings for which to explain the classifications.

[4]:
shap_values = explainer(data["text"][:3])

Visualize the impact on all the output classes

In the plots below, when you hover your mouse over an output class you get the explanation for that output class. When you click an output class name then that class remains the focus of the explanation visualization until you click another class.

The base value is what the model outputs when the entire input text is masked, while \(f_{output class}(inputs)\) is the output of the model for the full original input. The SHAP values explain in an addive way how the impact of unmasking each word changes the model output from the base value (where the entire input is masked) to the final prediction value.

[5]:
shap.plots.text(shap_values)


[0]
outputs
sadness
joy
love
anger
fear
surprise


0.30.1-0.1-0.30.50.70.90.1316720.131672base value0.9964080.996408fsadness(inputs)0.855 humiliated 0.009 didn 0.003 i 0.001 t 0.0 -0.004 feel -0.0
inputs
-0.0
0.003
i
0.009
didn
0.001
t
-0.004
feel
0.855
humiliated
0.0


[1]
outputs
sadness
joy
love
anger
fear
surprise


0.30.1-0.1-0.30.50.70.90.1441950.144195base value0.9952920.995292fsadness(inputs)0.599 hopeless 0.28 feeling 0.039 so 0.004 from 0.004 damned 0.002 from 0.002 awake 0.002 i 0.001 to 0.001 who 0.0 go 0.0 0.0 -0.045 hopeful -0.011 cares -0.006 just -0.006 is -0.004 someone -0.003 can -0.002 and -0.002 being -0.002 around -0.001 so
inputs
0.0
0.002
i
-0.003
can
0.0
go
0.004
from
0.28
feeling
0.039
so
0.599
hopeless
0.001
to
-0.001
so
0.004
damned
-0.045
hopeful
-0.006
just
0.002
from
-0.002
being
-0.002
around
-0.004
someone
0.001
who
-0.011
cares
-0.002
and
-0.006
is
0.002
awake
0.0