Spaces:
Runtime error
Runtime error
File size: 2,738 Bytes
58d33f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
"""Test LLM Math functionality."""
import json
import pytest
from langchain import LLMChain
from langchain.chains.api.base import APIChain
from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
from langchain.requests import RequestsWrapper
from tests.unit_tests.llms.fake_llm import FakeLLM
class FakeRequestsChain(RequestsWrapper):
"""Fake requests chain just for testing purposes."""
output: str
def get(self, url: str) -> str:
"""Just return the specified output."""
return self.output
@pytest.fixture
def test_api_data() -> dict:
"""Fake api data to use for testing."""
api_docs = """
This API endpoint will search the notes for a user.
Endpoint: https://thisapidoesntexist.com
GET /api/notes
Query parameters:
q | string | The search term for notes
"""
return {
"api_docs": api_docs,
"question": "Search for notes containing langchain",
"api_url": "https://thisapidoesntexist.com/api/notes?q=langchain",
"api_response": json.dumps(
{
"success": True,
"results": [{"id": 1, "content": "Langchain is awesome!"}],
}
),
"api_summary": "There is 1 note about langchain.",
}
@pytest.fixture
def fake_llm_api_chain(test_api_data: dict) -> APIChain:
"""Fake LLM API chain for testing."""
TEST_API_DOCS = test_api_data["api_docs"]
TEST_QUESTION = test_api_data["question"]
TEST_URL = test_api_data["api_url"]
TEST_API_RESPONSE = test_api_data["api_response"]
TEST_API_SUMMARY = test_api_data["api_summary"]
api_url_query_prompt = API_URL_PROMPT.format(
api_docs=TEST_API_DOCS, question=TEST_QUESTION
)
api_response_prompt = API_RESPONSE_PROMPT.format(
api_docs=TEST_API_DOCS,
question=TEST_QUESTION,
api_url=TEST_URL,
api_response=TEST_API_RESPONSE,
)
queries = {api_url_query_prompt: TEST_URL, api_response_prompt: TEST_API_SUMMARY}
fake_llm = FakeLLM(queries=queries)
api_request_chain = LLMChain(llm=fake_llm, prompt=API_URL_PROMPT)
api_answer_chain = LLMChain(llm=fake_llm, prompt=API_RESPONSE_PROMPT)
requests_wrapper = FakeRequestsChain(output=TEST_API_RESPONSE)
return APIChain(
api_request_chain=api_request_chain,
api_answer_chain=api_answer_chain,
requests_wrapper=requests_wrapper,
api_docs=TEST_API_DOCS,
)
def test_api_question(fake_llm_api_chain: APIChain, test_api_data: dict) -> None:
"""Test simple question that needs API access."""
question = test_api_data["question"]
output = fake_llm_api_chain.run(question)
assert output == test_api_data["api_summary"]
|