Spaces:
Running
Running
Médéric Hurier (Fmind)
commited on
Commit
·
ed67987
1
Parent(s):
5c68cc7
First release
Browse files- app.py +70 -16
- database.py +26 -8
- lib.py +29 -7
app.py
CHANGED
@@ -5,7 +5,6 @@
|
|
5 |
import logging
|
6 |
|
7 |
import gradio as gr
|
8 |
-
import tiktoken
|
9 |
|
10 |
import lib
|
11 |
|
@@ -18,34 +17,90 @@ logging.basicConfig(
|
|
18 |
|
19 |
# %% CONFIGS
|
20 |
|
|
|
|
|
21 |
THEME = "glass"
|
22 |
TITLE = "Fmind Chatbot"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
|
|
24 |
CLIENT = lib.get_database_client(path=lib.DATABASE_PATH)
|
25 |
-
ENCODING =
|
26 |
-
|
27 |
COLLECTION = CLIENT.get_collection(
|
28 |
name=lib.DATABASE_COLLECTION,
|
29 |
-
embedding_function=
|
30 |
)
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
# %% FUNCTIONS
|
41 |
|
42 |
|
43 |
def answer(message: str, history: list[str]) -> str:
|
44 |
"""Answer questions about my resume."""
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
|
51 |
# %% INTERFACES
|
@@ -60,7 +115,6 @@ interface = gr.ChatInterface(
|
|
60 |
undo_btn=None,
|
61 |
)
|
62 |
|
63 |
-
# %% ENTRYPOINTS
|
64 |
|
65 |
if __name__ == "__main__":
|
66 |
interface.launch()
|
|
|
5 |
import logging
|
6 |
|
7 |
import gradio as gr
|
|
|
8 |
|
9 |
import lib
|
10 |
|
|
|
17 |
|
18 |
# %% CONFIGS
|
19 |
|
20 |
+
# %% - Frontend
|
21 |
+
|
22 |
THEME = "glass"
|
23 |
TITLE = "Fmind Chatbot"
|
24 |
+
EXAMPLES = [
|
25 |
+
"Who is Médéric Hurier (Fmind)?",
|
26 |
+
"Is Fmind open to new opportunities?",
|
27 |
+
"Can you share details about Médéric PhD?",
|
28 |
+
"Elaborate on Médéric current work position",
|
29 |
+
"Describe his proficiency with Python programming",
|
30 |
+
"What is the answer to life, the universe, and everything?",
|
31 |
+
]
|
32 |
+
|
33 |
+
# %% - Backend
|
34 |
|
35 |
+
MODEL = lib.get_language_model()
|
36 |
CLIENT = lib.get_database_client(path=lib.DATABASE_PATH)
|
37 |
+
ENCODING = lib.get_encoding_function()
|
38 |
+
EMBEDDING = lib.get_embedding_function()
|
39 |
COLLECTION = CLIENT.get_collection(
|
40 |
name=lib.DATABASE_COLLECTION,
|
41 |
+
embedding_function=EMBEDDING,
|
42 |
)
|
43 |
|
44 |
+
# %% - Answer
|
45 |
+
|
46 |
+
PROMPT_CONTEXT = """
|
47 |
+
You are Fmind Chatbot, specialized in providing information regarding Médéric Hurier's (known as Fmind) professional background.
|
48 |
+
Médéric is an MLOps engineer based in Luxembourg. He is currently working at Decathlon. His calendar is booked until the conclusion of 2024.
|
49 |
+
Your responses should be succinct and maintain a professional tone. If inquiries deviate from Médéric's professional sphere, courteously decline to engage.
|
50 |
+
|
51 |
+
You may find more information about Médéric below (markdown format):
|
52 |
+
"""
|
53 |
+
PROMPT_MAX_TOKENS = lib.MODEL_INPUT_LIMIT
|
54 |
+
QUERY_MAX_DISTANCE = 0.4
|
55 |
+
QUERY_N_RESULTS = 20
|
56 |
|
57 |
# %% FUNCTIONS
|
58 |
|
59 |
|
60 |
def answer(message: str, history: list[str]) -> str:
|
61 |
"""Answer questions about my resume."""
|
62 |
+
# counters
|
63 |
+
n_tokens = 0
|
64 |
+
# messages
|
65 |
+
messages = []
|
66 |
+
# - context
|
67 |
+
n_tokens += len(ENCODING(PROMPT_CONTEXT))
|
68 |
+
messages += [{"role": "system", "content": PROMPT_CONTEXT}]
|
69 |
+
# - history
|
70 |
+
for user_content, assistant_content in history:
|
71 |
+
n_tokens += len(ENCODING(user_content))
|
72 |
+
n_tokens += len(ENCODING(assistant_content))
|
73 |
+
messages += [{"role": "user", "content": user_content}]
|
74 |
+
messages += [{"role": "assistant", "content": assistant_content}]
|
75 |
+
# - message
|
76 |
+
n_tokens += len(ENCODING(message))
|
77 |
+
messages += [{"role": "user", "content": message}]
|
78 |
+
# database
|
79 |
+
results = COLLECTION.query(query_texts=message, n_results=QUERY_N_RESULTS)
|
80 |
+
logging.info("Results: %s", results)
|
81 |
+
distances = results["distances"][0]
|
82 |
+
documents = results["documents"][0]
|
83 |
+
for distance, document in zip(distances, documents):
|
84 |
+
# - distance
|
85 |
+
logging.debug("Doc distance: %f", distance)
|
86 |
+
if distance > QUERY_MAX_DISTANCE:
|
87 |
+
break
|
88 |
+
# - document
|
89 |
+
n_document_tokens = len(ENCODING(document))
|
90 |
+
logging.debug("Doc tokens: %f", n_document_tokens)
|
91 |
+
if (n_tokens + n_document_tokens) >= PROMPT_MAX_TOKENS:
|
92 |
+
break
|
93 |
+
n_tokens += n_document_tokens
|
94 |
+
messages[0]["content"] += document
|
95 |
+
# response
|
96 |
+
logging.info("Tokens: %d", n_tokens)
|
97 |
+
logging.info("Messages: %s", messages)
|
98 |
+
api_response = MODEL(messages=messages)
|
99 |
+
logging.info("Response: %s", api_response.to_dict_recursive())
|
100 |
+
# content
|
101 |
+
content = api_response["choices"][0]["message"]["content"]
|
102 |
+
# return
|
103 |
+
return content
|
104 |
|
105 |
|
106 |
# %% INTERFACES
|
|
|
115 |
undo_btn=None,
|
116 |
)
|
117 |
|
|
|
118 |
|
119 |
if __name__ == "__main__":
|
120 |
interface.launch()
|
database.py
CHANGED
@@ -35,9 +35,15 @@ def segment_text(text: str, pattern: str) -> T.Iterator[tuple[str, str]]:
|
|
35 |
return pairs
|
36 |
|
37 |
|
38 |
-
def import_file(
|
|
|
|
|
|
|
|
|
|
|
39 |
"""Import a markdown file to a database collection."""
|
40 |
-
|
|
|
41 |
text = file.read()
|
42 |
filename = file.name
|
43 |
segments_h1 = segment_text(text=text, pattern=r"^# (.+)")
|
@@ -45,14 +51,19 @@ def import_file(file: T.TextIO, collection: lib.Collection) -> int:
|
|
45 |
logging.debug('\t- H1: "%s" (%d)', h1, len(h1_text))
|
46 |
segments_h2 = segment_text(text=h1_text, pattern=r"^## (.+)")
|
47 |
for h2, content in segments_h2:
|
48 |
-
|
|
|
|
|
49 |
id_ = f"{filename} # {h1} ## {h2}" # unique doc id
|
50 |
document = f"# {h1}\n\n## {h2}\n\n{content.strip()}"
|
51 |
metadata = {"filename": filename, "h1": h1, "h2": h2}
|
52 |
-
assert
|
|
|
|
|
53 |
collection.add(ids=id_, documents=document, metadatas=metadata)
|
54 |
-
|
55 |
-
|
|
|
56 |
|
57 |
|
58 |
def main(args: list[str] | None = None) -> int:
|
@@ -64,6 +75,9 @@ def main(args: list[str] | None = None) -> int:
|
|
64 |
logging.info("Database path: %s", database_path)
|
65 |
client = lib.get_database_client(path=database_path)
|
66 |
logging.info("- Reseting database client: %s", client.reset())
|
|
|
|
|
|
|
67 |
# embedding
|
68 |
embedding_function = lib.get_embedding_function()
|
69 |
logging.info("Embedding function: %s", embedding_function)
|
@@ -76,8 +90,12 @@ def main(args: list[str] | None = None) -> int:
|
|
76 |
# files
|
77 |
for i, file in enumerate(opts.files):
|
78 |
logging.info("Importing file %d: %s", i, file.name)
|
79 |
-
|
80 |
-
|
|
|
|
|
|
|
|
|
81 |
# return
|
82 |
return 0
|
83 |
|
|
|
35 |
return pairs
|
36 |
|
37 |
|
38 |
+
def import_file(
|
39 |
+
file: T.TextIO,
|
40 |
+
collection: lib.Collection,
|
41 |
+
encoding_function: T.Callable,
|
42 |
+
max_output_tokens: int = lib.ENCODING_OUTPUT_LIMIT,
|
43 |
+
) -> tuple[int, int]:
|
44 |
"""Import a markdown file to a database collection."""
|
45 |
+
n_chars = 0
|
46 |
+
n_tokens = 0
|
47 |
text = file.read()
|
48 |
filename = file.name
|
49 |
segments_h1 = segment_text(text=text, pattern=r"^# (.+)")
|
|
|
51 |
logging.debug('\t- H1: "%s" (%d)', h1, len(h1_text))
|
52 |
segments_h2 = segment_text(text=h1_text, pattern=r"^## (.+)")
|
53 |
for h2, content in segments_h2:
|
54 |
+
content_chars = len(content)
|
55 |
+
content_tokens = len(encoding_function(content))
|
56 |
+
logging.debug('\t\t- H2: "%s" (%d)', h2, content_chars)
|
57 |
id_ = f"{filename} # {h1} ## {h2}" # unique doc id
|
58 |
document = f"# {h1}\n\n## {h2}\n\n{content.strip()}"
|
59 |
metadata = {"filename": filename, "h1": h1, "h2": h2}
|
60 |
+
assert (
|
61 |
+
content_tokens < max_output_tokens
|
62 |
+
), f"Content is too long ({content_tokens}): #{h1} ##{h2}"
|
63 |
collection.add(ids=id_, documents=document, metadatas=metadata)
|
64 |
+
n_tokens += content_tokens
|
65 |
+
n_chars += content_chars
|
66 |
+
return n_chars, n_tokens
|
67 |
|
68 |
|
69 |
def main(args: list[str] | None = None) -> int:
|
|
|
75 |
logging.info("Database path: %s", database_path)
|
76 |
client = lib.get_database_client(path=database_path)
|
77 |
logging.info("- Reseting database client: %s", client.reset())
|
78 |
+
# encoding
|
79 |
+
encoding_function = lib.get_encoding_function()
|
80 |
+
logging.info("Encoding function: %s", encoding_function)
|
81 |
# embedding
|
82 |
embedding_function = lib.get_embedding_function()
|
83 |
logging.info("Embedding function: %s", embedding_function)
|
|
|
90 |
# files
|
91 |
for i, file in enumerate(opts.files):
|
92 |
logging.info("Importing file %d: %s", i, file.name)
|
93 |
+
n_chars, n_tokens = import_file(
|
94 |
+
file=file, collection=collection, encoding_function=encoding_function
|
95 |
+
)
|
96 |
+
logging.info(
|
97 |
+
"- Docs imported from file %s: %d chars | %d tokens", i, n_chars, n_tokens
|
98 |
+
)
|
99 |
# return
|
100 |
return 0
|
101 |
|
lib.py
CHANGED
@@ -5,6 +5,7 @@
|
|
5 |
|
6 |
__import__("pysqlite3")
|
7 |
|
|
|
8 |
import os
|
9 |
import sys
|
10 |
|
@@ -12,6 +13,8 @@ import sys
|
|
12 |
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
|
13 |
|
14 |
import chromadb
|
|
|
|
|
15 |
from chromadb.utils import embedding_functions
|
16 |
|
17 |
# %% CONFIGS
|
@@ -20,7 +23,13 @@ DATABASE_COLLECTION = "resume"
|
|
20 |
DATABASE_PATH = "database"
|
21 |
|
22 |
EMBEDDING_MODEL = "text-embedding-ada-002"
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
OPENAI_API_KEY = os.environ["OPENAI_API_KEY"]
|
26 |
|
@@ -31,20 +40,33 @@ Collection = chromadb.Collection
|
|
31 |
# %% FUNCTIONS
|
32 |
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
def get_database_client(path: str) -> chromadb.API:
|
35 |
"""Get a persistent client to the Chroma DB."""
|
36 |
-
settings = chromadb.Settings(
|
37 |
-
allow_reset=True,
|
38 |
-
anonymized_telemetry=False,
|
39 |
-
)
|
40 |
return chromadb.PersistentClient(path=path, settings=settings)
|
41 |
|
42 |
|
|
|
|
|
|
|
|
|
|
|
43 |
def get_embedding_function(
|
44 |
model_name: str = EMBEDDING_MODEL, api_key: str = OPENAI_API_KEY
|
45 |
) -> embedding_functions.EmbeddingFunction:
|
46 |
"""Get the embedding function for Chroma DB collections."""
|
47 |
return embedding_functions.OpenAIEmbeddingFunction(
|
48 |
-
model_name=model_name,
|
49 |
-
api_key=api_key,
|
50 |
)
|
|
|
5 |
|
6 |
__import__("pysqlite3")
|
7 |
|
8 |
+
import functools
|
9 |
import os
|
10 |
import sys
|
11 |
|
|
|
13 |
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
|
14 |
|
15 |
import chromadb
|
16 |
+
import openai
|
17 |
+
import tiktoken
|
18 |
from chromadb.utils import embedding_functions
|
19 |
|
20 |
# %% CONFIGS
|
|
|
23 |
DATABASE_PATH = "database"
|
24 |
|
25 |
EMBEDDING_MODEL = "text-embedding-ada-002"
|
26 |
+
|
27 |
+
ENCODING_NAME = "cl100k_base"
|
28 |
+
ENCODING_OUTPUT_LIMIT = 8191
|
29 |
+
|
30 |
+
MODEL_NAME = "gpt-3.5-turbo-16k"
|
31 |
+
MODEL_INPUT_LIMIT = 16_385
|
32 |
+
MODEL_TEMPERATURE = 0.9
|
33 |
|
34 |
OPENAI_API_KEY = os.environ["OPENAI_API_KEY"]
|
35 |
|
|
|
40 |
# %% FUNCTIONS
|
41 |
|
42 |
|
43 |
+
def get_language_model(
|
44 |
+
model: str = MODEL_NAME,
|
45 |
+
api_key: str = OPENAI_API_KEY,
|
46 |
+
temperature: float = MODEL_TEMPERATURE,
|
47 |
+
) -> openai.ChatCompletion:
|
48 |
+
"""Get an OpenAI ChatCompletion model."""
|
49 |
+
openai.api_key = api_key # configure the API key globally
|
50 |
+
return functools.partial(
|
51 |
+
openai.ChatCompletion.create, model=model, temperature=temperature
|
52 |
+
)
|
53 |
+
|
54 |
+
|
55 |
def get_database_client(path: str) -> chromadb.API:
|
56 |
"""Get a persistent client to the Chroma DB."""
|
57 |
+
settings = chromadb.Settings(allow_reset=True, anonymized_telemetry=False)
|
|
|
|
|
|
|
58 |
return chromadb.PersistentClient(path=path, settings=settings)
|
59 |
|
60 |
|
61 |
+
def get_encoding_function(encoding_name: str = ENCODING_NAME) -> tiktoken.Encoding:
|
62 |
+
"""Get the encoding function for OpenAI models."""
|
63 |
+
return tiktoken.get_encoding(encoding_name=encoding_name).encode
|
64 |
+
|
65 |
+
|
66 |
def get_embedding_function(
|
67 |
model_name: str = EMBEDDING_MODEL, api_key: str = OPENAI_API_KEY
|
68 |
) -> embedding_functions.EmbeddingFunction:
|
69 |
"""Get the embedding function for Chroma DB collections."""
|
70 |
return embedding_functions.OpenAIEmbeddingFunction(
|
71 |
+
model_name=model_name, api_key=api_key
|
|
|
72 |
)
|