softpick-340M-4096-batch16-steps100000
/
torchtitan
/experiments
/deepseek_v3
/symm_mem_recipes
/triton_utils.py
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the BSD-style license found in the | |
# LICENSE file in the root directory of this source tree. | |
import triton | |
import triton.language as tl | |
def get_tid(): | |
return tl.inline_asm_elementwise( | |
""" | |
mov.u32 $0, %tid.x; | |
mov.u32 $1, %tid.y; | |
mov.u32 $2, %tid.z; | |
""", | |
"=r,=r,=r", | |
[], | |
dtype=(tl.uint32, tl.uint32, tl.uint32), | |
is_pure=True, | |
pack=1, | |
) | |
def get_ntid(): | |
return tl.inline_asm_elementwise( | |
""" | |
mov.u32 $0, %ntid.x; | |
mov.u32 $1, %ntid.y; | |
mov.u32 $2, %ntid.z; | |
""", | |
"=r,=r,=r", | |
[], | |
dtype=(tl.uint32, tl.uint32, tl.uint32), | |
is_pure=True, | |
pack=1, | |
) | |
def get_flat_tid(): | |
tid_x, tid_y, tid_z = get_tid() | |
ntid_x, ntid_y, _ = get_ntid() | |
return tid_z * ntid_y * ntid_x + tid_y * ntid_x + tid_x | |
def get_flat_bid(): | |
return ( | |
tl.program_id(2) * tl.num_programs(1) * tl.num_programs(0) | |
+ tl.program_id(1) * tl.num_programs(0) | |
+ tl.program_id(0) | |
) | |
def sync_threads(): | |
tl.inline_asm_elementwise( | |
"bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1 | |
) | |