Spaces:
Running
Running
from unittest.mock import MagicMock, patch | |
from urllib.parse import urljoin | |
import pytest | |
from langchain_ollama import ChatOllama | |
from langflow.components.models import ChatOllamaComponent | |
def component(): | |
return ChatOllamaComponent() | |
def test_get_model_success(mock_get, component): | |
mock_response = MagicMock() | |
mock_response.json.return_value = {"models": [{"name": "model1"}, {"name": "model2"}]} | |
mock_response.raise_for_status.return_value = None | |
mock_get.return_value = mock_response | |
base_url = "http://localhost:11434" | |
model_names = component.get_model(base_url) | |
expected_url = urljoin(base_url, "/api/tags") | |
mock_get.assert_called_once_with(expected_url) | |
assert model_names == ["model1", "model2"] | |
def test_get_model_failure(mock_get, component): | |
# Mock the response for the HTTP GET request to raise an exception | |
mock_get.side_effect = Exception("HTTP request failed") | |
url = "http://localhost:11434/api/tags" | |
# Assert that the ValueError is raised when an exception occurs | |
with pytest.raises(ValueError, match="Could not retrieve models"): | |
component.get_model(url) | |
def test_update_build_config_mirostat_disabled(component): | |
build_config = { | |
"mirostat_eta": {"advanced": False, "value": 0.1}, | |
"mirostat_tau": {"advanced": False, "value": 5}, | |
} | |
field_value = "Disabled" | |
field_name = "mirostat" | |
updated_config = component.update_build_config(build_config, field_value, field_name) | |
assert updated_config["mirostat_eta"]["advanced"] is True | |
assert updated_config["mirostat_tau"]["advanced"] is True | |
assert updated_config["mirostat_eta"]["value"] is None | |
assert updated_config["mirostat_tau"]["value"] is None | |
def test_update_build_config_mirostat_enabled(component): | |
build_config = { | |
"mirostat_eta": {"advanced": False, "value": None}, | |
"mirostat_tau": {"advanced": False, "value": None}, | |
} | |
field_value = "Mirostat 2.0" | |
field_name = "mirostat" | |
updated_config = component.update_build_config(build_config, field_value, field_name) | |
assert updated_config["mirostat_eta"]["advanced"] is False | |
assert updated_config["mirostat_tau"]["advanced"] is False | |
assert updated_config["mirostat_eta"]["value"] == 0.2 | |
assert updated_config["mirostat_tau"]["value"] == 10 | |
def test_update_build_config_model_name(mock_get, component): | |
# Mock the response for the HTTP GET request | |
mock_response = MagicMock() | |
mock_response.json.return_value = {"models": [{"name": "model1"}, {"name": "model2"}]} | |
mock_response.raise_for_status.return_value = None | |
mock_get.return_value = mock_response | |
build_config = { | |
"base_url": {"load_from_db": False, "value": None}, | |
"model_name": {"options": []}, | |
} | |
field_value = None | |
field_name = "model_name" | |
updated_config = component.update_build_config(build_config, field_value, field_name) | |
assert updated_config["model_name"]["options"] == ["model1", "model2"] | |
def test_update_build_config_keep_alive(component): | |
build_config = {"keep_alive": {"value": None, "advanced": False}} | |
field_value = "Keep" | |
field_name = "keep_alive_flag" | |
updated_config = component.update_build_config(build_config, field_value, field_name) | |
assert updated_config["keep_alive"]["value"] == "-1" | |
assert updated_config["keep_alive"]["advanced"] is True | |
field_value = "Immediately" | |
updated_config = component.update_build_config(build_config, field_value, field_name) | |
assert updated_config["keep_alive"]["value"] == "0" | |
assert updated_config["keep_alive"]["advanced"] is True | |
def test_build_model(_mock_chat_ollama, component): # noqa: PT019 | |
component.base_url = "http://localhost:11434" | |
component.model_name = "llama3.1" | |
component.mirostat = "Mirostat 2.0" | |
component.mirostat_eta = 0.2 # Ensure this is set as a float | |
component.mirostat_tau = 10.0 # Ensure this is set as a float | |
component.temperature = 0.2 | |
component.verbose = True | |
model = component.build_model() | |
assert isinstance(model, ChatOllama) | |
assert model.base_url == "http://localhost:11434" | |
assert model.model == "llama3.1" | |