forthezero's picture
Upload 28 files
2651102 verified
"""
Diffutslator 主入口
基于扩散模型的中英互译系统
"""
import os
import sys
import argparse
def main():
parser = argparse.ArgumentParser(
description="Diffutslator - 基于扩散模型的翻译系统",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
示例:
# 快速验证训练
python main.py train --quick
# 完整训练
python main.py train --full
# 从检查点恢复训练
python main.py train --resume checkpoints/epoch_5.pt
# 交互式翻译
python main.py translate
# 翻译单个句子
python main.py translate --text "你好世界" --zh
# 使用更多DDIM步数
python main.py translate --text "Hello world" --en --ddim-steps 100
"""
)
subparsers = parser.add_subparsers(dest="command", help="命令")
# 训练命令
train_parser = subparsers.add_parser("train", help="训练模型")
train_parser.add_argument("--quick", action="store_true", help="快速验证模式")
train_parser.add_argument("--full", action="store_true", help="完整训练模式")
train_parser.add_argument("--samples", type=int, default=None, help="使用的数据量")
train_parser.add_argument("--epochs", type=int, default=None, help="训练轮数")
train_parser.add_argument("--batch-size", type=int, default=None, help="批量大小")
train_parser.add_argument("--resume", type=str, default=None, help="恢复训练的检查点")
# 翻译命令
translate_parser = subparsers.add_parser("translate", help="翻译文本")
translate_parser.add_argument("--checkpoint", type=str, default=None, help="检查点路径")
translate_parser.add_argument("--text", type=str, default=None, help="要翻译的文本")
translate_parser.add_argument("--zh", action="store_true", help="输入是中文")
translate_parser.add_argument("--en", action="store_true", help="输入是英文")
translate_parser.add_argument("--interactive", "-i", action="store_true", help="交互模式")
translate_parser.add_argument("--quiet", "-q", action="store_true", help="安静模式")
translate_parser.add_argument("--ddim-steps", type=int, default=50, help="DDIM步数")
args = parser.parse_args()
if args.command == "train":
# 导入并运行训练
from train import main as train_main
sys.argv = ["train.py"]
if args.quick:
sys.argv.append("--quick")
if args.full:
sys.argv.append("--full")
if args.samples:
sys.argv.extend(["--samples", str(args.samples)])
if args.epochs:
sys.argv.extend(["--epochs", str(args.epochs)])
if args.batch_size:
sys.argv.extend(["--batch-size", str(args.batch_size)])
if args.resume:
sys.argv.extend(["--resume", args.resume])
train_main()
elif args.command == "translate":
# 导入并运行推理
from inference import main as inference_main
sys.argv = ["inference.py"]
if args.checkpoint:
sys.argv.extend(["--checkpoint", args.checkpoint])
if args.text:
sys.argv.extend(["--text", args.text])
if args.zh:
sys.argv.append("--zh")
if args.en:
sys.argv.append("--en")
if args.interactive:
sys.argv.append("--interactive")
if args.quiet:
sys.argv.append("--quiet")
if args.ddim_steps:
sys.argv.extend(["--ddim-steps", str(args.ddim_steps)])
inference_main()
else:
parser.print_help()
if __name__ == "__main__":
main()