DotIN13 commited on
Commit
8f82b54
1 Parent(s): e3f0d7e

Fix inconsistent types in custom_autotune.py

Browse files
Files changed (1) hide show
  1. custom_autotune.py +7 -2
custom_autotune.py CHANGED
@@ -81,16 +81,21 @@ class Autotuner(triton.KernelInterface):
81
  # In my testing this gives decent results, and greatly reduces the amount of tuning required
82
  if self.nearest_power_of_two:
83
  key = tuple([2 ** int(math.log2(x) + 0.5) for x in key])
84
-
85
  if key not in self.cache:
86
  # prune configs
87
  pruned_configs = self.prune_configs(kwargs)
88
  bench_start = time.time()
89
  timings = {config: self._bench(*args, config=config, **kwargs)
90
  for config in pruned_configs}
 
 
 
 
 
91
  bench_end = time.time()
92
  self.bench_time = bench_end - bench_start
93
- self.cache[key] = builtins.min(timings, key=timings.get)
 
94
  self.hook(args)
95
  self.configs_timings = timings
96
  config = self.cache[key]
 
81
  # In my testing this gives decent results, and greatly reduces the amount of tuning required
82
  if self.nearest_power_of_two:
83
  key = tuple([2 ** int(math.log2(x) + 0.5) for x in key])
 
84
  if key not in self.cache:
85
  # prune configs
86
  pruned_configs = self.prune_configs(kwargs)
87
  bench_start = time.time()
88
  timings = {config: self._bench(*args, config=config, **kwargs)
89
  for config in pruned_configs}
90
+ temp = {}
91
+ for config in pruned_configs:
92
+ if isinstance(self._bench(*args, config=config, **kwargs), float) :
93
+ continue
94
+ temp[config] = {self._bench(*args, config=config, **kwargs)}
95
  bench_end = time.time()
96
  self.bench_time = bench_end - bench_start
97
+
98
+ self.cache[key] = builtins.min(temp, key=timings.get)
99
  self.hook(args)
100
  self.configs_timings = timings
101
  config = self.cache[key]