santhosh97 commited on
Commit
ada7876
1 Parent(s): 51cf589

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +195 -268
app.py CHANGED
@@ -9,8 +9,6 @@ from typing import List, Tuple, TYPE_CHECKING
9
  import uuid
10
  import argparse
11
  import logging
12
- import sendgrid
13
- from sendgrid.helpers.mail import Mail, Email, To, Content
14
  from enum import Enum
15
  import tempfile
16
  from pathlib import Path
@@ -34,60 +32,53 @@ logger.setLevel(logging.INFO)
34
  load_dotenv()
35
 
36
  _S3_PATH_OUTPUT = "s3://gretel-image-synthetics-use2/data/{identifier}/{image_type}_images.zip"
37
- CAPTCHA_ENDPOINT = "https://captcha-api.akshit.me/v2/generate"
38
- VERIFY_ENDPOINT = "https://captcha-api.akshit.me/v2/verify"
 
39
 
40
  class UxState(str, Enum):
41
- LOGIN = "login"
42
- VERIFY_EMAIL = "verify_email"
43
  UPLOAD1 = "upload1"
44
  UPLOAD2 = "upload2"
45
- CAPTCHA = "captcha"
 
46
  TRAIN = "train"
 
47
 
 
 
 
48
  def setup_session_state():
49
  if "key" not in st.session_state:
50
  st.session_state["key"] = uuid.uuid4().hex
51
 
52
  if "ux_state" not in st.session_state:
53
- st.session_state["ux_state"] = UxState.LOGIN
54
 
55
  if "model_inputs" not in st.session_state:
56
  st.session_state["model_inputs"] = None
57
 
58
- if "initial_concept_file_path" not in st.session_state:
59
- st.session_state["initial_concept_file_path"] = None
60
-
61
- if "initial_token" not in st.session_state:
62
- st.session_state["initial_token"] = None
63
-
64
- if "initial_class_token" not in st.session_state:
65
- st.session_state["initial_class_token"] = None
66
-
67
- if "secondary_concept_file_path" not in st.session_state:
68
- st.session_state["secondary_concept_file_path"] = None
69
-
70
- if "secondary_token" not in st.session_state:
71
- st.session_state["secondary_token"] = None
72
-
73
- if "secondary_class_token" not in st.session_state:
74
- st.session_state["secondary_class_token"] = None
75
-
76
  if "prompt_keywords" not in st.session_state:
77
  st.session_state["prompt_keywords"] = None
 
 
 
78
 
79
  if "view" not in st.session_state:
80
  st.session_state["view"] = False
81
 
82
- if "captcha_response" not in st.session_state:
83
- st.session_state["captcha_response"] = None
84
-
85
- if "captcha" not in st.session_state:
86
- st.session_state["captcha"] = {}
87
-
88
  if "user_email" not in st.session_state:
89
  st.session_state["user_email"] = None
90
 
 
 
 
 
 
 
91
 
92
  def bucket_parts(s3_path: str) -> Tuple[str, str]:
93
  """Split an S3 path into bucket and key.
@@ -162,15 +153,12 @@ def generate_s3_put_url(s3_path: str, expiration_seconds: int) -> str:
162
 
163
  def zip_and_upload_images(identifier: str, uploaded_files: List[BytesIO], image_type: str) -> str:
164
  """Save images as zip file to s3 for use in backend.
165
-
166
  Blocks until images are processed, added to zip file, and uploaded to S3.
167
-
168
  Args:
169
  identifier: unique identifier for the run, used in s3 link
170
  uploaded_files: BytesIO or UploadedFile from streamlit fileuploader
171
  image_type: string to identify different batches of images used in the
172
  backend model/training. Currently used values: "face", "theme"
173
-
174
  Returns:
175
  S3 location of zip file containing png images.
