Jordan Legg commited on
Commit
8a31b39
β€’
1 Parent(s): 4e6d911
Files changed (1) hide show
  1. app.py +34 -10
app.py CHANGED
@@ -5,10 +5,25 @@ import spaces
5
  import torch
6
  from diffusers import DiffusionPipeline
7
 
8
- dtype = torch.bfloat16
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
- pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  MAX_SEED = np.iinfo(np.int32).max
14
  MAX_IMAGE_SIZE = 2048
@@ -17,15 +32,24 @@ MAX_IMAGE_SIZE = 2048
17
  def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
18
  if randomize_seed:
19
  seed = random.randint(0, MAX_SEED)
20
- generator = torch.Generator().manual_seed(seed)
21
- image = pipe(
22
- prompt = prompt,
23
- width = width,
24
- height = height,
25
- num_inference_steps = num_inference_steps,
26
- generator = generator,
 
 
 
 
 
 
 
 
27
  guidance_scale=0.0
28
- ).images[0]
 
29
  return image, seed
30
 
31
  examples = [
 
5
  import torch
6
  from diffusers import DiffusionPipeline
7
 
 
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
 
10
+ # Load the model in FP16
11
+ pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.float16)
12
+
13
+ # Move the pipeline to GPU
14
+ pipe = pipe.to(device)
15
+
16
+ # Convert text encoders to full precision
17
+ pipe.text_encoder = pipe.text_encoder.to(torch.float32)
18
+ if hasattr(pipe, 'text_encoder_2'):
19
+ pipe.text_encoder_2 = pipe.text_encoder_2.to(torch.float32)
20
+
21
+ # Enable memory efficient attention if available
22
+ if hasattr(pipe, 'enable_xformers_memory_efficient_attention'):
23
+ pipe.enable_xformers_memory_efficient_attention()
24
+
25
+ # Compile the UNet for potential speedups
26
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
27
 
28
  MAX_SEED = np.iinfo(np.int32).max
29
  MAX_IMAGE_SIZE = 2048
 
32
  def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
33
  if randomize_seed:
34
  seed = random.randint(0, MAX_SEED)
35
+ generator = torch.Generator(device=device).manual_seed(seed)
36
+
37
+ # Use full precision for text encoding
38
+ with torch.no_grad():
39
+ text_inputs = pipe.tokenizer(prompt, return_tensors="pt").to(device)
40
+ text_embeddings = pipe.text_encoder(text_inputs.input_ids)[0]
41
+
42
+ # Use mixed precision for the rest of the pipeline
43
+ with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.float16):
44
+ image = pipe(
45
+ prompt_embeds=text_embeddings,
46
+ width=width,
47
+ height=height,
48
+ num_inference_steps=num_inference_steps,
49
+ generator=generator,
50
  guidance_scale=0.0
51
+ ).images[0]
52
+
53
  return image, seed
54
 
55
  examples = [