lunarflu HF Staff commited on
Commit
c41bcc3
·
1 Parent(s): a1eb69d

lunarbot -> gradiotest

Browse files
Files changed (1) hide show
  1. app.py +198 -211
app.py CHANGED
@@ -1,47 +1,43 @@
1
  import discord
2
- from discord.ui import Button
3
- from discord.ext import commands #for buttons
4
- breaking the bot, fix exception later
5
- import gradio_client
6
- from gradio_client import Client
7
- import gradio as gr
8
  import os
9
  import threading
10
-
11
- #for deepfloydif
12
  import requests
13
  import json
14
  import random
15
- from PIL import Image
16
- import matplotlib.pyplot as plt
17
- import matplotlib.image as mpimg
18
  import time
19
-
 
 
 
 
 
 
 
 
 
 
20
  import asyncio
21
 
22
- # random + small llama #
23
-
24
-
25
- #todos
26
- #alert
27
- #fix error on first command on bot startup
28
- #stable diffusion upscale
29
- #buttons for deepfloydIF (1,2,3,4)
30
- #application commands instead of message content checks (more user-friendly)
31
 
 
 
32
 
 
 
 
33
 
 
 
 
34
 
35
- DFIF_TOKEN = os.getenv('DFIF_TOKEN')
36
 
37
- #deepfloydIF
38
- #df = Client("DeepFloyd/IF", DFIF_TOKEN) #not reliable at the moment
39
- df = Client("huggingface-projects/IF", DFIF_TOKEN)
40
 
41
- #stable diffusion upscaler
42
- sdlu = Client("huggingface-projects/stable-diffusion-latent-upscaler", DFIF_TOKEN)
43
-
44
- #---------------------------------------------------------------------------------------------------------------------------------------------------
45
  class ButtonView(discord.ui.View):
46
  def __init__(self, ctx, image_paths, stage_1_result_path):
47
  super().__init__()
@@ -53,167 +49,214 @@ class ButtonView(discord.ui.View):
53
  for child in self.children:
54
  child.disabled = True
55
  self.stop()
56
- #-------------------------------------------
57
- async def invoke_dfif2(self, image_path):
58
- ctx = await self.get_context(message, cls=commands.Context)
59
- await self.ctx.invoke(self.ctx.bot.get_command('dfif2'), image_path=image_path, stage_1_result_path=self.stage_1_result_path)
60
 
61
  @discord.ui.button(label='Image 1', style=discord.ButtonStyle.blurple)
62
  async def image1_button(self, button: discord.ui.Button, interaction: discord.Interaction):
63
- await self.invoke_dfif2(self.image_paths[0])
64
- self.stop()
65
 
66
  @discord.ui.button(label='Image 2', style=discord.ButtonStyle.blurple)
67
  async def image2_button(self, button: discord.ui.Button, interaction: discord.Interaction):
68
- await self.invoke_dfif2(self.image_paths[1])
69
- self.stop()
70
 
71
  @discord.ui.button(label='Image 3', style=discord.ButtonStyle.blurple)
72
  async def image3_button(self, button: discord.ui.Button, interaction: discord.Interaction):
73
- await self.invoke_dfif2(self.image_paths[2])
74
- self.stop()
75
 
76
  @discord.ui.button(label='Image 4', style=discord.ButtonStyle.blurple)
77
  async def image4_button(self, button: discord.ui.Button, interaction: discord.Interaction):
78
- await self.invoke_dfif2(self.image_paths[3])
79
- self.stop()
80
- #---------------------------------------------------------------------------------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
 
82
 
83
- #---------------------------------------------------------------------------------------------------------------------------------------------------
84
- # Set up discord bot
85
- class MyClient(discord.Client):
86
- async def on_ready(self):
87
- print('Logged on as', self.user)
88
- self.log_channel = self.get_channel(1036960509586587689) # 1100458786826747945 = bot test
89
 
90
- async def on_message(self, message):
 
 
 
 
 
 
 
 
 
91
 
92
- #safety checks----------------------------------------------------------------------------------------------------
 
 
93
 