176
  """
@@ -201,268 +189,207 @@ def zip_and_upload_images(identifier: str, uploaded_files: List[BytesIO], image_
201
 
202
  return s3_path
203
 
204
- def send_email(to_email, user_code):
205
- sg = sendgrid.SendGridAPIClient(api_key=os.environ.get('SENDGRID_API_KEY'))
206
- from_email = Email("santhosh@gretel.ai")
207
- to_email = To(to_email)
208
- subject = "One Time Code"
209
- content = Content("text/plain", f"Here is your one-time code: {user_code}")
210
- mail = Mail(from_email, to_email, subject, content)
211
- mail_json = mail.get()
212
- response = sg.client.mail.send.post(request_body=mail_json)
213
-
214
-
215
- # Create a function to generate a captcha
216
- def generate_captcha():
217
- # Make a GET request to the API endpoint to generate a captcha
218
- response = requests.get(CAPTCHA_ENDPOINT)
219
-
220
- # If the request was successful, return the API response
221
- if response.status_code == 200:
222
- return response.json()
223
- else:
224
- logger.warn(f"Error from generate captcha request: {response.json()}")
225
- # Otherwise, return an error message
226
- return {"error": "Failed to generate captcha"}
227
-
228
- # Create a function to verify the captcha
229
- def verify_captcha(captcha_id, captcha_response):
230
- # Make a POST request to the API endpoint with the captcha ID and response
231
- verify_json = {"uuid": captcha_id, "captcha": captcha_response}
232
- response = requests.post(
233
- VERIFY_ENDPOINT, json=verify_json,
234
- )
235
- logger.info(f"Response from captcha verify: {response}")
236
-
237
- # If the request was successful, return the API response
238
- if response.status_code == 200:
239
- return response.json()
240
-
241
- # Otherwise, return an error message
242
- return {"error": "Failed to verify captcha"}
243
-
244
  def train_model(model_inputs):
245
- api_key = os.environ.get('API_KEY')
246
- model_key = os.environ.get('MODEL_KEY')
247
  st.markdown(str(model_inputs))
248
  _ = banana.run(api_key, model_key, model_inputs)
249
 
250
- def run_login():
251
- user_email_input = st.empty()
252
- with user_email_input.form(key='user_auth'):
253
- text_input = st.text_input(label='Please Enter Your Email')
254
- submit_button = st.form_submit_button(label='Submit')
255
- if submit_button:
256
- st.session_state["user_email"] = text_input
257
- send_email(text_input, str(st.session_state["key"]))
258
- st.session_state["ux_state"] = UxState.VERIFY_EMAIL
259
- # TODO: alternately run this submit log in a callback to the input?
260
- # or otherwise ensure we execute the runner for the new state
261
- st.experimental_rerun()
262
-
263
-
264
- def run_verify_email():
265
- user_auth = st.empty()
266
- with user_auth.form("one-code"):
267
- text_input = st.text_input(label='Please Input One Time Code')
268
- submit_button = st.form_submit_button(label='Submit')
269
- if submit_button:
270
- if text_input == st.session_state["key"]:
271
- st.session_state["ux_state"] = UxState.UPLOAD1
272
- st.experimental_rerun()
273
- else:
274
- st.markdown("Please Enter Correct Code!")
275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
 
277
  def run_upload_initial():
278
  identifier = st.session_state["key"]
279
- face_images = st.empty()
280
- with face_images.form("my_form"):
281
  uploaded_files = st.file_uploader(
282
- "Choose image files", accept_multiple_files=True, type=["png", "jpg", "jpeg"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  )
284
- initial_concept_token = st.text_input("Token Name")
285
- initial_concept_class_token = st.text_input("Token Class")
286
  submitted = st.form_submit_button(f"Upload")
287
  if submitted:
288
  with st.spinner('Uploading...'):
289
- st.session_state["initial_concept_file_path"] = zip_and_upload_images(
290
- identifier, uploaded_files, "face"
291
- )
292
- st.session_state["initial_token"] = initial_concept_token
293
- st.session_state["initial_class_token"] = initial_concept_class_token
 
 
294
  st.success(f'Uploading {len(uploaded_files)} files done!')
295
- st.session_state["ux_state"] = UxState.UPLOAD2
296
- st.experimental_rerun()
297
-
 
298
 
299
  def run_upload_secondary():
300
  identifier = st.session_state["key"]
301
- preset_theme_images = st.empty()
302
- with preset_theme_images.form("choose-preset-theme"):
303
- img = image_select(
304
- "Choose a Theme!",
305
- images=[
306
- "https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/theme-images/got.png",
307
- "https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/theme-images/ironman.png",
308
- "https://ichef.bbci.co.uk/images/ic/640x360/p09t1hg0.jpg",
309
- ],
310
- captions=["Game of Thrones", "Iron Man", "Thor"],
311
- return_value="index",
312
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
 
314
- col1, col2 = st.columns([0.17, 1])
315
- prompt_keywords = st.text_input("Prompt Keywords")
316
- with col1:
317
- submitted_3 = st.form_submit_button("Submit!")
318
- if submitted_3:
319
- with st.spinner():
320
- dictionary = {
321
- 0: [
322
- "https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/data/game-of-thrones.zip",
323
- "game-of-thrones",
324
- ],
325
- 1: ["https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/data/iron-man.zip", "iron-man"],
326
- 2: ["https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/data/thor.zip", "thor"],
327
- }
328
- st.session_state["model_inputs"] = {
329
- "secondary_concept_file_path": dictionary[img][0],
330
- # Use presigned url since backend does not have credentials
331
- "initial_token": st.session_state["initial_token"],
332
- "secondary_token": dictionary[img][1],
333
- "initial_class_token": st.session_state["initial_class_token"],
334
- "secondary_class_token": 'superhero',
335
- "initial_concept_file_path": generate_s3_get_url(st.session_state["initial_concept_file_path"], expiration_seconds=3600),
336
- "num_images": 50,
337
- "prompt_keywords": prompt_keywords
338
- }
339
- st.success("Success!")
340
- st.session_state["ux_state"] = UxState.CAPTCHA
341
- st.experimental_rerun()
342
- with col2:
343
- submitted_4 = st.form_submit_button(
344
- "If none of the themes interest you, click here!"
345
- )
346
- if submitted_4:
347
- st.session_state["view"] = True
348
-
349
- if st.session_state["view"]:
350
- # TODO: split into it's own ux state and function?
351
- custom_theme_images = st.empty()
352
- with custom_theme_images.form("input_custom_themes"):
353
- st.markdown("If none of the themes interest you, please input your own!")
354
- uploaded_files_2 = st.file_uploader(
355
- "Choose image files",
356
- accept_multiple_files=True,
357
- type=["png", "jpg", "jpeg"],
358
- )
359
- secondary_concept_token = st.text_input("Token Name")
360
- secondary_concept_class_token = st.text_input("Token Class")
361
- prompt_keywords = st.text_input("Prompt Keywords")
362
- submitted_3 = st.form_submit_button("Submit!")
363
- if submitted_3:
364
- with st.spinner('Uploading...'):
365
- st.session_state["secondary_concept_file_path"] = zip_and_upload_images(
366
- identifier, uploaded_files_2, "theme"
367
- )
368
- #st.markdown(secondary_concept_file_path)
369
- st.session_state["model_inputs"] = {
370
- # Use presigned urls since backend does not have credentials
371
- "initial_concept_file_path": generate_s3_get_url(st.session_state["initial_concept_file_path"], expiration_seconds=3600),
372
- "secondary_concept_file_path": generate_s3_get_url(st.session_state["secondary_concept_file_path"], expiration_seconds=3600),
373
- "initial_token": st.session_state["initial_token"],
374
- "secondary_token": secondary_concept_token,
375
- "initial_class_token": st.session_state["initial_class_token"],
376
- "secondary_class_token": secondary_concept_class_token,
377
- "num_images": 50,
378
- "prompt_keywords": prompt_keywords
379
- }
380
- st.success('Done!')
381
- st.session_state["ux_state"] = UxState.CAPTCHA
382
- st.experimental_rerun()
383
-
384
-
385
-
386
- def run_captcha():
387
- captcha_form = st.empty()
388
- with captcha_form.form("captcha_form", clear_on_submit=True):
389
- # Create container to create image/text input out of order from the
390
- # format submit button. Needed since we need to know the status of the
391
- # form submit to know what the captcha should do.
392
- captcha_container = st.container()
393
- display_captcha = True
394
- # TODO: Submit button renders first, then drops down once the image is
395
- # fetched leading to page reflow. Would be nice to not have reflow, but
396
- # we need to know if the submit button was previously pressed and if the
397
- # captcha was solved to generate and display a new captcha or not.
398
- # Possible solution is use an on_click callback to set a session_state
399
- # variable to access whether the button was pushed or not instead of the
400
- # return value here.
401
- submitted = st.form_submit_button("Submit Captcha!")
402
-
403
  if submitted:
404
- result = verify_captcha(st.session_state['captcha']['uuid'], st.session_state["captcha_response"])
405
- del st.session_state["captcha_response"]
406
- if 'message' in result and result['message'] == 'CAPTCHA_SOLVED':
407
- st.session_state['captcha'] = {}
408
- display_captcha = False
409
- with st.spinner("Model Fine Tuning..."):
410
- st.session_state["model_inputs"]["identifier"] = st.session_state["key"]
411
- st.session_state["model_inputs"]["email"] = st.session_state["user_email"]
412
- s3_output_path = _S3_PATH_OUTPUT.format(identifier=st.session_state["key"], image_type="generated")
413
- # The backend does not have s3 credentials, so generate
414
- # presigned urls for the backend to use to write and read
415
- # the generated images.
416
- st.session_state["model_inputs"]["output_s3_url_get"] = generate_s3_get_url(
417
- s3_output_path, expiration_seconds=60 * 60 * 24,
418
- )
419
- st.session_state["model_inputs"]["output_s3_url_put"] = generate_s3_put_url(
420
- s3_output_path, expiration_seconds=3600,
421
- )
422
- train_model(st.session_state["model_inputs"])
423
- st.session_state["ux_state"] = UxState.TRAIN
424
- st.experimental_rerun()
425
- else:
426
- st.error(result['error'])
427
-
428
- if display_captcha:
429
- # Generate new captcha and display. Occurs on first run of the
430
- # captcha state, or after previously failed captcha attempts.
431
- result = generate_captcha()
432
- captcha_id = result['uuid']
433
- captcha_image = result['captcha']
434
-
435
- st.session_state['captcha']['uuid'] = captcha_id
436
- st.session_state['captcha']['captcha'] = captcha_image
437
-
438
- captcha_container.image(captcha_image, width=300)
439
-
440
- captcha_container.text_input("Enter the captcha response", key="captcha_response")
441
- # Submit button already setup previously.
442
 
 
 
 
 
 
 
 
 
 
 
 
 
443
 
444
  def run_train():
445
- st.write(f"Congratulations, your model is training.")
446
  st.write(f"We'll send an email to {st.session_state['user_email']} when it's finished, usually about 20-30 minutes.")
447
- st.write("You may close this browser window/tab.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
448
 
 
 
 
449
 
450
  if __name__ == "__main__":
451
  setup_session_state()
452
 
453
  ux_state = st.session_state["ux_state"]
454
 
455
- if ux_state == UxState.LOGIN:
456
- run_login()
457
- elif ux_state == UxState.VERIFY_EMAIL:
458
- run_verify_email()
459
- elif ux_state == UxState.UPLOAD1:
460
- run_upload_initial()
461
- elif ux_state == UxState.UPLOAD2:
462
- run_upload_secondary()
463
- elif ux_state == UxState.CAPTCHA:
464
- run_captcha()
465
- elif ux_state == UxState.TRAIN:
466
- run_train()
467
  else:
468
- raise ValueError(f"Internal app error, unknown ux_state='{ux_state}'")
 
 
9
  import uuid
10
  import argparse
11
  import logging
 
 
12
  from enum import Enum
13
  import tempfile
14
  from pathlib import Path
 
32
  load_dotenv()
33
 
34
  _S3_PATH_OUTPUT = "s3://gretel-image-synthetics-use2/data/{identifier}/{image_type}_images.zip"
35
+ _GRETEL_USERINFO_ENDPOINT = "https://api.gretel.cloud/users/me"
36
+
37
+
38
 
39
  class UxState(str, Enum):
40
+ LOGIN_VIA_API_KEY = "login_via_api_key"
 
41
  UPLOAD1 = "upload1"
42
  UPLOAD2 = "upload2"
43
+ UPLOAD3 = "upload3"
44
+ PROMPT = "prompt"
45
  TRAIN = "train"
46
+ FINISHED = "finished"
47
 
48
+ # Command-line arguments to control some stuff for easier local testing.
49
+ # Eventually may want to move everything into functions and have a
50
+ # if __name__ == "main" setup instead of everything inline.
51
  def setup_session_state():
52
  if "key" not in st.session_state:
53
  st.session_state["key"] = uuid.uuid4().hex
54
 
55
  if "ux_state" not in st.session_state:
56
+ st.session_state["ux_state"] = UxState.LOGIN_VIA_API_KEY
57
 
58
  if "model_inputs" not in st.session_state:
59
  st.session_state["model_inputs"] = None
60
 
61
+ if "concepts" not in st.session_state:
62
+ st.session_state["concepts"] = []
63
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  if "prompt_keywords" not in st.session_state:
65
  st.session_state["prompt_keywords"] = None
66
+
67
+ if "prompt" not in st.session_state:
68
+ st.session_state["prompt"] = None
69
 
70
  if "view" not in st.session_state:
71
  st.session_state["view"] = False
72
 
 
 
 
 
 
 
73
  if "user_email" not in st.session_state:
74
  st.session_state["user_email"] = None
75
 
76
+ if "user_firstname" not in st.session_state:
77
+ st.session_state["user_firstname"] = None
78
+
79
+ if "user_verified" not in st.session_state:
80
+ st.session_state["user_verified"] = False
81
+
82
 
83
  def bucket_parts(s3_path: str) -> Tuple[str, str]:
84
  """Split an S3 path into bucket and key.
 
