secilozksen commited on
Commit
18665b8
1 Parent(s): 02ecb0f

Upload demo_dpr.py

Browse files
Files changed (1) hide show
  1. demo_dpr.py +7 -7
demo_dpr.py CHANGED
@@ -10,8 +10,7 @@ import torch
10
  from transformers import DPRQuestionEncoderTokenizer, AutoModel
11
  from pathlib import Path
12
  import base64
13
- import regex
14
- import tokenizers
15
 
16
  st.set_page_config(layout="wide")
17
 
@@ -55,12 +54,16 @@ def retrieve_rerank(question):
55
  top_5_scores.append(hit['cross-score'])
56
  return top_5_contexes, top_5_scores
57
 
58
-
 
 
 
 
59
 
60
  @st.cache(show_spinner=False, allow_output_mutation=True)
61
  def load_paragraphs(path):
62
  with open(path, "rb") as fIn:
63
- cache_data = pickle.load(fIn)
64
  corpus_sentences = cache_data['contexes']
65
  corpus_embeddings = cache_data['embeddings']
66
 
@@ -276,6 +279,3 @@ dpr_trained, bi_encoder, cross_encoder, trained_cross_encoder, question_tokenize
276
 
277
  qa_main_widgetsv2()
278
 
279
- #if __name__ == '__main__':
280
- # search_pipeline('Life insurance is paid by insurance companies that pay for what?', 1)
281
-
 
10
  from transformers import DPRQuestionEncoderTokenizer, AutoModel
11
  from pathlib import Path
12
  import base64
13
+ import io
 
14
 
15
  st.set_page_config(layout="wide")
16
 
 
54
  top_5_scores.append(hit['cross-score'])
55
  return top_5_contexes, top_5_scores
56
 
57
+ class CPU_Unpickler(pickle.Unpickler):
58
+ def find_class(self, module, name):
59
+ if module == 'torch.storage' and name == '_load_from_bytes':
60
+ return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
61
+ else: return super().find_class(module, name)
62
 
63
  @st.cache(show_spinner=False, allow_output_mutation=True)
64
  def load_paragraphs(path):
65
  with open(path, "rb") as fIn:
66
+ cache_data = CPU_Unpickler(fIn).load()
67
  corpus_sentences = cache_data['contexes']
68
  corpus_embeddings = cache_data['embeddings']
69
 
 
279
 
280
  qa_main_widgetsv2()
281