94
- # tldr, bot should run if
95
- #1) it does not have @offline role
96
- #2) user has @verified role
97
- #3) bot is in #bot-test channel
98
-
99
- # bot won't respond to itself, prevents feedback loop + API spam
100
- if message.author == self.user:
101
- return
102
-
103
- #antispam---------------------------------------------------
104
- if "discord.gg/" in message.content or "discordapp.com/invite/" in message.content:
105
- try:
106
- await message.delete()
107
- await asyncio.sleep(5) # safety mechanism, can only delete slowly
108
- except Exception as e:
109
- print(f"An error occurred while deleting the message: {e}")
110
- await asyncio.sleep(5)
111
- elif "@everyone" in message.content or "@here" in message.content:
112
- try:
113
- await message.delete()
114
- await asyncio.sleep(5) # safety mechanism, can only delete slowly
115
- except Exception as e:
116
- print(f"An error occurred while deleting the message: {e}")
117
- await asyncio.sleep(5)
118
- else:
119
- pass
120
-
121
- # if the bot has this role, it won't run
122
- OFFLINE_ROLE_ID = 1103676632667017266 # 1103676632667017266 = @offline / under maintenance
123
- guild = message.guild
124
- bot_member = guild.get_member(self.user.id)
125
- if any(role.id == OFFLINE_ROLE_ID for role in bot_member.roles):
126
- return
127
-
128
- # the message author needs this role in order to use the bot
129
- REQUIRED_ROLE_ID = 897376942817419265 # 900063512829755413 = @verified, 897376942817419265 = @huggingfolks
130
- if not any(role.id == REQUIRED_ROLE_ID for role in message.author.roles):
131
- return
132
-
133
- # channels where bot will accept commands
134
- ALLOWED_CHANNEL_IDS = [1100458786826747945] # 1100458786826747945 = #bot-test
135
- if message.channel.id not in ALLOWED_CHANNEL_IDS:
136
- return
137
-
138
- #deepfloydif----------------------------------------------------------------------------------------------------
139
-
140
- if message.content.startswith('!deepfloydif'): # change to application commands, more intuitive
141
-
142
- #(prompt, negative_prompt, seed, number_of_images, guidance_scale,custom_timesteps_1, number_of_inference_steps, api_name="/generate64")
143
- #-> (stage_1_results, stage_1_param_path, stage_1_result_path)
144
-
145
- # input prompt
146
- prompt = message.content[12:].strip()
147
-
148
- negative_prompt = ''
149
- seed = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  number_of_images = 4
151
- guidance_scale = 7
152
- custom_timesteps_1 = 'smart50'
153
- number_of_inference_steps = 50
154
-
155
- stage_1_results, stage_1_param_path, stage_1_result_path = df.predict(
156
- prompt,
157
- negative_prompt,
158
- seed,
159
- number_of_images,
160
- guidance_scale,
161
- custom_timesteps_1,
162
- number_of_inference_steps,
163
- api_name='/generate64')
164
-
165
- #stage_1_results, stage_1_param_path, stage_1_result_path = df.predict("gradio written on a wall", "blur", 1,1,7.0, 'smart100',50, api_name="/generate64")
166
-
167
- # stage_1_results -> path to directory with png files, so we isolate those
168
  png_files = [f for f in os.listdir(stage_1_results) if f.endswith('.png')]
169
 
170
- # merge images into larger, 2x2 image the way midjourney does it
171
  if png_files:
172
  first_png = png_files[0]
173
  second_png = png_files[1]
174
  third_png = png_files[2]
175
  fourth_png = png_files[3]
176
 
177
- '''
178
- [],[],[],[] -> [][]
179
- [][]
180
-
181
- '''
182
-
183
  first_png_path = os.path.join(stage_1_results, first_png)
184
  second_png_path = os.path.join(stage_1_results, second_png)
185
  third_png_path = os.path.join(stage_1_results, third_png)
186
  fourth_png_path = os.path.join(stage_1_results, fourth_png)
187
-
188
  img1 = Image.open(first_png_path)
189
  img2 = Image.open(second_png_path)
