remiai3 commited on
Commit
529b6de
·
verified ·
1 Parent(s): 63709f0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -14
app.py CHANGED
@@ -1,6 +1,7 @@
1
  from flask import Flask, request, jsonify
2
  from flask_cors import CORS
3
  from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
 
4
  import torch
5
  import os
6
  from PIL import Image
@@ -42,20 +43,43 @@ def load_model(model_id):
42
  if model_id not in model_cache:
43
  logger.info(f"Loading model {model_id}...")
44
  try:
45
- pipe = StableDiffusionPipeline.from_pretrained(
46
- model_paths[model_id],
47
- torch_dtype=torch.float32,
48
- use_auth_token=os.getenv("HF_TOKEN"),
49
- use_safetensors=True,
50
- low_cpu_mem_usage=True
51
- )
52
- logger.info(f"Pipeline components loading for {model_id}...")
53
  if model_id == "ssd-1b":
54
- # Patch UNet config to remove problematic parameter
55
- if hasattr(pipe.unet.config, "reverse_transformer_layers_per_block"):
56
- logger.info(f"Original UNet config for {model_id}: {pipe.unet.config}")
57
- pipe.unet.config.reverse_transformer_layers_per_block = None
58
- logger.info(f"Patched UNet config for {model_id}: {pipe.unet.config}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
60
  pipe.enable_attention_slicing()
61
  pipe.to(torch.device("cpu"))
@@ -99,7 +123,7 @@ def generate():
99
  pipe.to(torch.device("cpu"))
100
 
101
  images = []
102
- num_inference_steps = 20 if model_id == 'ssd-1b' else 30
103
  for _ in range(num_images):
104
  image = pipe(
105
  prompt=prompt,
 
1
  from flask import Flask, request, jsonify
2
  from flask_cors import CORS
3
  from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
4
+ from diffusers.models import UNet2DConditionModel
5
  import torch
6
  import os
7
  from PIL import Image
 
43
  if model_id not in model_cache:
44
  logger.info(f"Loading model {model_id}...")
45
  try:
 
 
 
 
 
 
 
 
46
  if model_id == "ssd-1b":
47
+ # Preload UNet and patch configuration
48
+ logger.info(f"Preloading UNet for {model_id}")
49
+ unet_config = UNet2DConditionModel.load_config(
50
+ f"{model_paths[model_id]}/unet",
51
+ use_auth_token=os.getenv("HF_TOKEN")
52
+ )
53
+ if "reverse_transformer_layers_per_block" in unet_config:
54
+ logger.info(f"Original UNet config for {model_id}: {unet_config}")
55
+ unet_config["reverse_transformer_layers_per_block"] = None
56
+ logger.info(f"Patched UNet config for {model_id}: {unet_config}")
57
+ unet = UNet2DConditionModel.from_config(unet_config)
58
+ unet.load_state_dict(
59
+ torch.load(
60
+ f"{model_paths[model_id]}/unet/diffusion_pytorch_model.bin",
61
+ map_location="cpu"
62
+ )
63
+ )
64
+ # Load pipeline with patched UNet
65
+ pipe = StableDiffusionPipeline.from_pretrained(
66
+ model_paths[model_id],
67
+ unet=unet,
68
+ torch_dtype=torch.float32,
69
+ use_auth_token=os.getenv("HF_TOKEN"),
70
+ use_safetensors=True,
71
+ low_cpu_mem_usage=True
72
+ )
73
+ else:
74
+ # Standard loading for sd-v1-5
75
+ pipe = StableDiffusionPipeline.from_pretrained(
76
+ model_paths[model_id],
77
+ torch_dtype=torch.float32,
78
+ use_auth_token=os.getenv("HF_TOKEN"),
79
+ use_safetensors=True,
80
+ low_cpu_mem_usage=True
81
+ )
82
+ logger.info(f"Pipeline components loading for {model_id}...")
83
  pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
84
  pipe.enable_attention_slicing()
85
  pipe.to(torch.device("cpu"))
 
123
  pipe.to(torch.device("cpu"))
124
 
125
  images = []
126
+ num_inference_steps = 10 if model_id == 'ssd-1b' else 30
127
  for _ in range(num_images):
128
  image = pipe(
129
  prompt=prompt,