File size: 37,131 Bytes
904ef7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
#include <stdint.h>

#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>

#include <ATen/cuda/CUDAContext.h>
#include <torch/torch.h>

#include <algorithm>
#include <stdexcept>

#include <cstdio>


#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")


template <typename T>
__host__ __device__ T div_round_up(T val, T divisor) {
	return (val + divisor - 1) / divisor;
}

template <typename scalar_t>
__global__ void kernel_sh(
    const scalar_t * __restrict__ inputs, 
    scalar_t * outputs, 
    uint32_t B, uint32_t D, uint32_t C,
    scalar_t * dy_dx
) {
	const uint32_t b = threadIdx.x + blockIdx.x * blockDim.x;
	if (b >= B) return;

	const uint32_t C2 = C * C;

	// locate
	inputs += b * D;
	outputs += b * C2;

	scalar_t x = inputs[0], y = inputs[1], z = inputs[2];

	scalar_t xy=x*y, xz=x*z, yz=y*z, x2=x*x, y2=y*y, z2=z*z, xyz=xy*z;
	scalar_t x4=x2*x2, y4=y2*y2, z4=z2*z2;
	scalar_t x6=x4*x2, y6=y4*y2, z6=z4*z2;

	auto write_sh = [&]() {
		outputs[0] = 0.28209479177387814f ;                          // 1/(2*sqrt(pi))
		if (C <= 1) { return; }
		outputs[1] = -0.48860251190291987f*y ;                               // -sqrt(3)*y/(2*sqrt(pi))
		outputs[2] = 0.48860251190291987f*z ;                                // sqrt(3)*z/(2*sqrt(pi))
		outputs[3] = -0.48860251190291987f*x ;                               // -sqrt(3)*x/(2*sqrt(pi))
		if (C <= 2) { return; }
		outputs[4] = 1.0925484305920792f*xy ;                                // sqrt(15)*xy/(2*sqrt(pi))
		outputs[5] = -1.0925484305920792f*yz ;                               // -sqrt(15)*yz/(2*sqrt(pi))
		outputs[6] = 0.94617469575755997f*z2 - 0.31539156525251999f ;                         // sqrt(5)*(3*z2 - 1)/(4*sqrt(pi))
		outputs[7] = -1.0925484305920792f*xz ;                               // -sqrt(15)*xz/(2*sqrt(pi))
		outputs[8] = 0.54627421529603959f*x2 - 0.54627421529603959f*y2 ;                              // sqrt(15)*(x2 - y2)/(4*sqrt(pi))
		if (C <= 3) { return; }
		outputs[9] = 0.59004358992664352f*y*(-3.0f*x2 + y2) ;                         // sqrt(70)*y*(-3*x2 + y2)/(8*sqrt(pi))
		outputs[10] = 2.8906114426405538f*xy*z ;                             // sqrt(105)*xy*z/(2*sqrt(pi))
		outputs[11] = 0.45704579946446572f*y*(1.0f - 5.0f*z2) ;                                // sqrt(42)*y*(1 - 5*z2)/(8*sqrt(pi))
		outputs[12] = 0.3731763325901154f*z*(5.0f*z2 - 3.0f) ;                         // sqrt(7)*z*(5*z2 - 3)/(4*sqrt(pi))
		outputs[13] = 0.45704579946446572f*x*(1.0f - 5.0f*z2) ;                                // sqrt(42)*x*(1 - 5*z2)/(8*sqrt(pi))
		outputs[14] = 1.4453057213202769f*z*(x2 - y2) ;                              // sqrt(105)*z*(x2 - y2)/(4*sqrt(pi))
		outputs[15] = 0.59004358992664352f*x*(-x2 + 3.0f*y2) ;                                // sqrt(70)*x*(-x2 + 3*y2)/(8*sqrt(pi))
		if (C <= 4) { return; }
		outputs[16] = 2.5033429417967046f*xy*(x2 - y2) ;                             // 3*sqrt(35)*xy*(x2 - y2)/(4*sqrt(pi))
		outputs[17] = 1.7701307697799304f*yz*(-3.0f*x2 + y2) ;                                // 3*sqrt(70)*yz*(-3*x2 + y2)/(8*sqrt(pi))
		outputs[18] = 0.94617469575756008f*xy*(7.0f*z2 - 1.0f) ;                               // 3*sqrt(5)*xy*(7*z2 - 1)/(4*sqrt(pi))
		outputs[19] = 0.66904654355728921f*yz*(3.0f - 7.0f*z2) ;                               // 3*sqrt(10)*yz*(3 - 7*z2)/(8*sqrt(pi))
		outputs[20] = -3.1735664074561294f*z2 + 3.7024941420321507f*z4 + 0.31735664074561293f ;                                // 3*(-30*z2 + 35*z4 + 3)/(16*sqrt(pi))
		outputs[21] = 0.66904654355728921f*xz*(3.0f - 7.0f*z2) ;                               // 3*sqrt(10)*xz*(3 - 7*z2)/(8*sqrt(pi))
		outputs[22] = 0.47308734787878004f*(x2 - y2)*(7.0f*z2 - 1.0f) ;                                // 3*sqrt(5)*(x2 - y2)*(7*z2 - 1)/(8*sqrt(pi))
		outputs[23] = 1.7701307697799304f*xz*(-x2 + 3.0f*y2) ;                                // 3*sqrt(70)*xz*(-x2 + 3*y2)/(8*sqrt(pi))
		outputs[24] = -3.7550144126950569f*x2*y2 + 0.62583573544917614f*x4 + 0.62583573544917614f*y4 ;                         // 3*sqrt(35)*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi))
		if (C <= 5) { return; }
		outputs[25] = 0.65638205684017015f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ;                            // 3*sqrt(154)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))
		outputs[26] = 8.3026492595241645f*xy*z*(x2 - y2) ;                           // 3*sqrt(385)*xy*z*(x2 - y2)/(4*sqrt(pi))
		outputs[27] = -0.48923829943525038f*y*(3.0f*x2 - y2)*(9.0f*z2 - 1.0f) ;                         // -sqrt(770)*y*(3*x2 - y2)*(9*z2 - 1)/(32*sqrt(pi))
		outputs[28] = 4.7935367849733241f*xy*z*(3.0f*z2 - 1.0f) ;                              // sqrt(1155)*xy*z*(3*z2 - 1)/(4*sqrt(pi))
		outputs[29] = 0.45294665119569694f*y*(14.0f*z2 - 21.0f*z4 - 1.0f) ;                             // sqrt(165)*y*(14*z2 - 21*z4 - 1)/(16*sqrt(pi))
		outputs[30] = 0.1169503224534236f*z*(-70.0f*z2 + 63.0f*z4 + 15.0f) ;                            // sqrt(11)*z*(-70*z2 + 63*z4 + 15)/(16*sqrt(pi))
		outputs[31] = 0.45294665119569694f*x*(14.0f*z2 - 21.0f*z4 - 1.0f) ;                             // sqrt(165)*x*(14*z2 - 21*z4 - 1)/(16*sqrt(pi))
		outputs[32] = 2.3967683924866621f*z*(x2 - y2)*(3.0f*z2 - 1.0f) ;                               // sqrt(1155)*z*(x2 - y2)*(3*z2 - 1)/(8*sqrt(pi))
		outputs[33] = -0.48923829943525038f*x*(x2 - 3.0f*y2)*(9.0f*z2 - 1.0f) ;                         // -sqrt(770)*x*(x2 - 3*y2)*(9*z2 - 1)/(32*sqrt(pi))
		outputs[34] = 2.0756623148810411f*z*(-6.0f*x2*y2 + x4 + y4) ;                         // 3*sqrt(385)*z*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi))
		outputs[35] = 0.65638205684017015f*x*(10.0f*x2*y2 - x4 - 5.0f*y4) ;                            // 3*sqrt(154)*x*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi))
		if (C <= 6) { return; }
		outputs[36] = 1.3663682103838286f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ;                               // sqrt(6006)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi))
		outputs[37] = 2.3666191622317521f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ;                            // 3*sqrt(2002)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))
		outputs[38] = 2.0182596029148963f*xy*(x2 - y2)*(11.0f*z2 - 1.0f) ;                             // 3*sqrt(91)*xy*(x2 - y2)*(11*z2 - 1)/(8*sqrt(pi))
		outputs[39] = -0.92120525951492349f*yz*(3.0f*x2 - y2)*(11.0f*z2 - 3.0f) ;                               // -sqrt(2730)*yz*(3*x2 - y2)*(11*z2 - 3)/(32*sqrt(pi))
		outputs[40] = 0.92120525951492349f*xy*(-18.0f*z2 + 33.0f*z4 + 1.0f) ;                           // sqrt(2730)*xy*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi))
		outputs[41] = 0.58262136251873131f*yz*(30.0f*z2 - 33.0f*z4 - 5.0f) ;                            // sqrt(273)*yz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi))
		outputs[42] = 6.6747662381009842f*z2 - 20.024298714302954f*z4 + 14.684485723822165f*z6 - 0.31784601133814211f ;                         // sqrt(13)*(105*z2 - 315*z4 + 231*z6 - 5)/(32*sqrt(pi))
		outputs[43] = 0.58262136251873131f*xz*(30.0f*z2 - 33.0f*z4 - 5.0f) ;                            // sqrt(273)*xz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi))
		outputs[44] = 0.46060262975746175f*(x2 - y2)*(11.0f*z2*(3.0f*z2 - 1.0f) - 7.0f*z2 + 1.0f) ;                               // sqrt(2730)*(x2 - y2)*(11*z2*(3*z2 - 1) - 7*z2 + 1)/(64*sqrt(pi))
		outputs[45] = -0.92120525951492349f*xz*(x2 - 3.0f*y2)*(11.0f*z2 - 3.0f) ;                               // -sqrt(2730)*xz*(x2 - 3*y2)*(11*z2 - 3)/(32*sqrt(pi))
		outputs[46] = 0.50456490072872406f*(11.0f*z2 - 1.0f)*(-6.0f*x2*y2 + x4 + y4) ;                          // 3*sqrt(91)*(11*z2 - 1)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi))
		outputs[47] = 2.3666191622317521f*xz*(10.0f*x2*y2 - x4 - 5.0f*y4) ;                            // 3*sqrt(2002)*xz*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi))
		outputs[48] = 10.247761577878714f*x2*y4 - 10.247761577878714f*x4*y2 + 0.6831841051919143f*x6 - 0.6831841051919143f*y6 ;                         // sqrt(6006)*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi))
		if (C <= 7) { return; }
		outputs[49] = 0.70716273252459627f*y*(-21.0f*x2*y4 + 35.0f*x4*y2 - 7.0f*x6 + y6) ;                              // 3*sqrt(715)*y*(-21*x2*y4 + 35*x4*y2 - 7*x6 + y6)/(64*sqrt(pi))
		outputs[50] = 5.2919213236038001f*xy*z*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ;                             // 3*sqrt(10010)*xy*z*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi))
		outputs[51] = -0.51891557872026028f*y*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + 5.0f*x4 + y4) ;                          // -3*sqrt(385)*y*(13*z2 - 1)*(-10*x2*y2 + 5*x4 + y4)/(64*sqrt(pi))
		outputs[52] = 4.1513246297620823f*xy*z*(x2 - y2)*(13.0f*z2 - 3.0f) ;                           // 3*sqrt(385)*xy*z*(x2 - y2)*(13*z2 - 3)/(8*sqrt(pi))
		outputs[53] = -0.15645893386229404f*y*(3.0f*x2 - y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ;                              // -3*sqrt(35)*y*(3*x2 - y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi))
		outputs[54] = 0.44253269244498261f*xy*z*(-110.0f*z2 + 143.0f*z4 + 15.0f) ;                              // 3*sqrt(70)*xy*z*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi))
		outputs[55] = 0.090331607582517306f*y*(-135.0f*z2 + 495.0f*z4 - 429.0f*z6 + 5.0f) ;                              // sqrt(105)*y*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi))
		outputs[56] = 0.068284276912004949f*z*(315.0f*z2 - 693.0f*z4 + 429.0f*z6 - 35.0f) ;                              // sqrt(15)*z*(315*z2 - 693*z4 + 429*z6 - 35)/(32*sqrt(pi))
		outputs[57] = 0.090331607582517306f*x*(-135.0f*z2 + 495.0f*z4 - 429.0f*z6 + 5.0f) ;                              // sqrt(105)*x*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi))
		outputs[58] = 0.07375544874083044f*z*(x2 - y2)*(143.0f*z2*(3.0f*z2 - 1.0f) - 187.0f*z2 + 45.0f) ;                         // sqrt(70)*z*(x2 - y2)*(143*z2*(3*z2 - 1) - 187*z2 + 45)/(64*sqrt(pi))
		outputs[59] = -0.15645893386229404f*x*(x2 - 3.0f*y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ;                              // -3*sqrt(35)*x*(x2 - 3*y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi))
		outputs[60] = 1.0378311574405206f*z*(13.0f*z2 - 3.0f)*(-6.0f*x2*y2 + x4 + y4) ;                         // 3*sqrt(385)*z*(13*z2 - 3)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi))
		outputs[61] = -0.51891557872026028f*x*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + x4 + 5.0f*y4) ;                          // -3*sqrt(385)*x*(13*z2 - 1)*(-10*x2*y2 + x4 + 5*y4)/(64*sqrt(pi))
		outputs[62] = 2.6459606618019f*z*(15.0f*x2*y4 - 15.0f*x4*y2 + x6 - y6) ;                               // 3*sqrt(10010)*z*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi))
		outputs[63] = 0.70716273252459627f*x*(-35.0f*x2*y4 + 21.0f*x4*y2 - x6 + 7.0f*y6) ;                              // 3*sqrt(715)*x*(-35*x2*y4 + 21*x4*y2 - x6 + 7*y6)/(64*sqrt(pi))
	};

	write_sh();

	if (dy_dx) {
		scalar_t *dx = dy_dx + b * D * C2;
		scalar_t *dy = dx + C2;
		scalar_t *dz = dy + C2;

		auto write_sh_dx = [&]() {
			dx[0] = 0.0f ;                             // 0
			if (C <= 1) { return; }
			dx[1] = 0.0f ;                             // 0
			dx[2] = 0.0f ;                             // 0
			dx[3] = -0.48860251190291992f ;                          // -sqrt(3)/(2*sqrt(pi))
			if (C <= 2) { return; }
			dx[4] = 1.0925484305920792f*y ;                          // sqrt(15)*y/(2*sqrt(pi))
			dx[5] = 0.0f ;                             // 0
			dx[6] = 0.0f ;                             // 0
			dx[7] = -1.0925484305920792f*z ;                         // -sqrt(15)*z/(2*sqrt(pi))
			dx[8] = 1.0925484305920792f*x ;                          // sqrt(15)*x/(2*sqrt(pi))
			if (C <= 3) { return; }
			dx[9] = -3.5402615395598609f*xy ;                                // -3*sqrt(70)*xy/(4*sqrt(pi))
			dx[10] = 2.8906114426405538f*yz ;                                // sqrt(105)*yz/(2*sqrt(pi))
			dx[11] = 0.0f ;                            // 0
			dx[12] = 0.0f ;                            // 0
			dx[13] = 0.45704579946446572f - 2.2852289973223288f*z2 ;                          // sqrt(42)*(1 - 5*z2)/(8*sqrt(pi))
			dx[14] = 2.8906114426405538f*xz ;                                // sqrt(105)*xz/(2*sqrt(pi))
			dx[15] = -1.7701307697799304f*x2 + 1.7701307697799304f*y2 ;                               // 3*sqrt(70)*(-x2 + y2)/(8*sqrt(pi))
			if (C <= 4) { return; }
			dx[16] = 2.5033429417967046f*y*(3.0f*x2 - y2) ;                           // 3*sqrt(35)*y*(3*x2 - y2)/(4*sqrt(pi))
			dx[17] = -10.620784618679583f*xy*z ;                             // -9*sqrt(70)*xy*z/(4*sqrt(pi))
			dx[18] = 0.94617469575756008f*y*(7.0f*z2 - 1.0f) ;                         // 3*sqrt(5)*y*(7*z2 - 1)/(4*sqrt(pi))
			dx[19] = 0.0f ;                            // 0
			dx[20] = 0.0f ;                            // 0
			dx[21] = 0.66904654355728921f*z*(3.0f - 7.0f*z2) ;                         // 3*sqrt(10)*z*(3 - 7*z2)/(8*sqrt(pi))
			dx[22] = 0.94617469575756008f*x*(7.0f*z2 - 1.0f) ;                         // 3*sqrt(5)*x*(7*z2 - 1)/(4*sqrt(pi))
			dx[23] = 5.3103923093397913f*z*(-x2 + y2) ;                              // 9*sqrt(70)*z*(-x2 + y2)/(8*sqrt(pi))
			dx[24] = 2.5033429417967046f*x*(x2 - 3.0f*y2) ;                           // 3*sqrt(35)*x*(x2 - 3*y2)/(4*sqrt(pi))
			if (C <= 5) { return; }
			dx[25] = 13.127641136803401f*xy*(-x2 + y2) ;                             // 15*sqrt(154)*xy*(-x2 + y2)/(8*sqrt(pi))
			dx[26] = 8.3026492595241645f*yz*(3.0f*x2 - y2) ;                          // 3*sqrt(385)*yz*(3*x2 - y2)/(4*sqrt(pi))
			dx[27] = 2.9354297966115022f*xy*(1.0f - 9.0f*z2) ;                         // 3*sqrt(770)*xy*(1 - 9*z2)/(16*sqrt(pi))
			dx[28] = 4.7935367849733241f*yz*(3.0f*z2 - 1.0f) ;                         // sqrt(1155)*yz*(3*z2 - 1)/(4*sqrt(pi))
			dx[29] = 0.0f ;                            // 0
			dx[30] = 0.0f ;                            // 0
			dx[31] = 6.3412531167397574f*z2 - 9.5118796751096362f*z4 - 0.45294665119569694f ;                          // sqrt(165)*(14*z2 - 21*z4 - 1)/(16*sqrt(pi))
			dx[32] = 4.7935367849733241f*xz*(3.0f*z2 - 1.0f) ;                         // sqrt(1155)*xz*(3*z2 - 1)/(4*sqrt(pi))
			dx[33] = -13.209434084751759f*x2*z2 + 1.4677148983057511f*x2 + 13.209434084751759f*y2*z2 - 1.4677148983057511f*y2 ;                         // 3*sqrt(770)*(-9*x2*z2 + x2 + 9*y2*z2 - y2)/(32*sqrt(pi))
			dx[34] = 8.3026492595241645f*xz*(x2 - 3.0f*y2) ;                          // 3*sqrt(385)*xz*(x2 - 3*y2)/(4*sqrt(pi))
			dx[35] = 19.6914617052051f*x2*y2 - 3.2819102842008503f*x4 - 3.2819102842008503f*y4 ;                               // 15*sqrt(154)*(6*x2*y2 - x4 - y4)/(32*sqrt(pi))
			if (C <= 6) { return; }
			dx[36] = 4.0991046311514854f*y*(-10.0f*x2*y2 + 5.0f*x4 + y4) ;                             // 3*sqrt(6006)*y*(-10*x2*y2 + 5*x4 + y4)/(32*sqrt(pi))
			dx[37] = 47.332383244635047f*xy*z*(-x2 + y2) ;                           // 15*sqrt(2002)*xy*z*(-x2 + y2)/(8*sqrt(pi))
			dx[38] = 2.0182596029148963f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ;                           // 3*sqrt(91)*y*(3*x2 - y2)*(11*z2 - 1)/(8*sqrt(pi))
			dx[39] = 5.5272315570895412f*xy*z*(3.0f - 11.0f*z2) ;                              // 3*sqrt(2730)*xy*z*(3 - 11*z2)/(16*sqrt(pi))
			dx[40] = 0.92120525951492349f*y*(-18.0f*z2 + 33.0f*z4 + 1.0f) ;                             // sqrt(2730)*y*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi))
			dx[41] = 0.0f ;                            // 0
			dx[42] = 0.0f ;                            // 0
			dx[43] = 0.58262136251873131f*z*(30.0f*z2 - 33.0f*z4 - 5.0f) ;                              // sqrt(273)*z*(30*z2 - 33*z4 - 5)/(16*sqrt(pi))
			dx[44] = 0.92120525951492349f*x*(-18.0f*z2 + 33.0f*z4 + 1.0f) ;                             // sqrt(2730)*x*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi))
			dx[45] = -2.7636157785447706f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ;                              // -3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(32*sqrt(pi))
			dx[46] = 2.0182596029148963f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ;                           // 3*sqrt(91)*x*(x2 - 3*y2)*(11*z2 - 1)/(8*sqrt(pi))
			dx[47] = 11.833095811158762f*z*(6.0f*x2*y2 - x4 - y4) ;                           // 15*sqrt(2002)*z*(6*x2*y2 - x4 - y4)/(32*sqrt(pi))
			dx[48] = 4.0991046311514854f*x*(-10.0f*x2*y2 + x4 + 5.0f*y4) ;                             // 3*sqrt(6006)*x*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi))
			if (C <= 7) { return; }
			dx[49] = 9.9002782553443485f*xy*(10.0f*x2*y2 - 3.0f*x4 - 3.0f*y4) ;                         // 21*sqrt(715)*xy*(10*x2*y2 - 3*x4 - 3*y4)/(32*sqrt(pi))
			dx[50] = 15.875763970811402f*yz*(-10.0f*x2*y2 + 5.0f*x4 + y4) ;                            // 9*sqrt(10010)*yz*(-10*x2*y2 + 5*x4 + y4)/(32*sqrt(pi))
			dx[51] = -10.378311574405206f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ;                             // -15*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(16*sqrt(pi))
			dx[52] = 4.1513246297620823f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ;                          // 3*sqrt(385)*yz*(3*x2 - y2)*(13*z2 - 3)/(8*sqrt(pi))
			dx[53] = 0.93875360317376422f*xy*(66.0f*z2 - 143.0f*z4 - 3.0f) ;                            // 9*sqrt(35)*xy*(66*z2 - 143*z4 - 3)/(32*sqrt(pi))
			dx[54] = 0.44253269244498261f*yz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ;                         // 3*sqrt(70)*yz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi))
			dx[55] = 0.0f ;                            // 0
			dx[56] = 0.0f ;                            // 0
			dx[57] = -12.194767023639836f*z2 + 44.714145753346067f*z4 - 38.752259652899923f*z6 + 0.45165803791258652f ;                         // sqrt(105)*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi))
			dx[58] = 0.44253269244498261f*xz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ;                         // 3*sqrt(70)*xz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi))
			dx[59] = 30.97886890473422f*x2*z2 - 67.120882626924143f*x2*z4 - 1.4081304047606462f*x2 - 30.97886890473422f*y2*z2 + 67.120882626924143f*y2*z4 + 1.4081304047606462f*y2 ;                              // 9*sqrt(35)*(66*x2*z2 - 143*x2*z4 - 3*x2 - 66*y2*z2 + 143*y2*z4 + 3*y2)/(64*sqrt(pi))
			dx[60] = 4.1513246297620823f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ;                          // 3*sqrt(385)*xz*(x2 - 3*y2)*(13*z2 - 3)/(8*sqrt(pi))
			dx[61] = -0.51891557872026028f*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + 4.0f*x2*(x2 - 5.0f*y2) + x4 + 5.0f*y4) ;                              // -3*sqrt(385)*(13*z2 - 1)*(-10*x2*y2 + 4*x2*(x2 - 5*y2) + x4 + 5*y4)/(64*sqrt(pi))
			dx[62] = 15.875763970811402f*xz*(-10.0f*x2*y2 + x4 + 5.0f*y4) ;                            // 9*sqrt(10010)*xz*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi))
			dx[63] = -74.252086915082614f*x2*y4 + 74.252086915082614f*x4*y2 - 4.9501391276721742f*x6 + 4.9501391276721742f*y6 ;                         // 21*sqrt(715)*(-15*x2*y4 + 15*x4*y2 - x6 + y6)/(64*sqrt(pi))
		};

		auto write_sh_dy = [&]() {
			dy[0] = 0.0f ;                             // 0
			if (C <= 1) { return; }
			dy[1] = -0.48860251190291992f ;                          // -sqrt(3)/(2*sqrt(pi))
			dy[2] = 0.0f ;                             // 0
			dy[3] = 0.0f ;                             // 0
			if (C <= 2) { return; }
			dy[4] = 1.0925484305920792f*x ;                          // sqrt(15)*x/(2*sqrt(pi))
			dy[5] = -1.0925484305920792f*z ;                         // -sqrt(15)*z/(2*sqrt(pi))
			dy[6] = 0.0f ;                             // 0
			dy[7] = 0.0f ;                             // 0
			dy[8] = -1.0925484305920792f*y ;                         // -sqrt(15)*y/(2*sqrt(pi))
			if (C <= 3) { return; }
			dy[9] = -1.7701307697799304f*x2 + 1.7701307697799304f*y2 ;                                // 3*sqrt(70)*(-x2 + y2)/(8*sqrt(pi))
			dy[10] = 2.8906114426405538f*xz ;                                // sqrt(105)*xz/(2*sqrt(pi))
			dy[11] = 0.45704579946446572f - 2.2852289973223288f*z2 ;                          // sqrt(42)*(1 - 5*z2)/(8*sqrt(pi))
			dy[12] = 0.0f ;                            // 0
			dy[13] = 0.0f ;                            // 0
			dy[14] = -2.8906114426405538f*yz ;                               // -sqrt(105)*yz/(2*sqrt(pi))
			dy[15] = 3.5402615395598609f*xy ;                                // 3*sqrt(70)*xy/(4*sqrt(pi))
			if (C <= 4) { return; }
			dy[16] = 2.5033429417967046f*x*(x2 - 3.0f*y2) ;                           // 3*sqrt(35)*x*(x2 - 3*y2)/(4*sqrt(pi))
			dy[17] = 5.3103923093397913f*z*(-x2 + y2) ;                              // 9*sqrt(70)*z*(-x2 + y2)/(8*sqrt(pi))
			dy[18] = 0.94617469575756008f*x*(7.0f*z2 - 1.0f) ;                         // 3*sqrt(5)*x*(7*z2 - 1)/(4*sqrt(pi))
			dy[19] = 0.66904654355728921f*z*(3.0f - 7.0f*z2) ;                         // 3*sqrt(10)*z*(3 - 7*z2)/(8*sqrt(pi))
			dy[20] = 0.0f ;                            // 0
			dy[21] = 0.0f ;                            // 0
			dy[22] = 0.94617469575756008f*y*(1.0f - 7.0f*z2) ;                         // 3*sqrt(5)*y*(1 - 7*z2)/(4*sqrt(pi))
			dy[23] = 10.620784618679583f*xy*z ;                              // 9*sqrt(70)*xy*z/(4*sqrt(pi))
			dy[24] = 2.5033429417967046f*y*(-3.0f*x2 + y2) ;                          // 3*sqrt(35)*y*(-3*x2 + y2)/(4*sqrt(pi))
			if (C <= 5) { return; }
			dy[25] = 19.6914617052051f*x2*y2 - 3.2819102842008503f*x4 - 3.2819102842008503f*y4 ;                               // 15*sqrt(154)*(6*x2*y2 - x4 - y4)/(32*sqrt(pi))
			dy[26] = 8.3026492595241645f*xz*(x2 - 3.0f*y2) ;                          // 3*sqrt(385)*xz*(x2 - 3*y2)/(4*sqrt(pi))
			dy[27] = -1.4677148983057511f*(x2 - y2)*(9.0f*z2 - 1.0f) ;                         // -3*sqrt(770)*(x2 - y2)*(9*z2 - 1)/(32*sqrt(pi))
			dy[28] = 4.7935367849733241f*xz*(3.0f*z2 - 1.0f) ;                         // sqrt(1155)*xz*(3*z2 - 1)/(4*sqrt(pi))
			dy[29] = 6.3412531167397574f*z2 - 9.5118796751096362f*z4 - 0.45294665119569694f ;                          // sqrt(165)*(14*z2 - 21*z4 - 1)/(16*sqrt(pi))
			dy[30] = 0.0f ;                            // 0
			dy[31] = 0.0f ;                            // 0
			dy[32] = 4.7935367849733241f*yz*(1.0f - 3.0f*z2) ;                         // sqrt(1155)*yz*(1 - 3*z2)/(4*sqrt(pi))
			dy[33] = 2.9354297966115022f*xy*(9.0f*z2 - 1.0f) ;                         // 3*sqrt(770)*xy*(9*z2 - 1)/(16*sqrt(pi))
			dy[34] = 8.3026492595241645f*yz*(-3.0f*x2 + y2) ;                         // 3*sqrt(385)*yz*(-3*x2 + y2)/(4*sqrt(pi))
			dy[35] = 13.127641136803401f*xy*(x2 - y2) ;                              // 15*sqrt(154)*xy*(x2 - y2)/(8*sqrt(pi))
			if (C <= 6) { return; }
			dy[36] = 4.0991046311514854f*x*(-10.0f*x2*y2 + x4 + 5.0f*y4) ;                             // 3*sqrt(6006)*x*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi))
			dy[37] = 11.833095811158762f*z*(6.0f*x2*y2 - x4 - y4) ;                           // 15*sqrt(2002)*z*(6*x2*y2 - x4 - y4)/(32*sqrt(pi))
			dy[38] = 2.0182596029148963f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ;                           // 3*sqrt(91)*x*(x2 - 3*y2)*(11*z2 - 1)/(8*sqrt(pi))
			dy[39] = -2.7636157785447706f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ;                              // -3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(32*sqrt(pi))
			dy[40] = 0.92120525951492349f*x*(-18.0f*z2 + 33.0f*z4 + 1.0f) ;                             // sqrt(2730)*x*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi))
			dy[41] = 0.58262136251873131f*z*(30.0f*z2 - 33.0f*z4 - 5.0f) ;                              // sqrt(273)*z*(30*z2 - 33*z4 - 5)/(16*sqrt(pi))
			dy[42] = 0.0f ;                            // 0
			dy[43] = 0.0f ;                            // 0
			dy[44] = 0.92120525951492349f*y*(18.0f*z2 - 33.0f*z4 - 1.0f) ;                              // sqrt(2730)*y*(18*z2 - 33*z4 - 1)/(32*sqrt(pi))
			dy[45] = 5.5272315570895412f*xy*z*(11.0f*z2 - 3.0f) ;                              // 3*sqrt(2730)*xy*z*(11*z2 - 3)/(16*sqrt(pi))
			dy[46] = -2.0182596029148963f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ;                          // -3*sqrt(91)*y*(3*x2 - y2)*(11*z2 - 1)/(8*sqrt(pi))
			dy[47] = 47.332383244635047f*xy*z*(x2 - y2) ;                            // 15*sqrt(2002)*xy*z*(x2 - y2)/(8*sqrt(pi))
			dy[48] = 4.0991046311514854f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ;                              // 3*sqrt(6006)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))
			if (C <= 7) { return; }
			dy[49] = -74.252086915082614f*x2*y4 + 74.252086915082614f*x4*y2 - 4.9501391276721742f*x6 + 4.9501391276721742f*y6 ;                         // 21*sqrt(715)*(-15*x2*y4 + 15*x4*y2 - x6 + y6)/(64*sqrt(pi))
			dy[50] = 15.875763970811402f*xz*(-10.0f*x2*y2 + x4 + 5.0f*y4) ;                            // 9*sqrt(10010)*xz*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi))
			dy[51] = 0.51891557872026028f*(13.0f*z2 - 1.0f)*(10.0f*x2*y2 - 5.0f*x4 + 4.0f*y2*(5.0f*x2 - y2) - y4) ;                                // 3*sqrt(385)*(13*z2 - 1)*(10*x2*y2 - 5*x4 + 4*y2*(5*x2 - y2) - y4)/(64*sqrt(pi))
			dy[52] = 4.1513246297620823f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ;                          // 3*sqrt(385)*xz*(x2 - 3*y2)*(13*z2 - 3)/(8*sqrt(pi))
			dy[53] = -0.46937680158688211f*(x2 - y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ;                             // -9*sqrt(35)*(x2 - y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi))
			dy[54] = 0.44253269244498261f*xz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ;                         // 3*sqrt(70)*xz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi))
			dy[55] = -12.194767023639836f*z2 + 44.714145753346067f*z4 - 38.752259652899923f*z6 + 0.45165803791258652f ;                         // sqrt(105)*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi))
			dy[56] = 0.0f ;                            // 0
			dy[57] = 0.0f ;                            // 0
			dy[58] = 0.44253269244498261f*yz*(110.0f*z2 - 143.0f*z4 - 15.0f) ;                          // 3*sqrt(70)*yz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi))
			dy[59] = 0.93875360317376422f*xy*(-66.0f*z2 + 143.0f*z4 + 3.0f) ;                           // 9*sqrt(35)*xy*(-66*z2 + 143*z4 + 3)/(32*sqrt(pi))
			dy[60] = -4.1513246297620823f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ;                         // -3*sqrt(385)*yz*(3*x2 - y2)*(13*z2 - 3)/(8*sqrt(pi))
			dy[61] = 10.378311574405206f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ;                              // 15*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(16*sqrt(pi))
			dy[62] = 15.875763970811402f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ;                             // 9*sqrt(10010)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))
			dy[63] = 9.9002782553443485f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ;                                // 21*sqrt(715)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi))
		};

		auto write_sh_dz = [&]() {
			dz[0] = 0.0f ;                             // 0
			if (C <= 1) { return; }
			dz[1] = 0.0f ;                             // 0
			dz[2] = 0.48860251190291992f ;                           // sqrt(3)/(2*sqrt(pi))
			dz[3] = 0.0f ;                             // 0
			if (C <= 2) { return; }
			dz[4] = 0.0f ;                             // 0
			dz[5] = -1.0925484305920792f*y ;                         // -sqrt(15)*y/(2*sqrt(pi))
			dz[6] = 1.8923493915151202f*z ;                          // 3*sqrt(5)*z/(2*sqrt(pi))
			dz[7] = -1.0925484305920792f*x ;                         // -sqrt(15)*x/(2*sqrt(pi))
			dz[8] = 0.0f ;                             // 0
			if (C <= 3) { return; }
			dz[9] = 0.0f ;                             // 0
			dz[10] = 2.8906114426405538f*xy ;                                // sqrt(105)*xy/(2*sqrt(pi))
			dz[11] = -4.5704579946446566f*yz ;                               // -5*sqrt(42)*yz/(4*sqrt(pi))
			dz[12] = 5.597644988851731f*z2 - 1.1195289977703462f ;                            // 3*sqrt(7)*(5*z2 - 1)/(4*sqrt(pi))
			dz[13] = -4.5704579946446566f*xz ;                               // -5*sqrt(42)*xz/(4*sqrt(pi))
			dz[14] = 1.4453057213202769f*x2 - 1.4453057213202769f*y2 ;                                // sqrt(105)*(x2 - y2)/(4*sqrt(pi))
			dz[15] = 0.0f ;                            // 0
			if (C <= 4) { return; }
			dz[16] = 0.0f ;                            // 0
			dz[17] = 1.7701307697799304f*y*(-3.0f*x2 + y2) ;                          // 3*sqrt(70)*y*(-3*x2 + y2)/(8*sqrt(pi))
			dz[18] = 13.246445740605839f*xy*z ;                              // 21*sqrt(5)*xy*z/(2*sqrt(pi))
			dz[19] = 2.0071396306718676f*y*(1.0f - 7.0f*z2) ;                          // 9*sqrt(10)*y*(1 - 7*z2)/(8*sqrt(pi))
			dz[20] = 14.809976568128603f*pow(z, 3) - 6.3471328149122579f*z ;                          // (105*z**3 - 45*z)/(4*sqrt(pi))
			dz[21] = 2.0071396306718676f*x*(1.0f - 7.0f*z2) ;                          // 9*sqrt(10)*x*(1 - 7*z2)/(8*sqrt(pi))
			dz[22] = 6.6232228703029197f*z*(x2 - y2) ;                               // 21*sqrt(5)*z*(x2 - y2)/(4*sqrt(pi))
			dz[23] = 1.7701307697799304f*x*(-x2 + 3.0f*y2) ;                          // 3*sqrt(70)*x*(-x2 + 3*y2)/(8*sqrt(pi))
			dz[24] = 0.0f ;                            // 0
			if (C <= 5) { return; }
			dz[25] = 0.0f ;                            // 0
			dz[26] = 8.3026492595241645f*xy*(x2 - y2) ;                              // 3*sqrt(385)*xy*(x2 - y2)/(4*sqrt(pi))
			dz[27] = 8.8062893898345074f*yz*(-3.0f*x2 + y2) ;                         // 9*sqrt(770)*yz*(-3*x2 + y2)/(16*sqrt(pi))
			dz[28] = 4.7935367849733241f*xy*(9.0f*z2 - 1.0f) ;                         // sqrt(1155)*xy*(9*z2 - 1)/(4*sqrt(pi))
			dz[29] = 12.682506233479513f*yz*(1.0f - 3.0f*z2) ;                         // 7*sqrt(165)*yz*(1 - 3*z2)/(4*sqrt(pi))
			dz[30] = -24.559567715218954f*z2 + 36.839351572828434f*z4 + 1.754254836801354f ;                           // 15*sqrt(11)*(-14*z2 + 21*z4 + 1)/(16*sqrt(pi))
			dz[31] = 12.682506233479513f*xz*(1.0f - 3.0f*z2) ;                         // 7*sqrt(165)*xz*(1 - 3*z2)/(4*sqrt(pi))
			dz[32] = 2.3967683924866621f*(x2 - y2)*(9.0f*z2 - 1.0f) ;                          // sqrt(1155)*(x2 - y2)*(9*z2 - 1)/(8*sqrt(pi))
			dz[33] = 8.8062893898345074f*xz*(-x2 + 3.0f*y2) ;                         // 9*sqrt(770)*xz*(-x2 + 3*y2)/(16*sqrt(pi))
			dz[34] = -12.453973889286246f*x2*y2 + 2.0756623148810411f*x4 + 2.0756623148810411f*y4 ;                            // 3*sqrt(385)*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi))
			dz[35] = 0.0f ;                            // 0
			if (C <= 6) { return; }
			dz[36] = 0.0f ;                            // 0
			dz[37] = 2.3666191622317521f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ;                              // 3*sqrt(2002)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))
			dz[38] = 44.401711264127719f*xy*z*(x2 - y2) ;                            // 33*sqrt(91)*xy*z*(x2 - y2)/(4*sqrt(pi))
			dz[39] = -2.7636157785447706f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ;                          // -3*sqrt(2730)*y*(3*x2 - y2)*(11*z2 - 1)/(32*sqrt(pi))
			dz[40] = 11.054463114179082f*xy*z*(11.0f*z2 - 3.0f) ;                              // 3*sqrt(2730)*xy*z*(11*z2 - 3)/(8*sqrt(pi))
			dz[41] = 2.9131068125936568f*y*(18.0f*z2 - 33.0f*z4 - 1.0f) ;                               // 5*sqrt(273)*y*(18*z2 - 33*z4 - 1)/(16*sqrt(pi))
			dz[42] = 2.6699064952403937f*z*(-30.0f*z2 + 33.0f*z4 + 5.0f) ;                              // 21*sqrt(13)*z*(-30*z2 + 33*z4 + 5)/(16*sqrt(pi))
			dz[43] = 2.9131068125936568f*x*(18.0f*z2 - 33.0f*z4 - 1.0f) ;                               // 5*sqrt(273)*x*(18*z2 - 33*z4 - 1)/(16*sqrt(pi))
			dz[44] = 5.5272315570895412f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ;                               // 3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(16*sqrt(pi))
			dz[45] = -2.7636157785447706f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ;                          // -3*sqrt(2730)*x*(x2 - 3*y2)*(11*z2 - 1)/(32*sqrt(pi))
			dz[46] = 11.10042781603193f*z*(-6.0f*x2*y2 + x4 + y4) ;                           // 33*sqrt(91)*z*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi))
			dz[47] = 2.3666191622317521f*x*(10.0f*x2*y2 - x4 - 5.0f*y4) ;                              // 3*sqrt(2002)*x*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi))
			dz[48] = 0.0f ;                            // 0
			if (C <= 7) { return; }
			dz[49] = 0.0f ;                            // 0
			dz[50] = 5.2919213236038001f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ;                                // 3*sqrt(10010)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi))
			dz[51] = 13.491805046726766f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ;                             // 39*sqrt(385)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))
			dz[52] = 12.453973889286248f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ;                              // 9*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(8*sqrt(pi))
			dz[53] = -6.8841930899409371f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ;                         // -33*sqrt(35)*yz*(3*x2 - y2)*(13*z2 - 3)/(16*sqrt(pi))
			dz[54] = 2.2126634622249131f*xy*(-66.0f*z2 + 143.0f*z4 + 3.0f) ;                            // 15*sqrt(70)*xy*(-66*z2 + 143*z4 + 3)/(32*sqrt(pi))
			dz[55] = 1.6259689364853116f*yz*(110.0f*z2 - 143.0f*z4 - 15.0f) ;                           // 9*sqrt(105)*yz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi))
			dz[56] = 64.528641681844675f*z2 - 236.60501950009714f*z4 + 205.05768356675085f*z6 - 2.3899496919201733f ;                           // 7*sqrt(15)*(135*z2 - 495*z4 + 429*z6 - 5)/(32*sqrt(pi))
			dz[57] = 1.6259689364853116f*xz*(110.0f*z2 - 143.0f*z4 - 15.0f) ;                           // 9*sqrt(105)*xz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi))
			dz[58] = 0.07375544874083044f*(x2 - y2)*(143.0f*z2*(3.0f*z2 - 1.0f) + 132.0f*z2*(13.0f*z2 - 5.0f) - 187.0f*z2 + 45.0f) ;                         // sqrt(70)*(x2 - y2)*(143*z2*(3*z2 - 1) + 132*z2*(13*z2 - 5) - 187*z2 + 45)/(64*sqrt(pi))
			dz[59] = -6.8841930899409371f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ;                         // -33*sqrt(35)*xz*(x2 - 3*y2)*(13*z2 - 3)/(16*sqrt(pi))
			dz[60] = 3.1134934723215619f*(13.0f*z2 - 1.0f)*(-6.0f*x2*y2 + x4 + y4) ;                            // 9*sqrt(385)*(13*z2 - 1)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi))
			dz[61] = 13.491805046726766f*xz*(10.0f*x2*y2 - x4 - 5.0f*y4) ;                             // 39*sqrt(385)*xz*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi))
			dz[62] = 39.6894099270285f*x2*y4 - 39.6894099270285f*x4*y2 + 2.6459606618019f*x6 - 2.6459606618019f*y6 ;                            // 3*sqrt(10010)*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi))
			dz[63] = 0.0f ;                            // 0
		};
		write_sh_dx();
		write_sh_dy();
		write_sh_dz();
	}
}