190
  img3 = Image.open(third_png_path)
191
  img4 = Image.open(fourth_png_path)
192
-
193
- # create a new blank image with the size of the combined images (2x2)
194
  combined_image = Image.new('RGB', (img1.width * 2, img1.height * 2))
195
-
196
- # paste the individual images into the combined image
197
  combined_image.paste(img1, (0, 0))
198
  combined_image.paste(img2, (img1.width, 0))
199
  combined_image.paste(img3, (0, img1.height))
200
  combined_image.paste(img4, (img1.width, img1.height))
201
-
202
- # save the combined image
203
  combined_image_path = os.path.join(stage_1_results, 'combined_image.png')
204
  combined_image.save(combined_image_path)
205
-
206
-
207
- # Send the combined image file as a discord attachment with the button view
208
- with open(combined_image_path, 'rb') as f:
209
- view = ButtonView(ctx, [first_png_path, second_png_path, third_png_path, fourth_png_path], stage_1_result_path)
210
- await message.reply('Here is the combined image', file=discord.File(f, 'combined_image.png'), view=view)
211
 
 
 
212
 
 
 
 
213
 
 
 
 
214
 
215
- #stage 2, can be stable diffusion too---------------------------------------------------------------------------------------------------------------------------------------------------
216
- async def dfif2(self, ctx, image_path, stage_1_result_path):
 
 
 
 
217
  selected_index_for_stage_2 = 0
218
  seed_2 = 0
219
  guidance_scale_2 = 4
@@ -221,78 +264,22 @@ class MyClient(discord.Client):
221
  number_of_inference_steps_2 = 50
222
  result_path = df.predict(stage_1_result_path, selected_index_for_stage_2, seed_2, guidance_scale_2, custom_timesteps_2, number_of_inference_steps_2, api_name='/upscale256')
223
 
 
 
224
  with open(result_path, 'rb') as f:
225
  await ctx.reply('Here is the result of the second stage', file=discord.File(f, 'result.png'))
226
-
227
- #---------------------------------------------------------------------------------------------------------------------------------------------------
228
- async def on_message_delete(self, message):
229
- if message.author == self.user:
230
- return
231
-
232
- embed = Embed(color=Color.red())
233
- embed.set_author(name=f"{message.author} ID: {message.author.id}", icon_url=message.author.avatar.url)
234
- embed.title = "Message Deleted"
235
- embed.description = message.content or "*(empty message)*"
236
- embed.add_field(name="Author Username", value=message.author.name, inline=True)
237
- embed.add_field(name="Channel", value=message.channel.mention, inline=True)
238
- embed.add_field(name="Message Created On", value=convert_to_timezone(message.created_at, zurich_tz), inline=True)
239
- embed.add_field(name="Message ID", value=message.id, inline=True)
240
- embed.add_field(name="Message Jump URL", value=f"[Jump to message!](https://discord.com/channels/{message.guild.id}/{message.channel.id}/{message.id})", inline=True)
241
-
242
- if message.attachments:
243
- attachment_urls = "\n".join([attachment.url for attachment in message.attachments])
244
- embed.add_field(name="Attachments", value=attachment_urls, inline=False)
245
-
246
- #embed.set_footer(text=f"{datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S UTC')}")
247
- embed.set_footer(text=f"{convert_to_timezone(datetime.utcnow(), zurich_tz)}")
248
 
249
- await self.log_channel.send(embed=embed)
 
 
250
 
251
 
