{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from accelerate import notebook_launcher\n", "import torch\n", "\n", "from model.trainer import ChatTrainer\n", "from config import TrainConfig, T5ModelConfig" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_config = TrainConfig()\n", "model_config = T5ModelConfig()\n", "\n", "print(train_config)\n", "print(model_config)\n", "\n", "gpu_count = torch.cuda.device_count()\n", "print('gpu device count: {}'.format(gpu_count))\n", "\n", "chat_trainer = ChatTrainer(train_config=train_config, model_config=model_config)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train = chat_trainer.train\n", "\n", "# chat_trainer.train() args: is_keep_training: bool, is_finetune: bool\n", "train_args = (False, False)\n", "\n", "# 使用notebook_launcher函数启动多卡训练\n", "notebook_launcher(train, num_processes=gpu_count, args=train_args, mixed_precision=train_config.mixed_precision)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test = chat_trainer.test\n", "notebook_launcher(test, num_processes=gpu_count, mixed_precision=train_config.mixed_precision)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }