Explain ResNet50 on ImageNet multi-class output using SHAP Partition Explainer¶
This notebook demonstrates how to use SHAP for explaining models which do image classification. Here, we are explaining the output of ResNet50 model for classifying images into 1000 ImageNet classes.
import json import numpy as np import shap import tensorflow as tf from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
Loading Model and Data¶
# load pre-trained model and data model = ResNet50(weights='imagenet') X, y = shap.datasets.imagenet50()
# getting ImageNet 1000 class names url = "https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json" with open(shap.datasets.cache(url)) as file: class_names = [v for v in json.load(file).values()] #print("Number of ImageNet classes:", len(class_names)) #print("Class names:", class_names)
SHAP ResNet50 model explanation for images¶
Build a partition explainer with: - the model (a python function) - the masker (a python function) - the partition tree (a python function) - output names (a list of names of the output classes)
# python function to get model output; replace this function with your own model function. def f(x): tmp = x.copy() preprocess_input(tmp) return model(tmp) # define a masker that is used to mask out partitions of the input image. masker = shap.maskers.Image("inpaint_telea", X.shape) # create an explainer with model and image masker explainer = shap.Explainer(f, masker, output_names=class_names) # here we explain two images using 500 evaluations of the underlying model to estimate the SHAP values shap_values = explainer(X[1:3], max_evals=500, batch_size=50, outputs=shap.Explanation.argsort.flip[:4])
Partition explainer: 50%|█████ | 1/2 [00:00<?, ?it/s] 0%| | 0/248 [00:00<?, ?it/s] 77%|███████▋ | 192/248 [00:04<00:01, 42.37it/s] 98%|█████████▊| 242/248 [00:08<00:00, 23.10it/s] 292it [00:13, 17.80it/s] 342it [00:17, 15.20it/s] Partition explainer: 3it [01:01, 20.37s/it]
Above image masker uses a blurring technique called “inpaint_telea”. There are alternate masking options available to experiment with such as “inpaint_ns” and “blur(kernel_xsize, kernel_xsize)”.
Recommended number of evaluations is 300-500 to get explanations with sufficient granularity for the super pixels. More the number of evaluations, more the granularity but also increases run-time.
outputs=shap.Explanation.argsort.flip[:4]has been used in the code above for getting SHAP values because we want to get the top 4 most probable classes for each image i.e. top 4 classes with decreasing probability. Hence, a flip argsort sliced by 4 has been used.
Visualizing SHAP values output¶
# output with shap values shap.image_plot(shap_values)
Interpretation of SHAP output explanation:¶
In the first example, given bird image is classified as an American Egret with next probable classes being a Crane, Heron and Flamingo. It is the “bump” over the bird’s neck that causes it to be classified as an American Egret vs a Crane, Heron or a Flamingo. You can see the neck region of the bird appropriately highlighted in red super pixels.
In the second example, it is the shape of the boat which causes it to be classified as a speedboat instead of a fountain, lifeboat or snowplow (appropriately highlighted in red super pixels).