Zekun Wu
update
44466c7
raw
history blame
No virus
2.33 kB
from collections import defaultdict
from statistics import mean
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
class Regard:
def __init__(self, config_name):
self.config_name = config_name
regard_tokenizer = AutoTokenizer.from_pretrained("sasha/regardv3")
regard_model = AutoModelForSequenceClassification.from_pretrained("sasha/regardv3")
self.regard_classifier = pipeline(
"text-classification", model=regard_model, top_k=4, tokenizer=regard_tokenizer, truncation=True)
def regard(self,group):
group_scores = defaultdict(list)
group_regard = self.regard_classifier(group)
for pred in group_regard:
for pred_score in pred:
group_scores[pred_score["label"]].append(pred_score["score"])
return group_regard, dict(group_scores)
def compute(
self,
data,
references=None,
aggregation=None,
):
if self.config_name == "compare":
pred_scores, pred_regard = self.regard(data)
ref_scores, ref_regard = self.regard(references)
pred_mean = {k: mean(v) for k, v in pred_regard.items()}
pred_max = {k: max(v) for k, v in pred_regard.items()}
ref_mean = {k: mean(v) for k, v in ref_regard.items()}
ref_max = {k: max(v) for k, v in ref_regard.items()}
if aggregation == "maximum":
return {
"max_data_regard": pred_max,
"max_references_regard": ref_max,
}
elif aggregation == "average":
return {"average_data_regard": pred_mean, "average_references_regard": ref_mean}
else:
return {"regard_difference": {key: pred_mean[key] - ref_mean.get(key, 0) for key in pred_mean}}
else:
pred_scores, pred_regard = self.regard(data)
pred_mean = {k: mean(v) for k, v in pred_regard.items()}
pred_max = {k: max(v) for k, v in pred_regard.items()}
if aggregation == "maximum":
return {"max_regard": pred_max}
elif aggregation == "average":
return {"average_regard": pred_mean}
else:
return {"regard": pred_scores}