b / app.py
soiz's picture
Update app.py
de728d2 verified
raw
history blame
2.82 kB
import os
os.system("pip install flask")
from random import randint
from flask import Flask, request, send_file
from io import BytesIO
from PIL import Image
import importlib
app = Flask(__name__)
def load_model(model_name):
try:
# Dynamically import the model module
module = importlib.import_module(f'all_models.{model_name}')
return module.get_model() # Assume get_model() returns a callable model
except Exception as error:
print(f"Error loading model {model_name}: {error}")
return lambda txt: None # Return a placeholder function
num_models = 6
default_models = ['model1', 'model2', 'model3', 'model4', 'model5', 'model6'] # Replace with actual model names
def extend_choices(choices):
return choices + (num_models - len(choices)) * ['NA']
def update_imgbox(choices):
choices_plus = extend_choices(choices)
return [Image.open(None) if m != 'NA' else None for m in choices_plus]
def gen_fn(model_str, prompt, negative_prompt=None):
if model_str == 'NA':
return None
noise = str('') # Optional: str(randint(0, 99999999999))
try:
model = load_model(model_str)
full_prompt = f'{prompt} {noise}'
if negative_prompt:
full_prompt += f' -{negative_prompt}'
result = model(full_prompt)
# Check if result is an image or a file path
if isinstance(result, str): # Assuming result might be a file path
if os.path.exists(result):
return Image.open(result)
else:
print(f"File path not found: {result}")
return None
elif isinstance(result, Image.Image):
return result
else:
print("Result is not an image:", type(result))
return None
except Exception as e:
print("Error generating image:", e)
return None
@app.route('/', methods=['GET'])
def home():
prompt = request.args.get('prompt', '')
model = request.args.get('model', default_models[0] if default_models else 'NA')
negative_prompt = request.args.get('Nprompt', None)
if not model:
return f'Invalid model: {model}', 400
if prompt:
# Generate the image
image = gen_fn(model, prompt, negative_prompt)
if isinstance(image, Image.Image): # Ensure the result is a PIL image
# Save image to BytesIO object
img_io = BytesIO()
image.save(img_io, format='PNG')
img_io.seek(0)
return send_file(img_io, mimetype='image/png', as_attachment=False)
return 'Failed to generate image.', 500
return 'Please provide a "prompt" query parameter in the URL.', 400
if __name__ == '__main__':
# Launch Flask app
app.run(host='0.0.0.0', port=7860) # Run Flask app