File size: 1,368 Bytes
626eca0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#!/bin/bash

# setup conda
CONDA_BASE=$(conda info --base)
# check if conda is installed
if [ -z "$CONDA_BASE" ]; then
  echo "Conda is not installed. Please install conda first."
  exit 1
fi
source "$CONDA_BASE"/etc/profile.d/conda.sh

# create conda env
read -rp "Enter environment name or prefix: " ENV_NAME
read -rp "Enter python version (default 3.10): " PYTHON_VERSION
if [ -z "$PYTHON_VERSION" ]; then
  PYTHON_VERSION="3.10"
fi

# check if ENV_NAME is a full path
if [[ "$ENV_NAME" == /* ]]; then
  CONDA_NEW_ARG="--prefix"
else
  CONDA_NEW_ARG="--name"
fi

conda create -y "$CONDA_NEW_ARG" "$ENV_NAME" python="$PYTHON_VERSION"
conda activate "$ENV_NAME"

# replace placeholder env with $ENV_NAME in scripts/train.sh
# NEW_CONDA_LINE="source \$CONDA_BASE/bin/activate $ENV_NAME"
# sed -i.bak -e "s,.*bin/activate.*,$NEW_CONDA_LINE,g" scripts/train.sh

# install torch
read -rp "Enter cuda version (e.g. '11.8', default no cuda support): " CUDA_VERSION
read -rp "Enter PyTorch version (e.g. '2.1', default latest): " PYTORCH_VERSION
if [ -n "$PYTORCH_VERSION" ]; then
  PYTORCH_VERSION="=$PYTORCH_VERSION"
fi
if [ -z "$CUDA_VERSION" ]; then
  conda install -y pytorch"$PYTORCH_VERSION" cpuonly -c pytorch
else
  conda install -y pytorch"$PYTORCH_VERSION" pytorch-cuda="$CUDA_VERSION" -c pytorch -c nvidia
fi

# install python requirements
pip install -e .[all]