teddyllm commited on
Commit
4a91290
1 Parent(s): 7cbc1c8

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +111 -0
  2. documents.json +0 -0
  3. faiss_index.bin +3 -0
  4. requirements.txt +100 -0
app.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import faiss
3
+ import numpy as np
4
+ import json
5
+ import gradio as gr
6
+ from openai import OpenAI
7
+ from sentence_transformers import SentenceTransformer
8
+
9
+ # Step 1: Set up OpenAI API key
10
+ openai_api_key = os.environ.get("OPENAI_API_KEY", "")
11
+
12
+ client = OpenAI(api_key=openai_api_key)
13
+
14
+ # Step 2: Load the pre-trained FAISS index and SentenceTransformer model
15
+ index = faiss.read_index("faiss_index.bin")
16
+ model = SentenceTransformer('all-MiniLM-L6-v2')
17
+
18
+ def load_documents(docs_path):
19
+ with open(docs_path, 'r', encoding='utf-8') as file:
20
+ return json.load(file)
21
+
22
+ # Specify the path to your JSON file
23
+ docs_path = 'documents.json'
24
+ documents = load_documents(docs_path)
25
+ dimension = 1536
26
+
27
+ def get_embeddings(text):
28
+ response = client.embeddings.create(
29
+ model="text-embedding-3-small",
30
+ input = [text]
31
+ )
32
+ embedding = response.data[0].embedding
33
+ return np.array(embedding, dtype='float32')
34
+
35
+ # Step 3: Function to search FAISS index
36
+ def search_index(query, k=3):
37
+ # Convert query to an embedding
38
+ query_vector = get_embeddings(query).reshape(1, -1).astype('float32')
39
+
40
+ # Check if the index is not empty before searching
41
+ if index.ntotal == 0:
42
+ return "No documents in the index."
43
+
44
+ # Search the FAISS index for the nearest neighbors
45
+ distances, indices = index.search(query_vector, k)
46
+
47
+ # Retrieve the top matching documents
48
+ results = [documents[i] for i in indices[0] if i != -1]
49
+
50
+ if results:
51
+ return "\n\n".join(results)
52
+ else:
53
+ return "No relevant documents found."
54
+
55
+ # Step 4: Function to generate a response using OpenAI's GPT
56
+ def generate_response(context, user_input):
57
+ prompt = f"{context}\n\nUser: {user_input}\nAssistant:"
58
+
59
+ response = client.chat.completions.create(
60
+ model="gpt-4o-mini",
61
+ messages=[{"role": "system", "content": "You are a helpful assistant."},
62
+ {"role": "user", "content": prompt}],
63
+ # stream=True,
64
+ )
65
+
66
+ # for chunk in stream:
67
+ # if chunk.choices[0].delta.content is not None:
68
+ # print(chunk.choices[0].delta.content, end="")
69
+ return response.choices[0].message.content
70
+
71
+ # Step 5: Gradio chatbot function
72
+ def chatbot_interface(user_input, chat_history):
73
+ # Step 5.1: Retrieve context using FAISS
74
+ context = search_index(user_input)
75
+
76
+ # Step 5.2: Generate a response using OpenAI GPT model
77
+ response = generate_response(context, user_input)
78
+
79
+ # Step 5.3: Update chat history
80
+ chat_history.append((user_input, response))
81
+ return chat_history, chat_history
82
+
83
+ def chat_gen(message, history):
84
+ history_openai_format = []
85
+ context = search_index(message)
86
+ prompt = f"{context}\n\nUser: {message}\nAssistant:"
87
+
88
+ response = client.chat.completions.create(
89
+ model="gpt-4o-mini",
90
+ messages=[{"role": "system", "content": "You are a helpful assistant."},
91
+ {"role": "user", "content": prompt}],
92
+ stream=True,
93
+ )
94
+ partial_message = ""
95
+ for chunk in response:
96
+ if chunk.choices[0].delta.content is not None:
97
+ partial_message = partial_message + chunk.choices[0].delta.content
98
+ yield partial_message
99
+
100
+ initial_msg = "Hello! I am DII assistant. You can ask me anything about DDI program. I am happy to assist you."
101
+ chatbot = gr.Chatbot(value = [[None, initial_msg]])
102
+ demo = gr.ChatInterface(chat_gen, chatbot=chatbot).queue()
103
+
104
+
105
+ try:
106
+ demo.launch(debug=True, share=False, show_api=False)
107
+ demo.close()
108
+ except Exception as e:
109
+ demo.close()
110
+ print(e)
111
+ raise e
documents.json ADDED
The diff for this file is too large to render. See raw diff
 
