File size: 739 Bytes
c3deef6
 
7696112
 
 
 
 
 
c3deef6
7696112
9f34c3a
 
 
 
 
da851f5
9f34c3a
 
 
 
 
 
 
 
 
8bc5035
9f34c3a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
---
license: bigscience-bloom-rail-1.0
datasets:
- squad
language:
- fr
- en
pipeline_tag: sentence-similarity
---

```python
import numpy as np
from transformers import pipeline
from scipy.spatial.distance import cdist

retriever = pipeline('feature-extraction', 'cmarkea/bloomz-3b-retriever')
infer = lambda x: [ii[0][-1] for ii in retriever(x)]

list_of_contexts = [...]
emb_contexts = np.concatenate(infer(list_of_contexts), axis=0)
list_of_queries = [...]
emb_queries = np.concatenate(infer(list_of_queries), axis=0)

dist = cdist(emb_queries, emb_contexts, 'euclidean')
top_k = lambda x: [[list_of_contexts[qq] for qq in ii] for ii in dist.argsort(axis=-1)[:,:x]]
# top 5 nearest contexts for each queries
top_contexts = top_k(5)
```