Médéric Hurier (Fmind) commited on
Commit
ed67987
·
1 Parent(s): 5c68cc7

First release

Browse files
Files changed (3) hide show
  1. app.py +70 -16
  2. database.py +26 -8
  3. lib.py +29 -7
app.py CHANGED
@@ -5,7 +5,6 @@
5
  import logging
6
 
7
  import gradio as gr
8
- import tiktoken
9
 
10
  import lib
11
 
@@ -18,34 +17,90 @@ logging.basicConfig(
18
 
19
  # %% CONFIGS
20
 
 
 
21
  THEME = "glass"
22
  TITLE = "Fmind Chatbot"
 
 
 
 
 
 
 
 
 
 
23
 
 
24
  CLIENT = lib.get_database_client(path=lib.DATABASE_PATH)
25
- ENCODING = tiktoken.get_encoding(encoding_name=lib.EMBEDDING_TOKENIZER)
26
- FUNCTION = lib.get_embedding_function()
27
  COLLECTION = CLIENT.get_collection(
28
  name=lib.DATABASE_COLLECTION,
29
- embedding_function=FUNCTION,
30
  )
31
 
32
- EXAMPLES = [
33
- "Who is Médéric Hurier (Fmind)?",
34
- "Is Fmind open to new opportunities?",
35
- "What is Médéric's most recent degree?",
36
- "What is Médéric's latest work experience?",
37
- "Is Médéric proficient in Python programming?",
38
- ]
 
 
 
 
 
39
 
40
  # %% FUNCTIONS
41
 
42
 
43
  def answer(message: str, history: list[str]) -> str:
44
  """Answer questions about my resume."""
45
- tokens = ENCODING.encode(message)
46
- print("History:", len(history))
47
- print("Tokens:", len(tokens))
48
- return message
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
 
51
  # %% INTERFACES
@@ -60,7 +115,6 @@ interface = gr.ChatInterface(
60
  undo_btn=None,
61
  )
62
 
63
- # %% ENTRYPOINTS
64
 
65
  if __name__ == "__main__":
66
  interface.launch()
 
5
  import logging
6
 
7
  import gradio as gr
 
8
 
9
  import lib
10
 
 
17
 
18
  # %% CONFIGS
19
 
20
+ # %% - Frontend
21
+
22
  THEME = "glass"
23
  TITLE = "Fmind Chatbot"
24
+ EXAMPLES = [
25
+ "Who is Médéric Hurier (Fmind)?",
26
+ "Is Fmind open to new opportunities?",
27
+ "Can you share details about Médéric PhD?",
28
+ "Elaborate on Médéric current work position",
29
+ "Describe his proficiency with Python programming",
30
+ "What is the answer to life, the universe, and everything?",
31
+ ]
32
+
33
+ # %% - Backend
34
 
35
+ MODEL = lib.get_language_model()
36
  CLIENT = lib.get_database_client(path=lib.DATABASE_PATH)
37
+ ENCODING = lib.get_encoding_function()
38
+ EMBEDDING = lib.get_embedding_function()
39
  COLLECTION = CLIENT.get_collection(
40
  name=lib.DATABASE_COLLECTION,
41
+ embedding_function=EMBEDDING,
42
  )
43
 
44
+ # %% - Answer
45
+
46
+ PROMPT_CONTEXT = """
47
+ You are Fmind Chatbot, specialized in providing information regarding Médéric Hurier's (known as Fmind) professional background.
48
+ Médéric is an MLOps engineer based in Luxembourg. He is currently working at Decathlon. His calendar is booked until the conclusion of 2024.
49
+ Your responses should be succinct and maintain a professional tone. If inquiries deviate from Médéric's professional sphere, courteously decline to engage.
50
+
51
+ You may find more information about Médéric below (markdown format):
52
+ """
53
+ PROMPT_MAX_TOKENS = lib.MODEL_INPUT_LIMIT
54
+ QUERY_MAX_DISTANCE = 0.4
55
+ QUERY_N_RESULTS = 20
56
 
57
  # %% FUNCTIONS
58
 
59
 
60
  def answer(message: str, history: list[str]) -> str:
61
  """Answer questions about my resume."""
62
+ # counters
63
+ n_tokens = 0
64
+ # messages
65
+ messages = []
66
+ # - context
67
+ n_tokens += len(ENCODING(PROMPT_CONTEXT))
68
+ messages += [{"role": "system", "content": PROMPT_CONTEXT}]
69
+ # - history
70
+ for user_content, assistant_content in history:
71
+ n_tokens += len(ENCODING(user_content))
72
+ n_tokens += len(ENCODING(assistant_content))
73
+ messages += [{"role": "user", "content": user_content}]
74
+ messages += [{"role": "assistant", "content": assistant_content}]
75
+ # - message
76
+ n_tokens += len(ENCODING(message))
77
+ messages += [{"role": "user", "content": message}]
78
+ # database
79
+ results = COLLECTION.query(query_texts=message, n_results=QUERY_N_RESULTS)
80
+ logging.info("Results: %s", results)
81
+ distances = results["distances"][0]
82
+ documents = results["documents"][0]
83
+ for distance, document in zip(distances, documents):
84
+ # - distance
85
+ logging.debug("Doc distance: %f", distance)
86
+ if distance > QUERY_MAX_DISTANCE:
87
+ break
88
+ # - document
89
+ n_document_tokens = len(ENCODING(document))
90
+ logging.debug("Doc tokens: %f", n_document_tokens)
91
+ if (n_tokens + n_document_tokens) >= PROMPT_MAX_TOKENS:
92
+ break
93
+ n_tokens += n_document_tokens
94
+ messages[0]["content"] += document
95
+ # response
96
+ logging.info("Tokens: %d", n_tokens)
97
+ logging.info("Messages: %s", messages)
98
+ api_response = MODEL(messages=messages)
99
+ logging.info("Response: %s", api_response.to_dict_recursive())
100
+ # content
101
+ content = api_response["choices"][0]["message"]["content"]
102
+ # return
103
+ return content
104
 
105
 
106
  # %% INTERFACES
 
115
  undo_btn=None,
116
  )