template <typename scalar_t>
__global__ void kernel_sh_backward(
    const scalar_t * __restrict__ grad,
	const scalar_t * __restrict__ inputs,
    uint32_t B, uint32_t D, uint32_t C,
    const scalar_t * __restrict__ dy_dx,
    scalar_t * grad_inputs
) {
	const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
	const uint32_t b = t / D;
	if (b >= B) return;

	const uint32_t d = t - b * D;
	const uint32_t C2 = C * C;

	// locate
	grad += b * C2;
	dy_dx += b * D * C2 + d * C2;

	for (int ch = 0; ch < C2; ch++) {
		grad_inputs[t] += grad[ch] * dy_dx[ch];
		//printf("t=%d, b=%d, d=%d, ch=%d, grad=%f (+= %f * %f)\n", t, b, d, ch, grad_inputs[t], grad[ch], dy_dx[ch]);
	}

}

// inputs: [B, D], float, in [0, 1]
// outputs: [B, L * C], float
template <typename scalar_t>
void sh_encode_forward_cuda(const scalar_t *inputs, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, scalar_t *dy_dx) {
	static constexpr uint32_t N_THREADS = 256;
	kernel_sh<scalar_t><<<div_round_up(B, N_THREADS), N_THREADS>>>(inputs, outputs, B, D, C, dy_dx);
}


