Spaces:
Runtime error
Runtime error
File size: 4,864 Bytes
c69cba4 bfdf8df c69cba4 bfdf8df c69cba4 bfdf8df c69cba4 bfdf8df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import json
import requests
from urllib.parse import quote
import discord
from typing import List
from qa_engine import logger, QAEngine
from discord_bot.client.utils import split_text_into_chunks
class DiscordClient(discord.Client):
"""
Discord Client class, used for interacting with a Discord server.
Args:
qa_service_url (str): The URL of the question answering service.
num_last_messages (int, optional): The number of previous messages to use as context for generating answers.
Defaults to 5.
use_names_in_context (bool, optional): Whether to include user names in the message context. Defaults to True.
enable_commands (bool, optional): Whether to enable commands for the bot. Defaults to True.
Attributes:
qa_service_url (str): The URL of the question answering service.
num_last_messages (int): The number of previous messages to use as context for generating answers.
use_names_in_context (bool): Whether to include user names in the message context.
enable_commands (bool): Whether to enable commands for the bot.
max_message_len (int): The maximum length of a message.
system_prompt (str): The system prompt to be used.
"""
def __init__(
self,
qa_engine: QAEngine,
channel_ids: list[int] = [],
num_last_messages: int = 5,
use_names_in_context: bool = True,
enable_commands: bool = True,
debug: bool = False
):
logger.info('Initializing Discord client...')
intents = discord.Intents.all()
intents.message_content = True
super().__init__(intents=intents, command_prefix='!')
assert num_last_messages >= 1, \
'The number of last messages in context should be at least 1'
self.qa_engine: QAEngine = qa_engine
self.channel_ids: list[int] = channel_ids
self.num_last_messages: int = num_last_messages
self.use_names_in_context: bool = use_names_in_context
self.enable_commands: bool = enable_commands
self.debug: bool = debug
self.min_messgae_len: int = 1800
self.max_message_len: int = 2000
async def on_ready(self):
"""
Callback function to be called when the client is ready.
"""
logger.info('Successfully logged in as: {0.user}'.format(self))
await self.change_presence(activity=discord.Game(name='Chatting...'))
async def get_last_messages(self, message) -> List[str]:
"""
Method to fetch recent messages from a message's channel.
Args:
message (Message): The discord Message object used to identify the channel.
Returns:
List[str]: Reversed list of recent messages from the channel,
excluding the input message. Messages may be prefixed with the author's name
if `self.use_names_in_context` is True.
"""
last_messages: List[str] = []
async for msg in message.channel.history(
limit=self.num_last_messages):
if self.use_names_in_context:
last_messages.append(f'{msg.author}: {msg.content}')
else:
last_messages.append(msg.content)
last_messages.reverse()
last_messages.pop() # remove last message from context
return last_messages
async def send_message(self, message, answer: str, sources: str):
chunks = split_text_into_chunks(
text=answer,
split_characters=['. ', ', ', '\n'],
min_size=self.min_messgae_len,
max_size=self.max_message_len
)
for chunk in chunks:
await message.channel.send(chunk)
await message.channel.send(sources)
async def on_message(self, message):
if self.channel_ids and message.channel.id not in self.channel_ids:
return
if message.author == self.user:
return
"""
if self.enable_commands and message.content.startswith('!'):
if message.content == '!clear':
await message.channel.purge()
return
"""
last_messages = await self.get_last_messages(message)
context = '\n'.join(last_messages)
logger.info('Received message: {0.content}'.format(message))
response = self.qa_engine.get_response(
question=message.content,
messages_context=context
)
logger.info('Sending response: {0}'.format(response))
try:
await self.send_message(
message,
response.get_answer(),
response.get_sources_as_text()
)
except Exception as e:
logger.error('Failed to send response: {0}'.format(e))
|