153
 
154
  def zip_and_upload_images(identifier: str, uploaded_files: List[BytesIO], image_type: str) -> str:
155
  """Save images as zip file to s3 for use in backend.
 
156
  Blocks until images are processed, added to zip file, and uploaded to S3.
 
157
  Args:
158
  identifier: unique identifier for the run, used in s3 link
159
  uploaded_files: BytesIO or UploadedFile from streamlit fileuploader
160
  image_type: string to identify different batches of images used in the
161
  backend model/training. Currently used values: "face", "theme"
 
162
  Returns:
163
  S3 location of zip file containing png images.
164
  """
 
189
 
190
  return s3_path
191
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  def train_model(model_inputs):
193
+ api_key = "03cdd72e-5c04-4207-bd6a-fd5712c1740e"
194
+ model_key = "bd2c55f5-84bb-40f9-82fb-196ca68b1c1d"
195
  st.markdown(str(model_inputs))
196
  _ = banana.run(api_key, model_key, model_inputs)
197
 
198
+ def switch_ux_state(new_state: UxState):
199
+ st.session_state['ux_state'] = new_state
200
+ st.experimental_rerun()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
+ def run_enter_api_key():
203
+ api_key_input = st.empty()
204
+ with api_key_input.form(key='user_auth_api_key'):
205
+ api_key_input = st.text_input(label='Please enter your Gretel API Key', type='password')
206
+ st.caption("Don't have a Gretel Cloud account yet? [Sign up](https://gretel.ai/signup) for free now!")
207
+ submit_button = st.form_submit_button(label='Submit', type='primary')
208
+ if submit_button:
209
+ r = requests.get(_GRETEL_USERINFO_ENDPOINT, headers={'authorization': api_key_input})
210
+ if r.status_code != 200:
211
+ st.error('API key could not be verified')
212
+ return
213
+ me = r.json().get('data', {}).get('me', {})
214
+ email = me.get('email')
215
+ if email is None:
216
+ st.error('No e-mail associated with this API key')
217
+ return
218
+ st.session_state["user_email"] = email
219
+ st.session_state["user_firstname"] = me.get('firstname')
220
+ st.session_state["user_verified"] = True
221
+
222
+ switch_ux_state(UxState.UPLOAD1)
223
 
