gptsite / gradio /interpretation.py
microhan's picture
update module gradio
d4576ce
"""Contains classes and methods related to interpretation for components in Gradio."""
from __future__ import annotations
import copy
import math
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
import numpy as np
from gradio import components, utils
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
from gradio import Interface
class Interpretable(ABC):
def __init__(self) -> None:
self.set_interpret_parameters()
def set_interpret_parameters(self):
"""
Set any parameters for interpretation. Properties can be set here to be
used in get_interpretation_neighbors and get_interpretation_scores.
"""
pass
def get_interpretation_scores(
self, x: Any, neighbors: List[Any] | None, scores: List[float], **kwargs
) -> List:
"""
Arrange the output values from the neighbors into interpretation scores for the interface to render.
Parameters:
x: Input to interface
neighbors: Neighboring values to input x used for interpretation.
scores: Output value corresponding to each neighbor in neighbors
Returns:
Arrangement of interpretation scores for interfaces to render.
"""
return scores
class TokenInterpretable(Interpretable, ABC):
@abstractmethod
def tokenize(self, x: Any) -> Tuple[List, List, None]:
"""
Interprets an input data point x by splitting it into a list of tokens (e.g
a string into words or an image into super-pixels).
"""
return [], [], None
@abstractmethod
def get_masked_inputs(self, tokens: List, binary_mask_matrix: List[List]) -> List:
return []
class NeighborInterpretable(Interpretable, ABC):
@abstractmethod
def get_interpretation_neighbors(self, x: Any) -> Tuple[List, Dict]:
"""
Generates values similar to input to be used to interpret the significance of the input in the final output.
Parameters:
x: Input to interface
Returns: (neighbor_values, interpret_kwargs, interpret_by_removal)
neighbor_values: Neighboring values to input x to compute for interpretation
interpret_kwargs: Keyword arguments to be passed to get_interpretation_scores
"""
return [], {}
async def run_interpret(interface: Interface, raw_input: List):
"""
Runs the interpretation command for the machine learning model. Handles both the "default" out-of-the-box
interpretation for a certain set of UI component types, as well as the custom interpretation case.
Parameters:
raw_input: a list of raw inputs to apply the interpretation(s) on.
"""
if isinstance(interface.interpretation, list): # Either "default" or "shap"
processed_input = [
input_component.preprocess(raw_input[i])
for i, input_component in enumerate(interface.input_components)
]
original_output = await interface.call_function(0, processed_input)
original_output = original_output["prediction"]
if len(interface.output_components) == 1:
original_output = [original_output]
scores, alternative_outputs = [], []
for i, (x, interp) in enumerate(zip(raw_input, interface.interpretation)):
if interp == "default":
input_component = interface.input_components[i]
neighbor_raw_input = list(raw_input)
if isinstance(input_component, TokenInterpretable):
tokens, neighbor_values, masks = input_component.tokenize(x)
interface_scores = []
alternative_output = []
for neighbor_input in neighbor_values:
neighbor_raw_input[i] = neighbor_input
processed_neighbor_input = [
input_component.preprocess(neighbor_raw_input[i])
for i, input_component in enumerate(
interface.input_components
)
]
neighbor_output = await interface.call_function(
0, processed_neighbor_input
)
neighbor_output = neighbor_output["prediction"]
if len(interface.output_components) == 1:
neighbor_output = [neighbor_output]
processed_neighbor_output = [
output_component.postprocess(neighbor_output[i])
for i, output_component in enumerate(
interface.output_components
)
]
alternative_output.append(processed_neighbor_output)
interface_scores.append(
quantify_difference_in_label(
interface, original_output, neighbor_output
)
)
alternative_outputs.append(alternative_output)
scores.append(
input_component.get_interpretation_scores(
raw_input[i],
neighbor_values,
interface_scores,
masks=masks,
tokens=tokens,
)
)
elif isinstance(input_component, NeighborInterpretable):
(
neighbor_values,
interpret_kwargs,
) = input_component.get_interpretation_neighbors(x)
interface_scores = []
alternative_output = []
for neighbor_input in neighbor_values:
neighbor_raw_input[i] = neighbor_input
processed_neighbor_input = [
input_component.preprocess(neighbor_raw_input[i])
for i, input_component in enumerate(
interface.input_components
)
]
neighbor_output = await interface.call_function(
0, processed_neighbor_input
)
neighbor_output = neighbor_output["prediction"]
if len(interface.output_components) == 1:
neighbor_output = [neighbor_output]
processed_neighbor_output = [
output_component.postprocess(neighbor_output[i])
for i, output_component in enumerate(
interface.output_components
)
]
alternative_output.append(processed_neighbor_output)
interface_scores.append(
quantify_difference_in_label(
interface, original_output, neighbor_output
)
)
alternative_outputs.append(alternative_output)
interface_scores = [-score for score in interface_scores]
scores.append(
input_component.get_interpretation_scores(
raw_input[i],
neighbor_values,
interface_scores,
**interpret_kwargs,
)
)
else:
raise ValueError(
f"Component {input_component} does not support interpretation"
)
elif interp == "shap" or interp == "shapley":
try:
import shap # type: ignore
except (ImportError, ModuleNotFoundError):
raise ValueError(
"The package `shap` is required for this interpretation method. Try: `pip install shap`"
)
input_component = interface.input_components[i]
if not isinstance(input_component, TokenInterpretable):
raise ValueError(
"Input component {} does not support `shap` interpretation".format(
input_component
)
)
tokens, _, masks = input_component.tokenize(x)
# construct a masked version of the input
def get_masked_prediction(binary_mask):
assert isinstance(input_component, TokenInterpretable)
masked_xs = input_component.get_masked_inputs(tokens, binary_mask)
preds = []
for masked_x in masked_xs:
processed_masked_input = copy.deepcopy(processed_input)
processed_masked_input[i] = input_component.preprocess(masked_x)
new_output = utils.synchronize_async(
interface.call_function, 0, processed_masked_input
)
new_output = new_output["prediction"]
if len(interface.output_components) == 1:
new_output = [new_output]
pred = get_regression_or_classification_value(
interface, original_output, new_output
)
preds.append(pred)
return np.array(preds)
num_total_segments = len(tokens)
explainer = shap.KernelExplainer(
get_masked_prediction, np.zeros((1, num_total_segments))
)
shap_values = explainer.shap_values(
np.ones((1, num_total_segments)),
nsamples=int(interface.num_shap * num_total_segments),
silent=True,
)
assert shap_values is not None, "SHAP values could not be calculated"
scores.append(
input_component.get_interpretation_scores(
raw_input[i],
None,
shap_values[0].tolist(),
masks=masks,
tokens=tokens,
)
)
alternative_outputs.append([])
elif interp is None:
scores.append(None)
alternative_outputs.append([])
else:
raise ValueError("Unknown intepretation method: {}".format(interp))
return scores, alternative_outputs
elif interface.interpretation: # custom interpretation function
processed_input = [
input_component.preprocess(raw_input[i])
for i, input_component in enumerate(interface.input_components)
]
interpreter = interface.interpretation
interpretation = interpreter(*processed_input)
if len(raw_input) == 1:
interpretation = [interpretation]
return interpretation, []
else:
raise ValueError("No interpretation method specified.")
def diff(original: Any, perturbed: Any) -> int | float:
try: # try computing numerical difference
score = float(original) - float(perturbed)
except ValueError: # otherwise, look at strict difference in label
score = int(not (original == perturbed))
return score
def quantify_difference_in_label(
interface: Interface, original_output: List, perturbed_output: List
) -> int | float:
output_component = interface.output_components[0]
post_original_output = output_component.postprocess(original_output[0])
post_perturbed_output = output_component.postprocess(perturbed_output[0])
if isinstance(output_component, components.Label):
original_label = post_original_output["label"]
perturbed_label = post_perturbed_output["label"]
# Handle different return types of Label interface
if "confidences" in post_original_output:
original_confidence = original_output[0][original_label]
perturbed_confidence = perturbed_output[0][original_label]
score = original_confidence - perturbed_confidence
else:
score = diff(original_label, perturbed_label)
return score
elif isinstance(output_component, components.Number):
score = diff(post_original_output, post_perturbed_output)
return score
else:
raise ValueError(
"This interpretation method doesn't support the Output component: {}".format(
output_component
)
)
def get_regression_or_classification_value(
interface: Interface, original_output: List, perturbed_output: List
) -> int | float:
"""Used to combine regression/classification for Shap interpretation method."""
output_component = interface.output_components[0]
post_original_output = output_component.postprocess(original_output[0])
post_perturbed_output = output_component.postprocess(perturbed_output[0])
if isinstance(output_component, components.Label):
original_label = post_original_output["label"]
perturbed_label = post_perturbed_output["label"]
# Handle different return types of Label interface
if "confidences" in post_original_output:
if math.isnan(perturbed_output[0][original_label]):
return 0
return perturbed_output[0][original_label]
else:
score = diff(
perturbed_label, original_label
) # Intentionally inverted order of arguments.
return score
else:
raise ValueError(
"This interpretation method doesn't support the Output component: {}".format(
output_component
)
)