PokeGen / app.py
Ron Au
feat(errors): Add better error handling
9421891
raw
history blame
2.94 kB
from time import time
from statistics import mean
from flask import Flask, jsonify, render_template, request
from modules.details import rand_details
from modules.inference import generate_image
app = Flask(__name__)
@app.route('/')
def index():
return render_template('index.html', **rand_details())
tasks = {}
def place_in_queue(task_id):
pending_tasks = list(task for task in tasks.values()
if task["status"] == "pending")
try:
return pending_tasks.index(task_id) + 1
except:
return 0
def calculate_eta(task_id):
total_durations = list(task["completed_at"] - task["created_at"]
for task in tasks.values() if "completed_at" in task)
place = tasks[task_id]["initial_place_in_queue"] or 1
if len(total_durations):
return sum(total_durations) / len(total_durations) * place
else:
return 40 * place
@app.route('/task/create')
def create_task():
prompt = request.args.get('prompt') or "покемон"
created_at = time()
task_id = f"{str(created_at)}_{prompt}"
tasks[task_id] = {
"task_id": task_id,
"created_at": created_at,
"prompt": prompt,
"initial_place_in_queue": place_in_queue(task_id),
"status": "pending",
"poll_count": 0,
}
print("Place in queue: ", place_in_queue(task_id))
print("ETA: ", calculate_eta(task_id))
return jsonify(tasks[task_id])
@app.route('/task/queue')
def queue_task():
task_id = request.args.get('task_id')
try:
tasks[task_id]["value"] = generate_image(tasks[task_id]["prompt"])
except Exception as ex:
tasks[task_id]["status"] = "failed"
tasks[task_id]["error"] = repr(ex)
else:
tasks[task_id]["status"] = "completed"
finally:
tasks[task_id]["completed_at"] = time()
return jsonify(tasks[task_id])
@app.route('/task/poll')
def poll_task():
task_id = request.args.get('task_id')
pending_tasks = []
completed_durations = []
for task in tasks.values():
if task["status"] == "pending":
pending_tasks.append(task["task_id"])
elif task["status"] == "completed":
completed_durations.append(
task["completed_at"] - task["created_at"])
try:
place_in_queue = pending_tasks.index(task_id) + 1
except:
place_in_queue = 0
if (len(completed_durations)):
eta = sum(completed_durations) / \
len(completed_durations) * (place_in_queue or 1)
else:
eta = 40 * (place_in_queue or 1)
tasks[task_id]["place_in_queue"] = place_in_queue
tasks[task_id]["eta"] = round(eta, 1)
tasks[task_id]["poll_count"] += 1
return jsonify(tasks[task_id])
@app.route('/details')
def generate_details():
return jsonify(rand_details())
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)