zaydzuhri's picture
Add files using upload-large-folder tool
199e69f verified
raw
history blame
1.36 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import triton
import triton.language as tl
@triton.jit
def get_tid():
return tl.inline_asm_elementwise(
"""
mov.u32 $0, %tid.x;
mov.u32 $1, %tid.y;
mov.u32 $2, %tid.z;
""",
"=r,=r,=r",
[],
dtype=(tl.uint32, tl.uint32, tl.uint32),
is_pure=True,
pack=1,
)
@triton.jit
def get_ntid():
return tl.inline_asm_elementwise(
"""
mov.u32 $0, %ntid.x;
mov.u32 $1, %ntid.y;
mov.u32 $2, %ntid.z;
""",
"=r,=r,=r",
[],
dtype=(tl.uint32, tl.uint32, tl.uint32),
is_pure=True,
pack=1,
)
@triton.jit
def get_flat_tid():
tid_x, tid_y, tid_z = get_tid()
ntid_x, ntid_y, _ = get_ntid()
return tid_z * ntid_y * ntid_x + tid_y * ntid_x + tid_x
@triton.jit
def get_flat_bid():
return (
tl.program_id(2) * tl.num_programs(1) * tl.num_programs(0)
+ tl.program_id(1) * tl.num_programs(0)
+ tl.program_id(0)
)
@triton.jit
def sync_threads():
tl.inline_asm_elementwise(
"bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1
)