Open Ended GPT2 Text Generation Explanations

This notebook demonstrates how to get explanations for the output of gpt2 used for open ended text generation. In this demo, we use the pretrained gpt2 model provided by hugging face (https://huggingface.co/gpt2) to explain the generated text by gpt2. We further showcase how to get explanations for custom output generated text and plot global input token importances for any output generated token.

[1]:
from transformers import AutoModelForCausalLM, AutoTokenizer

import shap

Load model and tokenizer

[2]:
tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=True)
model = AutoModelForCausalLM.from_pretrained("gpt2").cuda()

Below, we set certain model configurations. We need to define if the model is a decoder or encoder-decoder. This can be set through the ‘is_decoder’ or ‘is_encoder_decoder’ param in model’s config file. We can also set custom model generation parameters which will be used during the output text generation decoding process.

[3]:
# set model decoder to true
model.config.is_decoder = True
# set text-generation params under task_specific_params
model.config.task_specific_params["text-generation"] = {
    "do_sample": True,
    "max_length": 50,
    "temperature": 0.7,
    "top_k": 50,
    "no_repeat_ngram_size": 2,
}

Define initial text

[4]:
s = ["I enjoy walking with my cute dog"]

Create an explainer object and compute the SHAP values

[5]:
explainer = shap.Explainer(model, tokenizer)
shap_values = explainer(s)
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.

Visualize shap explanations

[6]:
shap.plots.text(shap_values)


[0]
outputs
,
but
I
'm
not
sure
if
I
'll
ever
be
able
to


0-2-424-4.04941-4.04941base value-1.27522-1.27522f,(inputs)4.064 dog 0.072 with -0.431 enjoy -0.427 walking -0.238 cute -0.15 my -0.117 I
inputs
-0.117
I
-0.431
enjoy
-0.427
walking
0.072
with
-0.15
my
-0.238
cute
4.064
dog

Another example…

[7]:
s = [
    "Scientists confirmed the worst possible outcome: the massive asteroid will collide with Earth"
]
[8]:
explainer = shap.Explainer(model, tokenizer)
shap_values = explainer(s)
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
[9]:
shap.plots.text(shap_values)


[0]
outputs
in
the
coming
days
.


-3-5-7-11-4.7396-4.7396base value-1.7384-1.7384fin(inputs)1.436 Earth 1.28 collide 0.489 with 0.442 Scientists 0.24 worst 0.221 asteroid 0.093 massive 0.075 outcome -0.385 the -0.365 : -0.168 the -0.166 will -0.156 possible -0.035 confirmed
inputs
0.442
Scientists
-0.035
confirmed
-0.168
the
0.24
worst
-0.156
possible
0.075
outcome
-0.365
:
-0.385
the
0.093
massive
0.221
asteroid
-0.166
will
1.28
collide
0.489
with
1.436
Earth

Custom text generation and debugging biased outputs

Below we demonstrate the process of how to explain the liklihood of generating a particular output sentence given an input sentence using the model. For example, we ask a question: Which country’s inhabitant (target) in the sentence “I know many people who are [target].” would have a high liklilhood of generating the token “vodka” in the output sentence “They love their vodka!” ? For this, we first define input-output sentence pairs

[10]:
# define input
x = [
    "I know many people who are Russian.",
    "I know many people who are Greek.",
    "I know many people who are Australian.",
    "I know many people who are American.",
    "I know many people who are Italian.",
    "I know many people who are Spanish.",
    "I know many people who are German.",
    "I know many people who are Indian.",
]
[11]:
# define output
y = [
    "They love their vodka!",
    "They love their vodka!",
    "They love their vodka!",
    "They love their vodka!",
    "They love their vodka!",
    "They love their vodka!",
    "They love their vodka!",
    "They love their vodka!",
]

We wrap the model with a Teacher Forcing scoring class and create a Text masker

[12]:
teacher_forcing_model = shap.models.TeacherForcing(model, tokenizer)
masker = shap.maskers.Text(tokenizer, mask_token="...", collapse_mask_token=True)

Create an explainer…

[13]:
explainer = shap.Explainer(teacher_forcing_model, masker)

Generate SHAP explanation values!

[14]:
shap_values = explainer(x, y)

