Geek7 commited on
Commit
3b41fd8
·
verified ·
1 Parent(s): 0d0d249

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -37
app.py CHANGED
@@ -1,58 +1,44 @@
1
- import gradio as gr
2
- from random import randint
3
- from all_models import models # Import the list of available models
4
- from externalmod import gr_Interface_load
5
- import asyncio
6
- import os
7
- from threading import RLock
8
  from flask import Flask, request, jsonify, send_file
9
  from flask_cors import CORS
 
10
  import tempfile
11
-
12
- lock = RLock()
13
- HF_TOKEN = os.environ.get("HF_TOKEN")
 
14
 
15
  app = Flask(__name__)
16
  CORS(app) # Enable CORS for all routes
17
 
18
- # Function to load models
19
- def load_fn(models):
20
- global models_load
21
- models_load = {}
22
-
23
- for model in models:
24
- if model not in models_load.keys():
25
- try:
26
- m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
27
- except Exception as error:
28
- print(error)
29
- m = gr.Interface(lambda: None, ['text'], ['image'])
30
- models_load.update({model: m})
31
 
32
- load_fn(models)
33
 
34
- num_models = 6 # Number of models to load initially
35
- MAX_SEED = 3999999999
36
- default_models = models[:num_models] # Load the first few models for inference
37
- inference_timeout = 600
38
 
39
  # Asynchronous function to perform inference
40
- async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
41
- kwargs = {"seed": seed}
42
- task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN))
 
43
  await asyncio.sleep(0)
44
  try:
45
  result = await asyncio.wait_for(task, timeout=timeout)
46
  except (Exception, asyncio.TimeoutError) as e:
47
  print(e)
48
- print(f"Task timed out: {model_str}")
49
- if not task.done():
50
  task.cancel()
51
  result = None
 
52
  if task.done() and result is not None:
53
  with lock:
54
  temp_image = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
55
- result.save(temp_image.name) # Save result as a temporary file
 
56
  return temp_image.name # Return the path to the saved image
57
  return None
58
 
@@ -62,16 +48,24 @@ def generate_api():
62
  data = request.get_json()
63
 
64
  # Extract required fields from the request
65
- model_str = data.get('model_str', default_models[0]) # Default to first model if not provided
66
  prompt = data.get('prompt', '')
67
  seed = data.get('seed', 1)
 
68
 
69
  if not prompt:
70
  return jsonify({"error": "Prompt is required"}), 400
71
-
 
 
 
 
 
72
  try:
 
 
 
73
  # Call the async inference function
74
- result_path = asyncio.run(infer(model_str, prompt, seed))
75
  if result_path:
76
  return send_file(result_path, mimetype='image/png') # Send back the generated image file
77
  else:
 
 
 
 
 
 
 
 
1
  from flask import Flask, request, jsonify, send_file
2
  from flask_cors import CORS
3
+ import asyncio
4
  import tempfile
5
+ import os
6
+ from threading import RLock
7
+ from huggingface_hub import InferenceClient
8
+ from all_models import models # Importing models from all_models
9
 
10
  app = Flask(__name__)
11
  CORS(app) # Enable CORS for all routes
12
 
13
+ lock = RLock()
14
+ HF_TOKEN = os.environ.get("HF_TOKEN") # Hugging Face token
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ inference_timeout = 600 # Set timeout for inference
17
 
18
+ # Function to dynamically load models from the "models" list
19
+ def get_model_from_name(model_name):
20
+ return model_name if model_name in models else None
 
21
 
22
  # Asynchronous function to perform inference
23
+ async def infer(client, prompt, seed=1, timeout=inference_timeout, model="prompthero/openjourney-v4"):
24
+ task = asyncio.create_task(
25
+ asyncio.to_thread(client.text_to_image, prompt=prompt, seed=seed, model=model)
26
+ )
27
  await asyncio.sleep(0)
28
  try:
29
  result = await asyncio.wait_for(task, timeout=timeout)
30
  except (Exception, asyncio.TimeoutError) as e:
31
  print(e)
32
+ print(f"Task timed out for model: {model}")
33
+ if not task.done():
34
  task.cancel()
35
  result = None
36
+
37
  if task.done() and result is not None:
38
  with lock:
39
  temp_image = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
40
+ with open(temp_image.name, "wb") as f:
41
+ f.write(result) # Save the result image as a temporary file
42
  return temp_image.name # Return the path to the saved image
43
  return None
44
 
 
48
  data = request.get_json()
49
 
50
  # Extract required fields from the request
 
51
  prompt = data.get('prompt', '')
52
  seed = data.get('seed', 1)
53
+ model_name = data.get('model', 'prompthero/openjourney-v4') # Default to "prompthero/openjourney-v4" if not provided
54
 
55
  if not prompt:
56
  return jsonify({"error": "Prompt is required"}), 400
57
+
58
+ # Get the model from all_models
59
+ model = get_model_from_name(model_name)
60
+ if not model:
61
+ return jsonify({"error": f"Model '{model_name}' not found in available models"}), 400
62
+
63
  try:
64
+ # Create a generic InferenceClient for the model
65
+ client = InferenceClient(token=HF_TOKEN) # Pass Hugging Face token if needed
66
+
67
  # Call the async inference function
68
+ result_path = asyncio.run(infer(client, prompt, seed, model=model))
69
  if result_path:
70
  return send_file(result_path, mimetype='image/png') # Send back the generated image file
71
  else: