Explaining a Question Answering Transformers Model

Here we demonstrate how to explain the output of a question answering model that predicts which range of the context text contains the answer to a given question.

[1]:
import transformers
import shap
import torch

# load the model
pmodel = transformers.pipeline('question-answering')

# define two predictions, one that outputs the logits for the range start,
# and the other for the range end
def f(questions, start):
    outs = []
    for q in questions:
        question, context = q.split("[SEP]")
        d = pmodel.tokenizer(question, context)
        out = pmodel.model.forward(**{k: torch.tensor(d[k]).reshape(1, -1) for k in d})
        logits = out.start_logits if start else out.end_logits
        outs.append(logits.reshape(-1).detach().numpy())
    return outs
def f_start(questions):
    return f(questions, True)
def f_end(questions):
    return f(questions, False)

# attach a dynamic output_names property to the models so we can plot the tokens at each output position
def out_names(inputs):
    question, context = inputs.split("[SEP]")
    d = pmodel.tokenizer(question, context)
    return [pmodel.tokenizer.decode([id]) for id in d["input_ids"]]
f_start.output_names = out_names
f_end.output_names = out_names

Explain the starting positions

Here we explain the starting range predictions of the model. Note that because the model output depends on the length of the model input, is is important that we pass the model’s native tokenizer for masking, so that when we hide portions of the text we can retain the same number of tokens and hence the same meaning for each output position.

[2]:
data = ["What is on the table?[SEP]When I got home today I saw my cat on the table, and my frog on the floor."]

explainer_start = shap.Explainer(f_start, pmodel.tokenizer)
shap_values_start = explainer_start(data)

shap.plots.text(shap_values_start)


[0]
outputs
[CLS]
What
is
on
the
table
?
[SEP]
When
I
got
home
today
I
saw
my
cat
on
the
table
,
and
my
frog
on
the
floor
.
[SEP]


-1-3-5-7135-1.26347-1.26347base value-3.97193-3.97193f[CLS](inputs)0.486 What 0.034 on 0.022 the 0.012 -0.937 ? -0.374 table -0.341 on the floor. -0.324 cat on -0.278 is -0.268 the table -0.226 my frog -0.137 and -0.128 saw my -0.126 , -0.074 today I -0.041 When I -0.009 got home
inputs
0.012
0.486
What
-0.278
is
0.034
on
0.022
the
-0.374
table
-0.937
?
0.0
[SEP]
-0.041 / 2
When I
-0.009 / 2
got home
-0.074 / 2
today I
-0.128 / 2
saw my
-0.324 / 2
cat on
-0.268 / 2
the table
-0.126
,
-0.137
and
-0.226 / 2
my frog
-0.341 / 4
on the floor.
0.0