MilesCranmer commited on
Commit
a06de5e
1 Parent(s): 1f4e612

Just add back py-generated eval tree array

Browse files
Files changed (2) hide show
  1. julia/sr.jl +2 -8
  2. pysr/sr.py +20 -12
julia/sr.jl CHANGED
@@ -285,9 +285,7 @@ function evalTreeArray(tree::Node, cX::Array{Float32, 2})::Union{Array{Float32,
285
  return nothing
286
  end
287
  op_idx = tree.op
288
- @inbounds @simd for i=1:clen
289
- cumulator[i] = UNAOP(op_idx, cumulator[i])
290
- end
291
  @inbounds for i=1:clen
292
  if isinf(cumulator[i]) || isnan(cumulator[i])
293
  return nothing
@@ -303,12 +301,8 @@ function evalTreeArray(tree::Node, cX::Array{Float32, 2})::Union{Array{Float32,
303
  if array2 === nothing
304
  return nothing
305
  end
306
-
307
  op_idx = tree.op
308
-
309
- @inbounds @simd for i=1:clen
310
- cumulator[i] = BINOP(op_idx, cumulator[i], array2[i])
311
- end
312
  @inbounds for i=1:clen
313
  if isinf(cumulator[i]) || isnan(cumulator[i])
314
  return nothing
 
285
  return nothing
286
  end
287
  op_idx = tree.op
288
+ UNAOP!(cumulator, op_idx, clen)
 
 
289
  @inbounds for i=1:clen
290
  if isinf(cumulator[i]) || isnan(cumulator[i])
291
  return nothing
 
301
  if array2 === nothing
302
  return nothing
303
  end
 
304
  op_idx = tree.op
305
+ BINOP!(cumulator, array2, op_idx, clen)
 
 
 
306
  @inbounds for i=1:clen
307
  if isinf(cumulator[i]) || isnan(cumulator[i])
308
  return nothing
pysr/sr.py CHANGED
@@ -286,27 +286,35 @@ const limitPowComplexity = {"true" if limitPowComplexity else "false"}
286
 
287
  op_runner = ""
288
  if len(binary_operators) > 0:
289
- op_runner += f"""
290
- @inline function BINOP(i::Int, x::Float32, y::Float32)::Float32
291
- if i == 1
292
- return @fastmath {binary_operators[0]}(x, y)"""
 
 
293
  for i in range(1, len(binary_operators)):
294
  op_runner += f"""
295
- elseif i == {i+1}
296
- return @fastmath {binary_operators[i]}(x, y)"""
 
 
297
  op_runner += """
298
  end
299
  end"""
300
 
301
  if len(unary_operators) > 0:
302
- op_runner += f"""
303
- @inline function UNAOP(i::Int, x::Float32)::Float32
304
- if i == 1
305
- return @fastmath {unary_operators[0]}(x)"""
 
 
306
  for i in range(1, len(unary_operators)):
307
  op_runner += f"""
308
- elseif i == {i+1}
309
- return @fastmath {unary_operators[i]}(x)"""
 
 
310
  op_runner += """
311
  end
312
  end"""
 
286
 
287
  op_runner = ""
288
  if len(binary_operators) > 0:
289
+ op_runner += """
290
+ @inline function BINOP!(x::Array{Float32, 1}, y::Array{Float32, 1}, i::Int, clen::Int)
291
+ if i === 1
292
+ @inbounds @simd for j=1:clen
293
+ x[j] = """f"{binary_operators[0]}""""(x[j], y[j])
294
+ end"""
295
  for i in range(1, len(binary_operators)):
296
  op_runner += f"""
297
+ elseif i === {i+1}
298
+ @inbounds @simd for j=1:clen
299
+ x[j] = {binary_operators[i]}(x[j], y[j])
300
+ end"""
301
  op_runner += """
302
  end
303
  end"""
304
 
305
  if len(unary_operators) > 0:
306
+ op_runner += """
307
+ @inline function UNAOP!(x::Array{Float32, 1}, i::Int, clen::Int)
308
+ if i === 1
309
+ @inbounds @simd for j=1:clen
310
+ x[j] = """f"{unary_operators[0]}(x[j])""""
311
+ end"""
312
  for i in range(1, len(unary_operators)):
313
  op_runner += f"""
314
+ elseif i === {i+1}
315
+ @inbounds @simd for j=1:clen
316
+ x[j] = {unary_operators[i]}(x[j])
317
+ end"""
318
  op_runner += """
319
  end
320
  end"""