lunarflu HF Staff commited on
Commit
313dffd
Β·
1 Parent(s): 6dd568f

[dfif] def inference(prompt):

Browse files
Files changed (1) hide show
  1. app.py +32 -35
app.py CHANGED
@@ -89,7 +89,7 @@ async def safetychecks(ctx):
89
  print(f"Error: safetychecks failed somewhere, command will not continue.")
90
  await ctx.message.reply(f"<@811235357663297546> SC failed somewhere") # this will always ping, as long as the bot has access to the channel
91
  #----------------------------------------------------------------------------------------------------------------------------------------------
92
- # jojo
93
  @bot.command()
94
  async def jojo(ctx):
95
  # img + face βœ…
@@ -122,7 +122,7 @@ async def jojo(ctx):
122
  await ctx.message.add_reaction('❌')
123
 
124
  #----------------------------------------------------------------------------------------------------------------------------------------------
125
- # Disney
126
  @bot.command()
127
  async def disney(ctx):
128
  try:
@@ -145,7 +145,7 @@ async def disney(ctx):
145
  await ctx.message.add_reaction('❌')
146
 
147
  #----------------------------------------------------------------------------------------------------------------------------------------------
148
- # Spider-Verse
149
  @bot.command()
150
  async def spiderverse(ctx):
151
  try:
@@ -168,7 +168,7 @@ async def spiderverse(ctx):
168
  await ctx.message.add_reaction('❌')
169
 
170
  #----------------------------------------------------------------------------------------------------------------------------------------------
171
- # sketch
172
  @bot.command()
173
  async def sketch(ctx):
174
  try:
@@ -189,8 +189,23 @@ async def sketch(ctx):
189
  print(f"Error: {e}")
190
  await thread.send(f"{ctx.author.mention}Error: {e}")
191
  await ctx.message.add_reaction('❌')
192
- #----------------------------------------------------------------------------------------------------------------------------------------------
193
- # Stage 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  @bot.command()
195
  async def deepfloydif(ctx, *, prompt: str):
196
  try:
@@ -219,21 +234,25 @@ async def deepfloydif(ctx, *, prompt: str):
219
  print(f"Error: {e}")
220
  await ctx.reply('stage 1 error -> pre generation')
221
  await ctx.message.add_reaction('❌')
222
-
 
223
  try:
224
  #stage_1_results, stage_1_param_path, stage_1_result_path = df.predict(
225
  # prompt, negative_prompt, seed, number_of_images, guidance_scale, custom_timesteps_1, number_of_inference_steps, api_name='/generate64')
226
 
227
- stage_1_results, stage_1_param_path, stage_1_result_path = await asyncio.get_running_loop().run_in_executor(
228
- None, df.predict, prompt, negative_prompt, seed, number_of_images, guidance_scale, custom_timesteps_1, number_of_inference_steps,
229
- api_name)
 
 
230
 
231
  partialpath = stage_1_result_path[5:] #magic for later
 
232
  except Exception as e:
233
  print(f"Error: {e}")
234
  await ctx.reply('stage 1 error -> during generation')
235
  await ctx.message.add_reaction('❌')
236
-
237
  try:
238
  png_files = [f for f in os.listdir(stage_1_results) if f.endswith('.png')]
239
 
@@ -298,7 +317,7 @@ async def deepfloydif(ctx, *, prompt: str):
298
  await ctx.message.add_reaction('❌')
299
 
300
  #----------------------------------------------------------------------------------------------------------------------------
301
- # Stage 2
302
  async def dfif2(index: int, stage_1_result_path, thread, dfif_command_message_id): # add safetychecks
303
  try:
304
  await thread.send(f"inside dfif2, upscaling")
