Files changed (1) hide show
  1. src/pipeline.py +105 -107
src/pipeline.py CHANGED
@@ -2,139 +2,137 @@ import os
2
  import torch
3
  import torch._dynamo
4
  import gc
5
-
6
  import json
7
  import transformers
8
  from huggingface_hub.constants import HF_HUB_CACHE
9
  from transformers import T5EncoderModel, T5TokenizerFast
10
  from PIL.Image import Image
11
- from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
12
  from pipelines.models import TextToImageRequest
13
  from optimum.quanto import requantize
14
- import json
15
-
16
  from torch import Generator
17
- from diffusers import FluxTransformer2DModel, DiffusionPipeline
18
-
19
- # MYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMYMY
20
- # ApricityApricityApricityApricityApricityApricityApricityApricityApricityApricityApricityApricityApricity
21
-
22
  from torch._dynamo import config
23
  from torch._inductor import config as ind_config
24
- import torch
25
- import math
26
- from typing import Dict, Any
27
-
28
- torch._dynamo.config.suppress_errors = True
29
- os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
30
- os.environ["TOKENIZERS_PARALLELISM"] = "True"
31
-
32
- ckpt_root = "MyApricity/FLUX_OPT_SCHNELL_1.2"
33
- revision_root = "488528b6f815bff1bbc747cf1e0947c77c544665"
34
- Pipeline = None
35
- use_com = False
36
-
37
-
38
- def optimize_torch():
39
- torch.backends.cuda.matmul.allow_tf32 = True
40
- torch.backends.cudnn.allow_tf32 = True
41
- torch.backends.cudnn.benchmark = True
42
- # torch.backends.cudnn.benchmark_limit = 20
43
- torch.set_float32_matmul_precision("high")
44
- # config.cache_size_limit = 10000000000
45
- # ind_config.shape_padding = True
46
-
47
- try:
48
- optimize_torch()
49
- except:
50
- print("nothing wrong")
51
-
52
- def delete_ca_che():
53
- torch.cuda.empty_cache()
54
- torch.cuda.reset_max_memory_allocated()
55
- torch.cuda.reset_peak_memory_stats()
56
-
57
-
58
-
59
- def pipeline_loader() -> Pipeline:
60
-
61
- print("Loading text encoder...")
62
- en = T5EncoderModel.from_pretrained(
63
- "city96/t5-v1_1-xxl-encoder-bf16",
64
- revision="1b9c856aadb864af93c1dcdc226c2774fa67bc86",
65
- torch_dtype=torch.bfloat16,
66
- )
67
-
68
- transformer_path_main = os.path.join(HF_HUB_CACHE, "models--MyApricity--FLUX_OPT_SCHNELL_1.2/snapshots/488528b6f815bff1bbc747cf1e0947c77c544665")
69
-
70
- transformer_model = FluxTransformer2DModel.from_pretrained(transformer_path_main, torch_dtype=torch.bfloat16, use_safetensors=False)
71
-
72
-
73
- pipe = DiffusionPipeline.from_pretrained(ckpt_root,
74
- revision=revision_root,
75
- transformer=transformer_model,
76
- torch_dtype=torch.bfloat16)
77
- pipe.to("cuda")
78
-
79
- try:
80
 
81
- # fuse QKV projections in Transformer and VAE
 
 
82
  pipe.transformer.fuse_qkv_projections()
83
  pipe.vae.fuse_qkv_projections()
84
 
85
- # switch memory layout to Torch's preferred, channels_last
86
  pipe.transformer.to(memory_format=torch.channels_last)
87
  pipe.vae.to(memory_format=torch.channels_last)
88
 
89
- # set torch compile flags
90
  config = torch._inductor.config
91
- config.disable_progress = False # show progress bar
92
- config.conv_1x1_as_mm = True # treat 1x1 convolutions as matrix muls
93
 
94
- # tag the compute-intensive modules, the Transformer and VAE decoder, for compilation
95
  pipe.transformer = torch.compile(
96
- pipe.transformer, mode="max-autotune", fullgraph=True
 
 
97
  )
98
  pipe.vae.decode = torch.compile(
99
- pipe.vae.decode, mode="max-autotune", fullgraph=True
 
 
100
  )
101
 
102
- # trigger torch compilation
103
- print("running torch compiliation..")
104
-
105
- pipe(
106
- "dummy prompt to trigger torch compilation",
107
- output_type="pil",
108
- num_inference_steps=4, # use ~50 for [dev], smaller for [schnell]
109
- ).images[0]
110
-
111
- print("finished torch compilation")
 
 
 
 
112
 
113
- except:
 
114
 
 
 
115
  pipe(
116
- "a beautiful girl",
117
  output_type="pil",
118
- num_inference_steps=4, # use ~50 for [dev], smaller for [schnell]
119
  ).images[0]
120
- print("Pass error")
121
-
122
-
123
- return pipe
124
-
125
-
126
- @torch.no_grad()
127
- def inference(request: TextToImageRequest, pipeline: Pipeline) -> Image:
128
-
129
- delete_ca_che()
130
- generator = Generator(pipeline.device).manual_seed(request.seed)
131
-
132
- return pipeline(
133
- request.prompt,
134
- generator=generator,
135
- guidance_scale=0.0,
136
- num_inference_steps=4,
137
- max_sequence_length=256,
138
- height=request.height,
139
- width=request.width,
140
- ).images[0]
 
