Spaces:
Runtime error
Runtime error
santhosh97
commited on
Commit
•
ada7876
1
Parent(s):
51cf589
Update app.py
Browse files
app.py
CHANGED
@@ -9,8 +9,6 @@ from typing import List, Tuple, TYPE_CHECKING
|
|
9 |
import uuid
|
10 |
import argparse
|
11 |
import logging
|
12 |
-
import sendgrid
|
13 |
-
from sendgrid.helpers.mail import Mail, Email, To, Content
|
14 |
from enum import Enum
|
15 |
import tempfile
|
16 |
from pathlib import Path
|
@@ -34,60 +32,53 @@ logger.setLevel(logging.INFO)
|
|
34 |
load_dotenv()
|
35 |
|
36 |
_S3_PATH_OUTPUT = "s3://gretel-image-synthetics-use2/data/{identifier}/{image_type}_images.zip"
|
37 |
-
|
38 |
-
|
|
|
39 |
|
40 |
class UxState(str, Enum):
|
41 |
-
|
42 |
-
VERIFY_EMAIL = "verify_email"
|
43 |
UPLOAD1 = "upload1"
|
44 |
UPLOAD2 = "upload2"
|
45 |
-
|
|
|
46 |
TRAIN = "train"
|
|
|
47 |
|
|
|
|
|
|
|
48 |
def setup_session_state():
|
49 |
if "key" not in st.session_state:
|
50 |
st.session_state["key"] = uuid.uuid4().hex
|
51 |
|
52 |
if "ux_state" not in st.session_state:
|
53 |
-
st.session_state["ux_state"] = UxState.
|
54 |
|
55 |
if "model_inputs" not in st.session_state:
|
56 |
st.session_state["model_inputs"] = None
|
57 |
|
58 |
-
if "
|
59 |
-
st.session_state["
|
60 |
-
|
61 |
-
if "initial_token" not in st.session_state:
|
62 |
-
st.session_state["initial_token"] = None
|
63 |
-
|
64 |
-
if "initial_class_token" not in st.session_state:
|
65 |
-
st.session_state["initial_class_token"] = None
|
66 |
-
|
67 |
-
if "secondary_concept_file_path" not in st.session_state:
|
68 |
-
st.session_state["secondary_concept_file_path"] = None
|
69 |
-
|
70 |
-
if "secondary_token" not in st.session_state:
|
71 |
-
st.session_state["secondary_token"] = None
|
72 |
-
|
73 |
-
if "secondary_class_token" not in st.session_state:
|
74 |
-
st.session_state["secondary_class_token"] = None
|
75 |
-
|
76 |
if "prompt_keywords" not in st.session_state:
|
77 |
st.session_state["prompt_keywords"] = None
|
|
|
|
|
|
|
78 |
|
79 |
if "view" not in st.session_state:
|
80 |
st.session_state["view"] = False
|
81 |
|
82 |
-
if "captcha_response" not in st.session_state:
|
83 |
-
st.session_state["captcha_response"] = None
|
84 |
-
|
85 |
-
if "captcha" not in st.session_state:
|
86 |
-
st.session_state["captcha"] = {}
|
87 |
-
|
88 |
if "user_email" not in st.session_state:
|
89 |
st.session_state["user_email"] = None
|
90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
def bucket_parts(s3_path: str) -> Tuple[str, str]:
|
93 |
"""Split an S3 path into bucket and key.
|
@@ -162,15 +153,12 @@ def generate_s3_put_url(s3_path: str, expiration_seconds: int) -> str:
|
|
162 |
|
163 |
def zip_and_upload_images(identifier: str, uploaded_files: List[BytesIO], image_type: str) -> str:
|
164 |
"""Save images as zip file to s3 for use in backend.
|
165 |
-
|
166 |
Blocks until images are processed, added to zip file, and uploaded to S3.
|
167 |
-
|
168 |
Args:
|
169 |
identifier: unique identifier for the run, used in s3 link
|
170 |
uploaded_files: BytesIO or UploadedFile from streamlit fileuploader
|
171 |
image_type: string to identify different batches of images used in the
|
172 |
backend model/training. Currently used values: "face", "theme"
|
173 |
-
|
174 |
Returns:
|
175 |
S3 location of zip file containing png images.
|
176 |
"""
|
@@ -201,268 +189,207 @@ def zip_and_upload_images(identifier: str, uploaded_files: List[BytesIO], image_
|
|
201 |
|
202 |
return s3_path
|
203 |
|
204 |
-
def send_email(to_email, user_code):
|
205 |
-
sg = sendgrid.SendGridAPIClient(api_key=os.environ.get('SENDGRID_API_KEY'))
|
206 |
-
from_email = Email("santhosh@gretel.ai")
|
207 |
-
to_email = To(to_email)
|
208 |
-
subject = "One Time Code"
|
209 |
-
content = Content("text/plain", f"Here is your one-time code: {user_code}")
|
210 |
-
mail = Mail(from_email, to_email, subject, content)
|
211 |
-
mail_json = mail.get()
|
212 |
-
response = sg.client.mail.send.post(request_body=mail_json)
|
213 |
-
|
214 |
-
|
215 |
-
# Create a function to generate a captcha
|
216 |
-
def generate_captcha():
|
217 |
-
# Make a GET request to the API endpoint to generate a captcha
|
218 |
-
response = requests.get(CAPTCHA_ENDPOINT)
|
219 |
-
|
220 |
-
# If the request was successful, return the API response
|
221 |
-
if response.status_code == 200:
|
222 |
-
return response.json()
|
223 |
-
else:
|
224 |
-
logger.warn(f"Error from generate captcha request: {response.json()}")
|
225 |
-
# Otherwise, return an error message
|
226 |
-
return {"error": "Failed to generate captcha"}
|
227 |
-
|
228 |
-
# Create a function to verify the captcha
|
229 |
-
def verify_captcha(captcha_id, captcha_response):
|
230 |
-
# Make a POST request to the API endpoint with the captcha ID and response
|
231 |
-
verify_json = {"uuid": captcha_id, "captcha": captcha_response}
|
232 |
-
response = requests.post(
|
233 |
-
VERIFY_ENDPOINT, json=verify_json,
|
234 |
-
)
|
235 |
-
logger.info(f"Response from captcha verify: {response}")
|
236 |
-
|
237 |
-
# If the request was successful, return the API response
|
238 |
-
if response.status_code == 200:
|
239 |
-
return response.json()
|
240 |
-
|
241 |
-
# Otherwise, return an error message
|
242 |
-
return {"error": "Failed to verify captcha"}
|
243 |
-
|
244 |
def train_model(model_inputs):
|
245 |
-
api_key =
|
246 |
-
model_key =
|
247 |
st.markdown(str(model_inputs))
|
248 |
_ = banana.run(api_key, model_key, model_inputs)
|
249 |
|
250 |
-
def
|
251 |
-
|
252 |
-
|
253 |
-
text_input = st.text_input(label='Please Enter Your Email')
|
254 |
-
submit_button = st.form_submit_button(label='Submit')
|
255 |
-
if submit_button:
|
256 |
-
st.session_state["user_email"] = text_input
|
257 |
-
send_email(text_input, str(st.session_state["key"]))
|
258 |
-
st.session_state["ux_state"] = UxState.VERIFY_EMAIL
|
259 |
-
# TODO: alternately run this submit log in a callback to the input?
|
260 |
-
# or otherwise ensure we execute the runner for the new state
|
261 |
-
st.experimental_rerun()
|
262 |
-
|
263 |
-
|
264 |
-
def run_verify_email():
|
265 |
-
user_auth = st.empty()
|
266 |
-
with user_auth.form("one-code"):
|
267 |
-
text_input = st.text_input(label='Please Input One Time Code')
|
268 |
-
submit_button = st.form_submit_button(label='Submit')
|
269 |
-
if submit_button:
|
270 |
-
if text_input == st.session_state["key"]:
|
271 |
-
st.session_state["ux_state"] = UxState.UPLOAD1
|
272 |
-
st.experimental_rerun()
|
273 |
-
else:
|
274 |
-
st.markdown("Please Enter Correct Code!")
|
275 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
276 |
|
277 |
def run_upload_initial():
|
278 |
identifier = st.session_state["key"]
|
279 |
-
|
280 |
-
with
|
281 |
uploaded_files = st.file_uploader(
|
282 |
-
"Choose image files", accept_multiple_files=True, type=["png", "jpg", "jpeg"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
283 |
)
|
284 |
-
initial_concept_token = st.text_input("Token Name")
|
285 |
-
initial_concept_class_token = st.text_input("Token Class")
|
286 |
submitted = st.form_submit_button(f"Upload")
|
287 |
if submitted:
|
288 |
with st.spinner('Uploading...'):
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
|
|
|
|
294 |
st.success(f'Uploading {len(uploaded_files)} files done!')
|
295 |
-
|
296 |
-
|
297 |
-
|
|
|
298 |
|
299 |
def run_upload_secondary():
|
300 |
identifier = st.session_state["key"]
|
301 |
-
|
302 |
-
with
|
303 |
-
|
304 |
-
"Choose
|
305 |
-
images=[
|
306 |
-
"https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/theme-images/got.png",
|
307 |
-
"https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/theme-images/ironman.png",
|
308 |
-
"https://ichef.bbci.co.uk/images/ic/640x360/p09t1hg0.jpg",
|
309 |
-
],
|
310 |
-
captions=["Game of Thrones", "Iron Man", "Thor"],
|
311 |
-
return_value="index",
|
312 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
313 |
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
"secondary_class_token": 'superhero',
|
335 |
-
"initial_concept_file_path": generate_s3_get_url(st.session_state["initial_concept_file_path"], expiration_seconds=3600),
|
336 |
-
"num_images": 50,
|
337 |
-
"prompt_keywords": prompt_keywords
|
338 |
-
}
|
339 |
-
st.success("Success!")
|
340 |
-
st.session_state["ux_state"] = UxState.CAPTCHA
|
341 |
-
st.experimental_rerun()
|
342 |
-
with col2:
|
343 |
-
submitted_4 = st.form_submit_button(
|
344 |
-
"If none of the themes interest you, click here!"
|
345 |
-
)
|
346 |
-
if submitted_4:
|
347 |
-
st.session_state["view"] = True
|
348 |
-
|
349 |
-
if st.session_state["view"]:
|
350 |
-
# TODO: split into it's own ux state and function?
|
351 |
-
custom_theme_images = st.empty()
|
352 |
-
with custom_theme_images.form("input_custom_themes"):
|
353 |
-
st.markdown("If none of the themes interest you, please input your own!")
|
354 |
-
uploaded_files_2 = st.file_uploader(
|
355 |
-
"Choose image files",
|
356 |
-
accept_multiple_files=True,
|
357 |
-
type=["png", "jpg", "jpeg"],
|
358 |
-
)
|
359 |
-
secondary_concept_token = st.text_input("Token Name")
|
360 |
-
secondary_concept_class_token = st.text_input("Token Class")
|
361 |
-
prompt_keywords = st.text_input("Prompt Keywords")
|
362 |
-
submitted_3 = st.form_submit_button("Submit!")
|
363 |
-
if submitted_3:
|
364 |
-
with st.spinner('Uploading...'):
|
365 |
-
st.session_state["secondary_concept_file_path"] = zip_and_upload_images(
|
366 |
-
identifier, uploaded_files_2, "theme"
|
367 |
-
)
|
368 |
-
#st.markdown(secondary_concept_file_path)
|
369 |
-
st.session_state["model_inputs"] = {
|
370 |
-
# Use presigned urls since backend does not have credentials
|
371 |
-
"initial_concept_file_path": generate_s3_get_url(st.session_state["initial_concept_file_path"], expiration_seconds=3600),
|
372 |
-
"secondary_concept_file_path": generate_s3_get_url(st.session_state["secondary_concept_file_path"], expiration_seconds=3600),
|
373 |
-
"initial_token": st.session_state["initial_token"],
|
374 |
-
"secondary_token": secondary_concept_token,
|
375 |
-
"initial_class_token": st.session_state["initial_class_token"],
|
376 |
-
"secondary_class_token": secondary_concept_class_token,
|
377 |
-
"num_images": 50,
|
378 |
-
"prompt_keywords": prompt_keywords
|
379 |
-
}
|
380 |
-
st.success('Done!')
|
381 |
-
st.session_state["ux_state"] = UxState.CAPTCHA
|
382 |
-
st.experimental_rerun()
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
def run_captcha():
|
387 |
-
captcha_form = st.empty()
|
388 |
-
with captcha_form.form("captcha_form", clear_on_submit=True):
|
389 |
-
# Create container to create image/text input out of order from the
|
390 |
-
# format submit button. Needed since we need to know the status of the
|
391 |
-
# form submit to know what the captcha should do.
|
392 |
-
captcha_container = st.container()
|
393 |
-
display_captcha = True
|
394 |
-
# TODO: Submit button renders first, then drops down once the image is
|
395 |
-
# fetched leading to page reflow. Would be nice to not have reflow, but
|
396 |
-
# we need to know if the submit button was previously pressed and if the
|
397 |
-
# captcha was solved to generate and display a new captcha or not.
|
398 |
-
# Possible solution is use an on_click callback to set a session_state
|
399 |
-
# variable to access whether the button was pushed or not instead of the
|
400 |
-
# return value here.
|
401 |
-
submitted = st.form_submit_button("Submit Captcha!")
|
402 |
-
|
403 |
if submitted:
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
# presigned urls for the backend to use to write and read
|
415 |
-
# the generated images.
|
416 |
-
st.session_state["model_inputs"]["output_s3_url_get"] = generate_s3_get_url(
|
417 |
-
s3_output_path, expiration_seconds=60 * 60 * 24,
|
418 |
-
)
|
419 |
-
st.session_state["model_inputs"]["output_s3_url_put"] = generate_s3_put_url(
|
420 |
-
s3_output_path, expiration_seconds=3600,
|
421 |
-
)
|
422 |
-
train_model(st.session_state["model_inputs"])
|
423 |
-
st.session_state["ux_state"] = UxState.TRAIN
|
424 |
-
st.experimental_rerun()
|
425 |
-
else:
|
426 |
-
st.error(result['error'])
|
427 |
-
|
428 |
-
if display_captcha:
|
429 |
-
# Generate new captcha and display. Occurs on first run of the
|
430 |
-
# captcha state, or after previously failed captcha attempts.
|
431 |
-
result = generate_captcha()
|
432 |
-
captcha_id = result['uuid']
|
433 |
-
captcha_image = result['captcha']
|
434 |
-
|
435 |
-
st.session_state['captcha']['uuid'] = captcha_id
|
436 |
-
st.session_state['captcha']['captcha'] = captcha_image
|
437 |
-
|
438 |
-
captcha_container.image(captcha_image, width=300)
|
439 |
-
|
440 |
-
captcha_container.text_input("Enter the captcha response", key="captcha_response")
|
441 |
-
# Submit button already setup previously.
|
442 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
443 |
|
444 |
def run_train():
|
445 |
-
st.write(
|
446 |
st.write(f"We'll send an email to {st.session_state['user_email']} when it's finished, usually about 20-30 minutes.")
|
447 |
-
st.write("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
448 |
|
|
|
|
|
|
|
449 |
|
450 |
if __name__ == "__main__":
|
451 |
setup_session_state()
|
452 |
|
453 |
ux_state = st.session_state["ux_state"]
|
454 |
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
run_train()
|
467 |
else:
|
468 |
-
raise ValueError(f"Internal app error, unknown ux_state='{ux_state}'")
|
|
|
|
9 |
import uuid
|
10 |
import argparse
|
11 |
import logging
|
|
|
|
|
12 |
from enum import Enum
|
13 |
import tempfile
|
14 |
from pathlib import Path
|
|
|
32 |
load_dotenv()
|
33 |
|
34 |
_S3_PATH_OUTPUT = "s3://gretel-image-synthetics-use2/data/{identifier}/{image_type}_images.zip"
|
35 |
+
_GRETEL_USERINFO_ENDPOINT = "https://api.gretel.cloud/users/me"
|
36 |
+
|
37 |
+
|
38 |
|
39 |
class UxState(str, Enum):
|
40 |
+
LOGIN_VIA_API_KEY = "login_via_api_key"
|
|
|
41 |
UPLOAD1 = "upload1"
|
42 |
UPLOAD2 = "upload2"
|
43 |
+
UPLOAD3 = "upload3"
|
44 |
+
PROMPT = "prompt"
|
45 |
TRAIN = "train"
|
46 |
+
FINISHED = "finished"
|
47 |
|
48 |
+
# Command-line arguments to control some stuff for easier local testing.
|
49 |
+
# Eventually may want to move everything into functions and have a
|
50 |
+
# if __name__ == "main" setup instead of everything inline.
|
51 |
def setup_session_state():
|
52 |
if "key" not in st.session_state:
|
53 |
st.session_state["key"] = uuid.uuid4().hex
|
54 |
|
55 |
if "ux_state" not in st.session_state:
|
56 |
+
st.session_state["ux_state"] = UxState.LOGIN_VIA_API_KEY
|
57 |
|
58 |
if "model_inputs" not in st.session_state:
|
59 |
st.session_state["model_inputs"] = None
|
60 |
|
61 |
+
if "concepts" not in st.session_state:
|
62 |
+
st.session_state["concepts"] = []
|
63 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
if "prompt_keywords" not in st.session_state:
|
65 |
st.session_state["prompt_keywords"] = None
|
66 |
+
|
67 |
+
if "prompt" not in st.session_state:
|
68 |
+
st.session_state["prompt"] = None
|
69 |
|
70 |
if "view" not in st.session_state:
|
71 |
st.session_state["view"] = False
|
72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
if "user_email" not in st.session_state:
|
74 |
st.session_state["user_email"] = None
|
75 |
|
76 |
+
if "user_firstname" not in st.session_state:
|
77 |
+
st.session_state["user_firstname"] = None
|
78 |
+
|
79 |
+
if "user_verified" not in st.session_state:
|
80 |
+
st.session_state["user_verified"] = False
|
81 |
+
|
82 |
|
83 |
def bucket_parts(s3_path: str) -> Tuple[str, str]:
|
84 |
"""Split an S3 path into bucket and key.
|
|
|
153 |
|
154 |
def zip_and_upload_images(identifier: str, uploaded_files: List[BytesIO], image_type: str) -> str:
|
155 |
"""Save images as zip file to s3 for use in backend.
|
|
|
156 |
Blocks until images are processed, added to zip file, and uploaded to S3.
|
|
|
157 |
Args:
|
158 |
identifier: unique identifier for the run, used in s3 link
|
159 |
uploaded_files: BytesIO or UploadedFile from streamlit fileuploader
|
160 |
image_type: string to identify different batches of images used in the
|
161 |
backend model/training. Currently used values: "face", "theme"
|
|
|
162 |
Returns:
|
163 |
S3 location of zip file containing png images.
|
164 |
"""
|
|
|
189 |
|
190 |
return s3_path
|
191 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
def train_model(model_inputs):
|
193 |
+
api_key = "03cdd72e-5c04-4207-bd6a-fd5712c1740e"
|
194 |
+
model_key = "bd2c55f5-84bb-40f9-82fb-196ca68b1c1d"
|
195 |
st.markdown(str(model_inputs))
|
196 |
_ = banana.run(api_key, model_key, model_inputs)
|
197 |
|
198 |
+
def switch_ux_state(new_state: UxState):
|
199 |
+
st.session_state['ux_state'] = new_state
|
200 |
+
st.experimental_rerun()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
|
202 |
+
def run_enter_api_key():
|
203 |
+
api_key_input = st.empty()
|
204 |
+
with api_key_input.form(key='user_auth_api_key'):
|
205 |
+
api_key_input = st.text_input(label='Please enter your Gretel API Key', type='password')
|
206 |
+
st.caption("Don't have a Gretel Cloud account yet? [Sign up](https://gretel.ai/signup) for free now!")
|
207 |
+
submit_button = st.form_submit_button(label='Submit', type='primary')
|
208 |
+
if submit_button:
|
209 |
+
r = requests.get(_GRETEL_USERINFO_ENDPOINT, headers={'authorization': api_key_input})
|
210 |
+
if r.status_code != 200:
|
211 |
+
st.error('API key could not be verified')
|
212 |
+
return
|
213 |
+
me = r.json().get('data', {}).get('me', {})
|
214 |
+
email = me.get('email')
|
215 |
+
if email is None:
|
216 |
+
st.error('No e-mail associated with this API key')
|
217 |
+
return
|
218 |
+
st.session_state["user_email"] = email
|
219 |
+
st.session_state["user_firstname"] = me.get('firstname')
|
220 |
+
st.session_state["user_verified"] = True
|
221 |
+
|
222 |
+
switch_ux_state(UxState.UPLOAD1)
|
223 |
|
224 |
def run_upload_initial():
|
225 |
identifier = st.session_state["key"]
|
226 |
+
images = st.empty()
|
227 |
+
with images.form("concept_one_form"):
|
228 |
uploaded_files = st.file_uploader(
|
229 |
+
"Choose first concept image files", accept_multiple_files=True, type=["png", "jpg", "jpeg"]
|
230 |
+
)
|
231 |
+
|
232 |
+
token = st.text_input("Token Name")
|
233 |
+
st.caption(
|
234 |
+
"""
|
235 |
+
The `token name` you use to describe your training images should be in the format: `a [identifier] [class noun]`, where the `[identifier]` should be a rare token. Relatively short sequences with 1-3 letters work the best (e.g. `sks`, `xjy`). `[class noun]` is a coarse class descriptor of the subject (e.g. cat, dog, watch, etc.). For example, your `token` can be: `a sks dog`, or with some extra description `a photo of a sks dog`. The trained model will learn to bind a unique identifier with your specific subject in the `instance_data`.
|
236 |
+
"""
|
237 |
+
)
|
238 |
+
class_token = st.text_input("Token Class")
|
239 |
+
st.caption(
|
240 |
+
"""
|
241 |
+
The `token class` is a description of the coarse class of your training images, in the format of `a [class noun]`, optionally with some extra description. `token_class` is used to alleviate overfitting to your customised images (the trained model should still keep the learnt prior so that it can still generate different dogs when the `[identifier]` is not in the prompt). Corresponding to the examples of the `token` above, the `token_class` can be `a dog` or `a photo of a dog`.
|
242 |
+
"""
|
243 |
+
)
|
244 |
+
concept = st.checkbox(
|
245 |
+
'Would you like to fine-tune on a second concept?',
|
246 |
)
|
|
|
|
|
247 |
submitted = st.form_submit_button(f"Upload")
|
248 |
if submitted:
|
249 |
with st.spinner('Uploading...'):
|
250 |
+
concept_information_dictionary = {
|
251 |
+
"file_path": generate_s3_get_url(zip_and_upload_images(
|
252 |
+
identifier, uploaded_files, "concept_one"), expiration_seconds=3600),
|
253 |
+
"token": token,
|
254 |
+
"class_token": class_token
|
255 |
+
}
|
256 |
+
st.session_state["concepts"].append(concept_information_dictionary)
|
257 |
st.success(f'Uploading {len(uploaded_files)} files done!')
|
258 |
+
if concept:
|
259 |
+
switch_ux_state(UxState.UPLOAD2)
|
260 |
+
else:
|
261 |
+
switch_ux_state(UxState.PROMPT)
|
262 |
|
263 |
def run_upload_secondary():
|
264 |
identifier = st.session_state["key"]
|
265 |
+
images = st.empty()
|
266 |
+
with images.form("concept_two_form"):
|
267 |
+
uploaded_files = st.file_uploader(
|
268 |
+
"Choose second concept image files", accept_multiple_files=True, type=["png", "jpg", "jpeg"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
269 |
)
|
270 |
+
token = st.text_input("Token Name")
|
271 |
+
st.caption(
|
272 |
+
"""
|
273 |
+
The `token name` you use to describe your training images should be in the format: `a [identifier] [class noun]`, where the `[identifier]` should be a rare token. Relatively short sequences with 1-3 letters work the best (e.g. `sks`, `xjy`). `[class noun]` is a coarse class descriptor of the subject (e.g. cat, dog, watch, etc.). For example, your `token` can be: `a sks dog`, or with some extra description `a photo of a sks dog`. The trained model will learn to bind a unique identifier with your specific subject in the `instance_data`.
|
274 |
+
"""
|
275 |
+
)
|
276 |
+
class_token = st.text_input("Token Class")
|
277 |
+
st.caption(
|
278 |
+
"""
|
279 |
+
The `token class` is a description of the coarse class of your training images, in the format of `a [class noun]`, optionally with some extra description. `token_class` is used to alleviate overfitting to your customised images (the trained model should still keep the learnt prior so that it can still generate different dogs when the `[identifier]` is not in the prompt). Corresponding to the examples of the `token` above, the `token_class` can be `a dog` or `a photo of a dog`.
|
280 |
+
"""
|
281 |
+
)
|
282 |
+
next_concept = st.checkbox(
|
283 |
+
'Would you like to fine-tune on a third concept?',
|
284 |
+
)
|
285 |
+
submitted = st.form_submit_button(f"Upload")
|
286 |
+
if submitted:
|
287 |
+
with st.spinner('Uploading...'):
|
288 |
+
concept_information_dictionary = {
|
289 |
+
"file_path": generate_s3_get_url(zip_and_upload_images(
|
290 |
+
identifier, uploaded_files, "concept_two"), expiration_seconds=3600),
|
291 |
+
"token": token,
|
292 |
+
"class_token": class_token
|
293 |
+
}
|
294 |
+
st.session_state["concepts"].append(concept_information_dictionary)
|
295 |
+
st.success(f'Uploading {len(uploaded_files)} files done!')
|
296 |
+
if next_concept:
|
297 |
+
switch_ux_state(UxState.UPLOAD3)
|
298 |
+
else:
|
299 |
+
switch_ux_state(UxState.PROMPT)
|
300 |
|
301 |
+
def run_upload_third():
|
302 |
+
identifier = st.session_state["key"]
|
303 |
+
images = st.empty()
|
304 |
+
with images.form("concept_three_form"):
|
305 |
+
uploaded_files = st.file_uploader(
|
306 |
+
"Choose third concept image files", accept_multiple_files=True, type=["png", "jpg", "jpeg"]
|
307 |
+
)
|
308 |
+
token = st.text_input("Token Name")
|
309 |
+
st.caption(
|
310 |
+
"""
|
311 |
+
The `token name` you use to describe your training images should be in the format: `a [identifier] [class noun]`, where the `[identifier]` should be a rare token. Relatively short sequences with 1-3 letters work the best (e.g. `sks`, `xjy`). `[class noun]` is a coarse class descriptor of the subject (e.g. cat, dog, watch, etc.). For example, your `token` can be: `a sks dog`, or with some extra description `a photo of a sks dog`. The trained model will learn to bind a unique identifier with your specific subject in the `instance_data`.
|
312 |
+
"""
|
313 |
+
)
|
314 |
+
class_token = st.text_input(f"Token Class")
|
315 |
+
st.caption(
|
316 |
+
"""
|
317 |
+
The `token class` is a description of the coarse class of your training images, in the format of `a [class noun]`, optionally with some extra description. `token_class` is used to alleviate overfitting to your customised images (the trained model should still keep the learnt prior so that it can still generate different dogs when the `[identifier]` is not in the prompt). Corresponding to the examples of the `token` above, the `token_class` can be `a dog` or `a photo of a dog`.
|
318 |
+
"""
|
319 |
+
)
|
320 |
+
submitted = st.form_submit_button(f"Upload")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
321 |
if submitted:
|
322 |
+
with st.spinner('Uploading...'):
|
323 |
+
concept_information_dictionary = {
|
324 |
+
"file_path": generate_s3_get_url(zip_and_upload_images(
|
325 |
+
identifier, uploaded_files, "concept_three"), expiration_seconds=3600),
|
326 |
+
"token": token,
|
327 |
+
"class_token": class_token
|
328 |
+
}
|
329 |
+
st.session_state["concepts"].append(concept_information_dictionary)
|
330 |
+
st.success(f'Uploading {len(uploaded_files)} files done!')
|
331 |
+
switch_ux_state(UxState.PROMPT)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
332 |
|
333 |
+
def run_prompts():
|
334 |
+
identifier = st.session_state["key"]
|
335 |
+
prompt_form = st.empty()
|
336 |
+
with prompt_form.form("prompt_form"):
|
337 |
+
#prompt = st.text_input("Token Name")
|
338 |
+
full_prompt = st.text_input("Prompt")
|
339 |
+
prompt_keywords = st.text_input(f"Prompt Keywords")
|
340 |
+
submitted = st.form_submit_button(f"Submit")
|
341 |
+
if submitted:
|
342 |
+
st.session_state["prompt_keywords"] = prompt_keywords
|
343 |
+
st.session_state["prompt"] = full_prompt
|
344 |
+
st.session_state["ux_state"] = UxState.TRAIN
|
345 |
|
346 |
def run_train():
|
347 |
+
st.write("Congratulations, your model is training.")
|
348 |
st.write(f"We'll send an email to {st.session_state['user_email']} when it's finished, usually about 20-30 minutes.")
|
349 |
+
st.write("Closing this tab will not affect the ongoing image generation.")
|
350 |
+
with st.spinner("Training in progress..."):
|
351 |
+
st.session_state["model_inputs"] = {
|
352 |
+
"concepts": st.session_state["concepts"],
|
353 |
+
"num_images": 50,
|
354 |
+
"prompt": st.session_state["prompt"],
|
355 |
+
"prompt_keywords": st.session_state["prompt_keywords"]
|
356 |
+
}
|
357 |
+
s3_output_path = _S3_PATH_OUTPUT.format(identifier=st.session_state["key"], image_type="generated")
|
358 |
+
st.session_state['model_inputs']['identifier'] = st.session_state["key"]
|
359 |
+
st.session_state['model_inputs']['email'] = st.session_state["user_email"]
|
360 |
+
# The backend does not have s3 credentials, so generate
|
361 |
+
# presigned urls for the backend to use to write and read
|
362 |
+
# the generated images.
|
363 |
+
st.session_state['model_inputs']['output_s3_url_get'] = generate_s3_get_url(
|
364 |
+
s3_output_path, expiration_seconds=60 * 60 * 24,
|
365 |
+
)
|
366 |
+
st.session_state['model_inputs']['output_s3_url_put'] = generate_s3_put_url(
|
367 |
+
s3_output_path, expiration_seconds=3600,
|
368 |
+
)
|
369 |
+
train_model(st.session_state['model_inputs'])
|
370 |
+
switch_ux_state(UxState.FINISHED)
|
371 |
+
|
372 |
|
373 |
+
def run_finished():
|
374 |
+
st.success('Image generation completed!')
|
375 |
+
st.write(f"We've sent an email to {st.session_state['user_email']} with a link to your generated images. Check it out!")
|
376 |
|
377 |
if __name__ == "__main__":
|
378 |
setup_session_state()
|
379 |
|
380 |
ux_state = st.session_state["ux_state"]
|
381 |
|
382 |
+
runners = {
|
383 |
+
UxState.LOGIN_VIA_API_KEY: run_enter_api_key,
|
384 |
+
UxState.UPLOAD1: run_upload_initial,
|
385 |
+
UxState.UPLOAD2: run_upload_secondary,
|
386 |
+
UxState.UPLOAD3: run_upload_third,
|
387 |
+
UxState.PROMPT: run_prompts,
|
388 |
+
UxState.TRAIN: run_train,
|
389 |
+
UxState.FINISHED: run_finished,
|
390 |
+
}
|
391 |
+
if (runner := runners.get(ux_state)) is not None:
|
392 |
+
runner()
|
|
|
393 |
else:
|
394 |
+
raise ValueError(f"Internal app error, unknown ux_state='{ux_state}'")
|
395 |
+
|