117
 
 
118
 
119
  if __name__ == "__main__":
120
  interface.launch()
database.py CHANGED
@@ -35,9 +35,15 @@ def segment_text(text: str, pattern: str) -> T.Iterator[tuple[str, str]]:
35
  return pairs
36
 
37
 
38
- def import_file(file: T.TextIO, collection: lib.Collection) -> int:
 
 
 
 
 
39
  """Import a markdown file to a database collection."""
40
- imported = 0
 
41
  text = file.read()
42
  filename = file.name
43
  segments_h1 = segment_text(text=text, pattern=r"^# (.+)")
@@ -45,14 +51,19 @@ def import_file(file: T.TextIO, collection: lib.Collection) -> int:
45
  logging.debug('\t- H1: "%s" (%d)', h1, len(h1_text))
46
  segments_h2 = segment_text(text=h1_text, pattern=r"^## (.+)")
47
  for h2, content in segments_h2:
48
- logging.debug('\t\t- H2: "%s" (%d)', h2, len(content))
 
 
49
  id_ = f"{filename} # {h1} ## {h2}" # unique doc id
50
  document = f"# {h1}\n\n## {h2}\n\n{content.strip()}"
51
  metadata = {"filename": filename, "h1": h1, "h2": h2}
52
- assert len(content) < 8000, f"Content is too long: #{h1} ##{h2}"
 
 
53
  collection.add(ids=id_, documents=document, metadatas=metadata)
54
- imported += len(document)
55
- return imported
 
56
 
57
 
58
  def main(args: list[str] | None = None) -> int:
@@ -64,6 +75,9 @@ def main(args: list[str] | None = None) -> int:
64
  logging.info("Database path: %s", database_path)
65
  client = lib.get_database_client(path=database_path)
66
  logging.info("- Reseting database client: %s", client.reset())
 
 
 
67
  # embedding
68
  embedding_function = lib.get_embedding_function()
69
  logging.info("Embedding function: %s", embedding_function)
@@ -76,8 +90,12 @@ def main(args: list[str] | None = None) -> int:
76
  # files
77
  for i, file in enumerate(opts.files):
78
  logging.info("Importing file %d: %s", i, file.name)
79
- imported = import_file(file=file, collection=collection)
80
- logging.info("- Docs imported from file %s: %d chars", i, imported)
 
 
 
 
81
  # return
82
  return 0
83
 
 
35
  return pairs
36
 
37
 
38
+ def import_file(
39
+ file: T.TextIO,
40
+ collection: lib.Collection,
41
+ encoding_function: T.Callable,
42
+ max_output_tokens: int = lib.ENCODING_OUTPUT_LIMIT,
43
+ ) -> tuple[int, int]:
44
  """Import a markdown file to a database collection."""
45
+ n_chars = 0
46
+ n_tokens = 0
47
  text = file.read()
48
  filename = file.name
49
  segments_h1 = segment_text(text=text, pattern=r"^# (.+)")
 
51
  logging.debug('\t- H1: "%s" (%d)', h1, len(h1_text))
52
  segments_h2 = segment_text(text=h1_text, pattern=r"^## (.+)")
53
  for h2, content in segments_h2:
54
+ content_chars = len(content)
55
+ content_tokens = len(encoding_function(content))
56
+ logging.debug('\t\t- H2: "%s" (%d)', h2, content_chars)
57
  id_ = f"{filename} # {h1} ## {h2}" # unique doc id
58
  document = f"# {h1}\n\n## {h2}\n\n{content.strip()}"
59
  metadata = {"filename": filename, "h1": h1, "h2": h2}
60
+ assert (
61
+ content_tokens < max_output_tokens
62
+ ), f"Content is too long ({content_tokens}): #{h1} ##{h2}"
63
  collection.add(ids=id_, documents=document, metadatas=metadata)
64
+ n_tokens += content_tokens
65
+ n_chars += content_chars
66
+ return n_chars, n_tokens
67
 
