llm_classifier / tests /test_api.py
argmin's picture
add files
510a9b0
raw
history blame
697 Bytes
from unittest.mock import Mock
from utils.api import classify_row_chat
def test_classify_row_chat():
# Mock the OpenAI client and its response
client_mock = Mock()
client_mock.chat.completions.create.return_value = Mock(
choices=[Mock(message=Mock(content="Positive"))]
)
# Define the prompt
prompt = "Classify the following observation: Age: 25, Weight: 70\nLabel:"
# Call the classify_row_chat function with the mocked client
prediction = classify_row_chat(prompt=prompt, client=client_mock, model="gpt-3.5-turbo")
# Assert the response matches the expected label
assert prediction == "Positive", "The classification should return 'Positive'"