File size: 1,909 Bytes
f5776d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from dspy.primitives.prediction import Prediction, Completions
from dsp.utils import normalize_text


default_normalize = lambda s: normalize_text(s) or None


def majority(prediction_or_completions, normalize=default_normalize, field=None):
    """
        Returns the most common completion for the target field (or the last field) in the signature.
        When normalize returns None, that completion is ignored.
        In case of a tie, earlier completion are prioritized.
    """

    assert any(isinstance(prediction_or_completions, t) for t in [Prediction, Completions, list])
    input_type = type(prediction_or_completions)

    # Get the completions
    if isinstance(prediction_or_completions, Prediction):
        completions = prediction_or_completions.completions
    else:
        completions = prediction_or_completions
    
    try:
        signature = completions.signature
    except:
        signature = None
    
    try:
        field = field if field else signature.fields[-1].output_variable
    except:
        field = field if field else list(completions[0].keys())[-1]

    # Normalize
    normalize = normalize if normalize else lambda x: x
    normalized_values = [normalize(completion[field]) for completion in completions]
    normalized_values_ = [x for x in normalized_values if x is not None]
    
    # Count
    value_counts = {}
    for value in (normalized_values_ or normalized_values):
        value_counts[value] = value_counts.get(value, 0) + 1

    majority_value = max(value_counts, key=value_counts.get)

    # Return the first completion with the majority value in the field
    for completion in completions:
        if normalize(completion[field]) == majority_value:
            break
    
    # if input_type == Prediction:
    return Prediction.from_completions([completion], signature=signature)

    return Completions([completion], signature=signature)