Add quantize embedders/modulation to argparse options
Browse files
main.py
CHANGED
@@ -129,6 +129,22 @@ def parse_args():
|
|
129 |
+ "and then saving the state_dict as a safetensors file), "
|
130 |
+ "which reduces the size of the checkpoint by about 50% & reduces startup time",
|
131 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
return parser.parse_args()
|
133 |
|
134 |
|
@@ -171,6 +187,8 @@ def main():
|
|
171 |
offload_ae=args.offload_ae,
|
172 |
offload_text_enc=args.offload_text_enc,
|
173 |
prequantized_flow=args.prequantized_flow,
|
|
|
|
|
174 |
)
|
175 |
app.state.model = FluxPipeline.load_pipeline_from_config(config)
|
176 |
|
|
|
129 |
+ "and then saving the state_dict as a safetensors file), "
|
130 |
+ "which reduces the size of the checkpoint by about 50% & reduces startup time",
|
131 |
)
|
132 |
+
parser.add_argument(
|
133 |
+
"-nqfm",
|
134 |
+
"--no-quantize-flow-modulation",
|
135 |
+
action="store_false",
|
136 |
+
default=True,
|
137 |
+
dest="quantize_modulation",
|
138 |
+
help="Disable quantization of the modulation layers in the flow model, adds ~2GB vram usage for moderate precision improvements",
|
139 |
+
)
|
140 |
+
parser.add_argument(
|
141 |
+
"-qfl",
|
142 |
+
"--quantize-flow-embedder-layers",
|
143 |
+
action="store_true",
|
144 |
+
default=False,
|
145 |
+
dest="quantize_flow_embedder_layers",
|
146 |
+
help="Quantize the flow embedder layers in the flow model, saves ~512MB vram usage, but precision loss is very noticeable",
|
147 |
+
)
|
148 |
return parser.parse_args()
|
149 |
|
150 |
|
|
|
187 |
offload_ae=args.offload_ae,
|
188 |
offload_text_enc=args.offload_text_enc,
|
189 |
prequantized_flow=args.prequantized_flow,
|
190 |
+
quantize_modulation=args.quantize_modulation,
|
191 |
+
quantize_flow_embedder_layers=args.quantize_flow_embedder_layers,
|
192 |
)
|
193 |
app.state.model = FluxPipeline.load_pipeline_from_config(config)
|
194 |
|
util.py
CHANGED
@@ -135,6 +135,8 @@ def load_config(
|
|
135 |
quant_text_enc: Optional[Literal["float8", "qint2", "qint4", "qint8"]] = None,
|
136 |
quant_ae: bool = False,
|
137 |
prequantized_flow: bool = False,
|
|
|
|
|
138 |
) -> ModelSpec:
|
139 |
"""
|
140 |
Load a model configuration using the passed arguments.
|
@@ -202,6 +204,8 @@ def load_config(
|
|
202 |
}.get(quant_text_enc, None),
|
203 |
ae_quantization_dtype=QuantizationDtype.qfloat8 if quant_ae else None,
|
204 |
prequantized_flow=prequantized_flow,
|
|
|
|
|
205 |
)
|
206 |
|
207 |
|
|
|
135 |
quant_text_enc: Optional[Literal["float8", "qint2", "qint4", "qint8"]] = None,
|
136 |
quant_ae: bool = False,
|
137 |
prequantized_flow: bool = False,
|
138 |
+
quantize_modulation: bool = True,
|
139 |
+
quantize_flow_embedder_layers: bool = False,
|
140 |
) -> ModelSpec:
|
141 |
"""
|
142 |
Load a model configuration using the passed arguments.
|
|
|
204 |
}.get(quant_text_enc, None),
|
205 |
ae_quantization_dtype=QuantizationDtype.qfloat8 if quant_ae else None,
|
206 |
prequantized_flow=prequantized_flow,
|
207 |
+
quantize_modulation=quantize_modulation,
|
208 |
+
quantize_flow_embedder_layers=quantize_flow_embedder_layers,
|
209 |
)
|
210 |
|
211 |
|