b
File size: 3,224 Bytes
6437aa7
f9b29bc
6437aa7
bc1a1bd
6b25aaf
 
6437aa7
6b25aaf
6437aa7
 
ac69805
 
f9b29bc
ac69805
 
 
 
6b25aaf
 
 
 
967362a
ac69805
 
6b25aaf
f9b29bc
6b25aaf
ac69805
37a6949
ac69805
6b25aaf
c325d55
 
 
622195c
6b25aaf
7ab2c9b
9327b29
37a6949
6b25aaf
 
967362a
c325d55
 
 
6b25aaf
7c29ed9
c325d55
7c29ed9
c325d55
 
 
 
 
 
7c29ed9
 
 
 
 
 
6b25aaf
7c29ed9
 
6b25aaf
7c29ed9
 
 
 
 
 
 
 
 
 
6b25aaf
7c29ed9
 
 
6b25aaf
7c29ed9
 
 
 
 
 
 
 
6b25aaf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
from flask import Flask, request, send_file
from io import BytesIO
from PIL import Image, ImageChops
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
import torch

# Flaskアプリケーションの初期化
app = Flask(__name__)

# グローバルなモデル辞書
models_load = {}

def load_model(model_name):
    global models_load
    if model_name not in models_load:
        try:
            scheduler = EulerDiscreteScheduler.from_pretrained(model_name, subfolder="scheduler")
            pipe = StableDiffusionPipeline.from_pretrained(model_name, scheduler=scheduler, torch_dtype=torch.float16)
            pipe = pipe.to("cuda")
            models_load[model_name] = pipe
            print(f"Model {model_name} loaded successfully.")
        except Exception as error:
            print(f"Error loading model {model_name}: {error}")
            models_load[model_name] = None

def gen_fn(model_str, prompt, negative_prompt=None, noise=None, cfg_scale=None, num_inference_steps=None):
    if model_str not in models_load:
        load_model(model_str)

    if model_str in models_load and models_load[model_str] is not None:
        if noise == "random":
            noise = str(randint(0, 99999999999))
        full_prompt = f'{prompt} {noise}' if noise else prompt

        print(f"Prompt: {full_prompt}")

        try:
            # モデル呼び出し
            result = models_load[model_str](full_prompt, num_inference_steps=num_inference_steps)
            image = result.images[0]  # 生成された画像を取得

            # Check if the image is completely black
            black = Image.new('RGB', image.size, (0, 0, 0))
            if ImageChops.difference(image, black).getbbox() is None:
                return None, 'The image is completely black.'

            return image, None

        except Exception as e:
            print("Error generating image:", e)
            return None, f"Error generating image: {e}"
    else:
        print(f"Model {model_str} not found")
        return None, f"Model {model_str} not found"

@app.route('/', methods=['GET'])
def home():
    prompt = request.args.get('prompt', '')
    model = request.args.get('model', '')
    noise = request.args.get('noise', None)
    num_inference_steps = request.args.get('steps', 50)  # デフォルト値を設定

    try:
        num_inference_steps = int(num_inference_steps)
    except ValueError:
        return 'Invalid "steps" parameter. It should be an integer.', 400
    
    if not model:
        return 'Please provide a "model" query parameter in the URL.', 400

    if not prompt:
        return 'Please provide a "prompt" query parameter in the URL.', 400
    
    # Generate the image
    image, error_message = gen_fn(model, prompt, noise=noise, num_inference_steps=num_inference_steps)
    if error_message:
        return error_message, 400
    
    if isinstance(image, Image.Image):
        img_io = BytesIO()
        image.save(img_io, format='PNG')
        img_io.seek(0)
        return send_file(img_io, mimetype='image/png', as_attachment=False)
    
    return 'Failed to generate image.', 500

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=7860)