252
- async def on_guild_channel_create(channel):
253
- # Channel creations
254
- embed = Embed(description=f'Channel {channel.mention} was created', color=Color.green())
255
- await bot.log_channel.send(embed=embed)
256
-
257
-
258
- async def on_guild_channel_delete(channel):
259
- # Channel deletions
260
- embed = Embed(description=f'Channel {channel.name} ({channel.mention}) was deleted', color=Color.red())
261
- await bot.log_channel.send(embed=embed)
262
-
263
-
264
- async def on_guild_role_create(role):
265
- # Creating roles
266
- embed = Embed(description=f'Role {role.mention} was created', color=Color.green())
267
- await bot.log_channel.send(embed=embed)
268
-
269
-
270
- async def on_guild_role_delete(role):
271
- # Deleting roles
272
- embed = Embed(description=f'Role {role.name} ({role.mention}) was deleted', color=Color.red())
273
- await bot.log_channel.send(embed=embed)
274
-
275
-
276
- async def on_guild_role_update(before, after):
277
- # Editing roles
278
- if before.name != after.name:
279
- embed = Embed(description=f'Role {before.mention} was renamed to {after.name}', color=Color.orange())
280
- await bot.log_channel.send(embed=embed)
281
-
282
- if before.permissions.administrator != after.permissions.administrator:
283
- # Changes involving the administrator permission
284
- certain_role_id = 1106995261487710411 # Replace with the actual role ID
285
- certain_role = after.guild.get_role(certain_role_id)
286
- embed = Embed(description=f'Role {after.mention} had its administrator permission {"enabled" if after.permissions.administrator else "disabled"}', color=Color.red())
287
- await bot.log_channel.send(content=certain_role.mention, embed=embed)
288
-
289
- DISCORD_TOKEN = os.environ.get("GRADIOTEST_TOKEN", None)
290
- intents = discord.Intents.default()
291
- intents.message_content = True
292
- client = MyClient(intents=intents)
293
 
294
  def run_bot():
295
- client.run(DISCORD_TOKEN)
296
 
297
  threading.Thread(target=run_bot).start()
298
 
@@ -300,4 +287,4 @@ def greet(name):
300
  return "Hello " + name + "!"
301
 
302
  demo = gr.Interface(fn=greet, inputs="text", outputs="text")
303
- demo.launch()
 
1
  import discord
 
 
 
 
 
 
2
  import os
3
  import threading
4
+ import gradio as gr
 
5
  import requests
6
  import json
7
  import random
 
 
 
8
  import time
9
+ import re
10
+ from discord import Embed, Color
11
+ from discord.ext import commands
12
+ # running?
13
+ from gradio_client import Client
14
+ from PIL import Image
15
+ from ratelimiter import RateLimiter
16
+ #
17
+ from datetime import datetime
18
+ from pytz import timezone
19
+ #
20
  import asyncio
21
 
22
+ zurich_tz = timezone("Europe/Zurich")
 
 
 
 
 
 
 
 
23
 
24
+ def convert_to_timezone(dt, tz):
25
+ return dt.astimezone(tz).strftime("%Y-%m-%d %H:%M:%S %Z")
26
 
27
+ DFIF_TOKEN = os.getenv('HF_TOKEN')
28
+ df = Client("huggingface-projects/IF", DFIF_TOKEN)
29
+ sdlu = Client("huggingface-projects/stable-diffusion-latent-upscaler", DFIF_TOKEN)
30
 
31
+ DISCORD_TOKEN = os.environ.get("LUNARBOT_TOKEN", None)
32
+ intents = discord.Intents.default()
33
+ intents.message_content = True
34
 
35
+ bot = commands.Bot(command_prefix='!', intents=intents)
36
 
37
+ rate_limiter = RateLimiter(max_calls=10, period=60) # 10 calls per minute
 
 
38
 
39
+ #buttons----------------------------------------------------------------------------------------------------------------------------------------------
40
+ #new
 
 
41
  class ButtonView(discord.ui.View):
42
  def __init__(self, ctx, image_paths, stage_1_result_path):
43
  super().__init__()
 
49
  for child in self.children:
50
  child.disabled = True
51
  self.stop()
 
 
 
 
52
 
53
  @discord.ui.button(label='Image 1', style=discord.ButtonStyle.blurple)
54
  async def image1_button(self, button: discord.ui.Button, interaction: discord.Interaction):
55
+ await self.ctx.invoke(self.ctx.bot.get_command('dfif2'), image_path=self.image_paths[0], stage_1_result_path=self.stage_1_result_path)
56
+ self.stop()
57
 
58
  @discord.ui.button(label='Image 2', style=discord.ButtonStyle.blurple)
