|
import discord |
|
import os |
|
import gradio as gr |
|
import asyncio |
|
from discord.ext import commands |
|
from gradio_client import Client |
|
from PIL import Image |
|
|
|
DFIF_TOKEN = os.getenv('HF_TOKEN') |
|
DISCORD_TOKEN = os.environ.get("GRADIOTEST_TOKEN", None) |
|
|
|
jojogan = Client("akhaliq/JoJoGAN", DFIF_TOKEN) |
|
|
|
intents = discord.Intents.default() |
|
intents.message_content = True |
|
|
|
bot = commands.Bot(command_prefix='!', intents=intents) |
|
|
|
|
|
async def jojo(ctx): |
|
start_time = time.time() |
|
style = 'JoJo' |
|
atchurl = 'https://cdn.discordapp.com/attachments/1100458786826747945/1111746037640601610/image.png' |
|
im = jojogan.predict(atchurl, style) |
|
end_time = time.time() |
|
generation_time = end_time - start_time |
|
await ctx.send(f"{style} image generated in {generation_time:.2f} seconds.") |
|
await ctx.send(file=discord.File(im)) |
|
|
|
|
|
@bot.command() |
|
async def command(ctx, num_requests: int): |
|
tasks = [] |
|
|
|
for _ in range(num_requests): |
|
task = asyncio.create_task(jojo(ctx)) |
|
tasks.append(task) |
|
|
|
await asyncio.gather(*tasks) |
|
await ctx.send("Command executed.") |
|
|
|
|
|
def run_bot(): |
|
bot.run(DISCORD_TOKEN) |
|
|
|
|
|
async def run_gradio_interface(): |
|
def greet(name): |
|
return "Hello " + name + "!" |
|
|
|
demo = gr.Interface(fn=greet, inputs="text", outputs="text") |
|
await demo.launch() |
|
|
|
|
|
def main(): |
|
loop = asyncio.get_event_loop() |
|
tasks = [ |
|
loop.create_task(run_bot()), |
|
loop.create_task(run_gradio_interface()) |
|
] |
|
loop.run_until_complete(asyncio.wait(tasks)) |
|
|
|
|
|
main() |
|
|