Spaces:
Sleeping
Sleeping
File size: 2,480 Bytes
899c526 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
#ifndef DISPATCH_H
#define DISPATCH_H
#include <torch/extension.h>
#include "so3.h"
#include "rxso3.h"
#include "se3.h"
#include "sim3.h"
#define PRIVATE_CASE_TYPE(group_index, enum_type, type, ...) \
case enum_type: { \
using scalar_t = type; \
switch (group_index) { \
case 1: { \
using group_t = SO3<type>; \
return __VA_ARGS__(); \
} \
case 2: { \
using group_t = RxSO3<type>; \
return __VA_ARGS__(); \
} \
case 3: { \
using group_t = SE3<type>; \
return __VA_ARGS__(); \
} \
case 4: { \
using group_t = Sim3<type>; \
return __VA_ARGS__(); \
} \
} \
} \
#define DISPATCH_GROUP_AND_FLOATING_TYPES(GROUP_INDEX, TYPE, NAME, ...) \
[&] { \
const auto& the_type = TYPE; \
/* don't use TYPE again in case it is an expensive or side-effect op */ \
at::ScalarType _st = ::detail::scalar_type(the_type); \
switch (_st) { \
PRIVATE_CASE_TYPE(GROUP_INDEX, at::ScalarType::Double, double, __VA_ARGS__) \
PRIVATE_CASE_TYPE(GROUP_INDEX, at::ScalarType::Float, float, __VA_ARGS__) \
default: break; \
} \
}()
#endif
|