heaversm commited on
Commit
cfaeada
·
1 Parent(s): 9748f58

fix toggle to show labels

Browse files
Files changed (1) hide show
  1. app.py +21 -27
app.py CHANGED
@@ -19,7 +19,7 @@ from dotenv import load_dotenv
19
  # stats stuff
20
  from pymongo.mongo_client import MongoClient
21
  from pymongo.server_api import ServerApi
22
- import time # Make sure to import the time module
23
 
24
 
25
 
@@ -51,28 +51,19 @@ image_labels_global = []
51
  image_paths_global = []
52
 
53
  def update_labels(show_labels):
54
- if show_labels:
55
- # return [(path, label) for path, label in zip(image_paths_global, image_labels_global)]
56
- updated_gallery = [(path, label) for path, label in zip(image_paths_global, image_labels_global)]
57
- else:
58
- # return [(path, "") for path in image_paths_global] # Empty string as label to hide them
59
- updated_gallery = [(path, "") for path in image_paths_global] # Empty string as label to hide them
60
  return updated_gallery
61
 
62
  def generate_images_wrapper(prompts, pw, model, show_labels):
63
  global image_paths_global, image_labels_global
64
  image_paths, image_labels = generate_images(prompts, pw, model)
65
- image_paths_global = image_paths # Store paths globally
66
-
67
- if show_labels:
68
- image_labels_global = image_labels # Store labels globally if showing labels is enabled
69
- else:
70
- image_labels_global = [""] * len(image_labels) # Use empty labels if showing labels is disabled
71
 
72
- # Modify the return statement to not use labels if show_labels is False
 
73
  image_data = [(path, label if show_labels else "") for path, label in zip(image_paths, image_labels)]
74
 
75
- return image_data # Return image paths with or without labels based on the toggle
76
 
77
  def download_image(url):
78
  response = requests.get(url)
@@ -86,7 +77,6 @@ def zip_images(image_paths_and_labels):
86
  with zipfile.ZipFile(zip_file_path, 'w') as zipf:
87
  for image_url, _ in image_paths_and_labels:
88
  image_content = download_image(image_url)
89
- # Generate a random filename for the image
90
  random_filename = ''.join(random.choices(string.ascii_letters + string.digits, k=10)) + ".png"
91
  # Write the image content to the zip file with the random filename
92
  zipf.writestr(random_filename, image_content)
@@ -108,19 +98,22 @@ def generate_images(prompts, pw, model):
108
  if pw != os.getenv("PW"):
109
  raise gr.Error("Invalid password. Please try again.")
110
 
111
- image_paths = [] # Initialize a list to hold paths of generated images
112
- image_labels = [] # Initialize a list to hold labels of generated images
113
- users = [] # Initialize a list to hold user initials
114
 
115
  # Split the prompts string into individual prompts based on semicolon separation
116
  prompts_list = prompts.split(';')
117
 
118
  for entry in prompts_list:
119
  entry_parts = entry.split('-', 1) # Split by the first dash found
120
- if len(entry_parts) != 2:
121
- raise gr.Error("Invalid prompt format. Please ensure it is in 'initials-prompt' format.")
 
 
 
 
122
 
123
- user_initials, text = entry_parts[0].strip(), entry_parts[1].strip() # Extract user initials and the prompt
124
  users.append(user_initials) # Append user initials to the list
125
 
126
  try:
@@ -138,9 +131,11 @@ def generate_images(prompts, pw, model):
138
  gen_time = end_time - start_time # total generation time
139
 
140
  image_url = response.data[0].url
141
- image_label = f"User: {user_initials}, Prompt: {text}" # Creating a label for the image including user initials
 
142
 
143
  try:
 
144
  mongo_collection.insert_one({"user": user_initials, "text": text, "model": model, "image_url": image_url, "gen_time": gen_time, "timestamp": time.time()})
145
  except Exception as e:
146
  print(e)
@@ -165,14 +160,13 @@ with gr.Blocks() as demo:
165
  placeholder="Enter your text and then click on the \"Image Generate\" button")
166
 
167
  model = gr.Dropdown(choices=["dall-e-2", "dall-e-3"], label="Model", value="dall-e-3")
168
- #MH TODO: not toggling properly
169
- show_labels = gr.Checkbox(label="Show Image Labels", value=True)
170
  btn = gr.Button("Generate Images")
171
  output_images = gr.Gallery(label="Image Outputs", show_label=True, columns=[3], rows=[1], object_fit="contain",
172
  height="auto", allow_preview=False)
173
 