@@ -342,33 +361,11 @@ async def dfif2(index: int, stage_1_result_path, thread, dfif_command_message_id
342
  #await ctx.reply('An error occured in stage 2') need to fix
343
  #await ctx.message.add_reaction('❌')
344
  #----------------------------------------------------------------------------------------------------------------------------
 
345
  @bot.event
346
  async def on_reaction_add(reaction, user): # ctx = await bot.get_context(reaction.message)? could try later, might simplify
347
  try:
348
  # safety checks first ❌
349
- '''
350
- if user.bot:
351
- return
352
-
353
- #offline bot check ❌
354
- offline_bot_role_id = 1103676632667017266
355
- bot_member = reaction.message.guild.get_member(bot.user.id)
356
- if any(role.id == offline_bot_role_id for role in bot_member.roles):
357
- return
358
-
359
- # verified role check ❌
360
- guild = reaction.message.guild
361
- required_role_id = 900063512829755413 # @verified for now
362
- required_role = guild.get_role(required_role_id)
363
- if required_role not in user.roles:
364
- return
365
-
366
- #channel check ❌
367
- if reaction.message.channel.id != 1100458786826747945:
368
- return
369
-
370
- '''
371
-
372
  thread = reaction.message.channel
373
  threadparentid = thread.parent.id
374
 
 
89
  print(f"Error: safetychecks failed somewhere, command will not continue.")
90
  await ctx.message.reply(f"<@811235357663297546> SC failed somewhere") # this will always ping, as long as the bot has access to the channel
91
  #----------------------------------------------------------------------------------------------------------------------------------------------
92
+ # jojo βœ…
93
  @bot.command()
94
  async def jojo(ctx):
95
  # img + face βœ…
 
122
  await ctx.message.add_reaction('❌')
123
 
124
  #----------------------------------------------------------------------------------------------------------------------------------------------
125
+ # Disney ❌
126
  @bot.command()
127
  async def disney(ctx):
128
  try:
 
145
  await ctx.message.add_reaction('❌')
146
 
147
  #----------------------------------------------------------------------------------------------------------------------------------------------
148
+ # Spider-Verse βœ…
149
  @bot.command()
150
  async def spiderverse(ctx):
151
  try:
 
168
  await ctx.message.add_reaction('❌')
169
 
170
  #----------------------------------------------------------------------------------------------------------------------------------------------
171
+ # sketch βœ…
172
  @bot.command()
173
  async def sketch(ctx):
174
  try:
 
189
  print(f"Error: {e}")
190
  await thread.send(f"{ctx.author.mention}Error: {e}")
191
  await ctx.message.add_reaction('❌')
192
+ #----------------------------------------------------------------------------------------------------------------------------------------------
193
+ # deepfloydif stage 1 generation ❌
194
+ def inference(prompt):
195
+ negative_prompt = ''
196
+ seed = random.randint(0, 1000)
197
+ #seed = 1
198
+ number_of_images = 4
199
+ guidance_scale = 7
200
+ custom_timesteps_1 = 'smart50'
201
+ number_of_inference_steps = 50
202
+
203
+ stage_1_results, stage_1_param_path, stage_1_result_path = df.predict(
204
+ prompt, negative_prompt, seed, number_of_images, guidance_scale, custom_timesteps_1, number_of_inference_steps, api_name='/generate64')
205
+
206
+ return [stage_1_results, stage_1_param_path, stage_1_result_path]
207
+ #----------------------------------------------------------------------------------------------------------------------------------------------
208
+ # Stage 1 ❌
209
  @bot.command()
210
  async def deepfloydif(ctx, *, prompt: str):
211
  try:
 
234
  print(f"Error: {e}")
235
  await ctx.reply('stage 1 error -> pre generation')
236
  await ctx.message.add_reaction('❌')
237
+
238
+ #generation❌-------------------------------------------------------
239
  try:
240
  #stage_1_results, stage_1_param_path, stage_1_result_path = df.predict(
241
  # prompt, negative_prompt, seed, number_of_images, guidance_scale, custom_timesteps_1, number_of_inference_steps, api_name='/generate64')
242
 
243
+ # run blocking function in executor
244
+ loop = asyncio.get_running_loop()
245
+ result = await loop.run_in_executor(None, inference, prompt)
246
+ stage_1_results = result[0]
247
+ stage_1_result_path = result[2]
248
 
249
  partialpath = stage_1_result_path[5:] #magic for later
250
+
251
  except Exception as e:
252
  print(f"Error: {e}")
253
  await ctx.reply('stage 1 error -> during generation')
254
  await ctx.message.add_reaction('❌')
255
+ #posting images----------------------------------------------------------------
256
  try:
257
  png_files = [f for f in os.listdir(stage_1_results) if f.endswith('.png')]
258
 
 
317
  await ctx.message.add_reaction('❌')
318
 
319
  #----------------------------------------------------------------------------------------------------------------------------
320
+ # Stage 2 ❌
321
  async def dfif2(index: int, stage_1_result_path, thread, dfif_command_message_id): # add safetychecks
322
  try:
323
  await thread.send(f"inside dfif2, upscaling")
 
361
  #await ctx.reply('An error occured in stage 2') need to fix
362
  #await ctx.message.add_reaction('❌')
363
  #----------------------------------------------------------------------------------------------------------------------------
364
+ # react detector for stage 2 ❌
365
  @bot.event
366
  async def on_reaction_add(reaction, user): # ctx = await bot.get_context(reaction.message)? could try later, might simplify
367
  try:
368
  # safety checks first ❌
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
369
  thread = reaction.message.channel
370
  threadparentid = thread.parent.id
371