Explaining a Question Answering Transformers Model

Here we demonstrate how to explain the output of a question answering model that predicts which range of the context text contains the answer to a given question.

[1]:
import torch
import transformers

import shap

# load the model
pmodel = transformers.pipeline("question-answering")


# define two predictions, one that outputs the logits for the range start,
# and the other for the range end
def f(questions, start):
    outs = []
    for q in questions:
        question, context = q.split("[SEP]")
        d = pmodel.tokenizer(question, context)
        out = pmodel.model.forward(**{k: torch.tensor(d[k]).reshape(1, -1) for k in d})
        logits = out.start_logits if start else out.end_logits
        outs.append(logits.reshape(-1).detach().numpy())
    return outs


def f_start(questions):
    return f(questions, True)


def f_end(questions):
    return f(questions, False)


# attach a dynamic output_names property to the models so we can plot the tokens at each output position
def out_names(inputs):
    question, context = inputs.split("[SEP]")
    d = pmodel.tokenizer(question, context)
    return [pmodel.tokenizer.decode([id]) for id in d["input_ids"]]


f_start.output_names = out_names
f_end.output_names = out_names

Explain the starting positions

Here we explain the starting range predictions of the model. Note that because the model output depends on the length of the model input, is is important that we pass the model’s native tokenizer for masking, so that when we hide portions of the text we can retain the same number of tokens and hence the same meaning for each output position.

[2]:
data = [
    "What is on the table?[SEP]When I got home today I saw my cat on the table, and my frog on the floor."
]

explainer_start = shap.Explainer(f_start, pmodel.tokenizer)
shap_values_start = explainer_start(data)

shap.plots.text(shap_values_start)


[0]
outputs
[CLS]
What
is
on
the
table
?
[SEP]
When
I
got
home
today
I
saw
my
cat
on
the
table
,
and
my
frog
on
the
floor
.
[SEP]


-1-3-5-7135-1.26347-1.26347base value-3.97193-3.97193f[CLS](inputs)0.486 What 0.034 on 0.022 the 0.012 -0.937 ? -0.374 table -0.341 on the floor. -0.324 cat on -0.278 is -0.268 the table -0.226 my frog -0.137 and -0.128 saw my -0.126 , -0.074 today I -0.041 When I -0.009 got home
inputs
0.012
0.486
What
-0.278
is
0.034
on
0.022
the
-0.374
table
-0.937
?
0.0
[SEP]
-0.041 / 2
When I
-0.009 / 2
got home
-0.074 / 2
today I
-0.128 / 2
saw my
-0.324 / 2
cat on
-0.268 / 2
the table
-0.126
,
-0.137
and
-0.226 / 2
my frog
-0.341 / 4
on the floor.
0.0

Explain the ending positions

This is the same process as above, but now we explain the end tokens.

[3]:
explainer_end = shap.Explainer(f_end, pmodel.tokenizer)
shap_values_end = explainer_end(data)

shap.plots.text(shap_values_end)


[0]
outputs
[CLS]
What
is
on
the
table
?
[SEP]
When
I
got
home
today
I
saw
my
cat
on
the
table
,
and
my
frog
on
the
floor
.
[SEP]


-0-3-636-1.72717-1.72717base value-2.16449-2.16449f[CLS](inputs)0.424 What 0.233 got home 0.231 the 0.215 When I 0.21 today I saw my 0.129 0.069 on -0.635 ? -0.292 cat on -0.221 my frog -0.217 and -0.159 on the floor -0.126 table -0.119 , -0.093 is -0.068 the table -0.018 .
inputs
0.129
0.424
What
-0.093
is
0.069
on
0.231
the
-0.126
table
-0.635
?
0.0
[SEP]
0.215 / 2
When I
0.233 / 2
got home
0.21 / 4
today I saw my
-0.292 / 2
cat on
-0.068 / 2
the table
-0.119
,
-0.217
and
-0.221 / 2
my frog
-0.159 / 3
on the floor
-0.018
.
0.0

Explain a matching function

In the example above we directly explained the output logits coming from the model. This required us to ensure that we only perturbed the input in length-preserving ways, so as to not change the meaning of the output logits. A less detailed but more flexible approach is to just score if specific answers are produced by the model.

[4]:
def make_answer_scorer(answers):
    def f(questions):
        out = []
        for q in questions:
            question, context = q.split("[SEP]")
            results = pmodel(question, context, topk=20)
            values = []
            for answer in answers:
                value = 0
                for result in results:
                    if result["answer"] == answer:
                        value = result["score"]
                        break
                values.append(value)
            out.append(values)
        return out

    f.output_names = answers
    return f


f_answers = make_answer_scorer(["my cat", "cat", "my frog"])
explainer_answers = shap.Explainer(f_answers, pmodel.tokenizer)
shap_values_answers = explainer_answers(data)

shap.plots.text(shap_values_answers)


[0]
outputs
my cat
cat
my frog


0.20.10-0.10.30.40.500base value0.4980790.498079fmy cat(inputs)0.14 table 0.097 my 0.056 What 0.049 table 0.049 the 0.042 cat 0.034 on the floor 0.026 on 0.015 saw 0.013 the 0.013 When I got home 0.013 ? 0.008 today I 0.003 , 0.002 on 0.001 0.0 . -0.042 my frog -0.013 is -0.008 and
inputs
0.001
0.056
What
-0.013
is
0.002
on
0.049
the
0.14
table
0.013
?
0.0
[SEP]
0.013 / 4
When I got home
0.008 / 2
today I
0.015
saw
0.097
my
0.042
cat
0.026
on
0.013
the
0.049
table
0.003
,
-0.008
and
-0.042 / 2
my frog
0.034 / 3
on the floor
0.0
.
0.0

Have an idea for more helpful examples? Pull requests that add to this documentation notebook are encouraged!