llm_classifier / tests /test_prompt.py
argmin's picture
add files
510a9b0
raw
history blame
891 Bytes
import pytest
from utils.prompt import generate_prompts
def test_generate_prompts():
example_rows = [ # Update to match the function's parameter name
{"features": {"Age": 34, "Weight": 70, "Location": "Urban"}, "label": "Positive"},
{"features": {"Age": 25, "Weight": 60, "Location": "Rural"}, "label": "Negative"},
]
features = ["Age", "Weight", "Location"]
label_descriptions = {
"Positive": "The sentiment is positive.",
"Negative": "The sentiment is negative.",
}
row = {"Age": 30, "Weight": 65, "Location": "Suburban"}
system_prompt, user_prompt = generate_prompts(
row=row, example_rows=example_rows, features=features, label_descriptions=label_descriptions
)
assert "Age: 34; Weight: 70; Location: Urban" in system_prompt
assert "Label: Positive" in system_prompt
assert "Label:" in user_prompt