Safetensors
aredden commited on
Commit
604f17d
·
1 Parent(s): 1f9e684

Add quantize embedders/modulation to argparse options

Browse files
Files changed (2) hide show
  1. main.py +18 -0
  2. util.py +4 -0
main.py CHANGED
@@ -129,6 +129,22 @@ def parse_args():
129
  + "and then saving the state_dict as a safetensors file), "
130
  + "which reduces the size of the checkpoint by about 50% & reduces startup time",
131
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  return parser.parse_args()
133
 
134
 
@@ -171,6 +187,8 @@ def main():
171
  offload_ae=args.offload_ae,
172
  offload_text_enc=args.offload_text_enc,
173
  prequantized_flow=args.prequantized_flow,
 
 
174
  )
175
  app.state.model = FluxPipeline.load_pipeline_from_config(config)
176
 
 
129
  + "and then saving the state_dict as a safetensors file), "
130
  + "which reduces the size of the checkpoint by about 50% & reduces startup time",
131
  )
132
+ parser.add_argument(
133
+ "-nqfm",
134
+ "--no-quantize-flow-modulation",
135
+ action="store_false",
136
+ default=True,
137
+ dest="quantize_modulation",
138
+ help="Disable quantization of the modulation layers in the flow model, adds ~2GB vram usage for moderate precision improvements",
139
+ )
140
+ parser.add_argument(
141
+ "-qfl",
142
+ "--quantize-flow-embedder-layers",
143
+ action="store_true",
144
+ default=False,
145
+ dest="quantize_flow_embedder_layers",
146
+ help="Quantize the flow embedder layers in the flow model, saves ~512MB vram usage, but precision loss is very noticeable",
147
+ )
148
  return parser.parse_args()
149
 
150
 
 
187
  offload_ae=args.offload_ae,
188
  offload_text_enc=args.offload_text_enc,
189
  prequantized_flow=args.prequantized_flow,
190
+ quantize_modulation=args.quantize_modulation,
191
+ quantize_flow_embedder_layers=args.quantize_flow_embedder_layers,
192
  )
193
  app.state.model = FluxPipeline.load_pipeline_from_config(config)
194
 
util.py CHANGED
@@ -135,6 +135,8 @@ def load_config(
135
  quant_text_enc: Optional[Literal["float8", "qint2", "qint4", "qint8"]] = None,
136
  quant_ae: bool = False,
137
  prequantized_flow: bool = False,
 
 
138
  ) -> ModelSpec:
139
  """
140
  Load a model configuration using the passed arguments.
@@ -202,6 +204,8 @@ def load_config(
202
  }.get(quant_text_enc, None),
203
  ae_quantization_dtype=QuantizationDtype.qfloat8 if quant_ae else None,
204
  prequantized_flow=prequantized_flow,
 
 
205
  )
206
 
207
 
 
135
  quant_text_enc: Optional[Literal["float8", "qint2", "qint4", "qint8"]] = None,
136
  quant_ae: bool = False,
137
  prequantized_flow: bool = False,
138
+ quantize_modulation: bool = True,
139
+ quantize_flow_embedder_layers: bool = False,
140
  ) -> ModelSpec:
141
  """
142
  Load a model configuration using the passed arguments.
 
204
  }.get(quant_text_enc, None),
205
  ae_quantization_dtype=QuantizationDtype.qfloat8 if quant_ae else None,
206
  prequantized_flow=prequantized_flow,
207
+ quantize_modulation=quantize_modulation,
208
+ quantize_flow_embedder_layers=quantize_flow_embedder_layers,
209
  )
210
 
211