"""Test few shot prompt template.""" import pytest from langchain.prompts.few_shot import FewShotPromptTemplate from langchain.prompts.prompt import PromptTemplate EXAMPLE_PROMPT = PromptTemplate( input_variables=["question", "answer"], template="{question}: {answer}" ) def test_suffix_only() -> None: """Test prompt works with just a suffix.""" suffix = "This is a {foo} test." input_variables = ["foo"] prompt = FewShotPromptTemplate( input_variables=input_variables, suffix=suffix, examples=[], example_prompt=EXAMPLE_PROMPT, ) output = prompt.format(foo="bar") expected_output = "This is a bar test." assert output == expected_output def test_prompt_missing_input_variables() -> None: """Test error is raised when input variables are not provided.""" # Test when missing in suffix template = "This is a {foo} test." with pytest.raises(ValueError): FewShotPromptTemplate( input_variables=[], suffix=template, examples=[], example_prompt=EXAMPLE_PROMPT, ) # Test when missing in prefix template = "This is a {foo} test." with pytest.raises(ValueError): FewShotPromptTemplate( input_variables=[], suffix="foo", examples=[], prefix=template, example_prompt=EXAMPLE_PROMPT, ) def test_prompt_extra_input_variables() -> None: """Test error is raised when there are too many input variables.""" template = "This is a {foo} test." input_variables = ["foo", "bar"] with pytest.raises(ValueError): FewShotPromptTemplate( input_variables=input_variables, suffix=template, examples=[], example_prompt=EXAMPLE_PROMPT, ) def test_few_shot_functionality() -> None: """Test that few shot works with examples.""" prefix = "This is a test about {content}." suffix = "Now you try to talk about {new_content}." examples = [ {"question": "foo", "answer": "bar"}, {"question": "baz", "answer": "foo"}, ] prompt = FewShotPromptTemplate( suffix=suffix, prefix=prefix, input_variables=["content", "new_content"], examples=examples, example_prompt=EXAMPLE_PROMPT, example_separator="\n", ) output = prompt.format(content="animals", new_content="party") expected_output = ( "This is a test about animals.\n" "foo: bar\n" "baz: foo\n" "Now you try to talk about party." ) assert output == expected_output def test_partial_init_string() -> None: """Test prompt can be initialized with partial variables.""" prefix = "This is a test about {content}." suffix = "Now you try to talk about {new_content}." examples = [ {"question": "foo", "answer": "bar"}, {"question": "baz", "answer": "foo"}, ] prompt = FewShotPromptTemplate( suffix=suffix, prefix=prefix, input_variables=["new_content"], partial_variables={"content": "animals"}, examples=examples, example_prompt=EXAMPLE_PROMPT, example_separator="\n", ) output = prompt.format(new_content="party") expected_output = ( "This is a test about animals.\n" "foo: bar\n" "baz: foo\n" "Now you try to talk about party." ) assert output == expected_output def test_partial_init_func() -> None: """Test prompt can be initialized with partial variables.""" prefix = "This is a test about {content}." suffix = "Now you try to talk about {new_content}." examples = [ {"question": "foo", "answer": "bar"}, {"question": "baz", "answer": "foo"}, ] prompt = FewShotPromptTemplate( suffix=suffix, prefix=prefix, input_variables=["new_content"], partial_variables={"content": lambda: "animals"}, examples=examples, example_prompt=EXAMPLE_PROMPT, example_separator="\n", ) output = prompt.format(new_content="party") expected_output = ( "This is a test about animals.\n" "foo: bar\n" "baz: foo\n" "Now you try to talk about party." ) assert output == expected_output def test_partial() -> None: """Test prompt can be partialed.""" prefix = "This is a test about {content}." suffix = "Now you try to talk about {new_content}." examples = [ {"question": "foo", "answer": "bar"}, {"question": "baz", "answer": "foo"}, ] prompt = FewShotPromptTemplate( suffix=suffix, prefix=prefix, input_variables=["content", "new_content"], examples=examples, example_prompt=EXAMPLE_PROMPT, example_separator="\n", ) new_prompt = prompt.partial(content="foo") new_output = new_prompt.format(new_content="party") expected_output = ( "This is a test about foo.\n" "foo: bar\n" "baz: foo\n" "Now you try to talk about party." ) assert new_output == expected_output output = prompt.format(new_content="party", content="bar") expected_output = ( "This is a test about bar.\n" "foo: bar\n" "baz: foo\n" "Now you try to talk about party." ) assert output == expected_output