nazneen's picture
interactive model card streamlit app
90f4ec6
raw history blame
No virus
4.93 kB
"""
rg_utils load helpers methods from python
"""
import pandas as pd
import re
import robustnessgym as rg
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
def update_pred(dp, model, dp_only=False):
""" Updating data panel with model prediction"""
model.predict_batch(dp, ["sentence"])
dp = dp.update(
lambda x: model.predict_batch(x, ["sentence"]),
batch_size=4,
is_batched_fn=True,
pbar=True,
)
if dp_only:
return dp
labels = pd.Series(["Negative Sentiment", "Positive Sentiment"])
probs = pd.Series(dp.__dict__["_data"]["probs"][0])
pred = pd.concat([labels, probs], axis=1)
pred.columns = ["Label", "Probability"]
return (dp, pred)
def remove_slice(bench, slice_name="user_data"):
""" Remove a slice from the rg dev bench"""
# slices and identifiers are in the same order
slice_list = []
slice_identifier = []
for i in bench.__dict__["_slices"]:
# look-up the term
name = str(i.__dict__["_identifier"])
if not re.search("new_words", name):
slice_list = slice_list + [i]
slice_identifier = slice_identifier + [name]
# metrics put datain a different order
metrics = {}
for key in bench.metrics["model"].keys():
if not re.search("new_words", key):
metrics[key] = bench.metrics["model"][key]
# slice table, repeat for sanity check
# slice_table = {}
# for key in bench.__dict__["_slice_table"].keys():
# key = str(key)
# if not re.search("new_words",key):
# slice_table[key] = bench.__dict__["_slice_table"][key]
bench.__dict__["_slices"] = set(slice_list)
bench.__dict__["_slice_identifiers"] = set(slice_identifier)
# bench.__dict__["_slice_table"] = set(slice_identifier)
bench.metrics["model"] = metrics
return bench
def add_slice(bench, table, model, slice_name="user_data"):
""" Adds a custom slice to RG """
# do it this way or it complains
dp = rg.DataPanel(
{
"sentence": table["sentence"].tolist(),
"label": table["label"].tolist(),
"pred": table["pred"].tolist(),
}
)
# dp._identifier = slice_name
# get prediction
# add to bench
# bench.add_slices([dp])
return dp
def new_bench():
""" Create new rg dev bench"""
bench = rg.DevBench()
bench.add_aggregators(
{
# Every model can be associated with custom metric calculation functions
#'distilbert-base-uncased-finetuned-sst-2-english': {
"model": {
# This function uses the predictions we stored earlier to calculate accuracy
#'accuracy': lambda dp: (dp['label'].round() == dp['pred'].numpy()).mean()
#'f1' : lambda dp: f1_score(dp['label'].round(),dp['pred'],average='macro',zero_division=1),
"recall": lambda dp: recall_score(
dp["label"].round(), dp["pred"], average="macro", zero_division=1
),
"precision": lambda dp: precision_score(
dp["label"].round(), dp["pred"], average="macro", zero_division=1
),
"accuracy": lambda dp: accuracy_score(dp["label"].round(), dp["pred"]),
}
}
)
return bench
def get_sliceid(slices):
""" Because RG stores data in a silly way"""
ids = []
for slice in list(slices):
ids = ids + [slice._identifier]
return ids
def get_sliceidx(slice_ids,name):
""" get the index from an rg slice"""
if name == "xyz_train":
idx = [i for i, elem in enumerate(slice_ids) if ("split=train" in str(elem)) ] #and len(str(elem).split("->")) == 1)]
elif name == "xyz_test":
idx = [i for i, elem in enumerate(slice_ids) if ("split=test" in str(elem)) ] #and len(str(elem).split("->")) == 1)]
else:
idx = [i for i, elem in enumerate(slice_ids) if name in str(elem)]
return idx[0]
def get_prob(x,i):
""" Helper to get probability"""
return(float(x[i]))
def slice_to_df(data):
""" Convert slice to dataframe"""
df = pd.DataFrame(
{
"sentence": list(data["sentence"]),
"model label": ["Positive Sentiment" if int(round(x)) == 1 else "Negative Sentiment" for x in data["label"]],
"model binary": [int(round(x)) for x in data["label"]],
}
)
prob = []
for i in range(0, len(data['probs'])):
prob.append(get_prob(data['probs'][i],df["model binary"][i]))
df["probability"] = prob
return df
def metrics_to_dict(metrics, slice_name):
""" Convert metrics to dataframe"""
all_metrics = {slice_name: {}}
all_metrics[slice_name]["metrics"] = metrics[slice_name]
all_metrics[slice_name]["source"] = "Custom Slice"
return all_metrics