File size: 3,224 Bytes
6437aa7 f9b29bc 6437aa7 bc1a1bd 6b25aaf 6437aa7 6b25aaf 6437aa7 ac69805 f9b29bc ac69805 6b25aaf 967362a ac69805 6b25aaf f9b29bc 6b25aaf ac69805 37a6949 ac69805 6b25aaf c325d55 622195c 6b25aaf 7ab2c9b 9327b29 37a6949 6b25aaf 967362a c325d55 6b25aaf 7c29ed9 c325d55 7c29ed9 c325d55 7c29ed9 6b25aaf 7c29ed9 6b25aaf 7c29ed9 6b25aaf 7c29ed9 6b25aaf 7c29ed9 6b25aaf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import os
from flask import Flask, request, send_file
from io import BytesIO
from PIL import Image, ImageChops
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
import torch
# Flaskアプリケーションの初期化
app = Flask(__name__)
# グローバルなモデル辞書
models_load = {}
def load_model(model_name):
global models_load
if model_name not in models_load:
try:
scheduler = EulerDiscreteScheduler.from_pretrained(model_name, subfolder="scheduler")
pipe = StableDiffusionPipeline.from_pretrained(model_name, scheduler=scheduler, torch_dtype=torch.float16)
pipe = pipe.to("cuda")
models_load[model_name] = pipe
print(f"Model {model_name} loaded successfully.")
except Exception as error:
print(f"Error loading model {model_name}: {error}")
models_load[model_name] = None
def gen_fn(model_str, prompt, negative_prompt=None, noise=None, cfg_scale=None, num_inference_steps=None):
if model_str not in models_load:
load_model(model_str)
if model_str in models_load and models_load[model_str] is not None:
if noise == "random":
noise = str(randint(0, 99999999999))
full_prompt = f'{prompt} {noise}' if noise else prompt
print(f"Prompt: {full_prompt}")
try:
# モデル呼び出し
result = models_load[model_str](full_prompt, num_inference_steps=num_inference_steps)
image = result.images[0] # 生成された画像を取得
# Check if the image is completely black
black = Image.new('RGB', image.size, (0, 0, 0))
if ImageChops.difference(image, black).getbbox() is None:
return None, 'The image is completely black.'
return image, None
except Exception as e:
print("Error generating image:", e)
return None, f"Error generating image: {e}"
else:
print(f"Model {model_str} not found")
return None, f"Model {model_str} not found"
@app.route('/', methods=['GET'])
def home():
prompt = request.args.get('prompt', '')
model = request.args.get('model', '')
noise = request.args.get('noise', None)
num_inference_steps = request.args.get('steps', 50) # デフォルト値を設定
try:
num_inference_steps = int(num_inference_steps)
except ValueError:
return 'Invalid "steps" parameter. It should be an integer.', 400
if not model:
return 'Please provide a "model" query parameter in the URL.', 400
if not prompt:
return 'Please provide a "prompt" query parameter in the URL.', 400
# Generate the image
image, error_message = gen_fn(model, prompt, noise=noise, num_inference_steps=num_inference_steps)
if error_message:
return error_message, 400
if isinstance(image, Image.Image):
img_io = BytesIO()
image.save(img_io, format='PNG')
img_io.seek(0)
return send_file(img_io, mimetype='image/png', as_attachment=False)
return 'Failed to generate image.', 500
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)
|