songxxzp commited on
Commit
c7d8998
1 Parent(s): 3485994

Update CPU kernel loading method

Browse files
Files changed (1) hide show
  1. quantization.py +49 -14
quantization.py CHANGED
@@ -103,7 +103,7 @@ class CPUKernel:
103
  self.int8WeightExtractionFloat = None
104
  self.int4WeightExtractionFloat = None
105
  self.int4WeightCompression = None
106
- self.SetNumThreads = None
107
 
108
  try:
109
  if not os.path.exists(default_cpu_kernel_code_path):
@@ -127,38 +127,74 @@ class CPUKernel:
127
  if compile_parallel_kernel and source_code == default_cpu_kernel_code_path:
128
  source_code = default_cpu_parallel_kernel_code_path
129
 
 
 
130
  if (not kernel_file) or (not os.path.exists(kernel_file)):
131
  print("No compiled kernel found.")
132
  try:
133
  if os.path.exists(source_code):
134
  print("Compiling kernels :", source_code)
135
  kernel_file = source_code[:-2] + ".so"
 
136
  if compile_parallel_kernel:
137
  compile_command = "gcc -O3 -fPIC -pthread -fopenmp -std=c99 {} -shared -o {}".format(source_code, kernel_file)
138
  print("Compiling", compile_command)
139
  exit_state = os.system(compile_command)
140
- if exit_state:
141
- print("Compile failed, using default cpu kernel code.")
 
 
 
 
 
 
 
 
 
 
 
 
142
  compile_parallel_kernel = False
143
  source_code = default_cpu_kernel_code_path
144
  kernel_file = source_code[:-2] + ".so"
145
- compile_command = "gcc -O3 -fPIC -std=c99 {} -shared -o {}".format(source_code, kernel_file)
146
- print("Compiling", compile_command)
147
- exit_state = os.system(compile_command)
148
- else:
149
  compile_command = "gcc -O3 -fPIC -std=c99 {} -shared -o {}".format(source_code, kernel_file)
150
  print("Compiling", compile_command)
151
  exit_state = os.system(compile_command)
152
-
153
- print("Kernels compiled :", kernel_file)
 
 
 
 
 
 
 
 
 
 
154
  else:
155
  print("Kernel source code not found.")
156
  return
157
  except:
158
- print("Failed to build kernel.")
 
 
 
159
  return
160
- if kernel_file:
161
- kernels = ctypes.cdll.LoadLibrary(kernel_file)
 
 
 
 
 
 
 
 
 
 
162
  self.int8WeightExtractionFloat = kernels.extract_int8_weight_to_float
163
  self.int4WeightExtractionFloat = kernels.extract_int4_weight_to_float
164
  self.int4WeightCompression = kernels.compress_int4_weight
@@ -167,11 +203,10 @@ class CPUKernel:
167
  self.SetNumThreads = kernels.set_num_threads
168
  except:
169
  print("No set_num_threads() found in kernel.")
170
- self.SetNumThreads = lambda x: x
171
  self.load = True
172
- print("Load kernel :", kernel_file)
173
  else:
174
  print("Failed to load kernel.")
 
175
 
176
  if compile_parallel_kernel:
177
  if parallel_num is None:
 
103
  self.int8WeightExtractionFloat = None
104
  self.int4WeightExtractionFloat = None
105
  self.int4WeightCompression = None
106
+ self.SetNumThreads = lambda x: x
107
 
108
  try:
109
  if not os.path.exists(default_cpu_kernel_code_path):
 
127
  if compile_parallel_kernel and source_code == default_cpu_kernel_code_path:
128
  source_code = default_cpu_parallel_kernel_code_path
129
 
130
+ kernels = None
131
+
132
  if (not kernel_file) or (not os.path.exists(kernel_file)):
133
  print("No compiled kernel found.")
134
  try:
135
  if os.path.exists(source_code):
136
  print("Compiling kernels :", source_code)
137
  kernel_file = source_code[:-2] + ".so"
138
+
139
  if compile_parallel_kernel:
140
  compile_command = "gcc -O3 -fPIC -pthread -fopenmp -std=c99 {} -shared -o {}".format(source_code, kernel_file)
141
  print("Compiling", compile_command)
142
  exit_state = os.system(compile_command)
143
+ if not exit_state:
144
+ try:
145
+ kernels = ctypes.cdll.LoadLibrary(kernel_file)
146
+ print("Load kernel :", kernel_file)
147
+ except:
148
+ kernels = None
149
+ print("Load parallel cpu kernel failed, using default cpu kernel code:")
150
+ import traceback
151
+ exception = traceback.format_exc()
152
+ print(exception)
153
+ else:
154
+ print("Compile default cpu kernel failed, using default cpu kernel code.")
155
+
156
+ if kernels is None: # adjust config, use default cpu kernel
157
  compile_parallel_kernel = False
158
  source_code = default_cpu_kernel_code_path
159
  kernel_file = source_code[:-2] + ".so"
160
+
161
+ if kernels is None:
 
 
162
  compile_command = "gcc -O3 -fPIC -std=c99 {} -shared -o {}".format(source_code, kernel_file)
163
  print("Compiling", compile_command)
164
  exit_state = os.system(compile_command)
165
+ if not exit_state:
166
+ try:
167
+ kernels = ctypes.cdll.LoadLibrary(kernel_file)
168
+ print("Load kernel :", kernel_file)
169
+ except:
170
+ kernels = None
171
+ print("Load default cpu kernel failed:")
172
+ import traceback
173
+ exception = traceback.format_exc()
174
+ print(exception)
175
+ else:
176
+ print("Compile default cpu kernel failed.")
177
  else:
178
  print("Kernel source code not found.")
179
  return
180
  except:
181
+ print("Failed to build cpu kernel:")
182
+ import traceback
183
+ exception = traceback.format_exc()
184
+ print(exception)
185
  return
186
+ else:
187
+ try:
188
+ kernels = ctypes.cdll.LoadLibrary(kernel_file)
189
+ print("Load kernel :", kernel_file)
190
+ except:
191
+ kernels = None
192
+ print("Load custom cpu kernel failed:")
193
+ import traceback
194
+ exception = traceback.format_exc()
195
+ print(exception)
196
+
197
+ if kernels is not None:
198
  self.int8WeightExtractionFloat = kernels.extract_int8_weight_to_float
199
  self.int4WeightExtractionFloat = kernels.extract_int4_weight_to_float
200
  self.int4WeightCompression = kernels.compress_int4_weight
 
203
  self.SetNumThreads = kernels.set_num_threads
204
  except:
205
  print("No set_num_threads() found in kernel.")
 
206
  self.load = True
 
207
  else:
208
  print("Failed to load kernel.")
209
+ return
210
 
211
  if compile_parallel_kernel:
212
  if parallel_num is None: