Spaces:
Runtime error
Runtime error
"""Test functionality related to ngram overlap based selector.""" | |
import pytest | |
from langchain.prompts.example_selector.ngram_overlap import ( | |
NGramOverlapExampleSelector, | |
ngram_overlap_score, | |
) | |
from langchain.prompts.prompt import PromptTemplate | |
EXAMPLES = [ | |
{"input": "See Spot run.", "output": "foo1"}, | |
{"input": "My dog barks.", "output": "foo2"}, | |
{"input": "Spot can run.", "output": "foo3"}, | |
] | |
def selector() -> NGramOverlapExampleSelector: | |
"""Get ngram overlap based selector to use in tests.""" | |
prompts = PromptTemplate( | |
input_variables=["input", "output"], template="Input: {input}\nOutput: {output}" | |
) | |
selector = NGramOverlapExampleSelector( | |
examples=EXAMPLES, | |
example_prompt=prompts, | |
) | |
return selector | |
def test_selector_valid(selector: NGramOverlapExampleSelector) -> None: | |
"""Test NGramOverlapExampleSelector can select examples.""" | |
sentence = "Spot can run." | |
output = selector.select_examples({"input": sentence}) | |
assert output == [EXAMPLES[2], EXAMPLES[0], EXAMPLES[1]] | |
def test_selector_add_example(selector: NGramOverlapExampleSelector) -> None: | |
"""Test NGramOverlapExampleSelector can add an example.""" | |
new_example = {"input": "Spot plays fetch.", "output": "foo4"} | |
selector.add_example(new_example) | |
sentence = "Spot can run." | |
output = selector.select_examples({"input": sentence}) | |
assert output == [EXAMPLES[2], EXAMPLES[0]] + [new_example] + [EXAMPLES[1]] | |
def test_selector_threshold_zero(selector: NGramOverlapExampleSelector) -> None: | |
"""Tests NGramOverlapExampleSelector threshold set to 0.0.""" | |
selector.threshold = 0.0 | |
sentence = "Spot can run." | |
output = selector.select_examples({"input": sentence}) | |
assert output == [EXAMPLES[2], EXAMPLES[0]] | |
def test_selector_threshold_more_than_one( | |
selector: NGramOverlapExampleSelector, | |
) -> None: | |
"""Tests NGramOverlapExampleSelector threshold greater than 1.0.""" | |
selector.threshold = 1.0 + 1e-9 | |
sentence = "Spot can run." | |
output = selector.select_examples({"input": sentence}) | |
assert output == [] | |
def test_ngram_overlap_score(selector: NGramOverlapExampleSelector) -> None: | |
"""Tests that ngram_overlap_score returns correct values.""" | |
selector.threshold = 1.0 + 1e-9 | |
none = ngram_overlap_score(["Spot can run."], ["My dog barks."]) | |
some = ngram_overlap_score(["Spot can run."], ["See Spot run."]) | |
complete = ngram_overlap_score(["Spot can run."], ["Spot can run."]) | |
check = [abs(none - 0.0) < 1e-9, 0.0 < some < 1.0, abs(complete - 1.0) < 1e-9] | |
assert check == [True, True, True] | |