| import subprocess |
| from argparse import ArgumentParser |
|
|
| from loguru import logger |
|
|
| from . import BaseAutoTrainCommand |
|
|
|
|
| def run_app_command_factory(args): |
| return RunSetupCommand(args.update_torch) |
|
|
|
|
| class RunSetupCommand(BaseAutoTrainCommand): |
| @staticmethod |
| def register_subcommand(parser: ArgumentParser): |
| run_setup_parser = parser.add_parser( |
| "setup", |
| description="✨ Run AutoTrain setup", |
| ) |
| run_setup_parser.add_argument( |
| "--update-torch", |
| action="store_true", |
| help="Update PyTorch to latest version", |
| ) |
| run_setup_parser.set_defaults(func=run_app_command_factory) |
|
|
| def __init__(self, update_torch: bool): |
| self.update_torch = update_torch |
|
|
| def run(self): |
| |
| cmd = "pip uninstall -y transformers && pip install git+https://github.com/huggingface/transformers.git" |
| pipe = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
| logger.info("Installing latest transformers@main") |
| _, _ = pipe.communicate() |
| logger.info("Successfully installed latest transformers") |
|
|
| cmd = "pip uninstall -y peft && pip install git+https://github.com/huggingface/peft.git" |
| pipe = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
| logger.info("Installing latest peft@main") |
| _, _ = pipe.communicate() |
| logger.info("Successfully installed latest peft") |
|
|
| cmd = "pip uninstall -y diffusers && pip install git+https://github.com/huggingface/diffusers.git" |
| pipe = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
| logger.info("Installing latest diffusers@main") |
| _, _ = pipe.communicate() |
| logger.info("Successfully installed latest diffusers") |
|
|
| cmd = "pip uninstall -y trl && pip install git+https://github.com/lvwerra/trl.git" |
| pipe = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
| logger.info("Installing latest trl@main") |
| _, _ = pipe.communicate() |
| logger.info("Successfully installed latest trl") |
|
|
| if self.update_torch: |
| cmd = "pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118" |
| pipe = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
| logger.info("Installing latest PyTorch") |
| _, _ = pipe.communicate() |
| logger.info("Successfully installed latest PyTorch") |
|
|