import asyncio import argparse import pathlib import os import discord from discord.ext import commands import gradio as gr import anyio captioner = gr.Interface.load("olivierdehaene/git-large-coco", src="spaces") lock = asyncio.Lock() bot = commands.Bot("", intents=discord.Intents(messages=True, guilds=True)) @bot.event async def on_ready(): print(f"Logged in as {bot.user}") print(f"Running in {len(bot.guilds)} servers...") async def run_prediction(img): output = await anyio.to_thread.run_sync(captioner, img) return output def remove_tags(content: str) -> str: content = content.replace("<@1040198143695933501>", "") content = content.replace("<@1057338428938788884>", "") return content.strip() async def send_file_or_text(channel, file_or_text: str): # if the file exists, send as a file if pathlib.Path(str(file_or_text)).exists(): with open(file_or_text, "rb") as f: return await channel.send(file=discord.File(f)) else: return await channel.send(file_or_text) async def make_prediction(message: discord.Message, content: str): predictions = await run_prediction(message.attachments[0].url) print(message.attachments[0].url) if isinstance(predictions, (tuple, list)): for p in predictions: await send_file_or_text(message.channel, p) else: await send_file_or_text(message.channel, predictions) return @bot.event async def on_message(message: discord.Message): if message.author == bot.user: return if message.content: content = remove_tags(message.content) await make_prediction(message, content) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--token", type=str, help="API key for the Discord bot. You can set this to your Discord token if you'd like to make your own clone of the Gradio Bot.", required=False, default=os.getenv("DISCORD_TOKEN", ""), ) args = parser.parse_args() bot.run(args.token)