174
- text.submit(fn=generate_images_wrapper, inputs=[text, pw, model], outputs=output_images, api_name="generate_image")
175
- # btn.click(fn=generate_images_wrapper, inputs=[text, pw, model], outputs=output_images, api_name=False)
176
  btn.click(fn=generate_images_wrapper, inputs=[text, pw, model, show_labels], outputs=output_images, api_name=False)
177
 
178
  show_labels.change(fn=update_labels, inputs=[show_labels], outputs=[output_images])
 
19
  # stats stuff
20
  from pymongo.mongo_client import MongoClient
21
  from pymongo.server_api import ServerApi
22
+ import time
23
 
24
 
25
 
 
51
  image_paths_global = []
52
 
53
  def update_labels(show_labels):
54
+ updated_gallery = [(path, label if show_labels else "") for path, label in zip(image_paths_global, image_labels_global)]
 
 
 
 
 
55
  return updated_gallery
56
 
57
  def generate_images_wrapper(prompts, pw, model, show_labels):
58
  global image_paths_global, image_labels_global
59
  image_paths, image_labels = generate_images(prompts, pw, model)
60
+ image_paths_global = image_paths
 
 
 
 
 
61
 
62
+ # store this as a global so we can handle toggle state
63
+ image_labels_global = image_labels
64
  image_data = [(path, label if show_labels else "") for path, label in zip(image_paths, image_labels)]
65
 
66
+ return image_data
67
 
68
  def download_image(url):
69
  response = requests.get(url)
 
77
  with zipfile.ZipFile(zip_file_path, 'w') as zipf:
78
  for image_url, _ in image_paths_and_labels:
79
  image_content = download_image(image_url)
 
80
  random_filename = ''.join(random.choices(string.ascii_letters + string.digits, k=10)) + ".png"
81
  # Write the image content to the zip file with the random filename
82
  zipf.writestr(random_filename, image_content)
 
98
  if pw != os.getenv("PW"):
99
  raise gr.Error("Invalid password. Please try again.")
100
 
101
+ image_paths = [] # holds urls of images
102
+ image_labels = [] # shows the prompt in the gallery above the image
103
+ users = [] # adds the user to the label
104
 
105
  # Split the prompts string into individual prompts based on semicolon separation
106
  prompts_list = prompts.split(';')
107
 
108
  for entry in prompts_list:
109
  entry_parts = entry.split('-', 1) # Split by the first dash found
110
+ if len(entry_parts) == 2:
111
+ #raise gr.Error("Invalid prompt format. Please ensure it is in 'initials-prompt' format.")
112
+ user_initials, text = entry_parts[0].strip(), entry_parts[1].strip() # Extract user initials and the prompt
113
+ else:
114
+ text = entry.strip() # If no initials are provided, use the entire prompt as the text
115
+ user_initials = ""
116
 
 
117
  users.append(user_initials) # Append user initials to the list
118
 
119
  try:
 
131
  gen_time = end_time - start_time # total generation time
132
 
133
  image_url = response.data[0].url
134
+ # conditionally render the user to the label with the prompt
135
+ image_label = f"Prompt: {text}" if user_initials == "" else f"User: {user_initials}, Prompt: {text}"
136
 
137
  try:
138
+ # Save the prompt, model, image URL, generation time and creation timestamp to the database
139
  mongo_collection.insert_one({"user": user_initials, "text": text, "model": model, "image_url": image_url, "gen_time": gen_time, "timestamp": time.time()})
140
  except Exception as e:
141
  print(e)
 
160
  placeholder="Enter your text and then click on the \"Image Generate\" button")
161
 
162
  model = gr.Dropdown(choices=["dall-e-2", "dall-e-3"], label="Model", value="dall-e-3")
163
+ show_labels = gr.Checkbox(label="Show Image Labels", value=False)
 
164
  btn = gr.Button("Generate Images")
165
  output_images = gr.Gallery(label="Image Outputs", show_label=True, columns=[3], rows=[1], object_fit="contain",
166
  height="auto", allow_preview=False)
167
 
168
+ #trigger generation either through hitting enter in the text field, or clicking the button.
169
+ text.submit(fn=generate_images_wrapper, inputs=[text, pw, model, show_labels], outputs=output_images, api_name="generate_image") # Generate an api endpoint in Gradio / HF
170
  btn.click(fn=generate_images_wrapper, inputs=[text, pw, model, show_labels], outputs=output_images, api_name=False)
171
 
172
  show_labels.change(fn=update_labels, inputs=[show_labels], outputs=[output_images])