zakerytclarke's picture
Update app.py
e802c16 verified
raw
history blame
3.04 kB
import streamlit as st
import os
import aiohttp
import asyncio
import discord
import pandas as pd
import requests
from teapotai import TeapotAI, TeapotAISettings
st.set_page_config(page_title="TeapotAI Discord Bot", page_icon=":robot_face:", layout="wide")
DISCORD_TOKEN = os.environ.get("discord_key")
# ========= CONFIG =========
CONFIG = {
# "OneTrainer": TeapotAI(
# documents=pd.read_csv("https://docs.google.com/spreadsheets/d/1NNbdQWIfVHq09lMhVSN36_SkGu6XgmKTXgBWPyQcBpk/export?gid=361556791&format=csv").content.str.split('\n\n').explode().reset_index(drop=True).to_list(),
# settings=TeapotAISettings(rag_num_results=7)
# ),
"Teapot AI": TeapotAI(
documents=pd.read_csv("https://docs.google.com/spreadsheets/d/1NNbdQWIfVHq09lMhVSN36_SkGu6XgmKTXgBWPyQcBpk/export?gid=1617599323&format=csv").content.str.split('\n\n').explode().reset_index(drop=True).to_list(),
settings=TeapotAISettings(rag_num_results=7)
),
}
# ========= SEARCH API =========
API_KEY = os.environ.get("brave_api_key")
def brave_search_context(query, count=1):
url = "https://api.search.brave.com/res/v1/web/search"
headers = {"Accept": "application/json", "X-Subscription-Token": API_KEY}
params = {"q": query, "count": count}
response = requests.get(url, headers=headers, params=params)
if response.status_code == 200:
results = response.json().get("web", {}).get("results", [])
print(results)
return "\n\n".join([res["title"]+"\n"+res["description"] for res in results])
else:
print(f"Error: {response.status_code}, {response.text}")
return ""
# ========= DISCORD CLIENT =========
intents = discord.Intents.default()
intents.messages = True
client = discord.Client(intents=intents)
async def handle_teapot_inference(server_name, user_input):
teapot_instance = CONFIG.get(server_name, CONFIG["Teapot AI"])
print(f"Using Teapot instance for server: {server_name}")
# Running query in a separate thread to avoid blocking the event loop
response = await asyncio.to_thread(teapot_instance.query, query=user_input, context=brave_search_context(user_input))
return response
@client.event
async def on_ready():
print(f'Logged in as {client.user}')
@client.event
async def on_message(message):
if message.author == client.user:
return
if f'<@{client.user.id}>' not in message.content:
return
server_name = message.guild.name if message.guild else "Teapot AI"
print(server_name, message.author, message.content)
async with message.channel.typing():
cleaned_message = message.content.replace(f'<@{client.user.id}>', "").strip()
response = await handle_teapot_inference(server_name, cleaned_message)
sent_message = await message.reply(response)
# ========= STREAMLIT =========
@st.cache_resource
def discord_loop():
st.session_state["initialized"] = True
client.run(DISCORD_TOKEN)
st.write("418 I'm a teapot")
return
discord_loop()