File size: 3,447 Bytes
ac499d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# -*- coding: utf-8 -*-
"""wiki_chat.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1P5rJeCXRSsDJw_1ksnHmodH6ng2Ot5NW
"""

# !pip install gradio

# !pip install -U sentence-transformers

# !pip install datasets


import gradio as gr
from sentence_transformers import SentenceTransformer, CrossEncoder, util
from torch import tensor as torch_tensor
from datasets import load_dataset

"""# import models"""

bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
bi_encoder.max_seq_length = 256     #Truncate long passages to 256 tokens

#The bi-encoder will retrieve top_k documents. We use a cross-encoder, to re-rank the results list to improve the quality
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

"""# import datasets"""
dataset = load_dataset("gfhayworth/wiki_mini", split='train')
mypassages = list(dataset.to_pandas()['psg'])

dataset_embed = load_dataset("gfhayworth/wiki_mini_embed", split='train')
dataset_embed_pd = dataset_embed.to_pandas()
mycorpus_embeddings = torch_tensor(dataset_embed_pd.values)

def search(query, top_k=20, top_n = 1):
  question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
  question_embedding = question_embedding #.cuda()
  hits = util.semantic_search(question_embedding, mycorpus_embeddings, top_k=top_k)
  hits = hits[0]  # Get the hits for the first query

  ##### Re-Ranking #####
  cross_inp = [[query, mypassages[hit['corpus_id']]] for hit in hits]
  cross_scores = cross_encoder.predict(cross_inp)

  # Sort results by the cross-encoder scores
  for idx in range(len(cross_scores)):
    hits[idx]['cross-score'] = cross_scores[idx]

  hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
  predictions = hits[:top_n]
  return predictions
  # for hit in hits[0:3]:
  #     print("\t{:.3f}\t{}".format(hit['cross-score'], mypassages[hit['corpus_id']].replace("\n", " ")))

def get_text(qry):
  predictions = search(qry)
  prediction_text = []
  for hit in predictions:
    prediction_text.append("{}".format(mypassages[hit['corpus_id']]))
  return prediction_text

# def prt_rslt(qry):
#   rslt = get_text(qry)
#   for r in rslt:
#     print(r)

# prt_rslt("who is the best rapper in the world?")

"""# chat example"""

def chat(message, history):
  history = history or []
  message = message.lower()
  
  responses = get_text(message)
  for response in responses:
    history.append((message, response))
  return history, history

css=".gradio-container {background-color: lightgray}"

with gr.Blocks(css=css) as demo:
  history_state = gr.State()
  gr.Markdown('# WikiBot')
  title='Wikipedia Chatbot'
  description='chatbot with search on Wikipedia'
  with gr.Row():
    chatbot = gr.Chatbot()
  with gr.Row():
    message = gr.Textbox(label='Input your question here:',
                         placeholder='How many countries are in Europe?',
                         lines=1)
    submit = gr.Button(value='Send',
                       variant='secondary').style(full_width=False)
  submit.click(chat,
               inputs=[message, history_state],
               outputs=[chatbot, history_state])
  gr.Examples(
            examples=["How many countries are in Europe?",
                      "Was Roman Emperor Constantine I a Christian?",
                      "Who is the best rapper in the world?"],
            inputs=message
        )

demo.launch()