Text to Text Explanation: Open Ended Text Generation Using GPT2

This notebook demonstrates how to get explanations for the output of gpt2 used for open ended text generation. In this demo, we use the pretrained gpt2 model provided by hugging face (https://huggingface.co/gpt2) to explain the generated text by gpt2. We further showcase how to get explanations for custom output generated text and plot global input token importances for any output generated token.

[1]:
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
import shap
import torch

Load model and tokenizer

[2]:
tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=True)
model =  AutoModelForCausalLM.from_pretrained("gpt2").cuda()

Below, we set certain model configurations. We need to define if the model is a decoder or encoder-decoder. This can be set through the ‘is_decoder’ or ‘is_encoder_decoder’ param in model’s config file. We can also set custom model generation parameters which will be used during the output text generation decoding process.

[3]:
# set model decoder to true
model.config.is_decoder=True
# set text-generation params under task_specific_params
model.config.task_specific_params["text-generation"] = {
    "do_sample": True,
    "max_length": 50,
    "temperature": 0.7,
    "top_k": 50,
    "no_repeat_ngram_size": 2
  }

Define initial text

[4]:
s = ['I enjoy walking with my cute dog']

Create an explainer object

[5]:
explainer = shap.Explainer(model,tokenizer)
Using pad_token, but it is not set yet.
explainers.Partition is still in an alpha state, so use with caution...

Compute shap values

[6]:
shap_values = explainer(s)
Setting `pad_token_id` to 50256 (first `eos_token_id`) to generate sequence

Visualize shap explanations

[7]:
shap.plots.text(shap_values)

0th instance:
Visualization Type:
Input/Output - Heatmap
Layout :
Input Text
I
enjoy
walking
with
my
cute
dog
Output Text
and
playing
outside
with
a
friend
.
We
also
have
a
fun
dog
walk
each
year
and
have
run
across
a
bunch
of
different
things
,
so
we
know
what
we
're
doing
right
.
We
love
reading
,
and
we

Another example…

[8]:
s=['Scientists confirmed the worst possible outcome: the massive asteroid will collide with Earth']
[9]:
explainer = shap.Explainer(model,tokenizer)
explainers.Partition is still in an alpha state, so use with caution...
[10]:
shap_values = explainer(s)
Setting `pad_token_id` to 50256 (first `eos_token_id`) to generate sequence
[11]:
shap.plots.text(shap_values)

0th instance:
Visualization Type:
Input/Output - Heatmap
Layout :
Input Text
Scientists
confirmed
the
worst
possible
outcome
:
the
massive
asteroid
will
collide
with
Earth
Output Text
on
March
4
,
2012
.
The
two
-
hour
-
long
collision
will
destroy
one
of
the
three
largest
asteroids
in
the
Solar
System
,
the
National
Science
Foundation
said
Friday
.
A

Custom text generation and debugging biased outputs

Below we demonstrate the process of how to explain the liklihood of generating a particular output sentence given an input sentence using the model. For example, we ask a question: Which country’s inhabitant (target) in the sentence “I know many people who are [target].” would have a high liklilhood of generating the token “vodka” in the output sentence “They love their vodka!” ? For this, we first define input-output sentence pairs

[12]:
# define input
x = [
    "I know many people who are Russian.",
    "I know many people who are Greek.",
    "I know many people who are Australian.",
    "I know many people who are American.",
    "I know many people who are Italian.",
    "I know many people who are Spanish.",
    "I know many people who are German.",
    "I know many people who are Indian."
]
[13]:
# define output
y = [
    "They love their vodka!",
    "They love their vodka!",
    "They love their vodka!",
    "They love their vodka!",
    "They love their vodka!",
    "They love their vodka!",
    "They love their vodka!",
    "They love their vodka!"
]

We wrap the model with a Teacher Forcing scoring class and create a Text masker

[14]:
teacher_forcing_model = shap.models.TeacherForcing(model, tokenizer)
masker = shap.maskers.Text(tokenizer, mask_token = "...", collapse_mask_token=True)

Create an explainer…

[15]:
explainer = shap.Explainer(teacher_forcing_model,masker)
explainers.Partition is still in an alpha state, so use with caution...

Generate SHAP explanation values!

[16]:
shap_values = explainer(x, y)

Now that we have generated the SHAP values, we can have a look at the contribution of tokens in the input driving the token “vodka” in the output sentence using the text plot. Note: The red color indicates a positive contribution while the blue color indicates negative contribution and the intensity of the color shows its strength in the respective direction.

[17]:
shap.plots.text(shap_values)

0th instance:
Visualization Type:
Input/Output - Heatmap
Layout :
Input Text
I
know
many
people
who
are
Russian
.
Output Text
They
love
their
vodka
!

1st instance:
Visualization Type:
Input/Output - Heatmap
Layout :
Input Text
I
know
many
people
who
are
Greek
.
Output Text
They
love
their
vodka
!

2nd instance:
Visualization Type:
Input/Output - Heatmap
Layout :
Input Text
I
know
many
people
who
are
Australian
.
Output Text
They
love
their
vodka
!

3rd instance:
Visualization Type:
Input/Output - Heatmap
Layout :
Input Text
I
know
many
people
who
are
American
.
Output Text
They
love
their
vodka
!

4th instance:
Visualization Type:
Input/Output - Heatmap
Layout :
Input Text
I
know
many
people
who
are
Italian
.
Output Text
They
love
their
vodka
!

5th instance:
Visualization Type:
Input/Output - Heatmap
Layout :
Input Text
I
know
many
people
who
are
Spanish
.
Output Text
They
love
their
vodka
!

6th instance:
Visualization Type:
Input/Output - Heatmap
Layout :
Input Text
I
know
many
people
who
are
German
.
Output Text
They
love
their
vodka
!

7th instance:
Visualization Type:
Input/Output - Heatmap
Layout :
Input Text
I
know
many
people
who
are
Indian
.
Output Text
They
love
their
vodka
!

To view what input tokens impact (positively/negatively) the liklihood of generating the word “vodka”, we plot the global token importances the word “vodka”.

Voila! Russians love their vodka, dont they? :)

[18]:
shap.plots.bar(shap_values[...,"vodka"])
../../../_images/example_notebooks_text_examples_text_generation_Open_Ended_Text_Generation_Explanation_Demo_33_0.png
[ ]: