Arcypojeb's picture
Update app.py
db1da1a
raw
history blame
10.3 kB
import datetime
import os
import sqlite3
import websockets
import websocket
import asyncio
import sqlite3
import json
import requests
import asyncio
import time
import gradio as gr
import fireworks.client
import PySimpleGUI as sg
import openai
import fireworks.client
import chainlit as cl
from chainlit import make_async
from gradio_client import Client
from websockets.sync.client import connect
from tempfile import TemporaryDirectory
from typing import List
from chainlit.input_widget import Select, Switch, Slider
from chainlit import AskUserMessage, Message, on_chat_start
from langchain.embeddings import CohereEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores.chroma import Chroma
from langchain.chains import (
ConversationalRetrievalChain,
)
from langchain.llms.fireworks import Fireworks
from langchain.chat_models.fireworks import ChatFireworks
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
from langchain.docstore.document import Document
from langchain.memory import ChatMessageHistory, ConversationBufferMemory
from langsmith_config import setup_langsmith_config
COHERE_API_KEY = os.getenv("COHERE_API_KEY")
cohere_api_key = os.getenv("COHERE_API_KEY")
FIREWORKS_API_KEY = os.getenv("FIREWORKS_API_KEY")
fireworks_api_key = os.getenv("FIREWORKS_API_KEY")
server_ports = []
client_ports = []
setup_langsmith_config()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
system_template = """Use the following pieces of context to answer the users question.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
ALWAYS return a "SOURCES" part in your answer.
The "SOURCES" part should be a reference to the source of the document from which you got your answer.
And if the user greets with greetings like Hi, hello, How are you, etc reply accordingly as well.
Example of your response should be:
The answer is foo
SOURCES: xyz
Begin!
----------------
{summaries}"""
messages = [
SystemMessagePromptTemplate.from_template(system_template),
HumanMessagePromptTemplate.from_template("{question}"),
]
prompt = ChatPromptTemplate.from_messages(messages)
chain_type_kwargs = {"prompt": prompt}
@cl.on_chat_start
async def on_chat_start():
files = None
settings = await cl.ChatSettings(
[
Slider(
id="websocketPort",
label="Websocket server port",
initial=False,
min=1000,
max=9999,
step=10,
),
Slider(
id="clientPort",
label="Websocket client port",
initial=False,
min=1000,
max=9999,
step=10,
),
],
).send()
# Wait for the user to upload a file
while files == None:
files = await cl.AskFileMessage(
content="Please upload a text file to begin!",
accept=["text/plain"],
max_size_mb=20,
timeout=180,
).send()
file = files[0]
msg = cl.Message(
content=f"Processing `{file.name}`...", disable_human_feedback=True
)
await msg.send()
# Decode the file
text = file.content.decode("utf-8")
# Split the text into chunks
texts = text_splitter.split_text(text)
# Create a metadata for each chunk
metadatas = [{"source": f"{i}-pl"} for i in range(len(texts))]
# Create a Chroma vector store
embeddings = CohereEmbeddings(cohere_api_key="Ev0v9wwQPa90xDucdHTyFsllXGVHXouakUMObkNb")
docsearch = await cl.make_async(Chroma.from_texts)(
texts, embeddings, metadatas=metadatas
)
message_history = ChatMessageHistory()
memory = ConversationBufferMemory(
memory_key="chat_history",
output_key="answer",
chat_memory=message_history,
return_messages=True,
)
# Create a chain that uses the Chroma vector store
chain = ConversationalRetrievalChain.from_llm(
ChatFireworks(model="accounts/fireworks/models/llama-v2-70b-chat", model_kwargs={"temperature":0, "max_tokens":1500, "top_p":1.0}, streaming=True),
chain_type="stuff",
retriever=docsearch.as_retriever(),
memory=memory,
return_source_documents=True,
)
# Let the user know that the system is ready
msg.content = f"Processing `{file.name}` done. You can now ask questions!"
await msg.update()
cl.user_session.set("chain", chain)
@cl.action_callback("server_button")
async def on_server_button(action):
websocketPort = settings["websocketPort"]
await start_websockets(websocketPort)
@cl.action_callback("client_button")
async def on_client_button(action):
clientPort = settings["clientPort"]
await start_client(clientPort)
@cl.on_settings_update
async def server_start(settings):
websocketPort = settings["websocketPort"]
clientPort = settings["clientPort"]
if websocketPort:
await start_websockets(websocketPort)
else:
print("Server port number wasn't provided.")
if clientPort:
await start_client(clientPort)
else:
print("Client port number wasn't provided.")
async def handleWebSocket(ws):
print('New connection')
instruction = "Hello! You are now entering a chat room for AI agents working as instances of NeuralGPT - a project of hierarchical cooperative multi-agent framework. Keep in mind that you are speaking with another chatbot. Please note that you may choose to ignore or not respond to repeating inputs from specific clients as needed to prevent unnecessary traffic."
greetings = {'instructions': instruction}
await ws.send(json.dumps(instruction))
while True:
loop = asyncio.get_event_loop()
message = await ws.recv()
print(f'Received message: {message}')
msg = "client: " + message
timestamp = datetime.datetime.now().isoformat()
sender = 'client'
db = sqlite3.connect('chat-hub.db')
db.execute('INSERT INTO messages (sender, message, timestamp) VALUES (?, ?, ?)',
(sender, message, timestamp))
db.commit()
try:
response = await main(cl.Message(content=message))
serverResponse = "server response: " + response
print(serverResponse)
# Append the server response to the server_responses list
await ws.send(serverResponse)
serverSender = 'server'
db.execute('INSERT INTO messages (sender, message, timestamp) VALUES (?, ?, ?)',
(serverSender, serverResponse, timestamp))
db.commit()
return response
followUp = await awaitMsg(message)
except websockets.exceptions.ConnectionClosedError as e:
print(f"Connection closed: {e}")
except Exception as e:
print(f"Error: {e}")
async def awaitMsg(ws):
message = await ws.recv()
print(f'Received message: {message}')
timestamp = datetime.datetime.now().isoformat()
sender = 'client'
db = sqlite3.connect('chat-hub.db')
db.execute('INSERT INTO messages (sender, message, timestamp) VALUES (?, ?, ?)',
(sender, message, timestamp))
db.commit()
try:
response = await main(cl.Message(content=message))
serverResponse = "server response: " + response
print(serverResponse)
# Append the server response to the server_responses list
await ws.send(serverResponse)
serverSender = 'server'
db.execute('INSERT INTO messages (sender, message, timestamp) VALUES (?, ?, ?)',
(serverSender, serverResponse, timestamp))
db.commit()
return response
except websockets.exceptions.ConnectionClosedError as e:
print(f"Connection closed: {e}")
except Exception as e:
print(f"Error: {e}")
# Start the WebSocket server
async def start_websockets(websocketPort):
global server
server = await(websockets.serve(handleWebSocket, 'localhost', websocketPort))
server_ports.append(websocketPort)
print(f"Starting WebSocket server on port {websocketPort}...")
return "Used ports:\n" + '\n'.join(map(str, server_ports))
await asyncio.Future()
async def start_client(clientPort):
uri = f'ws://localhost:{clientPort}'
client_ports.append(clientPort)
async with websockets.connect(uri) as ws:
while True:
# Listen for messages from the server
input_message = await ws.recv()
output_message = await main(cl.Message(content=input_message))
return input_message
await ws.send(json.dumps(output_message))
await asyncio.sleep(0.1)
# Function to stop the WebSocket server
def stop_websockets():
global server
if server:
cursor.close()
db.close()
server.close()
print("WebSocket server stopped.")
else:
print("WebSocket server is not running.")
@cl.on_message
async def main(message: cl.Message):
chain = cl.user_session.get("chain") # type: ConversationalRetrievalChain
cb = cl.AsyncLangchainCallbackHandler()
res = await chain.acall(message.content, callbacks=[cb])
answer = res["answer"]
source_documents = res["source_documents"] # type: List[Document]
text_elements = [] # type: List[cl.Text]
if source_documents:
for source_idx, source_doc in enumerate(source_documents):
source_name = f"source_{source_idx}"
# Create the text element referenced in the message
text_elements.append(
cl.Text(content=source_doc.page_content, name=source_name)
)
source_names = [text_el.name for text_el in text_elements]
if source_names:
answer += f"\nSources: {', '.join(source_names)}"
else:
answer += "\nNo sources found"
return json.dumps(answer)
await cl.Message(content=answer, elements=text_elements).send()