b / app.py
soiz's picture
Update app.py
6b25aaf verified
raw
history blame
3.22 kB
import os
from flask import Flask, request, send_file
from io import BytesIO
from PIL import Image, ImageChops
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
import torch
# Flaskアプリケーションの初期化
app = Flask(__name__)
# グローバルなモデル辞書
models_load = {}
def load_model(model_name):
global models_load
if model_name not in models_load:
try:
scheduler = EulerDiscreteScheduler.from_pretrained(model_name, subfolder="scheduler")
pipe = StableDiffusionPipeline.from_pretrained(model_name, scheduler=scheduler, torch_dtype=torch.float16)
pipe = pipe.to("cuda")
models_load[model_name] = pipe
print(f"Model {model_name} loaded successfully.")
except Exception as error:
print(f"Error loading model {model_name}: {error}")
models_load[model_name] = None
def gen_fn(model_str, prompt, negative_prompt=None, noise=None, cfg_scale=None, num_inference_steps=None):
if model_str not in models_load:
load_model(model_str)
if model_str in models_load and models_load[model_str] is not None:
if noise == "random":
noise = str(randint(0, 99999999999))
full_prompt = f'{prompt} {noise}' if noise else prompt
print(f"Prompt: {full_prompt}")
try:
# モデル呼び出し
result = models_load[model_str](full_prompt, num_inference_steps=num_inference_steps)
image = result.images[0] # 生成された画像を取得
# Check if the image is completely black
black = Image.new('RGB', image.size, (0, 0, 0))
if ImageChops.difference(image, black).getbbox() is None:
return None, 'The image is completely black.'
return image, None
except Exception as e:
print("Error generating image:", e)
return None, f"Error generating image: {e}"
else:
print(f"Model {model_str} not found")
return None, f"Model {model_str} not found"
@app.route('/', methods=['GET'])
def home():
prompt = request.args.get('prompt', '')
model = request.args.get('model', '')
noise = request.args.get('noise', None)
num_inference_steps = request.args.get('steps', 50) # デフォルト値を設定
try:
num_inference_steps = int(num_inference_steps)
except ValueError:
return 'Invalid "steps" parameter. It should be an integer.', 400
if not model:
return 'Please provide a "model" query parameter in the URL.', 400
if not prompt:
return 'Please provide a "prompt" query parameter in the URL.', 400
# Generate the image
image, error_message = gen_fn(model, prompt, noise=noise, num_inference_steps=num_inference_steps)
if error_message:
return error_message, 400
if isinstance(image, Image.Image):
img_io = BytesIO()
image.save(img_io, format='PNG')
img_io.seek(0)
return send_file(img_io, mimetype='image/png', as_attachment=False)
return 'Failed to generate image.', 500
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)