59
  async def image2_button(self, button: discord.ui.Button, interaction: discord.Interaction):
60
+ await self.ctx.invoke(self.ctx.bot.get_command('dfif2'), image_path=self.image_paths[1], stage_1_result_path=self.stage_1_result_path)
61
+ self.stop()
62
 
63
  @discord.ui.button(label='Image 3', style=discord.ButtonStyle.blurple)
64
  async def image3_button(self, button: discord.ui.Button, interaction: discord.Interaction):
65
+ await self.ctx.invoke(self.ctx.bot.get_command('dfif2'), image_path=self.image_paths[2], stage_1_result_path=self.stage_1_result_path)
66
+ self.stop()
67
 
68
  @discord.ui.button(label='Image 4', style=discord.ButtonStyle.blurple)
69
  async def image4_button(self, button: discord.ui.Button, interaction: discord.Interaction):
70
+ await self.ctx.invoke(self.ctx.bot.get_command('dfif2'), image_path=self.image_paths[3], stage_1_result_path=self.stage_1_result_path)
71
+ self.stop()
72
+
73
+ #new
74
+ def create_button_row(ctx, image_paths, stage_1_result_path):
75
+ view = ButtonView(ctx, image_paths, stage_1_result_path)
76
+ return view
77
+ #----------------------------------------------------------------------------------------------------------------------------------------------
78
+
79
+ @bot.event
80
+ async def on_ready():
81
+ print('Logged on as', bot.user)
82
+ bot.log_channel = bot.get_channel(1107006391547342910) # 1100458786826747945 = bot-test, 1107006391547342910 = lunarbot server
83
+
84
+ @bot.event
85
+ async def on_message_edit(before, after):
86
+ if before.author == bot.user:
87
+ return
88
+
89
+ if before.content != after.content:
90
+ embed = Embed(color=Color.orange())
91
+ embed.set_author(name=f"{before.author} ID: {before.author.id}", icon_url=before.author.avatar.url)
92
+ embed.title = "Message Edited"
93
+ embed.description = f"**Before:** {before.content or '*(empty message)*'}\n**After:** {after.content or '*(empty message)*'}"
94
+ embed.add_field(name="Author Username", value=before.author.name, inline=True)
95
+ embed.add_field(name="Channel", value=before.channel.mention, inline=True)
96
+ #embed.add_field(name="Message Created On", value=before.created_at.strftime("%Y-%m-%d %H:%M:%S UTC"), inline=True)
97
+ embed.add_field(name="Message Created On", value=convert_to_timezone(before.created_at, zurich_tz), inline=True)
98
+ embed.add_field(name="Message ID", value=before.id, inline=True)
99
+ embed.add_field(name="Message Jump URL", value=f"[Jump to message!](https://discord.com/channels/{before.guild.id}/{before.channel.id}/{before.id})", inline=True)
100
+
101
+ if before.attachments:
102
+ attachment_urls = "\n".join([attachment.url for attachment in before.attachments])
103
+ embed.add_field(name="Attachments", value=attachment_urls, inline=False)
104
+
105
+ #embed.set_footer(text=f"{datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S UTC')}")
106
+ embed.set_footer(text=f"{convert_to_timezone(datetime.utcnow(), zurich_tz)}")
107
 
108
+ await bot.log_channel.send(embed=embed)
109
 
110
+ @bot.event
111
+ async def on_message_delete(message):
112
+ if message.author == bot.user:
113
+ return
 
 
114
 
115
+ embed = Embed(color=Color.red())
116
+ embed.set_author(name=f"{message.author} ID: {message.author.id}", icon_url=message.author.avatar.url)
117
+ embed.title = "Message Deleted"
118
+ embed.description = message.content or "*(empty message)*"
119
+ embed.add_field(name="Author Username", value=message.author.name, inline=True)
120
+ embed.add_field(name="Channel", value=message.channel.mention, inline=True)
121
+ #embed.add_field(name="Message Created On", value=message.created_at.strftime("%Y-%m-%d %H:%M:%S UTC"), inline=True)
122
+ embed.add_field(name="Message Created On", value=convert_to_timezone(message.created_at, zurich_tz), inline=True)
123
+ embed.add_field(name="Message ID", value=message.id, inline=True)
124
+ embed.add_field(name="Message Jump URL", value=f"[Jump to message!](https://discord.com/channels/{message.guild.id}/{message.channel.id}/{message.id})", inline=True)
125
 
