freddyaboulton HF staff commited on
Commit
43953d1
1 Parent(s): 98b1aea
Files changed (2) hide show
  1. app.py +85 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import argparse
3
+ import pathlib
4
+
5
+
6
+ import discord
7
+ from discord.ext import commands
8
+ import gradio as gr
9
+ import anyio
10
+
11
+
12
+ captioner = gr.Interface.load("olivierdehaene/git-large-coco", src="spaces")
13
+
14
+ lock = asyncio.Lock()
15
+
16
+ bot = commands.Bot("", intents=discord.Intents(messages=True, guilds=True))
17
+
18
+
19
+ @bot.event
20
+ async def on_ready():
21
+ print(f"Logged in as {bot.user}")
22
+ print(f"Running in {len(bot.guilds)} servers...")
23
+
24
+
25
+ async def run_prediction(img):
26
+ output = await anyio.to_thread.run_sync(captioner, img)
27
+ return output
28
+
29
+
30
+ def remove_tags(content: str) -> str:
31
+ content = content.replace("<@1040198143695933501>", "")
32
+ content = content.replace("<@1057338428938788884>", "")
33
+ return content.strip()
34
+
35
+
36
+ async def send_file_or_text(channel, file_or_text: str):
37
+ # if the file exists, send as a file
38
+ if pathlib.Path(str(file_or_text)).exists():
39
+ with open(file_or_text, "rb") as f:
40
+ return await channel.send(file=discord.File(f))
41
+ else:
42
+ return await channel.send(file_or_text)
43
+
44
+
45
+ async def make_prediction(message: discord.Message, content: str):
46
+
47
+ # params = re.split(r' (?=")', content)
48
+ # params = [p.strip("'\"") for p in params]
49
+ # print(params)
50
+ predictions = await run_prediction(message.attachments[0].url)
51
+ print(message.attachments[0].url)
52
+ if isinstance(predictions, (tuple, list)):
53
+ for p in predictions:
54
+ await send_file_or_text(message.channel, p)
55
+ else:
56
+ await send_file_or_text(message.channel, predictions)
57
+ return
58
+
59
+
60
+ @bot.event
61
+ async def on_message(message: discord.Message):
62
+ if message.author == bot.user:
63
+ return
64
+ if message.content:
65
+ content = remove_tags(message.content)
66
+ await make_prediction(message, content)
67
+
68
+
69
+
70
+
71
+
72
+
73
+ if __name__ == "__main__":
74
+ parser = argparse.ArgumentParser()
75
+ parser.add_argument(
76
+ "--token",
77
+ type=str,
78
+ help="API key for the Discord bot. You can set this to your Discord token if you'd like to make your own clone of the Gradio Bot.",
79
+ required=False,
80
+ default="",
81
+ )
82
+ args = parser.parse_args()
83
+
84
+
85
+ bot.run(args.token)
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ discord.py
2
+ https://gradio-builds.s3.amazonaws.com/5c7a924962474c5b981f9629bddd15eadab8b8c9/gradio-3.16.1-py3-none-any.whl
3
+ requests