Stylique commited on
Commit
90a1041
Β·
verified Β·
1 Parent(s): 1814609

Upload 3 files

Browse files
Files changed (1) hide show
  1. post_install.py +36 -1
post_install.py CHANGED
@@ -57,7 +57,9 @@ def install_torch_sparse():
57
  print(f"PyTorch {torch.__version__} already installed with correct CUDA version")
58
  else:
59
  print(f"PyTorch {torch.__version__} installed, but need to update to 2.0.1+cu117")
60
- # First, install a compatible PyTorch version with CUDA 11.7 (as expected by PyTorch3D)
 
 
61
  print("Installing compatible PyTorch version...")
62
  if not run_command("pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu117"):
63
  return False
@@ -71,10 +73,28 @@ def install_torch_sparse():
71
  print("Checking PyTorch installation...")
72
  check_pytorch_cuda()
73
 
 
 
 
 
 
 
 
 
 
74
  # Now install torch-sparse with the compatible version
75
  print("Installing torch-sparse with PyTorch 2.0.1...")
76
  if run_command("pip install torch-sparse -f https://data.pyg.org/whl/torch-2.0.1+cu117.html"):
77
  print("Successfully installed torch-sparse")
 
 
 
 
 
 
 
 
 
78
  return True
79
 
80
  return False
@@ -95,6 +115,15 @@ def install_torch_scatter():
95
  print("Installing torch-scatter with PyTorch 2.0.1...")
96
  if run_command("pip install torch-scatter -f https://data.pyg.org/whl/torch-2.0.1+cu117.html"):
97
  print("Successfully installed torch-scatter")
 
 
 
 
 
 
 
 
 
98
  return True
99
 
100
  return False
@@ -286,6 +315,10 @@ def main():
286
  import torch
287
  print(f"βœ“ PyTorch {torch.__version__} - CUDA: {torch.cuda.is_available()}")
288
 
 
 
 
 
289
  import torch_sparse
290
  print("βœ“ torch-sparse")
291
 
@@ -303,9 +336,11 @@ def main():
303
 
304
  except ImportError as e:
305
  print(f"βœ— Import error: {e}")
 
306
  sys.exit(1)
307
  except Exception as e:
308
  print(f"βœ— Verification error: {e}")
 
309
  sys.exit(1)
310
 
311
  if __name__ == "__main__":
 
57
  print(f"PyTorch {torch.__version__} already installed with correct CUDA version")
58
  else:
59
  print(f"PyTorch {torch.__version__} installed, but need to update to 2.0.1+cu117")
60
+ # Uninstall current PyTorch and reinstall with correct version
61
+ print("Uninstalling current PyTorch...")
62
+ run_command("pip uninstall torch torchvision torchaudio -y")
63
  print("Installing compatible PyTorch version...")
64
  if not run_command("pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu117"):
65
  return False
 
73
  print("Checking PyTorch installation...")
74
  check_pytorch_cuda()
75
 
76
+ # Verify PyTorch version was actually updated
77
+ try:
78
+ import torch
79
+ if not torch.__version__.startswith("2.0.1") or "+cu117" not in torch.__version__:
80
+ print(f"Warning: PyTorch version is still {torch.__version__}, expected 2.0.1+cu117")
81
+ print("This may cause compatibility issues with torch-sparse and torch-scatter")
82
+ except Exception as e:
83
+ print(f"Error checking PyTorch version: {e}")
84
+
85
  # Now install torch-sparse with the compatible version
86
  print("Installing torch-sparse with PyTorch 2.0.1...")
87
  if run_command("pip install torch-sparse -f https://data.pyg.org/whl/torch-2.0.1+cu117.html"):
88
  print("Successfully installed torch-sparse")
89
+
90
+ # Verify torch-sparse is compatible
91
+ try:
92
+ import torch_sparse
93
+ print("torch-sparse import successful")
94
+ except Exception as e:
95
+ print(f"Warning: torch-sparse import failed: {e}")
96
+ print("This may indicate a version compatibility issue")
97
+
98
  return True
99
 
100
  return False
 
115
  print("Installing torch-scatter with PyTorch 2.0.1...")
116
  if run_command("pip install torch-scatter -f https://data.pyg.org/whl/torch-2.0.1+cu117.html"):
117
  print("Successfully installed torch-scatter")
118
+
119
+ # Verify torch-scatter is compatible
120
+ try:
121
+ import torch_scatter
122
+ print("torch-scatter import successful")
123
+ except Exception as e:
124
+ print(f"Warning: torch-scatter import failed: {e}")
125
+ print("This may indicate a version compatibility issue")
126
+
127
  return True
128
 
129
  return False
 
315
  import torch
316
  print(f"βœ“ PyTorch {torch.__version__} - CUDA: {torch.cuda.is_available()}")
317
 
318
+ # Check if PyTorch version is compatible
319
+ if not torch.__version__.startswith("2.0.1") or "+cu117" not in torch.__version__:
320
+ print(f"⚠ Warning: PyTorch version {torch.__version__} may not be compatible with installed extensions")
321
+
322
  import torch_sparse
323
  print("βœ“ torch-sparse")
324
 
 
336
 
337
  except ImportError as e:
338
  print(f"βœ— Import error: {e}")
339
+ print("This may indicate a version compatibility issue between PyTorch and its extensions")
340
  sys.exit(1)
341
  except Exception as e:
342
  print(f"βœ— Verification error: {e}")
343
+ print("This may indicate a version compatibility issue between PyTorch and its extensions")
344
  sys.exit(1)
345
 
346
  if __name__ == "__main__":