ilanaliouchouche commited on
Commit
ac1eff7
1 Parent(s): 2270397

V1 of my Assistant powered by RAG

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ chromadb/**/*.bin filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import cohere
4
+ from typing import Generator
5
+ from langchain_chroma import Chroma
6
+ from langchain_huggingface import HuggingFaceEmbeddings
7
+ from langchain.schema.document import Document
8
+ from typing import List
9
+
10
+
11
+ class HFSpaceChatBot:
12
+ """
13
+ A chatbot powered by Retrieval Augmented Generation (RAG) aimed
14
+ to be deployed on the Hugging Face Space platform.
15
+ """
16
+
17
+ def __init__(self,
18
+ embedding_model_path: str,
19
+ vector_database_path: str,
20
+ top_k: int = 10,
21
+ embedding_model_name: str = os.getenv("EMBEDDING_MODEL"),
22
+ api_key: str = os.getenv("CO_API_KEY"),
23
+ device: str = os.getenv("DEVICE"),
24
+ system_prompt: str = "Answer the user's question",
25
+ **kwargs) -> None:
26
+ """
27
+ Constructor for the HFSpaceChatBot class.
28
+
29
+ Args:
30
+ embedding_model_path (str): The path to the embedding model.
31
+ vector_database_path (str): The path to the vector database.
32
+ top_k (int): The number of top documents to retrieve.
33
+ embedding_model_name (str): The name of the embedding model.
34
+ api_key (str): The API key for the cohere API.
35
+ device (str): The device to run the model on.
36
+ system_prompt (str): The system prompt for the chatbot.
37
+ **kwargs: Additional keyword arguments (for the cohere API)
38
+ """
39
+
40
+ self.chat_history = []
41
+ self.cclient = cohere.Client(api_key=api_key)
42
+
43
+ self.embedding_model = HuggingFaceEmbeddings(
44
+ model_name=embedding_model_name,
45
+ model_kwargs={"device": device},
46
+ encode_kwargs={"normalize_embeddings": True},
47
+ cache_folder=embedding_model_path
48
+ )
49
+
50
+ self.vector_database = Chroma(
51
+ persist_directory=vector_database_path,
52
+ embedding_function=self.embedding_model
53
+ )
54
+
55
+ self.top_k = top_k
56
+
57
+ self.system_prompt = system_prompt
58
+
59
+ self.model_params = kwargs
60
+
61
+ def _get_relevant_information(self,
62
+ message: str) -> List[Document]:
63
+ """
64
+ Get the relevant information from the chat history.
65
+
66
+ Args:
67
+ message (str): The message to search for.
68
+
69
+ Returns:
70
+ List[Document]: A list of relevant documents.
71
+ """
72
+
73
+ return self.vector_database.similarity_search(message, self.top_k)
74
+
75
+ def _fetch_response(self,
76
+ message: str,
77
+ *args) -> Generator[str, None, None]:
78
+ """
79
+ Fetch the reponse from the cohere API.
80
+
81
+ Args:
82
+ message (str): The message of the user.
83
+
84
+ Returns:
85
+ Generator[str, None, None]: A generator yielding the output tokens.
86
+ """
87
+
88
+ docs = self._get_relevant_information(message)
89
+
90
+ relevant_information = "\n".join(
91
+ [doc.page_content
92
+ for doc in docs])
93
+
94
+ final_message = f"{self.system_prompt}\nWith the help of the\
95
+ following context:\n{relevant_information}\n\
96
+ Answer the following question:\n{message}"
97
+
98
+ response = self.cclient.chat_stream(
99
+ message=final_message,
100
+ chat_history=self.chat_history,
101
+ **self.model_params
102
+ )
103
+
104
+ current_text = ""
105
+ for event in response:
106
+ if event.event_type == "text-generation":
107
+ current_text += event.text
108
+ yield current_text
109
+
110
+ self.chat_history.append({
111
+ "role": "USER",
112
+ "text": message
113
+ })
114
+
115
+ self.chat_history.append({
116
+ "role": "CHATBOT",
117
+ "text": current_text
118
+ })
119
+
120
+ def launch(self,
121
+ title: str,
122
+ description: str) -> None:
123
+ """
124
+ Launch the chat interface.
125
+
126
+ Args:
127
+ title (str): The title of the chat interface.
128
+ description (str): The description of the chat interface.
129
+ """
130
+
131
+ gr.ChatInterface(
132
+ fn=self._fetch_response,
133
+ title=title,
134
+ description=description
135
+ ).launch()
136
+
137
+
138
+ if __name__ == "__main__":
139
+
140
+ embedding_model_path = os.path.join(os.getcwd(), "model")
141
+ system_prompt = """You are now assuming the role of the personal assistant
142
+ of Ilan ALIOUCHOUCHE, a French Computer Science student.
143
+ Your task is to assist users by answering their
144
+ questions about Ilan. You have access to comprehensive
145
+ details about Ilan's education, skills, professional
146
+ experience, and interests.
147
+ """.replace("\n", "")
148
+
149
+ chatbot = HFSpaceChatBot(
150
+ embedding_model_path=embedding_model_path,
151
+ vector_database_path=os.path.join(os.getcwd(), "chromadb"),
152
+ system_prompt=system_prompt,
153
+ temperature=0.4
154
+ )
155
+
156
+ title = "🤖 Ilan's Personal Agent 🤖"
157
+
158
+ description = """
159
+ You can ask my assistant (almost) anything about me! :D
160
+
161
+ You are currently using the Hugging Face Space version 🤗. A Docker image 🐳 for local use, utilizing a GGUF model is also available [here](https://github.com/ilanaliouchouche/my-ai-cv/pkgs/container/my-cv)
162
+
163
+ """
164
+
165
+ chatbot.launch(
166
+ title=title,
167
+ description=description
168
+ )
chromadb/c94c08ce-73f6-4e33-9eb4-4e32e42ae999/data_level0.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:95042e844cfb77b20e578cf65635282a99d7c4dd20e589ac062f38bc389f8e58
3
+ size 4236000
chromadb/c94c08ce-73f6-4e33-9eb4-4e32e42ae999/header.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fcc596bc1909f7cc610d5839236c90513b4fbad06776c253fa1b21bfd712e940
3
+ size 100
chromadb/c94c08ce-73f6-4e33-9eb4-4e32e42ae999/length.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc19b1997119425765295aeab72d76faa6927d4f83985d328c26f20468d6cc76
3
+ size 4000
chromadb/c94c08ce-73f6-4e33-9eb4-4e32e42ae999/link_lists.bin ADDED
File without changes
chromadb/chroma.sqlite3 ADDED
Binary file (365 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ aiohttp==3.9.5
3
+ aiosignal==1.3.1
4
+ altair==5.3.0
5
+ annotated-types==0.7.0
6
+ anyio==4.4.0
7
+ asgiref==3.8.1
8
+ attrs==23.2.0
9
+ backoff==2.2.1
10
+ bcrypt==4.1.3
11
+ boto3==1.34.122
12
+ botocore==1.34.122
13
+ build==1.2.1
14
+ cachetools==5.3.3
15
+ certifi==2024.6.2
16
+ charset-normalizer==3.3.2
17
+ chroma-hnswlib==0.7.3
18
+ chromadb==0.5.0
19
+ click==8.1.7
20
+ cohere==5.5.6
21
+ coloredlogs==15.0.1
22
+ contourpy==1.2.1
23
+ cycler==0.12.1
24
+ dataclasses-json==0.6.6
25
+ Deprecated==1.2.14
26
+ dnspython==2.6.1
27
+ email_validator==2.1.1
28
+ fastapi==0.111.0
29
+ fastapi-cli==0.0.4
30
+ fastavro==1.9.4
31
+ ffmpy==0.3.2
32
+ filelock==3.14.0
33
+ flatbuffers==24.3.25
34
+ fonttools==4.53.0
35
+ frozenlist==1.4.1
36
+ fsspec==2024.6.0
37
+ google-auth==2.30.0
38
+ googleapis-common-protos==1.63.1
39
+ grpcio==1.64.1
40
+ h11==0.14.0
41
+ httpcore==1.0.5
42
+ httptools==0.6.1
43
+ httpx==0.27.0
44
+ httpx-sse==0.4.0
45
+ huggingface-hub==0.23.3
46
+ humanfriendly==10.0
47
+ idna==3.7
48
+ importlib_resources==6.4.0
49
+ ipywidgets==8.1.3
50
+ Jinja2==3.1.4
51
+ jmespath==1.0.1
52
+ joblib==1.4.2
53
+ jsonpatch==1.33
54
+ jsonpointer==2.4
55
+ jsonschema==4.22.0
56
+ jsonschema-specifications==2023.12.1
57
+ jupyterlab_widgets==3.0.11
58
+ kiwisolver==1.4.5
59
+ kubernetes==30.1.0
60
+ langchain==0.2.3
61
+ langchain-chroma==0.1.1
62
+ langchain-community==0.2.4
63
+ langchain-core==0.2.5
64
+ langchain-huggingface==0.0.3
65
+ langchain-text-splitters==0.2.1
66
+ langsmith==0.1.75
67
+ markdown-it-py==3.0.0
68
+ MarkupSafe==2.1.5
69
+ marshmallow==3.21.3
70
+ matplotlib==3.9.0
71
+ mdurl==0.1.2
72
+ mmh3==4.1.0
73
+ monotonic==1.6
74
+ mpmath==1.3.0
75
+ multidict==6.0.5
76
+ mypy-extensions==1.0.0
77
+ networkx==3.3
78
+ numpy==1.26.4
79
+ oauthlib==3.2.2
80
+ onnxruntime==1.18.0
81
+ opentelemetry-api==1.25.0
82
+ opentelemetry-exporter-otlp-proto-common==1.25.0
83
+ opentelemetry-exporter-otlp-proto-grpc==1.25.0
84
+ opentelemetry-instrumentation==0.46b0
85
+ opentelemetry-instrumentation-asgi==0.46b0
86
+ opentelemetry-instrumentation-fastapi==0.46b0
87
+ opentelemetry-proto==1.25.0
88
+ opentelemetry-sdk==1.25.0
89
+ opentelemetry-semantic-conventions==0.46b0
90
+ opentelemetry-util-http==0.46b0
91
+ orjson==3.10.3
92
+ overrides==7.7.0
93
+ packaging==23.2
94
+ pandas==2.2.2
95
+ parameterized==0.9.0
96
+ pillow==10.3.0
97
+ posthog==3.5.0
98
+ protobuf==4.25.3
99
+ pyasn1==0.6.0
100
+ pyasn1_modules==0.4.0
101
+ pydantic==2.7.3
102
+ pydantic_core==2.18.4
103
+ pydub==0.25.1
104
+ pyparsing==3.1.2
105
+ PyPika==0.48.9
106
+ pyproject_hooks==1.1.0
107
+ python-dotenv==1.0.1
108
+ python-multipart==0.0.9
109
+ pytz==2024.1
110
+ PyYAML==6.0.1
111
+ referencing==0.35.1
112
+ regex==2024.5.15
113
+ requests==2.32.3
114
+ requests-oauthlib==2.0.0
115
+ rich==13.7.1
116
+ rpds-py==0.18.1
117
+ rsa==4.9
118
+ ruff==0.4.8
119
+ s3transfer==0.10.1
120
+ safetensors==0.4.3
121
+ scikit-learn==1.5.0
122
+ scipy==1.13.1
123
+ semantic-version==2.10.0
124
+ sentence-transformers==3.0.1
125
+ shellingham==1.5.4
126
+ sniffio==1.3.1
127
+ SQLAlchemy==2.0.30
128
+ starlette==0.37.2
129
+ sympy==1.12.1
130
+ tenacity==8.3.0
131
+ threadpoolctl==3.5.0
132
+ tokenizers==0.19.1
133
+ tomlkit==0.12.0
134
+ toolz==0.12.1
135
+ torch==2.3.1
136
+ tqdm==4.66.4
137
+ transformers==4.41.2
138
+ typer==0.12.3
139
+ types-requests==2.32.0.20240602
140
+ typing-inspect==0.9.0
141
+ tzdata==2024.1
142
+ ujson==5.10.0
143
+ urllib3==2.2.1
144
+ uvicorn==0.30.1
145
+ uvloop==0.19.0
146
+ watchfiles==0.22.0
147
+ websocket-client==1.8.0
148
+ websockets==11.0.3
149
+ widgetsnbextension==4.0.11
150
+ wrapt==1.16.0
151
+ yarl==1.9.4