|
import os |
|
|
|
|
|
return_code = os.system('pip install flask') |
|
if return_code != 0: |
|
raise RuntimeError("Failed to install Flask") |
|
|
|
import gradio as gr |
|
from random import randint |
|
from all_models import models |
|
from flask import Flask, request, send_file |
|
from io import BytesIO |
|
from PIL import Image |
|
|
|
app = Flask(__name__) |
|
|
|
def load_fn(models): |
|
global models_load |
|
models_load = {} |
|
|
|
for model in models: |
|
if model not in models_load.keys(): |
|
try: |
|
m = gr.load(f'models/{model}') |
|
except Exception as error: |
|
m = gr.Interface(lambda txt: None, ['text'], ['image']) |
|
models_load.update({model: m}) |
|
|
|
load_fn(models) |
|
|
|
num_models = 6 |
|
default_models = models[:num_models] |
|
|
|
def extend_choices(choices): |
|
return choices + (num_models - len(choices)) * ['NA'] |
|
|
|
def update_imgbox(choices): |
|
choices_plus = extend_choices(choices) |
|
return [gr.Image(None, label=m, visible=(m != 'NA')) for m in choices_plus] |
|
|
|
def gen_fn(model_str, prompt): |
|
if model_str == 'NA': |
|
return None |
|
noise = str('') |
|
try: |
|
result = models_load[model_str](f'{prompt} {noise}') |
|
|
|
if isinstance(result, str): |
|
if os.path.exists(result): |
|
return Image.open(result) |
|
else: |
|
print(f"File path not found: {result}") |
|
return None |
|
elif isinstance(result, Image.Image): |
|
return result |
|
else: |
|
print("Result is not an image:", type(result)) |
|
return None |
|
except Exception as e: |
|
print("Error generating image:", e) |
|
return None |
|
|
|
@app.route('/', methods=['GET']) |
|
def home(): |
|
prompt = request.args.get('prompt', '') |
|
model = request.args.get('model', default_models[0] if default_models else 'NA') |
|
|
|
if not model or model not in models_load: |
|
return f'Invalid model: {model}', 400 |
|
|
|
if prompt: |
|
|
|
image = gen_fn(model, prompt) |
|
if isinstance(image, Image.Image): |
|
|
|
img_io = BytesIO() |
|
image.save(img_io, format='PNG') |
|
img_io.seek(0) |
|
return send_file(img_io, mimetype='image/png', as_attachment=False) |
|
return 'Failed to generate image.', 500 |
|
return 'Please provide a "prompt" query parameter in the URL.', 400 |
|
|
|
if __name__ == '__main__': |
|
|
|
app.run(host='0.0.0.0', port=7860) |