Mamba-Trainer / Dockerfile
Pratik Dwivedi
Fix Imports (#5)
da2bc01
raw
history blame contribute delete
No virus
622 Bytes
FROM python:3.9
WORKDIR /code
COPY . /code
RUN pip install packaging ninja buildtools
RUN pip install --no-cache-dir torch==2.2.1 --index-url https://download.pytorch.org/whl/cu121
RUN pip install --no-cache-dir -r /code/requirements.txt
COPY . .
# CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
CMD ["python", "train_mamba.py", "--model", "state-spaces/mamba-130m", "--tokenizer", "EleutherAI/gpt-neox-20b", "--learning_rate", "5e-5", "--batch_size", "1", "--gradient_accumulation_steps", "1", "--optim paged_adamw_8bit", "--data_path", "./data/ultrachat_small.jsonl", "--num_epochs", "1"]