224
  def run_upload_initial():
225
  identifier = st.session_state["key"]
226
+ images = st.empty()
227
+ with images.form("concept_one_form"):
228
  uploaded_files = st.file_uploader(
229
+ "Choose first concept image files", accept_multiple_files=True, type=["png", "jpg", "jpeg"]
230
+ )
231
+
232
+ token = st.text_input("Token Name")
233
+ st.caption(
234
+ """
235
+ The `token name` you use to describe your training images should be in the format: `a [identifier] [class noun]`, where the `[identifier]` should be a rare token. Relatively short sequences with 1-3 letters work the best (e.g. `sks`, `xjy`). `[class noun]` is a coarse class descriptor of the subject (e.g. cat, dog, watch, etc.). For example, your `token` can be: `a sks dog`, or with some extra description `a photo of a sks dog`. The trained model will learn to bind a unique identifier with your specific subject in the `instance_data`.
236
+ """
237
+ )
238
+ class_token = st.text_input("Token Class")
239
+ st.caption(
240
+ """
241
+ The `token class` is a description of the coarse class of your training images, in the format of `a [class noun]`, optionally with some extra description. `token_class` is used to alleviate overfitting to your customised images (the trained model should still keep the learnt prior so that it can still generate different dogs when the `[identifier]` is not in the prompt). Corresponding to the examples of the `token` above, the `token_class` can be `a dog` or `a photo of a dog`.
242
+ """
243
+ )
244
+ concept = st.checkbox(
245
+ 'Would you like to fine-tune on a second concept?',
246
  )
 
 
