File size: 2,083 Bytes
43953d1
 
 
f8e9d85
43953d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8e9d85
43953d1
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import asyncio
import argparse
import pathlib
import os


import discord
from discord.ext import commands
import gradio as gr
import anyio


captioner = gr.Interface.load("olivierdehaene/git-large-coco", src="spaces")

lock = asyncio.Lock()

bot = commands.Bot("", intents=discord.Intents(messages=True, guilds=True))


@bot.event
async def on_ready():
    print(f"Logged in as {bot.user}")
    print(f"Running in {len(bot.guilds)} servers...")


async def run_prediction(img):
    output = await anyio.to_thread.run_sync(captioner, img)
    return output


def remove_tags(content: str) -> str:
    content = content.replace("<@1040198143695933501>", "")
    content = content.replace("<@1057338428938788884>", "")
    return content.strip()


async def send_file_or_text(channel, file_or_text: str):
    # if the file exists, send as a file
    if pathlib.Path(str(file_or_text)).exists():
        with open(file_or_text, "rb") as f:
            return await channel.send(file=discord.File(f))
    else:
        return await channel.send(file_or_text)


async def make_prediction(message: discord.Message, content: str):
    
    predictions = await run_prediction(message.attachments[0].url)
    print(message.attachments[0].url)
    if isinstance(predictions, (tuple, list)):
        for p in predictions:
            await send_file_or_text(message.channel, p)
    else:
        await send_file_or_text(message.channel, predictions)
    return


@bot.event
async def on_message(message: discord.Message):
    if message.author == bot.user:
        return
    if message.content:
        content = remove_tags(message.content)
        await make_prediction(message, content)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--token",
        type=str,
        help="API key for the Discord bot. You can set this to your Discord token if you'd like to make your own clone of the Gradio Bot.",
        required=False,
        default=os.getenv("DISCORD_TOKEN", ""),
    )
    args = parser.parse_args()


    bot.run(args.token)