Ron Au commited on
Commit
4c519fd
1 Parent(s): b49a47d

feat(eta): Improve duration UX

Browse files

- Render result without waiting for last poll interval to complete
- Calculate ETA based on past completions
- Update ETA during generation based on place in queue

Files changed (4) hide show
  1. app.py +32 -3
  2. modules/inference.py +4 -8
  3. static/index.js +29 -21
  4. templates/index.html +1 -1
app.py CHANGED
@@ -1,4 +1,5 @@
1
- import time
 
2
  from flask import Flask, jsonify, render_template, request
3
 
4
  from modules.details import load_lists, rand_details
@@ -6,6 +7,8 @@ from modules.inference import generate_image
6
 
7
  app = Flask(__name__)
8
 
 
 
9
 
10
  @app.route('/')
11
  def index():
@@ -19,10 +22,13 @@ tasks = {}
19
  def create_task():
20
  prompt = request.args.get('prompt') or "покемон"
21
 
22
- task_id = f"{str(time.time())}_{prompt}"
 
 
23
 
24
  tasks[task_id] = {
25
  "task_id": task_id,
 
26
  "prompt": prompt,
27
  "status": "pending",
28
  "poll_count": 0,
@@ -37,7 +43,9 @@ def queue_task():
37
 
38
  tasks[task_id]["value"] = generate_image(tasks[task_id]["prompt"])
39
 
40
- tasks[task_id]["status"] = "complete"
 
 
41
 
42
  return jsonify(tasks[task_id])
43
 
@@ -46,6 +54,27 @@ def queue_task():
46
  def poll_task():
47
  task_id = request.args.get('task_id')
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  tasks[task_id]["poll_count"] += 1
50
 
51
  return jsonify(tasks[task_id])
1
+ from time import time
2
+ from statistics import mean
3
  from flask import Flask, jsonify, render_template, request
4
 
5
  from modules.details import load_lists, rand_details
7
 
8
  app = Flask(__name__)
9
 
10
+ TEMPLATES_AUTO_RELOAD = True
11
+
12
 
13
  @app.route('/')
14
  def index():
22
  def create_task():
23
  prompt = request.args.get('prompt') or "покемон"
24
 
25
+ created_at = time()
26
+
27
+ task_id = f"{str(created_at)}_{prompt}"
28
 
29
  tasks[task_id] = {
30
  "task_id": task_id,
31
+ "created_at": created_at,
32
  "prompt": prompt,
33
  "status": "pending",
34
  "poll_count": 0,
43
 
44
  tasks[task_id]["value"] = generate_image(tasks[task_id]["prompt"])
45
 
46
+ tasks[task_id]["status"] = "completed"
47
+
48
+ tasks[task_id]["completed_at"] = time()
49
 
50
  return jsonify(tasks[task_id])
51
 
54
  def poll_task():
55
  task_id = request.args.get('task_id')
56
 
57
+ pending_tasks = []
58
+ completed_durations = []
59
+
60
+ for task in tasks.values():
61
+ if task["status"] == "pending":
62
+ pending_tasks.append(task["task_id"])
63
+ elif task["status"] == "completed":
64
+ completed_durations.append(task["completed_at"] - task["created_at"])
65
+
66
+ try:
67
+ place_in_queue = pending_tasks.index(task_id) + 1
68
+ except:
69
+ place_in_queue = 0
70
+
71
+ if (len(completed_durations)):
72
+ eta = sum(completed_durations) / len(completed_durations) * (place_in_queue or 1)
73
+ else:
74
+ eta = 40 * (place_in_queue or 1)
75
+
76
+ tasks[task_id]["place_in_queue"] = place_in_queue
77
+ tasks[task_id]["eta"] = round(eta, 1)
78
  tasks[task_id]["poll_count"] += 1
79
 
80
  return jsonify(tasks[task_id])
modules/inference.py CHANGED
@@ -13,14 +13,11 @@ fp16 = torch.cuda.is_available()
13
 
14
  file_dir = "./models"
15
  file_name = "pytorch_model.bin"
16
- config_file_url = hf_hub_url(
17
- repo_id="minimaxir/ai-generated-pokemon-rudalle", filename=file_name)
18
  cached_download(config_file_url, cache_dir=file_dir, force_filename=file_name)
19
 
20
- model = get_rudalle_model('Malevich', pretrained=False,
21
- fp16=fp16, device=device)
22
- model.load_state_dict(torch.load(
23
- f"{file_dir}/{file_name}", map_location=f"{'cuda:0' if torch.cuda.is_available() else 'cpu'}"))
24
 
25
  vae = get_vae().to(device)
26
  tokenizer = get_tokenizer()
@@ -50,8 +47,7 @@ def generate_image(prompt):
50
  if prompt.lower() in ['grass', 'fire', 'water', 'lightning', 'fighting', 'psychic', 'colorless', 'darkness', 'metal', 'dragon', 'fairy']:
51
  prompt = english_to_russian(prompt)
52
 
53
- result, _ = generate_images(
54
- prompt, tokenizer, model, vae, top_k=2048, images_num=1, top_p=0.995)
55
 
56
  buffer = BytesIO()
57
  result[0].save(buffer, format="PNG")
13
 
14
  file_dir = "./models"
15
  file_name = "pytorch_model.bin"
16
+ config_file_url = hf_hub_url(repo_id="minimaxir/ai-generated-pokemon-rudalle", filename=file_name)
 
17
  cached_download(config_file_url, cache_dir=file_dir, force_filename=file_name)
18
 
19
+ model = get_rudalle_model('Malevich', pretrained=False, fp16=fp16, device=device)
20
+ model.load_state_dict(torch.load(f"{file_dir}/{file_name}", map_location=f"{'cuda:0' if torch.cuda.is_available() else 'cpu'}"))
 
 
21
 
22
  vae = get_vae().to(device)
23
  tokenizer = get_tokenizer()
47
  if prompt.lower() in ['grass', 'fire', 'water', 'lightning', 'fighting', 'psychic', 'colorless', 'darkness', 'metal', 'dragon', 'fairy']:
48
  prompt = english_to_russian(prompt)
49
 
50
+ result, _ = generate_images(prompt, tokenizer, model, vae, top_k=2048, images_num=1, top_p=0.995)
 
51
 
52
  buffer = BytesIO()
53
  result[0].save(buffer, format="PNG")
static/index.js CHANGED
@@ -129,8 +129,9 @@ const createTask = async (prompt) => {
129
  return task;
130
  };
131
 
132
- const queueTask = (task_id) => {
133
- fetch(`${getBasePath()}task/queue?task_id=${task_id}`);
 
134
  };
135
 
136
  const pollTask = async (task) => {
@@ -140,18 +141,16 @@ const pollTask = async (task) => {
140
  };
141
 
142
  const longPollTask = async (task, interval = 10_000, max) => {
143
- if (task.status === 'complete' || (max && task.poll_count > max)) {
144
- return task;
145
- }
146
 
147
- const taskResponse = await fetch(`${getBasePath()}task/poll?task_id=${task.task_id}`);
148
 
149
- task = await taskResponse.json();
150
-
151
- if (task.status === 'complete' || task.poll_count > max) {
152
  return task;
153
  }
154
 
 
 
155
  await new Promise((resolve) => setTimeout(resolve, interval));
156
 
157
  return await longPollTask(task, interval, max);
@@ -162,26 +161,34 @@ const longPollTask = async (task, interval = 10_000, max) => {
162
  const generateButton = document.querySelector('button.generate');
163
 
164
  const durationTimer = () => {
 
165
  let duration = 0.0;
166
 
167
- return (secondsElement) => {
168
  const startTime = performance.now();
169
 
170
  const incrementSeconds = setInterval(() => {
171
  duration += 0.1;
172
- secondsElement.textContent = duration.toFixed(1);
173
  }, 100);
174
 
175
- const updateDuration = () => (duration = Number(((performance.now() - startTime) / 1_000).toFixed(1)));
 
 
 
 
 
 
 
176
 
177
  window.addEventListener('focus', updateDuration);
178
 
179
  return {
180
- cleanup: () => {
181
- updateDuration();
182
  clearInterval(incrementSeconds);
183
  window.removeEventListener('focus', updateDuration);
184
- secondsElement.textContent = duration.toFixed(1);
185
  },
186
  };
187
  };
@@ -238,7 +245,7 @@ generateButton.addEventListener('click', async () => {
238
  }
239
 
240
  const renderSection = document.querySelector('section.render');
241
- const durationSeconds = document.querySelector('.duration > .seconds');
242
  const initialiseCardRotation = cardRotationInitiator(renderSection);
243
 
244
  try {
@@ -246,14 +253,15 @@ generateButton.addEventListener('click', async () => {
246
 
247
  const details = await generateDetails();
248
  const task = await createTask(details.energy_type);
249
- queueTask(task.task_id);
250
 
251
- const timer = durationTimer();
252
- const cleanupTimer = timer(durationSeconds).cleanup;
 
 
 
253
 
254
- const completedTask = await longPollTask(task);
255
  generating = false;
256
- cleanupTimer();
257
 
258
  renderSection.innerHTML = cardHTML(details);
259
  const picture = document.querySelector('img.picture');
129
  return task;
130
  };
131
 
132
+ const queueTask = async (task_id) => {
133
+ const queueResponse = await fetch(`${getBasePath()}task/queue?task_id=${task_id}`);
134
+ return queueResponse.json();
135
  };
136
 
137
  const pollTask = async (task) => {
141
  };
142
 
143
  const longPollTask = async (task, interval = 10_000, max) => {
144
+ const etaDisplay = document.querySelector('.eta');
 
 
145
 
146
+ task = await pollTask(task);
147
 
148
+ if (task.status === 'completed' || (max && task.poll_count > max)) {
 
 
149
  return task;
150
  }
151
 
152
+ etaDisplay.textContent = Math.round(task.eta);
153
+
154
  await new Promise((resolve) => setTimeout(resolve, interval));
155
 
156
  return await longPollTask(task, interval, max);
161
  const generateButton = document.querySelector('button.generate');
162
 
163
  const durationTimer = () => {
164
+ const elapsedDisplay = document.querySelector('.elapsed');
165
  let duration = 0.0;
166
 
167
+ return () => {
168
  const startTime = performance.now();
169
 
170
  const incrementSeconds = setInterval(() => {
171
  duration += 0.1;
172
+ elapsedDisplay.textContent = duration.toFixed(1);
173
  }, 100);
174
 
175
+ const updateDuration = (task) => {
176
+ if (task?.status == 'completed') {
177
+ duration = task.completed_at - task.created_at;
178
+ return;
179
+ }
180
+
181
+ duration = Number(((performance.now() - startTime) / 1_000).toFixed(1));
182
+ };
183
 
184
  window.addEventListener('focus', updateDuration);
185
 
186
  return {
187
+ cleanup: (completedTask) => {
188
+ updateDuration(completedTask);
189
  clearInterval(incrementSeconds);
190
  window.removeEventListener('focus', updateDuration);
191
+ elapsedDisplay.textContent = duration.toFixed(1);
192
  },
193
  };
194
  };
245
  }
246
 
247
  const renderSection = document.querySelector('section.render');
248
+ const durationDisplay = document.querySelector('.duration');
249
  const initialiseCardRotation = cardRotationInitiator(renderSection);
250
 
251
  try {
253
 
254
  const details = await generateDetails();
255
  const task = await createTask(details.energy_type);
 
256
 
257
+ const timer = durationTimer(durationDisplay);
258
+ const timerCleanup = timer().cleanup;
259
+
260
+ const longPromises = [queueTask(task.task_id), longPollTask(task)];
261
+ const completedTask = await Promise.any(longPromises);
262
 
 
263
  generating = false;
264
+ timerCleanup(completedTask);
265
 
266
  renderSection.innerHTML = cardHTML(details);
267
  const picture = document.querySelector('img.picture');
templates/index.html CHANGED
@@ -15,7 +15,7 @@
15
  </head>
16
  <body>
17
  <h1>This Pokémon Does Not Exist</h1>
18
- <div class="duration"><span class="seconds">0.0</span>s (ETA: 490s)</div>
19
  <button class="generate">Generate Pokémon Card with AI</button>
20
  <section class="render"></section>
21
  </body>
15
  </head>
16
  <body>
17
  <h1>This Pokémon Does Not Exist</h1>
18
+ <div class="duration"><span class="elapsed">0.0</span>s (ETA: <span class="eta">40</span>s)</div>
19
  <button class="generate">Generate Pokémon Card with AI</button>
20
  <section class="render"></section>
21
  </body>