247
  submitted = st.form_submit_button(f"Upload")
248
  if submitted:
249
  with st.spinner('Uploading...'):
250
+ concept_information_dictionary = {
251
+ "file_path": generate_s3_get_url(zip_and_upload_images(
252
+ identifier, uploaded_files, "concept_one"), expiration_seconds=3600),
253
+ "token": token,
254
+ "class_token": class_token
255
+ }
256
+ st.session_state["concepts"].append(concept_information_dictionary)
257
  st.success(f'Uploading {len(uploaded_files)} files done!')
258
+ if concept:
259
+ switch_ux_state(UxState.UPLOAD2)
260
+ else:
261
+ switch_ux_state(UxState.PROMPT)
262
 
263
  def run_upload_secondary():
264
  identifier = st.session_state["key"]
265
+ images = st.empty()
266
+ with images.form("concept_two_form"):
267
+ uploaded_files = st.file_uploader(
268
+ "Choose second concept image files", accept_multiple_files=True, type=["png", "jpg", "jpeg"]
 
 
 
 
 
 
 
269
  )
270
+ token = st.text_input("Token Name")
271
+ st.caption(
272
+ """
273
+ The `token name` you use to describe your training images should be in the format: `a [identifier] [class noun]`, where the `[identifier]` should be a rare token. Relatively short sequences with 1-3 letters work the best (e.g. `sks`, `xjy`). `[class noun]` is a coarse class descriptor of the subject (e.g. cat, dog, watch, etc.). For example, your `token` can be: `a sks dog`, or with some extra description `a photo of a sks dog`. The trained model will learn to bind a unique identifier with your specific subject in the `instance_data`.
274
+ """
275
+ )
276
+ class_token = st.text_input("Token Class")
277
+ st.caption(
278
+ """
279
+ The `token class` is a description of the coarse class of your training images, in the format of `a [class noun]`, optionally with some extra description. `token_class` is used to alleviate overfitting to your customised images (the trained model should still keep the learnt prior so that it can still generate different dogs when the `[identifier]` is not in the prompt). Corresponding to the examples of the `token` above, the `token_class` can be `a dog` or `a photo of a dog`.
280
+ """
281
+ )
282
+ next_concept = st.checkbox(
283
+ 'Would you like to fine-tune on a third concept?',
284
+ )
285
+ submitted = st.form_submit_button(f"Upload")
286
+ if submitted:
287
+ with st.spinner('Uploading...'):
288
+ concept_information_dictionary = {
289
+ "file_path": generate_s3_get_url(zip_and_upload_images(
290
+ identifier, uploaded_files, "concept_two"), expiration_seconds=3600),
291
+ "token": token,
292
+ "class_token": class_token
293
+ }
294
+ st.session_state["concepts"].append(concept_information_dictionary)
295
+ st.success(f'Uploading {len(uploaded_files)} files done!')
296
+ if next_concept:
297
+ switch_ux_state(UxState.UPLOAD3)
298
+ else:
299
+ switch_ux_state(UxState.PROMPT)
300
 
