Chakshu123's picture
Duplicate from Chakshu123/sketch-colorization-with-hint
443d045
import copy
import math
import numpy as np
from gradio import utils
from gradio.components import Label, Number
async def run_interpret(interface, raw_input):
"""
Runs the interpretation command for the machine learning model. Handles both the "default" out-of-the-box
interpretation for a certain set of UI component types, as well as the custom interpretation case.
Parameters:
raw_input: a list of raw inputs to apply the interpretation(s) on.
"""
if isinstance(interface.interpretation, list): # Either "default" or "shap"
processed_input = [
input_component.preprocess(raw_input[i])
for i, input_component in enumerate(interface.input_components)
]
original_output = await interface.call_function(0, processed_input)
original_output = original_output["prediction"]
if len(interface.output_components) == 1:
original_output = [original_output]
scores, alternative_outputs = [], []
for i, (x, interp) in enumerate(zip(raw_input, interface.interpretation)):
if interp == "default":
input_component = interface.input_components[i]
neighbor_raw_input = list(raw_input)
if input_component.interpret_by_tokens:
tokens, neighbor_values, masks = input_component.tokenize(x)
interface_scores = []
alternative_output = []
for neighbor_input in neighbor_values:
neighbor_raw_input[i] = neighbor_input
processed_neighbor_input = [
input_component.preprocess(neighbor_raw_input[i])
for i, input_component in enumerate(
interface.input_components
)
]
neighbor_output = await interface.call_function(
0, processed_neighbor_input
)
neighbor_output = neighbor_output["prediction"]
if len(interface.output_components) == 1:
neighbor_output = [neighbor_output]
processed_neighbor_output = [
output_component.postprocess(neighbor_output[i])
for i, output_component in enumerate(
interface.output_components
)
]
alternative_output.append(processed_neighbor_output)
interface_scores.append(
quantify_difference_in_label(
interface, original_output, neighbor_output
)
)
alternative_outputs.append(alternative_output)
scores.append(
input_component.get_interpretation_scores(
raw_input[i],
neighbor_values,
interface_scores,
masks=masks,
tokens=tokens,
)
)
else:
(
neighbor_values,
interpret_kwargs,
) = input_component.get_interpretation_neighbors(x)
interface_scores = []
alternative_output = []
for neighbor_input in neighbor_values:
neighbor_raw_input[i] = neighbor_input
processed_neighbor_input = [
input_component.preprocess(neighbor_raw_input[i])
for i, input_component in enumerate(
interface.input_components
)
]
neighbor_output = await interface.call_function(
0, processed_neighbor_input
)
neighbor_output = neighbor_output["prediction"]
if len(interface.output_components) == 1:
neighbor_output = [neighbor_output]
processed_neighbor_output = [
output_component.postprocess(neighbor_output[i])
for i, output_component in enumerate(
interface.output_components
)
]
alternative_output.append(processed_neighbor_output)
interface_scores.append(
quantify_difference_in_label(
interface, original_output, neighbor_output
)
)
alternative_outputs.append(alternative_output)
interface_scores = [-score for score in interface_scores]
scores.append(
input_component.get_interpretation_scores(
raw_input[i],
neighbor_values,
interface_scores,
**interpret_kwargs
)
)
elif interp == "shap" or interp == "shapley":
try:
import shap # type: ignore
except (ImportError, ModuleNotFoundError):
raise ValueError(
"The package `shap` is required for this interpretation method. Try: `pip install shap`"
)
input_component = interface.input_components[i]
if not (input_component.interpret_by_tokens):
raise ValueError(
"Input component {} does not support `shap` interpretation".format(
input_component
)
)
tokens, _, masks = input_component.tokenize(x)
# construct a masked version of the input
def get_masked_prediction(binary_mask):
masked_xs = input_component.get_masked_inputs(tokens, binary_mask)
preds = []
for masked_x in masked_xs:
processed_masked_input = copy.deepcopy(processed_input)
processed_masked_input[i] = input_component.preprocess(masked_x)
new_output = utils.synchronize_async(
interface.call_function, 0, processed_masked_input
)
new_output = new_output["prediction"]
if len(interface.output_components) == 1:
new_output = [new_output]
pred = get_regression_or_classification_value(
interface, original_output, new_output
)
preds.append(pred)
return np.array(preds)
num_total_segments = len(tokens)
explainer = shap.KernelExplainer(
get_masked_prediction, np.zeros((1, num_total_segments))
)
shap_values = explainer.shap_values(
np.ones((1, num_total_segments)),
nsamples=int(interface.num_shap * num_total_segments),
silent=True,
)
scores.append(
input_component.get_interpretation_scores(
raw_input[i], None, shap_values[0], masks=masks, tokens=tokens
)
)
alternative_outputs.append([])
elif interp is None:
scores.append(None)
alternative_outputs.append([])
else:
raise ValueError("Unknown intepretation method: {}".format(interp))
return scores, alternative_outputs
else: # custom interpretation function
processed_input = [
input_component.preprocess(raw_input[i])
for i, input_component in enumerate(interface.input_components)
]
interpreter = interface.interpretation
interpretation = interpreter(*processed_input)
if len(raw_input) == 1:
interpretation = [interpretation]
return interpretation, []
def diff(original, perturbed):
try: # try computing numerical difference
score = float(original) - float(perturbed)
except ValueError: # otherwise, look at strict difference in label
score = int(not (original == perturbed))
return score
def quantify_difference_in_label(interface, original_output, perturbed_output):
output_component = interface.output_components[0]
post_original_output = output_component.postprocess(original_output[0])
post_perturbed_output = output_component.postprocess(perturbed_output[0])
if isinstance(output_component, Label):
original_label = post_original_output["label"]
perturbed_label = post_perturbed_output["label"]
# Handle different return types of Label interface
if "confidences" in post_original_output:
original_confidence = original_output[0][original_label]
perturbed_confidence = perturbed_output[0][original_label]
score = original_confidence - perturbed_confidence
else:
score = diff(original_label, perturbed_label)
return score
elif isinstance(output_component, Number):
score = diff(post_original_output, post_perturbed_output)
return score
else:
raise ValueError(
"This interpretation method doesn't support the Output component: {}".format(
output_component
)
)
def get_regression_or_classification_value(
interface, original_output, perturbed_output
):
"""Used to combine regression/classification for Shap interpretation method."""
output_component = interface.output_components[0]
post_original_output = output_component.postprocess(original_output[0])
post_perturbed_output = output_component.postprocess(perturbed_output[0])
if type(output_component) == Label:
original_label = post_original_output["label"]
perturbed_label = post_perturbed_output["label"]
# Handle different return types of Label interface
if "confidences" in post_original_output:
if math.isnan(perturbed_output[0][original_label]):
return 0
return perturbed_output[0][original_label]
else:
score = diff(
perturbed_label, original_label
) # Intentionally inverted order of arguments.
return score
else:
raise ValueError(
"This interpretation method doesn't support the Output component: {}".format(
output_component
)
)