Geek7 commited on
Commit
b5c8a42
1 Parent(s): 44aff97

Update myapp.py

Browse files
Files changed (1) hide show
  1. myapp.py +33 -65
myapp.py CHANGED
@@ -1,77 +1,45 @@
1
  from flask import Flask, request, jsonify, send_file
2
- import gradio as gr
3
- from random import randint
4
- from all_models import models
5
- from externalmod import gr_Interface_load
6
- import asyncio
7
- import os
8
- from threading import RLock
9
- from PIL import Image
10
 
11
- myapp = Flask(__name__)
 
12
 
13
- lock = RLock()
14
- HF_TOKEN = os.environ.get("HF_TOKEN")
15
 
16
- # Load models
17
- def load_fn(models):
18
- global models_load
19
- models_load = {}
20
-
21
- for model in models:
22
- if model not in models_load.keys():
23
- try:
24
- m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
25
- except Exception as error:
26
- print(error)
27
- m = gr.Interface(lambda: None, ['text'], ['image'])
28
- models_load.update({model: m})
29
-
30
- load_fn(models)
31
-
32
- num_models = 6
33
- MAX_SEED = 3999999999
34
- default_models = models[:num_models]
35
- inference_timeout = 600
36
 
37
- # Gradio inference function
38
- async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
39
- kwargs = {"seed": seed}
40
- task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN))
41
- await asyncio.sleep(0)
42
- try:
43
- result = await asyncio.wait_for(task, timeout=timeout)
44
- except (Exception, asyncio.TimeoutError) as e:
45
- print(e)
46
- print(f"Task timed out: {model_str}")
47
- if not task.done():
48
- task.cancel()
49
- result = None
50
- if task.done() and result is not None:
51
- with lock:
52
- png_path = "generated_image.png"
53
- result.save(png_path) # Save the result as an image
54
- return png_path
55
- return None
56
 
57
- # API function to perform inference
58
- @myapp.route('/generate-image', methods=['POST'])
59
- def generate_image():
60
- data = request.get_json()
61
  model_str = data['model_str']
62
  prompt = data['prompt']
63
- seed = data.get('seed', 1)
 
 
 
 
64
 
65
- # Run Gradio inference
66
- result_path = asyncio.run(infer(model_str, prompt, seed))
67
-
68
- if result_path:
69
- # Send back the generated image file
 
 
 
 
 
 
70
  return send_file(result_path, mimetype='image/png')
71
- else:
72
- return jsonify({"error": "Failed to generate image."}), 500
73
 
 
 
74
 
75
- # Add this block to make sure your app runs when called
76
- if __name__ == "__main__":
77
- myapp.run(host='0.0.0.0', port=7860) # Run directly
 
1
  from flask import Flask, request, jsonify, send_file
2
+ from flask_cors import CORS
3
+ from gradio_client import Client
4
+ from all_models import models # Import the models list
 
 
 
 
 
5
 
6
+ app = Flask(__name__)
7
+ CORS(app)
8
 
9
+ # Initialize Gradio Client with the first model in the list
10
+ client = Client("Geek7/mdztxi2")
11
 
12
+ @app.route('/predict', methods=['POST'])
13
+ def predict():
14
+ data = request.get_json()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ # Validate required fields
17
+ if not data or 'model_str' not in data or 'prompt' not in data or 'seed' not in data:
18
+ return jsonify({"error": "Missing required fields"}), 400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
 
 
 
 
20
  model_str = data['model_str']
21
  prompt = data['prompt']
22
+ seed = data['seed']
23
+
24
+ # Check if the model_str exists in the models list
25
+ if model_str not in models:
26
+ return jsonify({"error": f"Model '{model_str}' is not available."}), 400
27
 
28
+ try:
29
+ # Send a request to the Gradio Client and get the result
30
+ result = client.predict(
31
+ model_str=model_str,
32
+ prompt=prompt,
33
+ seed=seed,
34
+ api_name="/predict"
35
+ )
36
+
37
+ # Save the result to a file (assuming it returns a filepath)
38
+ result_path = result # Result is already the filepath
39
  return send_file(result_path, mimetype='image/png')
 
 
40
 
41
+ except Exception as e:
42
+ return jsonify({"error": str(e)}), 500
43
 
44
+ if __name__ == '__main__':
45
+ app.run(debug=True)