Add files using upload-large-folder tool
Browse files- .gitattributes +1 -0
- .gitignore +9 -0
- =0.14.0 +18 -0
- README.md +49 -0
- checkpoints/ae_best.pt +3 -0
- checkpoints/ae_last.pt +3 -0
- checkpoints/flow_best.pt +3 -0
- checkpoints/flow_last.pt +3 -0
- code_x_glue_cc_code_refinement_full/.gitattributes +27 -0
- code_x_glue_cc_code_refinement_full/README.md +273 -0
- code_x_glue_cc_code_refinement_full/medium/test-00000-of-00001.parquet +3 -0
- code_x_glue_cc_code_refinement_full/medium/train-00000-of-00001.parquet +3 -0
- code_x_glue_cc_code_refinement_full/medium/validation-00000-of-00001.parquet +3 -0
- code_x_glue_cc_code_refinement_full/small/test-00000-of-00001.parquet +3 -0
- code_x_glue_cc_code_refinement_full/small/train-00000-of-00001.parquet +3 -0
- code_x_glue_cc_code_refinement_full/small/validation-00000-of-00001.parquet +3 -0
- download_data.txt +29 -0
- eval_ae.py +156 -0
- requirements.txt +9 -0
- run_repair_flow.py +162 -0
- run_wiki_flow.py +282 -0
- src/__init__.py +0 -0
- src/__pycache__/__init__.cpython-311.pyc +0 -0
- src/__pycache__/config.cpython-311.pyc +0 -0
- src/__pycache__/search.cpython-311.pyc +0 -0
- src/__pycache__/trainer.cpython-311.pyc +0 -0
- src/config.py +44 -0
- src/models/__init__.py +0 -0
- src/models/__pycache__/__init__.cpython-311.pyc +0 -0
- src/models/__pycache__/autoencoder.cpython-311.pyc +0 -0
- src/models/__pycache__/dit.cpython-311.pyc +0 -0
- src/models/autoencoder.py +181 -0
- src/models/dit.py +279 -0
- src/search.py +129 -0
- src/trainer.py +267 -0
- src/utils/__pycache__/data_utils.cpython-311.pyc +0 -0
- src/utils/__pycache__/sandbox.cpython-311.pyc +0 -0
- src/utils/data_utils.py +157 -0
- src/utils/sandbox.py +156 -0
- tests/__pycache__/test_models.cpython-311.pyc +0 -0
- tests/test_models.py +61 -0
- train_ae.py +86 -0
- train_flow.py +227 -0
- wiki_results.tsv +0 -0
- wikilarge-dataset/.gitattributes +55 -0
- wikilarge-dataset/wiki.full.aner.ori.test.95.tsv +192 -0
- wikilarge-dataset/wiki.full.aner.ori.train.95.tsv +3 -0
- wikilarge-dataset/wiki.full.aner.ori.valid.95.tsv +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
wikilarge-dataset/wiki.full.aner.ori.train.95.tsv filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.venv/*
|
| 2 |
+
bigcode_humanevalpack_full/*
|
| 3 |
+
google_code_x_glue_ct_code_to_text_full/*
|
| 4 |
+
wiki_results.tsv
|
| 5 |
+
.cache/*
|
| 6 |
+
checkpoints/*
|
| 7 |
+
code_x_glue_cc_code_refinement_full/*
|
| 8 |
+
wikilarge-dataset/*
|
| 9 |
+
|
=0.14.0
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Looking in indexes: https://bytedpypi.byted.org/simple/
|
| 2 |
+
Requirement already satisfied: huggingface-hub in /mlx_devbox/users/lixinyu.222/playground/Diffusion_Learning/.venv/lib/python3.11/site-packages (1.2.3)
|
| 3 |
+
Requirement already satisfied: filelock in /mlx_devbox/users/lixinyu.222/playground/Diffusion_Learning/.venv/lib/python3.11/site-packages (from huggingface-hub) (3.20.0)
|
| 4 |
+
Requirement already satisfied: fsspec>=2023.5.0 in /mlx_devbox/users/lixinyu.222/playground/Diffusion_Learning/.venv/lib/python3.11/site-packages (from huggingface-hub) (2025.10.0)
|
| 5 |
+
Requirement already satisfied: hf-xet<2.0.0,>=1.2.0 in /mlx_devbox/users/lixinyu.222/playground/Diffusion_Learning/.venv/lib/python3.11/site-packages (from huggingface-hub) (1.2.0)
|
| 6 |
+
Requirement already satisfied: httpx<1,>=0.23.0 in /mlx_devbox/users/lixinyu.222/playground/Diffusion_Learning/.venv/lib/python3.11/site-packages (from huggingface-hub) (0.28.1)
|
| 7 |
+
Requirement already satisfied: packaging>=20.9 in /mlx_devbox/users/lixinyu.222/playground/Diffusion_Learning/.venv/lib/python3.11/site-packages (from huggingface-hub) (25.0)
|
| 8 |
+
Requirement already satisfied: pyyaml>=5.1 in /mlx_devbox/users/lixinyu.222/playground/Diffusion_Learning/.venv/lib/python3.11/site-packages (from huggingface-hub) (6.0.3)
|
| 9 |
+
Requirement already satisfied: shellingham in /mlx_devbox/users/lixinyu.222/playground/Diffusion_Learning/.venv/lib/python3.11/site-packages (from huggingface-hub) (1.5.4)
|
| 10 |
+
Requirement already satisfied: tqdm>=4.42.1 in /mlx_devbox/users/lixinyu.222/playground/Diffusion_Learning/.venv/lib/python3.11/site-packages (from huggingface-hub) (4.67.1)
|
| 11 |
+
Requirement already satisfied: typer-slim in /mlx_devbox/users/lixinyu.222/playground/Diffusion_Learning/.venv/lib/python3.11/site-packages (from huggingface-hub) (0.20.0)
|
| 12 |
+
Requirement already satisfied: typing-extensions>=3.7.4.3 in /mlx_devbox/users/lixinyu.222/playground/Diffusion_Learning/.venv/lib/python3.11/site-packages (from huggingface-hub) (4.15.0)
|
| 13 |
+
Requirement already satisfied: anyio in /mlx_devbox/users/lixinyu.222/playground/Diffusion_Learning/.venv/lib/python3.11/site-packages (from httpx<1,>=0.23.0->huggingface-hub) (4.12.0)
|
| 14 |
+
Requirement already satisfied: certifi in /mlx_devbox/users/lixinyu.222/playground/Diffusion_Learning/.venv/lib/python3.11/site-packages (from httpx<1,>=0.23.0->huggingface-hub) (2025.11.12)
|
| 15 |
+
Requirement already satisfied: httpcore==1.* in /mlx_devbox/users/lixinyu.222/playground/Diffusion_Learning/.venv/lib/python3.11/site-packages (from httpx<1,>=0.23.0->huggingface-hub) (1.0.9)
|
| 16 |
+
Requirement already satisfied: idna in /mlx_devbox/users/lixinyu.222/playground/Diffusion_Learning/.venv/lib/python3.11/site-packages (from httpx<1,>=0.23.0->huggingface-hub) (3.11)
|
| 17 |
+
Requirement already satisfied: h11>=0.16 in /mlx_devbox/users/lixinyu.222/playground/Diffusion_Learning/.venv/lib/python3.11/site-packages (from httpcore==1.*->httpx<1,>=0.23.0->huggingface-hub) (0.16.0)
|
| 18 |
+
Requirement already satisfied: click>=8.0.0 in /mlx_devbox/users/lixinyu.222/playground/Diffusion_Learning/.venv/lib/python3.11/site-packages (from typer-slim->huggingface-hub) (8.3.1)
|
README.md
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CodeFlow
|
| 2 |
+
## Version 1
|
| 3 |
+
CodeFlow/
|
| 4 |
+
├── src/
|
| 5 |
+
│ ├── __init__.py
|
| 6 |
+
│ ├── config.py # 全局配置
|
| 7 |
+
│ ├── models/
|
| 8 |
+
│ │ ├── __init__.py
|
| 9 |
+
│ │ ├── autoencoder.py # 潜空间 AE
|
| 10 |
+
│ │ └── dit.py # Diffusion Transformer
|
| 11 |
+
│ ├── utils/
|
| 12 |
+
│ │ ├── sandbox.py # 代码执行沙箱
|
| 13 |
+
│ │ └── data_utils.py # 数据加载器
|
| 14 |
+
│ └── trainer.py # 训练与推理引擎
|
| 15 |
+
├── tests/
|
| 16 |
+
│ ├── test_models.py # 模型单元测试
|
| 17 |
+
│ └── test_sandbox.py # 沙箱单元测试
|
| 18 |
+
├── run_wiki_flow.py # 入口1:Wiki 简化
|
| 19 |
+
└── run_mbpp_ae.py # 入口2:MBPP 重建验证
|
| 20 |
+
|
| 21 |
+
## Version 2
|
| 22 |
+
CodeFlow/
|
| 23 |
+
├── src/
|
| 24 |
+
│ ├── __init__.py
|
| 25 |
+
│ ├── config.py # 全局配置 (Patching, Dimensions)
|
| 26 |
+
│ ├── models/
|
| 27 |
+
│ │ ├── __init__.py
|
| 28 |
+
│ │ ├── autoencoder.py # Jina -> Linear -> Sphere -> Decoder
|
| 29 |
+
│ │ └── dit.py # Patched DiT + Flow Logic
|
| 30 |
+
│ ├── utils/
|
| 31 |
+
│ │ └── data.py # Wiki/MBPP 数据加载
|
| 32 |
+
│ └── trainer.py # 训练引擎 (AE & Flow)
|
| 33 |
+
├── run_mbpp_ae.py # 入口1:验证重建能力
|
| 34 |
+
├── run_wiki_flow.py # 入口2:验证 Flow Matching 编辑能力
|
| 35 |
+
└── requirements.txt
|
| 36 |
+
|
| 37 |
+
Autoencoder: 移除 VAE/KL,改用 Linear Compression + L2 Normalization。这保证了潜空间在单位球面上,语义连续且训练极其稳定。
|
| 38 |
+
Backbone: 依然是 Jina-v2 (Freeze) + NAR Decoder。
|
| 39 |
+
Generator: Patched DiT 配合 Rectified Flow,解决长序列计算瓶颈。
|
| 40 |
+
Optimization: 内置梯度累积、混合精度开关(默认关闭以适配 Jina)、多进程数据处理。
|
| 41 |
+
|
| 42 |
+
### 手动下载
|
| 43 |
+
# 安装 huggingface-hub 工具(若未安装)
|
| 44 |
+
pip install huggingface-hub
|
| 45 |
+
|
| 46 |
+
# 下载模型到本地目录(比如 ./jina-embeddings-v2-base-code)
|
| 47 |
+
huggingface-cli download --resume-download jinaai/jina-embeddings-v2-base-code --local-dir ./jina-embeddings-v2-base-code
|
| 48 |
+
|
| 49 |
+
huggingface-cli download bogdancazan/wikilarge-text-simplification --repo-type dataset --resume-download --local-dir ./wikilarge-dataset
|
checkpoints/ae_best.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f586f7c121051addeac0aa2e324cef974a2c46b4737c82d7b1835411efdca56d
|
| 3 |
+
size 1481302557
|
checkpoints/ae_last.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:55ed2a5d4a6e5c47fa7e148da7d296096720b368947b54b72ca2d487d6581116
|
| 3 |
+
size 1481302557
|
checkpoints/flow_best.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:909a902719e073cdb523c84cab003e8df3cb0763db4b78e38b6e5a08c529f193
|
| 3 |
+
size 532346699
|
checkpoints/flow_last.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:39d42298d1e93ead94725f2585425e8bdbc5d196dccaba57348b9bdc12ee56a1
|
| 3 |
+
size 532346699
|
code_x_glue_cc_code_refinement_full/.gitattributes
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bin.* filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
code_x_glue_cc_code_refinement_full/README.md
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
annotations_creators:
|
| 3 |
+
- expert-generated
|
| 4 |
+
language_creators:
|
| 5 |
+
- found
|
| 6 |
+
language:
|
| 7 |
+
- code
|
| 8 |
+
license:
|
| 9 |
+
- c-uda
|
| 10 |
+
multilinguality:
|
| 11 |
+
- other-programming-languages
|
| 12 |
+
size_categories:
|
| 13 |
+
- 10K<n<100K
|
| 14 |
+
source_datasets:
|
| 15 |
+
- original
|
| 16 |
+
task_categories:
|
| 17 |
+
- text2text-generation
|
| 18 |
+
task_ids: []
|
| 19 |
+
pretty_name: CodeXGlueCcCodeRefinement
|
| 20 |
+
tags:
|
| 21 |
+
- debugging
|
| 22 |
+
dataset_info:
|
| 23 |
+
- config_name: medium
|
| 24 |
+
features:
|
| 25 |
+
- name: id
|
| 26 |
+
dtype: int32
|
| 27 |
+
- name: buggy
|
| 28 |
+
dtype: string
|
| 29 |
+
- name: fixed
|
| 30 |
+
dtype: string
|
| 31 |
+
splits:
|
| 32 |
+
- name: train
|
| 33 |
+
num_bytes: 32614786
|
| 34 |
+
num_examples: 52364
|
| 35 |
+
- name: validation
|
| 36 |
+
num_bytes: 4086733
|
| 37 |
+
num_examples: 6546
|
| 38 |
+
- name: test
|
| 39 |
+
num_bytes: 4063665
|
| 40 |
+
num_examples: 6545
|
| 41 |
+
download_size: 14929559
|
| 42 |
+
dataset_size: 40765184
|
| 43 |
+
- config_name: small
|
| 44 |
+
features:
|
| 45 |
+
- name: id
|
| 46 |
+
dtype: int32
|
| 47 |
+
- name: buggy
|
| 48 |
+
dtype: string
|
| 49 |
+
- name: fixed
|
| 50 |
+
dtype: string
|
| 51 |
+
splits:
|
| 52 |
+
- name: train
|
| 53 |
+
num_bytes: 13006679
|
| 54 |
+
num_examples: 46680
|
| 55 |
+
- name: validation
|
| 56 |
+
num_bytes: 1629242
|
| 57 |
+
num_examples: 5835
|
| 58 |
+
- name: test
|
| 59 |
+
num_bytes: 1619700
|
| 60 |
+
num_examples: 5835
|
| 61 |
+
download_size: 5894462
|
| 62 |
+
dataset_size: 16255621
|
| 63 |
+
configs:
|
| 64 |
+
- config_name: medium
|
| 65 |
+
data_files:
|
| 66 |
+
- split: train
|
| 67 |
+
path: medium/train-*
|
| 68 |
+
- split: validation
|
| 69 |
+
path: medium/validation-*
|
| 70 |
+
- split: test
|
| 71 |
+
path: medium/test-*
|
| 72 |
+
- config_name: small
|
| 73 |
+
data_files:
|
| 74 |
+
- split: train
|
| 75 |
+
path: small/train-*
|
| 76 |
+
- split: validation
|
| 77 |
+
path: small/validation-*
|
| 78 |
+
- split: test
|
| 79 |
+
path: small/test-*
|
| 80 |
+
---
|
| 81 |
+
|
| 82 |
+
# Dataset Card for "code_x_glue_cc_code_refinement"
|
| 83 |
+
|
| 84 |
+
## Table of Contents
|
| 85 |
+
- [Dataset Description](#dataset-description)
|
| 86 |
+
- [Dataset Summary](#dataset-summary)
|
| 87 |
+
- [Supported Tasks and Leaderboards](#supported-tasks)
|
| 88 |
+
- [Languages](#languages)
|
| 89 |
+
- [Dataset Structure](#dataset-structure)
|
| 90 |
+
- [Data Instances](#data-instances)
|
| 91 |
+
- [Data Fields](#data-fields)
|
| 92 |
+
- [Data Splits](#data-splits-sample-size)
|
| 93 |
+
- [Dataset Creation](#dataset-creation)
|
| 94 |
+
- [Curation Rationale](#curation-rationale)
|
| 95 |
+
- [Source Data](#source-data)
|
| 96 |
+
- [Annotations](#annotations)
|
| 97 |
+
- [Personal and Sensitive Information](#personal-and-sensitive-information)
|
| 98 |
+
- [Considerations for Using the Data](#considerations-for-using-the-data)
|
| 99 |
+
- [Social Impact of Dataset](#social-impact-of-dataset)
|
| 100 |
+
- [Discussion of Biases](#discussion-of-biases)
|
| 101 |
+
- [Other Known Limitations](#other-known-limitations)
|
| 102 |
+
- [Additional Information](#additional-information)
|
| 103 |
+
- [Dataset Curators](#dataset-curators)
|
| 104 |
+
- [Licensing Information](#licensing-information)
|
| 105 |
+
- [Citation Information](#citation-information)
|
| 106 |
+
- [Contributions](#contributions)
|
| 107 |
+
|
| 108 |
+
## Dataset Description
|
| 109 |
+
|
| 110 |
+
- **Homepage:** https://github.com/microsoft/CodeXGLUE/tree/main/Code-Code/code-refinement
|
| 111 |
+
- **Paper:** https://arxiv.org/abs/2102.04664
|
| 112 |
+
|
| 113 |
+
### Dataset Summary
|
| 114 |
+
|
| 115 |
+
CodeXGLUE code-refinement dataset, available at https://github.com/microsoft/CodeXGLUE/tree/main/Code-Code/code-refinement
|
| 116 |
+
|
| 117 |
+
We use the dataset released by this paper(https://arxiv.org/pdf/1812.08693.pdf). The source side is a Java function with bugs and the target side is the refined one. All the function and variable names are normalized. Their dataset contains two subsets ( i.e.small and medium) based on the function length.
|
| 118 |
+
|
| 119 |
+
### Supported Tasks and Leaderboards
|
| 120 |
+
|
| 121 |
+
- `text2text-generation-other-debugging`: The dataset can be used to train a model for automatically fixing buggy code.
|
| 122 |
+
|
| 123 |
+
### Languages
|
| 124 |
+
|
| 125 |
+
- Java **programming** language
|
| 126 |
+
|
| 127 |
+
## Dataset Structure
|
| 128 |
+
|
| 129 |
+
### Data Instances
|
| 130 |
+
|
| 131 |
+
#### medium
|
| 132 |
+
|
| 133 |
+
An example of 'train' looks as follows.
|
| 134 |
+
```
|
| 135 |
+
{
|
| 136 |
+
"buggy": "public static TYPE_1 init ( java.lang.String name , java.util.Date date ) { TYPE_1 VAR_1 = new TYPE_1 ( ) ; VAR_1 . METHOD_1 ( name ) ; java.util.Calendar VAR_2 = java.util.Calendar.getInstance ( ) ; VAR_2 . METHOD_2 ( date ) ; VAR_1 . METHOD_3 ( VAR_2 ) ; return VAR_1 ; }\n",
|
| 137 |
+
"fixed": "public static TYPE_1 init ( java.lang.String name , java.util.Date date ) { TYPE_1 VAR_1 = new TYPE_1 ( ) ; VAR_1 . METHOD_1 ( name ) ; java.util.Calendar VAR_2 = null ; if ( date != null ) { VAR_2 = java.util.Calendar.getInstance ( ) ; VAR_2 . METHOD_2 ( date ) ; } VAR_1 . METHOD_3 ( VAR_2 ) ; return VAR_1 ; }\n",
|
| 138 |
+
"id": 0
|
| 139 |
+
}
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
#### small
|
| 143 |
+
|
| 144 |
+
An example of 'validation' looks as follows.
|
| 145 |
+
```
|
| 146 |
+
{
|
| 147 |
+
"buggy": "public java.util.List < TYPE_1 > METHOD_1 ( ) { java.util.ArrayList < TYPE_1 > VAR_1 = new java.util.ArrayList < TYPE_1 > ( ) ; for ( TYPE_2 VAR_2 : VAR_3 ) { VAR_1 . METHOD_2 ( VAR_2 . METHOD_1 ( ) ) ; } return VAR_1 ; } \n",
|
| 148 |
+
"fixed": "public java.util.List < TYPE_1 > METHOD_1 ( ) { return VAR_1 ; } \n",
|
| 149 |
+
"id": 0
|
| 150 |
+
}
|
| 151 |
+
```
|
| 152 |
+
|
| 153 |
+
### Data Fields
|
| 154 |
+
|
| 155 |
+
In the following each data field in go is explained for each config. The data fields are the same among all splits.
|
| 156 |
+
|
| 157 |
+
#### medium, small
|
| 158 |
+
|
| 159 |
+
|field name| type | description |
|
| 160 |
+
|----------|------|--------------------------------|
|
| 161 |
+
|id |int32 | Index of the sample |
|
| 162 |
+
|buggy |string| The buggy version of the code |
|
| 163 |
+
|fixed |string| The correct version of the code|
|
| 164 |
+
|
| 165 |
+
### Data Splits
|
| 166 |
+
|
| 167 |
+
| name |train|validation|test|
|
| 168 |
+
|------|----:|---------:|---:|
|
| 169 |
+
|medium|52364| 6546|6545|
|
| 170 |
+
|small |46680| 5835|5835|
|
| 171 |
+
|
| 172 |
+
## Dataset Creation
|
| 173 |
+
|
| 174 |
+
### Curation Rationale
|
| 175 |
+
|
| 176 |
+
[More Information Needed]
|
| 177 |
+
|
| 178 |
+
### Source Data
|
| 179 |
+
|
| 180 |
+
#### Initial Data Collection and Normalization
|
| 181 |
+
|
| 182 |
+
Downloaded from GitHub Archive every public GitHub event between March 2011 and October 2017 and used the Google BigQuery APIs.
|
| 183 |
+
[More Information Needed]
|
| 184 |
+
|
| 185 |
+
#### Who are the source language producers?
|
| 186 |
+
|
| 187 |
+
Software Engineering developers.
|
| 188 |
+
|
| 189 |
+
### Annotations
|
| 190 |
+
|
| 191 |
+
#### Annotation process
|
| 192 |
+
|
| 193 |
+
Automatically annotated by filtering commit messages containing the pattern: ("fix" or "solve") and ("bug" or "issue" or "problem" or "error"). A statistically significant amount of samples (95% confidence level with 5% confidence interval) were manually evaluated by two authors to check if the filtered bug/fix pairs were correct. After all disagreements were settled, authors conclude that 97.6% were true positives.
|
| 194 |
+
|
| 195 |
+
#### Who are the annotators?
|
| 196 |
+
|
| 197 |
+
Heuristics and the authors of the paper.
|
| 198 |
+
|
| 199 |
+
### Personal and Sensitive Information
|
| 200 |
+
|
| 201 |
+
[More Information Needed]
|
| 202 |
+
|
| 203 |
+
## Considerations for Using the Data
|
| 204 |
+
|
| 205 |
+
### Social Impact of Dataset
|
| 206 |
+
|
| 207 |
+
[More Information Needed]
|
| 208 |
+
|
| 209 |
+
### Discussion of Biases
|
| 210 |
+
|
| 211 |
+
[More Information Needed]
|
| 212 |
+
|
| 213 |
+
### Other Known Limitations
|
| 214 |
+
|
| 215 |
+
[More Information Needed]
|
| 216 |
+
|
| 217 |
+
## Additional Information
|
| 218 |
+
|
| 219 |
+
### Dataset Curators
|
| 220 |
+
|
| 221 |
+
https://github.com/microsoft, https://github.com/madlag
|
| 222 |
+
|
| 223 |
+
### Licensing Information
|
| 224 |
+
|
| 225 |
+
Computational Use of Data Agreement (C-UDA) License.
|
| 226 |
+
|
| 227 |
+
### Citation Information
|
| 228 |
+
|
| 229 |
+
```
|
| 230 |
+
@article{DBLP:journals/corr/abs-2102-04664,
|
| 231 |
+
author = {Shuai Lu and
|
| 232 |
+
Daya Guo and
|
| 233 |
+
Shuo Ren and
|
| 234 |
+
Junjie Huang and
|
| 235 |
+
Alexey Svyatkovskiy and
|
| 236 |
+
Ambrosio Blanco and
|
| 237 |
+
Colin B. Clement and
|
| 238 |
+
Dawn Drain and
|
| 239 |
+
Daxin Jiang and
|
| 240 |
+
Duyu Tang and
|
| 241 |
+
Ge Li and
|
| 242 |
+
Lidong Zhou and
|
| 243 |
+
Linjun Shou and
|
| 244 |
+
Long Zhou and
|
| 245 |
+
Michele Tufano and
|
| 246 |
+
Ming Gong and
|
| 247 |
+
Ming Zhou and
|
| 248 |
+
Nan Duan and
|
| 249 |
+
Neel Sundaresan and
|
| 250 |
+
Shao Kun Deng and
|
| 251 |
+
Shengyu Fu and
|
| 252 |
+
Shujie Liu},
|
| 253 |
+
title = {CodeXGLUE: {A} Machine Learning Benchmark Dataset for Code Understanding
|
| 254 |
+
and Generation},
|
| 255 |
+
journal = {CoRR},
|
| 256 |
+
volume = {abs/2102.04664},
|
| 257 |
+
year = {2021}
|
| 258 |
+
}
|
| 259 |
+
@article{tufano2019empirical,
|
| 260 |
+
title={An empirical study on learning bug-fixing patches in the wild via neural machine translation},
|
| 261 |
+
author={Tufano, Michele and Watson, Cody and Bavota, Gabriele and Penta, Massimiliano Di and White, Martin and Poshyvanyk, Denys},
|
| 262 |
+
journal={ACM Transactions on Software Engineering and Methodology (TOSEM)},
|
| 263 |
+
volume={28},
|
| 264 |
+
number={4},
|
| 265 |
+
pages={1--29},
|
| 266 |
+
year={2019},
|
| 267 |
+
publisher={ACM New York, NY, USA}
|
| 268 |
+
}
|
| 269 |
+
```
|
| 270 |
+
|
| 271 |
+
### Contributions
|
| 272 |
+
|
| 273 |
+
Thanks to @madlag (and partly also @ncoop57) for adding this dataset.
|
code_x_glue_cc_code_refinement_full/medium/test-00000-of-00001.parquet
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:776766b34878a999193dae39f0462ed1dc5aec3b7d219d3488db4c0127eca858
|
| 3 |
+
size 1488083
|
code_x_glue_cc_code_refinement_full/medium/train-00000-of-00001.parquet
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:955496550f8603dd920aa2a09d1e4a14d878bda9f35f2f3575bae68b9d9493f0
|
| 3 |
+
size 11943277
|
code_x_glue_cc_code_refinement_full/medium/validation-00000-of-00001.parquet
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9eb660729a7a43df19c714726933a517d0497cefa99cb95328bc964658eaa44c
|
| 3 |
+
size 1498199
|
code_x_glue_cc_code_refinement_full/small/test-00000-of-00001.parquet
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:be9b78909122cc07a9b31b55c999fe8b8346b1bd775e726b0cbb044f9d1ebc90
|
| 3 |
+
size 588578
|
code_x_glue_cc_code_refinement_full/small/train-00000-of-00001.parquet
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6e0ddaf02afd36f2827d0f6edcf94b3b6b5908206cbf63fc6ac99ee4a6332472
|
| 3 |
+
size 4715251
|
code_x_glue_cc_code_refinement_full/small/validation-00000-of-00001.parquet
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b623a6544959b1767eb3082064b3b3896e26f0770dfd2270d14d4e62da0246a5
|
| 3 |
+
size 590633
|
download_data.txt
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
pip install huggingface-hub
|
| 3 |
+
# 升级到最新版本(推荐)
|
| 4 |
+
pip install --upgrade huggingface-hub
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# 若需指定稳定版本(可选,最低支持--subset的版本约为0.14.0)
|
| 8 |
+
pip install huggingface-hub>=0.14.0
|
| 9 |
+
|
| 10 |
+
### 烙铁下错数据集了
|
| 11 |
+
huggingface-cli download \
|
| 12 |
+
google/code_x_glue_ct_code_to_text \
|
| 13 |
+
--repo-type dataset \
|
| 14 |
+
--local-dir ./google_code_x_glue_ct_code_to_text_full \
|
| 15 |
+
--local-dir-use-symlinks False
|
| 16 |
+
|
| 17 |
+
huggingface-cli download \
|
| 18 |
+
bigcode/humanevalpack \
|
| 19 |
+
--repo-type dataset \
|
| 20 |
+
--local-dir ./bigcode_humanevalpack_full \
|
| 21 |
+
--local-dir-use-symlinks False
|
| 22 |
+
|
| 23 |
+
## code Refine 数据集
|
| 24 |
+
load_dataset("google/code_x_glue_cc_code_refinement", "medium")
|
| 25 |
+
huggingface-cli download \
|
| 26 |
+
google/code_x_glue_cc_code_refinement \
|
| 27 |
+
--repo-type dataset \
|
| 28 |
+
--local-dir ./code_x_glue_cc_code_refinement_full \
|
| 29 |
+
--local-dir-use-symlinks False
|
eval_ae.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# scripts/eval_ae_consistency.py
|
| 2 |
+
"""
|
| 3 |
+
z0 = encoder(x)
|
| 4 |
+
x^1 = decoder(z0)
|
| 5 |
+
z1 = encoder(x^1)
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import argparse
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
from transformers import AutoTokenizer
|
| 13 |
+
|
| 14 |
+
from src.config import ModelConfig, TrainConfig
|
| 15 |
+
from src.models.autoencoder import ReshapedAutoencoder
|
| 16 |
+
from src.utils.data_utils import prepare_data
|
| 17 |
+
|
| 18 |
+
def pick_stop_id(tokenizer):
|
| 19 |
+
return tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.sep_token_id
|
| 20 |
+
|
| 21 |
+
def masked_mean(x, mask, eps=1e-6):
|
| 22 |
+
# x: [B,L] or [B,L,D] reduced already, mask: [B,L]
|
| 23 |
+
denom = mask.sum().clamp(min=eps)
|
| 24 |
+
return (x * mask).sum() / denom
|
| 25 |
+
|
| 26 |
+
@torch.no_grad()
|
| 27 |
+
def main():
|
| 28 |
+
ap = argparse.ArgumentParser()
|
| 29 |
+
ap.add_argument("--dataset", type=str, default="wiki")
|
| 30 |
+
ap.add_argument("--split", type=str, default="test")
|
| 31 |
+
ap.add_argument("--max_seq_len", type=int, default=128)
|
| 32 |
+
ap.add_argument("--batch_size", type=int, default=16)
|
| 33 |
+
ap.add_argument("--ckpt", type=str, default="/mnt/hdfs/user/lixinyu.222/CodeFlow/residual_robust_checkpoints/ae_best.pt", help="path to ae.state_dict()")
|
| 34 |
+
ap.add_argument("--max_batches", type=int, default=0, help="0 means full eval")
|
| 35 |
+
ap.add_argument("--print_n", type=int, default=8)
|
| 36 |
+
args = ap.parse_args()
|
| 37 |
+
|
| 38 |
+
# configs
|
| 39 |
+
m_cfg = ModelConfig(
|
| 40 |
+
encoder_name='../jina-embeddings-v2-base-code',
|
| 41 |
+
latent_dim=512,
|
| 42 |
+
max_seq_len=args.max_seq_len,
|
| 43 |
+
)
|
| 44 |
+
t_cfg = TrainConfig(batch_size=args.batch_size)
|
| 45 |
+
|
| 46 |
+
device = t_cfg.device
|
| 47 |
+
tokenizer = AutoTokenizer.from_pretrained(m_cfg.encoder_name,local_files_only=True,trust_remote_code=False)
|
| 48 |
+
stop_id = pick_stop_id(tokenizer)
|
| 49 |
+
|
| 50 |
+
loader = prepare_data(args.dataset, tokenizer, m_cfg.max_seq_len, t_cfg.batch_size, split=args.split)
|
| 51 |
+
# test_loader = prepare_data("wiki", tokenizer, m_cfg.max_seq_len, t_cfg.batch_size, split="test")
|
| 52 |
+
|
| 53 |
+
ae = ReshapedAutoencoder(m_cfg).to(device).float()
|
| 54 |
+
if args.ckpt:
|
| 55 |
+
sd = torch.load(args.ckpt, map_location="cpu")
|
| 56 |
+
ae.load_state_dict(sd, strict=True)
|
| 57 |
+
ae.eval()
|
| 58 |
+
|
| 59 |
+
total_ce = 0.0
|
| 60 |
+
total_acc = 0.0
|
| 61 |
+
total_tokens = 0.0
|
| 62 |
+
|
| 63 |
+
eos_found = 0
|
| 64 |
+
eos_pos_err = 0.0
|
| 65 |
+
eos_count = 0
|
| 66 |
+
|
| 67 |
+
total_cos = 0.0
|
| 68 |
+
total_l2 = 0.0
|
| 69 |
+
total_lat_tokens = 0.0
|
| 70 |
+
|
| 71 |
+
printed = 0
|
| 72 |
+
|
| 73 |
+
for bi, batch in enumerate(tqdm(loader, desc="Eval AE")):
|
| 74 |
+
if args.max_batches and bi >= args.max_batches:
|
| 75 |
+
break
|
| 76 |
+
|
| 77 |
+
ids = batch["tgt_ids"].to(device)
|
| 78 |
+
mask = batch["tgt_mask"].to(device)
|
| 79 |
+
|
| 80 |
+
# --- forward ---
|
| 81 |
+
z0 = ae.encode(ids, mask) # [B,L,D]
|
| 82 |
+
logits = ae.decode(z0, attention_mask=mask) # [B,L,V]
|
| 83 |
+
pred = logits.argmax(dim=-1) # [B,L]
|
| 84 |
+
|
| 85 |
+
# --- masked CE ---
|
| 86 |
+
labels = ids.masked_fill(mask == 0, -100)
|
| 87 |
+
ce = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-100, reduction="sum")
|
| 88 |
+
total_ce += ce.item()
|
| 89 |
+
|
| 90 |
+
# --- token acc (masked) ---
|
| 91 |
+
correct = ((pred == ids) & (mask.bool())).sum().item()
|
| 92 |
+
tok = mask.sum().item()
|
| 93 |
+
total_acc += correct
|
| 94 |
+
total_tokens += tok
|
| 95 |
+
|
| 96 |
+
# --- EOS stats ---
|
| 97 |
+
# true/pred EOS position (first occurrence)
|
| 98 |
+
B, L = ids.shape
|
| 99 |
+
for i in range(B):
|
| 100 |
+
# only search within valid tokens
|
| 101 |
+
valid_len = int(mask[i].sum().item())
|
| 102 |
+
true_seq = ids[i, :valid_len]
|
| 103 |
+
pred_seq = pred[i, :valid_len]
|
| 104 |
+
|
| 105 |
+
true_pos = (true_seq == stop_id).nonzero(as_tuple=True)[0]
|
| 106 |
+
pred_pos = (pred_seq == stop_id).nonzero(as_tuple=True)[0]
|
| 107 |
+
|
| 108 |
+
if pred_pos.numel() > 0:
|
| 109 |
+
eos_found += 1
|
| 110 |
+
if true_pos.numel() > 0:
|
| 111 |
+
eos_count += 1
|
| 112 |
+
tpos = int(true_pos[0].item())
|
| 113 |
+
ppos = int(pred_pos[0].item()) if pred_pos.numel() > 0 else valid_len - 1
|
| 114 |
+
eos_pos_err += abs(ppos - tpos)
|
| 115 |
+
|
| 116 |
+
# --- latent cycle: z0 -> token -> z1 ---
|
| 117 |
+
z1 = ae.encode(pred, mask)
|
| 118 |
+
cos = F.cosine_similarity(z0, z1, dim=-1) # [B,L]
|
| 119 |
+
l2 = (z0 - z1).pow(2).mean(dim=-1) # [B,L]
|
| 120 |
+
total_cos += (cos * mask).sum().item()
|
| 121 |
+
total_l2 += (l2 * mask).sum().item()
|
| 122 |
+
total_lat_tokens += mask.sum().item()
|
| 123 |
+
|
| 124 |
+
# --- print a few examples ---
|
| 125 |
+
if printed < args.print_n:
|
| 126 |
+
s = tokenizer.decode(ids[0], skip_special_tokens=True)
|
| 127 |
+
## 这里没有进行 pos 截断
|
| 128 |
+
# valid_len = int(mask[0].sum().item())
|
| 129 |
+
# pred_seq = pred[0, :valid_len]
|
| 130 |
+
# # 找 stop(eos/sep)
|
| 131 |
+
# end = _first_pos(pred_seq, stop_id, default=valid_len-1) + 1
|
| 132 |
+
# g = tokenizer.decode(pred_seq[:end], skip_special_tokens=True)
|
| 133 |
+
g = tokenizer.decode(pred[0], skip_special_tokens=True)
|
| 134 |
+
print("\n--- Example ---")
|
| 135 |
+
print("GT :", s)
|
| 136 |
+
print("REC:", g)
|
| 137 |
+
printed += 1
|
| 138 |
+
|
| 139 |
+
avg_ce = total_ce / max(total_tokens, 1.0)
|
| 140 |
+
avg_acc = total_acc / max(total_tokens, 1.0)
|
| 141 |
+
avg_cos = total_cos / max(total_lat_tokens, 1.0)
|
| 142 |
+
avg_l2 = total_l2 / max(total_lat_tokens, 1.0)
|
| 143 |
+
|
| 144 |
+
eos_found_rate = eos_found / max(total_tokens / args.max_seq_len, 1.0) # 近似 batch 数
|
| 145 |
+
eos_mae = eos_pos_err / max(eos_count, 1)
|
| 146 |
+
|
| 147 |
+
print("\n===== AE Metrics =====")
|
| 148 |
+
print(f"Masked CE per-token: {avg_ce:.4f}")
|
| 149 |
+
print(f"Token Acc (masked): {avg_acc:.4f}")
|
| 150 |
+
print(f"Latent cycle cosine(z0,z1): {avg_cos:.4f}")
|
| 151 |
+
print(f"Latent cycle l2(z0,z1): {avg_l2:.6f}")
|
| 152 |
+
print(f"EOS found rate (rough): {eos_found_rate:.4f}")
|
| 153 |
+
print(f"EOS position MAE (only where GT has EOS): {eos_mae:.2f}")
|
| 154 |
+
|
| 155 |
+
if __name__ == "__main__":
|
| 156 |
+
main()
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
transformers
|
| 3 |
+
datasets
|
| 4 |
+
accelerate
|
| 5 |
+
scikit-learn
|
| 6 |
+
timm
|
| 7 |
+
evaluate
|
| 8 |
+
sacrebleu
|
| 9 |
+
sacremoses
|
run_repair_flow.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.optim as optim
|
| 3 |
+
from transformers import AutoTokenizer
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
|
| 6 |
+
from src.config import ModelConfig, TrainConfig
|
| 7 |
+
from src.models.autoencoder import SphericalAutoencoder
|
| 8 |
+
from src.models.dit import PatchedFlowDiT
|
| 9 |
+
from src.trainer import Trainer
|
| 10 |
+
from src.utils.data_utils import prepare_data
|
| 11 |
+
from src.utils.sandbox import SafeSandbox
|
| 12 |
+
from src.search import DiffuMCTS
|
| 13 |
+
|
| 14 |
+
def inference(ae, flow, src_ids, src_mask, device, steps=10):
|
| 15 |
+
ae.eval(); flow.eval()
|
| 16 |
+
with torch.no_grad():
|
| 17 |
+
# Encode Source (Buggy) -> z_0
|
| 18 |
+
z_curr = ae.encode(src_ids, src_mask)
|
| 19 |
+
z_cond = z_curr.clone()
|
| 20 |
+
|
| 21 |
+
dt = 1.0 / steps
|
| 22 |
+
for i in range(steps):
|
| 23 |
+
t = torch.ones(z_curr.shape[0], device=device) * (i / steps)
|
| 24 |
+
v = flow(z_curr, t, condition=z_cond).float()
|
| 25 |
+
z_curr = z_curr + v * dt
|
| 26 |
+
|
| 27 |
+
z_curr = torch.nn.functional.normalize(z_curr, p=2, dim=-1)
|
| 28 |
+
logits = ae.decode(z_curr)
|
| 29 |
+
return torch.argmax(logits, dim=-1)
|
| 30 |
+
|
| 31 |
+
def evaluate_on_humaneval(ae, flow, tokenizer, device, num_samples=20):
|
| 32 |
+
"""
|
| 33 |
+
在 HumanEvalPack 上进行真实的执行测试
|
| 34 |
+
"""
|
| 35 |
+
print("\n>>> Starting Evaluation on HumanEvalPack (Real Execution)...")
|
| 36 |
+
loader = prepare_data("humanevalpack", tokenizer, 512, 1, split="test")
|
| 37 |
+
sandbox = SafeSandbox()
|
| 38 |
+
|
| 39 |
+
passed = 0
|
| 40 |
+
total = 0
|
| 41 |
+
|
| 42 |
+
# 只测前 num_samples 个,节省时间
|
| 43 |
+
for i, batch in enumerate(tqdm(loader, total=num_samples)):
|
| 44 |
+
if i >= num_samples: break
|
| 45 |
+
|
| 46 |
+
src = batch['src_ids'].to(device)
|
| 47 |
+
mask = batch['src_mask'].to(device)
|
| 48 |
+
test_code = batch['test_code'][0]
|
| 49 |
+
entry_point = batch['entry_point'][0]
|
| 50 |
+
|
| 51 |
+
# 1. Flow Inference
|
| 52 |
+
out_ids = inference(ae, flow, src, mask, device)
|
| 53 |
+
gen_code = tokenizer.decode(out_ids[0], skip_special_tokens=True)
|
| 54 |
+
|
| 55 |
+
# 2. Sandbox Execution
|
| 56 |
+
is_pass, msg = sandbox.run(gen_code, test_code, entry_point)
|
| 57 |
+
|
| 58 |
+
if is_pass:
|
| 59 |
+
passed += 1
|
| 60 |
+
|
| 61 |
+
total += 1
|
| 62 |
+
|
| 63 |
+
# 打印第一个 Case 看看效果
|
| 64 |
+
if i == 0:
|
| 65 |
+
print(f"\n[Case 0] Pass: {is_pass}")
|
| 66 |
+
print(f"Error: {msg}")
|
| 67 |
+
print(f"Generated:\n{gen_code[:200]}...")
|
| 68 |
+
|
| 69 |
+
print(f"\n=== Eval Result ===")
|
| 70 |
+
print(f"Pass@1: {passed}/{total} = {passed/total*100:.2f}%")
|
| 71 |
+
|
| 72 |
+
def evaluate_with_mcts(ae, flow, tokenizer, device, num_samples=20):
|
| 73 |
+
"""
|
| 74 |
+
使用 Diffu-MCTS 进行强化评估
|
| 75 |
+
"""
|
| 76 |
+
print(f"\n>>> Starting Diffu-MCTS Evaluation (samples={num_samples})...")
|
| 77 |
+
|
| 78 |
+
# 1. 准备数据和组件
|
| 79 |
+
loader = prepare_data("humanevalpack", tokenizer, 512, 1, split="test")
|
| 80 |
+
sandbox = SafeSandbox(timeout=2.0) # 2秒超时防止死循环
|
| 81 |
+
|
| 82 |
+
# 2. 初始化搜索器
|
| 83 |
+
mcts = DiffuMCTS(ae, flow, tokenizer, sandbox, device, config=None)
|
| 84 |
+
mcts.num_branches = 8 # 设定分支数 K=8
|
| 85 |
+
|
| 86 |
+
passed = 0
|
| 87 |
+
total = 0
|
| 88 |
+
|
| 89 |
+
# 3. 评估循环
|
| 90 |
+
for i, batch in enumerate(tqdm(loader, total=num_samples)):
|
| 91 |
+
if i >= num_samples: break
|
| 92 |
+
|
| 93 |
+
# 提取原始文本 (因为 MCTS 内部会处理 Tokenize)
|
| 94 |
+
# 这里的 batch['src_ids'] 是 tensor,我们需要原始 string
|
| 95 |
+
# 但 data_loader 把 string 丢了,所以我们这里反解码一下,或者修改 prepare_data 返回 raw text
|
| 96 |
+
# 为了简单,我们反解码 Buggy Code
|
| 97 |
+
src_ids = batch['src_ids'].to(device)
|
| 98 |
+
buggy_code = tokenizer.decode(src_ids[0], skip_special_tokens=True)
|
| 99 |
+
|
| 100 |
+
test_code = batch['test_code'][0]
|
| 101 |
+
entry_point = batch['entry_point'][0]
|
| 102 |
+
|
| 103 |
+
# --- 调用 MCTS ---
|
| 104 |
+
fixed_code, is_success = mcts.solve(buggy_code, test_code, entry_point)
|
| 105 |
+
|
| 106 |
+
if is_success:
|
| 107 |
+
passed += 1
|
| 108 |
+
|
| 109 |
+
total += 1
|
| 110 |
+
|
| 111 |
+
# Log 第一个样本
|
| 112 |
+
if i == 0:
|
| 113 |
+
print(f"\n[Case 0]")
|
| 114 |
+
print(f"Buggy:\n{buggy_code[:100]}...")
|
| 115 |
+
print(f"Fixed:\n{fixed_code[:100]}...")
|
| 116 |
+
print(f"Result: {'✅ PASS' if is_success else '❌ FAIL'}")
|
| 117 |
+
|
| 118 |
+
print(f"\n=== MCTS Results ===")
|
| 119 |
+
print(f"Pass@1 (with Search K={mcts.num_branches}): {passed}/{total} = {passed/total*100:.2f}%")
|
| 120 |
+
|
| 121 |
+
def main():
|
| 122 |
+
m_cfg = ModelConfig()
|
| 123 |
+
t_cfg = TrainConfig(batch_size=8, grad_accum_steps=4)
|
| 124 |
+
|
| 125 |
+
tokenizer = AutoTokenizer.from_pretrained(m_cfg.encoder_name, trust_remote_code=True)
|
| 126 |
+
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
|
| 127 |
+
|
| 128 |
+
# 1. Load CodeXGLUE for Training
|
| 129 |
+
train_loader = prepare_data("codexglue", tokenizer, m_cfg.max_seq_len, t_cfg.batch_size, split="train")
|
| 130 |
+
|
| 131 |
+
ae = SphericalAutoencoder(m_cfg).to(t_cfg.device).float()
|
| 132 |
+
# Patch pad token
|
| 133 |
+
if ae.encoder.config.pad_token_id is None:
|
| 134 |
+
ae.encoder.config.pad_token_id = tokenizer.pad_token_id
|
| 135 |
+
|
| 136 |
+
flow = PatchedFlowDiT(m_cfg).to(t_cfg.device).float()
|
| 137 |
+
|
| 138 |
+
trainer = Trainer(ae, flow, t_cfg, train_loader)
|
| 139 |
+
|
| 140 |
+
# --- Training Loop ---
|
| 141 |
+
|
| 142 |
+
# Step 1: Train AE
|
| 143 |
+
print("\n>>> Training AE on CodeXGLUE...")
|
| 144 |
+
opt_ae = optim.AdamW(filter(lambda p: p.requires_grad, ae.parameters()), lr=t_cfg.lr_ae)
|
| 145 |
+
for epoch in range(t_cfg.num_epochs_ae):
|
| 146 |
+
loss = trainer.train_ae(opt_ae)
|
| 147 |
+
print(f"AE Epoch {epoch}: Loss {loss:.4f}")
|
| 148 |
+
|
| 149 |
+
# Step 2: Train Flow
|
| 150 |
+
print("\n>>> Training Flow Matching on CodeXGLUE...")
|
| 151 |
+
opt_flow = optim.AdamW(flow.parameters(), lr=t_cfg.lr_flow)
|
| 152 |
+
for epoch in range(t_cfg.num_epochs_flow):
|
| 153 |
+
loss = trainer.train_flow(opt_flow)
|
| 154 |
+
print(f"Flow Epoch {epoch}: Loss {loss:.4f}")
|
| 155 |
+
|
| 156 |
+
# --- Evaluation ---
|
| 157 |
+
evaluate_on_humaneval(ae, flow, tokenizer, t_cfg.device)
|
| 158 |
+
# 训练结束后,进行 MCTS 评估
|
| 159 |
+
evaluate_with_mcts(ae, flow, tokenizer, t_cfg.device, num_samples=50)
|
| 160 |
+
|
| 161 |
+
if __name__ == "__main__":
|
| 162 |
+
main()
|
run_wiki_flow.py
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.optim as optim
|
| 3 |
+
from transformers import AutoTokenizer
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import os
|
| 7 |
+
import evaluate
|
| 8 |
+
import sacrebleu
|
| 9 |
+
|
| 10 |
+
from src.config import ModelConfig, TrainConfig
|
| 11 |
+
from src.models.autoencoder import ReshapedAutoencoder
|
| 12 |
+
from src.models.dit import PatchedFlowDiT
|
| 13 |
+
from src.trainer import Trainer
|
| 14 |
+
from src.utils.data_utils import prepare_data
|
| 15 |
+
|
| 16 |
+
### 加上判断eos的函数
|
| 17 |
+
def _pick_stop_id(tokenizer):
|
| 18 |
+
# BERT/Jina 系通常 eos_token_id=None,用 sep_token_id 作为终止符
|
| 19 |
+
return tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.sep_token_id
|
| 20 |
+
|
| 21 |
+
def _first_pos(x_1d, token_id, default):
|
| 22 |
+
# x_1d: [L]
|
| 23 |
+
idx = (x_1d == token_id).nonzero(as_tuple=True)[0]
|
| 24 |
+
return idx[0].item() if idx.numel() > 0 else default
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def calculate_metrics(sources, predictions, references):
|
| 28 |
+
"""
|
| 29 |
+
计算 SARI, BLEU, 和 压缩比
|
| 30 |
+
"""
|
| 31 |
+
## 这里尝试去huggingface hub 去下载 BLEU的评估脚本,但是因为网络问题没有找到
|
| 32 |
+
# sari_metric = evaluate.load("sari")
|
| 33 |
+
# bleu_metric = evaluate.load("bleu")
|
| 34 |
+
# SARI 需要 sources
|
| 35 |
+
# sari_score = sari_metric.compute(sources=sources, predictions=predictions, references=[[r] for r in references])
|
| 36 |
+
# # BLEU
|
| 37 |
+
# bleu_score = bleu_metric.compute(predictions=predictions, references=[[r] for r in references])
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# 1. BLEU
|
| 41 |
+
# sacrebleu 期望 references 是 List[List[str]] (多个参考)
|
| 42 |
+
# 这里的 references 是 List[str] (单个参考)
|
| 43 |
+
# 所以需要 transpose 一下: [[ref1, ref2, ...]]
|
| 44 |
+
bleu = sacrebleu.corpus_bleu(predictions, [references])
|
| 45 |
+
|
| 46 |
+
# 2. SARI
|
| 47 |
+
try:
|
| 48 |
+
# corpus_sari 返回值就是一个 SARI 对象,它的 score 属性是 float
|
| 49 |
+
sari = sacrebleu.corpus_sari(sources, predictions, [references])
|
| 50 |
+
sari_score = sari.score
|
| 51 |
+
except Exception as e:
|
| 52 |
+
print(f"SARI calculation failed: {e}")
|
| 53 |
+
sari_score = 0.0
|
| 54 |
+
|
| 55 |
+
# 3. Compression Ratio
|
| 56 |
+
ratios = [len(p) / len(s) if len(s) > 0 else 0 for p, s in zip(predictions, sources)]
|
| 57 |
+
avg_ratio = sum(ratios) / len(ratios)
|
| 58 |
+
|
| 59 |
+
return {
|
| 60 |
+
"SARI": sari_score, # 直接使用 float
|
| 61 |
+
"BLEU": bleu.score,
|
| 62 |
+
"Compression Ratio": avg_ratio
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
@torch.no_grad()
|
| 66 |
+
def inference_batch(ae, flow, loader, tokenizer, device, steps=10, save_path="results.txt",use_oneshot=True):
|
| 67 |
+
ae.eval()
|
| 68 |
+
flow.eval()
|
| 69 |
+
|
| 70 |
+
stop_id = _pick_stop_id(tokenizer)
|
| 71 |
+
pad_id = tokenizer.pad_token_id
|
| 72 |
+
|
| 73 |
+
print(f"\n>>> Running Inference on {len(loader.dataset)} examples...")
|
| 74 |
+
|
| 75 |
+
all_sources = []
|
| 76 |
+
all_targets = []
|
| 77 |
+
all_generated = []
|
| 78 |
+
|
| 79 |
+
scale = getattr(ae, "latent_scale", 10.0)
|
| 80 |
+
|
| 81 |
+
with open(save_path, "w", encoding="utf-8") as f:
|
| 82 |
+
f.write("Source\tTarget\tGenerated\n")
|
| 83 |
+
|
| 84 |
+
for batch in tqdm(loader, desc="Inferencing"):
|
| 85 |
+
src_ids = batch['src_ids'].to(device)
|
| 86 |
+
src_mask = batch['src_mask'].to(device)
|
| 87 |
+
tgt_ids = batch['tgt_ids'].to(device)
|
| 88 |
+
B, L = src_ids.shape
|
| 89 |
+
|
| 90 |
+
z_curr = ae.encode(src_ids, src_mask)
|
| 91 |
+
z_cond = z_curr.clone()
|
| 92 |
+
|
| 93 |
+
## 这里分别采用 one-shot 和多布采样
|
| 94 |
+
if use_oneshot:
|
| 95 |
+
# x-pred 最稳:直接 t=0 one-shot
|
| 96 |
+
t0 = torch.zeros(B, device=device)
|
| 97 |
+
z_curr = flow(z_curr, t0, condition=z_cond).float()
|
| 98 |
+
else:
|
| 99 |
+
dt = 1.0 / steps
|
| 100 |
+
for i in range(steps):
|
| 101 |
+
t_val = i / steps
|
| 102 |
+
# 避免 t=1 时的除零错误 (虽不常见但要防范)
|
| 103 |
+
if t_val >= 0.999: break
|
| 104 |
+
t = torch.ones(z_curr.shape[0], device=device) * t_val
|
| 105 |
+
|
| 106 |
+
## from v to z
|
| 107 |
+
# v = flow(z_curr, t, condition=z_cond).float()
|
| 108 |
+
# z_curr = z_curr + v * dt
|
| 109 |
+
## from z to v to zcur
|
| 110 |
+
|
| 111 |
+
pred_z1 = flow(z_curr, t, condition=z_cond).float()
|
| 112 |
+
## maybe optimize: 1 - t_val -> 1
|
| 113 |
+
v = (pred_z1 - z_curr) / (1.0 - t_val + + 1e-4) # add epilson
|
| 114 |
+
z_curr = z_curr + v * dt
|
| 115 |
+
z_curr = F.normalize(z_curr, p=2, dim=-1) * scale
|
| 116 |
+
z_curr = pred_z1 # 最后一次终点预测直接使用
|
| 117 |
+
|
| 118 |
+
z_curr = torch.nn.functional.normalize(z_curr, p=2, dim=-1) * scale ## scaling 对齐
|
| 119 |
+
|
| 120 |
+
# ---- 3) two-pass decode to determine length by EOS ----
|
| 121 |
+
full_mask = torch.ones(B, L, device=device) # 允许增长:全长都“可生成”
|
| 122 |
+
|
| 123 |
+
# Pass-1: decode with full mask
|
| 124 |
+
logits1 = ae.decode(z_curr, attention_mask=full_mask)
|
| 125 |
+
ids1 = logits1.argmax(dim=-1) # [B, L]
|
| 126 |
+
|
| 127 |
+
# find stop positions and build gen_mask
|
| 128 |
+
stop_pos = []
|
| 129 |
+
for i in range(B):
|
| 130 |
+
# 如果没预测到 stop,就用 L-1 当作“最大长度”
|
| 131 |
+
pos = _first_pos(ids1[i], stop_id, default=L - 1)
|
| 132 |
+
stop_pos.append(pos)
|
| 133 |
+
stop_pos = torch.tensor(stop_pos, device=device)
|
| 134 |
+
|
| 135 |
+
gen_mask = torch.zeros(B, L, device=device)
|
| 136 |
+
for i in range(B):
|
| 137 |
+
gen_mask[i, : stop_pos[i].item() + 1] = 1.0
|
| 138 |
+
|
| 139 |
+
# Pass-2: decode again with gen_mask, reducing tail interference
|
| 140 |
+
logits2 = ae.decode(z_curr, attention_mask=gen_mask)
|
| 141 |
+
ids2 = logits2.argmax(dim=-1)
|
| 142 |
+
|
| 143 |
+
# enforce pad after stop for clean decoding
|
| 144 |
+
ids2 = ids2.masked_fill(gen_mask == 0, pad_id)
|
| 145 |
+
|
| 146 |
+
# ---- 4) decode to text with truncation ----
|
| 147 |
+
src_texts = tokenizer.batch_decode(src_ids, skip_special_tokens=True)
|
| 148 |
+
tgt_texts = tokenizer.batch_decode(tgt_ids, skip_special_tokens=True)
|
| 149 |
+
|
| 150 |
+
gen_texts = []
|
| 151 |
+
for i in range(B):
|
| 152 |
+
end = stop_pos[i].item() + 1
|
| 153 |
+
ids_cut = ids2[i, :end]
|
| 154 |
+
gen_texts.append(tokenizer.decode(ids_cut, skip_special_tokens=True))
|
| 155 |
+
|
| 156 |
+
# Save & Collect
|
| 157 |
+
for s, t, g in zip(src_texts, tgt_texts, gen_texts):
|
| 158 |
+
# 简单的后处理:去掉换行符以便存成 TSV
|
| 159 |
+
s_clean = s.replace("\n", " ")
|
| 160 |
+
t_clean = t.replace("\n", " ")
|
| 161 |
+
g_clean = g.replace("\n", " ")
|
| 162 |
+
|
| 163 |
+
f.write(f"{s_clean}\t{t_clean}\t{g_clean}\n")
|
| 164 |
+
|
| 165 |
+
all_sources.append(s_clean)
|
| 166 |
+
all_targets.append(t_clean)
|
| 167 |
+
all_generated.append(g_clean)
|
| 168 |
+
|
| 169 |
+
return all_sources, all_targets, all_generated
|
| 170 |
+
|
| 171 |
+
### add saving ckpts
|
| 172 |
+
def main():
|
| 173 |
+
|
| 174 |
+
ckpt_dir = "checkpoints"
|
| 175 |
+
os.makedirs(ckpt_dir, exist_ok=True)
|
| 176 |
+
print(f"Checkpoints will be saved to: {ckpt_dir}")
|
| 177 |
+
|
| 178 |
+
# Config
|
| 179 |
+
m_cfg = ModelConfig(
|
| 180 |
+
encoder_name='../jina-embeddings-v2-base-code',
|
| 181 |
+
latent_dim=512,
|
| 182 |
+
max_seq_len=128 # Wiki 任务文本短,用 128 足够且快
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
t_cfg = TrainConfig(
|
| 186 |
+
batch_size=16, # 推理时可以大一点
|
| 187 |
+
num_epochs_ae=20, # 增加一点 AE 训练
|
| 188 |
+
num_epochs_flow=35, # 增加 Flow 训练
|
| 189 |
+
grad_accum_steps=4,
|
| 190 |
+
use_amp=False
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
tokenizer = AutoTokenizer.from_pretrained(m_cfg.encoder_name, trust_remote_code=True)
|
| 194 |
+
|
| 195 |
+
# 1. Load Data (Train & Test)
|
| 196 |
+
train_loader = prepare_data("wiki", tokenizer, m_cfg.max_seq_len, t_cfg.batch_size, split="train")
|
| 197 |
+
test_loader = prepare_data("wiki", tokenizer, m_cfg.max_seq_len, t_cfg.batch_size, split="test")
|
| 198 |
+
|
| 199 |
+
# Init
|
| 200 |
+
ae = ReshapedAutoencoder(m_cfg).to(t_cfg.device).float()
|
| 201 |
+
flow = PatchedFlowDiT(m_cfg).to(t_cfg.device).float()
|
| 202 |
+
|
| 203 |
+
if ae.encoder.config.pad_token_id is None:
|
| 204 |
+
ae.encoder.config.pad_token_id = tokenizer.pad_token_id
|
| 205 |
+
|
| 206 |
+
# trainer = Trainer(ae, flow, t_cfg, train_loader)
|
| 207 |
+
## 加上pad_id 和 stop_id
|
| 208 |
+
trainer = Trainer(ae, flow, t_cfg, train_loader, pad_id=tokenizer.pad_token_id, stop_id=_pick_stop_id(tokenizer))
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
# 2. Train AE
|
| 212 |
+
opt_ae = optim.AdamW(filter(lambda p: p.requires_grad, ae.parameters()), lr=t_cfg.lr_ae)
|
| 213 |
+
best_ae_loss = float('inf')
|
| 214 |
+
print("\n>>> Start Training Autoencoder...")
|
| 215 |
+
for epoch in range(t_cfg.num_epochs_ae):
|
| 216 |
+
loss = trainer.train_ae(opt_ae)
|
| 217 |
+
print(f"AE Epoch {epoch}: Loss {loss:.4f}")
|
| 218 |
+
|
| 219 |
+
# 保存 Best
|
| 220 |
+
if loss < best_ae_loss:
|
| 221 |
+
best_ae_loss = loss
|
| 222 |
+
torch.save(ae.state_dict(), os.path.join(ckpt_dir, "ae_best.pt"))
|
| 223 |
+
# print(f" Saved Best AE (Loss {loss:.4f})")
|
| 224 |
+
|
| 225 |
+
# 保存 Last (每个 epoch 覆盖,用于断点续训或检查)
|
| 226 |
+
torch.save(ae.state_dict(), os.path.join(ckpt_dir, "ae_last.pt"))
|
| 227 |
+
|
| 228 |
+
print(f"AE Training Done. Best Loss: {best_ae_loss:.4f}")
|
| 229 |
+
|
| 230 |
+
# 3. Train Flow
|
| 231 |
+
opt_flow = optim.AdamW(flow.parameters(), lr=t_cfg.lr_flow)
|
| 232 |
+
best_flow_loss = float('inf')
|
| 233 |
+
print("\n>>> Start Training Flow DiT...")
|
| 234 |
+
for epoch in range(t_cfg.num_epochs_flow):
|
| 235 |
+
loss = trainer.train_flow(opt_flow)
|
| 236 |
+
print(f"Flow Epoch {epoch}: Loss {loss:.4f}")
|
| 237 |
+
|
| 238 |
+
# 保存 Best
|
| 239 |
+
if loss < best_flow_loss:
|
| 240 |
+
best_flow_loss = loss
|
| 241 |
+
torch.save(flow.state_dict(), os.path.join(ckpt_dir, "flow_best.pt"))
|
| 242 |
+
# print(f" Saved Best Flow (Loss {loss:.4f})")
|
| 243 |
+
|
| 244 |
+
# 保存 Last
|
| 245 |
+
torch.save(flow.state_dict(), os.path.join(ckpt_dir, "flow_last.pt"))
|
| 246 |
+
|
| 247 |
+
print(f"Flow Training Done. Best Loss: {best_flow_loss:.4f}")
|
| 248 |
+
|
| 249 |
+
# 4. Evaluation
|
| 250 |
+
# 加载最佳权重
|
| 251 |
+
ae_path = os.path.join(ckpt_dir, "ae_best.pt")
|
| 252 |
+
flow_path = os.path.join(ckpt_dir, "flow_best.pt")
|
| 253 |
+
|
| 254 |
+
if os.path.exists(ae_path):
|
| 255 |
+
ae.load_state_dict(torch.load(ae_path, map_location=t_cfg.device))
|
| 256 |
+
print("Loaded AE Best.")
|
| 257 |
+
else:
|
| 258 |
+
print("Warning: AE Best ckpt not found, using last state.")
|
| 259 |
+
|
| 260 |
+
if os.path.exists(flow_path):
|
| 261 |
+
flow.load_state_dict(torch.load(flow_path, map_location=t_cfg.device))
|
| 262 |
+
print("Loaded Flow Best.")
|
| 263 |
+
else:
|
| 264 |
+
print("Warning: Flow Best ckpt not found, using last state.")
|
| 265 |
+
|
| 266 |
+
print("\n--- Starting Inference ---")
|
| 267 |
+
sources, targets, gens = inference_batch(
|
| 268 |
+
ae, flow, test_loader, tokenizer, t_cfg.device,
|
| 269 |
+
steps=10,
|
| 270 |
+
save_path="wiki_results.tsv"
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
# Calculate Metrics
|
| 274 |
+
metrics = calculate_metrics(sources, gens, targets)
|
| 275 |
+
print("\n=== Metrics ===")
|
| 276 |
+
for k, v in metrics.items():
|
| 277 |
+
print(f"{k}: {v:.4f}")
|
| 278 |
+
|
| 279 |
+
print(f"\nResults saved to wiki_results.tsv")
|
| 280 |
+
|
| 281 |
+
if __name__ == "__main__":
|
| 282 |
+
main()
|
src/__init__.py
ADDED
|
File without changes
|
src/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (174 Bytes). View file
|
|
|
src/__pycache__/config.cpython-311.pyc
ADDED
|
Binary file (2.29 kB). View file
|
|
|
src/__pycache__/search.cpython-311.pyc
ADDED
|
Binary file (4.62 kB). View file
|
|
|
src/__pycache__/trainer.cpython-311.pyc
ADDED
|
Binary file (13 kB). View file
|
|
|
src/config.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
@dataclass
|
| 5 |
+
class ModelConfig:
|
| 6 |
+
# sequence latent space config
|
| 7 |
+
encoder_name: str = "../jina-embeddings-v2-base-code" #"jinaai/jina-embeddings-v2-base-code","microsoft/codebert-base" # or roberta-base
|
| 8 |
+
input_dim: int = 768 # Jina Base is 768
|
| 9 |
+
latent_dim: int = 768 # 保留最大语义
|
| 10 |
+
decoder_layers: int = 4 # simple NAR decoder
|
| 11 |
+
|
| 12 |
+
# VAE Adapter config
|
| 13 |
+
max_seq_len: int = 2048 # set according to task
|
| 14 |
+
patch_size: int = 4 # patching compress rate
|
| 15 |
+
|
| 16 |
+
# DiT setting
|
| 17 |
+
dit_layers: int = 12
|
| 18 |
+
dit_heads: int = 8
|
| 19 |
+
dit_hidden: int = 768 # hidden width, less than latent_dim*patch_size to cut oom
|
| 20 |
+
mlp_ratio: float = 4.0
|
| 21 |
+
|
| 22 |
+
# @property
|
| 23 |
+
# def dit_hidden(self):
|
| 24 |
+
# return self.latent_dim
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class TrainConfig:
|
| 28 |
+
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
| 29 |
+
|
| 30 |
+
lr_ae: float = 1e-4
|
| 31 |
+
lr_flow: float = 5e-4
|
| 32 |
+
batch_size: int = 8
|
| 33 |
+
grad_accum_steps: int = 4 # 梯度积累,等效于Batch_size = 32
|
| 34 |
+
|
| 35 |
+
num_epochs_ae: int = 20 # 先训练AE 再训练Flow
|
| 36 |
+
num_epochs_flow: int = 50 # flow 需要训练的论数要多一些
|
| 37 |
+
|
| 38 |
+
grad_clip: float = 1.0
|
| 39 |
+
use_amp: bool = False # 混合精度训练,Jina+AMP 容易报错
|
| 40 |
+
save_dir: str = "./checkpoints"
|
| 41 |
+
|
| 42 |
+
def __post_init__(self):
|
| 43 |
+
import os
|
| 44 |
+
os.makedirs(self.save_dir, exist_ok=True)
|
src/models/__init__.py
ADDED
|
File without changes
|
src/models/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (181 Bytes). View file
|
|
|
src/models/__pycache__/autoencoder.cpython-311.pyc
ADDED
|
Binary file (9.72 kB). View file
|
|
|
src/models/__pycache__/dit.cpython-311.pyc
ADDED
|
Binary file (17.1 kB). View file
|
|
|
src/models/autoencoder.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from transformers import AutoModel, AutoConfig
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from transformers import AutoModel
|
| 9 |
+
|
| 10 |
+
class ResidualBlock(nn.Module):
|
| 11 |
+
def __init__(self, dim):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.net = nn.Sequential(
|
| 14 |
+
nn.Linear(dim, dim * 2),
|
| 15 |
+
nn.GELU(),
|
| 16 |
+
nn.Linear(dim * 2, dim)
|
| 17 |
+
)
|
| 18 |
+
self.norm = nn.LayerNorm(dim)
|
| 19 |
+
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
return self.norm(x + self.net(x))
|
| 22 |
+
|
| 23 |
+
class ResidualAutoencoder(nn.Module):
|
| 24 |
+
def __init__(self, cfg):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.cfg = cfg
|
| 27 |
+
|
| 28 |
+
# 1. Encoder (Frozen)
|
| 29 |
+
print(f"Loading Encoder: {cfg.encoder_name}...")
|
| 30 |
+
self.encoder = AutoModel.from_pretrained(cfg.encoder_name, trust_remote_code=True)
|
| 31 |
+
self.hidden_dim = self.encoder.config.hidden_size
|
| 32 |
+
for p in self.encoder.parameters(): p.requires_grad = False
|
| 33 |
+
|
| 34 |
+
# 2. Latent Processor (No Dimension Reduction)
|
| 35 |
+
# 保持 768 维度,只做特征整理
|
| 36 |
+
# 使用残差块保证梯度流
|
| 37 |
+
self.compressor = nn.Sequential(
|
| 38 |
+
nn.Linear(self.hidden_dim, self.hidden_dim), # 可选
|
| 39 |
+
ResidualBlock(self.hidden_dim),
|
| 40 |
+
# ResidualBlock(self.hidden_dim) # 可选:加深
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
self.decompressor = nn.Sequential(
|
| 44 |
+
ResidualBlock(self.hidden_dim),
|
| 45 |
+
# ResidualBlock(self.hidden_dim),
|
| 46 |
+
nn.Linear(self.hidden_dim, self.hidden_dim)
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# 3. Decoder (Pretrained)
|
| 50 |
+
print(f"Loading Decoder: {cfg.encoder_name}...")
|
| 51 |
+
self.decoder = AutoModel.from_pretrained(cfg.encoder_name, trust_remote_code=True)
|
| 52 |
+
self.decoder.config.is_decoder = False
|
| 53 |
+
|
| 54 |
+
# 4. Head
|
| 55 |
+
self.lm_head = nn.Linear(self.hidden_dim, self.encoder.config.vocab_size, bias=False)
|
| 56 |
+
with torch.no_grad():
|
| 57 |
+
self.lm_head.weight.copy_(self.encoder.embeddings.word_embeddings.weight)
|
| 58 |
+
self.lm_head.weight.requires_grad = True
|
| 59 |
+
|
| 60 |
+
def encode(self, input_ids, attention_mask):
|
| 61 |
+
with torch.no_grad():
|
| 62 |
+
enc_out = self.encoder(input_ids, attention_mask=attention_mask).last_hidden_state
|
| 63 |
+
return self.compressor(enc_out)
|
| 64 |
+
|
| 65 |
+
def decode(self, z, attention_mask):
|
| 66 |
+
h = self.decompressor(z)
|
| 67 |
+
dec_out = self.decoder(inputs_embeds=h, attention_mask=attention_mask).last_hidden_state
|
| 68 |
+
return self.lm_head(dec_out)
|
| 69 |
+
|
| 70 |
+
def forward(self, input_ids, attention_mask):
|
| 71 |
+
z = self.encode(input_ids, attention_mask)
|
| 72 |
+
logits = self.decode(z, attention_mask)
|
| 73 |
+
return logits, z
|
| 74 |
+
|
| 75 |
+
class ReshapedAutoencoder(nn.Module):
|
| 76 |
+
"""
|
| 77 |
+
Sequence-to-Sequence Autoencoder with Spherical Latent Space.
|
| 78 |
+
Logic: Token -> Jina -> Linear -> Linear -> Decoder -> Token
|
| 79 |
+
"""
|
| 80 |
+
def __init__(self, cfg):
|
| 81 |
+
super().__init__()
|
| 82 |
+
self.cfg = cfg
|
| 83 |
+
|
| 84 |
+
self.latent_scale = getattr(cfg,"latent_scale",10.0)
|
| 85 |
+
|
| 86 |
+
# 1. Encoder (Frozen Jina)
|
| 87 |
+
print(f"Loading Pretrained Encoder: {cfg.encoder_name}...")
|
| 88 |
+
# self.encoder = AutoModel.from_pretrained(cfg.encoder_name,local_files_only=True, trust_remote_code=False)
|
| 89 |
+
self.encoder = AutoModel.from_pretrained(cfg.encoder_name,trust_remote_code=True)
|
| 90 |
+
self.hidden_dim = self.encoder.config.hidden_size
|
| 91 |
+
self.vocab_size = self.encoder.config.vocab_size
|
| 92 |
+
|
| 93 |
+
# 冻结 Encoder 参数
|
| 94 |
+
for param in self.encoder.parameters():
|
| 95 |
+
param.requires_grad = False
|
| 96 |
+
|
| 97 |
+
# 放弃强制 Normalize,使用 LayerNorm 进行“软约束”
|
| 98 |
+
# 结构: Hidden -> Project -> LayerNorm -> Latent
|
| 99 |
+
self.compress = nn.Sequential(
|
| 100 |
+
nn.Linear(self.hidden_dim, cfg.latent_dim),
|
| 101 |
+
nn.GELU(),
|
| 102 |
+
nn.Linear(cfg.latent_dim, cfg.latent_dim),
|
| 103 |
+
nn.LayerNorm(cfg.latent_dim) # 关键:让 latent 保持稳定分布,利于 Flow
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# 3. Decompressor
|
| 107 |
+
self.decompress = nn.Sequential(
|
| 108 |
+
nn.Linear(cfg.latent_dim, self.hidden_dim),
|
| 109 |
+
nn.GELU(),
|
| 110 |
+
nn.Linear(self.hidden_dim, self.hidden_dim),
|
| 111 |
+
nn.LayerNorm(self.hidden_dim)
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
# 4. Decoder (Pretrained!)
|
| 115 |
+
# <--- load from pretaining Config --->
|
| 116 |
+
print(f"Loading Pretrained Decoder: {cfg.encoder_name}...")
|
| 117 |
+
# self.decoder = AutoModel.from_pretrained(cfg.encoder_name,local_files_only=True,trust_remote_code=False)
|
| 118 |
+
self.decoder = AutoModel.from_pretrained(cfg.encoder_name,trust_remote_code=True)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# for BERT,is_decoder=False 双向 Attention,这正是 NAR 需要的
|
| 122 |
+
# 不需要 causal mask
|
| 123 |
+
self.decoder.config.is_decoder = False
|
| 124 |
+
|
| 125 |
+
# 5. Output Head (Trainable)
|
| 126 |
+
# 初始化为 Encoder 的 Embedding,但允许训练
|
| 127 |
+
self.lm_head = nn.Linear(self.hidden_dim, self.encoder.config.vocab_size, bias=False)
|
| 128 |
+
with torch.no_grad():
|
| 129 |
+
self.lm_head.weight.copy_(self.encoder.embeddings.word_embeddings.weight)
|
| 130 |
+
# 允许微调,以适应 decoder 输出的偏差
|
| 131 |
+
self.lm_head.weight.requires_grad = True
|
| 132 |
+
|
| 133 |
+
def encode(self, input_ids, attention_mask):
|
| 134 |
+
"""
|
| 135 |
+
Input: [B, L]
|
| 136 |
+
Output: [B, L, Latent_Dim]
|
| 137 |
+
"""
|
| 138 |
+
with torch.no_grad():
|
| 139 |
+
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
|
| 140 |
+
|
| 141 |
+
# Compression
|
| 142 |
+
z = self.compress(outputs.last_hidden_state) # [B, L, 768]
|
| 143 |
+
|
| 144 |
+
## increase the scale
|
| 145 |
+
# z = z * self.latent_scale
|
| 146 |
+
|
| 147 |
+
return z
|
| 148 |
+
|
| 149 |
+
## 需要传入attention-mask 但是这里的疑问是对于推理没有mask 怎么办,看上去也没有判断eos
|
| 150 |
+
def decode(self, latents,attention_mask=None):
|
| 151 |
+
"""
|
| 152 |
+
Input: [B, L, Latent_Dim]
|
| 153 |
+
Output: [B, L, Vocab]
|
| 154 |
+
"""
|
| 155 |
+
## back to the original scale
|
| 156 |
+
# latents = latents / self.latent_scale
|
| 157 |
+
# 1. Decompress (back to Hidden Size)
|
| 158 |
+
hidden = self.decompress(latents)
|
| 159 |
+
|
| 160 |
+
# 2. Backbone Forward (通过 inputs_embeds 注入)
|
| 161 |
+
# AutoModel 会自动处理 mask (NAR 模式下通常是全向注意力)
|
| 162 |
+
decoder_outputs = self.decoder(
|
| 163 |
+
inputs_embeds=hidden,
|
| 164 |
+
attention_mask=attention_mask
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
sequence_output = decoder_outputs.last_hidden_state
|
| 168 |
+
|
| 169 |
+
# 3. Logits
|
| 170 |
+
return self.lm_head(sequence_output)
|
| 171 |
+
|
| 172 |
+
# def forward(self, input_ids, attention_mask):
|
| 173 |
+
# z = self.encode(input_ids, attention_mask)
|
| 174 |
+
# logits= self.decode(z, attention_mask=attention_mask)
|
| 175 |
+
# return logits, z
|
| 176 |
+
def forward(self, input_ids, encoder_mask, decoder_mask=None):
|
| 177 |
+
if decoder_mask is None:
|
| 178 |
+
decoder_mask = encoder_mask
|
| 179 |
+
z = self.encode(input_ids, encoder_mask)
|
| 180 |
+
logits = self.decode(z, attention_mask=decoder_mask)
|
| 181 |
+
return logits, z
|
src/models/dit.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import math
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
# from timm.models.vision_transformer import Attention, Mlp -> handson_tims
|
| 7 |
+
|
| 8 |
+
class Mlp(nn.Module):
|
| 9 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 10 |
+
super().__init__()
|
| 11 |
+
out_features = out_features or in_features
|
| 12 |
+
hidden_features = hidden_features or in_features
|
| 13 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 14 |
+
self.act = act_layer()
|
| 15 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 16 |
+
self.drop = nn.Dropout(drop)
|
| 17 |
+
|
| 18 |
+
def forward(self, x):
|
| 19 |
+
x = self.fc1(x)
|
| 20 |
+
x = self.act(x)
|
| 21 |
+
x = self.drop(x)
|
| 22 |
+
x = self.fc2(x)
|
| 23 |
+
x = self.drop(x)
|
| 24 |
+
return x
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class Attention(nn.Module):
|
| 28 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.num_heads = num_heads
|
| 31 |
+
head_dim = dim // num_heads
|
| 32 |
+
self.scale = head_dim ** -0.5
|
| 33 |
+
|
| 34 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 35 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 36 |
+
self.proj = nn.Linear(dim, dim)
|
| 37 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 38 |
+
|
| 39 |
+
def forward(self, x):
|
| 40 |
+
B, N, C = x.shape
|
| 41 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 42 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 43 |
+
|
| 44 |
+
# attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 45 |
+
# attn = attn.softmax(dim=-1)
|
| 46 |
+
# attn = self.attn_drop(attn)
|
| 47 |
+
|
| 48 |
+
# x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 49 |
+
# x = self.proj(x)
|
| 50 |
+
# x = self.proj_drop(x)
|
| 51 |
+
|
| 52 |
+
## Replace: use Flash-Attention
|
| 53 |
+
x = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0)
|
| 54 |
+
|
| 55 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
| 56 |
+
x = self.proj(x)
|
| 57 |
+
x = self.proj_drop(x)
|
| 58 |
+
|
| 59 |
+
return x
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class Patch1D(nn.Module):
|
| 63 |
+
"""
|
| 64 |
+
[B, L, D] -> [B, L/P, D*P]
|
| 65 |
+
"""
|
| 66 |
+
def __init__(self, patch_size):
|
| 67 |
+
super().__init__()
|
| 68 |
+
self.patch_size = patch_size
|
| 69 |
+
|
| 70 |
+
def forward(self, x):
|
| 71 |
+
B, L, D = x.shape
|
| 72 |
+
# Pad sequence if not divisible by patch_size
|
| 73 |
+
# [B,31,4]->patch_size = 2 -> [B,16,8],pad is [x_31, padding_0,,,]
|
| 74 |
+
if L % self.patch_size != 0:
|
| 75 |
+
pad = self.patch_size - (L % self.patch_size)
|
| 76 |
+
x = F.pad(x, (0, 0, 0, pad))
|
| 77 |
+
|
| 78 |
+
B, L_new, D = x.shape
|
| 79 |
+
# View as patches
|
| 80 |
+
return x.view(B, L_new // self.patch_size, D * self.patch_size)
|
| 81 |
+
|
| 82 |
+
class Unpatch1D(nn.Module):
|
| 83 |
+
"""
|
| 84 |
+
[B, L/P, D*P] -> [B, L, D]
|
| 85 |
+
"""
|
| 86 |
+
def __init__(self, patch_size):
|
| 87 |
+
super().__init__()
|
| 88 |
+
self.patch_size = patch_size
|
| 89 |
+
|
| 90 |
+
def forward(self, x):
|
| 91 |
+
B, L_new, DP = x.shape
|
| 92 |
+
return x.view(B, L_new * self.patch_size, DP // self.patch_size)
|
| 93 |
+
|
| 94 |
+
### 这里DiT的pos_embed没有使用到三角函数;另外,没有forward_with_cfg的函数实现 -> 暂时没有label_embedding
|
| 95 |
+
## from: https://github.com/willisma/SiT/blob/main/models.py
|
| 96 |
+
class TimestepEmbedder(nn.Module):
|
| 97 |
+
"""Sinusoidal Time Embeddings"""
|
| 98 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
| 99 |
+
super().__init__()
|
| 100 |
+
self.mlp = nn.Sequential(
|
| 101 |
+
nn.Linear(frequency_embedding_size, hidden_size,bias=True),
|
| 102 |
+
nn.SiLU(),
|
| 103 |
+
nn.Linear(hidden_size, hidden_size,bias=True),
|
| 104 |
+
)
|
| 105 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 106 |
+
|
| 107 |
+
@staticmethod
|
| 108 |
+
def timestep_embedding(t, dim, max_period=10000):
|
| 109 |
+
"""
|
| 110 |
+
Create sinusoidal timestep embeddings.
|
| 111 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
| 112 |
+
These may be fractional.
|
| 113 |
+
:param dim: the dimension of the output.
|
| 114 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 115 |
+
:return: an (N, D) Tensor of positional embeddings.
|
| 116 |
+
"""
|
| 117 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
| 118 |
+
## 兼容更多的 t 格式
|
| 119 |
+
if t.ndim > 1:
|
| 120 |
+
t = t.view(-1)
|
| 121 |
+
|
| 122 |
+
half = dim // 2
|
| 123 |
+
freqs = torch.exp(
|
| 124 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
| 125 |
+
).to(device=t.device)
|
| 126 |
+
args = t[:, None].float() * freqs[None]
|
| 127 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 128 |
+
if dim % 2:
|
| 129 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 130 |
+
return embedding
|
| 131 |
+
|
| 132 |
+
def forward(self, t):
|
| 133 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
| 134 |
+
t_emb = self.mlp(t_freq)
|
| 135 |
+
return t_emb
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def modulate(x, shift, scale):
|
| 139 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
| 140 |
+
|
| 141 |
+
## DiTBlock, adaptive layer norm conditioning
|
| 142 |
+
class DiTBlock(nn.Module):
|
| 143 |
+
"""Transformer Block with Adaptive Layer Norm (adaLN)"""
|
| 144 |
+
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
|
| 145 |
+
super().__init__()
|
| 146 |
+
self.hidden_size = hidden_size
|
| 147 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 148 |
+
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True)
|
| 149 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 150 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 151 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
| 152 |
+
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
|
| 153 |
+
self.adaLN_modulation = nn.Sequential(
|
| 154 |
+
nn.SiLU(),
|
| 155 |
+
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
def forward(self, x, c):
|
| 159 |
+
# c shape: [B, hidden_size]
|
| 160 |
+
# adaLN_out shape 应该是 [B, 6 * hidden_size]
|
| 161 |
+
adaLN_out = self.adaLN_modulation(c)
|
| 162 |
+
|
| 163 |
+
# --- Debug 探针 (如果再次报错,请查看这里打印的形状) ---
|
| 164 |
+
if adaLN_out.shape[1] != 6 * self.hidden_size:
|
| 165 |
+
print(f"⚠️ DiTBlock Shape Error!")
|
| 166 |
+
print(f"Input c shape: {c.shape}")
|
| 167 |
+
print(f"adaLN output shape: {adaLN_out.shape}")
|
| 168 |
+
print(f"Expected dim1: {6 * self.hidden_size}")
|
| 169 |
+
raise ValueError("adaLN output dimension mismatch!")
|
| 170 |
+
# ----------------------------------------------------
|
| 171 |
+
|
| 172 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = adaLN_out.chunk(6, dim=1)
|
| 173 |
+
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
|
| 174 |
+
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
| 175 |
+
return x
|
| 176 |
+
|
| 177 |
+
class PatchedFlowDiT(nn.Module):
|
| 178 |
+
"""
|
| 179 |
+
Main DiT Architecture for Flow Matching
|
| 180 |
+
Input: z_t (Noisy Latent) + t (Time) + condition (Original Latent)
|
| 181 |
+
Output: velocity vector
|
| 182 |
+
"""
|
| 183 |
+
def __init__(self, cfg):
|
| 184 |
+
super().__init__()
|
| 185 |
+
self.cfg = cfg
|
| 186 |
+
## add patch and unpatch block here
|
| 187 |
+
self.patcher = Patch1D(cfg.patch_size)
|
| 188 |
+
self.unpatcher = Unpatch1D(cfg.patch_size)
|
| 189 |
+
|
| 190 |
+
# 计算 Patch 后的输入维度
|
| 191 |
+
# Input to DiT = Patch(z_t) + Patch(Condition)
|
| 192 |
+
# 维度 = (Latent * Patch) * 2
|
| 193 |
+
input_feat_dim = cfg.latent_dim * cfg.patch_size
|
| 194 |
+
|
| 195 |
+
# Projection to DiT Hidden Size
|
| 196 |
+
self.input_proj = nn.Linear(input_feat_dim * 2, cfg.dit_hidden)
|
| 197 |
+
|
| 198 |
+
# Time & Pos Embeddings
|
| 199 |
+
self.time_embed = TimestepEmbedder(cfg.dit_hidden)
|
| 200 |
+
patched_len = (cfg.max_seq_len + cfg.patch_size - 1) // cfg.patch_size
|
| 201 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, patched_len, cfg.dit_hidden))
|
| 202 |
+
|
| 203 |
+
self.blocks = nn.ModuleList([
|
| 204 |
+
DiTBlock(cfg.dit_hidden, cfg.dit_heads) for _ in range(cfg.dit_layers)
|
| 205 |
+
])
|
| 206 |
+
|
| 207 |
+
# Output Projection (Predict Velocity)
|
| 208 |
+
self.final_layer = nn.Linear(cfg.dit_hidden, input_feat_dim)
|
| 209 |
+
self.initialize_weights()
|
| 210 |
+
|
| 211 |
+
def initialize_weights(self):
|
| 212 |
+
# Initialize transformer layers:
|
| 213 |
+
def _basic_init(module):
|
| 214 |
+
if isinstance(module, nn.Linear):
|
| 215 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 216 |
+
if module.bias is not None:
|
| 217 |
+
nn.init.constant_(module.bias, 0)
|
| 218 |
+
self.apply(_basic_init)
|
| 219 |
+
|
| 220 |
+
# Initialize pos_embed
|
| 221 |
+
nn.init.normal_(self.pos_embed, std=0.02)
|
| 222 |
+
|
| 223 |
+
# Zero-out adaLN modulation layers
|
| 224 |
+
for block in self.blocks:
|
| 225 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
| 226 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
| 227 |
+
|
| 228 |
+
# Zero-out final layer -> modify: to predict data, so initialize is xavier or normal
|
| 229 |
+
# nn.init.constant_(self.final_layer.weight, 0)
|
| 230 |
+
# nn.init.constant_(self.final_layer.bias, 0)
|
| 231 |
+
nn.init.xavier_uniform_(self.final_layer.weight)
|
| 232 |
+
nn.init.constant_(self.final_layer.bias, 0)
|
| 233 |
+
|
| 234 |
+
def forward(self, z_t, t, condition):
|
| 235 |
+
# x: [Batch, Seq, Dim]
|
| 236 |
+
# t: [Batch]
|
| 237 |
+
# condition: [Batch, Seq, Dim] (Optional, e.g., Source Sentence)
|
| 238 |
+
"""
|
| 239 |
+
z_t: [B, L, D]
|
| 240 |
+
condition: [B, L, D]
|
| 241 |
+
"""
|
| 242 |
+
# 1. Patching
|
| 243 |
+
z_p = self.patcher(z_t)
|
| 244 |
+
c_p = self.patcher(condition)
|
| 245 |
+
|
| 246 |
+
# 2. Concat & Project(Jit Style: Explicit Conditioning)
|
| 247 |
+
x = torch.cat([z_p, c_p], dim=-1)
|
| 248 |
+
x = self.input_proj(x)
|
| 249 |
+
|
| 250 |
+
# 3. Add Embeddings
|
| 251 |
+
t_emb = self.time_embed(t)
|
| 252 |
+
# Handle length mismatch due to padding
|
| 253 |
+
L_curr = x.shape[1]
|
| 254 |
+
x = x + self.pos_embed[:, :L_curr, :]
|
| 255 |
+
|
| 256 |
+
# 4. Transformer
|
| 257 |
+
for block in self.blocks:
|
| 258 |
+
x = block(x, t_emb)
|
| 259 |
+
|
| 260 |
+
# 5. Output & Unpatch
|
| 261 |
+
v_p = self.final_layer(x)
|
| 262 |
+
v = self.unpatcher(v_p)
|
| 263 |
+
|
| 264 |
+
# Crop to original length
|
| 265 |
+
return v[:, :z_t.shape[1], :]
|
| 266 |
+
|
| 267 |
+
def forward_with_cfg(self, x, t, condition, cfg_scale):
|
| 268 |
+
"""
|
| 269 |
+
支持 Classifier-Free Guidance 的前向传播
|
| 270 |
+
"""
|
| 271 |
+
# 1. condition
|
| 272 |
+
cond_out = self.forward(x, t, condition)
|
| 273 |
+
|
| 274 |
+
# 2.uncondition
|
| 275 |
+
uncond_out = self.forward(x, t, condition=None)
|
| 276 |
+
|
| 277 |
+
# 3. classifier-free guidance
|
| 278 |
+
# eps = eps_uncond + s * (eps_cond - eps_uncond)
|
| 279 |
+
return uncond_out + cfg_scale * (cond_out - uncond_out)
|
src/search.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
|
| 5 |
+
class DiffuMCTS:
|
| 6 |
+
"""
|
| 7 |
+
Diffusion-based Monte Carlo Tree Search (Simulation-based).
|
| 8 |
+
Flow Matching Generation -> Rollout.
|
| 9 |
+
Sandbox to evaluation.
|
| 10 |
+
"""
|
| 11 |
+
def __init__(self, ae, flow, tokenizer, sandbox, device, config):
|
| 12 |
+
self.ae = ae
|
| 13 |
+
self.flow = flow
|
| 14 |
+
self.tokenizer = tokenizer
|
| 15 |
+
self.sandbox = sandbox
|
| 16 |
+
self.device = device
|
| 17 |
+
|
| 18 |
+
# 搜索配置
|
| 19 |
+
self.num_branches = 8 # 分支数量 (K)
|
| 20 |
+
self.split_t = 0.5 # 在哪个时间点分叉 (0=Target, 1=Source)
|
| 21 |
+
self.noise_scale = 0.1 # 分支时的扰动强度
|
| 22 |
+
self.steps = 10 # Flow ODE 步数
|
| 23 |
+
|
| 24 |
+
@torch.no_grad()
|
| 25 |
+
def solve(self, buggy_code, test_code, entry_point):
|
| 26 |
+
"""
|
| 27 |
+
对外接口:尝试修复代码
|
| 28 |
+
Returns:
|
| 29 |
+
fixed_code (str): 修复后的代码
|
| 30 |
+
success (bool): 是否通过测试
|
| 31 |
+
"""
|
| 32 |
+
# 1. Encode Buggy Code (Source) -> z_1
|
| 33 |
+
tokens = self.tokenizer(
|
| 34 |
+
buggy_code,
|
| 35 |
+
max_length=2048,
|
| 36 |
+
padding="max_length",
|
| 37 |
+
truncation=True,
|
| 38 |
+
return_tensors="pt"
|
| 39 |
+
).to(self.device)
|
| 40 |
+
|
| 41 |
+
z_buggy = self.ae.encode(tokens['input_ids'], tokens['attention_mask'])
|
| 42 |
+
|
| 43 |
+
# 2. Search Strategy: Parallel Branching
|
| 44 |
+
# 我们执行一次带有分支的推理
|
| 45 |
+
best_code, success = self._parallel_rollout(z_buggy, test_code, entry_point)
|
| 46 |
+
|
| 47 |
+
return best_code, success
|
| 48 |
+
|
| 49 |
+
def _parallel_rollout(self, z_start, test_code, entry_point):
|
| 50 |
+
"""
|
| 51 |
+
执行并行 Rollout 搜索
|
| 52 |
+
"""
|
| 53 |
+
B, L, D = z_start.shape
|
| 54 |
+
K = self.num_branches
|
| 55 |
+
|
| 56 |
+
# --- Stage 1: Deterministic Flow (1.0 -> split_t) ---
|
| 57 |
+
# 先从 Buggy 状态走几步,让语义稍微稳定一点
|
| 58 |
+
z_curr = z_start.clone()
|
| 59 |
+
z_cond = z_start.clone() # Condition 始终是 Buggy Code
|
| 60 |
+
|
| 61 |
+
dt = 1.0 / self.steps
|
| 62 |
+
# 计算 split 对应的步数索引
|
| 63 |
+
split_step_idx = int((1.0 - self.split_t) * self.steps)
|
| 64 |
+
|
| 65 |
+
# 走前半程
|
| 66 |
+
for i in range(split_step_idx):
|
| 67 |
+
t_val = 1.0 - (i / self.steps) # 从 1 走向 0
|
| 68 |
+
t_tensor = torch.ones(B, device=self.device) * t_val
|
| 69 |
+
|
| 70 |
+
# ODE Step: z_next = z_prev - v * dt (注意时间方向)
|
| 71 |
+
# Rectified Flow 定义 v = z_1 - z_0 (从 Good 到 Bad 的反向? 或者是 Bad 到 Good?)
|
| 72 |
+
# 我们训练时: z_t = (1-t)z_bad + t*z_good.
|
| 73 |
+
# 所以 t=0 是 Bad, t=1 是 Good.
|
| 74 |
+
# 为了方便,我们定义 t 从 0 (Bad) 走向 1 (Good)。
|
| 75 |
+
|
| 76 |
+
# 修正逻辑:
|
| 77 |
+
# Forward Euler: z_{t+dt} = z_t + v * dt
|
| 78 |
+
# t 从 0 增加到 split_t
|
| 79 |
+
current_t_val = i / self.steps
|
| 80 |
+
t_tensor = torch.ones(B, device=self.device) * current_t_val
|
| 81 |
+
|
| 82 |
+
v = self.flow(z_curr, t_tensor, condition=z_cond)
|
| 83 |
+
z_curr = z_curr + v * dt
|
| 84 |
+
|
| 85 |
+
# --- Stage 2: Expansion (Branching) ---
|
| 86 |
+
# 复制 K 份
|
| 87 |
+
# [B, L, D] -> [B*K, L, D]
|
| 88 |
+
z_branches = z_curr.repeat(K, 1, 1)
|
| 89 |
+
z_cond_branches = z_cond.repeat(K, 1, 1)
|
| 90 |
+
|
| 91 |
+
# 注入高斯噪声 (Exploration)
|
| 92 |
+
# z' = z + noise
|
| 93 |
+
noise = torch.randn_like(z_branches) * self.noise_scale
|
| 94 |
+
z_branches = z_branches + noise
|
| 95 |
+
|
| 96 |
+
# 重新投影回球面 (保持流形约束)
|
| 97 |
+
z_branches = F.normalize(z_branches, p=2, dim=-1)
|
| 98 |
+
|
| 99 |
+
# --- Stage 3: Rollout (split_t -> 1.0) ---
|
| 100 |
+
# 并行推演所有分支
|
| 101 |
+
remaining_steps = self.steps - split_step_idx
|
| 102 |
+
|
| 103 |
+
for i in range(remaining_steps):
|
| 104 |
+
step_idx = split_step_idx + i
|
| 105 |
+
current_t_val = step_idx / self.steps
|
| 106 |
+
|
| 107 |
+
# [B*K]
|
| 108 |
+
t_tensor = torch.ones(z_branches.shape[0], device=self.device) * current_t_val
|
| 109 |
+
|
| 110 |
+
v = self.flow(z_branches, t_tensor, condition=z_cond_branches)
|
| 111 |
+
z_branches = z_branches + v * dt
|
| 112 |
+
|
| 113 |
+
# --- Stage 4: Decoding & Verification ---
|
| 114 |
+
# 批量解码
|
| 115 |
+
# [B*K, L, D] -> [B*K, L, Vocab]
|
| 116 |
+
logits = self.ae.decode(z_branches)
|
| 117 |
+
pred_ids = torch.argmax(logits, dim=-1)
|
| 118 |
+
|
| 119 |
+
candidate_codes = self.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
| 120 |
+
|
| 121 |
+
# 验证 Loop
|
| 122 |
+
# 只要有一个通过,就算成功 (Pass@k)
|
| 123 |
+
for code in candidate_codes:
|
| 124 |
+
is_pass, msg = self.sandbox.run(code, test_code, entry_point)
|
| 125 |
+
if is_pass:
|
| 126 |
+
return code, True
|
| 127 |
+
|
| 128 |
+
# 如果都失败,返回第一个(或者可以设计 heuristic 选择最接近的)
|
| 129 |
+
return candidate_codes[0], False
|
src/trainer.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
|
| 5 |
+
class Trainer:
|
| 6 |
+
def __init__(self, ae, flow, cfg, loader, pad_id, stop_id):
|
| 7 |
+
self.ae = ae.to(cfg.device)
|
| 8 |
+
self.flow = flow.to(cfg.device) if flow else None
|
| 9 |
+
self.cfg = cfg
|
| 10 |
+
self.loader = loader
|
| 11 |
+
self.device = cfg.device
|
| 12 |
+
self.pad_id = pad_id
|
| 13 |
+
self.stop_id = stop_id
|
| 14 |
+
|
| 15 |
+
def train_ae(self, optimizer):
|
| 16 |
+
self.ae.train()
|
| 17 |
+
total_loss = 0
|
| 18 |
+
pbar = tqdm(self.loader, desc="Train AE")
|
| 19 |
+
optimizer.zero_grad()
|
| 20 |
+
|
| 21 |
+
for step, batch in enumerate(pbar):
|
| 22 |
+
tgt = batch['tgt_ids'].to(self.device)
|
| 23 |
+
mask = batch['tgt_mask'].to(self.device)
|
| 24 |
+
|
| 25 |
+
# logits, z = self.ae(tgt, mask)
|
| 26 |
+
|
| 27 |
+
# ## 不太明白这里的mask 机制
|
| 28 |
+
# labels = tgt.masked_fill(mask == 0, -100)
|
| 29 |
+
# loss = F.cross_entropy(
|
| 30 |
+
# logits.view(-1, logits.size(-1)),
|
| 31 |
+
# labels.view(-1),
|
| 32 |
+
# ignore_index=-100
|
| 33 |
+
# )
|
| 34 |
+
# Reconstruction Loss
|
| 35 |
+
# loss = F.cross_entropy(logits.view(-1, logits.size(-1)), tgt.view(-1), ignore_index=1)
|
| 36 |
+
logits, z = self.ae(tgt, mask) # decoder_mask 默认 = mask
|
| 37 |
+
|
| 38 |
+
V = logits.size(-1)
|
| 39 |
+
B, L = tgt.shape3
|
| 40 |
+
|
| 41 |
+
# 1) token loss:只看 mask==1
|
| 42 |
+
labels_tok = tgt.masked_fill(mask == 0, -100)
|
| 43 |
+
loss_tok = F.cross_entropy(
|
| 44 |
+
logits.view(-1, V),
|
| 45 |
+
labels_tok.view(-1),
|
| 46 |
+
ignore_index=-100,
|
| 47 |
+
reduction="mean"
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# 2) pad loss:mask==0 的位置强制预测 PAD(轻权重)
|
| 51 |
+
pad_pos = (mask == 0)
|
| 52 |
+
if pad_pos.any():
|
| 53 |
+
# 每个位置的 CE
|
| 54 |
+
ce_all = F.cross_entropy(
|
| 55 |
+
logits.view(-1, V),
|
| 56 |
+
tgt.new_full((B * L,), self.pad_id),
|
| 57 |
+
reduction="none"
|
| 58 |
+
).view(B, L)
|
| 59 |
+
loss_pad = (ce_all * pad_pos).sum() / (pad_pos.sum() + 1e-6)
|
| 60 |
+
else:
|
| 61 |
+
loss_pad = logits.new_tensor(0.0)
|
| 62 |
+
|
| 63 |
+
# 3) 可选:stop 位置加权(让 SEP 更稳)
|
| 64 |
+
stop_pos = ((tgt == self.stop_id) & (mask == 1))
|
| 65 |
+
if stop_pos.any():
|
| 66 |
+
ce_tok = F.cross_entropy(
|
| 67 |
+
logits.view(-1, V),
|
| 68 |
+
tgt.view(-1),
|
| 69 |
+
reduction="none"
|
| 70 |
+
).view(B, L)
|
| 71 |
+
loss_stop = (ce_tok * stop_pos).sum() / (stop_pos.sum() + 1e-6)
|
| 72 |
+
else:
|
| 73 |
+
loss_stop = logits.new_tensor(0.0)
|
| 74 |
+
|
| 75 |
+
# 合成:pad/stop 的权重别太大
|
| 76 |
+
lambda_pad = 0.1
|
| 77 |
+
lambda_stop = 0.2
|
| 78 |
+
loss = loss_tok + lambda_pad * loss_pad + lambda_stop * loss_stop
|
| 79 |
+
|
| 80 |
+
loss = loss / self.cfg.grad_accum_steps
|
| 81 |
+
loss.backward()
|
| 82 |
+
|
| 83 |
+
if (step + 1) % self.cfg.grad_accum_steps == 0:
|
| 84 |
+
optimizer.step()
|
| 85 |
+
optimizer.zero_grad()
|
| 86 |
+
|
| 87 |
+
total_loss += loss.item() * self.cfg.grad_accum_steps
|
| 88 |
+
pbar.set_postfix(loss=loss.item() * self.cfg.grad_accum_steps)
|
| 89 |
+
|
| 90 |
+
return total_loss / len(self.loader)
|
| 91 |
+
|
| 92 |
+
def train_robust_ae(self, optimizer):
|
| 93 |
+
|
| 94 |
+
self.ae.train()
|
| 95 |
+
total_loss = 0
|
| 96 |
+
noise_std = 0.05
|
| 97 |
+
|
| 98 |
+
for batch in self.loader:
|
| 99 |
+
tgt_ids = batch['tgt_ids'].to(self.device)
|
| 100 |
+
tgt_mask = batch['tgt_mask'].to(self.device)
|
| 101 |
+
|
| 102 |
+
# 1. get normal z
|
| 103 |
+
with torch.no_grad():
|
| 104 |
+
z_clean = self.ae.encode(tgt_ids, tgt_mask)
|
| 105 |
+
|
| 106 |
+
# 2. add noise (Denoising Training)
|
| 107 |
+
# Decoder -> like z
|
| 108 |
+
noise = torch.randn_like(z_clean) * noise_std
|
| 109 |
+
z_noisy = z_clean + noise
|
| 110 |
+
|
| 111 |
+
# 3. Decode
|
| 112 |
+
logits = self.ae.decode(z_noisy, attention_mask=tgt_mask)
|
| 113 |
+
|
| 114 |
+
# 4. Loss
|
| 115 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)),
|
| 116 |
+
tgt_ids.view(-1),
|
| 117 |
+
reduction='none')
|
| 118 |
+
loss = (loss * tgt_mask.view(-1)).sum() / tgt_mask.sum()
|
| 119 |
+
|
| 120 |
+
# Backward
|
| 121 |
+
optimizer.zero_grad()
|
| 122 |
+
loss.backward()
|
| 123 |
+
optimizer.step()
|
| 124 |
+
|
| 125 |
+
total_loss += loss.item()
|
| 126 |
+
|
| 127 |
+
return total_loss / len(self.loader)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def train_ae_combined(self, optimizer, epoch, max_epochs):
|
| 131 |
+
"""
|
| 132 |
+
结合了 基础重建 + 课程去噪 + Pad/Stop 优化
|
| 133 |
+
"""
|
| 134 |
+
self.ae.train()
|
| 135 |
+
total_loss = 0
|
| 136 |
+
|
| 137 |
+
# --- 课程噪声调度 (Curriculum Noise) ---
|
| 138 |
+
# 前 20% 的 Epoch 不加噪声,先学好重建
|
| 139 |
+
# 后面线性增加到 0.1
|
| 140 |
+
if epoch < max_epochs * 0.2:
|
| 141 |
+
current_noise = 0.0
|
| 142 |
+
else:
|
| 143 |
+
progress = (epoch - max_epochs * 0.2) / (max_epochs * 0.8)
|
| 144 |
+
current_noise = 0.1 * progress # 最大噪声 0.1
|
| 145 |
+
|
| 146 |
+
pbar = tqdm(self.loader, desc=f"Train AE (Noise={current_noise:.4f})")
|
| 147 |
+
|
| 148 |
+
for step, batch in enumerate(pbar):
|
| 149 |
+
tgt = batch['tgt_ids'].to(self.device)
|
| 150 |
+
mask = batch['tgt_mask'].to(self.device)
|
| 151 |
+
|
| 152 |
+
# 1. Encode Clean
|
| 153 |
+
with torch.no_grad():
|
| 154 |
+
z_clean = self.ae.encode(tgt, mask)
|
| 155 |
+
|
| 156 |
+
# 2. Add Noise (如果 noise > 0)
|
| 157 |
+
if current_noise > 0:
|
| 158 |
+
noise = torch.randn_like(z_clean) * current_noise
|
| 159 |
+
z_input = z_clean + noise
|
| 160 |
+
else:
|
| 161 |
+
z_input = z_clean
|
| 162 |
+
|
| 163 |
+
# 3. Decode
|
| 164 |
+
logits = self.ae.decode(z_input, attention_mask=mask)
|
| 165 |
+
|
| 166 |
+
# 4. Calculate Advanced Loss (Copy from your original code)
|
| 167 |
+
V = logits.size(-1)
|
| 168 |
+
B, L = tgt.shape
|
| 169 |
+
|
| 170 |
+
# Token Loss (只看 mask==1)
|
| 171 |
+
labels_tok = tgt.masked_fill(mask == 0, -100)
|
| 172 |
+
loss_tok = F.cross_entropy(logits.view(-1, V), labels_tok.view(-1), ignore_index=-100)
|
| 173 |
+
|
| 174 |
+
# Pad Loss (mask==0)
|
| 175 |
+
pad_pos = (mask == 0)
|
| 176 |
+
if pad_pos.any():
|
| 177 |
+
ce_pad = F.cross_entropy(logits.view(-1, V), tgt.new_full((B*L,), self.pad_id), reduction='none').view(B,L)
|
| 178 |
+
loss_pad = (ce_pad * pad_pos).sum() / (pad_pos.sum() + 1e-6)
|
| 179 |
+
else:
|
| 180 |
+
loss_pad = torch.tensor(0.0, device=self.device)
|
| 181 |
+
|
| 182 |
+
# Stop Loss
|
| 183 |
+
stop_pos = ((tgt == self.stop_id) & (mask == 1))
|
| 184 |
+
if stop_pos.any():
|
| 185 |
+
ce_stop = F.cross_entropy(logits.view(-1, V), tgt.view(-1), reduction='none').view(B,L)
|
| 186 |
+
loss_stop = (ce_stop * stop_pos).sum() / (stop_pos.sum() + 1e-6)
|
| 187 |
+
else:
|
| 188 |
+
loss_stop = torch.tensor(0.0, device=self.device)
|
| 189 |
+
|
| 190 |
+
# 合并 Loss
|
| 191 |
+
loss = loss_tok + 0.1 * loss_pad + 0.5 * loss_stop # 提高一点 stop 的权重
|
| 192 |
+
|
| 193 |
+
# Backward
|
| 194 |
+
optimizer.zero_grad()
|
| 195 |
+
loss.backward()
|
| 196 |
+
optimizer.step()
|
| 197 |
+
|
| 198 |
+
total_loss += loss.item()
|
| 199 |
+
pbar.set_postfix(loss=loss.item())
|
| 200 |
+
|
| 201 |
+
return total_loss / len(self.loader)
|
| 202 |
+
|
| 203 |
+
def train_flow(self, optimizer):
|
| 204 |
+
self.flow.train()
|
| 205 |
+
self.ae.eval()
|
| 206 |
+
total_loss = 0
|
| 207 |
+
pbar = tqdm(self.loader, desc="Train Flow")
|
| 208 |
+
optimizer.zero_grad()
|
| 209 |
+
|
| 210 |
+
scale = getattr(self.ae, "latent_scale", 10.0)
|
| 211 |
+
|
| 212 |
+
for step, batch in enumerate(pbar):
|
| 213 |
+
src = batch['src_ids'].to(self.device)
|
| 214 |
+
src_mask = batch['src_mask'].to(self.device)
|
| 215 |
+
tgt = batch['tgt_ids'].to(self.device)
|
| 216 |
+
tgt_mask = batch['tgt_mask'].to(self.device)
|
| 217 |
+
|
| 218 |
+
with torch.no_grad():
|
| 219 |
+
z_bad = self.ae.encode(src, src_mask) # norm ~ scale
|
| 220 |
+
z_good = self.ae.encode(tgt, tgt_mask) # norm ~ scale
|
| 221 |
+
|
| 222 |
+
# Rectified Flow
|
| 223 |
+
bs = z_bad.shape[0]
|
| 224 |
+
t = torch.rand(bs, device=self.device).view(bs, 1, 1)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
# Interpolation: Bad -> Good, modify-> push back to sphere
|
| 228 |
+
z_t_linear = (1 - t) * z_bad + t * z_good
|
| 229 |
+
## test before or after
|
| 230 |
+
# z_t = F.normalize(z_t_linear, p=2, dim=-1) * scale
|
| 231 |
+
z_t = z_t_linear
|
| 232 |
+
|
| 233 |
+
# Modify: pred_v to pred_x
|
| 234 |
+
# target_v = z_good - z_bad
|
| 235 |
+
# pred_v = self.flow(z_t, t.squeeze(), condition=z_bad)
|
| 236 |
+
# loss = F.mse_loss(pred_v, target_v)
|
| 237 |
+
|
| 238 |
+
# to predict z_good (Target)
|
| 239 |
+
pred_z1 = self.flow(z_t, t, condition=z_bad)
|
| 240 |
+
# 3) (强烈建议) 把输出也投影回同一球面,避免 off-manifold -> 都不要normalize
|
| 241 |
+
pred_z1 = pred_z1
|
| 242 |
+
# pred_z1 = F.normalize(pred_z1, p=2, dim=-1) * scale
|
| 243 |
+
# Loss 直接算与 z_good 的距离
|
| 244 |
+
## 修改:loss必须按照mask 算有效token
|
| 245 |
+
mse = (pred_z1 - z_good).pow(2).mean(dim=-1) # [B,L]
|
| 246 |
+
w = tgt_mask.float()
|
| 247 |
+
|
| 248 |
+
# stop 位置加权
|
| 249 |
+
stop_pos = ((tgt == self.stop_id) & (tgt_mask == 1))
|
| 250 |
+
w = w + stop_pos.float() * 2.0 # 让 SEP 位置权重更大(比如 +2)
|
| 251 |
+
|
| 252 |
+
loss = (mse * w).sum() / (w.sum() + 1e-6)
|
| 253 |
+
|
| 254 |
+
# loss = (mse * tgt_mask).sum() / (tgt_mask.sum() + 1e-6)
|
| 255 |
+
# loss = F.mse_loss(pred_z1, z_good)
|
| 256 |
+
|
| 257 |
+
loss = loss / self.cfg.grad_accum_steps
|
| 258 |
+
loss.backward()
|
| 259 |
+
|
| 260 |
+
if (step + 1) % self.cfg.grad_accum_steps == 0:
|
| 261 |
+
optimizer.step()
|
| 262 |
+
optimizer.zero_grad()
|
| 263 |
+
|
| 264 |
+
total_loss += loss.item() * self.cfg.grad_accum_steps
|
| 265 |
+
pbar.set_postfix(loss=loss.item() * self.cfg.grad_accum_steps)
|
| 266 |
+
|
| 267 |
+
return total_loss / len(self.loader)
|
src/utils/__pycache__/data_utils.cpython-311.pyc
ADDED
|
Binary file (8.25 kB). View file
|
|
|
src/utils/__pycache__/sandbox.cpython-311.pyc
ADDED
|
Binary file (7.41 kB). View file
|
|
|
src/utils/data_utils.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from datasets import load_dataset
|
| 3 |
+
from transformers import AutoTokenizer
|
| 4 |
+
import torch
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
|
| 7 |
+
def custom_collate(batch):
|
| 8 |
+
return {
|
| 9 |
+
'src_ids': torch.stack([torch.tensor(x['src_ids']) for x in batch]),
|
| 10 |
+
'src_mask': torch.stack([torch.tensor(x['src_mask']) for x in batch]),
|
| 11 |
+
'tgt_ids': torch.stack([torch.tensor(x['tgt_ids']) for x in batch]),
|
| 12 |
+
'tgt_mask': torch.stack([torch.tensor(x['tgt_mask']) for x in batch]),
|
| 13 |
+
# 保留测试用例用于验证 (仅 Eval 时有效)
|
| 14 |
+
'test_code': [x.get('test_code', "") for x in batch],
|
| 15 |
+
'entry_point': [x.get('entry_point', "") for x in batch]
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
def prepare_data(task_name, tokenizer, max_len, batch_size, split="train"):
|
| 19 |
+
"""
|
| 20 |
+
支持 split 参数,方便划分训练集和测试集
|
| 21 |
+
"""
|
| 22 |
+
print(f"Loading {task_name} ({split})...")
|
| 23 |
+
|
| 24 |
+
if task_name == "codexglue":
|
| 25 |
+
# 训练集:Microsoft CodeXGLUE (Python Refinement)
|
| 26 |
+
# 包含 GitHub Bug -> Fix
|
| 27 |
+
dataset = load_dataset("./code_x_glue_cc_code_refinement_full", "medium", split=split)
|
| 28 |
+
|
| 29 |
+
# 40k
|
| 30 |
+
if split == "train": dataset = dataset.select(range(40000))
|
| 31 |
+
|
| 32 |
+
# Case A: 标准修复数据 (有 source 和 target)
|
| 33 |
+
if 'source' in cols and 'target' in cols:
|
| 34 |
+
print(">> Detected standard refinement pairs.")
|
| 35 |
+
def preprocess_standard(ex):
|
| 36 |
+
src = tokenizer(ex['source'], max_length=max_len, padding="max_length", truncation=True)
|
| 37 |
+
tgt = tokenizer(ex['target'], max_length=max_len, padding="max_length", truncation=True)
|
| 38 |
+
return {
|
| 39 |
+
'src_ids': src['input_ids'], 'src_mask': src['attention_mask'],
|
| 40 |
+
'tgt_ids': tgt['input_ids'], 'tgt_mask': tgt['attention_mask']
|
| 41 |
+
}
|
| 42 |
+
preprocess_fn = preprocess_standard
|
| 43 |
+
|
| 44 |
+
# Case B: 只有代码 (有 code),需要人工注入 Bug
|
| 45 |
+
elif 'code' in cols:
|
| 46 |
+
print(">> Detected raw code. Not to inject synthetic bugs...")
|
| 47 |
+
|
| 48 |
+
else:
|
| 49 |
+
raise ValueError(f"Dataset columns {cols} not recognized. Need 'source'/'target' or 'code'.")
|
| 50 |
+
|
| 51 |
+
def preprocess(ex):
|
| 52 |
+
buggy = ex['source']
|
| 53 |
+
fixed = ex['target']
|
| 54 |
+
|
| 55 |
+
src = tokenizer(buggy, max_length=max_len, padding="max_length", truncation=True)
|
| 56 |
+
tgt = tokenizer(fixed, max_length=max_len, padding="max_length", truncation=True)
|
| 57 |
+
|
| 58 |
+
return {
|
| 59 |
+
'src_ids': src['input_ids'], 'src_mask': src['attention_mask'],
|
| 60 |
+
'tgt_ids': tgt['input_ids'], 'tgt_mask': tgt['attention_mask']
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
# 移除原始列
|
| 64 |
+
cols = dataset.column_names
|
| 65 |
+
|
| 66 |
+
elif task_name == "humanevalpack":
|
| 67 |
+
# 验证集:HumanEvalPack (Fix Task)
|
| 68 |
+
# 包含 Buggy Code 和 对应的 Unit Tests
|
| 69 |
+
dataset = load_dataset("./bigcode_humanevalpack_full", "python", split="test") # 只有 test 集
|
| 70 |
+
|
| 71 |
+
# 筛选出 FIX 任务
|
| 72 |
+
dataset = dataset.filter(lambda x: x['task_id'].startswith("Python/FIX"))
|
| 73 |
+
|
| 74 |
+
def preprocess(ex):
|
| 75 |
+
# prompt 是前面的描述,buggy_solution 是有 bug 的代码
|
| 76 |
+
# 为了简化,我们把 prompt + buggy_solution 作为输入
|
| 77 |
+
full_buggy = ex['prompt'] + "\n" + ex['buggy_solution']
|
| 78 |
+
full_fixed = ex['prompt'] + "\n" + ex['canonical_solution']
|
| 79 |
+
|
| 80 |
+
src = tokenizer(full_buggy, max_length=max_len, padding="max_length", truncation=True)
|
| 81 |
+
tgt = tokenizer(full_fixed, max_length=max_len, padding="max_length", truncation=True)
|
| 82 |
+
|
| 83 |
+
return {
|
| 84 |
+
'src_ids': src['input_ids'], 'src_mask': src['attention_mask'],
|
| 85 |
+
'tgt_ids': tgt['input_ids'], 'tgt_mask': tgt['attention_mask'],
|
| 86 |
+
'test_code': ex['test'], # 核心:保留测试代码
|
| 87 |
+
'entry_point': ex['entry_point'] # 核心:保留入口函数名
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
# 保留所有列用于 debug,dataset.map 会自动处理返回的 dict
|
| 91 |
+
cols = [] # 不自动删除列,我们需要 test 列在 collate 中处理
|
| 92 |
+
|
| 93 |
+
# --- 1. Load Dataset ---
|
| 94 |
+
elif task_name == "wiki":
|
| 95 |
+
# 尝试本地加载,失败则下载
|
| 96 |
+
try:
|
| 97 |
+
dataset = load_dataset("./wikilarge-dataset")
|
| 98 |
+
except:
|
| 99 |
+
print("Local load failed, downloading from Hub...")
|
| 100 |
+
dataset = load_dataset("wikilarge")
|
| 101 |
+
|
| 102 |
+
# 手动划分: train用前10000条, test用后1000条 (做demo够了,全量太慢)
|
| 103 |
+
if split == "train":
|
| 104 |
+
dataset = dataset['train'].select(range(20000))
|
| 105 |
+
else:
|
| 106 |
+
# 假设总共有 ~290k,我们取后面一点做测试
|
| 107 |
+
dataset = dataset['train'].select(range(20000, 25000))
|
| 108 |
+
|
| 109 |
+
# 自动探测列名
|
| 110 |
+
cols = dataset.column_names
|
| 111 |
+
print(f"Wiki Dataset Columns: {cols}")
|
| 112 |
+
|
| 113 |
+
# 映射列名到 src/tgt
|
| 114 |
+
if 'src' in cols and 'dst' in cols:
|
| 115 |
+
src_key, tgt_key = 'src', 'dst'
|
| 116 |
+
elif 'Normal' in cols and 'Simple' in cols:
|
| 117 |
+
src_key, tgt_key = 'Normal', 'Simple'
|
| 118 |
+
else:
|
| 119 |
+
raise ValueError(f"Unknown column format for WikiLarge: {cols}")
|
| 120 |
+
|
| 121 |
+
def preprocess(ex):
|
| 122 |
+
# Source (Complex) -> Target (Simple)
|
| 123 |
+
src = tokenizer(ex[src_key], max_length=max_len, padding="max_length", truncation=True)
|
| 124 |
+
tgt = tokenizer(ex[tgt_key], max_length=max_len, padding="max_length", truncation=True)
|
| 125 |
+
return {
|
| 126 |
+
'src_ids': src['input_ids'], 'src_mask': src['attention_mask'],
|
| 127 |
+
'tgt_ids': tgt['input_ids'], 'tgt_mask': tgt['attention_mask']
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
elif task_name == "mbpp":
|
| 131 |
+
dataset = load_dataset("mbpp", split="train[:500]")
|
| 132 |
+
print(f"MBPP Dataset Columns: {dataset.column_names}")
|
| 133 |
+
|
| 134 |
+
# MBPP 自重建任务: src=code, tgt=code
|
| 135 |
+
def preprocess(ex):
|
| 136 |
+
enc = tokenizer(ex['code'], max_length=max_len, padding="max_length", truncation=True)
|
| 137 |
+
return {
|
| 138 |
+
'src_ids': enc['input_ids'], 'src_mask': enc['attention_mask'],
|
| 139 |
+
'tgt_ids': enc['input_ids'], 'tgt_mask': enc['attention_mask']
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
else:
|
| 143 |
+
raise ValueError(f"Unknown task: {task_name}")
|
| 144 |
+
|
| 145 |
+
# --- 2. Map & Batch ---
|
| 146 |
+
print(f"Preprocessing {task_name} data...")
|
| 147 |
+
# 使用 remove_columns=dataset.column_names 确保删除所有原始列
|
| 148 |
+
print(f"Preprocessing {len(dataset)} examples...")
|
| 149 |
+
dataset = dataset.map(
|
| 150 |
+
preprocess,
|
| 151 |
+
batched=True,
|
| 152 |
+
remove_columns=dataset.column_names,
|
| 153 |
+
num_proc=4
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
# Test 集不 shuffle,方便对齐
|
| 157 |
+
return DataLoader(dataset, batch_size=batch_size, shuffle=(split=="train"), collate_fn=custom_collate)
|
src/utils/sandbox.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import io
|
| 3 |
+
import multiprocessing
|
| 4 |
+
import contextlib
|
| 5 |
+
import signal
|
| 6 |
+
import subprocess
|
| 7 |
+
import tempfile
|
| 8 |
+
import shutil
|
| 9 |
+
|
| 10 |
+
# 增加超时控制
|
| 11 |
+
class TimeoutException(Exception): pass
|
| 12 |
+
|
| 13 |
+
def timeout_handler(signum, frame):
|
| 14 |
+
raise TimeoutException
|
| 15 |
+
|
| 16 |
+
def _exec_code(code_str, test_code, entry_point, result_queue):
|
| 17 |
+
"""
|
| 18 |
+
运行生成的代码 + 测试用例
|
| 19 |
+
"""
|
| 20 |
+
capture = io.StringIO()
|
| 21 |
+
success = False
|
| 22 |
+
error_msg = ""
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
# 简单的超时机制 (Linux only, Windows需要其他方式)
|
| 26 |
+
# signal.signal(signal.SIGALRM, timeout_handler)
|
| 27 |
+
# signal.alarm(2) # 2秒超时
|
| 28 |
+
|
| 29 |
+
with contextlib.redirect_stdout(capture), contextlib.redirect_stderr(capture):
|
| 30 |
+
# 创建独立命名空间
|
| 31 |
+
scope = {}
|
| 32 |
+
|
| 33 |
+
# 1. 执行生成的代码 (定义函数)
|
| 34 |
+
exec(code_str, scope)
|
| 35 |
+
|
| 36 |
+
# 2. 检查入口函数是否存在
|
| 37 |
+
if entry_point not in scope:
|
| 38 |
+
raise ValueError(f"Entry point {entry_point} not found in generated code.")
|
| 39 |
+
|
| 40 |
+
# 3. 执行测试用例
|
| 41 |
+
# HumanEval 的测试用例通常是 "check(entry_point_func)" 的形式
|
| 42 |
+
# 我们需要把 check 函数定义也 exec 进去,或者拼接到一起
|
| 43 |
+
full_test_script = code_str + "\n" + test_code + f"\ncheck({entry_point})"
|
| 44 |
+
|
| 45 |
+
exec(full_test_script, scope)
|
| 46 |
+
|
| 47 |
+
success = True
|
| 48 |
+
|
| 49 |
+
except Exception as e:
|
| 50 |
+
error_msg = str(e)
|
| 51 |
+
# finally:
|
| 52 |
+
# signal.alarm(0)
|
| 53 |
+
|
| 54 |
+
result_queue.put((success, error_msg))
|
| 55 |
+
|
| 56 |
+
class SafeSandbox:
|
| 57 |
+
def __init__(self, timeout=5.0):
|
| 58 |
+
self.timeout = timeout
|
| 59 |
+
|
| 60 |
+
def run(self, code, test_code, entry_point):
|
| 61 |
+
queue = multiprocessing.Queue()
|
| 62 |
+
p = multiprocessing.Process(target=_exec_code, args=(code, test_code, entry_point, queue))
|
| 63 |
+
p.start()
|
| 64 |
+
p.join(self.timeout)
|
| 65 |
+
|
| 66 |
+
if p.is_alive():
|
| 67 |
+
p.terminate()
|
| 68 |
+
p.join()
|
| 69 |
+
return False, "Timeout"
|
| 70 |
+
|
| 71 |
+
if not queue.empty():
|
| 72 |
+
return queue.get()
|
| 73 |
+
return False, "Unknown Error"
|
| 74 |
+
|
| 75 |
+
class JavaSandbox:
|
| 76 |
+
def __init__(self, timeout=5.0):
|
| 77 |
+
self.timeout = timeout
|
| 78 |
+
|
| 79 |
+
# 检查 Java 环境
|
| 80 |
+
if shutil.which("javac") is None or shutil.which("java") is None:
|
| 81 |
+
raise RuntimeError("Java environment (jdk) not found. Please install java.")
|
| 82 |
+
|
| 83 |
+
def run(self, code, test_code, entry_point):
|
| 84 |
+
"""
|
| 85 |
+
code: 修复后的 Java 方法代码
|
| 86 |
+
test_code: 包含 main 函数的测试类代码,调用 entry_point
|
| 87 |
+
entry_point: 方法名 (Java 中通常不需要,只要 test_code 写对)
|
| 88 |
+
"""
|
| 89 |
+
# 创建临时目录
|
| 90 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
| 91 |
+
file_name = "Solution.java" # 假设类名是 Solution
|
| 92 |
+
file_path = os.path.join(temp_dir, file_name)
|
| 93 |
+
|
| 94 |
+
# 拼接代码:我们需要把生成的 method 塞进一个 Class 里
|
| 95 |
+
# 这里假设 code 只是一个 method,test_code 是 main 函数
|
| 96 |
+
# 你需要根据数据集的具体格式调整拼接逻辑
|
| 97 |
+
|
| 98 |
+
full_source = f"""
|
| 99 |
+
public class Solution {{
|
| 100 |
+
{code}
|
| 101 |
+
|
| 102 |
+
{test_code}
|
| 103 |
+
}}
|
| 104 |
+
"""
|
| 105 |
+
# 1. 写入文件
|
| 106 |
+
with open(file_path, "w") as f:
|
| 107 |
+
f.write(full_source)
|
| 108 |
+
|
| 109 |
+
# 2. 编译
|
| 110 |
+
compile_cmd = ["javac", file_path]
|
| 111 |
+
try:
|
| 112 |
+
subprocess.run(compile_cmd, check=True, capture_output=True, timeout=10)
|
| 113 |
+
except subprocess.CalledProcessError as e:
|
| 114 |
+
return False, f"Compilation Error: {e.stderr.decode()}"
|
| 115 |
+
except subprocess.TimeoutExpired:
|
| 116 |
+
return False, "Compilation Timeout"
|
| 117 |
+
|
| 118 |
+
# 3. 运行
|
| 119 |
+
run_cmd = ["java", "-cp", temp_dir, "Solution"]
|
| 120 |
+
try:
|
| 121 |
+
result = subprocess.run(run_cmd, capture_output=True, timeout=self.timeout)
|
| 122 |
+
if result.returncode == 0:
|
| 123 |
+
return True, result.stdout.decode()
|
| 124 |
+
else:
|
| 125 |
+
return False, f"Runtime Error: {result.stderr.decode()}"
|
| 126 |
+
except subprocess.TimeoutExpired:
|
| 127 |
+
return False, "Runtime Timeout"
|
| 128 |
+
|
| 129 |
+
# 单元测试
|
| 130 |
+
if __name__ == "__main__":
|
| 131 |
+
sandbox = JavaSandbox()
|
| 132 |
+
|
| 133 |
+
# 正确代码
|
| 134 |
+
code_pass = """
|
| 135 |
+
public static int add(int a, int b) {
|
| 136 |
+
return a + b;
|
| 137 |
+
}
|
| 138 |
+
"""
|
| 139 |
+
test_pass = """
|
| 140 |
+
public static void main(String[] args) {
|
| 141 |
+
if (add(1, 1) == 2) {
|
| 142 |
+
System.out.println("PASS");
|
| 143 |
+
} else {
|
| 144 |
+
System.exit(1);
|
| 145 |
+
}
|
| 146 |
+
}
|
| 147 |
+
"""
|
| 148 |
+
print("Test Pass:", sandbox.run(code_pass, test_pass, "add"))
|
| 149 |
+
|
| 150 |
+
# 错误代码
|
| 151 |
+
code_fail = """
|
| 152 |
+
public static int add(int a, int b) {
|
| 153 |
+
return a * b; // Bug
|
| 154 |
+
}
|
| 155 |
+
"""
|
| 156 |
+
print("Test Fail:", sandbox.run(code_fail, test_pass, "add"))
|
tests/__pycache__/test_models.cpython-311.pyc
ADDED
|
Binary file (3.82 kB). View file
|
|
|
tests/test_models.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import unittest
|
| 2 |
+
import torch
|
| 3 |
+
import sys
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
# 确保能导入 src
|
| 7 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
| 8 |
+
|
| 9 |
+
from src.config import ModelConfig
|
| 10 |
+
from src.models.autoencoder import LatentAutoencoder
|
| 11 |
+
from src.models.dit import FlowDiT
|
| 12 |
+
|
| 13 |
+
class TestModels(unittest.TestCase):
|
| 14 |
+
def setUp(self):
|
| 15 |
+
# 构造一个小配置用于测试
|
| 16 |
+
self.cfg = ModelConfig(
|
| 17 |
+
encoder_name="roberta-base",
|
| 18 |
+
latent_dim=128,
|
| 19 |
+
max_seq_len=32,
|
| 20 |
+
decoder_layers=2, # 快一点
|
| 21 |
+
dit_layers=2
|
| 22 |
+
)
|
| 23 |
+
# 强行覆盖 dit_hidden 确保测试一致性 (虽然 property 已经保证了)
|
| 24 |
+
# self.cfg.dit_hidden = 128
|
| 25 |
+
|
| 26 |
+
def test_ae_shape(self):
|
| 27 |
+
print("\nTesting Autoencoder Shape...")
|
| 28 |
+
model = LatentAutoencoder(self.cfg)
|
| 29 |
+
input_ids = torch.randint(0, 100, (2, 32))
|
| 30 |
+
mask = torch.ones((2, 32))
|
| 31 |
+
logits, z = model(input_ids, mask)
|
| 32 |
+
|
| 33 |
+
self.assertEqual(z.shape, (2, 32, 128))
|
| 34 |
+
# 50265 是 RoBERTa 的词表大小
|
| 35 |
+
self.assertEqual(logits.shape, (2, 32, 50265))
|
| 36 |
+
print("AE Shape Check Passed.")
|
| 37 |
+
|
| 38 |
+
def test_dit_shape(self):
|
| 39 |
+
print("\nTesting DiT Shape...")
|
| 40 |
+
model = FlowDiT(self.cfg)
|
| 41 |
+
x = torch.randn(2, 32, 128) # B, Seq, Dim
|
| 42 |
+
t = torch.rand(2) # B
|
| 43 |
+
cond = torch.randn(2, 32, 128)
|
| 44 |
+
|
| 45 |
+
out = model(x, t, condition=cond)
|
| 46 |
+
self.assertEqual(out.shape, (2, 32, 128))
|
| 47 |
+
print("DiT Shape Check Passed.")
|
| 48 |
+
|
| 49 |
+
def test_cfg_forward(self):
|
| 50 |
+
print("\nTesting CFG Forward...")
|
| 51 |
+
model = FlowDiT(self.cfg)
|
| 52 |
+
x = torch.randn(2, 32, 128)
|
| 53 |
+
t = torch.rand(2)
|
| 54 |
+
cond = torch.randn(2, 32, 128)
|
| 55 |
+
|
| 56 |
+
out = model.forward_with_cfg(x, t, cond, cfg_scale=3.0)
|
| 57 |
+
self.assertEqual(out.shape, (2, 32, 128))
|
| 58 |
+
print("CFG Check Passed.")
|
| 59 |
+
|
| 60 |
+
if __name__ == "__main__":
|
| 61 |
+
unittest.main()
|
train_ae.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.optim as optim
|
| 3 |
+
from transformers import AutoTokenizer
|
| 4 |
+
import os
|
| 5 |
+
import argparse
|
| 6 |
+
|
| 7 |
+
from src.config import ModelConfig, TrainConfig
|
| 8 |
+
from src.models.autoencoder import ReshapedAutoencoder,ResidualAutoencoder
|
| 9 |
+
from src.trainer import Trainer
|
| 10 |
+
from src.utils.data_utils import prepare_data
|
| 11 |
+
|
| 12 |
+
def _pick_stop_id(tokenizer):
|
| 13 |
+
return tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.sep_token_id
|
| 14 |
+
|
| 15 |
+
def main():
|
| 16 |
+
parser = argparse.ArgumentParser()
|
| 17 |
+
parser.add_argument("--save_dir", type=str, default="/mnt/hdfs/user/lixinyu.222/CodeFlow/robust_checkpoints", help="Directory to save checkpoints")
|
| 18 |
+
args = parser.parse_args()
|
| 19 |
+
|
| 20 |
+
os.makedirs(args.save_dir, exist_ok=True)
|
| 21 |
+
print(f"Checkpoints will be saved to: {args.save_dir}")
|
| 22 |
+
|
| 23 |
+
# --- Config ---
|
| 24 |
+
m_cfg = ModelConfig(
|
| 25 |
+
encoder_name='../jina-embeddings-v2-base-code', # 请根据实际路径修改
|
| 26 |
+
latent_dim=512,
|
| 27 |
+
max_seq_len=128
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
t_cfg = TrainConfig(
|
| 31 |
+
batch_size=16,
|
| 32 |
+
num_epochs_ae=20, # 只关注 AE 的 epoch
|
| 33 |
+
grad_accum_steps=4,
|
| 34 |
+
use_amp=False,
|
| 35 |
+
lr_ae=1e-4
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# --- Data & Tokenizer ---
|
| 39 |
+
tokenizer = AutoTokenizer.from_pretrained(m_cfg.encoder_name, local_files_only=True, trust_remote_code=False)
|
| 40 |
+
train_loader = prepare_data("wiki", tokenizer, m_cfg.max_seq_len, t_cfg.batch_size, split="train")
|
| 41 |
+
|
| 42 |
+
# --- Model ---
|
| 43 |
+
ae = ReshapedAutoencoder(m_cfg).to(t_cfg.device).float()
|
| 44 |
+
# ae = ResidualAutoencoder(m_cfg).to(t_cfg.device).float()
|
| 45 |
+
|
| 46 |
+
if ae.encoder.config.pad_token_id is None:
|
| 47 |
+
ae.encoder.config.pad_token_id = tokenizer.pad_token_id
|
| 48 |
+
|
| 49 |
+
# --- Trainer ---
|
| 50 |
+
# 这里 flow 传 None,因为只训 AE
|
| 51 |
+
trainer = Trainer(
|
| 52 |
+
ae=ae,
|
| 53 |
+
flow=None,
|
| 54 |
+
cfg=t_cfg,
|
| 55 |
+
loader=train_loader,
|
| 56 |
+
pad_id=tokenizer.pad_token_id,
|
| 57 |
+
stop_id=_pick_stop_id(tokenizer)
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# --- Optimizer ---
|
| 61 |
+
opt_ae = optim.AdamW(filter(lambda p: p.requires_grad, ae.parameters()), lr=t_cfg.lr_ae)
|
| 62 |
+
|
| 63 |
+
# --- Training Loop ---
|
| 64 |
+
best_ae_loss = float('inf')
|
| 65 |
+
print("\n>>> Start Training Autoencoder...")
|
| 66 |
+
|
| 67 |
+
for epoch in range(t_cfg.num_epochs_ae):
|
| 68 |
+
# loss = trainer.train_ae(opt_ae)
|
| 69 |
+
# loss = trainer.train_robust_ae(opt_ae)
|
| 70 |
+
loss = trainer.train_ae_combined(opt_ae, epoch, t_cfg.num_epochs_ae)
|
| 71 |
+
print(f"AE Epoch {epoch}: Loss {loss:.4f}")
|
| 72 |
+
|
| 73 |
+
# Save Best
|
| 74 |
+
if loss < best_ae_loss:
|
| 75 |
+
best_ae_loss = loss
|
| 76 |
+
save_path = os.path.join(args.save_dir, "ae_best.pt")
|
| 77 |
+
torch.save(ae.state_dict(), save_path)
|
| 78 |
+
print(f" Saved Best AE to {save_path}")
|
| 79 |
+
|
| 80 |
+
# Save Last
|
| 81 |
+
torch.save(ae.state_dict(), os.path.join(args.save_dir, "ae_last.pt"))
|
| 82 |
+
|
| 83 |
+
print(f"AE Training Done. Best Loss: {best_ae_loss:.4f}")
|
| 84 |
+
|
| 85 |
+
if __name__ == "__main__":
|
| 86 |
+
main()
|
train_flow.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.optim as optim
|
| 3 |
+
from transformers import AutoTokenizer
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import os
|
| 7 |
+
import argparse
|
| 8 |
+
import sacrebleu
|
| 9 |
+
|
| 10 |
+
from src.config import ModelConfig, TrainConfig
|
| 11 |
+
from src.models.autoencoder import ReshapedAutoencoder
|
| 12 |
+
from src.models.dit import PatchedFlowDiT
|
| 13 |
+
from src.trainer import Trainer
|
| 14 |
+
from src.utils.data_utils import prepare_data
|
| 15 |
+
|
| 16 |
+
# --- Helper Functions for Inference (复制过来以便独立运行) ---
|
| 17 |
+
def _pick_stop_id(tokenizer):
|
| 18 |
+
return tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.sep_token_id
|
| 19 |
+
|
| 20 |
+
def _first_pos(x_1d, token_id, default):
|
| 21 |
+
idx = (x_1d == token_id).nonzero(as_tuple=True)[0]
|
| 22 |
+
return idx[0].item() if idx.numel() > 0 else default
|
| 23 |
+
|
| 24 |
+
def calculate_metrics(sources, predictions, references):
|
| 25 |
+
bleu = sacrebleu.corpus_bleu(predictions, [references])
|
| 26 |
+
try:
|
| 27 |
+
sari = sacrebleu.corpus_sari(sources, predictions, [references])
|
| 28 |
+
sari_score = sari.score
|
| 29 |
+
except Exception:
|
| 30 |
+
sari_score = 0.0
|
| 31 |
+
|
| 32 |
+
ratios = [len(p) / len(s) if len(s) > 0 else 0 for p, s in zip(predictions, sources)]
|
| 33 |
+
avg_ratio = sum(ratios) / len(ratios) if ratios else 0
|
| 34 |
+
|
| 35 |
+
return {"SARI": sari_score, "BLEU": bleu.score, "Compression Ratio": avg_ratio}
|
| 36 |
+
|
| 37 |
+
@torch.no_grad()
|
| 38 |
+
def inference_batch(ae, flow, loader, tokenizer, device, steps=10, save_path="results.txt", use_oneshot=True):
|
| 39 |
+
ae.eval()
|
| 40 |
+
flow.eval()
|
| 41 |
+
stop_id = _pick_stop_id(tokenizer)
|
| 42 |
+
pad_id = tokenizer.pad_token_id
|
| 43 |
+
|
| 44 |
+
print(f"\n>>> Running Inference on {len(loader.dataset)} examples...")
|
| 45 |
+
|
| 46 |
+
all_sources, all_targets, all_generated = [], [], []
|
| 47 |
+
scale = getattr(ae, "latent_scale", 10.0) # 兼容逻辑
|
| 48 |
+
|
| 49 |
+
with open(save_path, "w", encoding="utf-8") as f:
|
| 50 |
+
f.write("Source\tTarget\tGenerated\n")
|
| 51 |
+
|
| 52 |
+
for batch in tqdm(loader, desc="Inferencing"):
|
| 53 |
+
src_ids = batch['src_ids'].to(device)
|
| 54 |
+
src_mask = batch['src_mask'].to(device)
|
| 55 |
+
tgt_ids = batch['tgt_ids'].to(device)
|
| 56 |
+
B, L = src_ids.shape
|
| 57 |
+
|
| 58 |
+
# Encode
|
| 59 |
+
z_curr = ae.encode(src_ids, src_mask)
|
| 60 |
+
z_cond = z_curr.clone()
|
| 61 |
+
|
| 62 |
+
# Flow Sampling
|
| 63 |
+
if use_oneshot:
|
| 64 |
+
t0 = torch.zeros(B, device=device)
|
| 65 |
+
z_curr = flow(z_curr, t0, condition=z_cond).float()
|
| 66 |
+
else:
|
| 67 |
+
dt = 1.0 / steps
|
| 68 |
+
for i in range(steps):
|
| 69 |
+
t_val = i / steps
|
| 70 |
+
if t_val >= 0.999: break
|
| 71 |
+
t = torch.ones(B, device=device) * t_val
|
| 72 |
+
pred_z1 = flow(z_curr, t, condition=z_cond).float()
|
| 73 |
+
v = (pred_z1 - z_curr) / (1.0 - t_val + 1e-4)
|
| 74 |
+
z_curr = z_curr + v * dt
|
| 75 |
+
z_curr = pred_z1
|
| 76 |
+
|
| 77 |
+
# Decode (Pass 1: Detect Length)
|
| 78 |
+
full_mask = torch.ones(B, L, device=device)
|
| 79 |
+
logits1 = ae.decode(z_curr, attention_mask=full_mask)
|
| 80 |
+
ids1 = logits1.argmax(dim=-1)
|
| 81 |
+
|
| 82 |
+
stop_pos = []
|
| 83 |
+
for i in range(B):
|
| 84 |
+
pos = _first_pos(ids1[i], stop_id, default=L - 1)
|
| 85 |
+
stop_pos.append(pos)
|
| 86 |
+
|
| 87 |
+
# Decode (Pass 2: Clean Decode)
|
| 88 |
+
gen_mask = torch.zeros(B, L, device=device)
|
| 89 |
+
for i in range(B):
|
| 90 |
+
gen_mask[i, : stop_pos[i] + 1] = 1.0
|
| 91 |
+
|
| 92 |
+
logits2 = ae.decode(z_curr, attention_mask=gen_mask)
|
| 93 |
+
ids2 = logits2.argmax(dim=-1)
|
| 94 |
+
ids2 = ids2.masked_fill(gen_mask == 0, pad_id)
|
| 95 |
+
|
| 96 |
+
# Convert to Text
|
| 97 |
+
src_texts = tokenizer.batch_decode(src_ids, skip_special_tokens=True)
|
| 98 |
+
tgt_texts = tokenizer.batch_decode(tgt_ids, skip_special_tokens=True)
|
| 99 |
+
|
| 100 |
+
gen_texts = []
|
| 101 |
+
for i in range(B):
|
| 102 |
+
end = stop_pos[i] + 1
|
| 103 |
+
ids_cut = ids2[i, :end]
|
| 104 |
+
gen_texts.append(tokenizer.decode(ids_cut, skip_special_tokens=True))
|
| 105 |
+
|
| 106 |
+
for s, t, g in zip(src_texts, tgt_texts, gen_texts):
|
| 107 |
+
s_c = s.replace("\n", " ")
|
| 108 |
+
t_c = t.replace("\n", " ")
|
| 109 |
+
g_c = g.replace("\n", " ")
|
| 110 |
+
f.write(f"{s_c}\t{t_c}\t{g_c}\n")
|
| 111 |
+
all_sources.append(s_c)
|
| 112 |
+
all_targets.append(t_c)
|
| 113 |
+
all_generated.append(g_c)
|
| 114 |
+
|
| 115 |
+
return all_sources, all_targets, all_generated
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def main():
|
| 119 |
+
parser = argparse.ArgumentParser()
|
| 120 |
+
parser.add_argument("--ae_ckpt", type=str, default="/mnt/hdfs/user/lixinyu.222/CodeFlow/residual_robust_checkpoints/ae_best.pt", help="Path to pre-trained AE checkpoint")
|
| 121 |
+
parser.add_argument("--save_dir", type=str, default="residual_robust_checkpoints", help="Directory to save flow checkpoints")
|
| 122 |
+
parser.add_argument("--use_oneshot", action="store_true", default=True, help="Use one-shot sampling for inference")
|
| 123 |
+
args = parser.parse_args()
|
| 124 |
+
|
| 125 |
+
os.makedirs(args.save_dir, exist_ok=True)
|
| 126 |
+
|
| 127 |
+
# --- Config ---
|
| 128 |
+
m_cfg = ModelConfig(
|
| 129 |
+
encoder_name='../jina-embeddings-v2-base-code',
|
| 130 |
+
latent_dim=512,
|
| 131 |
+
max_seq_len=128
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
t_cfg = TrainConfig(
|
| 135 |
+
batch_size=16,
|
| 136 |
+
num_epochs_flow=35, # 只关注 Flow 的 epoch
|
| 137 |
+
grad_accum_steps=4,
|
| 138 |
+
use_amp=False,
|
| 139 |
+
lr_flow=2e-4
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# --- Tokenizer & Data ---
|
| 143 |
+
tokenizer = AutoTokenizer.from_pretrained(m_cfg.encoder_name,local_files_only=True, trust_remote_code=False)
|
| 144 |
+
train_loader = prepare_data("wiki", tokenizer, m_cfg.max_seq_len, t_cfg.batch_size, split="train")
|
| 145 |
+
test_loader = prepare_data("wiki", tokenizer, m_cfg.max_seq_len, t_cfg.batch_size, split="test")
|
| 146 |
+
|
| 147 |
+
# --- Load AE (Pre-trained) ---
|
| 148 |
+
print(f"\n>>> Loading Pre-trained Autoencoder from {args.ae_ckpt} ...")
|
| 149 |
+
ae = ReshapedAutoencoder(m_cfg).to(t_cfg.device).float()
|
| 150 |
+
|
| 151 |
+
if not os.path.exists(args.ae_ckpt):
|
| 152 |
+
raise FileNotFoundError(f"AE checkpoint not found at {args.ae_ckpt}. Please run train_ae.py first.")
|
| 153 |
+
|
| 154 |
+
ae.load_state_dict(torch.load(args.ae_ckpt, map_location=t_cfg.device))
|
| 155 |
+
|
| 156 |
+
# 冻结 AE 的所有参数,Flow 训练时不更新 AE
|
| 157 |
+
ae.eval()
|
| 158 |
+
for param in ae.parameters():
|
| 159 |
+
param.requires_grad = False
|
| 160 |
+
print(">>> Autoencoder loaded and frozen.")
|
| 161 |
+
|
| 162 |
+
if ae.encoder.config.pad_token_id is None:
|
| 163 |
+
ae.encoder.config.pad_token_id = tokenizer.pad_token_id
|
| 164 |
+
|
| 165 |
+
# --- Initialize Flow ---
|
| 166 |
+
flow = PatchedFlowDiT(m_cfg).to(t_cfg.device).float()
|
| 167 |
+
|
| 168 |
+
# --- Trainer ---
|
| 169 |
+
trainer = Trainer(
|
| 170 |
+
ae=ae,
|
| 171 |
+
flow=flow,
|
| 172 |
+
cfg=t_cfg,
|
| 173 |
+
loader=train_loader,
|
| 174 |
+
pad_id=tokenizer.pad_token_id,
|
| 175 |
+
stop_id=_pick_stop_id(tokenizer)
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
# --- Optimizer ---
|
| 179 |
+
opt_flow = optim.AdamW(flow.parameters(), lr=t_cfg.lr_flow)
|
| 180 |
+
|
| 181 |
+
# --- Training Loop ---
|
| 182 |
+
best_flow_loss = float('inf')
|
| 183 |
+
print("\n>>> Start Training Flow DiT...")
|
| 184 |
+
|
| 185 |
+
for epoch in range(t_cfg.num_epochs_flow):
|
| 186 |
+
# 传入 opt_flow 训练 Flow
|
| 187 |
+
loss = trainer.train_flow(opt_flow)
|
| 188 |
+
print(f"Flow Epoch {epoch}: Loss {loss:.4f}")
|
| 189 |
+
|
| 190 |
+
# Save Best
|
| 191 |
+
if loss < best_flow_loss:
|
| 192 |
+
best_flow_loss = loss
|
| 193 |
+
save_path = os.path.join(args.save_dir, "flow_best.pt")
|
| 194 |
+
torch.save(flow.state_dict(), save_path)
|
| 195 |
+
# print(f" Saved Best Flow to {save_path}")
|
| 196 |
+
|
| 197 |
+
# Save Last
|
| 198 |
+
torch.save(flow.state_dict(), os.path.join(args.save_dir, "flow_last.pt"))
|
| 199 |
+
|
| 200 |
+
print(f"Flow Training Done. Best Loss: {best_flow_loss:.4f}")
|
| 201 |
+
|
| 202 |
+
# --- Inference / Evaluation ---
|
| 203 |
+
print("\n>>> Loading Best Flow Checkpoint for Evaluation...")
|
| 204 |
+
best_flow_path = os.path.join(args.save_dir, "flow_best.pt")
|
| 205 |
+
if os.path.exists(best_flow_path):
|
| 206 |
+
flow.load_state_dict(torch.load(best_flow_path, map_location=t_cfg.device))
|
| 207 |
+
else:
|
| 208 |
+
print("Warning: Best checkpoint not found, utilizing last epoch weights.")
|
| 209 |
+
|
| 210 |
+
print("\n--- Starting Inference ---")
|
| 211 |
+
sources, targets, gens = inference_batch(
|
| 212 |
+
ae, flow, test_loader, tokenizer, t_cfg.device,
|
| 213 |
+
steps=10,
|
| 214 |
+
save_path="wiki_results.tsv",
|
| 215 |
+
use_oneshot=args.use_oneshot
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
# Metrics
|
| 219 |
+
metrics = calculate_metrics(sources, gens, targets)
|
| 220 |
+
print("\n=== Metrics ===")
|
| 221 |
+
for k, v in metrics.items():
|
| 222 |
+
print(f"{k}: {v:.4f}")
|
| 223 |
+
|
| 224 |
+
print(f"\nResults saved to wiki_results.tsv")
|
| 225 |
+
|
| 226 |
+
if __name__ == "__main__":
|
| 227 |
+
main()
|
wiki_results.tsv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
wikilarge-dataset/.gitattributes
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.lz4 filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
# Audio files - uncompressed
|
| 37 |
+
*.pcm filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
*.sam filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
*.raw filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
# Audio files - compressed
|
| 41 |
+
*.aac filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
*.flac filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
*.mp3 filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
*.ogg filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
*.wav filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
# Image files - uncompressed
|
| 47 |
+
*.bmp filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
*.tiff filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
# Image files - compressed
|
| 52 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
*.webp filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
wiki.full.aner.ori.train.95.tsv filter=lfs diff=lfs merge=lfs -text
|
wikilarge-dataset/wiki.full.aner.ori.test.95.tsv
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Normal Simple
|
| 2 |
+
his next work saturday follows an especially eventful day in the life of a successful neurosurgeon. his next work at saturday will be a successful neurosurgeon.
|
| 3 |
+
the tarantula the trickster character spun a black cord and attaching it to the ball crawled away fast to the east pulling on the cord with all his strength. the tarantuala the trickster spun a black chord and attached it to the ball crawling away fast to the east and pulling the chord with all his strength.
|
| 4 |
+
there he died six weeks later on january. he died six weeks later on january th.
|
| 5 |
+
they are culturally akin to the coastal peoples of papua new guinea. their culture is similar to the culture of the coastal peoples of papua new guinea.
|
| 6 |
+
following the drummers are dancers who often play the sogo lrb a tiny drum that makes almost no sound rrb and tend to have more elaborate even acrobatic choreography. the drummers are dancers and often play the sogo which they tend to have arcobatic choreography.
|
| 7 |
+
the spacecraft consists of two main elements the nasa cassini orbiter named after the italian french astronomer giovanni domenico cassini and the esa huygens probe named after the dutch astronomer mathematician and physicist christiaan huygens. the spacecraft is having two main parts. one is known as nasa cassini orbiter. it is named after giovanni domenico cassini an italian french astronomer. the other part is known as esa huygens probe. it is named after christiaan huygens. he was a dutch astronomer mathematician and physicist.
|
| 8 |
+
alessandro lrb sandro rrb mazzola lrb born november rrb is an italian former football player. alessandro mazzola is an italian former football player.
|
| 9 |
+
it was originally thought that the debris thrown up by the collision filled in the smaller craters. it was thought that the debris thrown up by the collision filled the smaller craters.
|
| 10 |
+
graham attended wheaton college from to when he graduated with a ba in anthropology. graham attended wheaton college from to and graduated with a ba in anthropology.
|
| 11 |
+
he was also named sportsman of the year by sports illustrated. sports illustrated named him sportsman of the year in.
|
| 12 |
+
fives is a british sport believed to derive from the same origins as many racquet sports. fives a british sport came from the same games as many racquet sports.
|
| 13 |
+
for example king bhumibol was born on monday so on his birthday throughout thailand will be decorated with yellow color. all over thailand the color yellow will be used to celebrate king bhumibal
|
| 14 |
+
nevertheless tagore emulated numerous styles including craftwork from northern new ireland haida carvings from the west coast of canada lrb british columbia rrb and woodcuts by max pechstein. tagore emulated various styles including craftwork from northern new ireland carvings from the west coast of canada and woodcuts by max pechstein neverthelessly
|
| 15 |
+
she performed for president reagan in s great performances at the white house series which aired on the public broadcasting service. she did a show for president reagan in s great performances at the white house series which was shown on television on the public broadcasting service.
|
| 16 |
+
perry saturn lrb with terri rrb defeated eddie guerrero lrb with chyna rrb to win the wwf european championship lrb rrb saturn pinned guerrero after a diving elbow drop. perry saturn and terri defeated eddie guerrero and chyna to win the wwf european championship lrb rrb saturn pinned guerrero after a diving elbow drop.
|
| 17 |
+
she remained in the united states until when she and her husband returned to france. she stayed in the united states until then she and her husband went to france.
|
| 18 |
+
he also completed two collections of short stories entitled the ribbajack other curious yarns and seven strange and ghostly tales. he also completed two collections of short storeis. the one title was the ribbajack other curious yarns. the other one was titled as seven strange and ghostly tales.
|
| 19 |
+
at the voyager images ophelia appears as an elongated object the major axis pointing towards uranus. at the voyager pictures ophelia appears as a stretched object. a stretched object was the major axis. it pointing towards uranus.
|
| 20 |
+
the british decided to eliminate him and take the land by force. the british decided to put an end to him and take the land by force.
|
| 21 |
+
some towns on the eyre highway in the south east corner of western australia between the south australian border almost as far as caiguna do not follow official western australian time. there are some towns in western australia that do not follow official western australian time.
|
| 22 |
+
in architectural decoration small pieces of colored and iridescent shell have been used to create mosaics and inlays which have been used to decorate walls furniture and boxes. small pieces of colored and shiny shell has been used to decorate walls furniture and boxes.
|
| 23 |
+
the other incorporated cities on the palos verdes peninsula include rancho palos verdes rolling hills estates and rolling hills. rancho palos verdes rolling hills estates and rolling hills are three cities on the palos verdes peninsula.
|
| 24 |
+
fearing that drek will destroy the galaxy clank asks ratchet to help him find the famous superhero captain qwark in an effort to stop drek. clank asks ratchet to help him find the famous superhero captain qwark because he is afraid that drek will try to destroy the galaxy and wants to stop him.
|
| 25 |
+
he advocates applying a user centered design process in product development cycles and also works towards popularizing interaction design as a mainstream discipline. he favors product development cycles that features an easy to use design process and works towards bringing interaction design into mainstream popuarity.
|
| 26 |
+
it is theoretically possible that the other editors who may have reported you and the administrator who blocked you are part of a conspiracy against someone half a world away they ve never met in person. it is possible that the other editors who may have reported you is a part of the conspiracy. similarly the administrator who blocked you may also be a part of the conspiracy. the conspiracy is against someone they have not met in prison.
|
| 27 |
+
working group i assesses scientific aspects of the climate system and climate change. working group i makes note of climate system and climate change
|
| 28 |
+
formal minor planet designations are number name combinations overseen by the minor planet center a branch of the iau. formal minor planet designations are number name combinations overseen by the minor planet center.
|
| 29 |
+
as a result although many mosques will not enforce violations both men and women when attending a mosque must adhere to these guidelines. although many mosques will not enforce rules both men and women when there must follow these rules
|
| 30 |
+
mariel of redwall is a fantasy novel by brian jacques published in. mariel of redwall is a fiction novel in the category of fantasy by author brian jacques published in.
|
| 31 |
+
ryan prosser lrb born july rrb is a professional rugby union player for bristol rugby in the guinness premiership. ryan prosser who was born on july is a professional rugby union player he has played for briston rugby in the guinness premiership.
|
| 32 |
+
like previous assessment reports it consists of four reports three of them from its working groups. the assessment report contiains four reports just like previous reports and three of them are from working groups.
|
| 33 |
+
this stamp remained the standard letter stamp for the remainder of victoria s reign and vast quantities were printed. this stamp stayed the standard letter stamp for the rest of victoria s reign and many were printed.
|
| 34 |
+
the international fight league was an american mixed martial arts lrb mma rrb promotion billed as the world s first mma league. the world s first mma league was the international fight league and american mixed martial arts.
|
| 35 |
+
giardia lamblia lrb synonymous with lamblia intestinalis and giardia duodenalis rrb is a flagellated protozoan parasite that colonises and reproduces in the small intestine causing giardiasis. giardia lamblia is a flagellated protozoan parasite that colonises and reproduces in the small intestine causing giardiasis.
|
| 36 |
+
this was the area east of the mouth of the vistula river later sometimes called prussia proper. this area which later was sometimes called prussia proper was east of the place where the vistula river begins.
|
| 37 |
+
after graduation he returned to yerevan to teach at the local conservatory and later he was appointed artistic director of the armenian philarmonic orchestra. fter graduation he came back to yerevan to teach at the local conservatory and then he was appointed artistic director of the armenian philarmonic orchestra.
|
| 38 |
+
weelkes was later to find himself in trouble with the chichester cathedral authorities for his heavy drinking and immoderate behaviour. weelkes later found himself in trouble with the chichester cathedral authorities for his heavy drinking and immoderate behavior.
|
| 39 |
+
he is also a member of another jungiery boyband club. he is also a member of club which is another jungiery boyband.
|
| 40 |
+
the apostolic tradition attributed to the theologian hippolytus attests the singing of hallel psalms with alleluia as the refrain in early christian agape feasts. the apostolic tradition connected to the scientist hippolytus who is an expert in theology starts the singing of hallel psalms with alleluia as the repeated line in early christian lovable and wonderful festivals.
|
| 41 |
+
it was the first asteroid to be discovered by a spacecraft. it was the first asteroid discovered by a spacecraft.
|
| 42 |
+
it continues as the bohemian switzerland in the czech republic. it is still called as the bohemian switzerland in the czech republic.
|
| 43 |
+
this leads to consumer confusion when lrb rrb bytes is referenced as mb lrb megabyte rrb instead of mib. the consumer gets confused when bytes is called mb instead of mib.
|
| 44 |
+
the incident has been the subject of numerous reports as to ethics in scholarship. the incident has been the subject of numerous reports regarding scholarship ethics.
|
| 45 |
+
they are castrated so that the animal may be more docile or may put on weight more quickly. they are castrated so that the animal is docile or may put on weight quickly.
|
| 46 |
+
seventh sons have strong knacks lrb specific magical abilities rrb and seventh sons of seventh sons are both extraordinarily rare and powerful. seventh sons have strong knacks lrb specific magical abilities rrb and he is extraordinarily rare and powerful.
|
| 47 |
+
benchmarking conducted by passmark software highlights the version s second install time second scan time and mb memory utilization. passmark software tested standards of version and the highlights are second install time second scan time and mb memory utilization.
|
| 48 |
+
if there are no strong land use controls buildings are built along a bypass converting it into an ordinary town road and the bypass may eventually become as congested as the local streets it was intended to avoid. if there are no strong land use controls the bypass as a result may become congested. this is because buildings are built along a bypass converting it into an ordinary town road. the byepass is intended to avoid such congestion.
|
| 49 |
+
it is also a starting point for people wanting to explore cooktown cape york peninsula and the atherton tableland. it is a starting point for people wanting to explore cooktown cape york peninsula and atherton tableland.
|
| 50 |
+
bruises often induce pain but are not normally dangerous. bruises often hurt but are not normally dangerous.
|
| 51 |
+
tickets can be retailed for national rail services the docklands light railway and on oyster card. tickets can be retailed for national rail services and the docklands light railway on oyster card.
|
| 52 |
+
the historical method comprises the techniques and guidelines by which historians use primary sources and other evidence to research and then to write history. in writing history there is a method called the historical method which uses primary sources and other evidence to research the historical events.
|
| 53 |
+
the sheer weight of the continental icecap sitting on top of lake vostok is believed to contribute to the high oxygen concentration. the lake vostok has a very large weight of the continental icecap on its waters. the high oxygen concentration of the lake water may because of this icecap.
|
| 54 |
+
aliteracy lrb sometimes spelled alliteracy rrb is the state of being able to read but being uninterested in doing so. aliteracy is being able to read but uninterested in read.
|
| 55 |
+
mifepristone is a synthetic steroid compound used as a pharmaceutical. pharmaceutical has used a mifepristone is a synthetic steroid
|
| 56 |
+
shortly after attaining category status the outer convection of the hurricane became ragged. shortly after reaching category status the outer convection of the hurricane became worn out
|
| 57 |
+
the equilibrium price for a certain type of labor is the wage. the balanced price for any kind of labor is called a wage.
|
| 58 |
+
convinced that the grounds were haunted they decided to publish their findings in a book an adventure lrb rrb under the pseudonyms of elizabeth morison and frances lamont. the authors using pseudonyms elizabeth morison and frances lamont published the book an adventure in written about particular hauntings.
|
| 59 |
+
he left a detachment of troops to garrison the newly conquered region. he left troops to garrison the newly conquered region.
|
| 60 |
+
the depression moved inland on the th as a circulation devoid of convection and dissipated the next day over brazil where it caused heavy rains and flooding. on th depression moved inland as convection less circulation and after weakening over brazil it caused heavy rains and flooding.
|
| 61 |
+
the characters are foul mouthed extensions of their earlier characters pete and dud. the characters speak bad language of their earlier characters pete and dud
|
| 62 |
+
in culver ran for iowa secretary of state and was victorious. in culver successfully ran for iowa secretary of state.
|
| 63 |
+
in mark messier took the hart over ray bourque by a margin of two votes the difference being a single first place vote. in mark messier took hart over ray bourque by two votes the difference being a first place vote.
|
| 64 |
+
shade sets the main plot of the novel in motion when he impetuously defies that law and inadvertently initiates a chain of events that leads to the destruction of his colony s home forcing their premature migration and his separation from them. the main plot of the novel is when shade defies the law and sets off a chain of events that lead to the destuction of his colony s home and becoming separated from them.
|
| 65 |
+
the amazon basin is the part of south america drained by the amazon river and its tributaries. the amazon basin is the part of south america drained by the amazon river and those who pay tribute to it.
|
| 66 |
+
the two former presidents were later separately charged with mutiny and treason for their roles in the coup and the gwangju massacre. the two former presidents were later charged each on their own with mutiny and treason for their roles in the coup and the gwangju massacre.
|
| 67 |
+
moderate to severe damage extended up the atlantic coastline and as far inland as west virginia. there was moderate to severe damage all the way up the atlantic coastline and as far inland as west virginia.
|
| 68 |
+
because the owner tends to be unaware these computers are metaphorically compared to zombies. these computers are metaphorically compared to zombies as the owner was not concious about it
|
| 69 |
+
for example the stylebook of the associated press is updated annually. the stylebook of the associated press is updated yearly.
|
| 70 |
+
the four canonical texts are the gospel of matthew gospel of mark gospel of luke and gospel of john probably written between ad and lrb see also the gospel according to the hebrews rrb. gospels matthew mark luke and john were most likey written after christ
|
| 71 |
+
development stable releases are rare but there are often subversion snapshots which are stable enough to use. development stable releases are rare but there are often subversion snapshots lrb an informal photograph taken quickly rrb which are stable enough to use.
|
| 72 |
+
cogeneration lrb also combined heat and power chp rrb is the use of a heat engine or a power station to simultaneously generate both electricity and useful heat. cogeneration is a combination of heart and power to simultaneously generate both electricity and useful heat.
|
| 73 |
+
on occasion the male den master will also allow a second male into the den the reason for this is unclear. on opportunity the male den master will let a second male inside of the den the basis for this is poorly explained.
|
| 74 |
+
below are some useful links to facilitate your involvement. below are some useful links to help your involvement.
|
| 75 |
+
he served as the prime minister of egypt between and and again from and. he served as prime minster of egypt from and as well as from through.
|
| 76 |
+
she was left behind lrb explanations for this vary rrb when the rest of the nicole os were moved to the mainland. people have different thinking about why she was left behind when the nicolenos were moved to the main part of the country.
|
| 77 |
+
chauvin was embarrassed to receive his award and initially indicated that he may not accept it. chauvin was embarrassed to get his award and at first said that he may not accept it.
|
| 78 |
+
later esperanto speakers began to see the language and the culture that had grown up around it as ends in themselves even if esperanto is never adopted by the united nations or other international organizations. later esperanto speakers started to see the language and culture that had grown up around it as ends in themselves though esperanto is never accepted by the united nations of other international organizations.
|
| 79 |
+
dry air wrapping around the southern periphery of the cyclone eroded most of the deep convection by early on september. early september dry air wrapping around the southern area of the cyclone caused most of the heat to leave.
|
| 80 |
+
a few animals have chromatic response changing color in changing environments either seasonally lrb ermine snowshoe hare rrb or far more rapidly with chromatophores in their integument lrb the cephalopod family rrb. some animals change color when their environments change a process called chromatic response either seasonally as with ermine and snowshoe hare or far more rapidly with chromoa tophonres in theri integument lrb the cephalapod family. rrb
|
| 81 |
+
this closely resembles the unix philosophy of having multiple programs each doing one thing well and working together over universal interfaces. this looks like the unix idea of having several programs with each doing one thing and working together.
|
| 82 |
+
he came from a musical family his mother larue was an administrative assistant and singer and his father keith brion was a band director at yale. his was a musical family as his mother larue was a secretary and singer while his father keith brion was a band director at yale.
|
| 83 |
+
the largest populations of mennonites are in canada democratic republic of congo and the united states but mennonites can also be found in tight knit communities in at least countries on six continents or scattered amongst the populace of those countries. the largest populations of mennonites are in canada democratic republic of congo and the united states. mennonites also live in close communities in at least countries on six continents or scattered throughout the populations of those countries.
|
| 84 |
+
acanthopholis s armour consisted of oval plates set almost horizontally into the skin with spikes protruding from the neck and shoulder area along the spine. acanthopholis s armour was made up of oval plates that were put into the skin lengthwise and had spikes that jutted out from the neck and shoulder area across the spine.
|
| 85 |
+
conversely bills proposed by the law commission and consolidation bills start in the house of lords. the bills proposed by the law commission and consolidation bills start in the house of lords contrarily.
|
| 86 |
+
reflection nebulae are usually blue because the scattering is more efficient for blue light than red lrb this is the same scattering process that gives us blue skies and red sunsets rrb. reflection nebulae are commonly blue because the scattering is more powerful for blue light than red lrb this is the same reason for the sky appears in blue and the sunset in red colors
|
| 87 |
+
macgruber starts asking for simple objects to make something to defuse the bomb but he is later distracted by something lrb usually involving his personal life rrb that makes him run out of time. macgruber asked for many items to help shut the bomb off but he was distracted and ran out of time.
|
| 88 |
+
the pad called for the resignation of the governments of thaksin shinawatra samak sundaravej and somchai wongsawat whom the pad accused of being proxies for thaksin. the pad called for thakin shinaatra samak sundaravej and somchai to step down as government leaders because the pad considered them to be used by thaksin.
|
| 89 |
+
while at kahn he was chief architect for the fisher building in. he was a chief architect fot fisher building in when he was at kahn
|
| 90 |
+
he excuses himself because he has to leave for rehearsal and he and dr. sch n leave. he excuses himself because he has to leave for rehearsal dr.dr. sch n leave
|
| 91 |
+
the sheppard line currently has fewer users than the other two subway lines and shorter trains are run. the sheppard line not only has fewer users than the other two subway lines it also runs shorter trains.
|
| 92 |
+
it has a capacity of making it the largest stadium in europe and the eleventh largest in the world. it can seat. which makes it the largest stadium in europe and the eleventh largest in the world.
|
| 93 |
+
in december ten boom was honored as one of the righteous among the nations by the state of israel. in december ten boon was honored as part of the righteous amoung the nations by the state of israel.
|
| 94 |
+
terms such as undies for underwear and movie for moving picture are oft heard terms in english. words like undies movie are oft heard terms in english.
|
| 95 |
+
jurisdiction draws its substance from public international law conflict of laws constitutional law and the powers of the executive and legislative branches of government to allocate resources to best serve the needs of its native society. power moves towards its material from community national boundaries opposition of laws organizational law and ability of the administrative and creative offshoots of government to assign support to best serve the needs of its native society.
|
| 96 |
+
despite this farrenc was paid less than her male counterparts for nearly a decade. even so farrenc was paid less than her male peers for nearly years.
|
| 97 |
+
gumbasia was created in a style vorkapich taught called kinesthetic film principles. vorkapich taught gumbasia in a style called kinesthetic film principles.
|
| 98 |
+
the lawyer brandon lrb waise lee rrb became his idol and mk sun grew up to be a lawyer. the lawyer brandon lrb waise lee rrb was his idol as mk sun grew up to be a lawyer.
|
| 99 |
+
military career donaldson enlisted in the australian army on june. donaldson inlisted in australia s army on june to start his military career.
|
| 100 |
+
the kindle features level grayscale display improved battery life percent faster page refreshing a text to speech option to read the text aloud and overall thickness reduced from. to. inches lrb. millimeters rrb. the kindle features grayscale display improved battery life and overall thickness reduced.
|
| 101 |
+
yoghurt or yogurt is a dairy product produced by bacterial fermentation of milk. yoghurt or yogurt is a milk based food made by bacterial fermentation of milk.
|
| 102 |
+
seventy five defencemen are in the hall of fame more than any other current position while only goaltenders have been inducted. out of seventy five defencemen in the hall of fame only goaltenders have been inducted.
|
| 103 |
+
alternative views on the subject have been proposed throughout the centuries lrb see below rrb but all were rejected by mainstream christian bodies. different views on the subject have been brought up over the centuries lrb see below rrb but all were rejected by mainstream christian bodies
|
| 104 |
+
the album however was banned from many record stores nationwide. the album is banned from many record stores nationwide.
|
| 105 |
+
the company opened twice as many canadian outlets as mcdonald s wendy s confirms tim hortons ipo by march ottawa business journal december and system wide sales also surpassed those of mcdonald s canadian operations as of. the company opened two times as many restaurants in canada as mcdonald s wendy s confirms tim hortons ipo by march ottawa business journal december and sales throughout the company were also greater than those of mcdonald s canadian business as of.
|
| 106 |
+
he won the presidential election held on march with. of the popular vote. he conquered the presidential poll on march with. of the popular vote.
|
| 107 |
+
in she was the only female entertainer allowed to perform in saudi arabia. as a female entertainer she alone was allowed to perform in saudi arabia during
|
| 108 |
+
offenbach s numerous operettas such as orpheus in the underworld and la belle h l ne were extremely popular in both france and the english speaking world during the s and s. offenbach s a great number of operettas such as orpheus in the underworld and la beautiful woman helene were greatly pleasing to all in both france and the english talking earth during the s and s
|
| 109 |
+
roof tiles dating back to the tang dynasty with this symbol have been found west of the ancient city of chang an lrb modern day xian rrb. roof tiles during tang dynasty with this symbol have been found west of the ancient city of chang an or modern day xian.
|
| 110 |
+
by most accounts the instrument was nearly impossible to control. it was nearly impossible to control the instrument by most accounts.
|
| 111 |
+
characteristics radar observations indicate a fairly pure iron nickel composition. radar testing shows composition of mostly iron nickel.
|
| 112 |
+
lo che harbours the installations of onyx the swiss interception system for electronic intelligence gathering. lo che harbours of onyx is the swiss interception system for electronic intelligence gathering.
|
| 113 |
+
a matchbook is a small cardboard folder lrb matchcover rrb enclosing a quantity of matches and having a coarse striking surface on the exterior. a matchbook is a small cardboard folder lrb or matchcover rrb that holds some matches and has a rough area on the outside.
|
| 114 |
+
she was among the first doctors to object to cigarette smoking around children and drug use in pregnant women. she was one of the first doctors that said cigarette smoking near children and drug use in pregnant women was not safe.
|
| 115 |
+
defiantly she vowed to never renounce the commune and dared the judges to sentence her to death. she refused to give up the commune and prefered the death sentence
|
| 116 |
+
oel manga series graystripe s trilogy there is a three volume original english language manga series following graystripe between the time that he was taken by twolegs in dawn until he returned to thunderclan in the sight. oel manga occuring in sequence graystripe s trilogy. there s a three quantity volume earliest english language manga series following graystripe linking the time that it was accepted by twolegs in dawn up until he came back to thunderclan in the sight.
|
| 117 |
+
samovar porter lrb rrb p. syrians did not congregate in urban enclaves many of the immigrants who had worked as peddlers were able to interact with americans on a daily basis. samovar porter lrb rrb p. syrians did not get together in city groups many of the immigrants who had worked as sellers on the street were able to talk with americans every day.
|
| 118 |
+
he was also famous for his prints book covers posters and garden metalwork furniture. he is famous for prints book covers posters and garden metalwork furniture.
|
| 119 |
+
during childhood she suffered from collapsed lungs twice she had pneumonia times a year a ruptured appendix and had a tonsillar cyst. for two times she had lung disorder when she was a child. she was also suffered from pneumonia to times a year. she was also affected by appendix disorder and had a tonsillar cyst. all these happened during her childhood period.
|
| 120 |
+
small value inductors can also be built on integrated circuits using the same processes that are used to make transistors. both small value inductors and transistors can be built on integrated circuits.
|
| 121 |
+
no skater has yet accomplished a quadruple axel in competition. quadruple axel at a competition is yet to be fulfilled by any skater
|
| 122 |
+
from the telephone exchange the port jackson district commandant could communicate with all military installations on the harbour. by use of the telephone exchange the port jackson district commandant could talk to all military installations on the harbour.
|
| 123 |
+
however even to those who enter the prayer hall of a mosque without the intention of praying there are still rules that apply. however even to those who enter the prayer hall of a mosque without the purpose of praying there are some rules applicable to them.
|
| 124 |
+
it is described as pointed in the face and about the size of a rabbit. it is about the size of a rabbit and has a pointed face.
|
| 125 |
+
human skin hues can range from very dark brown to very pale pink. the colors of human skin can be very dark brown or very pale pink or anywhere in between.
|
| 126 |
+
nupedia was founded on march under the ownership of bomis inc a web portal company. bomis inc a web portal company founded nupedia on march.
|
| 127 |
+
notable features of the design include key dependent s boxes and a highly complex key schedule. notable features of the design include s boxes which is a highly complex key schedule.
|
| 128 |
+
the primavera is a painting by the italian renaissance painter sandro botticelli c.. painted around the primavera is a painting by the italian renaissance painter sandro botticelli.
|
| 129 |
+
new south wales s largest city and capital is sydney. largest city new south wales and its capital is sydney.
|
| 130 |
+
the polymer is most often epoxy but other polymers such as polyester vinyl ester or nylon are also sometimes used. polymers such as polyester vinylester or nylon are also sometimes used as epoxies.
|
| 131 |
+
stands were eventually added behind each set of goals during the s and s as the ground began to be modernised. during the s and s the ground got more modern and stands were eventually added behind each set of goals.
|
| 132 |
+
a bastion on the eastern approaches was built later. later a bastion was built on the eastern approaches.
|
| 133 |
+
events europe july battle of stiklestad lrb norway rrb olav haraldsson loses to his pagan vassals and is killed in the battle. among events that happened in europe on july. was the battle of stiklesstand in norway in which olav haraldsson lost his pagan subjects and his life.
|
| 134 |
+
others have theorized that tresca was eliminated by the nkvd as retribution for criticism of the stalin regime of the soviet union. others have made theories that tresca was took away by the nkvd as retribution for views put forward as to errors of the stalin system of things of the soviet union.
|
| 135 |
+
schuschnigg immediately responded publicly that reports of riots were false. schuschnigg responded publicly that reports of riots were false.
|
| 136 |
+
depending on the context another closely related meaning of constituent is that of a citizen residing in the area governed represented or otherwise served by a politician sometimes this is restricted to citizens who elected the politician. based on the given situation one more nearest meaning of constituent would be a citizen residing in the area governed represented or otherwise served by a politician at times this is limited to the citizens who elected the politician.
|
| 137 |
+
wario land the wario land series is a platforming series that started with wario land super mario land a spin off of the super mario land series. the platform series wario land from the wario land series stared with the super mario land series.
|
| 138 |
+
these attacks may have been psychological in origin rather than physical. these attacks may have been psychological in origin not physical.
|
| 139 |
+
furthermore spectroscopic studies have shown evidence of hydrated minerals and silicates which indicate rather a stony surface composition. furthermore spectroscopic studies have shown proof of hydrated minerals and silicates which points to rather a stony surface material.
|
| 140 |
+
she became the authoritative editor of her husband s works for breitkopf und h rtel. she became the editor of her husband s works for breitkopf und hartel.
|
| 141 |
+
mercury is similar in appearance to the moon it is heavily cratered with regions of smooth plains has no natural satellites and no substantial atmosphere. mercury is similar in appearance to the moon with heavily cratered regions of smooth plains and no natural satellites or substantial atmosphere.
|
| 142 |
+
geography the town lies in the limmat valley between baden and z rich. geographically the town lies in the limmat valley between baden and zurich.
|
| 143 |
+
these ideally provide excellent habitat for chinkara hog deer and blue bull. ideally these make an excellent breeding ground for chinkara hog deer and blue bull.
|
| 144 |
+
after the sena dynasty dhaka was successively ruled by the turkish and afghan governors descending from the delhi sultanate before the arrival of the mughals in. before the mughals arrived in after the sena dynasty dhaka was ruled for a long time by turkish and afghan governors that descended from the delhi sultanate.
|
| 145 |
+
for rowling this scene is important because it shows harry s bravery and by retrieving cedric s corpse he demonstrates selflessness and compassion. for rowling this incident is important becaus it shows harry s courage and by regaining cedric d corpse he establishes no concern for oneself and sympathy.
|
| 146 |
+
on june he and fellow raf members jan carl raspe and holger meins were apprehended after a lengthy shootout in frankfurt. he and fellow raf members jan carl raspe and holger meins were taken hold after a lengthy shootout. it was in frankfurt on june.
|
| 147 |
+
together they formed new music manchester a group committed to contemporary music. they formed the new music manchester band and sang contemporary
|
| 148 |
+
the compact and intense hurricane caused extreme damage in the upper florida keys as a storm surge of approximately to feet affected the region. small but intense the hurricane caused a lot of damage in the upper florida keys when a surge of nearly to feet hit the area.
|
| 149 |
+
it is now the site of meher baba s samadhi lrb tomb shrine rrb as well as facilities and accommodations for pilgrims. it is now the meher baba s tomb shrine and a place for pilgrims.
|
| 150 |
+
salem is a city in essex county massachusetts united states. salem is a city in essex county massachusetts.
|
| 151 |
+
forty nine species of pipefish and nine species of seahorse have been recorded. forty nine species of pipefish and nine species of seahorse are recorded.
|
| 152 |
+
therefore these pdfs can not be distributed without further manipulation if they contain images. if any of these pdfs contain pictures then they require additional processing before they can be issued
|
| 153 |
+
heavy rain fell across portions of britain on october causing localized accumulation of flood waters. heavy rain fell across britain on october causing accumulation of flood waters.
|
| 154 |
+
ohio state s library system encompasses twenty one libraries located on its columbus campus. ohio state s library system has twenty one libraries located on its campus.
|
| 155 |
+
in other developments both iceland and greenland accepted the overlordship of norway but scotland was able to repulse a norse invasion and broker a favorable peace settlement. both iceland and greenland accepted the ruler of norway but scotland was able to prevent a norwegian invasion and a negotiate a peace settlement.
|
| 156 |
+
the singles from the album included by the way the zephyr song ca n t stop dosed and universally speaking. some singles from the album are by the way the zephyr song ca n t stop dosed and universally speaking.
|
| 157 |
+
in april minix became free open source software under a permissive free software licence but by this time other operating systems had surpassed its capabilities and it remained primarily an operating system for students and hobbyists. in april minix became free or open source software under a non restrictive free software licence but by this time other operating systems had exceeded its capabilities and it continued to be mainly an operating system for students and hobbyists.
|
| 158 |
+
the body color varies from medium brown to gold ish to beige white and occasionally is marked with dark brown spots especially on the limbs. the body color varies from medium brown to goldish to beige white and sometimes is marked with dark brown spots.
|
| 159 |
+
the latter provided audiences with the sort of information later provided by intertitles and can help historians imagine what the film may have been like. the latter which gave audiences the same sort of information later audience members would gett from subtitles can help historians imagine what the film may have been like.
|
| 160 |
+
that is because real estate businesses and other assets in the underground economies of the third world can not be used as collateral to raise capital to finance industrial and commercial expansion. that is beacuase real estate businesses and other assets in the underground economies of the third world can not be used as collateral to raise capital.
|
| 161 |
+
ned and dan advanced to the police camp ordering them to surrender. ned and dan told by the police to surrender.
|
| 162 |
+
a mutant is a type of fictional character that appears in comic books published by marvel comics. mutants are fictional characters from the x men comic books published by marvel.
|
| 163 |
+
the sat reasoning test lrb formerly scholastic aptitude test and scholastic assessment test rrb is a standardized test for college admissions in the united states. the sat is a standardized test for college admissions in the united states.
|
| 164 |
+
some reports read that various factors increase the likelihood of both paralysis and hallucinations. some reports said that various things make it more possible to have paralysis and hallucinations.
|
| 165 |
+
his sentence was transportation to australia for seven years. his sentence was carried to australia for seven years.
|
| 166 |
+
her notorious friendship with the russian mystic grigori rasputin was also an important factor in her life. her well known relation with the russian mystic grigori rasputin was additionaly an important number in her life.
|
| 167 |
+
the term dorsal refers to anatomical structures that are either situated toward or grow off that side of an animal. the word dorsal means any body part that grows off that side of an animal or that grows toward that side of an animal.
|
| 168 |
+
the term protein itself was coined by berzelius after mulder observed that all proteins seemed to have the same empirical formula and might be composed of a single type of lrb very large rrb molecule. the term protein was made by berzelius after mulder
|
| 169 |
+
after the jerilderie raid the gang laid low for months evading capture. after the jerilderie raid the gang laid low for months.
|
| 170 |
+
in an extension was added curving north from union station below university avenue and queen s park to near bloor street where it turned west to terminate at st. george and bloor streets. an extension was added in which curved north from union station below university avenue and queens park reaching nearly to bloor street and ending on the west side at st. george and bloor streets.
|
| 171 |
+
it is located on an old portage trail which led west through the mountains to unalakleet. it s near an old portage trail that led west to unalakleet through the mountains.
|
| 172 |
+
people with cardiomyopathy are often at risk of arrhythmia or sudden cardiac death or both. arrhythmia or heart beat disorder and sudden cardiac arrest are often associated with cardiomyopathy. cardiomyopathy is deterioration of heart muscle and persons with this disease may subject to arrhythmia or sudden cardiac arrest. sometimes both may happen at once.
|
| 173 |
+
as the largest sub region in mesoamerica it encompassed a vast and varied landscape from the mountainous regions of the sierra madre to the semi arid plains of northern yucat n. as the largest sub region in mesoamerica it is a vast and varied landscape from the mountainous regions of the sierra madre to the plains of yucatan.
|
| 174 |
+
google subsequently made the comic available on google books and their site and mentioned it on its official blog along with an explanation for the early release. google made the comic available on google books mentioned it on their blog explaining the early release.
|
| 175 |
+
the book political economy was published in but had limited classroom adoption. the book political economy was published in but was not used in many classrooms.
|
| 176 |
+
he toured with the ipo in the spring of for their first ever performance in the soviet union with concerts in moscow and leningrad and toured with the ipo again in performing in china and india. for their first ever performance in the soviet union he toured with the ipo spring from to with china and india
|
| 177 |
+
napoleonic wars austrian general mack surrenders his army to the grand army of napoleon at ulm reaping napoleon over prisoners and inflicting casualties on the losers. austrian general mack surrenders his army to grand army of napoleon at ulm.
|
| 178 |
+
it has long been the economic centre of northern nigeria and a centre for the production and export of groundnuts. it has long been the economic centre of norther nigeria along with the center for production and export of groundnuts.
|
| 179 |
+
a majority of south indians speak one of the five dravidian languages kannada malayalam tamil telugu and tulu. most south indians speak one of the five dravidian languages kannada malayalam tamil telugu and tulu.
|
| 180 |
+
meteora earned the band multiple awards and honors. meteora won many awards and honors for the band.
|
| 181 |
+
in the th century slavs started to move into the area. in the th century slaves started to move in the area.
|
| 182 |
+
winchester is a city in scott county illinois united states. winchester is a city located in scott county illinois united states
|
| 183 |
+
out of participants in the national casting she was chosen among the candidates to appear on the tv show. out of participants in the national casting and it was selected the candidates to appear on the tv show.
|
| 184 |
+
the latter device can then be designed and used in less stringent environments. the device can be designed for use in less exact environments.
|
| 185 |
+
gimnasia hired first famed colombian trainer francisco maturana and then julio c sar falcioni but both had limited success. gimnasia hired francisco maturana a columbian trainer and then julio c sar falcioni but they were not very successful.
|
| 186 |
+
brighton is a city in washington county iowa united states. brighton is a city in washington county iowa.
|
| 187 |
+
pauline returned in the game boy remake of donkey kong in and later mario vs. donkey kong march of the minis in although the character is now described as mario s friend. pauline turned in the game boy remake of donkey kong in and later mario vs. donkey kong march of the minis in even though the character is now mario s friend.
|
| 188 |
+
his real date of birth was never recorded but it is believed to be a date between and. since his actual date of birth was not recorded it is believed to be between.
|
| 189 |
+
this quantitative measure indicates how much of a particular drug or other substance lrb inhibitor rrb is needed to inhibit a given biological process lrb or component of a process i.e. an enzyme cell cell receptor or microorganism rrb by half. this quantitative measure indicates how much of a drug or other substance is needed to inhibit a biological process by half.
|
| 190 |
+
there he had one daughter later baptized as mary ann fisher power to ann lrb e rrb power. he had one daughter named mary ann fisher power who was later baptized to ann lrb e rrb power.
|
| 191 |
+
during an interview edward gorey mentioned that bawden was one of his favorite artists lamenting the fact that not many people remembered or knew about this fine artist. during an interview edward gorey said that bawden was one of his favorite artists and is saddened by the fact that not many people remembered or knew about this fine artist.
|
| 192 |
+
gable also earned an academy award nomination when he portrayed fletcher christian in s mutiny on the bounty. gable also earned an academy award nomination for his portrayal of fletcher christian in the film mutiny on the bounty.
|
wikilarge-dataset/wiki.full.aner.ori.train.95.tsv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3173629a89d5b4a7311262bc7c7fb2ec480c400330410e32730a75afe365dea0
|
| 3 |
+
size 36251904
|
wikilarge-dataset/wiki.full.aner.ori.valid.95.tsv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|