Spaces:
Sleeping
Sleeping
Stefan
commited on
Commit
•
bb3407a
1
Parent(s):
7ce98a0
feat(setup): initial commit
Browse files- .gitignore +2 -0
- .vscode/settings.json +3 -0
- Pipfile +33 -0
- Pipfile.lock +0 -0
- embedding.py +48 -0
- main.py +28 -0
- pg.py +41 -0
- processing.py +95 -0
- requirements.txt +97 -0
- vectors.py +38 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
data*/
|
2 |
+
.env
|
.vscode/settings.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"editor.defaultFormatter": "ms-python.black-formatter"
|
3 |
+
}
|
Pipfile
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[[source]]
|
2 |
+
url = "https://pypi.org/simple"
|
3 |
+
verify_ssl = true
|
4 |
+
name = "pypi"
|
5 |
+
|
6 |
+
[packages]
|
7 |
+
numpy = "*"
|
8 |
+
pandas = "*"
|
9 |
+
torch = "*"
|
10 |
+
transformers = "*"
|
11 |
+
accelerate = "*"
|
12 |
+
sentencepiece = "*"
|
13 |
+
protobuf = "==3.20.1"
|
14 |
+
aiohttp = "*"
|
15 |
+
aiodns = "*"
|
16 |
+
brotli = "*"
|
17 |
+
python-dotenv = "*"
|
18 |
+
openai = "*"
|
19 |
+
nest-asyncio = "*"
|
20 |
+
tqdm = "*"
|
21 |
+
tiktoken = "*"
|
22 |
+
instructorembedding = "*"
|
23 |
+
markdown = "*"
|
24 |
+
sentence-transformers = "*"
|
25 |
+
pinecone-client = "*"
|
26 |
+
psycopg2 = "*"
|
27 |
+
gradio = "*"
|
28 |
+
|
29 |
+
[dev-packages]
|
30 |
+
ipykernel = "*"
|
31 |
+
|
32 |
+
[requires]
|
33 |
+
python_version = "3.11"
|
Pipfile.lock
ADDED
The diff for this file is too large to render.
See raw diff
|
|
embedding.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import Tensor
|
2 |
+
import tiktoken
|
3 |
+
from transformers import AutoTokenizer, AutoModel
|
4 |
+
|
5 |
+
tokenizer = AutoTokenizer.from_pretrained("intfloat/e5-large-v2")
|
6 |
+
model = AutoModel.from_pretrained("intfloat/e5-large-v2")
|
7 |
+
|
8 |
+
EMBEDDING_CHAR_LIMIT = 512
|
9 |
+
|
10 |
+
|
11 |
+
def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
|
12 |
+
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
13 |
+
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
14 |
+
|
15 |
+
|
16 |
+
def strings_to_vectors(strings: list[str]):
|
17 |
+
passage_batch = tokenizer(
|
18 |
+
strings,
|
19 |
+
max_length=EMBEDDING_CHAR_LIMIT,
|
20 |
+
padding=True,
|
21 |
+
truncation=True,
|
22 |
+
return_tensors="pt",
|
23 |
+
)
|
24 |
+
passage_outputs = model(**passage_batch)
|
25 |
+
return average_pool(
|
26 |
+
passage_outputs.last_hidden_state, passage_batch["attention_mask"]
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
def num_tokens_from_str(string, model="gpt-3.5-turbo"):
|
31 |
+
"""Returns the number of tokens used by a list of messages."""
|
32 |
+
try:
|
33 |
+
encoding = tiktoken.encoding_for_model(model)
|
34 |
+
except KeyError:
|
35 |
+
encoding = tiktoken.get_encoding("cl100k_base")
|
36 |
+
if model == "gpt-3.5-turbo": # note: future models may deviate from this
|
37 |
+
num_tokens = 0
|
38 |
+
num_tokens += (
|
39 |
+
4 # every message follows <im_start>{role/name}\n{content}<im_end>\n
|
40 |
+
)
|
41 |
+
num_tokens += len(encoding.encode(string))
|
42 |
+
num_tokens += 2 # every reply is primed with <im_start>assistant
|
43 |
+
return num_tokens
|
44 |
+
else:
|
45 |
+
raise NotImplementedError(
|
46 |
+
f"""num_tokens_from_messages() is not presently implemented for model {model}.
|
47 |
+
See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
|
48 |
+
)
|
main.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from processing import md_to_passages
|
3 |
+
from pg import get_chapters
|
4 |
+
from vectors import match_query
|
5 |
+
|
6 |
+
|
7 |
+
def find_embedding(query: str):
|
8 |
+
top_res = match_query(query, 3)
|
9 |
+
# print(top_res)
|
10 |
+
|
11 |
+
chapters = get_chapters(list(map(lambda x: x["metadata"]["chapterId"], top_res)))
|
12 |
+
|
13 |
+
output = ""
|
14 |
+
|
15 |
+
for res, chapter in zip(top_res, chapters):
|
16 |
+
passages = md_to_passages(chapter["explanation"])
|
17 |
+
output += f"{res['id']}\t| score: {res['score']:.2f}%\n{passages[res['passage_idx']]}\n\n"
|
18 |
+
|
19 |
+
return output
|
20 |
+
|
21 |
+
|
22 |
+
with gr.Blocks() as quesbook_search:
|
23 |
+
question = gr.Text(label="question")
|
24 |
+
answer = gr.Text(label="answer")
|
25 |
+
submit = gr.Button("Submit")
|
26 |
+
submit.click(fn=find_embedding, inputs=question, outputs=answer)
|
27 |
+
|
28 |
+
quesbook_search.launch()
|
pg.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import psycopg2
|
2 |
+
import os
|
3 |
+
|
4 |
+
pg = psycopg2.connect(
|
5 |
+
dbname=os.getenv("POSTGRES_DB"),
|
6 |
+
user=os.getenv("POSTGRES_USER"),
|
7 |
+
password=os.getenv("POSTGRES_PASSWORD"),
|
8 |
+
port=os.getenv("POSTGRES_PORT"),
|
9 |
+
host=os.getenv("POSTGRES_HOST"),
|
10 |
+
)
|
11 |
+
|
12 |
+
|
13 |
+
def get_chapters(ids: list[int]):
|
14 |
+
cur = pg.cursor()
|
15 |
+
cur.execute(
|
16 |
+
"""
|
17 |
+
SELECT
|
18 |
+
ch.id,
|
19 |
+
ch.explanation
|
20 |
+
FROM
|
21 |
+
chapters ch
|
22 |
+
WHERE
|
23 |
+
ch.id = ANY (%s);
|
24 |
+
""",
|
25 |
+
(ids,),
|
26 |
+
)
|
27 |
+
data = cur.fetchall()
|
28 |
+
cur.close()
|
29 |
+
|
30 |
+
chapters = list(map(lambda x: {"id": x[0], "explanation": x[1]}, data))
|
31 |
+
|
32 |
+
ordered_chapters = []
|
33 |
+
for id in ids:
|
34 |
+
chapter = next(
|
35 |
+
(ch for ch in chapters if ch["id"] == id),
|
36 |
+
None,
|
37 |
+
)
|
38 |
+
if chapter:
|
39 |
+
ordered_chapters.append(chapter)
|
40 |
+
|
41 |
+
return ordered_chapters
|
processing.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from markdown import Markdown
|
2 |
+
from io import StringIO
|
3 |
+
import re
|
4 |
+
from embedding import num_tokens_from_str, EMBEDDING_CHAR_LIMIT
|
5 |
+
|
6 |
+
HTMLR = re.compile("<.*?>")
|
7 |
+
WS = re.compile("\s+")
|
8 |
+
LIGHTGALLERY = re.compile("\[lightgallery.*\]")
|
9 |
+
|
10 |
+
|
11 |
+
def unmark_element(element, stream=None):
|
12 |
+
if stream is None:
|
13 |
+
stream = StringIO()
|
14 |
+
if element.text:
|
15 |
+
stream.write(element.text)
|
16 |
+
for sub in element:
|
17 |
+
unmark_element(sub, stream)
|
18 |
+
if element.tail:
|
19 |
+
stream.write(element.tail)
|
20 |
+
return stream.getvalue()
|
21 |
+
|
22 |
+
|
23 |
+
# patching Markdown
|
24 |
+
Markdown.output_formats["plain"] = unmark_element
|
25 |
+
__md = Markdown(output_format="plain", extensions=["tables"])
|
26 |
+
__md.stripTopLevelTags = False
|
27 |
+
|
28 |
+
|
29 |
+
def unmark(text):
|
30 |
+
return __md.convert(text)
|
31 |
+
|
32 |
+
|
33 |
+
def clean_md(text: str) -> list[str]:
|
34 |
+
cleantext = re.sub(HTMLR, "", text)
|
35 |
+
cleantext = re.sub(LIGHTGALLERY, "", cleantext)
|
36 |
+
para = cleantext.split("\n#")
|
37 |
+
para = [unmark(p) for p in para]
|
38 |
+
para = [re.sub(WS, " ", p.lower()) for p in para]
|
39 |
+
return para
|
40 |
+
|
41 |
+
|
42 |
+
start_seq_length = num_tokens_from_str("passage: ")
|
43 |
+
|
44 |
+
|
45 |
+
def truncate_to_sequences(text: str, max_char=EMBEDDING_CHAR_LIMIT) -> list[str]:
|
46 |
+
sequence_length = num_tokens_from_str(text) // (max_char - start_seq_length) + 1
|
47 |
+
length = len(text)
|
48 |
+
separator = length // sequence_length
|
49 |
+
|
50 |
+
sequences = []
|
51 |
+
base = 0
|
52 |
+
while base < length:
|
53 |
+
count = len(sequences) + 1
|
54 |
+
end = min(separator * count, length)
|
55 |
+
found = False
|
56 |
+
|
57 |
+
if end == length:
|
58 |
+
found = True
|
59 |
+
|
60 |
+
if found is False:
|
61 |
+
section = text[base:end]
|
62 |
+
section_rev = section[::-1]
|
63 |
+
for i in range(len(section_rev)):
|
64 |
+
if section_rev[i : i + 2] == " .":
|
65 |
+
found = True
|
66 |
+
end -= 1
|
67 |
+
break
|
68 |
+
end -= 1
|
69 |
+
|
70 |
+
if found is False:
|
71 |
+
end = separator * count
|
72 |
+
for i in range(len(section_rev)):
|
73 |
+
if section_rev[i] == " ":
|
74 |
+
found = True
|
75 |
+
break
|
76 |
+
end -= 1
|
77 |
+
|
78 |
+
if num_tokens_from_str(text[base:end]) > max_char:
|
79 |
+
sub_sequences = truncate_to_sequences(text[base:end])
|
80 |
+
sequences += sub_sequences
|
81 |
+
else:
|
82 |
+
sequences.append(text[base:end])
|
83 |
+
|
84 |
+
base = base + end
|
85 |
+
return sequences
|
86 |
+
|
87 |
+
|
88 |
+
def md_to_passages(md: str) -> list[str]:
|
89 |
+
initial_passages = clean_md(md)
|
90 |
+
passages = []
|
91 |
+
for p in initial_passages:
|
92 |
+
sequences = truncate_to_sequences(p)
|
93 |
+
passages += sequences
|
94 |
+
|
95 |
+
return passages
|
requirements.txt
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-i https://pypi.org/simple
|
2 |
+
accelerate==0.19.0
|
3 |
+
aiodns==3.0.0
|
4 |
+
aiofiles==23.1.0 ; python_version >= '3.7' and python_version < '4.0'
|
5 |
+
aiohttp==3.8.4
|
6 |
+
aiosignal==1.3.1 ; python_version >= '3.7'
|
7 |
+
altair==5.0.0 ; python_version >= '3.7'
|
8 |
+
anyio==3.6.2 ; python_full_version >= '3.6.2'
|
9 |
+
async-timeout==4.0.2 ; python_version >= '3.6'
|
10 |
+
attrs==23.1.0 ; python_version >= '3.7'
|
11 |
+
brotli==1.0.9
|
12 |
+
certifi==2023.5.7 ; python_version >= '3.6'
|
13 |
+
cffi==1.15.1
|
14 |
+
charset-normalizer==3.1.0 ; python_full_version >= '3.7.0'
|
15 |
+
click==8.1.3 ; python_version >= '3.7'
|
16 |
+
contourpy==1.0.7 ; python_version >= '3.8'
|
17 |
+
cycler==0.11.0 ; python_version >= '3.6'
|
18 |
+
dnspython==2.3.0 ; python_version >= '3.7' and python_version < '4.0'
|
19 |
+
fastapi==0.95.2 ; python_version >= '3.7'
|
20 |
+
ffmpy==0.3.0
|
21 |
+
filelock==3.12.0 ; python_version >= '3.7'
|
22 |
+
fonttools==4.39.4 ; python_version >= '3.8'
|
23 |
+
frozenlist==1.3.3 ; python_version >= '3.7'
|
24 |
+
fsspec==2023.5.0 ; python_version >= '3.8'
|
25 |
+
gradio==3.32.0
|
26 |
+
gradio-client==0.2.5 ; python_version >= '3.7'
|
27 |
+
h11==0.14.0 ; python_version >= '3.7'
|
28 |
+
httpcore==0.17.2 ; python_version >= '3.7'
|
29 |
+
httpx==0.24.1 ; python_version >= '3.7'
|
30 |
+
huggingface-hub==0.14.1 ; python_full_version >= '3.7.0'
|
31 |
+
idna==3.4 ; python_version >= '3.5'
|
32 |
+
instructorembedding==1.0.0
|
33 |
+
jinja2==3.1.2 ; python_version >= '3.7'
|
34 |
+
joblib==1.2.0 ; python_version >= '3.7'
|
35 |
+
jsonschema==4.17.3 ; python_version >= '3.7'
|
36 |
+
kiwisolver==1.4.4 ; python_version >= '3.7'
|
37 |
+
linkify-it-py==2.0.2
|
38 |
+
loguru==0.7.0 ; python_version >= '3.5'
|
39 |
+
markdown==3.4.3
|
40 |
+
markdown-it-py[linkify]==2.2.0 ; python_version >= '3.7'
|
41 |
+
markupsafe==2.1.2 ; python_version >= '3.7'
|
42 |
+
matplotlib==3.7.1 ; python_version >= '3.8'
|
43 |
+
mdit-py-plugins==0.3.3 ; python_version >= '3.7'
|
44 |
+
mdurl==0.1.2 ; python_version >= '3.7'
|
45 |
+
mpmath==1.3.0
|
46 |
+
multidict==6.0.4 ; python_version >= '3.7'
|
47 |
+
nest-asyncio==1.5.6
|
48 |
+
networkx==3.1 ; python_version >= '3.8'
|
49 |
+
nltk==3.8.1 ; python_version >= '3.7'
|
50 |
+
numpy==1.24.3
|
51 |
+
openai==0.27.7
|
52 |
+
orjson==3.8.13 ; python_version >= '3.7'
|
53 |
+
packaging==23.1 ; python_version >= '3.7'
|
54 |
+
pandas==2.0.1
|
55 |
+
pillow==9.5.0 ; python_version >= '3.7'
|
56 |
+
pinecone-client==2.2.1
|
57 |
+
protobuf==3.20.1
|
58 |
+
psutil==5.9.5 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'
|
59 |
+
psycopg2==2.9.6
|
60 |
+
pycares==4.3.0
|
61 |
+
pycparser==2.21
|
62 |
+
pydantic==1.10.8 ; python_version >= '3.7'
|
63 |
+
pydub==0.25.1
|
64 |
+
pygments==2.15.1 ; python_version >= '3.7'
|
65 |
+
pyparsing==3.0.9 ; python_full_version >= '3.6.8'
|
66 |
+
pyrsistent==0.19.3 ; python_version >= '3.7'
|
67 |
+
python-dateutil==2.8.2 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'
|
68 |
+
python-dotenv==1.0.0
|
69 |
+
python-multipart==0.0.6 ; python_version >= '3.7'
|
70 |
+
pytz==2023.3
|
71 |
+
pyyaml==6.0 ; python_version >= '3.6'
|
72 |
+
regex==2023.5.5 ; python_version >= '3.6'
|
73 |
+
requests==2.31.0 ; python_version >= '3.7'
|
74 |
+
scikit-learn==1.2.2 ; python_version >= '3.8'
|
75 |
+
scipy==1.10.1 ; python_version < '3.12' and python_version >= '3.8'
|
76 |
+
semantic-version==2.10.0 ; python_version >= '2.7'
|
77 |
+
sentence-transformers==2.2.2
|
78 |
+
sentencepiece==0.1.99
|
79 |
+
six==1.16.0 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'
|
80 |
+
sniffio==1.3.0 ; python_version >= '3.7'
|
81 |
+
starlette==0.27.0 ; python_version >= '3.7'
|
82 |
+
sympy==1.12 ; python_version >= '3.8'
|
83 |
+
threadpoolctl==3.1.0 ; python_version >= '3.6'
|
84 |
+
tiktoken==0.4.0
|
85 |
+
tokenizers==0.13.3
|
86 |
+
toolz==0.12.0 ; python_version >= '3.5'
|
87 |
+
torch==2.0.1
|
88 |
+
torchvision==0.15.2 ; python_version >= '3.8'
|
89 |
+
tqdm==4.65.0
|
90 |
+
transformers==4.29.2
|
91 |
+
typing-extensions==4.6.1 ; python_version >= '3.7'
|
92 |
+
tzdata==2023.3 ; python_version >= '2'
|
93 |
+
uc-micro-py==1.0.2 ; python_version >= '3.7'
|
94 |
+
urllib3==2.0.2 ; python_version >= '3.7'
|
95 |
+
uvicorn==0.22.0 ; python_version >= '3.7'
|
96 |
+
websockets==11.0.3 ; python_version >= '3.7'
|
97 |
+
yarl==1.9.2 ; python_version >= '3.7'
|
vectors.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from embedding import strings_to_vectors
|
2 |
+
import pinecone
|
3 |
+
import os
|
4 |
+
|
5 |
+
PINECONE_API = os.getenv("PINECONE_API")
|
6 |
+
|
7 |
+
pinecone.init(api_key=PINECONE_API, environment="us-west4-gcp-free")
|
8 |
+
|
9 |
+
vector_index = pinecone.Index("quesmed")
|
10 |
+
|
11 |
+
|
12 |
+
def scored_vector_todict(scored_vector):
|
13 |
+
x = {
|
14 |
+
"id": scored_vector["id"],
|
15 |
+
"metadata": {
|
16 |
+
"topicId": int(scored_vector["metadata"]["topicId"]),
|
17 |
+
"chapterId": int(scored_vector["metadata"]["chapterId"]),
|
18 |
+
"conceptId": int(scored_vector["metadata"]["conceptId"]),
|
19 |
+
},
|
20 |
+
"score": scored_vector["score"] * 100,
|
21 |
+
"values": scored_vector["values"],
|
22 |
+
}
|
23 |
+
for k, v in x["metadata"].items():
|
24 |
+
x[k] = int(v)
|
25 |
+
x["passage_idx"] = int(x["id"][-1])
|
26 |
+
return x
|
27 |
+
|
28 |
+
|
29 |
+
def match_query(query: str, n_res=3):
|
30 |
+
queries = [f"query: {query.replace('?','').lower()}"]
|
31 |
+
query_embeddings = strings_to_vectors(queries)
|
32 |
+
result = vector_index.query(
|
33 |
+
query_embeddings[0].tolist(),
|
34 |
+
top_k=n_res,
|
35 |
+
include_metadata=True,
|
36 |
+
namespace="quesbook",
|
37 |
+
)
|
38 |
+
return list(map(scored_vector_todict, result["matches"]))
|