Explaining a Question Answering Transformers Model
Here we demonstrate how to explain the output of a question answering model that predicts which range of the context text contains the answer to a given question.
[1]:
import numpy as np
import torch
import transformers
import shap
# load the model
pmodel = transformers.pipeline("question-answering")
tokenized_qs = None # variable to store the tokenized data
# define two predictions, one that outputs the logits for the range start,
# and the other for the range end
def f(questions, tokenized_qs, start):
outs = []
for q in questions:
idx = np.argwhere(np.array(tokenized_qs["input_ids"]) == pmodel.tokenizer.sep_token_id)[
0, 0
] # this code assumes that there is only one sentence in data
d = tokenized_qs.copy()
d["input_ids"][:idx] = q[:idx]
d["input_ids"][idx + 1 :] = q[idx + 1 :]
out = pmodel.model.forward(**{k: torch.tensor(d[k]).reshape(1, -1) for k in d})
logits = out.start_logits if start else out.end_logits
outs.append(logits.reshape(-1).detach().numpy())
return outs
def tokenize_data(data):
for q in data:
question, context = q.split("[SEP]")
tokenized_data = pmodel.tokenizer(question, context)
return tokenized_data # this code assumes that there is only one sentence in data
def f_start(questions):
return f(questions, tokenized_qs, True)
def f_end(questions):
return f(questions, tokenized_qs, False)
# attach a dynamic output_names property to the models so we can plot the tokens at each output position
def out_names(inputs):
question, context = inputs.split("[SEP]")
d = pmodel.tokenizer(question, context)
return [pmodel.tokenizer.decode([id]) for id in d["input_ids"]]
f_start.output_names = out_names
f_end.output_names = out_names
Explain the starting positions
Here we explain the starting range predictions of the model. Note that because the model output depends on the length of the model input, is is important that we pass the model’s native tokenizer for masking, so that when we hide portions of the text we can retain the same number of tokens and hence the same meaning for each output position.
[2]:
data = [
"What is on the table?[SEP]When I got home today I saw my cat on the table, and my frog on the floor.",
] # this code assumes that there is only one sentence in data
tokenized_qs = tokenize_data(data)
explainer_start = shap.Explainer(f_start, shap.maskers.Text(tokenizer=pmodel.tokenizer, output_type="ids"))
shap_values_start = explainer_start(data)
shap.plots.text(shap_values_start)
Partition explainer: 2it [00:32, 32.86s/it]
[0]
outputs
[CLS]
What
is
on
the
table
?
[SEP]
When
I
got
home
today
I
saw
my
cat
on
the
table
,
and
my
frog
on
the
floor
.
[SEP]
inputs
0.0
0.498
What
-0.459
is
0.214
on
0.022
the
-0.374
table
-0.937
?
0.0
[SEP]
-0.059
When
-0.049
I
0.026
got
0.032
home
0.013
today
-0.0
I
-0.137
saw
-0.078
my
-0.254
cat
-0.25
on
-0.05
the
-0.037
table
-0.126
,
-0.182
and
-0.094
my
-0.086
frog
-0.233 / 3
on the floor
-0.108
.
0.0
inputs
0.0
0.498
What
-0.459
is
0.214
on
0.022
the
-0.374
table
-0.937
?
0.0
[SEP]
-0.059
When
-0.049
I
0.026
got
0.032
home
0.013
today
-0.0
I
-0.137
saw
-0.078
my
-0.254
cat
-0.25
on
-0.05
the
-0.037
table
-0.126
,
-0.182
and
-0.094
my
-0.086
frog
-0.233 / 3
on the floor
-0.108
.
0.0
inputs
0.0
1.481
What
-0.906
is
-0.26
on
-0.508
the
-0.669
table
-2.177
?
0.0
[SEP]
-0.043
When
-0.033
I
0.076
got
0.082
home
0.164
today
0.139
I
-0.146
saw
-0.009
my
-0.236
cat
-0.237
on
-0.02
the
-0.027
table
-0.044
,
-0.338
and
-0.226
my
-0.282
frog
-0.185 / 3
on the floor
-0.047
.
0.0
inputs
0.0
1.481
What
-0.906
is
-0.26
on
-0.508
the
-0.669
table
-2.177
?
0.0
[SEP]
-0.043
When
-0.033
I
0.076
got
0.082
home
0.164
today
0.139
I
-0.146
saw
-0.009
my
-0.236
cat
-0.237
on
-0.02
the
-0.027
table
-0.044
,
-0.338
and
-0.226
my
-0.282
frog
-0.185 / 3
on the floor
-0.047
.
0.0
inputs
0.0
0.282
What
-0.845
is
0.186
on
-0.412
the
-0.405
table
-1.745
?
0.0
[SEP]
-0.026
When
-0.016
I
0.106
got
0.109
home
0.213
today
0.187
I
-0.068
saw
0.065
my
-0.246
cat
-0.253
on
0.002
the
-0.013
table
-0.046
,
-0.442
and
-0.211
my
-0.26
frog
-0.265 / 3
on the floor
-0.136
.
0.0
inputs
0.0
0.282
What
-0.845
is
0.186
on
-0.412
the
-0.405
table
-1.745
?
0.0
[SEP]
-0.026
When
-0.016
I
0.106
got
0.109
home
0.213
today
0.187
I
-0.068
saw
0.065
my
-0.246
cat
-0.253
on
0.002
the
-0.013
table
-0.046
,
-0.442
and
-0.211
my
-0.26
frog
-0.265 / 3
on the floor
-0.136
.
0.0
inputs
0.0
0.462
What
0.153
is
0.352
on
-0.81
the
-0.447
table
-2.052
?
0.0
[SEP]
-0.01
When
-0.0
I
0.136
got
0.142
home
0.18
today
0.158
I
-0.098
saw
0.022
my
-0.358
cat
-0.402
on
-0.183
the
-0.245
table
-0.035
,
-0.354
and
-0.24
my
-0.269
frog
-0.343 / 3
on the floor
-0.184
.
0.0
inputs
0.0
0.462
What
0.153
is
0.352
on
-0.81
the
-0.447
table
-2.052
?
0.0
[SEP]
-0.01
When
-0.0
I
0.136
got
0.142
home
0.18
today
0.158
I
-0.098
saw
0.022
my
-0.358
cat
-0.402
on
-0.183
the
-0.245
table
-0.035
,
-0.354
and
-0.24
my
-0.269
frog
-0.343 / 3
on the floor
-0.184
.
0.0
inputs
0.0
0.545
What
-0.584
is
0.457
on
0.117
the
-0.773
table
-2.561
?
0.0
[SEP]
-0.036
When
-0.023
I
0.137
got
0.142
home
0.165
today
0.144
I
-0.141
saw
-0.019
my
-0.27
cat
-0.291
on
-0.23
the
-0.287
table
-0.011
,
-0.32
and
-0.265
my
-0.266
frog
-0.232 / 3
on the floor
-0.063
.
0.0