Spaces:
Sleeping
Sleeping
import os | |
import random | |
import subprocess | |
import sys | |
from enum import Enum, unique | |
from . import launcher | |
from .api.app import run_api | |
from .chat.chat_model import run_chat | |
from .eval.evaluator import run_eval | |
from .extras.logging import get_logger | |
from .extras.misc import get_device_count | |
from .train.tuner import export_model, run_exp | |
from .webui.interface import run_web_demo, run_web_ui | |
USAGE = ( | |
"-" * 70 | |
+ "\n" | |
+ "| Usage: |\n" | |
+ "| llamafactory-cli api -h: launch an OpenAI-style API server |\n" | |
+ "| llamafactory-cli chat -h: launch a chat interface in CLI |\n" | |
+ "| llamafactory-cli eval -h: evaluate models |\n" | |
+ "| llamafactory-cli export -h: merge LoRA adapters and export model |\n" | |
+ "| llamafactory-cli train -h: train models |\n" | |
+ "| llamafactory-cli webchat -h: launch a chat interface in Web UI |\n" | |
+ "| llamafactory-cli webui: launch LlamaBoard |\n" | |
+ "| llamafactory-cli version: show version info |\n" | |
+ "-" * 70 | |
) | |
VERSION = "0.7.2.dev0" | |
WELCOME = ( | |
"-" * 58 | |
+ "\n" | |
+ "| Welcome to LLaMA Factory, version {}".format(VERSION) | |
+ " " * (21 - len(VERSION)) | |
+ "|\n|" | |
+ " " * 56 | |
+ "|\n" | |
+ "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n" | |
+ "-" * 58 | |
) | |
logger = get_logger(__name__) | |
class Command(str, Enum): | |
API = "api" | |
CHAT = "chat" | |
EVAL = "eval" | |
EXPORT = "export" | |
TRAIN = "train" | |
WEBDEMO = "webchat" | |
WEBUI = "webui" | |
VER = "version" | |
HELP = "help" | |
def main(): | |
command = sys.argv.pop(1) | |
if command == Command.API: | |
run_api() | |
elif command == Command.CHAT: | |
run_chat() | |
elif command == Command.EVAL: | |
run_eval() | |
elif command == Command.EXPORT: | |
export_model() | |
elif command == Command.TRAIN: | |
if get_device_count() > 1: | |
nnodes = os.environ.get("NNODES", "1") | |
node_rank = os.environ.get("RANK", "0") | |
nproc_per_node = os.environ.get("NPROC_PER_NODE", str(get_device_count())) | |
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1") | |
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999))) | |
logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port)) | |
subprocess.run( | |
[ | |
"torchrun", | |
"--nnodes", | |
nnodes, | |
"--node_rank", | |
node_rank, | |
"--nproc_per_node", | |
nproc_per_node, | |
"--master_addr", | |
master_addr, | |
"--master_port", | |
master_port, | |
launcher.__file__, | |
*sys.argv[1:], | |
] | |
) | |
else: | |
run_exp() | |
elif command == Command.WEBDEMO: | |
run_web_demo() | |
elif command == Command.WEBUI: | |
run_web_ui() | |
elif command == Command.VER: | |
print(WELCOME) | |
elif command == Command.HELP: | |
print(USAGE) | |
else: | |
raise NotImplementedError("Unknown command: {}".format(command)) | |