|
import asyncio |
|
import os |
|
import threading |
|
import random |
|
from threading import Event |
|
from typing import Optional |
|
|
|
import discord |
|
import gradio as gr |
|
from discord import Permissions |
|
from discord.ext import commands |
|
from discord.utils import oauth_url |
|
|
|
import gradio_client as grc |
|
from gradio_client.utils import QueueError |
|
|
|
event = Event() |
|
|
|
DISCORD_TOKEN = os.getenv("DISCORD_TOKEN") |
|
|
|
async def wait(job): |
|
while not job.done(): |
|
await asyncio.sleep(0.2) |
|
|
|
def get_client(session: Optional[str] = None) -> grc.Client: |
|
client = grc.Client("huggingface-projects/transformers-musicgen", hf_token=os.getenv("HF_TOKEN")) |
|
if session: |
|
client.session_hash = session |
|
return client |
|
|
|
intents = discord.Intents.default() |
|
intents.message_content = True |
|
bot = commands.Bot(command_prefix="/", intents=intents) |
|
|
|
@bot.event |
|
async def on_ready(): |
|
print(f"Logged in as {bot.user} (ID: {bot.user.id})") |
|
synced = await bot.tree.sync() |
|
print(f"Synced commands: {', '.join([s.name for s in synced])}.") |
|
event.set() |
|
print("------") |
|
|
|
|
|
|
|
@bot.hybrid_command( |
|
name="testm", |
|
description="Enter a prompt to generate music!", |
|
) |
|
async def musicgen_command(ctx, prompt: str, seed: int = None): |
|
"""Generates music based on a prompt""" |
|
if ctx.author.id == bot.user.id: |
|
return |
|
if seed is None: |
|
seed = random.randint(1, 10000) |
|
try: |
|
await music_create(ctx, prompt, seed) |
|
except Exception as e: |
|
print(f"Error: {e}") |
|
|
|
async def music_create(ctx, prompt, seed): |
|
"""Runs music_create_job in executor""" |
|
try: |
|
message = await ctx.send(f"**{prompt}** - {ctx.author.mention} Generating...") |
|
|
|
loop = asyncio.get_running_loop() |
|
job = await loop.run_in_executor(None, music_create_job, prompt, seed) |
|
|
|
try: |
|
job.result() |
|
files = job.outputs() |
|
media_files = files[0] |
|
except QueueError: |
|
await ctx.send("The gradio space powering this bot is really busy! Please try again later!") |
|
|
|
audio = media_files[0] |
|
video = media_files[1] |
|
short_filename = prompt[:20] |
|
audio_filename = f"{short_filename}.mp3" |
|
video_filename = f"{short_filename}.mp4" |
|
|
|
with open(video, "rb") as file: |
|
discord_video_file = discord.File(file, filename=video_filename) |
|
await ctx.send(file=discord_video_file) |
|
|
|
with open(audio, "rb") as file: |
|
discord_audio_file = discord.File(file, filename=audio_filename) |
|
await ctx.send(file=discord_audio_file) |
|
|
|
except Exception as e: |
|
print(f"music_create Error: {e}") |
|
|
|
|
|
def music_create_job(prompt, seed): |
|
"""Generates music based on a given prompt""" |
|
try: |
|
job = musicgen.submit(prompt, seed, api_name="/predict") |
|
while not job.done(): |
|
pass |
|
return job |
|
|
|
except Exception as e: |
|
print(f"music_create_job Error: {e}") |
|
|
|
|
|
|
|
def run_bot(): |
|
if not DISCORD_TOKEN: |
|
print("DISCORD_TOKEN NOT SET") |
|
event.set() |
|
else: |
|
bot.run(DISCORD_TOKEN) |
|
|
|
|
|
threading.Thread(target=run_bot).start() |
|
|
|
event.wait() |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown( |
|
""" |
|
# Discord bot of https://huggingface.co/spaces/facebook/MusicGen |
|
""" |
|
) |
|
|
|
demo.launch() |