abreza commited on
Commit
f244dfa
·
1 Parent(s): 6bd2666

try to find cuda home

Browse files
Files changed (1) hide show
  1. app.py +9 -15
app.py CHANGED
@@ -1,25 +1,19 @@
1
  import os
2
  import subprocess
 
3
 
4
- # Check common CUDA installation paths and set the correct one
5
- cuda_paths = [
6
- '/usr/local/cuda', # Default path
7
- '/usr/local/cuda-11.0', # Example of a versioned path
8
- '/usr/local/cuda-11.1',
9
- '/usr/local/cuda-11.2'
10
- ]
11
 
12
- cuda_home = None
13
- for path in cuda_paths:
14
- if os.path.exists(path):
15
- cuda_home = path
16
- break
17
 
18
- if cuda_home is None:
19
- raise EnvironmentError('CUDA installation not found. Please install CUDA or set CUDA_HOME manually.')
 
20
 
21
  # Set the CUDA_HOME environment variable
22
- os.environ['CUDA_HOME'] = '/usr/local/cuda'
23
  os.environ['PATH'] = os.environ['CUDA_HOME'] + '/bin:' + os.environ['PATH']
24
  os.environ['LD_LIBRARY_PATH'] = os.environ['CUDA_HOME'] + '/lib64:' + os.environ.get('LD_LIBRARY_PATH', '')
25
 
 
1
  import os
2
  import subprocess
3
+ import glob
4
 
5
+ # Find all CUDA directories that match /usr/local/cuda*
6
+ cuda_dirs = glob.glob('/usr/local/cuda*')
 
 
 
 
 
7
 
8
+ if not cuda_dirs:
9
+ raise EnvironmentError('No CUDA installation found. Please install CUDA or set CUDA_HOME manually.')
 
 
 
10
 
11
+ # Assume the highest version of CUDA is the one to use
12
+ cuda_dirs.sort()
13
+ cuda_home = cuda_dirs[-1]
14
 
15
  # Set the CUDA_HOME environment variable
16
+ os.environ['CUDA_HOME'] = cuda_home
17
  os.environ['PATH'] = os.environ['CUDA_HOME'] + '/bin:' + os.environ['PATH']
18
  os.environ['LD_LIBRARY_PATH'] = os.environ['CUDA_HOME'] + '/lib64:' + os.environ.get('LD_LIBRARY_PATH', '')
19