shap.Cohorts

class shap.Cohorts(**kwargs: Explanation)

A collection of Explanation objects, typically each explaining a cluster of similar samples.

Examples

A Cohorts object can be initialized in a variety of ways.

By explicitly specifying the cohorts:

>>> exp = Explanation(
...     values=np.random.uniform(low=-1, high=1, size=(500, 5)),
...     data=np.random.normal(loc=1, scale=3, size=(500, 5)),
...     feature_names=list("abcde"),
... )
>>> cohorts = Cohorts(
...     col_a_neg=exp[exp[:, "a"].data < 0],
...     col_a_pos=exp[exp[:, "a"].data >= 0],
... )
>>> cohorts
<shap._explanation.Cohorts object with 2 cohorts of sizes: [(198, 5), (302, 5)]>

Or using the Explanation.cohorts() method:

>>> cohorts2 = exp.cohorts(3)
>>> cohorts2
<shap._explanation.Cohorts object with 3 cohorts of sizes: [(182, 5), (12, 5), (306, 5)]>

Most of the Explanation interface is also exposed in Cohorts. For example, to retrieve the SHAP values corresponding to column ‘a’ across all cohorts, you can use:

>>> cohorts[..., 'a'].values
<shap._explanation.Cohorts object with 2 cohorts of sizes: [(198,), (302,)]>

To actually retrieve the values of a particular Explanation, you’ll need to access it via the Cohorts.cohorts() property:

>>> cohorts.cohorts["col_a_neg"][..., 'a'].values
array([...])  # truncated
__init__(**kwargs: Explanation) None

Methods

__init__(**kwargs)

Attributes

cohorts

Internal collection of cohorts, stored as a dictionary.

property cohorts: dict[str, Explanation]

Internal collection of cohorts, stored as a dictionary.