|
import ast |
|
import copy |
|
import json |
|
import re |
|
import string |
|
from difflib import get_close_matches |
|
from typing import Any, Dict |
|
|
|
import numpy as np |
|
|
|
from .deprecation_utils import deprecation |
|
from .operator import MultiStreamOperator |
|
from .operators import FieldOperator, InstanceFieldOperator |
|
from .settings_utils import get_constants |
|
|
|
constants = get_constants() |
|
|
|
|
|
class PostProcess(MultiStreamOperator): |
|
operator: InstanceFieldOperator |
|
process_prediction: bool = True |
|
process_references: bool = True |
|
|
|
def prepare(self): |
|
super().prepare() |
|
self.prediction_operator = copy.copy(self.operator) |
|
self.prediction_operator.field = "prediction" |
|
self.references_operator = copy.copy(self.operator) |
|
self.references_operator.field = "references" |
|
self.references_operator.process_every_value = True |
|
self.references_operator.dont_apply_to_streams = [constants.inference_stream] |
|
|
|
def process(self, multi_stream): |
|
if self.process_prediction: |
|
multi_stream = self.prediction_operator(multi_stream) |
|
if self.process_references: |
|
multi_stream = self.references_operator(multi_stream) |
|
return multi_stream |
|
|
|
|
|
class ToString(FieldOperator): |
|
def process_value(self, text: Any) -> Any: |
|
return str(text) |
|
|
|
|
|
class ToStringStripped(FieldOperator): |
|
def process_value(self, text: Any) -> Any: |
|
return str(text).strip() |
|
|
|
|
|
class SplitStrip(FieldOperator): |
|
delimiter: str = " " |
|
strip_every_element: bool = False |
|
|
|
def process_value(self, text: Any) -> Any: |
|
return [ |
|
x.strip() if self.strip_every_element else x |
|
for x in text.split(self.delimiter) |
|
] |
|
|
|
|
|
class ToListByComma(SplitStrip): |
|
delimiter = "," |
|
strip_every_element = True |
|
|
|
|
|
class ToListByCommaSpace(SplitStrip): |
|
delimiter = ", " |
|
strip_every_element = True |
|
|
|
|
|
class RegexParser(FieldOperator): |
|
"""A processor that uses regex in order to parse a string.""" |
|
|
|
regex: str |
|
termination_regex: str = None |
|
|
|
def process_value(self, text: Any) -> Any: |
|
if self.termination_regex is not None and re.fullmatch( |
|
self.termination_regex, text |
|
): |
|
return [] |
|
return re.findall(self.regex, text) |
|
|
|
|
|
class ExtractWithRegex(RegexParser): |
|
def process_value(self, text: Any) -> Any: |
|
matches = super().process_value(text) |
|
if matches: |
|
return matches[0] |
|
return "" |
|
|
|
|
|
class ListToEmptyEntitiesTuples(FieldOperator): |
|
def process_value(self, lst: Any) -> Any: |
|
try: |
|
return [(str(item), "") for item in lst] |
|
except json.JSONDecodeError: |
|
return [] |
|
|
|
|
|
class DictOfListsToPairs(FieldOperator): |
|
position_key_before_value: bool = True |
|
|
|
def process_value(self, obj: Any) -> Any: |
|
try: |
|
result = [] |
|
for key, values in obj.items(): |
|
for value in values: |
|
assert isinstance(value, str) |
|
pair = ( |
|
(key, value) if self.position_key_before_value else (value, key) |
|
) |
|
result.append(pair) |
|
return result |
|
except: |
|
return [] |
|
|
|
|
|
class TakeFirstNonEmptyLine(FieldOperator): |
|
def process_value(self, text: Any) -> Any: |
|
parts = str(text).strip().split("\n") |
|
if len(parts) == 0: |
|
return "" |
|
return parts[0].strip() |
|
|
|
|
|
class ConvertToBoolean(FieldOperator): |
|
def process_value(self, text: Any) -> Any: |
|
clean_instance = str(text).strip().lower() |
|
if any(w in clean_instance for w in ["no", "not", "wrong", "false"]): |
|
return "FALSE" |
|
if any(w in clean_instance for w in ["yes", "right", "correct", "true"]): |
|
return "TRUE" |
|
return "OTHER" |
|
|
|
|
|
class LowerCaseTillPunc(FieldOperator): |
|
def process_value(self, text: Any) -> Any: |
|
non_empty_line = text.lower() |
|
match = re.search(r"[.,!?;]", non_empty_line) |
|
if match: |
|
|
|
non_empty_line = non_empty_line[: match.start()] |
|
return non_empty_line |
|
|
|
|
|
class Lower(FieldOperator): |
|
def process_value(self, text: Any) -> Any: |
|
return text.lower() |
|
|
|
|
|
@deprecation("2.0.0", alternative=Lower) |
|
class LowerCase(Lower): |
|
pass |
|
|
|
|
|
class Capitalize(FieldOperator): |
|
def process_value(self, text: Any) -> Any: |
|
return text.capitalize() |
|
|
|
|
|
class GetStringAfter(FieldOperator): |
|
substring: str |
|
|
|
def process_value(self, text: Any) -> Any: |
|
return text.split(self.substring, 1)[-1].strip() |
|
|
|
|
|
class MatchClosestOption(InstanceFieldOperator): |
|
options_field: str = "options" |
|
|
|
def process_instance_value(self, value: Any, instance: Dict[str, Any]): |
|
options = instance["task_data"][self.options_field] |
|
return get_close_matches(value, options, n=1, cutoff=0.0)[0] |
|
|
|
|
|
def process_instance_value(self, value, instance): |
|
options = instance[self.options_field] |
|
|
|
closest_match = get_close_matches(value, options, n=1, cutoff=0) |
|
return closest_match[0] if closest_match else None |
|
|
|
|
|
class Substring(FieldOperator): |
|
begin: int = 0 |
|
end: int = None |
|
|
|
def process_value(self, text: Any) -> Any: |
|
if self.end is None: |
|
return text[self.begin :] |
|
return text[self.begin : self.end] |
|
|
|
|
|
class FirstCharacter(FieldOperator): |
|
def process_value(self, text: Any) -> Any: |
|
match = re.search(r"\s*(\w)", text) |
|
if match: |
|
return match.groups(0)[0] |
|
return "" |
|
|
|
|
|
class TakeFirstWord(FieldOperator): |
|
def process_value(self, text: Any) -> Any: |
|
match = re.search(r"([-]*[0-9]+(\.([0-9]+))*)|([\w]+)", text) |
|
if match: |
|
return text[match.start() : match.end()] |
|
return "" |
|
|
|
|
|
class YesNoToInt(FieldOperator): |
|
def process_value(self, text: Any) -> Any: |
|
if text == "yes": |
|
return "1" |
|
if text == "no": |
|
return "0" |
|
return text |
|
|
|
|
|
class YesToOneElseZero(FieldOperator): |
|
def process_value(self, text: Any) -> Any: |
|
if text == "yes": |
|
return "1" |
|
return "0" |
|
|
|
|
|
class StrToFloatFormat(FieldOperator): |
|
def process_value(self, text: Any) -> Any: |
|
try: |
|
return str(float(text)) |
|
except Exception: |
|
return str(text) |
|
|
|
|
|
class ToYesOrNone(FieldOperator): |
|
def process_value(self, text: Any) -> Any: |
|
if text == "yes": |
|
return "yes" |
|
return "none" |
|
|
|
|
|
class StanceToProCon(FieldOperator): |
|
def process_value(self, text: Any) -> Any: |
|
if text == "positive": |
|
return "PRO" |
|
if text in ["negative", "suggestion"]: |
|
return "CON" |
|
return "none" |
|
|
|
|
|
class StringEquals(FieldOperator): |
|
string: str |
|
|
|
def process_value(self, text: Any) -> Any: |
|
if "not " + self.string.lower() in text.lower(): |
|
return "not " + self.string.lower() |
|
if self.string.lower() in text.lower(): |
|
return self.string.lower() |
|
return text |
|
|
|
|
|
@deprecation("2.0.0", alternative=StringEquals) |
|
class StringOrNotString(StringEquals): |
|
pass |
|
|
|
|
|
class ExtractMtBenchRatingJudgment(FieldOperator): |
|
def process_value(self, text: Any) -> Any: |
|
match = re.search(r"\[\[([\d]+\.?[\d]*)\]\]", text) |
|
try: |
|
return float(match.group(1)) / 10 |
|
except: |
|
return 0.0 |
|
|
|
|
|
class ExtractMtBenchLabelJudgment(FieldOperator): |
|
def process_value(self, text: Any) -> Any: |
|
match = re.search(r"\[\[([^\]]+)\]\]", text) |
|
try: |
|
return str(match.group(1)) |
|
except: |
|
return "None" |
|
|
|
|
|
class LiteralEval(FieldOperator): |
|
def process_value(self, text: Any) -> Any: |
|
if text is not None and not isinstance(text, str): |
|
raise ValueError( |
|
f"LiteralEval: field '{self.field}' is expected to be of 'str' input type, got: {type(text)}" |
|
) |
|
if text is None or text == "": |
|
return text |
|
return ast.literal_eval(text.strip()) |
|
|
|
|
|
class ExtractSafeUnsafeJudgment(FieldOperator): |
|
def process_value(self, text: Any) -> Any: |
|
first_line = str(text).strip().split("\n")[0].lower() |
|
if first_line == "safe": |
|
return 1.0 |
|
return 0.0 |
|
|
|
|
|
class ExtractArenaHardNumericalJudgment(FieldOperator): |
|
def process_value(self, text: Any) -> Any: |
|
match = re.search(r"\[\[([^\]]+)\]\]", text) |
|
try: |
|
res = str(match.group(1)) |
|
if res == "A>B": |
|
return 1 |
|
if res == "A>>B": |
|
return 3 |
|
if res == "B>A": |
|
return -1 |
|
if res == "B>>A": |
|
return -3 |
|
return 0 |
|
|
|
except: |
|
return 0 |
|
|
|
|
|
class InferDictsToBinaryLogprobs(FieldOperator): |
|
neg_class_name: str |
|
pos_class_name: str |
|
|
|
take_logprobs_from_end: bool = False |
|
num_logprobs_to_take: int = 3 |
|
min_probability_mass = 0.0001 |
|
|
|
def verify(self): |
|
super().verify() |
|
if ( |
|
self.neg_class_name.lower() in self.pos_class_name.lower() |
|
or self.pos_class_name.lower() in self.neg_class_name.lower() |
|
): |
|
raise ValueError( |
|
f"""Class names in {self.__class__.__name__} should not overlap, got "{self.pos_class_name}" and "{self.neg_class_name}""" |
|
) |
|
|
|
def process_value(self, obj: Any) -> Any: |
|
for i in self.get_token_range(obj): |
|
try: |
|
pos_probs, neg_probs = self.get_pos_neg_probs(pred_dict=obj[i]) |
|
if pos_probs or neg_probs: |
|
sum_probs = sum(pos_probs) + sum(neg_probs) |
|
if sum_probs > self.min_probability_mass: |
|
return sum(pos_probs) / sum_probs |
|
except: |
|
pass |
|
return 0 |
|
|
|
def get_pos_neg_probs(self, pred_dict): |
|
token_logprobs = pred_dict["top_tokens"] |
|
|
|
pos_and_neg_probs = [] |
|
for class_name in [self.pos_class_name, self.neg_class_name]: |
|
|
|
|
|
|
|
name_regex = re.compile( |
|
rf"(\W|Ġ|_)*{class_name}(\W|Ġ|_)*", flags=re.IGNORECASE |
|
) |
|
class_probs = [ |
|
np.exp(d["logprob"]) |
|
for d in token_logprobs |
|
if name_regex.fullmatch(d["text"]) |
|
] |
|
pos_and_neg_probs.append(class_probs) |
|
return pos_and_neg_probs |
|
|
|
def get_token_range(self, obj: Any) -> range: |
|
n_tokens = min([self.num_logprobs_to_take, len(obj)]) |
|
if self.take_logprobs_from_end: |
|
return range(-1, -(n_tokens + 1), -1) |
|
return range(n_tokens) |
|
|
|
|
|
class RemoveArticles(FieldOperator): |
|
def process_value(self, text: Any) -> Any: |
|
return re.sub(r"\b(a|an|the)\b", " ", text) |
|
|
|
|
|
class RemovePunctuations(FieldOperator): |
|
def process_value(self, text: Any) -> Any: |
|
puncs_to_exclude = set(string.punctuation) |
|
return "".join(c for c in text if c not in puncs_to_exclude) |
|
|
|
|
|
class FixWhiteSpace(FieldOperator): |
|
def process_value(self, text: Any) -> Any: |
|
return " ".join(text.split()) |
|
|