File size: 2,623 Bytes
1e57a2c
 
 
 
 
 
 
 
 
 
 
 
e7ded97
1e57a2c
 
 
 
 
 
8fc12bf
 
 
 
1e57a2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import streamlit as st
from sentence_transformers import SentenceTransformer, CrossEncoder, util
import torch
from huggingface_hub import hf_hub_download

embedding_path = "abokbot/wikipedia-embedding"

st.header("Wikipedia Search Engine app")

st_model_load = st.text('Loading wikipedia embedding...')

@st.cache_resource
def load_embedding():
    print("Loading embedding...")
    hf_hub_download(repo_id="abokbot/wikipedia-embedding", filename="simple_wikipedia_embedding.pt")
    wikipedia_embedding = torch.load("wikipedia-embedding/simple_wikipedia_embedding.pt") 
    print("Embedding loaded!")
    return wikipedia_embedding

wikipedia_embedding = load_embedding()
st.success('Embedding loaded!')
st_model_load.text("")

"""


#We use the Bi-Encoder to encode all passages, so that we can use it with sematic search
# cf https://www.sbert.net/docs/pretrained-models/msmarco-v3.html
bi_encoder = SentenceTransformer('msmarco-MiniLM-L-6-v3')
bi_encoder.max_seq_length = 256     #Truncate long passages to 256 tokens
top_k = 32                          #Number of passages we want to retrieve with the bi-encoder

#The bi-encoder will retrieve 100 documents. We use a cross-encoder, to re-rank the results list to improve the quality
cross_encoder = CrossEncoder('cross-encoder/ms-marco-TinyBERT-L-2-v2')

def search(query):
    print("Input question:", query)
    ##### Sematic Search #####
    # Encode the query using the bi-encoder and find potentially relevant passages
    question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
    question_embedding = question_embedding.cuda()
    hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
    hits = hits[0]  # Get the hits for the first query

    ##### Re-Ranking #####
    # Now, score all retrieved passages with the cross_encoder
    cross_inp = [[query, dataset["text"][hit['corpus_id']]] for hit in hits]
    cross_scores = cross_encoder.predict(cross_inp)

    # Sort results by the cross-encoder scores
    for idx in range(len(cross_scores)):
        hits[idx]['cross-score'] = cross_scores[idx]

    # Output of top-3 hits from re-ranker
    print("\n-------------------------\n")
    print("Top-3 Cross-Encoder Re-ranker hits")
    hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
    for hit in hits[0:3]:
        print("score: ",  round(hit['cross-score'], 3),"\n",
              "title: ", dataset["title"][hit['corpus_id']], "\n", 
              "substract: ", dataset["text"][hit['corpus_id']].replace("\n", " "), "\n", 
              "link: ", dataset["url"][hit['corpus_id']],"\n")


"""