Open Ended GPT2 Text Generation Explanations

This notebook demonstrates how to get explanations for the output of gpt2 used for open ended text generation. In this demo, we use the pretrained gpt2 model provided by hugging face (https://huggingface.co/gpt2) to explain the generated text by gpt2. We further showcase how to get explanations for custom output generated text and plot global input token importances for any output generated token.

[1]:
from transformers import AutoModelForCausalLM, AutoTokenizer

import shap

Load model and tokenizer

[2]:
tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=True)
model = AutoModelForCausalLM.from_pretrained("gpt2").cuda()

Below, we set certain model configurations. We need to define if the model is a decoder or encoder-decoder. This can be set through the ‘is_decoder’ or ‘is_encoder_decoder’ param in model’s config file. We can also set custom model generation parameters which will be used during the output text generation decoding process.

[3]:
# set model decoder to true
model.config.is_decoder = True
# set text-generation params under task_specific_params
model.config.task_specific_params["text-generation"] = {
    "do_sample": True,
    "max_length": 50,
    "temperature": 0.7,
    "top_k": 50,
    "no_repeat_ngram_size": 2,
}

Define initial text

[4]:
s = ["I enjoy walking with my cute dog"]

Create an explainer object and compute the SHAP values

[5]:
explainer = shap.Explainer(model, tokenizer)
shap_values = explainer(s)
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.

Visualize shap explanations

[6]:
shap.plots.text(shap_values)


[0]
outputs
,
but
I
'm
not
sure
if
I
'll
ever
be
able
to


0-2-424-4.04941-4.04941base value-1.27522-1.27522f,(inputs)4.064 dog 0.072 with -0.431 enjoy -0.427 walking -0.238 cute -0.15 my -0.117 I
inputs
-0.117
I
-0.431
enjoy
-0.427
walking
0.072
with
-0.15
my
-0.238
cute
4.064
dog

Another example…

[7]:
s = ["Scientists confirmed the worst possible outcome: the massive asteroid will collide with Earth"]
[8]:
explainer = shap.Explainer(model, tokenizer)
shap_values = explainer(s)
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
[9]:
shap.plots.text(shap_values)


[0]
outputs
in
the
coming
days
.


-3-5-7-11-4.7396-4.7396base value-1.7384-1.7384fin(inputs)1.436 Earth 1.28 collide 0.489 with 0.442 Scientists 0.24 worst 0.221 asteroid 0.093 massive 0.075 outcome -0.385 the -0.365 : -0.168 the -0.166 will -0.156 possible -0.035 confirmed
inputs
0.442
Scientists
-0.035
confirmed
-0.168
the
0.24
worst
-0.156
possible
0.075
outcome
-0.365
:
-0.385
the
0.093
massive
0.221
asteroid
-0.166
will
1.28
collide
0.489
with
1.436
Earth