File size: 353 Bytes
375ee53
 
 
4e92ab0
375ee53
 
 
 
 
 
 
 
 
 
4e92ab0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import os
import shutil


def find_cuda():
    cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
    if cuda_home and os.path.exists(cuda_home):
        return cuda_home

    nvcc_path = shutil.which('nvcc')
    if nvcc_path:
        cuda_path = os.path.dirname(os.path.dirname(nvcc_path))
        return cuda_path

    return None