zwv9 commited on
Commit
fdc9e50
1 Parent(s): d6b300f
Files changed (1) hide show
  1. wfx.py +6 -1
wfx.py CHANGED
@@ -22,6 +22,8 @@ def parse_args():
22
  args.add_argument('--model', type=str, required=True)
23
  args.add_argument('--custom-pipeline', type=str, default=None)
24
  args.add_argument('--compile-mode', default='sfast', type=str, choices=['sfast', 'torch', 'no-compile'])
 
 
25
  return args.parse_args()
26
 
27
  def quantize_unet(m):
@@ -68,7 +70,10 @@ class WFX():
68
  except ImportError:
69
  logger.warning('triton not found, disabling triton')
70
 
71
- self.compiler_config.enable_cuda_graph = True
 
 
 
72
 
73
  for key in self.compiler_config.__dict__:
74
  logger.info(f'cc - {key}: {self.compiler_config.__dict__[key]}')
 
22
  args.add_argument('--model', type=str, required=True)
23
  args.add_argument('--custom-pipeline', type=str, default=None)
24
  args.add_argument('--compile-mode', default='sfast', type=str, choices=['sfast', 'torch', 'no-compile'])
25
+ args.add_argument('--enable-cuda-graph', action='store_true', default=False)
26
+ args.add_argument('--disable-prefer-lowp-gemm', action='store_true', default=False)
27
  return args.parse_args()
28
 
29
  def quantize_unet(m):
 
70
  except ImportError:
71
  logger.warning('triton not found, disabling triton')
72
 
73
+ self.compiler_config.enable_cuda_graph = args.enable_cuda_graph
74
+
75
+ if args.disable_prefer_lowp_gemm:
76
+ self.compiler_config.prefer_lowp_gemm = False
77
 
78
  for key in self.compiler_config.__dict__:
79
  logger.info(f'cc - {key}: {self.compiler_config.__dict__[key]}')