aifartist commited on
Commit
3a69d7a
·
verified ·
1 Parent(s): 39de971

Update gradio-app.py

Browse files
Files changed (1) hide show
  1. gradio-app.py +35 -5
gradio-app.py CHANGED
@@ -33,12 +33,42 @@ pipe.set_progress_bar_config(disable=True)
33
 
34
  if TORCH_COMPILE:
35
  #optmode = 'max-autotune'
36
- optmode = 'reduce-overhead'
37
- pipe.text_encoder = torch.compile(pipe.text_encoder, mode=optmode)
38
- pipe.tokenizer = torch.compile(pipe.tokenizer, mode=optmode)
39
- pipe.unet = torch.compile(pipe.unet, mode=optmode)
40
- pipe.vae = torch.compile(pipe.vae, mode=optmode)
 
 
 
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  def predict(prompt1, prompt2, merge_ratio, guidance, steps, sharpness, seed=1231231):
44
  torch.manual_seed(seed)
 
33
 
34
  if TORCH_COMPILE:
35
  #optmode = 'max-autotune'
36
+ #optmode = 'reduce-overhead'
37
+ #pipe.text_encoder = torch.compile(pipe.text_encoder, mode=optmode)
38
+ #pipe.tokenizer = torch.compile(pipe.tokenizer, mode=optmode)
39
+ #pipe.unet = torch.compile(pipe.unet, mode=optmode)
40
+ #pipe.vae = torch.compile(pipe.vae, mode=optmode)
41
+ doCompile = True
42
+ if doCompile:
43
+ config = CompilationConfig.Default()
44
 
45
+ try:
46
+ import xformers
47
+ config.enable_xformers = True
48
+ except ImportError:
49
+ print('xformers not installed, skipping')
50
+ try:
51
+ import triton
52
+ config.enable_triton = True
53
+ except ImportError:
54
+ print('Triton not installed, skipping')
55
+
56
+ config.enable_cuda_graph = True
57
+ config.enable_jit = True
58
+ config.enable_jit_freeze = True
59
+ config.trace_scheduler = False#True CHECK THIS AGAIN
60
+ config.enable_cnn_optimization = True
61
+ config.preserve_parameters = False
62
+ config.prefer_lowp_gemm = True
63
+ config.enable_fused_linear_geglu = True
64
+
65
+ torch.jit.optimize_for_inference = True
66
+ torch.jit.enable_onednn_fusion = True
67
+ torch.jit.set_fusion_strategy([('STATIC', 1), ('DYNAMIC', 1)])
68
+
69
+ for p in pipe.text_encoder.parameters(): p.requires_grad=False
70
+ for p in pipe.vae.decoder.parameters(): p.requires_grad=False
71
+ for p in pipe.unet.parameters(): p.requires_grad=False
72
 
73
  def predict(prompt1, prompt2, merge_ratio, guidance, steps, sharpness, seed=1231231):
74
  torch.manual_seed(seed)