126
+ if message.attachments:
127
+ attachment_urls = "\n".join([attachment.url for attachment in message.attachments])
128
+ embed.add_field(name="Attachments", value=attachment_urls, inline=False)
129
 
130
+ #embed.set_footer(text=f"{datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S UTC')}")
131
+ embed.set_footer(text=f"{convert_to_timezone(datetime.utcnow(), zurich_tz)}")
132
+
133
+ await bot.log_channel.send(embed=embed)
134
+
135
+
136
+
137
+
138
+
139
+ #new
140
+ @bot.event
141
+ async def on_voice_state_update(member, before, after):
142
+ if before.channel != after.channel:
143
+ # Moving members in voice chat
144
+ embed = Embed(description=f'{member} moved in voice chat from {before.channel} to {after.channel}', color=Color.blue())
145
+ await bot.log_channel.send(embed=embed)
146
+
147
+ if before.mute != after.mute:
148
+ # Muting members in voice chat
149
+ embed = Embed(description=f'{member} was {"muted" if after.mute else "unmuted"} in voice chat', color=Color.orange())
150
+ await bot.log_channel.send(embed=embed)
151
+
152
+ if before.deaf != after.deaf:
153
+ # Deafening members in voice chat
154
+ embed = Embed(description=f'{member} was {"deafened" if after.deaf else "undeafened"} in voice chat', color=Color.orange())
155
+ await bot.log_channel.send(embed=embed)
156
+
157
+ @bot.event
158
+ async def on_member_update(before, after):
159
+ if before.nick != after.nick:
160
+ # Nickname changes
161
+ embed = Embed(description=f'{before} changed their nickname to {after.nick}', color=Color.blue())
162
+ await bot.log_channel.send(embed=embed)
163
+
164
+ @bot.event
165
+ async def on_guild_channel_create(channel):
166
+ # Channel creations
167
+ embed = Embed(description=f'Channel {channel.mention} was created', color=Color.green())
168
+ await bot.log_channel.send(embed=embed)
169
+
170
+ @bot.event
171
+ async def on_guild_channel_delete(channel):
172
+ # Channel deletions
173
+ embed = Embed(description=f'Channel {channel.name} ({channel.mention}) was deleted', color=Color.red())
174
+ await bot.log_channel.send(embed=embed)
175
+
176
+ @bot.event
177
+ async def on_guild_role_create(role):
178
+ # Creating roles
179
+ embed = Embed(description=f'Role {role.mention} was created', color=Color.green())
180
+ await bot.log_channel.send(embed=embed)
181
+
182
+ @bot.event
183
+ async def on_guild_role_delete(role):
184
+ # Deleting roles
185
+ embed = Embed(description=f'Role {role.name} ({role.mention}) was deleted', color=Color.red())
186
+ await bot.log_channel.send(embed=embed)
187
+
188
+ @bot.event
189
+ async def on_guild_role_update(before, after):
190
+ # Editing roles
191
+ if before.name != after.name:
192
+ embed = Embed(description=f'Role {before.mention} was renamed to {after.name}', color=Color.orange())
193
+ await bot.log_channel.send(embed=embed)
194
+
195
+ if before.permissions.administrator != after.permissions.administrator:
196
+ # Changes involving the administrator permission
197
+ certain_role_id = 1106995261487710411 # Replace with the actual role ID
198
+ certain_role = after.guild.get_role(certain_role_id)
199
+ embed = Embed(description=f'Role {after.mention} had its administrator permission {"enabled" if after.permissions.administrator else "disabled"}', color=Color.red())
200
+ await bot.log_channel.send(content=certain_role.mention, embed=embed)
201
+
202
+ @bot.command()
203
+ @commands.cooldown(1, 5, commands.BucketType.user)
204
+ async def deepfloydif(ctx, *, prompt: str):
205
+ try:
206
+ prompt = prompt.strip()[:100] # Limit the prompt length to 100 characters
207
+ prompt = re.sub(r'[^\w\s]', '', prompt) # Remove special characters
208
+
209
+ with rate_limiter:
210
  number_of_images = 4
