Spaces:
Runtime error
Runtime error
File size: 2,083 Bytes
43953d1 f8e9d85 43953d1 f8e9d85 43953d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import asyncio
import argparse
import pathlib
import os
import discord
from discord.ext import commands
import gradio as gr
import anyio
captioner = gr.Interface.load("olivierdehaene/git-large-coco", src="spaces")
lock = asyncio.Lock()
bot = commands.Bot("", intents=discord.Intents(messages=True, guilds=True))
@bot.event
async def on_ready():
print(f"Logged in as {bot.user}")
print(f"Running in {len(bot.guilds)} servers...")
async def run_prediction(img):
output = await anyio.to_thread.run_sync(captioner, img)
return output
def remove_tags(content: str) -> str:
content = content.replace("<@1040198143695933501>", "")
content = content.replace("<@1057338428938788884>", "")
return content.strip()
async def send_file_or_text(channel, file_or_text: str):
# if the file exists, send as a file
if pathlib.Path(str(file_or_text)).exists():
with open(file_or_text, "rb") as f:
return await channel.send(file=discord.File(f))
else:
return await channel.send(file_or_text)
async def make_prediction(message: discord.Message, content: str):
predictions = await run_prediction(message.attachments[0].url)
print(message.attachments[0].url)
if isinstance(predictions, (tuple, list)):
for p in predictions:
await send_file_or_text(message.channel, p)
else:
await send_file_or_text(message.channel, predictions)
return
@bot.event
async def on_message(message: discord.Message):
if message.author == bot.user:
return
if message.content:
content = remove_tags(message.content)
await make_prediction(message, content)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--token",
type=str,
help="API key for the Discord bot. You can set this to your Discord token if you'd like to make your own clone of the Gradio Bot.",
required=False,
default=os.getenv("DISCORD_TOKEN", ""),
)
args = parser.parse_args()
bot.run(args.token) |