# Explaining a Question Answering Transformers Model

Here we demonstrate how to explain the output of a question answering model that predicts which range of the context text contains the answer to a given question.

[1]:

import transformers
import shap
import torch

# define two predictions, one that outputs the logits for the range start,
# and the other for the range end
def f(questions, start):
outs = []
for q in questions:
question, context = q.split("[SEP]")
d = pmodel.tokenizer(question, context)
out = pmodel.model.forward(**{k: torch.tensor(d[k]).reshape(1, -1) for k in d})
logits = out.start_logits if start else out.end_logits
outs.append(logits.reshape(-1).detach().numpy())
return outs
def f_start(questions):
return f(questions, True)
def f_end(questions):
return f(questions, False)

# attach a dynamic output_names property to the models so we can plot the tokens at each output position
def out_names(inputs):
question, context = inputs.split("[SEP]")
d = pmodel.tokenizer(question, context)
return [pmodel.tokenizer.decode([id]) for id in d["input_ids"]]
f_start.output_names = out_names
f_end.output_names = out_names


## Explain the starting positions

Here we explain the starting range predictions of the model. Note that because the model output depends on the length of the model input, is is important that we pass the model’s native tokenizer for masking, so that when we hide portions of the text we can retain the same number of tokens and hence the same meaning for each output position.

[2]:

data = ["What is on the table?[SEP]When I got home today I saw my cat on the table, and my frog on the floor."]

explainer_start = shap.Explainer(f_start, pmodel.tokenizer)
shap_values_start = explainer_start(data)

shap.plots.text(shap_values_start)


[0]
outputs
[CLS]
What
is
on
the
table
?
[SEP]
When
I
got
home
today
I
saw
my
cat
on
the
table
,
and
my
frog
on
the
floor
.
[SEP]

inputs
0.012
0.486
What
-0.278
is
0.034
on
0.022
the
-0.374
table
-0.937
?
0.0
[SEP]
-0.041 / 2
When I
-0.009 / 2
got home
-0.074 / 2
today I
-0.128 / 2
saw my
-0.324 / 2
cat on
-0.268 / 2
the table
-0.126
,
-0.137
and
-0.226 / 2
my frog
-0.341 / 4
on the floor.
0.0