Text to Text Explanation: Abstractive Summarization Example
This notebook demonstrates use of generating model explanations for a text to text scenario on a pretrained transformer model. Below we demonstrate the process of generating explanations for a pretrained model distilbart on the Extreme Summarization (XSum) Dataset provided by hugging face (https://huggingface.co/sshleifer/distilbart-xsum-12-6).
The first example only needs the model and tokenizer and we use the model decoder to generate log odds of the output tokens to be explained. In the second example, we demonstrate the use of how to generate expplanations for model in the form of an api/fucntion (input->text and output->text). In this case we need to approximate the log odds by using a text similarity model. The underlying explainer used to compute the shap values is the partition explainer.
[1]:
import numpy as np
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
from datasets import load_dataset
import shap
Load model and tokenizer
[2]:
tokenizer = AutoTokenizer.from_pretrained("sshleifer/distilbart-xsum-12-6")
model = AutoModelForSeq2SeqLM.from_pretrained("sshleifer/distilbart-xsum-12-6").cuda()
Load data
[3]:
dataset = load_dataset('xsum', split='train')
Using custom data configuration default
Reusing dataset xsum (/home/slundberg/.cache/huggingface/datasets/xsum/default/1.2.0/f9abaabb5e2b2a1e765c25417264722d31877b34ec34b437c53242f6e5c30d6d)
[4]:
# slice inputs from dataset to run model inference on
s = dataset['document'][0:1]
Create an explainer object
[5]:
explainer = shap.Explainer(model, tokenizer)
Compute shap values
[6]:
shap_values = explainer(s)
floor_divide is deprecated, and will be removed in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values.
To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at /pytorch/aten/src/ATen/native/BinaryOps.cpp:467.)
Partition explainer: 2it [00:19, 9.52s/it]
Visualize shap explanations
[7]:
shap.plots.text(shap_values)