301
+ def run_upload_third():
302
+ identifier = st.session_state["key"]
303
+ images = st.empty()
304
+ with images.form("concept_three_form"):
305
+ uploaded_files = st.file_uploader(
306
+ "Choose third concept image files", accept_multiple_files=True, type=["png", "jpg", "jpeg"]
307
+ )
308
+ token = st.text_input("Token Name")
309
+ st.caption(
310
+ """
311
+ The `token name` you use to describe your training images should be in the format: `a [identifier] [class noun]`, where the `[identifier]` should be a rare token. Relatively short sequences with 1-3 letters work the best (e.g. `sks`, `xjy`). `[class noun]` is a coarse class descriptor of the subject (e.g. cat, dog, watch, etc.). For example, your `token` can be: `a sks dog`, or with some extra description `a photo of a sks dog`. The trained model will learn to bind a unique identifier with your specific subject in the `instance_data`.
312
+ """
313
+ )
314
+ class_token = st.text_input(f"Token Class")
315
+ st.caption(
316
+ """
317
+ The `token class` is a description of the coarse class of your training images, in the format of `a [class noun]`, optionally with some extra description. `token_class` is used to alleviate overfitting to your customised images (the trained model should still keep the learnt prior so that it can still generate different dogs when the `[identifier]` is not in the prompt). Corresponding to the examples of the `token` above, the `token_class` can be `a dog` or `a photo of a dog`.
318
+ """
319
+ )
320
+ submitted = st.form_submit_button(f"Upload")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
  if submitted:
