Text to Text Explanation: Abstractive Summarization Example
This notebook demonstrates use of generating model explanations for a text to text scenario on a pretrained transformer model. Below we demonstrate the process of generating explanations for a pretrained model distilbart on the Extreme Summarization (XSum) Dataset provided by hugging face (https://huggingface.co/sshleifer/distilbart-xsum-12-6).
The first example only needs the model and tokenizer and we use the model decoder to generate log odds of the output tokens to be explained. In the second example, we demonstrate the use of how to generate expplanations for model in the form of an api/fucntion (input->text and output->text). In this case we need to approximate the log odds by using a text similarity model. The underlying explainer used to compute the shap values is the partition explainer.
[1]:
import numpy as np
import torch
from datasets import load_dataset
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import shap
Load model and tokenizer
[2]:
tokenizer = AutoTokenizer.from_pretrained("sshleifer/distilbart-xsum-12-6")
model = AutoModelForSeq2SeqLM.from_pretrained("sshleifer/distilbart-xsum-12-6").cuda()
Load data
[3]:
dataset = load_dataset("xsum", split="train")
Using custom data configuration default
Reusing dataset xsum (/home/slundberg/.cache/huggingface/datasets/xsum/default/1.2.0/f9abaabb5e2b2a1e765c25417264722d31877b34ec34b437c53242f6e5c30d6d)
[4]:
# slice inputs from dataset to run model inference on
s = dataset["document"][0:1]
Create an explainer object
[5]:
explainer = shap.Explainer(model, tokenizer)
Compute shap values
[6]:
shap_values = explainer(s)
floor_divide is deprecated, and will be removed in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values.
To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at /pytorch/aten/src/ATen/native/BinaryOps.cpp:467.)
Partition explainer: 2it [00:19, 9.52s/it]
Visualize shap explanations
[7]:
shap.plots.text(shap_values)
API
Below we demonstrate generating explanations for a model which is an api/function. Since this is a model agnostic case, we use a text similarity model to approximate log odds of generating output text which is used for computing shap explanations.
[8]:
# Define function
def f(x):
inputs = tokenizer(x.tolist(), return_tensors="pt", padding=True).to("cuda")
with torch.no_grad():
out = model.generate(**inputs)
sentence = [tokenizer.decode(g, skip_special_tokens=True) for g in out]
return np.array(sentence)
For a model agnostic case, we wrap the model to be explained with the shal.models.TeacherForcing class and define the text similarity model and tokenizer. The TeacherForcing class uses the similarity model to approximate the log odds of generating the output text from the model(function->f)
We also have to define a Text masker and define mask_token=”…” and pass collapse_mask_token=True, which then cues the algorithm to use text infilling while masking
[9]:
# wrap model with TeacherForcingLogits class
teacher_forcing_model = shap.models.TeacherForcing(
f, similarity_model=model, similarity_tokenizer=tokenizer, device=model.device
)
# create a Text masker
masker = shap.maskers.Text(tokenizer, mask_token="...", collapse_mask_token=True)
Create an explainer object using wrapped model and Text masker
[10]:
explainer_model_agnostic = shap.Explainer(teacher_forcing_model, masker)
Compute shap values
[11]:
shap_values_model_agnostic = explainer_model_agnostic(s)
Partition explainer: 2it [00:34, 17.39s/it]
Visualize shap explanations
[12]:
shap.plots.text(shap_values_model_agnostic)
Have an idea for more helpful examples? Pull requests that add to this documentation notebook are encouraged!