Explaining a Question Answering Transformers Model
Here we demonstrate how to explain the output of a question answering model that predicts which range of the context text contains the answer to a given question.
[1]:
import torch
import transformers
import shap
# load the model
pmodel = transformers.pipeline("question-answering")
# define two predictions, one that outputs the logits for the range start,
# and the other for the range end
def f(questions, start):
outs = []
for q in questions:
question, context = q.split("[SEP]")
d = pmodel.tokenizer(question, context)
out = pmodel.model.forward(**{k: torch.tensor(d[k]).reshape(1, -1) for k in d})
logits = out.start_logits if start else out.end_logits
outs.append(logits.reshape(-1).detach().numpy())
return outs
def f_start(questions):
return f(questions, True)
def f_end(questions):
return f(questions, False)
# attach a dynamic output_names property to the models so we can plot the tokens at each output position
def out_names(inputs):
question, context = inputs.split("[SEP]")
d = pmodel.tokenizer(question, context)
return [pmodel.tokenizer.decode([id]) for id in d["input_ids"]]
f_start.output_names = out_names
f_end.output_names = out_names
Explain the starting positions
Here we explain the starting range predictions of the model. Note that because the model output depends on the length of the model input, is is important that we pass the model’s native tokenizer for masking, so that when we hide portions of the text we can retain the same number of tokens and hence the same meaning for each output position.
[2]:
data = [
"What is on the table?[SEP]When I got home today I saw my cat on the table, and my frog on the floor."
]
explainer_start = shap.Explainer(f_start, pmodel.tokenizer)
shap_values_start = explainer_start(data)
shap.plots.text(shap_values_start)
Explain the ending positions
This is the same process as above, but now we explain the end tokens.
[3]:
explainer_end = shap.Explainer(f_end, pmodel.tokenizer)
shap_values_end = explainer_end(data)
shap.plots.text(shap_values_end)
Explain a matching function
In the example above we directly explained the output logits coming from the model. This required us to ensure that we only perturbed the input in length-preserving ways, so as to not change the meaning of the output logits. A less detailed but more flexible approach is to just score if specific answers are produced by the model.
[4]:
def make_answer_scorer(answers):
def f(questions):
out = []
for q in questions:
question, context = q.split("[SEP]")
results = pmodel(question, context, topk=20)
values = []
for answer in answers:
value = 0
for result in results:
if result["answer"] == answer:
value = result["score"]
break
values.append(value)
out.append(values)
return out
f.output_names = answers
return f
f_answers = make_answer_scorer(["my cat", "cat", "my frog"])
explainer_answers = shap.Explainer(f_answers, pmodel.tokenizer)
shap_values_answers = explainer_answers(data)
shap.plots.text(shap_values_answers)
Have an idea for more helpful examples? Pull requests that add to this documentation notebook are encouraged!