322
+ with st.spinner('Uploading...'):
323
+ concept_information_dictionary = {
324
+ "file_path": generate_s3_get_url(zip_and_upload_images(
325
+ identifier, uploaded_files, "concept_three"), expiration_seconds=3600),
326
+ "token": token,
327
+ "class_token": class_token
328
+ }
329
+ st.session_state["concepts"].append(concept_information_dictionary)
330
+ st.success(f'Uploading {len(uploaded_files)} files done!')
331
+ switch_ux_state(UxState.PROMPT)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
 
333
+ def run_prompts():
334
+ identifier = st.session_state["key"]
335
+ prompt_form = st.empty()
336
+ with prompt_form.form("prompt_form"):
337
+ #prompt = st.text_input("Token Name")
338
+ full_prompt = st.text_input("Prompt")
339
+ prompt_keywords = st.text_input(f"Prompt Keywords")
340
+ submitted = st.form_submit_button(f"Submit")
341
+ if submitted:
342
+ st.session_state["prompt_keywords"] = prompt_keywords
343
+ st.session_state["prompt"] = full_prompt
344
+ st.session_state["ux_state"] = UxState.TRAIN
345
 
346
  def run_train():
347
+ st.write("Congratulations, your model is training.")
348
  st.write(f"We'll send an email to {st.session_state['user_email']} when it's finished, usually about 20-30 minutes.")
349
+ st.write("Closing this tab will not affect the ongoing image generation.")
350
+ with st.spinner("Training in progress..."):
351
+ st.session_state["model_inputs"] = {
352
+ "concepts": st.session_state["concepts"],
353
+ "num_images": 50,
354
+ "prompt": st.session_state["prompt"],
355
+ "prompt_keywords": st.session_state["prompt_keywords"]
356
+ }
357
+ s3_output_path = _S3_PATH_OUTPUT.format(identifier=st.session_state["key"], image_type="generated")
358
+ st.session_state['model_inputs']['identifier'] = st.session_state["key"]
359
+ st.session_state['model_inputs']['email'] = st.session_state["user_email"]
360
+ # The backend does not have s3 credentials, so generate
361
+ # presigned urls for the backend to use to write and read
362
+ # the generated images.
363
+ st.session_state['model_inputs']['output_s3_url_get'] = generate_s3_get_url(
364
+ s3_output_path, expiration_seconds=60 * 60 * 24,
365
+ )
366
+ st.session_state['model_inputs']['output_s3_url_put'] = generate_s3_put_url(
367
+ s3_output_path, expiration_seconds=3600,
368
+ )
369
+ train_model(st.session_state['model_inputs'])
370
+ switch_ux_state(UxState.FINISHED)
371
+
372
 
373
+ def run_finished():
374
+ st.success('Image generation completed!')
375
+ st.write(f"We've sent an email to {st.session_state['user_email']} with a link to your generated images. Check it out!")
376
 
377
  if __name__ == "__main__":
378
  setup_session_state()
379
 
380
  ux_state = st.session_state["ux_state"]
381
 
382
+ runners = {
383
+ UxState.LOGIN_VIA_API_KEY: run_enter_api_key,
384
+ UxState.UPLOAD1: run_upload_initial,
385
+ UxState.UPLOAD2: run_upload_secondary,
386
+ UxState.UPLOAD3: run_upload_third,
387
+ UxState.PROMPT: run_prompts,
388
+ UxState.TRAIN: run_train,
389
+ UxState.FINISHED: run_finished,
390
+ }
391
+ if (runner := runners.get(ux_state)) is not None:
392
+ runner()
 
393
  else:
394
+ raise ValueError(f"Internal app error, unknown ux_state='{ux_state}'")
395
+