KonradSzafer's picture
channel id added to config
bfdf8df
raw
history blame
No virus
4.86 kB
import json
import requests
from urllib.parse import quote
import discord
from typing import List
from qa_engine import logger, QAEngine
from discord_bot.client.utils import split_text_into_chunks
class DiscordClient(discord.Client):
"""
Discord Client class, used for interacting with a Discord server.
Args:
qa_service_url (str): The URL of the question answering service.
num_last_messages (int, optional): The number of previous messages to use as context for generating answers.
Defaults to 5.
use_names_in_context (bool, optional): Whether to include user names in the message context. Defaults to True.
enable_commands (bool, optional): Whether to enable commands for the bot. Defaults to True.
Attributes:
qa_service_url (str): The URL of the question answering service.
num_last_messages (int): The number of previous messages to use as context for generating answers.
use_names_in_context (bool): Whether to include user names in the message context.
enable_commands (bool): Whether to enable commands for the bot.
max_message_len (int): The maximum length of a message.
system_prompt (str): The system prompt to be used.
"""
def __init__(
self,
qa_engine: QAEngine,
channel_ids: list[int] = [],
num_last_messages: int = 5,
use_names_in_context: bool = True,
enable_commands: bool = True,
debug: bool = False
):
logger.info('Initializing Discord client...')
intents = discord.Intents.all()
intents.message_content = True
super().__init__(intents=intents, command_prefix='!')
assert num_last_messages >= 1, \
'The number of last messages in context should be at least 1'
self.qa_engine: QAEngine = qa_engine
self.channel_ids: list[int] = channel_ids
self.num_last_messages: int = num_last_messages
self.use_names_in_context: bool = use_names_in_context
self.enable_commands: bool = enable_commands
self.debug: bool = debug
self.min_messgae_len: int = 1800
self.max_message_len: int = 2000
async def on_ready(self):
"""
Callback function to be called when the client is ready.
"""
logger.info('Successfully logged in as: {0.user}'.format(self))
await self.change_presence(activity=discord.Game(name='Chatting...'))
async def get_last_messages(self, message) -> List[str]:
"""
Method to fetch recent messages from a message's channel.
Args:
message (Message): The discord Message object used to identify the channel.
Returns:
List[str]: Reversed list of recent messages from the channel,
excluding the input message. Messages may be prefixed with the author's name
if `self.use_names_in_context` is True.
"""
last_messages: List[str] = []
async for msg in message.channel.history(
limit=self.num_last_messages):
if self.use_names_in_context:
last_messages.append(f'{msg.author}: {msg.content}')
else:
last_messages.append(msg.content)
last_messages.reverse()
last_messages.pop() # remove last message from context
return last_messages
async def send_message(self, message, answer: str, sources: str):
chunks = split_text_into_chunks(
text=answer,
split_characters=['. ', ', ', '\n'],
min_size=self.min_messgae_len,
max_size=self.max_message_len
)
for chunk in chunks:
await message.channel.send(chunk)
await message.channel.send(sources)
async def on_message(self, message):
if self.channel_ids and message.channel.id not in self.channel_ids:
return
if message.author == self.user:
return
"""
if self.enable_commands and message.content.startswith('!'):
if message.content == '!clear':
await message.channel.purge()
return
"""
last_messages = await self.get_last_messages(message)
context = '\n'.join(last_messages)
logger.info('Received message: {0.content}'.format(message))
response = self.qa_engine.get_response(
question=message.content,
messages_context=context
)
logger.info('Sending response: {0}'.format(response))
try:
await self.send_message(
message,
response.get_answer(),
response.get_sources_as_text()
)
except Exception as e:
logger.error('Failed to send response: {0}'.format(e))