Spaces:
Runtime error
Runtime error
secilozksen
commited on
Commit
•
18665b8
1
Parent(s):
02ecb0f
Upload demo_dpr.py
Browse files- demo_dpr.py +7 -7
demo_dpr.py
CHANGED
@@ -10,8 +10,7 @@ import torch
|
|
10 |
from transformers import DPRQuestionEncoderTokenizer, AutoModel
|
11 |
from pathlib import Path
|
12 |
import base64
|
13 |
-
import
|
14 |
-
import tokenizers
|
15 |
|
16 |
st.set_page_config(layout="wide")
|
17 |
|
@@ -55,12 +54,16 @@ def retrieve_rerank(question):
|
|
55 |
top_5_scores.append(hit['cross-score'])
|
56 |
return top_5_contexes, top_5_scores
|
57 |
|
58 |
-
|
|
|
|
|
|
|
|
|
59 |
|
60 |
@st.cache(show_spinner=False, allow_output_mutation=True)
|
61 |
def load_paragraphs(path):
|
62 |
with open(path, "rb") as fIn:
|
63 |
-
cache_data =
|
64 |
corpus_sentences = cache_data['contexes']
|
65 |
corpus_embeddings = cache_data['embeddings']
|
66 |
|
@@ -276,6 +279,3 @@ dpr_trained, bi_encoder, cross_encoder, trained_cross_encoder, question_tokenize
|
|
276 |
|
277 |
qa_main_widgetsv2()
|
278 |
|
279 |
-
#if __name__ == '__main__':
|
280 |
-
# search_pipeline('Life insurance is paid by insurance companies that pay for what?', 1)
|
281 |
-
|
|
|
10 |
from transformers import DPRQuestionEncoderTokenizer, AutoModel
|
11 |
from pathlib import Path
|
12 |
import base64
|
13 |
+
import io
|
|
|
14 |
|
15 |
st.set_page_config(layout="wide")
|
16 |
|
|
|
54 |
top_5_scores.append(hit['cross-score'])
|
55 |
return top_5_contexes, top_5_scores
|
56 |
|
57 |
+
class CPU_Unpickler(pickle.Unpickler):
|
58 |
+
def find_class(self, module, name):
|
59 |
+
if module == 'torch.storage' and name == '_load_from_bytes':
|
60 |
+
return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
|
61 |
+
else: return super().find_class(module, name)
|
62 |
|
63 |
@st.cache(show_spinner=False, allow_output_mutation=True)
|
64 |
def load_paragraphs(path):
|
65 |
with open(path, "rb") as fIn:
|
66 |
+
cache_data = CPU_Unpickler(fIn).load()
|
67 |
corpus_sentences = cache_data['contexes']
|
68 |
corpus_embeddings = cache_data['embeddings']
|
69 |
|
|
|
279 |
|
280 |
qa_main_widgetsv2()
|
281 |
|
|
|
|
|
|