File size: 1,918 Bytes
a85c9b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from unittest.mock import MagicMock

import pytest
from chromadb.api.types import Documents, Embeddings

from embedchain.config.embedder.base import BaseEmbedderConfig
from embedchain.embedder.base import BaseEmbedder


@pytest.fixture
def base_embedder():
    return BaseEmbedder()


def test_initialization(base_embedder):
    assert isinstance(base_embedder.config, BaseEmbedderConfig)
    # not initialized
    assert not hasattr(base_embedder, "embedding_fn")
    assert not hasattr(base_embedder, "vector_dimension")


def test_set_embedding_fn(base_embedder):
    def embedding_function(texts: Documents) -> Embeddings:
        return [f"Embedding for {text}" for text in texts]

    base_embedder.set_embedding_fn(embedding_function)
    assert hasattr(base_embedder, "embedding_fn")
    assert callable(base_embedder.embedding_fn)
    embeddings = base_embedder.embedding_fn(["text1", "text2"])
    assert embeddings == ["Embedding for text1", "Embedding for text2"]


def test_set_embedding_fn_when_not_a_function(base_embedder):
    with pytest.raises(ValueError):
        base_embedder.set_embedding_fn(None)


def test_set_vector_dimension(base_embedder):
    base_embedder.set_vector_dimension(256)
    assert hasattr(base_embedder, "vector_dimension")
    assert base_embedder.vector_dimension == 256


def test_set_vector_dimension_type_error(base_embedder):
    with pytest.raises(TypeError):
        base_embedder.set_vector_dimension(None)


def test_langchain_default_concept():
    embeddings = MagicMock()
    embeddings.embed_documents.return_value = ["Embedding1", "Embedding2"]
    embed_function = BaseEmbedder._langchain_default_concept(embeddings)
    result = embed_function(["text1", "text2"])
    assert result == ["Embedding1", "Embedding2"]


def test_embedder_with_config():
    embedder = BaseEmbedder(BaseEmbedderConfig())
    assert isinstance(embedder.config, BaseEmbedderConfig)