Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import os | |
| import aiohttp | |
| import asyncio | |
| import discord | |
| import pandas as pd | |
| import requests | |
| from teapotai import TeapotAI, TeapotAISettings | |
| st.set_page_config(page_title="TeapotAI Discord Bot", page_icon=":robot_face:", layout="wide") | |
| DISCORD_TOKEN = os.environ.get("discord_key") | |
| # ========= CONFIG ========= | |
| CONFIG = { | |
| # "OneTrainer": TeapotAI( | |
| # documents=pd.read_csv("https://docs.google.com/spreadsheets/d/1NNbdQWIfVHq09lMhVSN36_SkGu6XgmKTXgBWPyQcBpk/export?gid=361556791&format=csv").content.str.split('\n\n').explode().reset_index(drop=True).to_list(), | |
| # settings=TeapotAISettings(rag_num_results=7) | |
| # ), | |
| "Teapot AI": TeapotAI( | |
| documents=pd.read_csv("https://docs.google.com/spreadsheets/d/1NNbdQWIfVHq09lMhVSN36_SkGu6XgmKTXgBWPyQcBpk/export?gid=1617599323&format=csv").content.str.split('\n\n').explode().reset_index(drop=True).to_list(), | |
| settings=TeapotAISettings(rag_num_results=7) | |
| ), | |
| } | |
| # ========= SEARCH API ========= | |
| API_KEY = os.environ.get("brave_api_key") | |
| def brave_search_context(query, count=1): | |
| url = "https://api.search.brave.com/res/v1/web/search" | |
| headers = {"Accept": "application/json", "X-Subscription-Token": API_KEY} | |
| params = {"q": query, "count": count} | |
| response = requests.get(url, headers=headers, params=params) | |
| if response.status_code == 200: | |
| results = response.json().get("web", {}).get("results", []) | |
| print(results) | |
| return "\n\n".join([res["title"]+"\n"+res["description"] for res in results]) | |
| else: | |
| print(f"Error: {response.status_code}, {response.text}") | |
| return "" | |
| # ========= DISCORD CLIENT ========= | |
| intents = discord.Intents.default() | |
| intents.messages = True | |
| client = discord.Client(intents=intents) | |
| async def handle_teapot_inference(server_name, user_input): | |
| teapot_instance = CONFIG.get(server_name, CONFIG["Teapot AI"]) | |
| print(f"Using Teapot instance for server: {server_name}") | |
| # Running query in a separate thread to avoid blocking the event loop | |
| response = await asyncio.to_thread(teapot_instance.query, query=user_input, context=brave_search_context(user_input)) | |
| return response | |
| async def on_ready(): | |
| print(f'Logged in as {client.user}') | |
| async def on_message(message): | |
| if message.author == client.user: | |
| return | |
| if f'<@{client.user.id}>' not in message.content: | |
| return | |
| server_name = message.guild.name if message.guild else "Teapot AI" | |
| print(server_name, message.author, message.content) | |
| async with message.channel.typing(): | |
| cleaned_message = message.content.replace(f'<@{client.user.id}>', "").strip() | |
| response = await handle_teapot_inference(server_name, cleaned_message) | |
| sent_message = await message.reply(response) | |
| # ========= STREAMLIT ========= | |
| def discord_loop(): | |
| st.session_state["initialized"] = True | |
| client.run(DISCORD_TOKEN) | |
| st.write("418 I'm a teapot") | |
| return | |
| discord_loop() | |