template <typename scalar_t>
void sh_encode_backward_cuda(const scalar_t *grad, const scalar_t *inputs, const uint32_t B, const uint32_t D, const uint32_t C, scalar_t *dy_dx, scalar_t *grad_inputs) {
	static constexpr uint32_t N_THREADS = 256;
	kernel_sh_backward<scalar_t><<<div_round_up(B * D, N_THREADS), N_THREADS>>>(grad, inputs, B, D, C, dy_dx, grad_inputs);
}


void sh_encode_forward(at::Tensor inputs, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, at::optional<at::Tensor> dy_dx) {
    CHECK_CUDA(inputs);
    CHECK_CUDA(outputs);
    // CHECK_CUDA(dy_dx);
    
    CHECK_CONTIGUOUS(inputs);
    CHECK_CONTIGUOUS(outputs);
    // CHECK_CONTIGUOUS(dy_dx);

    CHECK_IS_FLOATING(inputs);
    CHECK_IS_FLOATING(outputs);
    // CHECK_IS_FLOATING(dy_dx);

    AT_DISPATCH_FLOATING_TYPES_AND_HALF(
    inputs.scalar_type(), "sh_encode_forward_cuda", ([&] {
		sh_encode_forward_cuda<scalar_t>(inputs.data_ptr<scalar_t>(), outputs.data_ptr<scalar_t>(), B, D, C, dy_dx.has_value() ? dy_dx.value().data_ptr<scalar_t>() : nullptr);
    }));	
}

void sh_encode_backward(at::Tensor grad, at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t C, at::Tensor dy_dx, at::Tensor grad_inputs) {    
    CHECK_CUDA(grad);
    CHECK_CUDA(inputs);
    CHECK_CUDA(dy_dx);
    CHECK_CUDA(grad_inputs);
    
    CHECK_CONTIGUOUS(grad);
    CHECK_CONTIGUOUS(inputs);
    CHECK_CONTIGUOUS(dy_dx);
    CHECK_CONTIGUOUS(grad_inputs);

    CHECK_IS_FLOATING(grad);
    CHECK_IS_FLOATING(inputs);
    CHECK_IS_FLOATING(dy_dx);
    CHECK_IS_FLOATING(grad_inputs);

    AT_DISPATCH_FLOATING_TYPES_AND_HALF(
    grad.scalar_type(), "sh_encode_backward_cuda", ([&] {
    	sh_encode_backward_cuda<scalar_t>(grad.data_ptr<scalar_t>(), inputs.data_ptr<scalar_t>(), B, D, C, dy_dx.data_ptr<scalar_t>(), grad_inputs.data_ptr<scalar_t>());
    }));	
}