Ron Au commited on
Commit
c6fcf99
1 Parent(s): 0c0e375

refactor(FastAPI): Flask -> FastAPI

Browse files
Files changed (5) hide show
  1. README.md +1 -1
  2. app.py +35 -26
  3. requirements.txt +3 -2
  4. start.py +4 -0
  5. {templates → static}/index.html +2 -2
README.md CHANGED
@@ -4,6 +4,6 @@ emoji: 🧬
4
  colorFrom: blue
5
  colorTo: yellow
6
  sdk: gradio
7
- app_file: app.py
8
  pinned: false
9
  ---
4
  colorFrom: blue
5
  colorTo: yellow
6
  sdk: gradio
7
+ app_file: start.py
8
  pinned: false
9
  ---
app.py CHANGED
@@ -1,16 +1,17 @@
 
 
 
 
 
1
  from time import time
2
  from statistics import mean
3
- from flask import Flask, jsonify, render_template, request
4
 
5
  from modules.details import rand_details
6
  from modules.inference import generate_image
7
 
8
- app = Flask(__name__)
9
 
10
-
11
- @app.route('/')
12
- def index():
13
- return render_template('index.html', **rand_details())
14
 
15
 
16
  tasks = {}
@@ -34,15 +35,18 @@ def calculate_eta(task_id):
34
  place = tasks[task_id]["initial_place_in_queue"] or 1
35
 
36
  if len(total_durations):
37
- return sum(total_durations) / len(total_durations) * place
38
  else:
39
  return 40 * place
40
 
41
 
42
- @app.route('/task/create')
43
- def create_task():
44
- prompt = request.args.get('prompt') or "покемон"
 
45
 
 
 
46
  created_at = time()
47
 
48
  task_id = f"{str(created_at)}_{prompt}"
@@ -59,13 +63,11 @@ def create_task():
59
  print("Place in queue: ", place_in_queue(task_id))
60
  print("ETA: ", calculate_eta(task_id))
61
 
62
- return jsonify(tasks[task_id])
63
-
64
 
65
- @app.route('/task/queue')
66
- def queue_task():
67
- task_id = request.args.get('task_id')
68
 
 
 
69
  try:
70
  tasks[task_id]["value"] = generate_image(tasks[task_id]["prompt"])
71
  except Exception as ex:
@@ -76,13 +78,11 @@ def queue_task():
76
  finally:
77
  tasks[task_id]["completed_at"] = time()
78
 
79
- return jsonify(tasks[task_id])
80
-
81
 
82
- @app.route('/task/poll')
83
- def poll_task():
84
- task_id = request.args.get('task_id')
85
 
 
 
86
  pending_tasks = []
87
  completed_durations = []
88
 
@@ -108,13 +108,22 @@ def poll_task():
108
  tasks[task_id]["eta"] = round(eta, 1)
109
  tasks[task_id]["poll_count"] += 1
110
 
111
- return jsonify(tasks[task_id])
 
 
 
 
 
 
112
 
113
 
114
- @app.route('/details')
115
- def generate_details():
116
- return jsonify(rand_details())
 
117
 
118
 
119
- if __name__ == '__main__':
120
- app.run(host='0.0.0.0', port=7860)
 
 
1
+ from fastapi import FastAPI
2
+ from fastapi.staticfiles import StaticFiles
3
+ from fastapi.responses import FileResponse
4
+
5
+
6
  from time import time
7
  from statistics import mean
 
8
 
9
  from modules.details import rand_details
10
  from modules.inference import generate_image
11
 
12
+ app = FastAPI()
13
 
14
+ app.mount("/static", StaticFiles(directory="static"), name="static")
 
 
 
15
 
16
 
17
  tasks = {}
35
  place = tasks[task_id]["initial_place_in_queue"] or 1
36
 
37
  if len(total_durations):
38
+ return mean(total_durations) * place
39
  else:
40
  return 40 * place
41
 
42
 
43
+ @app.get('/')
44
+ def index():
45
+ return FileResponse(path="static/index.html", media_type="text/html")
46
+
47
 
48
+ @app.get('/task/create')
49
+ def create_task(prompt: str = "покемон"):
50
  created_at = time()
51
 
52
  task_id = f"{str(created_at)}_{prompt}"
63
  print("Place in queue: ", place_in_queue(task_id))
64
  print("ETA: ", calculate_eta(task_id))
65
 
66
+ return tasks[task_id]
 
67
 
 
 
 
68
 
69
+ @app.get('/task/queue')
70
+ def queue_task(task_id: str):
71
  try:
72
  tasks[task_id]["value"] = generate_image(tasks[task_id]["prompt"])
73
  except Exception as ex:
78
  finally:
79
  tasks[task_id]["completed_at"] = time()
80
 
81
+ return tasks[task_id]
 
82
 
 
 
 
83
 
84
+ @app.get('/task/poll')
85
+ def poll_task(task_id: str):
86
  pending_tasks = []
87
  completed_durations = []
88
 
108
  tasks[task_id]["eta"] = round(eta, 1)
109
  tasks[task_id]["poll_count"] += 1
110
 
111
+ return tasks[task_id]
112
+
113
+
114
+ # @app.route('/details')
115
+ @app.get('/details')
116
+ async def generate_details():
117
+ return rand_details()
118
 
119
 
120
+ @app.get('/duck/quack')
121
+ async def test(query: str = "quack"):
122
+ print(query)
123
+ return {"duck": query}
124
 
125
 
126
+ @app.get('/test')
127
+ async def test(query: str = "test"):
128
+ print(query)
129
+ return {"query": query}
requirements.txt CHANGED
@@ -1,2 +1,3 @@
1
- Flask
2
- rudalle==1.0.0
 
1
+ rudalle==1.0.*
2
+ fastapi==0.74.*
3
+ uvicorn[standard]==0.17.*
start.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
1
+
2
+ import subprocess
3
+
4
+ subprocess.run("uvicorn app:app", shell=True)
{templates → static}/index.html RENAMED
@@ -9,13 +9,13 @@
9
  opacity: 0;
10
  }
11
  </style>
12
- <link rel="shortcut icon" href="{{ url_for('static', filename='favicon.ico') }}" />
13
  <link rel="stylesheet" id="stylesheet-tag" />
14
  <script type="module" id="script-tag"></script>
15
  <script>
16
  const basePath = document.location.origin + document.location.pathname;
17
- document.getElementById('script-tag').src = `${basePath}static/js/index.js`;
18
  document.getElementById('stylesheet-tag').href = `${basePath}static/style.css`;
 
19
  </script>
20
  </head>
21
  <body>
9
  opacity: 0;
10
  }
11
  </style>
12
+ <link rel="shortcut icon" href="https://huggingface.co/front/assets/huggingface_logo-noborder.svg" />
13
  <link rel="stylesheet" id="stylesheet-tag" />
14
  <script type="module" id="script-tag"></script>
15
  <script>
16
  const basePath = document.location.origin + document.location.pathname;
 
17
  document.getElementById('stylesheet-tag').href = `${basePath}static/style.css`;
18
+ document.getElementById('script-tag').src = `${basePath}static/js/index.js`;
19
  </script>
20
  </head>
21
  <body>