Spaces:
Runtime error
Runtime error
fix toggle to show labels
Browse files
app.py
CHANGED
@@ -19,7 +19,7 @@ from dotenv import load_dotenv
|
|
19 |
# stats stuff
|
20 |
from pymongo.mongo_client import MongoClient
|
21 |
from pymongo.server_api import ServerApi
|
22 |
-
import time
|
23 |
|
24 |
|
25 |
|
@@ -51,28 +51,19 @@ image_labels_global = []
|
|
51 |
image_paths_global = []
|
52 |
|
53 |
def update_labels(show_labels):
|
54 |
-
if show_labels
|
55 |
-
# return [(path, label) for path, label in zip(image_paths_global, image_labels_global)]
|
56 |
-
updated_gallery = [(path, label) for path, label in zip(image_paths_global, image_labels_global)]
|
57 |
-
else:
|
58 |
-
# return [(path, "") for path in image_paths_global] # Empty string as label to hide them
|
59 |
-
updated_gallery = [(path, "") for path in image_paths_global] # Empty string as label to hide them
|
60 |
return updated_gallery
|
61 |
|
62 |
def generate_images_wrapper(prompts, pw, model, show_labels):
|
63 |
global image_paths_global, image_labels_global
|
64 |
image_paths, image_labels = generate_images(prompts, pw, model)
|
65 |
-
image_paths_global = image_paths
|
66 |
-
|
67 |
-
if show_labels:
|
68 |
-
image_labels_global = image_labels # Store labels globally if showing labels is enabled
|
69 |
-
else:
|
70 |
-
image_labels_global = [""] * len(image_labels) # Use empty labels if showing labels is disabled
|
71 |
|
72 |
-
#
|
|
|
73 |
image_data = [(path, label if show_labels else "") for path, label in zip(image_paths, image_labels)]
|
74 |
|
75 |
-
return image_data
|
76 |
|
77 |
def download_image(url):
|
78 |
response = requests.get(url)
|
@@ -86,7 +77,6 @@ def zip_images(image_paths_and_labels):
|
|
86 |
with zipfile.ZipFile(zip_file_path, 'w') as zipf:
|
87 |
for image_url, _ in image_paths_and_labels:
|
88 |
image_content = download_image(image_url)
|
89 |
-
# Generate a random filename for the image
|
90 |
random_filename = ''.join(random.choices(string.ascii_letters + string.digits, k=10)) + ".png"
|
91 |
# Write the image content to the zip file with the random filename
|
92 |
zipf.writestr(random_filename, image_content)
|
@@ -108,19 +98,22 @@ def generate_images(prompts, pw, model):
|
|
108 |
if pw != os.getenv("PW"):
|
109 |
raise gr.Error("Invalid password. Please try again.")
|
110 |
|
111 |
-
image_paths = [] #
|
112 |
-
image_labels = [] #
|
113 |
-
users = [] #
|
114 |
|
115 |
# Split the prompts string into individual prompts based on semicolon separation
|
116 |
prompts_list = prompts.split(';')
|
117 |
|
118 |
for entry in prompts_list:
|
119 |
entry_parts = entry.split('-', 1) # Split by the first dash found
|
120 |
-
if len(entry_parts)
|
121 |
-
raise gr.Error("Invalid prompt format. Please ensure it is in 'initials-prompt' format.")
|
|
|
|
|
|
|
|
|
122 |
|
123 |
-
user_initials, text = entry_parts[0].strip(), entry_parts[1].strip() # Extract user initials and the prompt
|
124 |
users.append(user_initials) # Append user initials to the list
|
125 |
|
126 |
try:
|
@@ -138,9 +131,11 @@ def generate_images(prompts, pw, model):
|
|
138 |
gen_time = end_time - start_time # total generation time
|
139 |
|
140 |
image_url = response.data[0].url
|
141 |
-
|
|
|
142 |
|
143 |
try:
|
|
|
144 |
mongo_collection.insert_one({"user": user_initials, "text": text, "model": model, "image_url": image_url, "gen_time": gen_time, "timestamp": time.time()})
|
145 |
except Exception as e:
|
146 |
print(e)
|
@@ -165,14 +160,13 @@ with gr.Blocks() as demo:
|
|
165 |
placeholder="Enter your text and then click on the \"Image Generate\" button")
|
166 |
|
167 |
model = gr.Dropdown(choices=["dall-e-2", "dall-e-3"], label="Model", value="dall-e-3")
|
168 |
-
|
169 |
-
show_labels = gr.Checkbox(label="Show Image Labels", value=True)
|
170 |
btn = gr.Button("Generate Images")
|
171 |
output_images = gr.Gallery(label="Image Outputs", show_label=True, columns=[3], rows=[1], object_fit="contain",
|
172 |
height="auto", allow_preview=False)
|
173 |
|
174 |
-
|
175 |
-
|
176 |
btn.click(fn=generate_images_wrapper, inputs=[text, pw, model, show_labels], outputs=output_images, api_name=False)
|
177 |
|
178 |
show_labels.change(fn=update_labels, inputs=[show_labels], outputs=[output_images])
|
|
|
19 |
# stats stuff
|
20 |
from pymongo.mongo_client import MongoClient
|
21 |
from pymongo.server_api import ServerApi
|
22 |
+
import time
|
23 |
|
24 |
|
25 |
|
|
|
51 |
image_paths_global = []
|
52 |
|
53 |
def update_labels(show_labels):
|
54 |
+
updated_gallery = [(path, label if show_labels else "") for path, label in zip(image_paths_global, image_labels_global)]
|
|
|
|
|
|
|
|
|
|
|
55 |
return updated_gallery
|
56 |
|
57 |
def generate_images_wrapper(prompts, pw, model, show_labels):
|
58 |
global image_paths_global, image_labels_global
|
59 |
image_paths, image_labels = generate_images(prompts, pw, model)
|
60 |
+
image_paths_global = image_paths
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
+
# store this as a global so we can handle toggle state
|
63 |
+
image_labels_global = image_labels
|
64 |
image_data = [(path, label if show_labels else "") for path, label in zip(image_paths, image_labels)]
|
65 |
|
66 |
+
return image_data
|
67 |
|
68 |
def download_image(url):
|
69 |
response = requests.get(url)
|
|
|
77 |
with zipfile.ZipFile(zip_file_path, 'w') as zipf:
|
78 |
for image_url, _ in image_paths_and_labels:
|
79 |
image_content = download_image(image_url)
|
|
|
80 |
random_filename = ''.join(random.choices(string.ascii_letters + string.digits, k=10)) + ".png"
|
81 |
# Write the image content to the zip file with the random filename
|
82 |
zipf.writestr(random_filename, image_content)
|
|
|
98 |
if pw != os.getenv("PW"):
|
99 |
raise gr.Error("Invalid password. Please try again.")
|
100 |
|
101 |
+
image_paths = [] # holds urls of images
|
102 |
+
image_labels = [] # shows the prompt in the gallery above the image
|
103 |
+
users = [] # adds the user to the label
|
104 |
|
105 |
# Split the prompts string into individual prompts based on semicolon separation
|
106 |
prompts_list = prompts.split(';')
|
107 |
|
108 |
for entry in prompts_list:
|
109 |
entry_parts = entry.split('-', 1) # Split by the first dash found
|
110 |
+
if len(entry_parts) == 2:
|
111 |
+
#raise gr.Error("Invalid prompt format. Please ensure it is in 'initials-prompt' format.")
|
112 |
+
user_initials, text = entry_parts[0].strip(), entry_parts[1].strip() # Extract user initials and the prompt
|
113 |
+
else:
|
114 |
+
text = entry.strip() # If no initials are provided, use the entire prompt as the text
|
115 |
+
user_initials = ""
|
116 |
|
|
|
117 |
users.append(user_initials) # Append user initials to the list
|
118 |
|
119 |
try:
|
|
|
131 |
gen_time = end_time - start_time # total generation time
|
132 |
|
133 |
image_url = response.data[0].url
|
134 |
+
# conditionally render the user to the label with the prompt
|
135 |
+
image_label = f"Prompt: {text}" if user_initials == "" else f"User: {user_initials}, Prompt: {text}"
|
136 |
|
137 |
try:
|
138 |
+
# Save the prompt, model, image URL, generation time and creation timestamp to the database
|
139 |
mongo_collection.insert_one({"user": user_initials, "text": text, "model": model, "image_url": image_url, "gen_time": gen_time, "timestamp": time.time()})
|
140 |
except Exception as e:
|
141 |
print(e)
|
|
|
160 |
placeholder="Enter your text and then click on the \"Image Generate\" button")
|
161 |
|
162 |
model = gr.Dropdown(choices=["dall-e-2", "dall-e-3"], label="Model", value="dall-e-3")
|
163 |
+
show_labels = gr.Checkbox(label="Show Image Labels", value=False)
|
|
|
164 |
btn = gr.Button("Generate Images")
|
165 |
output_images = gr.Gallery(label="Image Outputs", show_label=True, columns=[3], rows=[1], object_fit="contain",
|
166 |
height="auto", allow_preview=False)
|
167 |
|
168 |
+
#trigger generation either through hitting enter in the text field, or clicking the button.
|
169 |
+
text.submit(fn=generate_images_wrapper, inputs=[text, pw, model, show_labels], outputs=output_images, api_name="generate_image") # Generate an api endpoint in Gradio / HF
|
170 |
btn.click(fn=generate_images_wrapper, inputs=[text, pw, model, show_labels], outputs=output_images, api_name=False)
|
171 |
|
172 |
show_labels.change(fn=update_labels, inputs=[show_labels], outputs=[output_images])
|