Synced repo using 'sync_with_huggingface' Github Action
Browse files- app.py +15 -0
- wuerstchen.py +64 -0
app.py
CHANGED
@@ -11,6 +11,7 @@ from discord.ext import commands
|
|
11 |
from falcon import falcon_chat, continue_chat
|
12 |
from musicgen import music_create
|
13 |
from codellama import continue_codellama, try_codellama
|
|
|
14 |
|
15 |
# HF GUILD SETTINGS
|
16 |
MY_GUILD_ID = 1077674588122648679 if os.getenv("TEST_ENV", False) else 879548962464493619
|
@@ -117,6 +118,20 @@ async def audioldm2(ctx, prompt: str):
|
|
117 |
print(f"Error: (app.py){e}")
|
118 |
|
119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
def run_bot():
|
121 |
client.run(DISCORD_TOKEN)
|
122 |
|
|
|
11 |
from falcon import falcon_chat, continue_chat
|
12 |
from musicgen import music_create
|
13 |
from codellama import continue_codellama, try_codellama
|
14 |
+
from wuerstchen import run_wuerstchen
|
15 |
|
16 |
# HF GUILD SETTINGS
|
17 |
MY_GUILD_ID = 1077674588122648679 if os.getenv("TEST_ENV", False) else 879548962464493619
|
|
|
118 |
print(f"Error: (app.py){e}")
|
119 |
|
120 |
|
121 |
+
@client.hybrid_command(
|
122 |
+
name="wuerstchen",
|
123 |
+
with_app_command=True,
|
124 |
+
description="Enter a prompt to generate art!",
|
125 |
+
)
|
126 |
+
@app_commands.guilds(MY_GUILD)
|
127 |
+
async def wuerstchen_command(ctx, prompt: str):
|
128 |
+
"""Wuerstchen generation"""
|
129 |
+
try:
|
130 |
+
await run_wuerstchen(ctx, prompt, client)
|
131 |
+
except Exception as e:
|
132 |
+
print(f"Error wuerstchen: (app.py){e}")
|
133 |
+
|
134 |
+
|
135 |
def run_bot():
|
136 |
client.run(DISCORD_TOKEN)
|
137 |
|
wuerstchen.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import glob
|
3 |
+
import os
|
4 |
+
import random
|
5 |
+
|
6 |
+
import discord
|
7 |
+
from gradio_client import Client
|
8 |
+
|
9 |
+
|
10 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
11 |
+
wuerstchen_client = Client("warp-ai/Wuerstchen", HF_TOKEN)
|
12 |
+
|
13 |
+
BOT_USER_ID = 1102236653545861151
|
14 |
+
WUERSTCHEN_CHANNEL_ID = 1151792944676864041
|
15 |
+
|
16 |
+
|
17 |
+
def wuerstchen_inference(prompt, client):
|
18 |
+
negative_prompt = ""
|
19 |
+
seed = random.randint(0, 1000)
|
20 |
+
width = 1024
|
21 |
+
height = 1024
|
22 |
+
prior_num_inference_steps = 60
|
23 |
+
prior_guidance_scale = 4
|
24 |
+
decoder_num_inference_steps = 12
|
25 |
+
decoder_guidance_scale = 0
|
26 |
+
num_images_per_prompt = 1
|
27 |
+
|
28 |
+
job = wuerstchen_client.submit(
|
29 |
+
prompt,
|
30 |
+
negative_prompt,
|
31 |
+
seed,
|
32 |
+
width,
|
33 |
+
height,
|
34 |
+
prior_num_inference_steps,
|
35 |
+
prior_guidance_scale,
|
36 |
+
decoder_num_inference_steps,
|
37 |
+
decoder_guidance_scale,
|
38 |
+
num_images_per_prompt,
|
39 |
+
api_name="/run",
|
40 |
+
)
|
41 |
+
while not job.done():
|
42 |
+
pass
|
43 |
+
else:
|
44 |
+
return job
|
45 |
+
|
46 |
+
|
47 |
+
async def run_wuerstchen(ctx, prompt, client):
|
48 |
+
"""wuerstchen"""
|
49 |
+
try:
|
50 |
+
if ctx.author.id != BOT_USER_ID:
|
51 |
+
if ctx.channel.id == WUERSTCHEN_CHANNEL_ID:
|
52 |
+
channel = client.get_channel(WUERSTCHEN_CHANNEL_ID)
|
53 |
+
message = await ctx.send(f"**{prompt}** - {ctx.author.mention} <a:loading:1114111677990981692>")
|
54 |
+
|
55 |
+
loop = asyncio.get_running_loop()
|
56 |
+
job = await loop.run_in_executor(None, wuerstchen_inference, prompt, client)
|
57 |
+
|
58 |
+
png_files = list(glob.glob(f"{job.outputs()[-1]}/**/*.png"))
|
59 |
+
await message.delete()
|
60 |
+
|
61 |
+
with open(png_files[0], "rb") as f:
|
62 |
+
await channel.send(f"**{prompt}** - {ctx.author.mention}", file=discord.File(f, "wuerstchen.png"))
|
63 |
+
except Exception as e:
|
64 |
+
print(f"Error: {e}")
|