radames HF staff commited on
Commit
183a6f7
1 Parent(s): d06eb92

fix diffusers

Browse files
Files changed (1) hide show
  1. app.py +5 -15
app.py CHANGED
@@ -1,11 +1,7 @@
1
  import numpy as np
2
  import PIL.Image
3
  import torch
4
- from typing import List
5
- from diffusers.utils import numpy_to_pil
6
  from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
7
- from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS
8
- from fastapi import FastAPI
9
  import uvicorn
10
  from fastapi.middleware.cors import CORSMiddleware
11
  from fastapi.responses import RedirectResponse, StreamingResponse
@@ -16,7 +12,6 @@ from db import Database
16
  import uuid
17
  import logging
18
  from fastapi import FastAPI, Request, HTTPException
19
- from fastapi.middleware.cors import CORSMiddleware
20
  from asyncio import Lock
21
 
22
 
@@ -40,13 +35,11 @@ dtype = torch.bfloat16
40
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
41
  if torch.cuda.is_available():
42
  prior_pipeline = StableCascadePriorPipeline.from_pretrained(
43
- "stabilityai/stable-cascade-prior", torch_dtype=dtype
44
- ) # .to(device)
45
  decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained(
46
- "stabilityai/stable-cascade", torch_dtype=dtype
47
- ) # .to(device)
48
- prior_pipeline.to(device)
49
- decoder_pipeline.to(device)
50
 
51
  if USE_TORCH_COMPILE:
52
  prior_pipeline.prior = torch.compile(
@@ -67,16 +60,14 @@ def generate(
67
  prior_guidance_scale: float = 4.0,
68
  decoder_num_inference_steps: int = 10,
69
  decoder_guidance_scale: float = 0.0,
70
- num_images_per_prompt: int = 2,
71
  ) -> PIL.Image.Image:
72
-
73
  generator = torch.Generator().manual_seed(seed)
74
  prior_output = prior_pipeline(
75
  prompt=prompt,
76
  height=height,
77
  width=width,
78
  num_inference_steps=prior_num_inference_steps,
79
- timesteps=DEFAULT_STAGE_C_TIMESTEPS,
80
  negative_prompt=negative_prompt,
81
  guidance_scale=prior_guidance_scale,
82
  num_images_per_prompt=num_images_per_prompt,
@@ -133,7 +124,6 @@ async def generate_image(
133
 
134
  logging.info(f"Image not found in cache, generating new image")
135
  async with generate_lock:
136
-
137
  pil_image = generate(prompt, negative_prompt, seed)
138
  img_id = str(uuid.uuid4())
139
  img_path = IMGS_PATH / f"{img_id}.jpg"
 
1
  import numpy as np
2
  import PIL.Image
3
  import torch
 
 
4
  from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
 
 
5
  import uvicorn
6
  from fastapi.middleware.cors import CORSMiddleware
7
  from fastapi.responses import RedirectResponse, StreamingResponse
 
12
  import uuid
13
  import logging
14
  from fastapi import FastAPI, Request, HTTPException
 
15
  from asyncio import Lock
16
 
17
 
 
35
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
36
  if torch.cuda.is_available():
37
  prior_pipeline = StableCascadePriorPipeline.from_pretrained(
38
+ "stabilityai/stable-cascade-prior", variant="bf16", torch_dtype=torch.bfloat16
39
+ ).to(device)
40
  decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained(
41
+ "stabilityai/stable-cascade", variant="bf16", torch_dtype=torch.bfloat16
42
+ ).to(device)
 
 
43
 
44
  if USE_TORCH_COMPILE:
45
  prior_pipeline.prior = torch.compile(
 
60
  prior_guidance_scale: float = 4.0,
61
  decoder_num_inference_steps: int = 10,
62
  decoder_guidance_scale: float = 0.0,
63
+ num_images_per_prompt: int = 1,
64
  ) -> PIL.Image.Image:
 
65
  generator = torch.Generator().manual_seed(seed)
66
  prior_output = prior_pipeline(
67
  prompt=prompt,
68
  height=height,
69
  width=width,
70
  num_inference_steps=prior_num_inference_steps,
 
71
  negative_prompt=negative_prompt,
72
  guidance_scale=prior_guidance_scale,
73
  num_images_per_prompt=num_images_per_prompt,
 
124
 
125
  logging.info(f"Image not found in cache, generating new image")
126
  async with generate_lock:
 
127
  pil_image = generate(prompt, negative_prompt, seed)
128
  img_id = str(uuid.uuid4())
129
  img_path = IMGS_PATH / f"{img_id}.jpg"