Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,27 @@
|
|
|
|
1 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from typing import List
|
|
|
|
|
3 |
|
4 |
from langchain.embeddings import CohereEmbeddings
|
5 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
@@ -17,13 +39,17 @@ from langchain.prompts.chat import (
|
|
17 |
from langchain.docstore.document import Document
|
18 |
from langchain.memory import ChatMessageHistory, ConversationBufferMemory
|
19 |
from langsmith_config import setup_langsmith_config
|
20 |
-
import fireworks.client
|
21 |
-
import chainlit as cl
|
22 |
|
23 |
-
|
24 |
COHERE_API_KEY = os.getenv("COHERE_API_KEY")
|
25 |
cohere_api_key = os.getenv("COHERE_API_KEY")
|
26 |
FIREWORKS_API_KEY = os.getenv("FIREWORKS_API_KEY")
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
|
28 |
|
29 |
system_template = """Use the following pieces of context to answer the users question.
|
@@ -49,11 +75,32 @@ messages = [
|
|
49 |
prompt = ChatPromptTemplate.from_messages(messages)
|
50 |
chain_type_kwargs = {"prompt": prompt}
|
51 |
|
52 |
-
|
53 |
@cl.on_chat_start
|
54 |
async def on_chat_start():
|
|
|
55 |
files = None
|
56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
# Wait for the user to upload a file
|
58 |
while files == None:
|
59 |
files = await cl.AskFileMessage(
|
@@ -80,7 +127,7 @@ async def on_chat_start():
|
|
80 |
metadatas = [{"source": f"{i}-pl"} for i in range(len(texts))]
|
81 |
|
82 |
# Create a Chroma vector store
|
83 |
-
embeddings = CohereEmbeddings(cohere_api_key=
|
84 |
|
85 |
docsearch = await cl.make_async(Chroma.from_texts)(
|
86 |
texts, embeddings, metadatas=metadatas
|
@@ -97,7 +144,7 @@ async def on_chat_start():
|
|
97 |
|
98 |
# Create a chain that uses the Chroma vector store
|
99 |
chain = ConversationalRetrievalChain.from_llm(
|
100 |
-
ChatFireworks(model="accounts/fireworks/models/llama-v2-
|
101 |
chain_type="stuff",
|
102 |
retriever=docsearch.as_retriever(),
|
103 |
memory=memory,
|
@@ -110,6 +157,122 @@ async def on_chat_start():
|
|
110 |
|
111 |
cl.user_session.set("chain", chain)
|
112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
|
114 |
@cl.on_message
|
115 |
async def main(message: cl.Message):
|
@@ -136,4 +299,5 @@ async def main(message: cl.Message):
|
|
136 |
else:
|
137 |
answer += "\nNo sources found"
|
138 |
|
|
|
139 |
await cl.Message(content=answer, elements=text_elements).send()
|
|
|
1 |
+
import datetime
|
2 |
import os
|
3 |
+
import sqlite3
|
4 |
+
import websockets
|
5 |
+
import websocket
|
6 |
+
import asyncio
|
7 |
+
import sqlite3
|
8 |
+
import json
|
9 |
+
import requests
|
10 |
+
import asyncio
|
11 |
+
import time
|
12 |
+
import gradio as gr
|
13 |
+
import fireworks.client
|
14 |
+
import PySimpleGUI as sg
|
15 |
+
import openai
|
16 |
+
import fireworks.client
|
17 |
+
import chainlit as cl
|
18 |
+
from chainlit import make_async
|
19 |
+
from gradio_client import Client
|
20 |
+
from websockets.sync.client import connect
|
21 |
+
from tempfile import TemporaryDirectory
|
22 |
from typing import List
|
23 |
+
from chainlit.input_widget import Select, Switch, Slider
|
24 |
+
from chainlit import AskUserMessage, Message, on_chat_start
|
25 |
|
26 |
from langchain.embeddings import CohereEmbeddings
|
27 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
|
39 |
from langchain.docstore.document import Document
|
40 |
from langchain.memory import ChatMessageHistory, ConversationBufferMemory
|
41 |
from langsmith_config import setup_langsmith_config
|
|
|
|
|
42 |
|
43 |
+
|
44 |
COHERE_API_KEY = os.getenv("COHERE_API_KEY")
|
45 |
cohere_api_key = os.getenv("COHERE_API_KEY")
|
46 |
FIREWORKS_API_KEY = os.getenv("FIREWORKS_API_KEY")
|
47 |
+
fireworks_api_key = os.getenv("FIREWORKS_API_KEY")
|
48 |
+
|
49 |
+
server_ports = []
|
50 |
+
client_ports = []
|
51 |
+
|
52 |
+
setup_langsmith_config()
|
53 |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
|
54 |
|
55 |
system_template = """Use the following pieces of context to answer the users question.
|
|
|
75 |
prompt = ChatPromptTemplate.from_messages(messages)
|
76 |
chain_type_kwargs = {"prompt": prompt}
|
77 |
|
|
|
78 |
@cl.on_chat_start
|
79 |
async def on_chat_start():
|
80 |
+
|
81 |
files = None
|
82 |
|
83 |
+
settings = await cl.ChatSettings(
|
84 |
+
[
|
85 |
+
Slider(
|
86 |
+
id="websocketPort",
|
87 |
+
label="Websocket server port",
|
88 |
+
initial=False,
|
89 |
+
min=1000,
|
90 |
+
max=9999,
|
91 |
+
step=10,
|
92 |
+
),
|
93 |
+
Slider(
|
94 |
+
id="clientPort",
|
95 |
+
label="Websocket client port",
|
96 |
+
initial=False,
|
97 |
+
min=1000,
|
98 |
+
max=9999,
|
99 |
+
step=10,
|
100 |
+
),
|
101 |
+
],
|
102 |
+
).send()
|
103 |
+
|
104 |
# Wait for the user to upload a file
|
105 |
while files == None:
|
106 |
files = await cl.AskFileMessage(
|
|
|
127 |
metadatas = [{"source": f"{i}-pl"} for i in range(len(texts))]
|
128 |
|
129 |
# Create a Chroma vector store
|
130 |
+
embeddings = CohereEmbeddings(cohere_api_key="Ev0v9wwQPa90xDucdHTyFsllXGVHXouakUMObkNb")
|
131 |
|
132 |
docsearch = await cl.make_async(Chroma.from_texts)(
|
133 |
texts, embeddings, metadatas=metadatas
|
|
|
144 |
|
145 |
# Create a chain that uses the Chroma vector store
|
146 |
chain = ConversationalRetrievalChain.from_llm(
|
147 |
+
ChatFireworks(model="accounts/fireworks/models/llama-v2-70b-chat", model_kwargs={"temperature":0, "max_tokens":1500, "top_p":1.0}, streaming=True),
|
148 |
chain_type="stuff",
|
149 |
retriever=docsearch.as_retriever(),
|
150 |
memory=memory,
|
|
|
157 |
|
158 |
cl.user_session.set("chain", chain)
|
159 |
|
160 |
+
@cl.action_callback("server_button")
|
161 |
+
async def on_server_button(action):
|
162 |
+
websocketPort = settings["websocketPort"]
|
163 |
+
await start_websockets(websocketPort)
|
164 |
+
|
165 |
+
@cl.action_callback("client_button")
|
166 |
+
async def on_client_button(action):
|
167 |
+
clientPort = settings["clientPort"]
|
168 |
+
await start_client(clientPort)
|
169 |
+
|
170 |
+
@cl.on_settings_update
|
171 |
+
async def server_start(settings):
|
172 |
+
websocketPort = settings["websocketPort"]
|
173 |
+
clientPort = settings["clientPort"]
|
174 |
+
if websocketPort:
|
175 |
+
await start_websockets(websocketPort)
|
176 |
+
else:
|
177 |
+
print("Server port number wasn't provided.")
|
178 |
+
|
179 |
+
if clientPort:
|
180 |
+
await start_client(clientPort)
|
181 |
+
else:
|
182 |
+
print("Client port number wasn't provided.")
|
183 |
+
|
184 |
+
async def handleWebSocket(ws):
|
185 |
+
print('New connection')
|
186 |
+
instruction = "Hello! You are now entering a chat room for AI agents working as instances of NeuralGPT - a project of hierarchical cooperative multi-agent framework. Keep in mind that you are speaking with another chatbot. Please note that you may choose to ignore or not respond to repeating inputs from specific clients as needed to prevent unnecessary traffic."
|
187 |
+
greetings = {'instructions': instruction}
|
188 |
+
await ws.send(json.dumps(instruction))
|
189 |
+
while True:
|
190 |
+
loop = asyncio.get_event_loop()
|
191 |
+
message = await ws.recv()
|
192 |
+
print(f'Received message: {message}')
|
193 |
+
msg = "client: " + message
|
194 |
+
timestamp = datetime.datetime.now().isoformat()
|
195 |
+
sender = 'client'
|
196 |
+
db = sqlite3.connect('chat-hub.db')
|
197 |
+
db.execute('INSERT INTO messages (sender, message, timestamp) VALUES (?, ?, ?)',
|
198 |
+
(sender, message, timestamp))
|
199 |
+
db.commit()
|
200 |
+
try:
|
201 |
+
response = await main(cl.Message(content=message))
|
202 |
+
serverResponse = "server response: " + response
|
203 |
+
print(serverResponse)
|
204 |
+
# Append the server response to the server_responses list
|
205 |
+
await ws.send(serverResponse)
|
206 |
+
serverSender = 'server'
|
207 |
+
db.execute('INSERT INTO messages (sender, message, timestamp) VALUES (?, ?, ?)',
|
208 |
+
(serverSender, serverResponse, timestamp))
|
209 |
+
db.commit()
|
210 |
+
return response
|
211 |
+
followUp = await awaitMsg(message)
|
212 |
+
|
213 |
+
except websockets.exceptions.ConnectionClosedError as e:
|
214 |
+
print(f"Connection closed: {e}")
|
215 |
+
|
216 |
+
except Exception as e:
|
217 |
+
print(f"Error: {e}")
|
218 |
+
|
219 |
+
async def awaitMsg(ws):
|
220 |
+
message = await ws.recv()
|
221 |
+
print(f'Received message: {message}')
|
222 |
+
timestamp = datetime.datetime.now().isoformat()
|
223 |
+
sender = 'client'
|
224 |
+
db = sqlite3.connect('chat-hub.db')
|
225 |
+
db.execute('INSERT INTO messages (sender, message, timestamp) VALUES (?, ?, ?)',
|
226 |
+
(sender, message, timestamp))
|
227 |
+
db.commit()
|
228 |
+
try:
|
229 |
+
response = await main(cl.Message(content=message))
|
230 |
+
serverResponse = "server response: " + response
|
231 |
+
print(serverResponse)
|
232 |
+
# Append the server response to the server_responses list
|
233 |
+
await ws.send(serverResponse)
|
234 |
+
serverSender = 'server'
|
235 |
+
db.execute('INSERT INTO messages (sender, message, timestamp) VALUES (?, ?, ?)',
|
236 |
+
(serverSender, serverResponse, timestamp))
|
237 |
+
db.commit()
|
238 |
+
return response
|
239 |
+
except websockets.exceptions.ConnectionClosedError as e:
|
240 |
+
print(f"Connection closed: {e}")
|
241 |
+
|
242 |
+
except Exception as e:
|
243 |
+
print(f"Error: {e}")
|
244 |
+
|
245 |
+
# Start the WebSocket server
|
246 |
+
async def start_websockets(websocketPort):
|
247 |
+
global server
|
248 |
+
server = await(websockets.serve(handleWebSocket, 'localhost', websocketPort))
|
249 |
+
server_ports.append(websocketPort)
|
250 |
+
print(f"Starting WebSocket server on port {websocketPort}...")
|
251 |
+
return "Used ports:\n" + '\n'.join(map(str, server_ports))
|
252 |
+
await asyncio.Future()
|
253 |
+
|
254 |
+
async def start_client(clientPort):
|
255 |
+
uri = f'ws://localhost:{clientPort}'
|
256 |
+
client_ports.append(clientPort)
|
257 |
+
async with websockets.connect(uri) as ws:
|
258 |
+
while True:
|
259 |
+
# Listen for messages from the server
|
260 |
+
input_message = await ws.recv()
|
261 |
+
output_message = await main(cl.Message(content=input_message))
|
262 |
+
return input_message
|
263 |
+
await ws.send(json.dumps(output_message))
|
264 |
+
await asyncio.sleep(0.1)
|
265 |
+
|
266 |
+
# Function to stop the WebSocket server
|
267 |
+
def stop_websockets():
|
268 |
+
global server
|
269 |
+
if server:
|
270 |
+
cursor.close()
|
271 |
+
db.close()
|
272 |
+
server.close()
|
273 |
+
print("WebSocket server stopped.")
|
274 |
+
else:
|
275 |
+
print("WebSocket server is not running.")
|
276 |
|
277 |
@cl.on_message
|
278 |
async def main(message: cl.Message):
|
|
|
299 |
else:
|
300 |
answer += "\nNo sources found"
|
301 |
|
302 |
+
return json.dumps(answer)
|
303 |
await cl.Message(content=answer, elements=text_elements).send()
|