File size: 1,941 Bytes
cfc4d1a
 
 
 
 
 
 
 
 
ff3b08e
cfc4d1a
ff3b08e
 
 
cfc4d1a
 
ff3b08e
 
 
 
 
 
 
 
 
 
cfc4d1a
 
 
 
 
ff3b08e
cfc4d1a
 
 
 
ff3b08e
cfc4d1a
 
 
ff3b08e
cfc4d1a
ff3b08e
 
cfc4d1a
 
 
 
 
ff3b08e
cfc4d1a
 
 
 
 
 
 
 
ff3b08e
 
 
cfc4d1a
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# + tags=["hide_inp"]
desc = """
### Question Answering with Retrieval

Chain that answers questions with embeedding based retrieval. [[Code](https://github.com/srush/MiniChain/blob/main/examples/qa.py)]

(Adapted from [OpenAI Notebook](https://github.com/openai/openai-cookbook/blob/main/examples/Question_answering_using_embeddings.ipynb).)
"""
# -

# $

import datasets
import numpy as np
from minichain import prompt, show, OpenAIEmbed, OpenAI
from manifest import Manifest

# We use Hugging Face Datasets as the database by assigning
# a FAISS index.

olympics = datasets.load_from_disk("olympics.data")
olympics.add_faiss_index("embeddings")


# Fast KNN retieval prompt

@prompt(OpenAIEmbed())
def get_neighbors(model, inp, k):
    embedding = model(inp)
    res = olympics.get_nearest_examples("embeddings", np.array(embedding), k)
    return res.examples["content"]

@prompt(OpenAI(),
        template_file="qa.pmpt.tpl")
def get_result(model, query, neighbors):
    return model(dict(question=query, docs=neighbors))

def qa(query):
    n = get_neighbors(query, 3)
    return get_result(query, n)

# $


questions = ["Who won the 2020 Summer Olympics men's high jump?",
             "Why was the 2020 Summer Olympics originally postponed?",
             "In the 2020 Summer Olympics, how many gold medals did the country which won the most medals win?",
             "What is the total number of medals won by France?",
             "What is the tallest mountain in the world?"]

gradio = show(qa,
              examples=questions,
              subprompts=[get_neighbors, get_result],
              description=desc,
              code=open("qa.py", "r").read().split("$")[1].strip().strip("#").strip(),
              )
if __name__ == "__main__":
    gradio.launch()



# # + tags=["hide_inp"]
# QAPrompt().show(
#     {"question": "Who won the race?", "docs": ["doc1", "doc2", "doc3"]}, "Joe Bob"
# )
# # -

# show_log("qa.log")