Machine Translation Explanations
This notebook demonstrates model explanations for a text to text scenario using a pretrained transformer model for machine translation. In this demo, we showcase explanations on two different models: English to Spanish (https://huggingface.co/Helsinki-NLP/opus-mt-en-es), and English to French (https://huggingface.co/Helsinki-NLP/opus-mt-en-fr).
[1]:
import numpy as np
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import shap
import torch
English to Spanish model
[2]:
# load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-es")
model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-es").cuda()
# define the input sentences we want to translate
data = [
"Transformers have rapidly become the model of choice for NLP problems, replacing older recurrent neural network models"
]
Explain the model’s predictions
[3]:
# we build an explainer by passing the model we want to explain and
# the tokenizer we want to use to break up the input strings
explainer = shap.Explainer(model, tokenizer)
# explainers are callable, just like models
shap_values = explainer(data, fixed_context=1)
floor_divide is deprecated, and will be removed in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values.
To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at /pytorch/aten/src/ATen/native/BinaryOps.cpp:467.)
Visualize shap explanations
[4]:
shap.plots.text(shap_values)
[0]
outputs
Los
transformador
es
se
han
convertido
rápidamente
en
el
modelo
de
elección
para
problemas
N
LP
,
reemplaza
ndo
modelos
de
red
neuro
nal
recurrente
s
más
antiguos
inputs
1.965
▁Transform
5.114
ers
1.903
▁have
-0.505
▁rapidly
0.186
▁become
0.101
▁the
-0.225
▁model
0.325
▁of
-0.114
▁choice
0.081
▁for
-0.096
▁N
0.021
LP
-0.247
▁problems
-0.417
,
0.053
▁replacing
0.025
▁older
0.05
▁recurrent
0.172
▁neural
0.105
▁network
-0.114
▁models
-0.1
inputs
1.965
▁Transform
5.114
ers
1.903
▁have
-0.505
▁rapidly
0.186
▁become
0.101
▁the
-0.225
▁model
0.325
▁of
-0.114
▁choice
0.081
▁for
-0.096
▁N
0.021
LP
-0.247
▁problems
-0.417
,
0.053
▁replacing
0.025
▁older
0.05
▁recurrent
0.172
▁neural
0.105
▁network
-0.114
▁models
-0.1
inputs
7.261
▁Transform
4.398
ers
-0.073
▁have
0.104
▁rapidly
-0.194
▁become
0.024
▁the
0.131
▁model
0.117
▁of
0.001
▁choice
0.242
▁for
0.092
▁N
-0.103
LP
-0.173
▁problems
0.14
,
0.16
▁replacing
0.203
▁older
0.247
▁recurrent
0.094
▁neural
0.261
▁network
0.309
▁models
0.267
inputs
7.261
▁Transform
4.398
ers
-0.073
▁have
0.104
▁rapidly
-0.194
▁become
0.024
▁the
0.131
▁model
0.117
▁of
0.001
▁choice
0.242
▁for
0.092
▁N
-0.103
LP
-0.173
▁problems
0.14
,
0.16
▁replacing
0.203
▁older
0.247
▁recurrent
0.094
▁neural
0.261
▁network
0.309
▁models
0.267
inputs
-0.165
▁Transform
-0.11
ers
-0.009
▁have
-0.035
▁rapidly
0.017
▁become
0.002
▁the
-0.008
▁model
-0.011
▁of
-0.015
▁choice
-0.014
▁for
-0.006
▁N
-0.007
LP
0.004
▁problems
0.012
,
-0.01
▁replacing
-0.002
▁older
-0.006
▁recurrent
-0.009
▁neural
-0.003
▁network
-0.003
▁models
0.007
inputs
-0.165
▁Transform
-0.11
ers
-0.009
▁have
-0.035
▁rapidly
0.017
▁become
0.002
▁the
-0.008
▁model
-0.011
▁of
-0.015
▁choice
-0.014
▁for
-0.006
▁N
-0.007
LP
0.004
▁problems
0.012
,
-0.01
▁replacing
-0.002
▁older
-0.006
▁recurrent
-0.009
▁neural
-0.003
▁network
-0.003
▁models
0.007
inputs
-0.101
▁Transform
1.591
ers
0.787
▁have
-0.91
▁rapidly
5.289
▁become
-0.661
▁the
-0.702
▁model
-0.672
▁of
-0.026
▁choice
0.042
▁for
-0.014
▁N
0.026
LP
-0.086
▁problems
-0.046
,
-0.0
▁replacing
-0.036
▁older
-0.014
▁recurrent
0.042
▁neural
0.016
▁network
0.021
▁models
-0.052
inputs
-0.101
▁Transform
1.591
ers
0.787
▁have
-0.91
▁rapidly
5.289
▁become
-0.661
▁the
-0.702
▁model
-0.672
▁of
-0.026
▁choice
0.042
▁for
-0.014
▁N
0.026
LP
-0.086
▁problems
-0.046
,
-0.0
▁replacing
-0.036
▁older
-0.014
▁recurrent
0.042
▁neural
0.016
▁network
0.021
▁models
-0.052
inputs
-0.385
▁Transform
-0.282
ers
6.018
▁have
-1.286
▁rapidly
1.998
▁become
-0.009
▁the
-0.315
▁model
-0.146
▁of
-0.0
▁choice
0.016
▁for
-0.003
▁N
0.005
LP
-0.014
▁problems
-0.107
,
-0.039
▁replacing
-0.019
▁older
-0.03
▁recurrent
-0.028
▁neural
-0.126
▁network
-0.054
▁models
-0.064
inputs
-0.385
▁Transform
-0.282
ers
6.018
▁have
-1.286
▁rapidly
1.998
▁become
-0.009
▁the
-0.315
▁model
-0.146
▁of
-0.0
▁choice
0.016
▁for
-0.003
▁N
0.005
LP
-0.014
▁problems
-0.107
,
-0.039
▁replacing
-0.019
▁older
-0.03
▁recurrent
-0.028
▁neural
-0.126
▁network
-0.054
▁models
-0.064