File size: 2,660 Bytes
58d33f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
"""Test functionality related to ngram overlap based selector."""

import pytest

from langchain.prompts.example_selector.ngram_overlap import (
    NGramOverlapExampleSelector,
    ngram_overlap_score,
)
from langchain.prompts.prompt import PromptTemplate

EXAMPLES = [
    {"input": "See Spot run.", "output": "foo1"},
    {"input": "My dog barks.", "output": "foo2"},
    {"input": "Spot can run.", "output": "foo3"},
]


@pytest.fixture
def selector() -> NGramOverlapExampleSelector:
    """Get ngram overlap based selector to use in tests."""
    prompts = PromptTemplate(
        input_variables=["input", "output"], template="Input: {input}\nOutput: {output}"
    )
    selector = NGramOverlapExampleSelector(
        examples=EXAMPLES,
        example_prompt=prompts,
    )
    return selector


def test_selector_valid(selector: NGramOverlapExampleSelector) -> None:
    """Test NGramOverlapExampleSelector can select examples."""
    sentence = "Spot can run."
    output = selector.select_examples({"input": sentence})
    assert output == [EXAMPLES[2], EXAMPLES[0], EXAMPLES[1]]


def test_selector_add_example(selector: NGramOverlapExampleSelector) -> None:
    """Test NGramOverlapExampleSelector can add an example."""
    new_example = {"input": "Spot plays fetch.", "output": "foo4"}
    selector.add_example(new_example)
    sentence = "Spot can run."
    output = selector.select_examples({"input": sentence})
    assert output == [EXAMPLES[2], EXAMPLES[0]] + [new_example] + [EXAMPLES[1]]


def test_selector_threshold_zero(selector: NGramOverlapExampleSelector) -> None:
    """Tests NGramOverlapExampleSelector threshold set to 0.0."""
    selector.threshold = 0.0
    sentence = "Spot can run."
    output = selector.select_examples({"input": sentence})
    assert output == [EXAMPLES[2], EXAMPLES[0]]


def test_selector_threshold_more_than_one(
    selector: NGramOverlapExampleSelector,
) -> None:
    """Tests NGramOverlapExampleSelector threshold greater than 1.0."""
    selector.threshold = 1.0 + 1e-9
    sentence = "Spot can run."
    output = selector.select_examples({"input": sentence})
    assert output == []


def test_ngram_overlap_score(selector: NGramOverlapExampleSelector) -> None:
    """Tests that ngram_overlap_score returns correct values."""
    selector.threshold = 1.0 + 1e-9
    none = ngram_overlap_score(["Spot can run."], ["My dog barks."])
    some = ngram_overlap_score(["Spot can run."], ["See Spot run."])
    complete = ngram_overlap_score(["Spot can run."], ["Spot can run."])

    check = [abs(none - 0.0) < 1e-9, 0.0 < some < 1.0, abs(complete - 1.0) < 1e-9]
    assert check == [True, True, True]