| """ |
| Diffutslator 主入口 |
| 基于扩散模型的中英互译系统 |
| """ |
|
|
| import os |
| import sys |
| import argparse |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Diffutslator - 基于扩散模型的翻译系统", |
| formatter_class=argparse.RawDescriptionHelpFormatter, |
| epilog=""" |
| 示例: |
| # 快速验证训练 |
| python main.py train --quick |
| |
| # 完整训练 |
| python main.py train --full |
| |
| # 从检查点恢复训练 |
| python main.py train --resume checkpoints/epoch_5.pt |
| |
| # 交互式翻译 |
| python main.py translate |
| |
| # 翻译单个句子 |
| python main.py translate --text "你好世界" --zh |
| |
| # 使用更多DDIM步数 |
| python main.py translate --text "Hello world" --en --ddim-steps 100 |
| """ |
| ) |
| |
| subparsers = parser.add_subparsers(dest="command", help="命令") |
| |
| |
| train_parser = subparsers.add_parser("train", help="训练模型") |
| train_parser.add_argument("--quick", action="store_true", help="快速验证模式") |
| train_parser.add_argument("--full", action="store_true", help="完整训练模式") |
| train_parser.add_argument("--samples", type=int, default=None, help="使用的数据量") |
| train_parser.add_argument("--epochs", type=int, default=None, help="训练轮数") |
| train_parser.add_argument("--batch-size", type=int, default=None, help="批量大小") |
| train_parser.add_argument("--resume", type=str, default=None, help="恢复训练的检查点") |
| |
| |
| translate_parser = subparsers.add_parser("translate", help="翻译文本") |
| translate_parser.add_argument("--checkpoint", type=str, default=None, help="检查点路径") |
| translate_parser.add_argument("--text", type=str, default=None, help="要翻译的文本") |
| translate_parser.add_argument("--zh", action="store_true", help="输入是中文") |
| translate_parser.add_argument("--en", action="store_true", help="输入是英文") |
| translate_parser.add_argument("--interactive", "-i", action="store_true", help="交互模式") |
| translate_parser.add_argument("--quiet", "-q", action="store_true", help="安静模式") |
| translate_parser.add_argument("--ddim-steps", type=int, default=50, help="DDIM步数") |
| |
| args = parser.parse_args() |
| |
| if args.command == "train": |
| |
| from train import main as train_main |
| sys.argv = ["train.py"] |
| |
| if args.quick: |
| sys.argv.append("--quick") |
| if args.full: |
| sys.argv.append("--full") |
| if args.samples: |
| sys.argv.extend(["--samples", str(args.samples)]) |
| if args.epochs: |
| sys.argv.extend(["--epochs", str(args.epochs)]) |
| if args.batch_size: |
| sys.argv.extend(["--batch-size", str(args.batch_size)]) |
| if args.resume: |
| sys.argv.extend(["--resume", args.resume]) |
| |
| train_main() |
| |
| elif args.command == "translate": |
| |
| from inference import main as inference_main |
| sys.argv = ["inference.py"] |
| |
| if args.checkpoint: |
| sys.argv.extend(["--checkpoint", args.checkpoint]) |
| if args.text: |
| sys.argv.extend(["--text", args.text]) |
| if args.zh: |
| sys.argv.append("--zh") |
| if args.en: |
| sys.argv.append("--en") |
| if args.interactive: |
| sys.argv.append("--interactive") |
| if args.quiet: |
| sys.argv.append("--quiet") |
| if args.ddim_steps: |
| sys.argv.extend(["--ddim-steps", str(args.ddim_steps)]) |
| |
| inference_main() |
| |
| else: |
| parser.print_help() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|