File size: 1,625 Bytes
c41b22c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34614a7
c41b22c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import platform
import sys

import torch


def get_cuda_version():
    if torch.cuda.is_available():
        cuda_version = torch.version.cuda
        return f"cu{cuda_version.replace('.', '')[:2]}"  # 例如:cu121
    return "cpu"


def get_torch_version():
    return f"torch{torch.__version__.split('+')[0]}"[:-2]  # 例如:torch2.2


def get_python_version():
    version = sys.version_info
    return f"cp{version.major}{version.minor}"  # 例如:cp310


def get_abi_flag():
    return "abiTRUE" if torch._C._GLIBCXX_USE_CXX11_ABI else "abiFALSE"


def get_platform():
    system = platform.system().lower()
    machine = platform.machine().lower()
    if system == "linux" and machine == "x86_64":
        return "linux_x86_64"
    elif system == "windows" and machine == "amd64":
        return "win_amd64"
    elif system == "darwin" and machine == "x86_64":
        return "macosx_x86_64"
    else:
        raise ValueError(f"Unsupported platform: {system}_{machine}")


def generate_flash_attn_filename(flash_attn_version="2.7.2.post1"):
    cuda_version = get_cuda_version()
    torch_version = get_torch_version()
    python_version = get_python_version()
    abi_flag = get_abi_flag()
    platform_tag = get_platform()

    filename = (
        f"flash_attn-{flash_attn_version}+{cuda_version}{torch_version}cxx11{abi_flag}-"
        f"{python_version}-{python_version}-{platform_tag}.whl"
    )
    return filename


if __name__ == "__main__":
    try:
        filename = generate_flash_attn_filename()
        print(f"{filename}")
    except Exception as e:
        print("Error generating filename:", e)