Geek7 commited on
Commit
eb0bba6
1 Parent(s): cd81248

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -8
app.py CHANGED
@@ -1,11 +1,60 @@
1
- # app.py
2
-
 
 
 
3
  import os
4
- import subprocess
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- if __name__ == "__main__":
7
- # Run awake.py in the background
8
- subprocess.Popen(["python", "wk.py"]) # Start awake.py
 
 
 
9
 
10
- # Run the Flask app using Gunicorn
11
- os.system("gunicorn -w 4 -b 0.0.0.0:5000 myapp:myapp") # 4 worker processes
 
 
1
+ import gradio as gr
2
+ from random import randint
3
+ from all_models import models
4
+ from externalmod import gr_Interface_load
5
+ import asyncio
6
  import os
7
+ from threading import RLock
8
+
9
+ lock = RLock()
10
+ HF_TOKEN = os.environ.get("HF_TOKEN")
11
+
12
+ def load_fn(models):
13
+ global models_load
14
+ models_load = {}
15
+
16
+ for model in models:
17
+ if model not in models_load.keys():
18
+ try:
19
+ m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
20
+ except Exception as error:
21
+ print(error)
22
+ m = gr.Interface(lambda: None, ['text'], ['image'])
23
+ models_load.update({model: m})
24
+
25
+ load_fn(models)
26
+
27
+ num_models = 6
28
+ MAX_SEED = 3999999999
29
+ default_models = models[:num_models]
30
+ inference_timeout = 600
31
+
32
+ async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
33
+ kwargs = {"seed": seed}
34
+ task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN))
35
+ await asyncio.sleep(0)
36
+ try:
37
+ result = await asyncio.wait_for(task, timeout=timeout)
38
+ except (Exception, asyncio.TimeoutError) as e:
39
+ print(e)
40
+ print(f"Task timed out: {model_str}")
41
+ if not task.done():
42
+ task.cancel()
43
+ result = None
44
+ if task.done() and result is not None:
45
+ with lock:
46
+ png_path = "image.png"
47
+ result.save(png_path)
48
+ return png_path
49
+ return None
50
 
51
+ # Expose Gradio API
52
+ def generate_api(model_str, prompt, seed=1):
53
+ result = asyncio.run(infer(model_str, prompt, seed))
54
+ if result:
55
+ return result # Path to generated image
56
+ return None
57
 
58
+ # Launch Gradio API without frontend
59
+ iface = gr.Interface(fn=generate_api, inputs=["text", "text", "number"], outputs="file")
60
+ iface.launch(show_api=True, share=True)