remiai3 commited on
Commit
adde396
·
verified ·
1 Parent(s): 25f5903

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -27
app.py CHANGED
@@ -44,32 +44,47 @@ def load_model(model_id):
44
  logger.info(f"Loading model {model_id}...")
45
  try:
46
  if model_id == "ssd-1b":
47
- # Preload UNet and patch configuration
48
- logger.info(f"Preloading UNet for {model_id}")
49
- unet_config = UNet2DConditionModel.load_config(
50
- f"{model_paths[model_id]}/unet",
51
- use_auth_token=os.getenv("HF_TOKEN")
52
- )
53
- if "reverse_transformer_layers_per_block" in unet_config:
54
- logger.info(f"Original UNet config for {model_id}: {unet_config}")
55
- unet_config["reverse_transformer_layers_per_block"] = None
56
- logger.info(f"Patched UNet config for {model_id}: {unet_config}")
57
- unet = UNet2DConditionModel.from_config(unet_config)
58
- unet.load_state_dict(
59
- torch.load(
60
- f"{model_paths[model_id]}/unet/diffusion_pytorch_model.bin",
61
- map_location="cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  )
63
- )
64
- # Load pipeline with patched UNet
65
- pipe = StableDiffusionPipeline.from_pretrained(
66
- model_paths[model_id],
67
- unet=unet,
68
- torch_dtype=torch.float32,
69
- use_auth_token=os.getenv("HF_TOKEN"),
70
- use_safetensors=True,
71
- low_cpu_mem_usage=True
72
- )
73
  else:
74
  # Standard loading for sd-v1-5
75
  pipe = StableDiffusionPipeline.from_pretrained(
@@ -77,7 +92,8 @@ def load_model(model_id):
77
  torch_dtype=torch.float32,
78
  use_auth_token=os.getenv("HF_TOKEN"),
79
  use_safetensors=True,
80
- low_cpu_mem_usage=True
 
81
  )
82
  logger.info(f"Pipeline components loading for {model_id}...")
83
  pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
@@ -123,7 +139,7 @@ def generate():
123
  pipe.to(torch.device("cpu"))
124
 
125
  images = []
126
- num_inference_steps = 10 if model_id == 'ssd-1b' else 30
127
  for _ in range(num_images):
128
  image = pipe(
129
  prompt=prompt,
 
44
  logger.info(f"Loading model {model_id}...")
45
  try:
46
  if model_id == "ssd-1b":
47
+ # Try loading UNet from xet/unet
48
+ try:
49
+ logger.info(f"Preloading UNet for {model_id} from {model_paths[model_id]}/xet/unet")
50
+ unet_config = UNet2DConditionModel.load_config(
51
+ f"{model_paths[model_id]}/xet/unet",
52
+ use_auth_token=os.getenv("HF_TOKEN"),
53
+ force_download=True
54
+ )
55
+ if "reverse_transformer_layers_per_block" in unet_config:
56
+ logger.info(f"Original UNet config for {model_id}: {unet_config}")
57
+ unet_config["reverse_transformer_layers_per_block"] = None
58
+ logger.info(f"Patched UNet config for {model_id}: {unet_config}")
59
+ unet = UNet2DConditionModel.from_config(unet_config)
60
+ unet.load_state_dict(
61
+ torch.load(
62
+ f"{model_paths[model_id]}/xet/unet/diffusion_pytorch_model.bin",
63
+ map_location="cpu"
64
+ )
65
+ )
66
+ # Load pipeline with patched UNet
67
+ pipe = StableDiffusionPipeline.from_pretrained(
68
+ model_paths[model_id],
69
+ unet=unet,
70
+ torch_dtype=torch.float32,
71
+ use_auth_token=os.getenv("HF_TOKEN"),
72
+ use_safetensors=True,
73
+ low_cpu_mem_usage=True,
74
+ force_download=True
75
+ )
76
+ except Exception as e:
77
+ logger.warning(f"Failed to load UNet for {model_id}: {str(e)}")
78
+ logger.info(f"Falling back to standard pipeline loading for {model_id}")
79
+ # Fallback to standard pipeline
80
+ pipe = StableDiffusionPipeline.from_pretrained(
81
+ model_paths[model_id],
82
+ torch_dtype=torch.float32,
83
+ use_auth_token=os.getenv("HF_TOKEN"),
84
+ use_safetensors=True,
85
+ low_cpu_mem_usage=True,
86
+ force_download=True
87
  )
 
 
 
 
 
 
 
 
 
 
88
  else:
89
  # Standard loading for sd-v1-5
90
  pipe = StableDiffusionPipeline.from_pretrained(
 
92
  torch_dtype=torch.float32,
93
  use_auth_token=os.getenv("HF_TOKEN"),
94
  use_safetensors=True,
95
+ low_cpu_mem_usage=True,
96
+ force_download=True
97
  )
98
  logger.info(f"Pipeline components loading for {model_id}...")
99
  pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
 
139
  pipe.to(torch.device("cpu"))
140
 
141
  images = []
142
+ num_inference_steps = 20 if model_id == 'ssd-1b' else 30
143
  for _ in range(num_images):
144
  image = pipe(
145
  prompt=prompt,