File size: 2,738 Bytes
58d33f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""Test LLM Math functionality."""

import json

import pytest

from langchain import LLMChain
from langchain.chains.api.base import APIChain
from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
from langchain.requests import RequestsWrapper
from tests.unit_tests.llms.fake_llm import FakeLLM


class FakeRequestsChain(RequestsWrapper):
    """Fake requests chain just for testing purposes."""

    output: str

    def get(self, url: str) -> str:
        """Just return the specified output."""
        return self.output


@pytest.fixture
def test_api_data() -> dict:
    """Fake api data to use for testing."""
    api_docs = """
    This API endpoint will search the notes for a user.

    Endpoint: https://thisapidoesntexist.com
    GET /api/notes

    Query parameters:
    q | string | The search term for notes
    """
    return {
        "api_docs": api_docs,
        "question": "Search for notes containing langchain",
        "api_url": "https://thisapidoesntexist.com/api/notes?q=langchain",
        "api_response": json.dumps(
            {
                "success": True,
                "results": [{"id": 1, "content": "Langchain is awesome!"}],
            }
        ),
        "api_summary": "There is 1 note about langchain.",
    }


@pytest.fixture
def fake_llm_api_chain(test_api_data: dict) -> APIChain:
    """Fake LLM API chain for testing."""
    TEST_API_DOCS = test_api_data["api_docs"]
    TEST_QUESTION = test_api_data["question"]
    TEST_URL = test_api_data["api_url"]
    TEST_API_RESPONSE = test_api_data["api_response"]
    TEST_API_SUMMARY = test_api_data["api_summary"]

    api_url_query_prompt = API_URL_PROMPT.format(
        api_docs=TEST_API_DOCS, question=TEST_QUESTION
    )
    api_response_prompt = API_RESPONSE_PROMPT.format(
        api_docs=TEST_API_DOCS,
        question=TEST_QUESTION,
        api_url=TEST_URL,
        api_response=TEST_API_RESPONSE,
    )
    queries = {api_url_query_prompt: TEST_URL, api_response_prompt: TEST_API_SUMMARY}
    fake_llm = FakeLLM(queries=queries)
    api_request_chain = LLMChain(llm=fake_llm, prompt=API_URL_PROMPT)
    api_answer_chain = LLMChain(llm=fake_llm, prompt=API_RESPONSE_PROMPT)
    requests_wrapper = FakeRequestsChain(output=TEST_API_RESPONSE)
    return APIChain(
        api_request_chain=api_request_chain,
        api_answer_chain=api_answer_chain,
        requests_wrapper=requests_wrapper,
        api_docs=TEST_API_DOCS,
    )


def test_api_question(fake_llm_api_chain: APIChain, test_api_data: dict) -> None:
    """Test simple question that needs API access."""
    question = test_api_data["question"]
    output = fake_llm_api_chain.run(question)
    assert output == test_api_data["api_summary"]