Explaining a Question Answering Transformers Model

Here we demonstrate how to explain the output of a question answering model that predicts which range of the context text contains the answer to a given question.

[1]:
import numpy as np
import torch
import transformers

import shap

# load the model
pmodel = transformers.pipeline("question-answering")
tokenized_qs = None  # variable to store the tokenized data


# define two predictions, one that outputs the logits for the range start,
# and the other for the range end
def f(questions, tokenized_qs, start):
    outs = []
    for q in questions:
        idx = np.argwhere(np.array(tokenized_qs["input_ids"]) == pmodel.tokenizer.sep_token_id)[
            0, 0
        ]  # this code assumes that there is only one sentence in data
        d = tokenized_qs.copy()
        d["input_ids"][:idx] = q[:idx]
        d["input_ids"][idx + 1 :] = q[idx + 1 :]
        out = pmodel.model.forward(**{k: torch.tensor(d[k]).reshape(1, -1) for k in d})
        logits = out.start_logits if start else out.end_logits
        outs.append(logits.reshape(-1).detach().numpy())
    return outs


def tokenize_data(data):
    for q in data:
        question, context = q.split("[SEP]")
        tokenized_data = pmodel.tokenizer(question, context)
    return tokenized_data  # this code assumes that there is only one sentence in data


def f_start(questions):
    return f(questions, tokenized_qs, True)


def f_end(questions):
    return f(questions, tokenized_qs, False)


# attach a dynamic output_names property to the models so we can plot the tokens at each output position
def out_names(inputs):
    question, context = inputs.split("[SEP]")
    d = pmodel.tokenizer(question, context)
    return [pmodel.tokenizer.decode([id]) for id in d["input_ids"]]


f_start.output_names = out_names
f_end.output_names = out_names

Explain the starting positions

Here we explain the starting range predictions of the model. Note that because the model output depends on the length of the model input, is is important that we pass the model’s native tokenizer for masking, so that when we hide portions of the text we can retain the same number of tokens and hence the same meaning for each output position.

[2]:
data = [
    "What is on the table?[SEP]When I got home today I saw my cat on the table, and my frog on the floor.",
]  # this code assumes that there is only one sentence in data
tokenized_qs = tokenize_data(data)

explainer_start = shap.Explainer(f_start, shap.maskers.Text(tokenizer=pmodel.tokenizer, output_type="ids"))
shap_values_start = explainer_start(data)

shap.plots.text(shap_values_start)
Partition explainer: 2it [00:32, 32.86s/it]


[0]
outputs
[CLS]
What
is
on
the
table
?
[SEP]
When
I
got
home
today
I
saw
my
cat
on
the
table
,
and
my
frog
on
the
floor
.
[SEP]


-1-4-725-1.26347-1.26347base value-3.97193-3.97193f[CLS](inputs)0.498 What 0.214 on 0.032 home 0.026 got 0.022 the 0.013 today -0.937 ? -0.459 is -0.374 table -0.254 cat -0.25 on -0.233 on the floor -0.182 and -0.137 saw -0.126 , -0.108 . -0.094 my -0.086 frog -0.078 my -0.059 When -0.05 the -0.049 I -0.037 table -0.0 I
inputs
0.0
0.498
What
-0.459
is
0.214
on
0.022
the
-0.374
table
-0.937
?
0.0
[SEP]
-0.059
When
-0.049
I
0.026
got
0.032
home
0.013
today
-0.0
I
-0.137
saw
-0.078
my
-0.254
cat
-0.25
on
-0.05
the
-0.037
table
-0.126
,
-0.182
and
-0.094
my
-0.086
frog
-0.233 / 3
on the floor
-0.108
.
0.0