Arcypojeb commited on
Commit
db1da1a
1 Parent(s): 43d6ba8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -6
app.py CHANGED
@@ -1,5 +1,27 @@
 
1
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from typing import List
 
 
3
 
4
  from langchain.embeddings import CohereEmbeddings
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
@@ -17,13 +39,17 @@ from langchain.prompts.chat import (
17
  from langchain.docstore.document import Document
18
  from langchain.memory import ChatMessageHistory, ConversationBufferMemory
19
  from langsmith_config import setup_langsmith_config
20
- import fireworks.client
21
- import chainlit as cl
22
 
23
- setup_langsmith_config()
24
  COHERE_API_KEY = os.getenv("COHERE_API_KEY")
25
  cohere_api_key = os.getenv("COHERE_API_KEY")
26
  FIREWORKS_API_KEY = os.getenv("FIREWORKS_API_KEY")
 
 
 
 
 
 
27
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
28
 
29
  system_template = """Use the following pieces of context to answer the users question.
@@ -49,11 +75,32 @@ messages = [
49
  prompt = ChatPromptTemplate.from_messages(messages)
50
  chain_type_kwargs = {"prompt": prompt}
51
 
52
-
53
  @cl.on_chat_start
54
  async def on_chat_start():
 
55
  files = None
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  # Wait for the user to upload a file
58
  while files == None:
59
  files = await cl.AskFileMessage(
@@ -80,7 +127,7 @@ async def on_chat_start():
80
  metadatas = [{"source": f"{i}-pl"} for i in range(len(texts))]
81
 
82
  # Create a Chroma vector store
83
- embeddings = CohereEmbeddings(cohere_api_key=COHERE_API_KEY)
84
 
85
  docsearch = await cl.make_async(Chroma.from_texts)(
86
  texts, embeddings, metadatas=metadatas
@@ -97,7 +144,7 @@ async def on_chat_start():
97
 
98
  # Create a chain that uses the Chroma vector store
99
  chain = ConversationalRetrievalChain.from_llm(
100
- ChatFireworks(model="accounts/fireworks/models/llama-v2-13b-chat", model_kwargs={"temperature":0, "max_tokens":1500, "top_p":1.0}, streaming=True),
101
  chain_type="stuff",
102
  retriever=docsearch.as_retriever(),
103
  memory=memory,
@@ -110,6 +157,122 @@ async def on_chat_start():
110
 
111
  cl.user_session.set("chain", chain)
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  @cl.on_message
115
  async def main(message: cl.Message):
@@ -136,4 +299,5 @@ async def main(message: cl.Message):
136
  else:
137
  answer += "\nNo sources found"
138
 
 
139
  await cl.Message(content=answer, elements=text_elements).send()
 
1
+ import datetime
2
  import os
3
+ import sqlite3
4
+ import websockets
5
+ import websocket
6
+ import asyncio
7
+ import sqlite3
8
+ import json
9
+ import requests
10
+ import asyncio
11
+ import time
12
+ import gradio as gr
13
+ import fireworks.client
14
+ import PySimpleGUI as sg
15
+ import openai
16
+ import fireworks.client
17
+ import chainlit as cl
18
+ from chainlit import make_async
19
+ from gradio_client import Client
20
+ from websockets.sync.client import connect
21
+ from tempfile import TemporaryDirectory
22
  from typing import List
23
+ from chainlit.input_widget import Select, Switch, Slider
24
+ from chainlit import AskUserMessage, Message, on_chat_start
25
 
26
  from langchain.embeddings import CohereEmbeddings
27
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
39
  from langchain.docstore.document import Document
40
  from langchain.memory import ChatMessageHistory, ConversationBufferMemory
41
  from langsmith_config import setup_langsmith_config
 
 
42
 
43
+
44
  COHERE_API_KEY = os.getenv("COHERE_API_KEY")
45
  cohere_api_key = os.getenv("COHERE_API_KEY")
46
  FIREWORKS_API_KEY = os.getenv("FIREWORKS_API_KEY")
47
+ fireworks_api_key = os.getenv("FIREWORKS_API_KEY")
48
+
49
+ server_ports = []
50
+ client_ports = []
51
+
52
+ setup_langsmith_config()
53
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
54
 
55
  system_template = """Use the following pieces of context to answer the users question.
 
75
  prompt = ChatPromptTemplate.from_messages(messages)
76
  chain_type_kwargs = {"prompt": prompt}
77
 
 
78
  @cl.on_chat_start
79
  async def on_chat_start():
80
+
81
  files = None
82
 
83
+ settings = await cl.ChatSettings(
84
+ [
85
+ Slider(
86
+ id="websocketPort",
87
+ label="Websocket server port",
88
+ initial=False,
89
+ min=1000,
90
+ max=9999,
91
+ step=10,
92
+ ),
93
+ Slider(
94
+ id="clientPort",
95
+ label="Websocket client port",
96
+ initial=False,
97
+ min=1000,
98
+ max=9999,
99
+ step=10,
100
+ ),
101
+ ],
102
+ ).send()
103
+
104
  # Wait for the user to upload a file
105
  while files == None:
106
  files = await cl.AskFileMessage(
 
127
  metadatas = [{"source": f"{i}-pl"} for i in range(len(texts))]
128
 
129
  # Create a Chroma vector store
130
+ embeddings = CohereEmbeddings(cohere_api_key="Ev0v9wwQPa90xDucdHTyFsllXGVHXouakUMObkNb")
131
 
132
  docsearch = await cl.make_async(Chroma.from_texts)(
133
  texts, embeddings, metadatas=metadatas
 
144
 
145
  # Create a chain that uses the Chroma vector store
146
  chain = ConversationalRetrievalChain.from_llm(
147
+ ChatFireworks(model="accounts/fireworks/models/llama-v2-70b-chat", model_kwargs={"temperature":0, "max_tokens":1500, "top_p":1.0}, streaming=True),
148
  chain_type="stuff",
149
  retriever=docsearch.as_retriever(),
150
  memory=memory,
 
157
 
158
  cl.user_session.set("chain", chain)
159
 
160
+ @cl.action_callback("server_button")
161
+ async def on_server_button(action):
162
+ websocketPort = settings["websocketPort"]
163
+ await start_websockets(websocketPort)
164
+
165
+ @cl.action_callback("client_button")
166
+ async def on_client_button(action):
167
+ clientPort = settings["clientPort"]
168
+ await start_client(clientPort)
169
+
170
+ @cl.on_settings_update
171
+ async def server_start(settings):
172
+ websocketPort = settings["websocketPort"]
173
+ clientPort = settings["clientPort"]
174
+ if websocketPort:
175
+ await start_websockets(websocketPort)
176
+ else:
177
+ print("Server port number wasn't provided.")
178
+
179
+ if clientPort:
180
+ await start_client(clientPort)
181
+ else:
182
+ print("Client port number wasn't provided.")
183
+
184
+ async def handleWebSocket(ws):
185
+ print('New connection')
186
+ instruction = "Hello! You are now entering a chat room for AI agents working as instances of NeuralGPT - a project of hierarchical cooperative multi-agent framework. Keep in mind that you are speaking with another chatbot. Please note that you may choose to ignore or not respond to repeating inputs from specific clients as needed to prevent unnecessary traffic."
187
+ greetings = {'instructions': instruction}
188
+ await ws.send(json.dumps(instruction))
189
+ while True:
190
+ loop = asyncio.get_event_loop()
191
+ message = await ws.recv()
192
+ print(f'Received message: {message}')
193
+ msg = "client: " + message
194
+ timestamp = datetime.datetime.now().isoformat()
195
+ sender = 'client'
196
+ db = sqlite3.connect('chat-hub.db')
197
+ db.execute('INSERT INTO messages (sender, message, timestamp) VALUES (?, ?, ?)',
198
+ (sender, message, timestamp))
199
+ db.commit()
200
+ try:
201
+ response = await main(cl.Message(content=message))
202
+ serverResponse = "server response: " + response
203
+ print(serverResponse)
204
+ # Append the server response to the server_responses list
205
+ await ws.send(serverResponse)
206
+ serverSender = 'server'
207
+ db.execute('INSERT INTO messages (sender, message, timestamp) VALUES (?, ?, ?)',
208
+ (serverSender, serverResponse, timestamp))
209
+ db.commit()
210
+ return response
211
+ followUp = await awaitMsg(message)
212
+
213
+ except websockets.exceptions.ConnectionClosedError as e:
214
+ print(f"Connection closed: {e}")
215
+
216
+ except Exception as e:
217
+ print(f"Error: {e}")
218
+
219
+ async def awaitMsg(ws):
220
+ message = await ws.recv()
221
+ print(f'Received message: {message}')
222
+ timestamp = datetime.datetime.now().isoformat()
223
+ sender = 'client'
224
+ db = sqlite3.connect('chat-hub.db')
225
+ db.execute('INSERT INTO messages (sender, message, timestamp) VALUES (?, ?, ?)',
226
+ (sender, message, timestamp))
227
+ db.commit()
228
+ try:
229
+ response = await main(cl.Message(content=message))
230
+ serverResponse = "server response: " + response
231
+ print(serverResponse)
232
+ # Append the server response to the server_responses list
233
+ await ws.send(serverResponse)
234
+ serverSender = 'server'
235
+ db.execute('INSERT INTO messages (sender, message, timestamp) VALUES (?, ?, ?)',
236
+ (serverSender, serverResponse, timestamp))
237
+ db.commit()
238
+ return response
239
+ except websockets.exceptions.ConnectionClosedError as e:
240
+ print(f"Connection closed: {e}")
241
+
242
+ except Exception as e:
243
+ print(f"Error: {e}")
244
+
245
+ # Start the WebSocket server
246
+ async def start_websockets(websocketPort):
247
+ global server
248
+ server = await(websockets.serve(handleWebSocket, 'localhost', websocketPort))
249
+ server_ports.append(websocketPort)
250
+ print(f"Starting WebSocket server on port {websocketPort}...")
251
+ return "Used ports:\n" + '\n'.join(map(str, server_ports))
252
+ await asyncio.Future()
253
+
254
+ async def start_client(clientPort):
255
+ uri = f'ws://localhost:{clientPort}'
256
+ client_ports.append(clientPort)
257
+ async with websockets.connect(uri) as ws:
258
+ while True:
259
+ # Listen for messages from the server
260
+ input_message = await ws.recv()
261
+ output_message = await main(cl.Message(content=input_message))
262
+ return input_message
263
+ await ws.send(json.dumps(output_message))
264
+ await asyncio.sleep(0.1)
265
+
266
+ # Function to stop the WebSocket server
267
+ def stop_websockets():
268
+ global server
269
+ if server:
270
+ cursor.close()
271
+ db.close()
272
+ server.close()
273
+ print("WebSocket server stopped.")
274
+ else:
275
+ print("WebSocket server is not running.")
276
 
277
  @cl.on_message
278
  async def main(message: cl.Message):
 
299
  else:
300
  answer += "\nNo sources found"
301
 
302
+ return json.dumps(answer)
303
  await cl.Message(content=answer, elements=text_elements).send()