lunarflu HF Staff commited on
Commit
fb515d0
·
1 Parent(s): 8a0bc2e

[deepfloydif.py] 1.1 PR

Browse files
Files changed (1) hide show
  1. deepfloydif.py +195 -0
deepfloydif.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import discord
2
+ from gradio_client import Client
3
+ import os
4
+ import random
5
+ from PIL import Image
6
+ import asyncio
7
+ import glob
8
+ import pathlib
9
+
10
+ HF_TOKEN = os.getenv("HF_TOKEN")
11
+ deepfloydif_client = Client("huggingface-projects/IF", HF_TOKEN)
12
+
13
+ BOT_USER_ID = 1086256910572986469 if os.getenv("TEST_ENV", False) else 1102236653545861151
14
+ DEEPFLOYDIF_CHANNEL_ID = 1121834257959092234 if os.getenv("TEST_ENV", False) else 1119313215675973714
15
+
16
+
17
+ def deepfloydif_stage_1_inference(prompt):
18
+ """Generates an image based on a prompt"""
19
+ negative_prompt = ""
20
+ seed = random.randint(0, 1000)
21
+ number_of_images = 4
22
+ guidance_scale = 7
23
+ custom_timesteps_1 = "smart50"
24
+ number_of_inference_steps = 50
25
+ (
26
+ stage_1_images,
27
+ stage_1_param_path,
28
+ path_for_stage_2_upscaling,
29
+ ) = deepfloydif_client.predict(
30
+ prompt,
31
+ negative_prompt,
32
+ seed,
33
+ number_of_images,
34
+ guidance_scale,
35
+ custom_timesteps_1,
36
+ number_of_inference_steps,
37
+ api_name="/generate64",
38
+ )
39
+ return [stage_1_images, stage_1_param_path, path_for_stage_2_upscaling]
40
+
41
+
42
+ def deepfloydif_stage_2_inference(index, path_for_stage_2_upscaling):
43
+ """Upscales one of the images from deepfloydif_stage_1_inference based on the chosen index"""
44
+ selected_index_for_stage_2 = index
45
+ seed_2 = 0
46
+ guidance_scale_2 = 4
47
+ custom_timesteps_2 = "smart50"
48
+ number_of_inference_steps_2 = 50
49
+ result_path = deepfloydif_client.predict(
50
+ path_for_stage_2_upscaling,
51
+ selected_index_for_stage_2,
52
+ seed_2,
53
+ guidance_scale_2,
54
+ custom_timesteps_2,
55
+ number_of_inference_steps_2,
56
+ api_name="/upscale256",
57
+ )
58
+ return result_path
59
+
60
+
61
+ async def react_1234(reaction_emojis, combined_image_dfif):
62
+ """Sets up 4 reaction emojis so the user can choose an image to upscale for deepfloydif"""
63
+ for emoji in reaction_emojis:
64
+ await combined_image_dfif.add_reaction(emoji)
65
+
66
+
67
+ def load_image(png_files, stage_1_images):
68
+ """Opens images as variables so we can combine them later"""
69
+ results = []
70
+ for file in png_files:
71
+ png_path = os.path.join(stage_1_images, file)
72
+ results.append(Image.open(png_path))
73
+ return results
74
+
75
+
76
+ def combine_images(png_files, stage_1_images, partial_path):
77
+ if os.environ.get("TEST_ENV") == "True":
78
+ print("Combining images for deepfloydif_stage_1")
79
+ images = load_image(png_files, stage_1_images)
80
+ combined_image = Image.new("RGB", (images[0].width * 2, images[0].height * 2))
81
+ combined_image.paste(images[0], (0, 0))
82
+ combined_image.paste(images[1], (images[0].width, 0))
83
+ combined_image.paste(images[2], (0, images[0].height))
84
+ combined_image.paste(images[3], (images[0].width, images[0].height))
85
+ combined_image_path = os.path.join(stage_1_images, f"{partial_path}.png")
86
+ combined_image.save(combined_image_path)
87
+ return combined_image_path
88
+
89
+
90
+ async def deepfloydif_stage_1(interaction, prompt, client):
91
+ """DeepfloydIF command (generate images with realistic text using slash commands)"""
92
+ try:
93
+ if interaction.user.id != BOT_USER_ID:
94
+ if interaction.channel.id == DEEPFLOYDIF_CHANNEL_ID:
95
+ if os.environ.get("TEST_ENV") == "True":
96
+ print("Safety checks passed for deepfloydif_stage_1")
97
+ await interaction.response.send_message("Working on it!")
98
+ channel = interaction.channel
99
+ # interaction.response message can't be used to create a thread, so we create another message
100
+ message = await channel.send("DeepfloydIF Thread")
101
+ thread = await message.create_thread(name=f"{prompt}", auto_archive_duration=60)
102
+ await thread.send(
103
+ "[DISCLAIMER: HuggingBot is a **highly experimental** beta feature; Additional information on the"
104
+ " DeepfloydIF model can be found here: https://huggingface.co/spaces/DeepFloyd/IF"
105
+ )
106
+ await thread.send(f"{interaction.user.mention} Generating images in thread, can take ~1 minute...")
107
+
108
+ loop = asyncio.get_running_loop()
109
+ result = await loop.run_in_executor(None, deepfloydif_stage_1_inference, prompt)
110
+ stage_1_images = result[0]
111
+ path_for_stage_2_upscaling = result[2]
112
+
113
+ partial_path = pathlib.Path(path_for_stage_2_upscaling).name
114
+ png_files = list(glob.glob(f"{stage_1_images}/**/*.png"))
115
+
116
+ if png_files:
117
+ combined_image_path = combine_images(png_files, stage_1_images, partial_path)
118
+ if os.environ.get("TEST_ENV") == "True":
119
+ print("Images combined for deepfloydif_stage_1")
120
+ with open(combined_image_path, "rb") as f:
121
+ combined_image_dfif = await thread.send(
122
+ f"{interaction.user.mention} React with the image number you want to upscale!",
123
+ file=discord.File(f, f"{partial_path}.png"),
124
+ )
125
+ emoji_list = ["↖️", "↗️", "↙️", "↘️"]
126
+ await react_1234(emoji_list, combined_image_dfif)
127
+ else:
128
+ await thread.send(f"{interaction.user.mention} No PNG files were found, cannot post them!")
129
+ except Exception as e:
130
+ print(f"Error: {e}")
131
+
132
+
133
+ async def deepfloydif_stage_2_react_check(reaction, user):
134
+ """Checks for a reaction in order to call dfif2"""
135
+ try:
136
+ if os.environ.get("TEST_ENV") == "True":
137
+ print("Running deepfloydif_stage_2_react_check")
138
+ global BOT_USER_ID
139
+ global DEEPFLOYDIF_CHANNEL_ID
140
+ if user.id != BOT_USER_ID:
141
+ thread = reaction.message.channel
142
+ thread_parent_id = thread.parent.id
143
+ if thread_parent_id == DEEPFLOYDIF_CHANNEL_ID:
144
+ if reaction.message.attachments:
145
+ if user.id == reaction.message.mentions[0].id:
146
+ attachment = reaction.message.attachments[0]
147
+ image_name = attachment.filename
148
+ partial_path = image_name[:-4]
149
+ full_path = "/tmp/" + partial_path
150
+ emoji = reaction.emoji
151
+ if emoji == "↖️":
152
+ index = 0
153
+ elif emoji == "↗️":
154
+ index = 1
155
+ elif emoji == "↙️":
156
+ index = 2
157
+ elif emoji == "↘️":
158
+ index = 3
159
+ path_for_stage_2_upscaling = full_path
160
+ thread = reaction.message.channel
161
+ await deepfloydif_stage_2(
162
+ index,
163
+ path_for_stage_2_upscaling,
164
+ thread,
165
+ )
166
+ except Exception as e:
167
+ print(f"Error: {e} (known error, does not cause issues, low priority)")
168
+
169
+
170
+ async def deepfloydif_stage_2(index: int, path_for_stage_2_upscaling, thread):
171
+ """upscaling function for images generated using /deepfloydif"""
172
+ try:
173
+ if os.environ.get("TEST_ENV") == "True":
174
+ print("Running deepfloydif_stage_2")
175
+ if index == 0:
176
+ position = "top left"
177
+ elif index == 1:
178
+ position = "top right"
179
+ elif index == 2:
180
+ position = "bottom left"
181
+ elif index == 3:
182
+ position = "bottom right"
183
+ await thread.send(f"Upscaling the {position} image...")
184
+
185
+ # run blocking function in executor
186
+ loop = asyncio.get_running_loop()
187
+ result_path = await loop.run_in_executor(
188
+ None, deepfloydif_stage_2_inference, index, path_for_stage_2_upscaling
189
+ )
190
+
191
+ with open(result_path, "rb") as f:
192
+ await thread.send("Here is the upscaled image!", file=discord.File(f, "result.png"))
193
+ await thread.edit(archived=True)
194
+ except Exception as e:
195
+ print(f"Error: {e}")