anpigon's picture
add langchain docs
ed4d993
import os
from contextlib import contextmanager
from typing import Generator
from unittest.mock import Mock
import pytest
from ai21 import AI21Client
from ai21.models import (
AnswerResponse,
ChatOutput,
ChatResponse,
Completion,
CompletionData,
CompletionFinishReason,
CompletionsResponse,
FinishReason,
Penalty,
RoleType,
SegmentationResponse,
)
from ai21.models.responses.segmentation_response import Segment
from pytest_mock import MockerFixture
J2_CHAT_MODEL_NAME = "j2-ultra"
JAMBA_CHAT_MODEL_NAME = "jamba-instruct-preview"
DUMMY_API_KEY = "test_api_key"
BASIC_EXAMPLE_LLM_PARAMETERS = {
"num_results": 3,
"max_tokens": 20,
"min_tokens": 10,
"temperature": 0.5,
"top_p": 0.5,
"top_k_return": 0,
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True),
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True),
"count_penalty": Penalty(
scale=0.2,
apply_to_punctuation=True,
apply_to_emojis=True,
),
}
BASIC_EXAMPLE_CHAT_PARAMETERS = {
"num_results": 3,
"max_tokens": 20,
"min_tokens": 10,
"temperature": 0.5,
"top_p": 0.5,
"top_k_return": 0,
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True),
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True),
"count_penalty": Penalty(
scale=0.2,
apply_to_punctuation=True,
apply_to_emojis=True,
),
"n": 3,
}
SEGMENTS = [
Segment(
segment_type="normal_text",
segment_text=(
"The original full name of the franchise is Pocket Monsters "
"(ポケットモンスター, Poketto Monsutā), which was abbreviated to "
"Pokemon during development of the original games.\n\nWhen the "
"franchise was released internationally, the short form of the "
"title was used, with an acute accent (´) over the e to aid "
"in pronunciation."
),
),
Segment(
segment_type="normal_text",
segment_text=(
"Pokémon refers to both the franchise itself and the creatures "
"within its fictional universe.\n\nAs a noun, it is identical in "
"both the singular and plural, as is every individual species "
'name;[10] it is grammatically correct to say "one Pokémon" '
'and "many Pokémon", as well as "one Pikachu" and "many '
'Pikachu".\n\nIn English, Pokémon may be pronounced either '
"/'powkɛmon/ (poe-keh-mon) or /'powkɪmon/ (poe-key-mon)."
),
),
]
BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT = {
"num_results": 3,
"max_tokens": 20,
"min_tokens": 10,
"temperature": 0.5,
"top_p": 0.5,
"top_k_return": 0,
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True).to_dict(),
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True).to_dict(),
"count_penalty": Penalty(
scale=0.2,
apply_to_punctuation=True,
apply_to_emojis=True,
).to_dict(),
}
BASIC_EXAMPLE_CHAT_PARAMETERS_AS_DICT = {
"num_results": 3,
"max_tokens": 20,
"min_tokens": 10,
"temperature": 0.5,
"top_p": 0.5,
"top_k_return": 0,
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True).to_dict(),
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True).to_dict(),
"count_penalty": Penalty(
scale=0.2,
apply_to_punctuation=True,
apply_to_emojis=True,
).to_dict(),
"n": 3,
}
@pytest.fixture
def mocked_completion_response(mocker: MockerFixture) -> Mock:
mocked_response = mocker.MagicMock(spec=CompletionsResponse)
mocked_response.prompt = "this is a test prompt"
mocked_response.completions = [
Completion(
data=CompletionData(text="test", tokens=[]),
finish_reason=CompletionFinishReason(reason=None, length=None),
)
]
return mocked_response
@pytest.fixture
def mock_client_with_completion(
mocker: MockerFixture, mocked_completion_response: Mock
) -> Mock:
mock_client = mocker.MagicMock(spec=AI21Client)
mock_client.completion = mocker.MagicMock()
mock_client.completion.create.side_effect = [
mocked_completion_response,
mocked_completion_response,
]
mock_client.count_tokens.side_effect = [10, 20]
return mock_client
@pytest.fixture
def mock_client_with_chat(mocker: MockerFixture) -> Mock:
mock_client = mocker.MagicMock(spec=AI21Client)
mock_client.chat = mocker.MagicMock()
output = ChatOutput(
text="Hello Pickle Rick!",
role=RoleType.ASSISTANT,
finish_reason=FinishReason(reason="testing"),
)
mock_client.chat.create.return_value = ChatResponse(outputs=[output])
return mock_client
@contextmanager
def temporarily_unset_api_key() -> Generator:
"""
Unset and set environment key for testing purpose for when an API KEY is not set
"""
api_key = os.environ.pop("AI21_API_KEY", None)
yield
if api_key is not None:
os.environ["AI21_API_KEY"] = api_key
@pytest.fixture
def mock_client_with_contextual_answers(mocker: MockerFixture) -> Mock:
mock_client = mocker.MagicMock(spec=AI21Client)
mock_client.answer = mocker.MagicMock()
mock_client.answer.create.return_value = AnswerResponse(
id="some_id",
answer="some answer",
answer_in_context=False,
)
return mock_client
@pytest.fixture
def mock_client_with_semantic_text_splitter(mocker: MockerFixture) -> Mock:
mock_client = mocker.MagicMock(spec=AI21Client)
mock_client.segmentation = mocker.MagicMock()
mock_client.segmentation.create.return_value = SegmentationResponse(
id="12345",
segments=SEGMENTS,
)
return mock_client