211
+ current_time = int(time.time())
212
+ random.seed(current_time)
213
+ seed = random.randint(0, 2**32 - 1)
214
+ stage_1_results, stage_1_param_path, stage_1_result_path = df.predict(prompt, "blur", seed, number_of_images, 7.0, 'smart100', 50, api_name="/generate64")
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  png_files = [f for f in os.listdir(stage_1_results) if f.endswith('.png')]
216
 
 
217
  if png_files:
218
  first_png = png_files[0]
219
  second_png = png_files[1]
220
  third_png = png_files[2]
221
  fourth_png = png_files[3]
222
 
 
 
 
 
 
 
223
  first_png_path = os.path.join(stage_1_results, first_png)
224
  second_png_path = os.path.join(stage_1_results, second_png)
225
  third_png_path = os.path.join(stage_1_results, third_png)
226
  fourth_png_path = os.path.join(stage_1_results, fourth_png)
227
+
228
  img1 = Image.open(first_png_path)
229
  img2 = Image.open(second_png_path)
230
  img3 = Image.open(third_png_path)
231
  img4 = Image.open(fourth_png_path)
232
+
 
233
  combined_image = Image.new('RGB', (img1.width * 2, img1.height * 2))
234
+
 
235
  combined_image.paste(img1, (0, 0))
236
  combined_image.paste(img2, (img1.width, 0))
237
  combined_image.paste(img3, (0, img1.height))
238
  combined_image.paste(img4, (img1.width, img1.height))
239
+
 
240
  combined_image_path = os.path.join(stage_1_results, 'combined_image.png')
241
  combined_image.save(combined_image_path)
 
 
 
 
 
 
242
 
243
+ # Trigger the second stage prediction
244
+ #await dfif2(ctx, stage_1_result_path)
245
 
246
+ await ctx.reply('Here is the combined image. Select an option quickly!')
247
+ with open(combined_image_path, 'rb') as f:
248
+ await ctx.send(file=discord.File(f, 'combined_image.png'), view=create_button_row(ctx, [first_png_path, second_png_path, third_png_path, fourth_png_path], stage_1_result_path))
249
 
250
+ except Exception as e:
251
+ print(f"Error: {e}")
252
+ await ctx.reply('An error occurred while processing your request. Please wait 5 seconds before retrying.')
253
 
254
+ #new stage 2----------------------------------------------------------------------------------------------------------------------------------------------
255
+ # Stage 2
256
+ @bot.command()
257
+ @commands.cooldown(1, 5, commands.BucketType.user)
258
+ async def dfif2(ctx, image_path, stage_1_result_path):
259
+ try:
260
  selected_index_for_stage_2 = 0
261
  seed_2 = 0
262
  guidance_scale_2 = 4
 
264
  number_of_inference_steps_2 = 50
265
  result_path = df.predict(stage_1_result_path, selected_index_for_stage_2, seed_2, guidance_scale_2, custom_timesteps_2, number_of_inference_steps_2, api_name='/upscale256')
266
 
267
+ # Process the result_path or perform any additional operations
268
+
269
  with open(result_path, 'rb') as f:
270
  await ctx.reply('Here is the result of the second stage', file=discord.File(f, 'result.png'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
+ except Exception as e:
273
+ print(f"Error: {e}")
274
+ await ctx.reply('An error occurred while processing stage 2 upscaling. Please try again later.')
275
 
276
 
277
+
278
+
279
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
 
281
  def run_bot():
282
+ bot.run(DISCORD_TOKEN)
283
 
284
  threading.Thread(target=run_bot).start()
285
 
 
287
  return "Hello " + name + "!"
288
 
289
  demo = gr.Interface(fn=greet, inputs="text", outputs="text")
290
+ demo.launch()