Explaining a Question Answering Transformers Model
Here we demonstrate how to explain the output of a question answering model that predicts which range of the context text contains the answer to a given question.
[1]:
import transformers
import shap
import torch
# load the model
pmodel = transformers.pipeline('question-answering')
# define two predictions, one that outputs the logits for the range start,
# and the other for the range end
def f(questions, start):
outs = []
for q in questions:
question, context = q.split("[SEP]")
d = pmodel.tokenizer(question, context)
out = pmodel.model.forward(**{k: torch.tensor(d[k]).reshape(1, -1) for k in d})
logits = out.start_logits if start else out.end_logits
outs.append(logits.reshape(-1).detach().numpy())
return outs
def f_start(questions):
return f(questions, True)
def f_end(questions):
return f(questions, False)
# attach a dynamic output_names property to the models so we can plot the tokens at each output position
def out_names(inputs):
question, context = inputs.split("[SEP]")
d = pmodel.tokenizer(question, context)
return [pmodel.tokenizer.decode([id]) for id in d["input_ids"]]
f_start.output_names = out_names
f_end.output_names = out_names
Explain the starting positions
Here we explain the starting range predictions of the model. Note that because the model output depends on the length of the model input, is is important that we pass the model’s native tokenizer for masking, so that when we hide portions of the text we can retain the same number of tokens and hence the same meaning for each output position.
[2]:
data = ["What is on the table?[SEP]When I got home today I saw my cat on the table, and my frog on the floor."]
explainer_start = shap.Explainer(f_start, pmodel.tokenizer)
shap_values_start = explainer_start(data)
shap.plots.text(shap_values_start)
[0]
outputs
[CLS]
What
is
on
the
table
?
[SEP]
When
I
got
home
today
I
saw
my
cat
on
the
table
,
and
my
frog
on
the
floor
.
[SEP]
inputs
0.012
0.486
What
-0.278
is
0.034
on
0.022
the
-0.374
table
-0.937
?
0.0
[SEP]
-0.041 / 2
When I
-0.009 / 2
got home
-0.074 / 2
today I
-0.128 / 2
saw my
-0.324 / 2
cat on
-0.268 / 2
the table
-0.126
,
-0.137
and
-0.226 / 2
my frog
-0.341 / 4
on the floor.
0.0
inputs
0.012
0.486
What
-0.278
is
0.034
on
0.022
the
-0.374
table
-0.937
?
0.0
[SEP]
-0.041 / 2
When I
-0.009 / 2
got home
-0.074 / 2
today I
-0.128 / 2
saw my
-0.324 / 2
cat on
-0.268 / 2
the table
-0.126
,
-0.137
and
-0.226 / 2
my frog
-0.341 / 4
on the floor.
0.0
inputs
0.06
1.421
What
-0.683
is
-0.482
on
-0.508
the
-0.669
table
-2.177
?
0.0
[SEP]
0.028 / 2
When I
0.054 / 2
got home
0.113 / 2
today I
0.035 / 2
saw my
-0.242 / 2
cat on
-0.278 / 2
the table
-0.044
,
-0.39
and
-0.455 / 2
my frog
-0.232 / 4
on the floor.
0.0
inputs
0.06
1.421
What
-0.683
is
-0.482
on
-0.508
the
-0.669
table
-2.177
?
0.0
[SEP]
0.028 / 2
When I
0.054 / 2
got home
0.113 / 2
today I
0.035 / 2
saw my
-0.242 / 2
cat on
-0.278 / 2
the table
-0.044
,
-0.39
and
-0.455 / 2
my frog
-0.232 / 4
on the floor.
0.0
inputs
0.011
0.271
What
-0.561
is
-0.097
on
-0.412
the
-0.405
table
-1.745
?
0.0
[SEP]
0.066 / 2
When I
0.108 / 2
got home
0.232 / 2
today I
0.165 / 2
saw my
-0.235 / 2
cat on
-0.276 / 2
the table
-0.046
,
-0.421
and
-0.491 / 2
my frog
-0.401 / 4
on the floor.
0.0
inputs
0.011
0.271
What
-0.561
is
-0.097
on
-0.412
the
-0.405
table
-1.745
?
0.0
[SEP]
0.066 / 2
When I
0.108 / 2
got home
0.232 / 2
today I
0.165 / 2
saw my
-0.235 / 2
cat on
-0.276 / 2
the table
-0.046
,
-0.421
and
-0.491 / 2
my frog
-0.401 / 4
on the floor.
0.0
inputs
-0.012
0.474
What
0.303
is
0.201
on
-0.81
the
-0.447
table
-2.052
?
0.0
[SEP]
0.114 / 2
When I
0.154 / 2
got home
0.16 / 2
today I
0.102 / 2
saw my
-0.554 / 2
cat on
-0.634 / 2
the table
-0.035
,
-0.346
and
-0.516 / 2
my frog
-0.527 / 4
on the floor.
0.0
inputs
-0.012
0.474
What
0.303
is
0.201
on
-0.81
the
-0.447
table
-2.052
?
0.0
[SEP]
0.114 / 2
When I
0.154 / 2
got home
0.16 / 2
today I
0.102 / 2
saw my
-0.554 / 2
cat on
-0.634 / 2
the table
-0.035
,
-0.346
and
-0.516 / 2
my frog
-0.527 / 4
on the floor.
0.0
inputs
-0.006
0.551
What
-0.467
is
0.339
on
0.117
the
-0.773
table
-2.561
?
0.0
[SEP]
0.087 / 2
When I
0.132 / 2
got home
0.108 / 2
today I
0.041 / 2
saw my
-0.473 / 2
cat on
-0.605 / 2
the table
-0.011
,
-0.316
and
-0.535 / 2
my frog
-0.295 / 4
on the floor.
0.0
inputs
-0.006
0.551
What
-0.467
is
0.339
on
0.117
the
-0.773
table
-2.561
?
0.0
[SEP]
0.087 / 2
When I
0.132 / 2
got home
0.108 / 2
today I
0.041 / 2
saw my
-0.473 / 2
cat on
-0.605 / 2
the table
-0.011
,
-0.316
and
-0.535 / 2
my frog
-0.295 / 4
on the floor.
0.0
inputs
0.001
0.44
What
-0.37
is
-0.445
on
0.383
the
0.25
table
-2.004
?
0.0
[SEP]
0.001 / 2
When I
0.049 / 2
got home
0.056 / 2
today I
-0.021 / 2
saw my
-0.3 / 2
cat on
-0.425 / 2
the table
-0.009
,
-0.326
and
-0.504 / 2
my frog
-0.41 / 4
on the floor.
0.0
inputs
0.001
0.44
What
-0.37
is
-0.445
on
0.383
the
0.25
table
-2.004
?
0.0
[SEP]
0.001 / 2
When I
0.049 / 2
got home
0.056 / 2
today I
-0.021 / 2
saw my
-0.3 / 2
cat on
-0.425 / 2
the table
-0.009
,
-0.326
and
-0.504 / 2
my frog
-0.41 / 4
on the floor.
0.0