kaisugi commited on
Commit
b62fe6e
1 Parent(s): bf27a63
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -43,13 +43,13 @@ def load_title_embeddings():
43
 
44
 
45
  def get_retrieval_results(index, input_text, top_k, tokenizer, title_df):
46
- batch_dict = tokenizer(f"query: {input_text}", max_length=512, padding=True, truncation=True, return_tensors='pt')
47
  with torch.no_grad():
48
  outputs = model(**batch_dict)
49
  query_embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
50
  query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
51
 
52
- _, ids = index.search(x=np.array([query_embeddings]), k=top_k)
53
  retrieved_titles = []
54
  retrieved_pids = []
55
 
 
43
 
44
 
45
  def get_retrieval_results(index, input_text, top_k, tokenizer, title_df):
46
+ batch_dict = tokenizer([f"query: {input_text}"], max_length=512, padding=True, truncation=True, return_tensors='pt')
47
  with torch.no_grad():
48
  outputs = model(**batch_dict)
49
  query_embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
50
  query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
51
 
52
+ _, ids = index.search(x=query_embeddings.detach().numpy().copy(), k=top_k)
53
  retrieved_titles = []
54
  retrieved_pids = []
55