File size: 6,700 Bytes
9d8c3ac
 
d3a3ae9
529b6de
9d8c3ac
 
 
 
 
 
 
c6498ec
 
84e9994
 
c6498ec
 
 
9d8c3ac
 
c6498ec
9d8c3ac
 
 
84e9994
 
9d8c3ac
c6498ec
9d8c3ac
 
c6498ec
 
 
9d8c3ac
c6498ec
9d8c3ac
c6498ec
 
 
 
9d8c3ac
 
c6498ec
 
 
bb3038b
d3a3ae9
adde396
d3a3ae9
 
 
 
 
 
 
 
 
 
 
 
 
adde396
600dd5f
adde396
 
 
 
 
 
 
 
 
 
8002a33
adde396
 
 
 
 
 
 
 
 
 
 
 
529b6de
 
 
 
 
 
 
adde396
 
529b6de
 
c6498ec
 
84e9994
c6498ec
 
 
 
 
 
9d8c3ac
 
 
c6498ec
9d8c3ac
5c955c9
 
 
 
9d8c3ac
 
c6498ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84e9994
c6498ec
 
adde396
c6498ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d8c3ac
 
c6498ec
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
from flask import Flask, request, jsonify
from flask_cors import CORS
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, DPMSolverMultistepScheduler
from diffusers.models import UNet2DConditionModel
import torch
import os
from PIL import Image
import base64
import time
import logging

# Disable GPU detection
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["CUDA_DEVICE_ORDER"] = ""
os.environ["TORCH_CUDA_ARCH_LIST"] = ""
torch.set_default_device("cpu")

app = Flask(__name__, static_folder='static')
CORS(app)

# Configure logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Log device in use
logger.info(f"Device in use: {torch.device('cpu')}")

# Model cache
model_cache = {}
model_paths = {
    "ssd-1b": "remiai3/ssd-1b",
    "sd-v1-5": "remiai3/stable-diffusion-v1-5"
}

# Image ratio to dimensions (optimized for CPU)
ratio_to_dims = {
    "1:1": (256, 256),
    "3:4": (192, 256),
    "16:9": (256, 144)
}

def load_model(model_id):
    if model_id not in model_cache:
        logger.info(f"Loading model {model_id}...")
        try:
            if model_id == "ssd-1b":
                # Try StableDiffusionXLPipeline first
                try:
                    logger.info(f"Attempting StableDiffusionXLPipeline for {model_id}")
                    pipe = StableDiffusionXLPipeline.from_pretrained(
                        model_paths[model_id],
                        torch_dtype=torch.float32,
                        use_auth_token=os.getenv("HF_TOKEN"),
                        use_safetensors=True,
                        low_cpu_mem_usage=True,
                        force_download=True
                    )
                except Exception as e:
                    logger.warning(f"StableDiffusionXLPipeline failed for {model_id}: {str(e)}")
                    logger.info(f"Falling back to StableDiffusionPipeline for {model_id}")
                    # Fallback to StableDiffusionPipeline with patched UNet
                    unet_config = UNet2DConditionModel.load_config(
                        f"{model_paths[model_id]}/unet",
                        use_auth_token=os.getenv("HF_TOKEN"),
                        force_download=True
                    )
                    if "reverse_transformer_layers_per_block" in unet_config:
                        logger.info(f"Original UNet config for {model_id}: {unet_config}")
                        unet_config["reverse_transformer_layers_per_block"] = None
                        logger.info(f"Patched UNet config for {model_id}: {unet_config}")
                    unet = UNet2DConditionModel.from_config(unet_config)
                    unet.load_state_dict(
                        torch.load(
                            f"{model_paths[model_id]}/unet/diffusion_pytorch_model.bin",
                            map_location="cpu"
                        )
                    )
                    pipe = StableDiffusionPipeline.from_pretrained(
                        model_paths[model_id],
                        unet=unet,
                        torch_dtype=torch.float32,
                        use_auth_token=os.getenv("HF_TOKEN"),
                        use_safetensors=True,
                        low_cpu_mem_usage=True,
                        force_download=True
                    )
            else:
                # Standard loading for sd-v1-5
                pipe = StableDiffusionPipeline.from_pretrained(
                    model_paths[model_id],
                    torch_dtype=torch.float32,
                    use_auth_token=os.getenv("HF_TOKEN"),
                    use_safetensors=True,
                    low_cpu_mem_usage=True,
                    force_download=True
                )
            logger.info(f"Pipeline components loading for {model_id}...")
            pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
            pipe.enable_attention_slicing()
            pipe.to(torch.device("cpu"))
            model_cache[model_id] = pipe
            logger.info(f"Model {model_id} loaded successfully")
        except Exception as e:
            logger.error(f"Error loading model {model_id}: {str(e)}")
            raise
    return model_cache[model_id]

@app.route('/')
def index():
    return app.send_static_file('index.html')

@app.route('/assets/<path:filename>')
def serve_assets(filename):
    return app.send_static_file(os.path.join('assets', filename))

@app.route('/generate', methods=['POST'])
def generate():
    try:
        data = request.json
        model_id = data.get('model', 'ssd-1b')
        prompt = data.get('prompt', '')
        ratio = data.get('ratio', '1:1')
        num_images = min(int(data.get('num_images', 1)), 4)
        guidance_scale = float(data.get('guidance_scale', 7.5))

        if not prompt:
            return jsonify({"error": "Prompt is required"}), 400

        if model_id == 'ssd-1b' and num_images > 1:
            return jsonify({"error": "SSD-1B allows only 1 image per generation"}), 400
        if model_id == 'ssd-1b' and ratio != '1:1':
            return jsonify({"error": "SSD-1B supports only 1:1 ratio"}), 400
        if model_id == 'sd-v1-5' and len(prompt.split()) > 77:
            return jsonify({"error": "Prompt exceeds 77 tokens for Stable Diffusion v1.5"}), 400

        width, height = ratio_to_dims.get(ratio, (256, 256))
        pipe = load_model(model_id)
        pipe.to(torch.device("cpu"))

        images = []
        num_inference_steps = 20 if model_id == 'ssd-1b' else 30
        for _ in range(num_images):
            image = pipe(
                prompt=prompt,
                height=height,
                width=width,
                num_inference_steps=num_inference_steps,
                guidance_scale=guidance_scale
            ).images[0]
            images.append(image)

        output_dir = "outputs"
        os.makedirs(output_dir, exist_ok=True)
        image_urls = []
        for i, img in enumerate(images):
            img_path = os.path.join(output_dir, f"generated_{int(time.time())}_{i}.png")
            img.save(img_path)
            with open(img_path, "rb") as f:
                img_data = base64.b64encode(f.read()).decode('utf-8')
            image_urls.append(f"data:image/png;base64,{img_data}")
            os.remove(img_path)

        return jsonify({"images": image_urls})

    except Exception as e:
        logger.error(f"Image generation failed: {str(e)}")
        return jsonify({"error": f"Image generation failed: {str(e)}"}), 500

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=7860)