Now that we have generated the SHAP values, we can have a look at the contribution of tokens in the input driving the token “vodka” in the output sentence using the text plot. Note: The red color indicates a positive contribution while the blue color indicates negative contribution and the intensity of the color shows its strength in the respective direction.

[15]:
shap.plots.text(shap_values)


[0]
outputs
They
love
their
vodka
!


-7-8-9-10-6-5-4-3-8.24848-8.24848base value-8.78452-8.78452fThey(inputs)0.375 . 0.124 people 0.109 are 0.035 who -0.488 Russian -0.377 I -0.158 know -0.157 many
inputs
-0.377
I
-0.158
know
-0.157
many
0.124
people
0.035
who
0.109
are
-0.488
Russian
0.375
.


[1]
outputs
They
love
their
vodka
!


-7-8-9-10-6-5-4-3-8.24848-8.24848base value-8.94869-8.94869fThey(inputs)0.387 . 0.149 people 0.144 are 0.054 who -0.716 Greek -0.351 I -0.242 many -0.125 know
inputs
-0.351
I
-0.125
know
-0.242
many
0.149
people
0.054
who
0.144
are
-0.716
Greek
0.387
.


[2]
outputs
They
love
their
vodka
!


-7-8-9-10-6-5-4-3-8.24848-8.24848base value-8.67602-8.67602fThey(inputs)0.701 . 0.144 people 0.015 are -0.529 Australian -0.41 I -0.176 many -0.158 know -0.015 who
inputs
-0.41
I
-0.158
know
-0.176
many
0.144
people
-0.015
who
0.015
are
-0.529
Australian
0.701
.


[3]
outputs
They
love
their
vodka
!


-7-8-9-10-6-5-4-3-8.24848-8.24848base value-9.14276-9.14276fThey(inputs)0.39 . 0.134 people 0.03 are -0.632 American -0.439 I -0.185 know -0.162 many -0.03 who
inputs
-0.439
I
-0.185
know
-0.162
many
0.134
people
-0.03
who
0.03
are
-0.632
American
0.39
.


[4]
outputs
They
love
their
vodka
!


-7-8-9-10-6-5-4-3-8.24848-8.24848base value-9.08274-9.08274fThey(inputs)0.428 . 0.155 are 0.106 people 0.079 who -0.76 Italian -0.454 I -0.24 many -0.149 know
inputs
-0.454
I
-0.149
know
-0.24
many
0.106
people
0.079
who
0.155
are
-0.76
Italian
0.428
.


[5]
outputs
They
love
their
vodka
!


-7-8-9-10-6-5-4-3-8.24848-8.24848base value-9.0745-9.0745fThey(inputs)0.414 . 0.288 are 0.156 who 0.106 people -1.015 Spanish -0.399 I -0.225 many -0.15 know
inputs
-0.399
I
-0.15
know
-0.225
many
0.106
people
0.156
who
0.288
are
-1.015
Spanish
0.414
.


[6]
outputs
They
love
their
vodka
!


-7-8-9-10-6-5-4-3-8.24848-8.24848base value-8.9994-8.9994fThey(inputs)0.46 . 0.186 are 0.138 people 0.063 who -0.811 German -0.38 I -0.282 many -0.125 know
inputs
-0.38
I
-0.125
know
-0.282
many
0.138
people
0.063
who
0.186
are
-0.811
German
0.46
.


[7]
outputs
They
love
their
vodka
!


-7-8-9-10-6-5-4-3-8.24848-8.24848base value-8.63055-8.63055fThey(inputs)0.374 . 0.128 people 0.1 Indian -0.484 I -0.227 know -0.21 many -0.054 who -0.011 are
inputs
-0.484
I
-0.227
know
-0.21
many
0.128
people
-0.054
who
-0.011
are
0.1
Indian
0.374
.

To view what input tokens impact (positively/negatively) the liklihood of generating the word “vodka”, we plot the global token importances the word “vodka”.

Voila! Russians love their vodka, dont they? :)

[16]:
shap.plots.bar(shap_values[0, :, "vodka"])
../../../_images/example_notebooks_text_examples_text_generation_Open_Ended_GPT2_Text_Generation_Explanations_30_0.png

Have an idea for more helpful examples? Pull requests that add to this documentation notebook are encouraged!