manbeast3b commited on
Commit
ed2924c
·
verified ·
1 Parent(s): 1781e0b

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +3 -3
src/pipeline.py CHANGED
@@ -9,7 +9,7 @@ from torch import Generator
9
 
10
  Pipeline = None
11
  # Consistent environment variable setting
12
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:1024" # More robust memory management
13
 
14
  ckpt_id = "black-forest-labs/FLUX.1-schnell"
15
 
@@ -18,7 +18,7 @@ def load_pipeline() -> Pipeline:
18
  # torch.cuda.empty_cache()
19
 
20
  dtype = torch.bfloat16
21
- text_encoder = T5EncoderModel.from_pretrained(
22
  "city96/t5-v1_1-xxl-encoder-bf16", torch_dtype=dtype
23
  ).to(memory_format=torch.channels_last)
24
  vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=dtype).to(memory_format=torch.channels_last)
@@ -27,7 +27,7 @@ def load_pipeline() -> Pipeline:
27
  pipeline = DiffusionPipeline.from_pretrained(
28
  ckpt_id,
29
  vae=vae,
30
- text_encoder=text_encoder,
31
  torch_dtype=dtype,
32
  )#.to("cuda")
33
 
 
9
 
10
  Pipeline = None
11
  # Consistent environment variable setting
12
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128" # More robust memory management
13
 
14
  ckpt_id = "black-forest-labs/FLUX.1-schnell"
15
 
 
18
  # torch.cuda.empty_cache()
19
 
20
  dtype = torch.bfloat16
21
+ text_encoder_2 = T5EncoderModel.from_pretrained(
22
  "city96/t5-v1_1-xxl-encoder-bf16", torch_dtype=dtype
23
  ).to(memory_format=torch.channels_last)
24
  vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=dtype).to(memory_format=torch.channels_last)
 
27
  pipeline = DiffusionPipeline.from_pretrained(
28
  ckpt_id,
29
  vae=vae,
30
+ text_encoder_2=text_encoder_2,
31
  torch_dtype=dtype,
32
  )#.to("cuda")
33