b / app.py
soiz's picture
Update app.py
58582ba verified
raw
history blame
2.7 kB
import os
# Install Flask if not already installed
return_code = os.system('pip install flask')
if return_code != 0:
raise RuntimeError("Failed to install Flask")
import gradio as gr
from random import randint
from all_models import models
from flask import Flask, request, send_file
from io import BytesIO
from PIL import Image
app = Flask(__name__)
def load_fn(models):
global models_load
models_load = {}
for model in models:
if model not in models_load.keys():
try:
m = gr.load(f'models/{model}')
except Exception as error:
m = gr.Interface(lambda txt: None, ['text'], ['image'])
models_load.update({model: m})
load_fn(models)
num_models = 6
default_models = models[:num_models]
def extend_choices(choices):
return choices + (num_models - len(choices)) * ['NA']
def update_imgbox(choices):
choices_plus = extend_choices(choices)
return [gr.Image(None, label=m, visible=(m != 'NA')) for m in choices_plus]
def gen_fn(model_str, prompt):
if model_str == 'NA':
return None
noise = str('') # Optional: str(randint(0, 99999999999))
try:
result = models_load[model_str](f'{prompt} {noise}')
# Check if result is an image or a file path
if isinstance(result, str): # Assuming result might be a file path
if os.path.exists(result):
return Image.open(result)
else:
print(f"File path not found: {result}")
return None
elif isinstance(result, Image.Image):
return result
else:
print("Result is not an image:", type(result))
return None
except Exception as e:
print("Error generating image:", e)
return None
@app.route('/', methods=['GET'])
def home():
prompt = request.args.get('prompt', '')
model = request.args.get('model', default_models[0] if default_models else 'NA')
if not model or model not in models_load:
return f'Invalid model: {model}', 400
if prompt:
# Generate the image
image = gen_fn(model, prompt)
if isinstance(image, Image.Image): # Ensure the result is a PIL image
# Save image to BytesIO object
img_io = BytesIO()
image.save(img_io, format='PNG')
img_io.seek(0)
return send_file(img_io, mimetype='image/png', as_attachment=False)
return 'Failed to generate image.', 500
return 'Please provide a "prompt" query parameter in the URL.', 400
if __name__ == '__main__':
# Launch Flask app
app.run(host='0.0.0.0', port=7860) # Run Flask app