barunsaha commited on
Commit
8f5ac4a
·
1 Parent(s): 622c44e

Add tests for icon embeddings

Browse files
Files changed (1) hide show
  1. tests/unit/test_icons_embeddings.py +219 -0
tests/unit/test_icons_embeddings.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unit tests for the icons embeddings module.
3
+ """
4
+ import importlib
5
+ import sys
6
+ from pathlib import Path
7
+ from types import SimpleNamespace
8
+ from typing import Any, List
9
+
10
+ import numpy as np
11
+
12
+
13
+ def _reload_module_with_dummies(monkeypatch: Any, emb_dim: int = 4):
14
+ """
15
+ Reload the icons_embeddings module after monkeypatching the
16
+ Transformers constructors to return lightweight dummy objects.
17
+
18
+ This prevents network/download or heavy model initialization during
19
+ tests and allows deterministic embeddings.
20
+
21
+ Args:
22
+ monkeypatch: The pytest monkeypatch fixture.
23
+ emb_dim: The embedding dimensionality that the dummy model
24
+ should produce.
25
+
26
+ Returns:
27
+ The reloaded module object.
28
+ """
29
+ class DummyTokenizer:
30
+ def __call__(self, texts, return_tensors=None, padding=None,
31
+ max_length=None, truncation=None):
32
+ if isinstance(texts, str):
33
+ texts_list = [texts]
34
+ else:
35
+ texts_list = list(texts)
36
+
37
+ return {'texts': texts_list}
38
+
39
+
40
+ class DummyTensor:
41
+ def __init__(self, arr: np.ndarray) -> None:
42
+ self.arr = arr
43
+
44
+ def mean(self, dim: int) -> 'DummyTensor':
45
+ # Take numpy mean along the requested axis to emulate PyTorch.
46
+ return DummyTensor(self.arr.mean(axis=dim))
47
+
48
+ def detach(self) -> 'DummyTensor':
49
+ return self
50
+
51
+ def numpy(self) -> np.ndarray:
52
+ return self.arr
53
+
54
+
55
+ class DummyModel:
56
+ def __call__(self, **inputs: Any) -> SimpleNamespace:
57
+ texts = inputs.get('texts', [])
58
+ n = len(texts)
59
+ seq_len = 3
60
+ arr = np.arange(n * seq_len * emb_dim, dtype=float)
61
+ arr = arr.reshape((n, seq_len, emb_dim))
62
+ return SimpleNamespace(last_hidden_state=DummyTensor(arr))
63
+
64
+ monkeypatch.setattr(
65
+ 'transformers.BertTokenizer.from_pretrained',
66
+ lambda name: DummyTokenizer(),
67
+ )
68
+ monkeypatch.setattr(
69
+ 'transformers.BertModel.from_pretrained',
70
+ lambda name: DummyModel(),
71
+ )
72
+
73
+ if 'slidedeckai.helpers.icons_embeddings' in sys.modules:
74
+ mod = importlib.reload(sys.modules['slidedeckai.helpers.icons_embeddings'])
75
+ else:
76
+ mod = importlib.import_module('slidedeckai.helpers.icons_embeddings')
77
+
78
+ return mod
79
+
80
+
81
+ def test_get_icons_list(tmp_path: Path, monkeypatch: Any) -> None:
82
+ """
83
+ get_icons_list should return the stems of PNG files in the
84
+ configured icons directory.
85
+ """
86
+ mod = _reload_module_with_dummies(monkeypatch)
87
+
88
+ # Prepare a temporary icons directory with some files.
89
+ icons_dir = tmp_path / 'icons'
90
+ icons_dir.mkdir()
91
+ (icons_dir / 'apple.png').write_text('x')
92
+ (icons_dir / 'banana.png').write_text('y')
93
+ (icons_dir / 'not_an_icon.txt').write_text('z')
94
+
95
+ monkeypatch.setattr(mod.GlobalConfig, 'ICONS_DIR', icons_dir)
96
+
97
+ icons = mod.get_icons_list()
98
+ assert set(icons) == {'apple', 'banana'}
99
+
100
+
101
+ def test_get_embeddings_single_and_list(monkeypatch: Any) -> None:
102
+ """
103
+ get_embeddings must return numpy arrays with the expected shapes for
104
+ single string and list inputs.
105
+ """
106
+ emb_dim = 5
107
+ mod = _reload_module_with_dummies(monkeypatch, emb_dim=emb_dim)
108
+
109
+ # Single string -> shape (1, emb_dim)
110
+ arr1 = mod.get_embeddings('hello')
111
+ assert isinstance(arr1, np.ndarray)
112
+ assert arr1.shape == (1, emb_dim)
113
+
114
+ # List of strings -> shape (3, emb_dim)
115
+ arr2 = mod.get_embeddings(['a', 'b', 'c'])
116
+ assert arr2.shape == (3, emb_dim)
117
+
118
+ # Verify determinism from our dummy model for the first row.
119
+ # The dummy model fills values with a range; mean over axis=1 reduces
120
+ # the seq_len dimension.
121
+ expected_first_row = np.arange(3 * emb_dim).reshape((3, emb_dim)).mean(axis=0)
122
+ assert np.allclose(arr2[0], expected_first_row)
123
+
124
+
125
+ def test_save_and_load_embeddings(tmp_path: Path, monkeypatch: Any) -> None:
126
+ """
127
+ save_icons_embeddings should write embeddings and file names to the
128
+ configured paths and load_saved_embeddings should read them back.
129
+ """
130
+ emb_dim = 6
131
+ mod = _reload_module_with_dummies(monkeypatch, emb_dim=emb_dim)
132
+
133
+ # Create icons dir with files.
134
+ icons_dir = tmp_path / 'icons2'
135
+ icons_dir.mkdir()
136
+ (icons_dir / 'one.png').write_text('1')
137
+ (icons_dir / 'two.png').write_text('2')
138
+
139
+ monkeypatch.setattr(mod.GlobalConfig, 'ICONS_DIR', icons_dir)
140
+ emb_file = tmp_path / 'emb.npy'
141
+ names_file = tmp_path / 'names.npy'
142
+ monkeypatch.setattr(mod.GlobalConfig, 'EMBEDDINGS_FILE_NAME', str(emb_file))
143
+ monkeypatch.setattr(mod.GlobalConfig, 'ICONS_FILE_NAME', str(names_file))
144
+
145
+ # Run save which uses the dummy tokenizer/model to create embeddings.
146
+ mod.save_icons_embeddings()
147
+
148
+ assert emb_file.exists()
149
+ assert names_file.exists()
150
+
151
+ loaded_emb, loaded_names = mod.load_saved_embeddings()
152
+ assert isinstance(loaded_emb, np.ndarray)
153
+ assert isinstance(loaded_names, np.ndarray)
154
+ assert loaded_emb.shape[0] == len(loaded_names)
155
+
156
+
157
+ def test_find_icons(monkeypatch: Any, tmp_path: Path) -> None:
158
+ """
159
+ find_icons should map keywords to the most similar icon filenames
160
+ based on cosine similarity against pre-saved embeddings.
161
+ """
162
+ # Reload module with dummy model but we will monkeypatch get_embeddings
163
+ # to control keyword embeddings precisely.
164
+ mod = _reload_module_with_dummies(monkeypatch, emb_dim=3)
165
+
166
+ # Prepare saved embeddings with two icons.
167
+ emb = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]])
168
+ names = np.array(['a_icon', 'b_icon'])
169
+
170
+ emb_file = tmp_path / 'emb_s.npy'
171
+ names_file = tmp_path / 'names_s.npy'
172
+ np.save(str(emb_file), emb)
173
+ np.save(str(names_file), names)
174
+
175
+ monkeypatch.setattr(mod.GlobalConfig, 'EMBEDDINGS_FILE_NAME', str(emb_file))
176
+ monkeypatch.setattr(mod.GlobalConfig, 'ICONS_FILE_NAME', str(names_file))
177
+
178
+ # Make keyword embeddings match each saved one.
179
+ def fake_get_embeddings(keywords: List[str]) -> np.ndarray:
180
+ out = []
181
+ for kw in keywords:
182
+ if kw == 'match_a':
183
+ out.append([1.0, 0.0, 0.0])
184
+ else:
185
+ out.append([0.0, 1.0, 0.0])
186
+ return np.array(out)
187
+
188
+ monkeypatch.setattr(mod, 'get_embeddings', fake_get_embeddings)
189
+
190
+ res = mod.find_icons(['match_a', 'other'])
191
+ assert list(res) == ['a_icon', 'b_icon']
192
+
193
+
194
+ def test_main_calls_and_prints(monkeypatch: Any, capsys: Any) -> None:
195
+ """
196
+ main should call save_icons_embeddings and find_icons and print the
197
+ zipped results. We monkeypatch the heavy functions to keep it fast.
198
+ """
199
+ mod = _reload_module_with_dummies(monkeypatch)
200
+ called = {}
201
+
202
+ def fake_save():
203
+ called['saved'] = True
204
+
205
+
206
+ def fake_find(keywords: List[str]) -> List[str]:
207
+ called['found'] = True
208
+ return ['x' for _ in keywords]
209
+
210
+
211
+ monkeypatch.setattr(mod, 'save_icons_embeddings', fake_save)
212
+ monkeypatch.setattr(mod, 'find_icons', fake_find)
213
+
214
+ mod.main()
215
+
216
+ captured = capsys.readouterr()
217
+ assert 'The relevant icon files are' in captured.out
218
+ assert called.get('saved') is True
219
+ assert called.get('found') is True