agent-flow / src /backend /tests /unit /components /models /test_chatollama_component.py
Tai Truong
fix readme
d202ada
raw
history blame
4.46 kB
from unittest.mock import MagicMock, patch
from urllib.parse import urljoin
import pytest
from langchain_ollama import ChatOllama
from langflow.components.models import ChatOllamaComponent
@pytest.fixture
def component():
return ChatOllamaComponent()
@patch("httpx.Client.get")
def test_get_model_success(mock_get, component):
mock_response = MagicMock()
mock_response.json.return_value = {"models": [{"name": "model1"}, {"name": "model2"}]}
mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response
base_url = "http://localhost:11434"
model_names = component.get_model(base_url)
expected_url = urljoin(base_url, "/api/tags")
mock_get.assert_called_once_with(expected_url)
assert model_names == ["model1", "model2"]
@patch("httpx.Client.get")
def test_get_model_failure(mock_get, component):
# Mock the response for the HTTP GET request to raise an exception
mock_get.side_effect = Exception("HTTP request failed")
url = "http://localhost:11434/api/tags"
# Assert that the ValueError is raised when an exception occurs
with pytest.raises(ValueError, match="Could not retrieve models"):
component.get_model(url)
def test_update_build_config_mirostat_disabled(component):
build_config = {
"mirostat_eta": {"advanced": False, "value": 0.1},
"mirostat_tau": {"advanced": False, "value": 5},
}
field_value = "Disabled"
field_name = "mirostat"
updated_config = component.update_build_config(build_config, field_value, field_name)
assert updated_config["mirostat_eta"]["advanced"] is True
assert updated_config["mirostat_tau"]["advanced"] is True
assert updated_config["mirostat_eta"]["value"] is None
assert updated_config["mirostat_tau"]["value"] is None
def test_update_build_config_mirostat_enabled(component):
build_config = {
"mirostat_eta": {"advanced": False, "value": None},
"mirostat_tau": {"advanced": False, "value": None},
}
field_value = "Mirostat 2.0"
field_name = "mirostat"
updated_config = component.update_build_config(build_config, field_value, field_name)
assert updated_config["mirostat_eta"]["advanced"] is False
assert updated_config["mirostat_tau"]["advanced"] is False
assert updated_config["mirostat_eta"]["value"] == 0.2
assert updated_config["mirostat_tau"]["value"] == 10
@patch("httpx.Client.get")
def test_update_build_config_model_name(mock_get, component):
# Mock the response for the HTTP GET request
mock_response = MagicMock()
mock_response.json.return_value = {"models": [{"name": "model1"}, {"name": "model2"}]}
mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response
build_config = {
"base_url": {"load_from_db": False, "value": None},
"model_name": {"options": []},
}
field_value = None
field_name = "model_name"
updated_config = component.update_build_config(build_config, field_value, field_name)
assert updated_config["model_name"]["options"] == ["model1", "model2"]
def test_update_build_config_keep_alive(component):
build_config = {"keep_alive": {"value": None, "advanced": False}}
field_value = "Keep"
field_name = "keep_alive_flag"
updated_config = component.update_build_config(build_config, field_value, field_name)
assert updated_config["keep_alive"]["value"] == "-1"
assert updated_config["keep_alive"]["advanced"] is True
field_value = "Immediately"
updated_config = component.update_build_config(build_config, field_value, field_name)
assert updated_config["keep_alive"]["value"] == "0"
assert updated_config["keep_alive"]["advanced"] is True
@patch(
"langchain_community.chat_models.ChatOllama",
return_value=ChatOllama(base_url="http://localhost:11434", model="llama3.1"),
)
def test_build_model(_mock_chat_ollama, component): # noqa: PT019
component.base_url = "http://localhost:11434"
component.model_name = "llama3.1"
component.mirostat = "Mirostat 2.0"
component.mirostat_eta = 0.2 # Ensure this is set as a float
component.mirostat_tau = 10.0 # Ensure this is set as a float
component.temperature = 0.2
component.verbose = True
model = component.build_model()
assert isinstance(model, ChatOllama)
assert model.base_url == "http://localhost:11434"
assert model.model == "llama3.1"