Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import glob
|
3 |
+
import os
|
4 |
+
import random
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
import threading
|
8 |
+
from discord import app_commands
|
9 |
+
from discord.ext import commands
|
10 |
+
import discord
|
11 |
+
import gradio as gr
|
12 |
+
import discord
|
13 |
+
from gradio_client import Client
|
14 |
+
|
15 |
+
|
16 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
17 |
+
wuerstchen_client = Client("huggingface-projects/Wuerstchen-duplicate", HF_TOKEN)
|
18 |
+
DISCORD_TOKEN = os.getenv("DISCORD_TOKEN")
|
19 |
+
|
20 |
+
#---------------------------------------------------------------------------------------------------------------------
|
21 |
+
intents = discord.Intents.all()
|
22 |
+
bot = commands.Bot(command_prefix="/", intents=intents)
|
23 |
+
#---------------------------------------------------------------------------------------------------------------------
|
24 |
+
@bot.event
|
25 |
+
async def on_ready():
|
26 |
+
print(f"Logged in as {bot.user} (ID: {bot.user.id})")
|
27 |
+
synced = await bot.tree.sync()
|
28 |
+
print(f"Synced commands: {', '.join([s.name for s in synced])}.")
|
29 |
+
print("------")
|
30 |
+
#---------------------------------------------------------------------------------------------------------------------
|
31 |
+
@client.hybrid_command(
|
32 |
+
name="wuerstchen",
|
33 |
+
description="Enter a prompt to generate art!",
|
34 |
+
)
|
35 |
+
@app_commands.guilds(MY_GUILD)
|
36 |
+
async def wuerstchen_command(ctx, prompt: str):
|
37 |
+
"""Wuerstchen generation"""
|
38 |
+
try:
|
39 |
+
await run_wuerstchen(ctx, prompt, client)
|
40 |
+
except Exception as e:
|
41 |
+
print(f"Error wuerstchen: (app.py){e}")
|
42 |
+
|
43 |
+
|
44 |
+
def wuerstchen_inference(prompt):
|
45 |
+
"""Inference for Wuerstchen"""
|
46 |
+
negative_prompt = ""
|
47 |
+
seed = random.randint(0, 1000)
|
48 |
+
width = 1024
|
49 |
+
height = 1024
|
50 |
+
prior_num_inference_steps = 60
|
51 |
+
prior_guidance_scale = 4
|
52 |
+
decoder_num_inference_steps = 12
|
53 |
+
decoder_guidance_scale = 0
|
54 |
+
num_images_per_prompt = 1
|
55 |
+
|
56 |
+
result_path = wuerstchen_client.predict(
|
57 |
+
prompt,
|
58 |
+
negative_prompt,
|
59 |
+
seed,
|
60 |
+
width,
|
61 |
+
height,
|
62 |
+
prior_num_inference_steps,
|
63 |
+
prior_guidance_scale,
|
64 |
+
decoder_num_inference_steps,
|
65 |
+
decoder_guidance_scale,
|
66 |
+
num_images_per_prompt,
|
67 |
+
api_name="/run",
|
68 |
+
)
|
69 |
+
png_file = list(glob.glob(f"{result_path}/**/*.png"))
|
70 |
+
return png_file[0]
|
71 |
+
|
72 |
+
|
73 |
+
async def run_wuerstchen(ctx, prompt, client):
|
74 |
+
"""Responds to /Wuerstchen command"""
|
75 |
+
try:
|
76 |
+
if ctx.author.id != BOT_USER_ID:
|
77 |
+
if ctx.channel.id == WUERSTCHEN_CHANNEL_ID:
|
78 |
+
channel = client.get_channel(WUERSTCHEN_CHANNEL_ID)
|
79 |
+
message = await ctx.send(f"**{prompt}** - {ctx.author.mention} <a:loading:1114111677990981692>")
|
80 |
+
|
81 |
+
loop = asyncio.get_running_loop()
|
82 |
+
result_path = await loop.run_in_executor(None, wuerstchen_inference, prompt)
|
83 |
+
|
84 |
+
await message.delete()
|
85 |
+
with open(result_path, "rb") as f:
|
86 |
+
await channel.send(f"**{prompt}** - {ctx.author.mention}", file=discord.File(f, "wuerstchen.png"))
|
87 |
+
except Exception as e:
|
88 |
+
print(f"Error: {e}")
|
89 |
+
|
90 |
+
|
91 |
+
def run_bot():
|
92 |
+
bot.run(DISCORD_TOKEN)
|
93 |
+
|
94 |
+
|
95 |
+
threading.Thread(target=run_bot).start()
|
96 |
+
"""This allows us to run the Discord bot in a Python thread"""
|
97 |
+
with gr.Blocks() as demo:
|
98 |
+
gr.Markdown("""
|
99 |
+
# Huggingbots Server
|
100 |
+
This space hosts the huggingbots discord bot.
|
101 |
+
Currently supported models are Falcon and DeepfloydIF
|
102 |
+
""")
|
103 |
+
demo.queue(concurrency_count=100)
|
104 |
+
demo.queue(max_size=100)
|
105 |
+
demo.launch()
|