Text to Multiclass Explanation: Language Modeling Example
This notebook demostrates how to get explanations for the top-k next words generated by a language model. In this demo, we use the pretrained gpt2 model provided by hugging face (https://huggingface.co/gpt2) to predict the top-k next words. By looking at the top-k next words, we treat them as k separate classes and then learn the explanations for each of this k words. We thereby are able to explain the contribution of words in the input that are responsible for the liklihood of the top-k next words to be predicted.
[1]:
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
import shap
import torch
Load model and tokenizer
[2]:
tokenizer = AutoTokenizer.from_pretrained('gpt2', use_fast=True)
model = AutoModelForCausalLM.from_pretrained('gpt2').cuda()
We next wrap the model with the TopKLM model which extracts the log odds of the top-k next words and also create a Text masker by initializing it with the mask_token = “…” and set collapse_mask_token = True, which is used for infilling text during perturbation of the inputs.
[3]:
wrapped_model = shap.models.TopKLM(model, tokenizer, k=100)
masker = shap.maskers.Text(tokenizer, mask_token = "...", collapse_mask_token=True)
Define data
Here we set the initial text for which we want the gpt2 model to predict the next word
[4]:
s = ["In a shocking finding, scientists discovered a herd of unicorns living in a"]
Create explainer object
[5]:
explainer = shap.Explainer(wrapped_model, masker)
Compute SHAP values
[6]:
shap_values = explainer(s)
Visualize the SHAP values across the input sentence for the top-k next words
We can now see the top-k next words predicted by gpt2 under “Output Text” in the viz plot below and hover over each of the token to understand which words in the input sentence are driving the generation of the particular output word to be predicted
[7]:
shap.plots.text(shap_values)