File size: 1,909 Bytes
f5776d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
from dspy.primitives.prediction import Prediction, Completions
from dsp.utils import normalize_text
default_normalize = lambda s: normalize_text(s) or None
def majority(prediction_or_completions, normalize=default_normalize, field=None):
"""
Returns the most common completion for the target field (or the last field) in the signature.
When normalize returns None, that completion is ignored.
In case of a tie, earlier completion are prioritized.
"""
assert any(isinstance(prediction_or_completions, t) for t in [Prediction, Completions, list])
input_type = type(prediction_or_completions)
# Get the completions
if isinstance(prediction_or_completions, Prediction):
completions = prediction_or_completions.completions
else:
completions = prediction_or_completions
try:
signature = completions.signature
except:
signature = None
try:
field = field if field else signature.fields[-1].output_variable
except:
field = field if field else list(completions[0].keys())[-1]
# Normalize
normalize = normalize if normalize else lambda x: x
normalized_values = [normalize(completion[field]) for completion in completions]
normalized_values_ = [x for x in normalized_values if x is not None]
# Count
value_counts = {}
for value in (normalized_values_ or normalized_values):
value_counts[value] = value_counts.get(value, 0) + 1
majority_value = max(value_counts, key=value_counts.get)
# Return the first completion with the majority value in the field
for completion in completions:
if normalize(completion[field]) == majority_value:
break
# if input_type == Prediction:
return Prediction.from_completions([completion], signature=signature)
return Completions([completion], signature=signature)
|