File size: 17,345 Bytes
9dd3461 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 |
#pragma once
#include <ATen/native/ForeachUtils.h>
#include <ATen/native/cuda/MultiTensorApply.cuh>
#include <ATen/OpMathType.h>
namespace at { namespace native {
namespace {
// Initializes args and checks if all args are aligned
template<int depth, typename T>
__device__ bool init_args(
T** args,
TensorListMetadata<depth>& tl,
int chunk_idx,
int chunk_size,
int tensor_loc) {
bool all_aligned = true;
for (int i = 0; i < depth; i++) {
args[i] = (T*)tl.addresses[i][tensor_loc];
args[i] += chunk_idx * chunk_size;
if (!is_aligned(args[i])) {
all_aligned = false;
}
}
return all_aligned;
}
// Initializes args and checks if all args are aligned
template<int depth, typename T, typename T2>
__device__ bool init_args(
T** args,
TensorListScalarListMetadata<T2, depth>& tl,
int chunk_idx,
int chunk_size,
int tensor_loc) {
bool all_aligned = true;
for (int i = 0; i < depth; i++) {
args[i] = (T*)tl.addresses[i][tensor_loc];
args[i] += chunk_idx * chunk_size;
if (!is_aligned(args[i])) {
all_aligned = false;
}
}
return all_aligned;
}
template<int depth, typename T>
__device__ bool init_args(
T** args,
FusedOptimizerTensorListMetadata<depth>& tl,
int chunk_idx,
int chunk_size,
int tensor_loc) {
bool all_aligned = true;
for (int i = 0; i < depth; i++) {
args[i] = (T*)tl.addresses[i][tensor_loc];
args[i] += chunk_idx * chunk_size;
if (!is_aligned(args[i])) {
all_aligned = false;
}
}
return all_aligned;
}
template<int depth, typename T>
__device__ void load_args(T r_args[][kILP], T** args, int i_start, int chunk_size, int n) {
#pragma unroll
for(int ii = 0; ii < kILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
for (int r_index = 0; r_index < depth; r_index++) {
r_args[r_index][ii] = 0;
if(i < n && i < chunk_size) {
r_args[r_index][ii] = args[r_index][i];
}
}
}
}
template<typename T>
__device__ void store_args(T* dst, T* src, int i_start, int chunk_size, int n) {
#pragma unroll
for(int ii = 0; ii < kILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if(i < n && i < chunk_size)
dst[i] = src[ii];
}
}
template<int res_arg_index, typename Op, typename T, typename opmath_t>
__device__ __forceinline__ void binary_op_scalar(
T r_args[][kILP],
T** args,
opmath_t scalar,
int n,
int chunk_size,
bool all_aligned,
Op op) {
// to make things simple, we put aligned case in a different code path
if(n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
for(int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) {
// load
load_store(r_args[0], args[0], 0, i_start);
#pragma unroll
for(int ii = 0; ii < kILP; ii++) {
r_args[0][ii] = static_cast<T>(op(static_cast<opmath_t>(r_args[0][ii]),
static_cast<opmath_t>(scalar)));
}
// store
load_store(args[res_arg_index], r_args[0], i_start, 0);
}
}
else {
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) {
// Regardless if depth is 1 (for inplace) or 2 (for out of place), r_args has depth 1
load_args<1>(r_args, args, i_start, chunk_size, n);
#pragma unroll
for(int ii = 0; ii < kILP; ii++) {
r_args[0][ii] = static_cast<T>(op(static_cast<opmath_t>(r_args[0][ii]),
static_cast<opmath_t>(scalar)));
}
store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
}
}
}
template<int res_arg_index, typename Op, typename T, typename opmath_t>
__device__ __forceinline__ void pointwise_op_scalar(
T r_args[][kILP],
T** args,
opmath_t scalar,
int n,
int chunk_size,
bool all_aligned,
Op op) {
// to make things simple, we put aligned case in a different code path
if(n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
for(int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) {
// load
load_store(r_args[0], args[0], 0, i_start);
load_store(r_args[1], args[1], 0, i_start);
load_store(r_args[2], args[2], 0, i_start);
#pragma unroll
for(int ii = 0; ii < kILP; ii++) {
r_args[0][ii] = static_cast<T>(static_cast<opmath_t>(r_args[0][ii]) +
scalar * op(static_cast<opmath_t>(r_args[1][ii]),
static_cast<opmath_t>(r_args[2][ii])));
}
// store
load_store(args[res_arg_index], r_args[0], i_start, 0);
}
}
else {
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) {
// Regardless if depth is 3 (for inplace) or 4 (for out of place), r_args has depth 3
load_args<3>(r_args, args, i_start, chunk_size, n);
#pragma unroll
for(int ii = 0; ii < kILP; ii++) {
r_args[0][ii] = static_cast<T>(static_cast<opmath_t>(r_args[0][ii]) +
scalar * op(static_cast<opmath_t>(r_args[1][ii]),
static_cast<opmath_t>(r_args[2][ii])));
}
store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
}
}
}
//
// Binary Functors
//
template<typename T, int depth, int r_args_depth, int res_arg_index>
struct BinaryOpScalarFunctor {
using opmath_t = at::opmath_type<T>;
template<typename Op> __device__ __forceinline__ void operator() (
int chunk_size,
TensorListMetadata<depth>& tl,
Op op,
opmath_t scalar) {
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.numel_for_tensor[tensor_loc];
T* args[depth];
bool all_aligned = init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
n -= chunk_idx * chunk_size;
T r_args[r_args_depth][kILP];
binary_op_scalar<res_arg_index>(r_args, args, scalar, n, chunk_size, all_aligned, op);
}
};
template<typename T, int depth, int r_args_depth, int res_arg_index>
struct BinaryOpScalarListFunctor {
using opmath_t = at::opmath_type<T>;
template<typename Op> __device__ __forceinline__ void operator() (
int chunk_size,
TensorListScalarListMetadata<opmath_t, depth>& tl,
Op op) {
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.numel_for_tensor[tensor_loc];
T* args[depth];
bool all_aligned = init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
opmath_t scalar = tl.scalar_vals[tensor_loc];
n -= chunk_idx * chunk_size;
T r_args[r_args_depth][kILP];
binary_op_scalar<res_arg_index>(r_args, args, scalar, n, chunk_size, all_aligned, op);
}
};
template<typename T, int depth, int r_args_depth, int res_arg_index>
struct BinaryOpListAlphaFunctor {
using opmath_t = at::opmath_type<T>;
template<typename Op> __device__ __forceinline__ void operator() (
int chunk_size,
TensorListMetadata<depth>& tl,
Op op,
opmath_t alpha) {
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.numel_for_tensor[tensor_loc];
T* args[depth];
bool all_aligned = init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
n -= chunk_idx * chunk_size;
T r_args[r_args_depth][kILP];
// to make things simple, we put aligned case in a different code path
if(n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
for(int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) {
// load
load_store(r_args[0], args[0], 0, i_start);
load_store(r_args[1], args[1], 0, i_start);
#pragma unroll
for(int ii = 0; ii < kILP; ii++) {
r_args[0][ii] = static_cast<T>(op(static_cast<opmath_t>(r_args[0][ii]),
alpha * static_cast<opmath_t>(r_args[1][ii])));
}
// store
load_store(args[res_arg_index], r_args[0], i_start , 0);
}
}
else {
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) {
load_args<r_args_depth>(r_args, args, i_start, chunk_size, n);
#pragma unroll
for(int ii = 0; ii < kILP; ii++) {
r_args[0][ii] = static_cast<T>(op(static_cast<opmath_t>(r_args[0][ii]),
alpha * static_cast<opmath_t>(r_args[1][ii])));
}
store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
}
}
}
};
//
// Unary Functors
//
template<typename T, int depth, int r_args_depth, int res_arg_index>
struct ZeroFunctor {
__device__ __forceinline__ void operator() (
int chunk_size,
TensorListMetadata<1>& tl) {
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.numel_for_tensor[tensor_loc];
T* args[depth];
bool all_aligned = init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
n -= chunk_idx * chunk_size;
T r_args[r_args_depth][kILP];
// to make things simple, we put aligned case in a different code path
if(n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
for(int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) {
#pragma unroll
for(int ii = 0; ii < kILP; ii++) {
r_args[0][ii] = 0;
}
// store
load_store(args[0], r_args[0], i_start, 0);
}
}
else {
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) {
#pragma unroll
for(int ii = 0; ii < kILP; ii++) {
r_args[0][ii] = 0;
}
store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
}
}
}
};
template<typename T, int depth, int r_args_depth, int res_arg_index>
struct UnaryOpFunctor {
using opmath_t = at::opmath_type<T>;
template<typename Op> __device__ __forceinline__ void operator() (
int chunk_size,
TensorListMetadata<depth>& tl,
Op op) {
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.numel_for_tensor[tensor_loc];
T* args[depth];
bool all_aligned = init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
n -= chunk_idx * chunk_size;
T r_args[r_args_depth][kILP];
// to make things simple, we put aligned case in a different code path
if(n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
for(int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) {
// load
load_store(r_args[0], args[0], 0, i_start);
#pragma unroll
for(int ii = 0; ii < kILP; ii++) {
r_args[0][ii] = static_cast<T>(op(static_cast<opmath_t>(r_args[0][ii])));
}
// store
load_store(args[res_arg_index], r_args[0], i_start, 0);
}
}
else {
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) {
load_args<r_args_depth>(r_args, args, i_start, chunk_size, n);
#pragma unroll
for(int ii = 0; ii < kILP; ii++) {
r_args[0][ii] = static_cast<T>(op(static_cast<opmath_t>(r_args[0][ii])));
}
store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
}
}
}
};
//
// Pointwise Functors
//
template<typename T, int depth, int r_args_depth, int res_arg_index>
struct PointwiseOpScalarFunctor {
using opmath_t = at::opmath_type<T>;
template<typename Op> __device__ __forceinline__ void operator() (
int chunk_size,
TensorListMetadata<depth>& tl,
Op op,
opmath_t scalar) {
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.numel_for_tensor[tensor_loc];
T* args[depth];
bool all_aligned = init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
n -= chunk_idx * chunk_size;
T r_args[r_args_depth][kILP];
pointwise_op_scalar<res_arg_index>(r_args, args, scalar, n, chunk_size, all_aligned, op);
}
};
template<typename T, int depth, int r_args_depth, int res_arg_index>
struct PointwiseOpScalarListFunctor {
using opmath_t = at::opmath_type<T>;
template<typename Op> __device__ __forceinline__ void operator() (
int chunk_size,
TensorListScalarListMetadata<opmath_t, depth>& tl,
Op op) {
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.numel_for_tensor[tensor_loc];
T* args[depth];
bool all_aligned = init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
opmath_t scalar = tl.scalar_vals[tensor_loc];
n -= chunk_idx * chunk_size;
T r_args[r_args_depth][kILP];
pointwise_op_scalar<res_arg_index>(r_args, args, scalar, n, chunk_size, all_aligned, op);
}
};
template<typename T, int depth>
struct PointwiseOpListFunctor {
using opmath_t = at::opmath_type<T>;
template<typename Op> __device__ __forceinline__ void operator() (
int chunk_size,
TensorListMetadata<depth>& tl,
Op op) {
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.numel_for_tensor[tensor_loc];
T* args[depth];
bool all_aligned = init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
n -= chunk_idx * chunk_size;
T r_args[depth - 1][kILP];
// to make things simple, we put aligned case in a different code path
if(n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
for(int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) {
// load
load_store(r_args[0], args[0], 0, i_start);
load_store(r_args[1], args[1], 0, i_start);
#pragma unroll
for(int ii = 0; ii < kILP; ii++) {
r_args[0][ii] = static_cast<T>(op(static_cast<opmath_t>(r_args[0][ii]),
static_cast<opmath_t>(r_args[1][ii])));
}
// store
load_store(args[2], r_args[0], i_start , 0);
}
}
else {
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) {
load_args<depth - 1>(r_args, args, i_start, chunk_size, n);
#pragma unroll
for(int ii = 0; ii < kILP; ii++) {
r_args[0][ii] = static_cast<T>(op(static_cast<opmath_t>(r_args[0][ii]),
static_cast<opmath_t>(r_args[1][ii])));
}
store_args(args[2], r_args[0], i_start, chunk_size, n);
}
}
}
};
} // namespace
}} // namespace at::native
|