Text to Multiclass Explanation: Language Modeling Example

This notebook demostrates how to get explanations for the top-k next words generated by a language model. In this demo, we use the pretrained gpt2 model provided by hugging face (https://huggingface.co/gpt2) to predict the top-k next words. By looking at the top-k next words, we treat them as k separate classes and then learn the explanations for each of this k words. We thereby are able to explain the contribution of words in the input that are responsible for the liklihood of the top-k next words to be predicted.

[1]:
from transformers import AutoModelForCausalLM, AutoTokenizer

import shap

Load model and tokenizer

[2]:
tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=True)
model = AutoModelForCausalLM.from_pretrained("gpt2").cuda()

We next wrap the model with the TopKLM model which extracts the log odds of the top-k next words and also create a Text masker by initializing it with the mask_token = “…” and set collapse_mask_token = True, which is used for infilling text during perturbation of the inputs.

[3]:
wrapped_model = shap.models.TopKLM(model, tokenizer, k=100)
masker = shap.maskers.Text(tokenizer, mask_token="...", collapse_mask_token=True)

Define data

Here we set the initial text for which we want the gpt2 model to predict the next word

[4]:
s = ["In a shocking finding, scientists discovered a herd of unicorns living in a"]

Create explainer object

[5]:
explainer = shap.Explainer(wrapped_model, masker)

Compute SHAP values

[6]:
shap_values = explainer(s)

Visualize the SHAP values across the input sentence for the top-k next words

We can now see the top-k next words predicted by gpt2 under “Output Text” in the viz plot below and hover over each of the token to understand which words in the input sentence are driving the generation of the particular output word to be predicted

[7]:
shap.plots.text(shap_values)


[0]
outputs
cave
forest
small
desert
tiny
"
remote
zoo
tree
field
house
nest
tropical
lake
large
mountain
farm
group
wild
very
single
barn
jungle
new
valley
world
garden
herd
grass
natural
park
swamp
laboratory
nearby
well
rural
pond
dark
wood
subter
room
lab
cage
huge
New
water
colony
massive
common
state
deep
home
man
mine
human
rock
region
box
river
part
hollow
c
hole
vast
village
different
virtual
city
strange
greenhouse
frozen
shallow
semi
flat
patch
mysterious
local
giant
sub
barren
special
mountainous
mud
cemetery
pod
hive
newly
closed
community
California
place
flooded
prehistoric
sw
high
z
hot
far
1
pasture


-9-12-15-6-3-12.5736-12.5736base value-2.9149-2.9149f cave(inputs)4.799 a 2.87 in 1.266 living 0.741 orns 0.474 of 0.463 unic 0.165 herd 0.138 a 0.115 In -0.304 , -0.304 discovered -0.236 finding -0.2 shocking -0.176 a -0.153 scientists
inputs
0.115
In
0.138
a
-0.2
shocking
-0.236
finding
-0.304
,
-0.153
scientists
-0.304
discovered
-0.176
a
0.165
herd
0.474
of
0.463
unic
0.741
orns
1.266
living
2.87
in
4.799
a

Have an idea for more helpful examples? Pull requests that add to this documentation notebook are encouraged!