File size: 5,587 Bytes
287a0bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from typing import Generator, cast
import numpy as np
import pytest
import chromadb
from chromadb.api.types import (
    Embeddable,
    EmbeddingFunction,
    Embeddings,
    Image,
    Document,
)
from chromadb.test.property.strategies import hashing_embedding_function
from chromadb.test.property.invariants import _exact_distances


# A 'standard' multimodal embedding function, which converts inputs to strings
# then hashes them to a fixed dimension.
class hashing_multimodal_ef(EmbeddingFunction[Embeddable]):
    def __init__(self) -> None:
        self._hef = hashing_embedding_function(dim=10, dtype=np.float_)

    def __call__(self, input: Embeddable) -> Embeddings:
        to_texts = [str(i) for i in input]
        embeddings = np.array(self._hef(to_texts))
        # Normalize the embeddings
        # This is so we can generate random unit vectors and have them be close to the embeddings
        embeddings /= np.linalg.norm(embeddings, axis=1, keepdims=True)
        return cast(Embeddings, embeddings.tolist())


def random_image() -> Image:
    return np.random.randint(0, 255, size=(10, 10, 3), dtype=np.int32)


def random_document() -> Document:
    return str(random_image())


@pytest.fixture
def multimodal_collection(
    default_ef: EmbeddingFunction[Embeddable] = hashing_multimodal_ef(),
) -> Generator[chromadb.Collection, None, None]:
    client = chromadb.Client()
    collection = client.create_collection(
        name="multimodal_collection", embedding_function=default_ef
    )
    yield collection
    client.clear_system_cache()


# Test adding and querying of a multimodal collection consisting of images and documents
def test_multimodal(
    multimodal_collection: chromadb.Collection,
    default_ef: EmbeddingFunction[Embeddable] = hashing_multimodal_ef(),
    n_examples: int = 10,
    n_query_results: int = 3,
) -> None:
    # Fix numpy's random seed for reproducibility
    random_state = np.random.get_state()
    np.random.seed(0)

    image_ids = [str(i) for i in range(n_examples)]
    images = [random_image() for _ in range(n_examples)]
    image_embeddings = default_ef(images)

    document_ids = [str(i) for i in range(n_examples, 2 * n_examples)]
    documents = [random_document() for _ in range(n_examples)]
    document_embeddings = default_ef(documents)

    # Trying to add a document and an image at the same time should fail
    with pytest.raises(
        ValueError, match="You can only provide documents or images, not both."
    ):
        multimodal_collection.add(
            ids=image_ids[0], documents=documents[0], images=images[0]
        )

    # Add some documents
    multimodal_collection.add(ids=document_ids, documents=documents)
    # Add some images
    multimodal_collection.add(ids=image_ids, images=images)

    # get() should return all the documents and images
    # ids corresponding to images should not have documents
    get_result = multimodal_collection.get(include=["documents"])
    assert len(get_result["ids"]) == len(document_ids) + len(image_ids)
    for i, id in enumerate(get_result["ids"]):
        assert id in document_ids or id in image_ids
        assert get_result["documents"] is not None
        if id in document_ids:
            assert get_result["documents"][i] == documents[document_ids.index(id)]
        if id in image_ids:
            assert get_result["documents"][i] is None

    # Generate a random query image
    query_image = random_image()
    query_image_embedding = default_ef([query_image])

    image_neighbor_indices, _ = _exact_distances(
        query_image_embedding, image_embeddings + document_embeddings
    )
    # Get the ids of the nearest neighbors
    nearest_image_neighbor_ids = [
        image_ids[i] if i < n_examples else document_ids[i % n_examples]
        for i in image_neighbor_indices[0][:n_query_results]
    ]

    # Generate a random query document
    query_document = random_document()
    query_document_embedding = default_ef([query_document])
    document_neighbor_indices, _ = _exact_distances(
        query_document_embedding, image_embeddings + document_embeddings
    )
    nearest_document_neighbor_ids = [
        image_ids[i] if i < n_examples else document_ids[i % n_examples]
        for i in document_neighbor_indices[0][:n_query_results]
    ]

    # Querying with both images and documents should fail
    with pytest.raises(ValueError):
        multimodal_collection.query(
            query_images=[query_image], query_texts=[query_document]
        )

    # Query with images
    query_result = multimodal_collection.query(
        query_images=[query_image], n_results=n_query_results, include=["documents"]
    )

    assert query_result["ids"][0] == nearest_image_neighbor_ids

    # Query with documents
    query_result = multimodal_collection.query(
        query_texts=[query_document], n_results=n_query_results, include=["documents"]
    )

    assert query_result["ids"][0] == nearest_document_neighbor_ids
    np.random.set_state(random_state)


@pytest.mark.xfail
def test_multimodal_update_with_image(
    multimodal_collection: chromadb.Collection,
) -> None:
    # Updating an entry with an existing document should remove the documentß

    document = random_document()
    image = random_image()
    id = "0"

    multimodal_collection.add(ids=id, documents=document)

    multimodal_collection.update(ids=id, images=image)

    get_result = multimodal_collection.get(ids=id, include=["documents"])
    assert get_result["documents"] is not None
    assert get_result["documents"][0] is None