68
 
69
  def main(args: list[str] | None = None) -> int:
 
75
  logging.info("Database path: %s", database_path)
76
  client = lib.get_database_client(path=database_path)
77
  logging.info("- Reseting database client: %s", client.reset())
78
+ # encoding
79
+ encoding_function = lib.get_encoding_function()
80
+ logging.info("Encoding function: %s", encoding_function)
81
  # embedding
82
  embedding_function = lib.get_embedding_function()
83
  logging.info("Embedding function: %s", embedding_function)
 
90
  # files
91
  for i, file in enumerate(opts.files):
92
  logging.info("Importing file %d: %s", i, file.name)
93
+ n_chars, n_tokens = import_file(
94
+ file=file, collection=collection, encoding_function=encoding_function
95
+ )
96
+ logging.info(
97
+ "- Docs imported from file %s: %d chars | %d tokens", i, n_chars, n_tokens
98
+ )
99
  # return
100
  return 0
101
 
lib.py CHANGED
@@ -5,6 +5,7 @@
5
 
6
  __import__("pysqlite3")
7
 
 
8
  import os
9
  import sys
10
 
@@ -12,6 +13,8 @@ import sys
12
  sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
13
 
14
  import chromadb
 
 
15
  from chromadb.utils import embedding_functions
16
 
17
  # %% CONFIGS
@@ -20,7 +23,13 @@ DATABASE_COLLECTION = "resume"
20
  DATABASE_PATH = "database"
21
 
22
  EMBEDDING_MODEL = "text-embedding-ada-002"
23
- EMBEDDING_TOKENIZER = "cl100k_base"
 
 
 
 
 
 
24
 
25
  OPENAI_API_KEY = os.environ["OPENAI_API_KEY"]
26
 
@@ -31,20 +40,33 @@ Collection = chromadb.Collection
31
  # %% FUNCTIONS
32
 
33
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  def get_database_client(path: str) -> chromadb.API:
35
  """Get a persistent client to the Chroma DB."""
36
- settings = chromadb.Settings(
37
- allow_reset=True,
38
- anonymized_telemetry=False,
39
- )
40
  return chromadb.PersistentClient(path=path, settings=settings)
41
 
42
 
 
 
 
 
 
43
  def get_embedding_function(
44
  model_name: str = EMBEDDING_MODEL, api_key: str = OPENAI_API_KEY
45
  ) -> embedding_functions.EmbeddingFunction:
46
  """Get the embedding function for Chroma DB collections."""
47
  return embedding_functions.OpenAIEmbeddingFunction(
48
- model_name=model_name,
49
- api_key=api_key,
50
  )
 
5
 
6
  __import__("pysqlite3")
7
 
8
+ import functools
9
  import os
10
  import sys
11
 
 
13
  sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
14
 
15
  import chromadb
16
+ import openai
17
+ import tiktoken
18
  from chromadb.utils import embedding_functions
19
 
20
  # %% CONFIGS
 
23
  DATABASE_PATH = "database"
24
 
25
  EMBEDDING_MODEL = "text-embedding-ada-002"
26
+
27
+ ENCODING_NAME = "cl100k_base"
28
+ ENCODING_OUTPUT_LIMIT = 8191
29
+
30
+ MODEL_NAME = "gpt-3.5-turbo-16k"
31
+ MODEL_INPUT_LIMIT = 16_385
32
+ MODEL_TEMPERATURE = 0.9
33
 
34
  OPENAI_API_KEY = os.environ["OPENAI_API_KEY"]
35
 
 
40
  # %% FUNCTIONS
41
 
42
 
43
+ def get_language_model(
44
+ model: str = MODEL_NAME,
45
+ api_key: str = OPENAI_API_KEY,
46
+ temperature: float = MODEL_TEMPERATURE,
47
+ ) -> openai.ChatCompletion:
48
+ """Get an OpenAI ChatCompletion model."""
49
+ openai.api_key = api_key # configure the API key globally
50
+ return functools.partial(
51
+ openai.ChatCompletion.create, model=model, temperature=temperature
52
+ )
53
+
54
+
55
  def get_database_client(path: str) -> chromadb.API:
56
  """Get a persistent client to the Chroma DB."""
57
+ settings = chromadb.Settings(allow_reset=True, anonymized_telemetry=False)
 
 
 
58
  return chromadb.PersistentClient(path=path, settings=settings)
59
 
60
 
61
+ def get_encoding_function(encoding_name: str = ENCODING_NAME) -> tiktoken.Encoding:
62
+ """Get the encoding function for OpenAI models."""
63
+ return tiktoken.get_encoding(encoding_name=encoding_name).encode
64
+
65
+
66
  def get_embedding_function(
67
  model_name: str = EMBEDDING_MODEL, api_key: str = OPENAI_API_KEY
68
  ) -> embedding_functions.EmbeddingFunction:
69
  """Get the embedding function for Chroma DB collections."""
70
  return embedding_functions.OpenAIEmbeddingFunction(
71
+ model_name=model_name, api_key=api_key
 
72
  )