2
  import torch
3
  import torch._dynamo
4
  import gc
 
5
  import json
6
  import transformers
7
  from huggingface_hub.constants import HF_HUB_CACHE
8
  from transformers import T5EncoderModel, T5TokenizerFast
9
  from PIL.Image import Image
10
+ from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny, FluxTransformer2DModel, DiffusionPipeline
11
  from pipelines.models import TextToImageRequest
12
  from optimum.quanto import requantize
 
 
13
  from torch import Generator
 
 
 
 
 
14
  from torch._dynamo import config
15
  from torch._inductor import config as ind_config
16
+ from typing import Dict, Any, Callable
17
+ from functools import wraps
18
+
19
+ def error_handler(func: Callable):
20
+ @wraps(func)
21
+ def wrapper(*args, **kwargs):
22
+ try:
23
+ return func(*args, **kwargs)
24
+ except Exception as e:
25
+ print(f"Error in {func.__name__}: {str(e)}")
26
+ return wrapper
27
+
28
+ class TorchOptimizer:
29
+ def optimize_settings(self):
30
+ torch.backends.cuda.matmul.allow_tf32 = True
31
+ torch.backends.cudnn.allow_tf32 = True
32
+ torch.backends.cudnn.benchmark = True
33
+ torch.set_float32_matmul_precision("high")
34
+
35
+ def clear_cache(self):
36
+ torch.cuda.empty_cache()
37
+ torch.cuda.reset_max_memory_allocated()
38
+ torch.cuda.reset_peak_memory_stats()
39
+
40
+ class PipelineManager:
41
+ def __init__(self):
42
+ self.ckpt_root = "MyApricity/FLUX_OPT_SCHNELL_1.2"
43
+ self.revision_root = "488528b6f815bff1bbc747cf1e0947c77c544665"
44
+ self.pipeline = None
45
+ self.optimizer = TorchOptimizer()
46
+
47
+ # Configure environment
48
+ torch._dynamo.config.suppress_errors = True
49
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "expandable_segments:True"
50
+ os.environ["TOKENIZERS_PARALLELISM"] = "True"
51
+
52
+ # Initialize torch settings
53
+ self.optimizer.optimize_settings()
54
+
55
+
56
+ def load_transformer(self):
57
+ transformer_path = os.path.join(
58
+ HF_HUB_CACHE,
59
+ "models--MyApricity--FLUX_OPT_SCHNELL_1.2/snapshots/488528b6f815bff1bbc747cf1e0947c77c544665"
60
+ )
61
+ return FluxTransformer2DModel.from_pretrained(
62
+ transformer_path,
63
+ torch_dtype=torch.bfloat16,
64
+ use_safetensors=False
65
+ )
 
 
 
 
 
 
66
 
67
+ @error_handler
68
+ def optimize_pipeline(self, pipe):
69
+ # Fuse QKV projections
70
  pipe.transformer.fuse_qkv_projections()
71
  pipe.vae.fuse_qkv_projections()
72
 
73
+ # Optimize memory layout
74
  pipe.transformer.to(memory_format=torch.channels_last)
75
  pipe.vae.to(memory_format=torch.channels_last)
76
 
77
+ # Configure torch inductor
78
  config = torch._inductor.config
79
+ config.disable_progress = False
80
+ config.conv_1x1_as_mm = True
81
 
82
+ # Compile modules
83
  pipe.transformer = torch.compile(
84
+ pipe.transformer,
85
+ mode="max-autotune",
86
+ fullgraph=True
87
  )
88
  pipe.vae.decode = torch.compile(
89
+ pipe.vae.decode,
90
+ mode="max-autotune",
91
+ fullgraph=True
92
  )
93
 
94
+ return pipe
95
+
96
+ def load_pipeline(self):
97
+ # Load transformer model
98
+ transformer_model = self.load_transformer()
99
+
100
+ # Create pipeline
101
+ pipe = DiffusionPipeline.from_pretrained(
102
+ self.ckpt_root,
103
+ revision=self.revision_root,
104
+ transformer=transformer_model,
105
+ torch_dtype=torch.bfloat16
106
+ )
107
+ pipe.to("cuda")
108
 
109
+ # Optimize pipeline
110
+ pipe = self.optimize_pipeline(pipe)
111
 
112
+ # Trigger compilation
113
+ print("Running torch compilation...")
114
  pipe(
115
+ "dummy prompt to trigger torch compilation",
116
  output_type="pil",
117
+ num_inference_steps=4
118
  ).images[0]
119
+ print("Finished torch compilation")
120
+
121
+ return pipe
122
+
123
+ def run_inference(self, request: TextToImageRequest) -> Image:
124
+ if self.pipeline is None:
125
+ self.pipeline = self.load_pipeline()
126
+
127
+ self.optimizer.clear_cache()
128
+ generator = Generator(self.pipeline.device).manual_seed(request.seed)
129
+
130
+ return self.pipeline(
131
+ request.prompt,
132
+ generator=generator,
133
+ guidance_scale=0.0,
134
+ num_inference_steps=4,
135
+ max_sequence_length=256,
136
+ height=request.height,
137
+ width=request.width,
138
+ ).images[0]