Safetensors
aredden commited on
Commit
289aa1f
·
1 Parent(s): f9ba912

Small fixes & clean up

Browse files
Files changed (4) hide show
  1. .gitignore +5 -1
  2. image_encoder.py +1 -2
  3. main_gr.py +1 -1
  4. util.py +6 -11
.gitignore CHANGED
@@ -9,4 +9,8 @@ __pycache__
9
  *.mp3
10
  *.mp3
11
  *.txt
12
- .copilotignore
 
 
 
 
 
9
  *.mp3
10
  *.mp3
11
  *.txt
12
+ .copilotignore
13
+ .misc
14
+ BFL-flux-diffusers
15
+ .env
16
+ .env.*
image_encoder.py CHANGED
@@ -54,7 +54,7 @@ def test_real_img():
54
  img_hwc = torch.from_numpy(im).cuda().type(torch.float32)
55
  img_chw = img_hwc.permute(2, 0, 1).contiguous()
56
  img_gray = img_hwc.mean(dim=2, keepdim=False).contiguous().clamp(0, 255)
57
- tj = TurboImage()
58
  o = tj.encode_torch(img_chw)
59
  o2 = tj.encode_torch(img_hwc)
60
  o3 = tj.encode_torch(img_gray)
@@ -64,7 +64,6 @@ def test_real_img():
64
  f.write(o)
65
  with open("out_gray.jpg", "wb") as f:
66
  f.write(o3)
67
- # print(o)
68
 
69
 
70
  if __name__ == "__main__":
 
54
  img_hwc = torch.from_numpy(im).cuda().type(torch.float32)
55
  img_chw = img_hwc.permute(2, 0, 1).contiguous()
56
  img_gray = img_hwc.mean(dim=2, keepdim=False).contiguous().clamp(0, 255)
57
+ tj = ImageEncoder()
58
  o = tj.encode_torch(img_chw)
59
  o2 = tj.encode_torch(img_hwc)
60
  o3 = tj.encode_torch(img_gray)
 
64
  f.write(o)
65
  with open("out_gray.jpg", "wb") as f:
66
  f.write(o3)
 
67
 
68
 
69
  if __name__ == "__main__":
main_gr.py CHANGED
@@ -1,7 +1,7 @@
1
  import torch
2
 
3
  from flux_pipeline import FluxPipeline
4
- import gradio as gr
5
  from PIL import Image
6
 
7
 
 
1
  import torch
2
 
3
  from flux_pipeline import FluxPipeline
4
+ import gradio as gr # type: ignore
5
  from PIL import Image
6
 
7
 
util.py CHANGED
@@ -236,18 +236,13 @@ def load_autoencoder(config: ModelSpec) -> AutoEncoder:
236
  missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
237
  print_load_warning(missing, unexpected)
238
  if config.ae_quantization_dtype is not None:
239
- from quantize_swap_and_dispatch import _full_quant, into_qtype
240
-
241
  ae.to(into_device(config.ae_device))
242
- _full_quant(
243
- ae,
244
- max_quants=8000,
245
- current_quants=0,
246
- quantization_dtype=into_qtype(config.ae_quantization_dtype),
247
- )
248
- if config.offload_vae:
249
- ae.to("cpu")
250
- torch.cuda.empty_cache()
251
  return ae
252
 
253
 
 
236
  missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
237
  print_load_warning(missing, unexpected)
238
  if config.ae_quantization_dtype is not None:
 
 
239
  ae.to(into_device(config.ae_device))
240
+ from float8_quantize import recursive_swap_linears
241
+
242
+ recursive_swap_linears(ae)
243
+ if config.offload_vae:
244
+ ae.to("cpu")
245
+ torch.cuda.empty_cache()
 
 
 
246
  return ae
247
 
248