faiss_index.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ddae79299503fe170726dd027d6d4cf5a7057ebdcd36a186d0a643f981a1c4b0
3
+ size 755757
requirements.txt ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ annotated-types==0.7.0
3
+ anyio==4.6.2.post1
4
+ appnope==0.1.4
5
+ asttokens==2.4.1
6
+ certifi==2024.8.30
7
+ charset-normalizer==3.4.0
8
+ click==8.1.7
9
+ comm==0.2.2
10
+ debugpy==1.8.8
11
+ decorator==5.1.1
12
+ distro==1.9.0
13
+ et_xmlfile==2.0.0
14
+ executing==2.1.0
15
+ faiss-cpu==1.9.0
16
+ fastapi==0.115.4
17
+ ffmpy==0.4.0
18
+ filelock==3.16.1
19
+ fsspec==2024.10.0
20
+ gradio==5.5.0
21
+ gradio_client==1.4.2
22
+ h11==0.14.0
23
+ httpcore==1.0.6
24
+ httpx==0.27.2
25
+ huggingface-hub==0.26.2
26
+ idna==3.10
27
+ ipykernel==6.29.5
28
+ ipython==8.29.0
29
+ jedi==0.19.2
30
+ Jinja2==3.1.4
31
+ jiter==0.7.0
32
+ joblib==1.4.2
33
+ jupyter_client==8.6.3
34
+ jupyter_core==5.7.2
35
+ lxml==5.3.0
36
+ markdown-it-py==3.0.0
37
+ MarkupSafe==2.1.5
38
+ matplotlib-inline==0.1.7
39
+ mdurl==0.1.2
40
+ mpmath==1.3.0
41
+ nest-asyncio==1.6.0
42
+ networkx==3.4.2
43
+ numpy==2.1.3
44
+ openai==1.54.3
45
+ openpyxl==3.1.5
46
+ orjson==3.10.11
47
+ packaging==24.2
48
+ pandas==2.2.3
49
+ parso==0.8.4
50
+ pexpect==4.9.0
51
+ pillow==11.0.0
52
+ platformdirs==4.3.6
53
+ prompt_toolkit==3.0.48
54
+ psutil==6.1.0
55
+ ptyprocess==0.7.0
56
+ pure_eval==0.2.3
57
+ pydantic==2.9.2
58
+ pydantic_core==2.23.4
59
+ pydub==0.25.1
60
+ Pygments==2.18.0
61
+ PyPDF2==3.0.1
62
+ pytesseract==0.3.13
63
+ python-dateutil==2.9.0.post0
64
+ python-docx==1.1.2
65
+ python-multipart==0.0.12
66
+ pytz==2024.2
67
+ PyYAML==6.0.2
68
+ pyzmq==26.2.0
69
+ regex==2024.11.6
70
+ requests==2.32.3
71
+ rich==13.9.4
72
+ ruff==0.7.3
73
+ safehttpx==0.1.1
74
+ safetensors==0.4.5
75
+ scikit-learn==1.5.2
76
+ scipy==1.14.1
77
+ semantic-version==2.10.0
78
+ sentence-transformers==3.3.0
79
+ setuptools==75.4.0
80
+ shellingham==1.5.4
81
+ six==1.16.0
82
+ sniffio==1.3.1
83
+ stack-data==0.6.3
84
+ starlette==0.41.2
85
+ sympy==1.13.1
86
+ threadpoolctl==3.5.0
87
+ tokenizers==0.20.3
88
+ tomlkit==0.12.0
89
+ torch==2.5.1
90
+ tornado==6.4.1
91
+ tqdm==4.67.0
92
+ traitlets==5.14.3
93
+ transformers==4.46.2
94
+ typer==0.13.0
95
+ typing_extensions==4.12.2
96
+ tzdata==2024.2
97
+ urllib3==2.2.3
98
+ uvicorn==0.32.0
99
+ wcwidth==0.2.13
100
+ websockets==12.0