abidlabs HF staff commited on
Commit
d115945
1 Parent(s): 81bbd31

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +175 -0
app.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import argparse
3
+ from collections import Counter
4
+ import json
5
+ import pathlib
6
+ import re
7
+
8
+
9
+ import discord
10
+ from discord.ext import commands
11
+ import gradio as gr
12
+ from gradio import utils
13
+ import requests
14
+
15
+ from typing import Dict, List
16
+
17
+ from utils import *
18
+
19
+
20
+ lock = asyncio.Lock()
21
+
22
+ bot = commands.Bot("", intents=discord.Intents(messages=True, guilds=True))
23
+
24
+
25
+ GUILD_SPACES_FILE = "guild_spaces.pkl"
26
+
27
+
28
+ if pathlib.Path(GUILD_SPACES_FILE).exists():
29
+ guild_spaces = read_pickle_file(GUILD_SPACES_FILE)
30
+ assert isinstance(guild_spaces, dict), f"{GUILD_SPACES_FILE} in invalid format."
31
+ guild_blocks = {}
32
+ delete_keys = []
33
+ for k, v in guild_spaces.items():
34
+ try:
35
+ guild_blocks[k] = gr.Interface.load(v, src="spaces")
36
+ except ValueError:
37
+ delete_keys.append(k)
38
+ for k in delete_keys:
39
+ del guild_spaces[k]
40
+ else:
41
+ guild_spaces: Dict[int, str] = {}
42
+ guild_blocks: Dict[int, gr.Blocks] = {}
43
+
44
+
45
+ HASHED_USERS_FILE = "users.pkl"
46
+
47
+ if pathlib.Path(HASHED_USERS_FILE).exists():
48
+ hashed_users = read_pickle_file(HASHED_USERS_FILE)
49
+ assert isinstance(hashed_users, list), f"{HASHED_USERS_FILE} in invalid format."
50
+ else:
51
+ hashed_users: List[str] = []
52
+
53
+
54
+ @bot.event
55
+ async def on_ready():
56
+ print(f"Logged in as {bot.user}")
57
+ print(f"Running in {len(bot.guilds)} servers...")
58
+
59
+
60
+ async def run_prediction(space: gr.Blocks, *inputs):
61
+ inputs = list(inputs)
62
+ fn_index = 0
63
+ processed_inputs = space.serialize_data(fn_index=fn_index, inputs=inputs)
64
+ batch = space.dependencies[fn_index]["batch"]
65
+
66
+ if batch:
67
+ processed_inputs = [[inp] for inp in processed_inputs]
68
+
69
+ outputs = await space.process_api(
70
+ fn_index=fn_index, inputs=processed_inputs, request=None, state={}
71
+ )
72
+ outputs = outputs["data"]
73
+
74
+ if batch:
75
+ outputs = [out[0] for out in outputs]
76
+
77
+ processed_outputs = space.deserialize_data(fn_index, outputs)
78
+ processed_outputs = utils.resolve_singleton(processed_outputs)
79
+
80
+ return processed_outputs
81
+
82
+
83
+ async def display_stats(message: discord.Message):
84
+ await message.channel.send(
85
+ f"Running in {len(bot.guilds)} servers\n"
86
+ f"Total # of users: {len(hashed_users)}\n"
87
+ f"------------------"
88
+ )
89
+ await message.channel.send(f"Most popular spaces:")
90
+ # display the top 10 most frequently occurring strings and their counts
91
+ spaces = guild_spaces.values()
92
+ counts = Counter(spaces)
93
+ for space, count in counts.most_common(10):
94
+ await message.channel.send(f"- {space}: {count}")
95
+
96
+
97
+ async def load_space(guild: discord.Guild, message: discord.Message, content: str):
98
+ iframe_url = (
99
+ requests.get(f"https://huggingface.co/api/spaces/{content}/host")
100
+ .json()
101
+ .get("host")
102
+ )
103
+ if iframe_url is None:
104
+ return await message.channel.send(
105
+ f"Space: {content} not found. If you'd like to make a prediction, enclose the inputs in quotation marks."
106
+ )
107
+ else:
108
+ await message.channel.send(
109
+ f"Loading Space: https://huggingface.co/spaces/{content}..."
110
+ )
111
+ interface = gr.Interface.load(content, src="spaces")
112
+ guild_spaces[guild.id] = content
113
+ guild_blocks[guild.id] = interface
114
+ asyncio.create_task(update_pickle_file(guild_spaces, GUILD_SPACES_FILE))
115
+ if len(content) > 32 - len(f"{bot.name} []"): # type: ignore
116
+ nickname = content[: 32 - len(f"{bot.name} []") - 3] + "..." # type: ignore
117
+ else:
118
+ nickname = content
119
+ nickname = f"{bot.name} [{nickname}]" # type: ignore
120
+ await guild.me.edit(nick=nickname)
121
+ await message.channel.send(
122
+ "Ready to make predictions! Type in your inputs and enclose them in quotation marks."
123
+ )
124
+
125
+
126
+ async def disconnect_space(bot: commands.Bot, guild: discord.Guild):
127
+ guild_spaces.pop(guild.id, None)
128
+ guild_blocks.pop(guild.id, None)
129
+ asyncio.create_task(update_pickle_file(guild_spaces, GUILD_SPACES_FILE))
130
+ await guild.me.edit(nick=bot.name) # type: ignore
131
+
132
+
133
+ async def make_prediction(guild: discord.Guild, message: discord.Message, content: str):
134
+ if guild.id in guild_spaces:
135
+ params = re.split(r' (?=")', content)
136
+ params = [p.strip("'\"") for p in params]
137
+ space = guild_blocks[guild.id]
138
+ predictions = await run_prediction(space, *params)
139
+ if isinstance(predictions, (tuple, list)):
140
+ for p in predictions:
141
+ await send_file_or_text(message.channel, p)
142
+ else:
143
+ await send_file_or_text(message.channel, predictions)
144
+ return
145
+ else:
146
+ await message.channel.send(
147
+ "No Space is currently running. Please type in the name of a Hugging Face Space name first, e.g. abidlabs/en2fr"
148
+ )
149
+ await guild.me.edit(nick=bot.name) # type: ignore
150
+
151
+
152
+ @bot.event
153
+ async def on_message(message: discord.Message):
154
+ if message.author == bot.user:
155
+ return
156
+ h = hash_user_id(message.author.id)
157
+ if h not in hashed_users:
158
+ hashed_users.append(h)
159
+ asyncio.create_task(update_pickle_file(hashed_users, HASHED_USERS_FILE))
160
+ else:
161
+ if message.content:
162
+ content = remove_tags(message.content)
163
+ guild = message.channel.guild
164
+ assert guild, "Message not sent in a guild."
165
+
166
+ if content.strip() == "exit":
167
+ await disconnect_space(bot, guild)
168
+ elif content.strip() == "stats":
169
+ await display_stats(message)
170
+ elif content.startswith('"') or content.startswith("'"):
171
+ await make_prediction(guild, message, content)
172
+ else:
173
+ await load_space(guild, message, content)
174
+
175
+ bot.run(os.getenv("discord_token"))