Cyrile commited on
Commit
83c1880
1 Parent(s): 3b50f73

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +2 -2
README.md CHANGED
@@ -51,14 +51,14 @@ from transformers import pipeline
51
  from scipy.spatial.distance import cdist
52
 
53
  retriever = pipeline('feature-extraction', 'cmarkea/bloomz-3b-retriever')
54
- infer = lambda x: [ii[0][-1] for ii in retriever(x)]
55
 
56
  list_of_contexts = [...]
57
  emb_contexts = np.concatenate(infer(list_of_contexts), axis=0)
58
  list_of_queries = [...]
59
  emb_queries = np.concatenate(infer(list_of_queries), axis=0)
60
 
61
- dist = cdist(emb_queries, emb_contexts, 'euclidean')
62
  top_k = lambda x: [
63
  [list_of_contexts[qq] for qq in ii]
64
  for ii in dist.argsort(axis=-1)[:,:x]
 
51
  from scipy.spatial.distance import cdist
52
 
53
  retriever = pipeline('feature-extraction', 'cmarkea/bloomz-3b-retriever')
54
+ infer = lambda x: [ii[0][-1] for ii in retriever(x)] # Inportant: take only last token!
55
 
56
  list_of_contexts = [...]
57
  emb_contexts = np.concatenate(infer(list_of_contexts), axis=0)
58
  list_of_queries = [...]
59
  emb_queries = np.concatenate(infer(list_of_queries), axis=0)
60
 
61
+ dist = cdist(emb_queries, emb_contexts, 'euclidean') # Important: take l2 distance!
62
  top_k = lambda x: [
63
  [list_of_contexts[qq] for qq in ii]
64
  for ii in dist.argsort(axis=-1)[:,:x]