soiz commited on
Commit
58582ba
1 Parent(s): de728d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -22
app.py CHANGED
@@ -1,42 +1,49 @@
1
  import os
2
- os.system("pip install flask")
 
 
 
 
 
 
3
  from random import randint
 
4
  from flask import Flask, request, send_file
5
  from io import BytesIO
6
  from PIL import Image
7
- import importlib
8
 
9
  app = Flask(__name__)
10
 
11
- def load_model(model_name):
12
- try:
13
- # Dynamically import the model module
14
- module = importlib.import_module(f'all_models.{model_name}')
15
- return module.get_model() # Assume get_model() returns a callable model
16
- except Exception as error:
17
- print(f"Error loading model {model_name}: {error}")
18
- return lambda txt: None # Return a placeholder function
 
 
 
 
 
19
 
20
  num_models = 6
21
- default_models = ['model1', 'model2', 'model3', 'model4', 'model5', 'model6'] # Replace with actual model names
22
 
23
  def extend_choices(choices):
24
  return choices + (num_models - len(choices)) * ['NA']
25
 
26
  def update_imgbox(choices):
27
  choices_plus = extend_choices(choices)
28
- return [Image.open(None) if m != 'NA' else None for m in choices_plus]
29
 
30
- def gen_fn(model_str, prompt, negative_prompt=None):
31
  if model_str == 'NA':
32
  return None
33
  noise = str('') # Optional: str(randint(0, 99999999999))
34
  try:
35
- model = load_model(model_str)
36
- full_prompt = f'{prompt} {noise}'
37
- if negative_prompt:
38
- full_prompt += f' -{negative_prompt}'
39
- result = model(full_prompt)
40
  # Check if result is an image or a file path
41
  if isinstance(result, str): # Assuming result might be a file path
42
  if os.path.exists(result):
@@ -57,14 +64,13 @@ def gen_fn(model_str, prompt, negative_prompt=None):
57
  def home():
58
  prompt = request.args.get('prompt', '')
59
  model = request.args.get('model', default_models[0] if default_models else 'NA')
60
- negative_prompt = request.args.get('Nprompt', None)
61
 
62
- if not model:
63
  return f'Invalid model: {model}', 400
64
 
65
  if prompt:
66
  # Generate the image
67
- image = gen_fn(model, prompt, negative_prompt)
68
  if isinstance(image, Image.Image): # Ensure the result is a PIL image
69
  # Save image to BytesIO object
70
  img_io = BytesIO()
@@ -76,4 +82,4 @@ def home():
76
 
77
  if __name__ == '__main__':
78
  # Launch Flask app
79
- app.run(host='0.0.0.0', port=7860) # Run Flask app
 
1
  import os
2
+
3
+ # Install Flask if not already installed
4
+ return_code = os.system('pip install flask')
5
+ if return_code != 0:
6
+ raise RuntimeError("Failed to install Flask")
7
+
8
+ import gradio as gr
9
  from random import randint
10
+ from all_models import models
11
  from flask import Flask, request, send_file
12
  from io import BytesIO
13
  from PIL import Image
 
14
 
15
  app = Flask(__name__)
16
 
17
+ def load_fn(models):
18
+ global models_load
19
+ models_load = {}
20
+
21
+ for model in models:
22
+ if model not in models_load.keys():
23
+ try:
24
+ m = gr.load(f'models/{model}')
25
+ except Exception as error:
26
+ m = gr.Interface(lambda txt: None, ['text'], ['image'])
27
+ models_load.update({model: m})
28
+
29
+ load_fn(models)
30
 
31
  num_models = 6
32
+ default_models = models[:num_models]
33
 
34
  def extend_choices(choices):
35
  return choices + (num_models - len(choices)) * ['NA']
36
 
37
  def update_imgbox(choices):
38
  choices_plus = extend_choices(choices)
39
+ return [gr.Image(None, label=m, visible=(m != 'NA')) for m in choices_plus]
40
 
41
+ def gen_fn(model_str, prompt):
42
  if model_str == 'NA':
43
  return None
44
  noise = str('') # Optional: str(randint(0, 99999999999))
45
  try:
46
+ result = models_load[model_str](f'{prompt} {noise}')
 
 
 
 
47
  # Check if result is an image or a file path
48
  if isinstance(result, str): # Assuming result might be a file path
49
  if os.path.exists(result):
 
64
  def home():
65
  prompt = request.args.get('prompt', '')
66
  model = request.args.get('model', default_models[0] if default_models else 'NA')
 
67
 
68
+ if not model or model not in models_load:
69
  return f'Invalid model: {model}', 400
70
 
71
  if prompt:
72
  # Generate the image
73
+ image = gen_fn(model, prompt)
74
  if isinstance(image, Image.Image): # Ensure the result is a PIL image
75
  # Save image to BytesIO object
76
  img_io = BytesIO()
 
82
 
83
  if __name__ == '__main__':
84
  # Launch Flask app
85
+ app.run(host='0.0.0.0', port=7860) # Run Flask app