shap.models.TextGeneration

class shap.models.TextGeneration(model=None, tokenizer=None, target_sentences=None, device=None)

Generates target sentence/ids using a base model.

It generates target sentence/ids for a model (a pretrained transformer model or a function).

__init__(model=None, tokenizer=None, target_sentences=None, device=None)

Create a text generator model from a pretrained transformer model or a function.

For a pretrained transformer model, a tokenizer should be passed.

Parameters:
model: object or function

A object of any pretrained transformer model or function for which target sentence/ids are to be generated.

tokenizer: object

A tokenizer object(PreTrainedTokenizer/PreTrainedTokenizerFast) which is used to tokenize sentence.

target_sentences: list

A target sentence for every explanation row.

device: str

By default, it infers if system has a gpu and accordingly sets device. Should be ‘cpu’ or ‘cuda’ or pytorch models.

Returns:
numpy.ndarray

Array of target sentence/ids.

Methods

__init__([model, tokenizer, ...])

Create a text generator model from a pretrained transformer model or a function.

get_inputs(X[, padding_side])

The function tokenizes source sentences.

load(in_file[, instantiate])

This is meant to be overridden by subclasses and called with super.

model_generate(X)

This function performs text generation for tensorflow and pytorch models.

parse_prefix_suffix_for_model_generate_output(output)

Calculates if special tokens are present in the beginning/end of the model generated output.

save(out_file)

Save the model to the given file stream.

get_inputs(X, padding_side='right')

The function tokenizes source sentences.

In model agnostic case, the function calls model(X) which is expected to return a batch of output sentences which is tokenized to compute inputs.

Parameters:
X: numpy.ndarray

X is a batch of sentences.

Returns:
dict

Dictionary of padded source sentence ids and attention mask as tensors(“pt” or “tf” based on model_type).

classmethod load(in_file, instantiate=True)

This is meant to be overridden by subclasses and called with super.

We return constructor argument values when not being instantiated. Since there are no constructor arguments for the Serializable class we just return an empty dictionary.

model_generate(X)

This function performs text generation for tensorflow and pytorch models.

Parameters:
X: dict

Dictionary of padded source sentence ids and attention mask as tensors.

Returns:
numpy.ndarray

Returns target sentence ids.

parse_prefix_suffix_for_model_generate_output(output)

Calculates if special tokens are present in the beginning/end of the model generated output.

Parameters:
output: list

A list of output token ids.

Returns:
dict

Dictionary of prefix and suffix lengths concerning special tokens in output ids.

save(out_file)

Save the model to the given file stream.