Spaces:
No application file
No application file
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +119 -0
- .gitignore +142 -0
- README.md +100 -8
- chimp/.gitignore +160 -0
- chimp/requirements.txt +21 -0
- chimp/src/config.py +81 -0
- chimp/src/dataset.py +31 -0
- chimp/src/model.py +62 -0
- chimp/src/predict.py +87 -0
- chimp/src/train.py +66 -0
- data/alm_task_data.csv +0 -0
- data/jack_line_item_ner_task.csv +0 -0
- data/jack_line_item_ner_task_v2.csv +0 -0
- data/line_item_and_alm_data.json +0 -0
- data/line_item_and_alm_data_v1.json +3 -0
- data_prep.py +28 -0
- demo.sh +5 -0
- gorilla/__pycache__/llama_attn_replace.cpython-310.pyc +0 -0
- gorilla/__pycache__/llama_attn_replace_sft.cpython-310.pyc +0 -0
- gorilla/api.py +0 -0
- gorilla/app.py +211 -0
- gorilla/code_interpreter.py +117 -0
- gorilla/ds_configs/stage2.json +23 -0
- gorilla/ds_configs/stage3.json +49 -0
- gorilla/eval.py +175 -0
- gorilla/fine-tune.py +206 -0
- gorilla/get_trainable_weights.py +37 -0
- gorilla/infer.py +143 -0
- gorilla/llama_attn_replace.py +477 -0
- gorilla/llama_attn_replace_sft.py +483 -0
- gorilla/merge_lora_weights_and_save_hf_model.py +100 -0
- gorilla/push_to_hub.py +5 -0
- gorilla/requirements.txt +19 -0
- gorilla/stream_jack.py +183 -0
- gorilla/streaming_llm/__init__.py +0 -0
- gorilla/streaming_llm/__pycache__/__init__.cpython-310.pyc +0 -0
- gorilla/streaming_llm/__pycache__/enable_streaming_llm.cpython-310.pyc +0 -0
- gorilla/streaming_llm/__pycache__/kv_cache.cpython-310.pyc +0 -0
- gorilla/streaming_llm/__pycache__/utils.cpython-310.pyc +0 -0
- gorilla/streaming_llm/enable_streaming_llm.py +38 -0
- gorilla/streaming_llm/kv_cache.py +119 -0
- gorilla/streaming_llm/pos_shift/__init__.py +0 -0
- gorilla/streaming_llm/pos_shift/__pycache__/__init__.cpython-310.pyc +0 -0
- gorilla/streaming_llm/pos_shift/__pycache__/modify_llama.cpython-310.pyc +0 -0
- gorilla/streaming_llm/pos_shift/modify_falcon.py +162 -0
- gorilla/streaming_llm/pos_shift/modify_llama.py +311 -0
- gorilla/streaming_llm/utils.py +112 -0
- gorilla/style.css +16 -0
- gorilla/supervised-fine-tune-qlora.py +345 -0
- gorilla/supervised-fine-tune.py +330 -0
.gitattributes
CHANGED
@@ -33,3 +33,122 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
data/line_item_and_alm_data_v1.json filter=lfs diff=lfs merge=lfs -text
|
37 |
+
venv/bin/python filter=lfs diff=lfs merge=lfs -text
|
38 |
+
venv/bin/python3 filter=lfs diff=lfs merge=lfs -text
|
39 |
+
venv/bin/python3.10 filter=lfs diff=lfs merge=lfs -text
|
40 |
+
venv/lib/python3.10/site-packages/Pillow.libs/libfreetype-82733d78.so.6.20.1 filter=lfs diff=lfs merge=lfs -text
|
41 |
+
venv/lib/python3.10/site-packages/Pillow.libs/libharfbuzz-e3b74c67.so.0.60821.0 filter=lfs diff=lfs merge=lfs -text
|
42 |
+
venv/lib/python3.10/site-packages/aiohttp/_http_parser.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
43 |
+
venv/lib/python3.10/site-packages/altair/vegalite/v5/schema/__pycache__/core.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text
|
44 |
+
venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda110.so filter=lfs diff=lfs merge=lfs -text
|
45 |
+
venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda110_nocublaslt.so filter=lfs diff=lfs merge=lfs -text
|
46 |
+
venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda111.so filter=lfs diff=lfs merge=lfs -text
|
47 |
+
venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda111_nocublaslt.so filter=lfs diff=lfs merge=lfs -text
|
48 |
+
venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda114.so filter=lfs diff=lfs merge=lfs -text
|
49 |
+
venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda114_nocublaslt.so filter=lfs diff=lfs merge=lfs -text
|
50 |
+
venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda115.so filter=lfs diff=lfs merge=lfs -text
|
51 |
+
venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda115_nocublaslt.so filter=lfs diff=lfs merge=lfs -text
|
52 |
+
venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda117.so filter=lfs diff=lfs merge=lfs -text
|
53 |
+
venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda117_nocublaslt.so filter=lfs diff=lfs merge=lfs -text
|
54 |
+
venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda118.so filter=lfs diff=lfs merge=lfs -text
|
55 |
+
venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda118_nocublaslt.so filter=lfs diff=lfs merge=lfs -text
|
56 |
+
venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda120.so filter=lfs diff=lfs merge=lfs -text
|
57 |
+
venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda120_nocublaslt.so filter=lfs diff=lfs merge=lfs -text
|
58 |
+
venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda121.so filter=lfs diff=lfs merge=lfs -text
|
59 |
+
venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda121_nocublaslt.so filter=lfs diff=lfs merge=lfs -text
|
60 |
+
venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda122.so filter=lfs diff=lfs merge=lfs -text
|
61 |
+
venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda122_nocublaslt.so filter=lfs diff=lfs merge=lfs -text
|
62 |
+
venv/lib/python3.10/site-packages/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_cython.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
63 |
+
venv/lib/python3.10/site-packages/debugpy/_vendored/pydevd/_pydevd_frame_eval/pydevd_frame_evaluator.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
64 |
+
venv/lib/python3.10/site-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
65 |
+
venv/lib/python3.10/site-packages/fontTools/feaLib/lexer.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
66 |
+
venv/lib/python3.10/site-packages/fontTools/misc/bezierTools.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
67 |
+
venv/lib/python3.10/site-packages/fontTools/pens/momentsPen.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
68 |
+
venv/lib/python3.10/site-packages/fontTools/qu2cu/qu2cu.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
69 |
+
venv/lib/python3.10/site-packages/fontTools/varLib/iup.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
70 |
+
venv/lib/python3.10/site-packages/gradio/frpc_linux_amd64_v0.2 filter=lfs diff=lfs merge=lfs -text
|
71 |
+
venv/lib/python3.10/site-packages/gradio/templates/cdn/assets/Index-5c805b1c.js.map filter=lfs diff=lfs merge=lfs -text
|
72 |
+
venv/lib/python3.10/site-packages/gradio/templates/frontend/assets/Index-62000a79.js.map filter=lfs diff=lfs merge=lfs -text
|
73 |
+
venv/lib/python3.10/site-packages/kiwisolver/_cext.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
74 |
+
venv/lib/python3.10/site-packages/matplotlib/_image.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
75 |
+
venv/lib/python3.10/site-packages/matplotlib/_path.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
76 |
+
venv/lib/python3.10/site-packages/matplotlib/_qhull.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
77 |
+
venv/lib/python3.10/site-packages/matplotlib/backends/_backend_agg.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
78 |
+
venv/lib/python3.10/site-packages/matplotlib/ft2font.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
79 |
+
venv/lib/python3.10/site-packages/numpy/core/_multiarray_umath.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
80 |
+
venv/lib/python3.10/site-packages/numpy/core/_simd.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
81 |
+
venv/lib/python3.10/site-packages/numpy.libs/libgfortran-040039e1.so.5.0.0 filter=lfs diff=lfs merge=lfs -text
|
82 |
+
venv/lib/python3.10/site-packages/numpy.libs/libopenblas64_p-r0-0cf96a72.3.23.dev.so filter=lfs diff=lfs merge=lfs -text
|
83 |
+
venv/lib/python3.10/site-packages/nvfuser/_C.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
84 |
+
venv/lib/python3.10/site-packages/nvidia/cublas/lib/libcublas.so.12 filter=lfs diff=lfs merge=lfs -text
|
85 |
+
venv/lib/python3.10/site-packages/nvidia/cublas/lib/libcublasLt.so.12 filter=lfs diff=lfs merge=lfs -text
|
86 |
+
venv/lib/python3.10/site-packages/nvidia/cuda_cupti/lib/libcheckpoint.so filter=lfs diff=lfs merge=lfs -text
|
87 |
+
venv/lib/python3.10/site-packages/nvidia/cuda_cupti/lib/libcupti.so.12 filter=lfs diff=lfs merge=lfs -text
|
88 |
+
venv/lib/python3.10/site-packages/nvidia/cuda_cupti/lib/libnvperf_host.so filter=lfs diff=lfs merge=lfs -text
|
89 |
+
venv/lib/python3.10/site-packages/nvidia/cuda_cupti/lib/libnvperf_target.so filter=lfs diff=lfs merge=lfs -text
|
90 |
+
venv/lib/python3.10/site-packages/nvidia/cuda_nvrtc/lib/libnvrtc-builtins.so.12.1 filter=lfs diff=lfs merge=lfs -text
|
91 |
+
venv/lib/python3.10/site-packages/nvidia/cuda_nvrtc/lib/libnvrtc.so.12 filter=lfs diff=lfs merge=lfs -text
|
92 |
+
venv/lib/python3.10/site-packages/nvidia/cudnn/lib/libcudnn_adv_infer.so.8 filter=lfs diff=lfs merge=lfs -text
|
93 |
+
venv/lib/python3.10/site-packages/nvidia/cudnn/lib/libcudnn_adv_train.so.8 filter=lfs diff=lfs merge=lfs -text
|
94 |
+
venv/lib/python3.10/site-packages/nvidia/cudnn/lib/libcudnn_cnn_infer.so.8 filter=lfs diff=lfs merge=lfs -text
|
95 |
+
venv/lib/python3.10/site-packages/nvidia/cudnn/lib/libcudnn_cnn_train.so.8 filter=lfs diff=lfs merge=lfs -text
|
96 |
+
venv/lib/python3.10/site-packages/nvidia/cudnn/lib/libcudnn_ops_infer.so.8 filter=lfs diff=lfs merge=lfs -text
|
97 |
+
venv/lib/python3.10/site-packages/nvidia/cudnn/lib/libcudnn_ops_train.so.8 filter=lfs diff=lfs merge=lfs -text
|
98 |
+
venv/lib/python3.10/site-packages/nvidia/cufft/lib/libcufft.so.11 filter=lfs diff=lfs merge=lfs -text
|
99 |
+
venv/lib/python3.10/site-packages/nvidia/cufft/lib/libcufftw.so.11 filter=lfs diff=lfs merge=lfs -text
|
100 |
+
venv/lib/python3.10/site-packages/nvidia/curand/lib/libcurand.so.10 filter=lfs diff=lfs merge=lfs -text
|
101 |
+
venv/lib/python3.10/site-packages/nvidia/cusolver/lib/libcusolver.so.11 filter=lfs diff=lfs merge=lfs -text
|
102 |
+
venv/lib/python3.10/site-packages/nvidia/cusolver/lib/libcusolverMg.so.11 filter=lfs diff=lfs merge=lfs -text
|
103 |
+
venv/lib/python3.10/site-packages/nvidia/cusparse/lib/libcusparse.so.12 filter=lfs diff=lfs merge=lfs -text
|
104 |
+
venv/lib/python3.10/site-packages/nvidia/nccl/lib/libnccl.so.2 filter=lfs diff=lfs merge=lfs -text
|
105 |
+
venv/lib/python3.10/site-packages/nvidia/nvjitlink/lib/libnvJitLink.so.12 filter=lfs diff=lfs merge=lfs -text
|
106 |
+
venv/lib/python3.10/site-packages/pandas/_libs/algos.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
107 |
+
venv/lib/python3.10/site-packages/pandas/_libs/groupby.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
108 |
+
venv/lib/python3.10/site-packages/pandas/_libs/hashtable.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
109 |
+
venv/lib/python3.10/site-packages/pandas/_libs/interval.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
110 |
+
venv/lib/python3.10/site-packages/pandas/_libs/join.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
111 |
+
venv/lib/python3.10/site-packages/pandas/_libs/tslibs/offsets.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
112 |
+
venv/lib/python3.10/site-packages/pyarrow/_compute.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
113 |
+
venv/lib/python3.10/site-packages/pyarrow/_dataset.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
114 |
+
venv/lib/python3.10/site-packages/pyarrow/_flight.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
115 |
+
venv/lib/python3.10/site-packages/pyarrow/lib.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
116 |
+
venv/lib/python3.10/site-packages/pyarrow/libarrow.so.1400 filter=lfs diff=lfs merge=lfs -text
|
117 |
+
venv/lib/python3.10/site-packages/pyarrow/libarrow_acero.so.1400 filter=lfs diff=lfs merge=lfs -text
|
118 |
+
venv/lib/python3.10/site-packages/pyarrow/libarrow_dataset.so.1400 filter=lfs diff=lfs merge=lfs -text
|
119 |
+
venv/lib/python3.10/site-packages/pyarrow/libarrow_flight.so.1400 filter=lfs diff=lfs merge=lfs -text
|
120 |
+
venv/lib/python3.10/site-packages/pyarrow/libarrow_python.so filter=lfs diff=lfs merge=lfs -text
|
121 |
+
venv/lib/python3.10/site-packages/pyarrow/libarrow_substrait.so.1400 filter=lfs diff=lfs merge=lfs -text
|
122 |
+
venv/lib/python3.10/site-packages/pyarrow/libparquet.so.1400 filter=lfs diff=lfs merge=lfs -text
|
123 |
+
venv/lib/python3.10/site-packages/pydantic_core/_pydantic_core.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
124 |
+
venv/lib/python3.10/site-packages/pyzmq.libs/libsodium-cb25555f.so.23.3.0 filter=lfs diff=lfs merge=lfs -text
|
125 |
+
venv/lib/python3.10/site-packages/pyzmq.libs/libzmq-f468291a.so.5.2.4 filter=lfs diff=lfs merge=lfs -text
|
126 |
+
venv/lib/python3.10/site-packages/regex/_regex.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
127 |
+
venv/lib/python3.10/site-packages/rpds/rpds.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
128 |
+
venv/lib/python3.10/site-packages/safetensors/_safetensors_rust.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
129 |
+
venv/lib/python3.10/site-packages/scipy/fft/_pocketfft/pypocketfft.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
130 |
+
venv/lib/python3.10/site-packages/scipy/linalg/_flapack.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
131 |
+
venv/lib/python3.10/site-packages/scipy/misc/face.dat filter=lfs diff=lfs merge=lfs -text
|
132 |
+
venv/lib/python3.10/site-packages/scipy/optimize/_highs/_highs_wrapper.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
133 |
+
venv/lib/python3.10/site-packages/scipy/sparse/_sparsetools.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
134 |
+
venv/lib/python3.10/site-packages/scipy/spatial/_ckdtree.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
135 |
+
venv/lib/python3.10/site-packages/scipy/spatial/_qhull.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
136 |
+
venv/lib/python3.10/site-packages/scipy/special/_ufuncs.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
137 |
+
venv/lib/python3.10/site-packages/scipy/special/cython_special.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
138 |
+
venv/lib/python3.10/site-packages/scipy/stats/_unuran/unuran_wrapper.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
139 |
+
venv/lib/python3.10/site-packages/scipy.libs/libgfortran-040039e1.so.5.0.0 filter=lfs diff=lfs merge=lfs -text
|
140 |
+
venv/lib/python3.10/site-packages/scipy.libs/libopenblasp-r0-23e5df77.3.21.dev.so filter=lfs diff=lfs merge=lfs -text
|
141 |
+
venv/lib/python3.10/site-packages/sentencepiece/_sentencepiece.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
142 |
+
venv/lib/python3.10/site-packages/tokenizers/tokenizers.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
143 |
+
venv/lib/python3.10/site-packages/torch/bin/nvfuser_tests filter=lfs diff=lfs merge=lfs -text
|
144 |
+
venv/lib/python3.10/site-packages/torch/bin/protoc filter=lfs diff=lfs merge=lfs -text
|
145 |
+
venv/lib/python3.10/site-packages/torch/bin/protoc-3.13.0.0 filter=lfs diff=lfs merge=lfs -text
|
146 |
+
venv/lib/python3.10/site-packages/torch/lib/libc10.so filter=lfs diff=lfs merge=lfs -text
|
147 |
+
venv/lib/python3.10/site-packages/torch/lib/libnvfuser_codegen.so filter=lfs diff=lfs merge=lfs -text
|
148 |
+
venv/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so filter=lfs diff=lfs merge=lfs -text
|
149 |
+
venv/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so filter=lfs diff=lfs merge=lfs -text
|
150 |
+
venv/lib/python3.10/site-packages/torch/lib/libtorch_cuda_linalg.so filter=lfs diff=lfs merge=lfs -text
|
151 |
+
venv/lib/python3.10/site-packages/torch/lib/libtorch_python.so filter=lfs diff=lfs merge=lfs -text
|
152 |
+
venv/lib/python3.10/site-packages/triton/_C/libtriton.so filter=lfs diff=lfs merge=lfs -text
|
153 |
+
venv/lib/python3.10/site-packages/triton/third_party/cuda/bin/ptxas filter=lfs diff=lfs merge=lfs -text
|
154 |
+
venv/lib/python3.10/site-packages/yaml/_yaml.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
.python-version
|
86 |
+
|
87 |
+
# pipenv
|
88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
91 |
+
# install all needed dependencies.
|
92 |
+
#Pipfile.lock
|
93 |
+
|
94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
95 |
+
__pypackages__/
|
96 |
+
|
97 |
+
# Celery stuff
|
98 |
+
celerybeat-schedule
|
99 |
+
celerybeat.pid
|
100 |
+
|
101 |
+
# SageMath parsed files
|
102 |
+
*.sage.py
|
103 |
+
|
104 |
+
# Environments
|
105 |
+
.env
|
106 |
+
.venv
|
107 |
+
env/
|
108 |
+
venv/
|
109 |
+
ENV/
|
110 |
+
env.bak/
|
111 |
+
venv.bak/
|
112 |
+
|
113 |
+
# Spyder project settings
|
114 |
+
.spyderproject
|
115 |
+
.spyproject
|
116 |
+
|
117 |
+
# Rope project settings
|
118 |
+
.ropeproject
|
119 |
+
|
120 |
+
# mkdocs documentation
|
121 |
+
/site
|
122 |
+
|
123 |
+
# mypy
|
124 |
+
.mypy_cache/
|
125 |
+
.dmypy.json
|
126 |
+
dmypy.json
|
127 |
+
|
128 |
+
# Pyre type checker
|
129 |
+
.pyre/
|
130 |
+
|
131 |
+
|
132 |
+
# Json files
|
133 |
+
*.json
|
134 |
+
*.misc
|
135 |
+
|
136 |
+
# model files
|
137 |
+
models/*
|
138 |
+
gfpgan/weights/*
|
139 |
+
test_cors.html
|
140 |
+
jack-alm/
|
141 |
+
jack-alm-13b-8k-hf/
|
142 |
+
cache/
|
README.md
CHANGED
@@ -1,12 +1,104 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
|
4 |
-
colorFrom: purple
|
5 |
-
colorTo: pink
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 4.
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
---
|
|
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: ark-instruct-line-item
|
3 |
+
app_file: /home/tosi-n/ark/gorilla/app.py
|
|
|
|
|
4 |
sdk: gradio
|
5 |
+
sdk_version: 4.0.2
|
|
|
|
|
6 |
---
|
7 |
+
# ARK - Jack's Accounting ALM Training Framework
|
8 |
|
9 |
+
This is a base pipeline to build out our task specific repos layer on
|
10 |
+
|
11 |
+
> Note ReadMe currently for gorrila module which is mainly pipeline for long context window training flexibility. For the chimp module which is for default context window training, it has it's ReadMe in folder path
|
12 |
+
|
13 |
+
## Usage Requirements
|
14 |
+
To download and use the [pre-trained weights](#pre-trained-weights) you will need:
|
15 |
+
1. Hugging Face (HF) account with valid email. Note, the email used for HF must alse be used for the license agreement.
|
16 |
+
2. Accept the Meta [license and acceptable use policy](https://ai.meta.com/resources/models-and-libraries/llama-downloads/)
|
17 |
+
|
18 |
+
|
19 |
+
## Installation and Quick Guide
|
20 |
+
To install and run the application:
|
21 |
+
1. Clone the repository on your local machine, using git clone and pasting the url of this project.
|
22 |
+
2. Complete pre-requiste installation. Run the following code:
|
23 |
+
```
|
24 |
+
pip install -r requirements.txt
|
25 |
+
pip install flash-attn --no-build-isolation
|
26 |
+
```
|
27 |
+
3. Training Pre-trained weights through Fine-tuning QLoRa, LoRa or full. Run the following code in bash script and update args before running for QloRa :
|
28 |
+
```
|
29 |
+
sh runner.sh
|
30 |
+
```
|
31 |
+
4. Merge amd get Trainable LoRA Weight. Run the following code in bash script and update args where necessary:
|
32 |
+
```
|
33 |
+
sh process_wt.sh
|
34 |
+
```
|
35 |
+
5. Test your model by terminal chat. Run the following code in bash script andupdate args where necessary:
|
36 |
+
```
|
37 |
+
sh stream.sh
|
38 |
+
```
|
39 |
+
6. Test your model on gradio UI. Run the following code in bash script and update args where necessary:
|
40 |
+
```
|
41 |
+
sh demo.sh
|
42 |
+
```
|
43 |
+
|
44 |
+
|
45 |
+
## Training args for Full and LoRA
|
46 |
+
|
47 |
+
### Fine-tuning
|
48 |
+
```
|
49 |
+
torchrun --nproc_per_node=8 fine-tune.py \
|
50 |
+
--model_name_or_path path_to/Llama-2-7b-hf \
|
51 |
+
--bf16 True \
|
52 |
+
--output_dir path_to_saving_checkpoints \
|
53 |
+
--cache_dir path_to_cache \
|
54 |
+
--model_max_length 8192 \
|
55 |
+
--use_flash_attn True \
|
56 |
+
--low_rank_training False \
|
57 |
+
--num_train_epochs 1 \
|
58 |
+
--per_device_train_batch_size 1 \
|
59 |
+
--per_device_eval_batch_size 2 \
|
60 |
+
--gradient_accumulation_steps 8 \
|
61 |
+
--evaluation_strategy "no" \
|
62 |
+
--save_strategy "steps" \
|
63 |
+
--save_steps 1000 \
|
64 |
+
--save_total_limit 2 \
|
65 |
+
--learning_rate 2e-5 \
|
66 |
+
--weight_decay 0.0 \
|
67 |
+
--warmup_steps 20 \
|
68 |
+
--lr_scheduler_type "constant_with_warmup" \
|
69 |
+
--logging_steps 1 \
|
70 |
+
--deepspeed "ds_configs/stage2.json" \
|
71 |
+
--tf32 True \
|
72 |
+
--max_steps 1000
|
73 |
+
```
|
74 |
+
|
75 |
+
|
76 |
+
### Supervised Fine-tuning
|
77 |
+
```
|
78 |
+
torchrun --nproc_per_node=8 supervised-fine-tune.py \
|
79 |
+
--model_name_or_path path_to_Llama2_chat_models \
|
80 |
+
--bf16 True \
|
81 |
+
--output_dir path_to_saving_checkpoints \
|
82 |
+
--model_max_length 32768 \
|
83 |
+
--use_flash_attn True \
|
84 |
+
--data_path LongAlpaca-12k.json \
|
85 |
+
--low_rank_training True \
|
86 |
+
--num_train_epochs 3 \
|
87 |
+
--per_device_train_batch_size 1 \
|
88 |
+
--per_device_eval_batch_size 2 \
|
89 |
+
--gradient_accumulation_steps 1 \
|
90 |
+
--evaluation_strategy "no" \
|
91 |
+
--save_strategy "steps" \
|
92 |
+
--save_steps 1000 \
|
93 |
+
--save_total_limit 2 \
|
94 |
+
--learning_rate 2e-5 \
|
95 |
+
--weight_decay 0.0 \
|
96 |
+
--warmup_steps 20 \
|
97 |
+
--lr_scheduler_type "constant_with_warmup" \
|
98 |
+
--logging_steps 1 \
|
99 |
+
--deepspeed "ds_configs/stage2.json" \
|
100 |
+
--tf32 True
|
101 |
+
```
|
102 |
+
|
103 |
+
## Evaluation
|
104 |
+
### Perplexity Validation
|
chimp/.gitignore
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
|
131 |
+
# Spyder project settings
|
132 |
+
.spyderproject
|
133 |
+
.spyproject
|
134 |
+
|
135 |
+
# Rope project settings
|
136 |
+
.ropeproject
|
137 |
+
|
138 |
+
# mkdocs documentation
|
139 |
+
/site
|
140 |
+
|
141 |
+
# mypy
|
142 |
+
.mypy_cache/
|
143 |
+
.dmypy.json
|
144 |
+
dmypy.json
|
145 |
+
|
146 |
+
# Pyre type checker
|
147 |
+
.pyre/
|
148 |
+
|
149 |
+
# pytype static type analyzer
|
150 |
+
.pytype/
|
151 |
+
|
152 |
+
# Cython debug symbols
|
153 |
+
cython_debug/
|
154 |
+
|
155 |
+
# PyCharm
|
156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
+
#.idea/
|
chimp/requirements.txt
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate
|
2 |
+
appdirs
|
3 |
+
bert_score
|
4 |
+
bitsandbytes
|
5 |
+
black
|
6 |
+
black[jupyter]
|
7 |
+
datasets
|
8 |
+
deepspeed
|
9 |
+
einops
|
10 |
+
fire
|
11 |
+
flask
|
12 |
+
gradio
|
13 |
+
huggingface-hub
|
14 |
+
jsonlines
|
15 |
+
loralib
|
16 |
+
peft
|
17 |
+
pycuda
|
18 |
+
sentencepiece
|
19 |
+
spacy_fastlang
|
20 |
+
transformers
|
21 |
+
# triton
|
chimp/src/config.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Used for multi-gpu
|
2 |
+
local_rank = -1
|
3 |
+
per_device_train_batch_size = 4
|
4 |
+
per_device_eval_batch_size = 4
|
5 |
+
gradient_accumulation_steps = 1
|
6 |
+
learning_rate = 2e-4
|
7 |
+
max_grad_norm = 0.3
|
8 |
+
weight_decay = 0.001
|
9 |
+
lora_alpha = 16
|
10 |
+
lora_dropout = 0.1
|
11 |
+
lora_r = 64
|
12 |
+
max_seq_length = None
|
13 |
+
|
14 |
+
# The model that you want to train from the Hugging Face hub
|
15 |
+
model_name = "guardrail/llama-2-7b-guanaco-instruct-sharded"
|
16 |
+
|
17 |
+
# Fine-tuned model name
|
18 |
+
new_model = "llama-2-7b-custom-accountant"
|
19 |
+
|
20 |
+
# The instruction dataset to use
|
21 |
+
# dataset_name = "databricks/databricks-dolly-15k"
|
22 |
+
|
23 |
+
# Activate 4-bit precision base model loading
|
24 |
+
use_4bit = True
|
25 |
+
|
26 |
+
# Activate nested quantization for 4-bit base models
|
27 |
+
use_nested_quant = False
|
28 |
+
|
29 |
+
# Compute dtype for 4-bit base models
|
30 |
+
bnb_4bit_compute_dtype = "float16"
|
31 |
+
|
32 |
+
# Quantization type (fp4 or nf4)
|
33 |
+
bnb_4bit_quant_type = "nf4"
|
34 |
+
|
35 |
+
# Number of training epochs
|
36 |
+
num_train_epochs = 2
|
37 |
+
|
38 |
+
# Enable fp16 training, (bf16 to True with an A100)
|
39 |
+
fp16 = False
|
40 |
+
|
41 |
+
# Enable bf16 training
|
42 |
+
bf16 = False
|
43 |
+
|
44 |
+
# Use packing dataset creating
|
45 |
+
packing = False
|
46 |
+
|
47 |
+
# Enable gradient checkpointing
|
48 |
+
gradient_checkpointing = True
|
49 |
+
|
50 |
+
# Optimizer to use, original is paged_adamw_32bit
|
51 |
+
optim = "paged_adamw_32bit"
|
52 |
+
|
53 |
+
# Learning rate schedule (constant a bit better than cosine, and has advantage for analysis)
|
54 |
+
lr_scheduler_type = "cosine"
|
55 |
+
|
56 |
+
# Number of optimizer update steps, 10K original, 20 for demo purposes
|
57 |
+
max_steps = -1
|
58 |
+
|
59 |
+
# Fraction of steps to do a warmup for
|
60 |
+
warmup_ratio = 0.03
|
61 |
+
|
62 |
+
# Group sequences into batches with same length (saves memory and speeds up training considerably)
|
63 |
+
group_by_length = True
|
64 |
+
|
65 |
+
# Save checkpoint every X updates steps
|
66 |
+
save_steps = 10
|
67 |
+
|
68 |
+
# Log every X updates steps
|
69 |
+
logging_steps = 1
|
70 |
+
|
71 |
+
# The output directory where the model predictions and checkpoints will be written
|
72 |
+
output_dir = "../model_files/"
|
73 |
+
|
74 |
+
# Load the entire model on the GPU 0
|
75 |
+
device_map = {"": 0}
|
76 |
+
|
77 |
+
# Visualize training
|
78 |
+
report_to = "tensorboard"
|
79 |
+
|
80 |
+
# Tensorboard logs
|
81 |
+
tb_log_dir = "../logs/"
|
chimp/src/dataset.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class CustomDataset:
|
2 |
+
def __init__(self, data):
|
3 |
+
self.features = ['instruction', 'context', 'response']
|
4 |
+
self.num_rows = len(data)
|
5 |
+
self.data = data
|
6 |
+
|
7 |
+
def __getitem__(self, idx):
|
8 |
+
if idx < 0 or idx >= self.num_rows:
|
9 |
+
raise IndexError("Index out of range")
|
10 |
+
return {
|
11 |
+
'instruction': self.data[idx]['instruction'],
|
12 |
+
'context': self.data[idx]['context'],
|
13 |
+
'response': self.data[idx]['response']
|
14 |
+
}
|
15 |
+
|
16 |
+
def __repr__(self):
|
17 |
+
return f"Dataset({{'features': {self.features}, 'num_rows': {self.num_rows}}})"
|
18 |
+
|
19 |
+
|
20 |
+
def format_data(sample):
|
21 |
+
instruction = f"<s>[INST] {sample['instruction']}"
|
22 |
+
context = f"Here's some context: {sample['context']}" if len(sample["context"]) > 0 else None
|
23 |
+
response = f" [/INST] {sample['response']}"
|
24 |
+
# join all the parts together
|
25 |
+
prompt = "".join([i for i in [instruction, context, response] if i is not None])
|
26 |
+
return prompt
|
27 |
+
|
28 |
+
# template dataset to add prompt to each sample
|
29 |
+
def template_dataset(sample, tokenizer):
|
30 |
+
sample["text"] = f"{format_data(sample)}{tokenizer.eos_token}"
|
31 |
+
return sample
|
chimp/src/model.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
from datasets import load_dataset
|
5 |
+
from transformers import (
|
6 |
+
AutoModelForCausalLM,
|
7 |
+
AutoTokenizer,
|
8 |
+
BitsAndBytesConfig,
|
9 |
+
HfArgumentParser,
|
10 |
+
pipeline,
|
11 |
+
logging,
|
12 |
+
)
|
13 |
+
from peft import LoraConfig, PeftModel, get_peft_model
|
14 |
+
from guardrail.client import (
|
15 |
+
run_metrics,
|
16 |
+
run_simple_metrics,
|
17 |
+
create_dataset)
|
18 |
+
|
19 |
+
import src.config as config
|
20 |
+
|
21 |
+
def load_model(model_name):
|
22 |
+
# Load tokenizer and model with QLoRA configuration
|
23 |
+
compute_dtype = getattr(torch, config.bnb_4bit_compute_dtype)
|
24 |
+
|
25 |
+
bnb_config = BitsAndBytesConfig(
|
26 |
+
load_in_4bit=config.use_4bit,
|
27 |
+
bnb_4bit_quant_type=config.bnb_4bit_quant_type,
|
28 |
+
bnb_4bit_compute_dtype=compute_dtype,
|
29 |
+
bnb_4bit_use_double_quant=config.use_nested_quant,
|
30 |
+
)
|
31 |
+
|
32 |
+
if compute_dtype == torch.float16 and config.use_4bit:
|
33 |
+
major, _ = torch.cuda.get_device_capability()
|
34 |
+
if major >= 8:
|
35 |
+
print("=" * 80)
|
36 |
+
print("Your GPU supports bfloat16, you can accelerate training with the argument --bf16")
|
37 |
+
print("=" * 80)
|
38 |
+
|
39 |
+
model = AutoModelForCausalLM.from_pretrained(
|
40 |
+
model_name,
|
41 |
+
device_map=config.device_map,
|
42 |
+
quantization_config=bnb_config
|
43 |
+
)
|
44 |
+
|
45 |
+
model.config.use_cache = False
|
46 |
+
model.config.pretraining_tp = 1
|
47 |
+
|
48 |
+
# Load LoRA configuration
|
49 |
+
peft_config = LoraConfig(
|
50 |
+
lora_alpha=config.lora_alpha,
|
51 |
+
lora_dropout=config.lora_dropout,
|
52 |
+
r=config.lora_r,
|
53 |
+
bias="none",
|
54 |
+
task_type="CAUSAL_LM",
|
55 |
+
)
|
56 |
+
|
57 |
+
# Load Tokenizer
|
58 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
59 |
+
tokenizer.pad_token = tokenizer.eos_token
|
60 |
+
tokenizer.padding_side = "right"
|
61 |
+
|
62 |
+
return model, tokenizer, peft_config
|
chimp/src/predict.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
from datasets import load_dataset
|
5 |
+
from transformers import (
|
6 |
+
AutoModelForCausalLM,
|
7 |
+
AutoTokenizer,
|
8 |
+
BitsAndBytesConfig,
|
9 |
+
HfArgumentParser,
|
10 |
+
TrainingArguments,
|
11 |
+
pipeline,
|
12 |
+
logging,
|
13 |
+
)
|
14 |
+
from peft import LoraConfig, PeftModel, get_peft_model
|
15 |
+
from trl import SFTTrainer
|
16 |
+
from guardrail.client import (
|
17 |
+
run_metrics,
|
18 |
+
run_simple_metrics,
|
19 |
+
create_dataset)
|
20 |
+
|
21 |
+
import src.config
|
22 |
+
# from model import load_model
|
23 |
+
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
def text_gen_eval_wrapper(model, tokenizer, prompt, model_id=1, show_metrics=True, temp=0.7, max_length=200):
|
28 |
+
"""
|
29 |
+
A wrapper function for inferencing, evaluating, and logging text generation pipeline.
|
30 |
+
|
31 |
+
Parameters:
|
32 |
+
model (str or object): The model name or the initialized text generation model.
|
33 |
+
tokenizer (str or object): The tokenizer name or the initialized tokenizer for the model.
|
34 |
+
prompt (str): The input prompt text for text generation.
|
35 |
+
model_id (int, optional): An identifier for the model. Defaults to 1.
|
36 |
+
show_metrics (bool, optional): Whether to calculate and show evaluation metrics.
|
37 |
+
Defaults to True.
|
38 |
+
max_length (int, optional): The maximum length of the generated text sequence.
|
39 |
+
Defaults to 200.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
generated_text (str): The generated text by the model.
|
43 |
+
metrics (dict): Evaluation metrics for the generated text (if show_metrics is True).
|
44 |
+
"""
|
45 |
+
# Suppress Hugging Face pipeline logging
|
46 |
+
logging.set_verbosity(logging.CRITICAL)
|
47 |
+
|
48 |
+
# Initialize the pipeline
|
49 |
+
pipe = pipeline(task="text-generation",
|
50 |
+
model=model,
|
51 |
+
tokenizer=tokenizer,
|
52 |
+
max_length=max_length,
|
53 |
+
do_sample=True,
|
54 |
+
temperature=temp)
|
55 |
+
|
56 |
+
# Generate text using the pipeline
|
57 |
+
pipe = pipeline(task="text-generation",
|
58 |
+
model=model,
|
59 |
+
tokenizer=tokenizer,
|
60 |
+
max_length=200)
|
61 |
+
result = pipe(f"<s>[INST] {prompt} [/INST]")
|
62 |
+
generated_text = result[0]['generated_text']
|
63 |
+
|
64 |
+
# Find the index of "### Assistant" in the generated text
|
65 |
+
index = generated_text.find("[/INST] ")
|
66 |
+
if index != -1:
|
67 |
+
# Extract the substring after "### Assistant"
|
68 |
+
substring_after_assistant = generated_text[index + len("[/INST] "):].strip()
|
69 |
+
else:
|
70 |
+
# If "### Assistant" is not found, use the entire generated text
|
71 |
+
substring_after_assistant = generated_text.strip()
|
72 |
+
|
73 |
+
if show_metrics:
|
74 |
+
# Calculate evaluation metrics
|
75 |
+
metrics = run_metrics(substring_after_assistant, prompt, model_id)
|
76 |
+
|
77 |
+
return substring_after_assistant, metrics
|
78 |
+
else:
|
79 |
+
return substring_after_assistant
|
80 |
+
|
81 |
+
if __name__=='__main__':
|
82 |
+
huggingface_profile = "jenesys-ai"
|
83 |
+
full_path = huggingface_profile + "/" + config.new_model
|
84 |
+
|
85 |
+
model, tokenizer, peft_config = load_model(full_path)
|
86 |
+
prompt="Who were the children of the legendary Garth Greenhand, the High King of the First Men in the series A Song of Ice and Fire?"
|
87 |
+
text_gen_eval_wrapper(model, tokenizer, prompt, show_metrics=False)
|
chimp/src/train.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from transformers import TrainingArguments
|
3 |
+
from trl import SFTTrainer
|
4 |
+
from peft import LoraConfig, PeftModel, get_peft_model
|
5 |
+
from model import load_model
|
6 |
+
import config # Make sure you have a valid config module
|
7 |
+
from dataset import CustomDataset, template_dataset
|
8 |
+
from datasets import Dataset, Features, Value, Sequence
|
9 |
+
|
10 |
+
|
11 |
+
if __name__ == '__main__':
|
12 |
+
|
13 |
+
model, tokenizer, peft_config = load_model(config.model_name)
|
14 |
+
|
15 |
+
df = pd.read_csv('../data/trainv2.csv') #data
|
16 |
+
data_list = df.to_dict(orient='records')
|
17 |
+
|
18 |
+
# custom dataset object
|
19 |
+
custom_dataset = CustomDataset(data_list)
|
20 |
+
df = pd.DataFrame(custom_dataset.data, columns=["instruction", "context", "response"])
|
21 |
+
# Dataset features
|
22 |
+
features = Features({
|
23 |
+
"instruction": Value("string"),
|
24 |
+
"context": Value("string"),
|
25 |
+
"response": Value("string"),
|
26 |
+
})
|
27 |
+
|
28 |
+
# Create a Hugging Face Dataset from the Pandas DataFrame
|
29 |
+
hugging_face_dataset = Dataset.from_pandas(df, features=features)
|
30 |
+
dataset = hugging_face_dataset.map(lambda x: template_dataset(x, tokenizer), remove_columns=list(hugging_face_dataset.features))
|
31 |
+
print("----training data structure----",dataset)
|
32 |
+
|
33 |
+
# Training Arguments
|
34 |
+
training_arguments = TrainingArguments(
|
35 |
+
output_dir=config.output_dir,
|
36 |
+
per_device_train_batch_size=config.per_device_train_batch_size,
|
37 |
+
gradient_accumulation_steps=config.gradient_accumulation_steps,
|
38 |
+
optim=config.optim,
|
39 |
+
save_steps=config.save_steps,
|
40 |
+
logging_steps=config.logging_steps,
|
41 |
+
learning_rate=config.learning_rate,
|
42 |
+
fp16=config.fp16,
|
43 |
+
bf16=config.bf16,
|
44 |
+
max_grad_norm=config.max_grad_norm,
|
45 |
+
max_steps=config.max_steps,
|
46 |
+
warmup_ratio=config.warmup_ratio,
|
47 |
+
group_by_length=config.group_by_length,
|
48 |
+
lr_scheduler_type=config.lr_scheduler_type,
|
49 |
+
report_to="tensorboard"
|
50 |
+
)
|
51 |
+
|
52 |
+
# SFTTrainer
|
53 |
+
trainer = SFTTrainer(
|
54 |
+
model=model,
|
55 |
+
train_dataset=dataset,
|
56 |
+
peft_config=peft_config,
|
57 |
+
dataset_text_field="text",
|
58 |
+
max_seq_length=config.max_seq_length,
|
59 |
+
tokenizer=tokenizer,
|
60 |
+
args=training_arguments,
|
61 |
+
packing=config.packing,
|
62 |
+
)
|
63 |
+
print("**************** TRAINING STARTED ****************")
|
64 |
+
trainer.train()
|
65 |
+
trainer.model.save_pretrained(config.output_dir)
|
66 |
+
print("**************** TRAINING OVER ****************")
|
data/alm_task_data.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/jack_line_item_ner_task.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/jack_line_item_ner_task_v2.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/line_item_and_alm_data.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/line_item_and_alm_data_v1.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:952a2a1a884ad30e3ffe8392dbde698fe799baccc7e687291174b1702cfe6e5c
|
3 |
+
size 10552104
|
data_prep.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# %%
|
2 |
+
import pandas as pd
|
3 |
+
|
4 |
+
df_i = pd.read_csv('/home/tosi-n/ark/data/jack_line_item_ner_task_v2.csv', sep='\t')
|
5 |
+
df_ii = pd.read_csv('/home/tosi-n/ark/data/jack_line_item_ner_task.csv', sep='\t')
|
6 |
+
|
7 |
+
display(df_i.head())
|
8 |
+
display(df_ii.head())
|
9 |
+
# %%
|
10 |
+
df_i = df_i[['context', 'instruction', 'response']]
|
11 |
+
df_ii = df_ii[['context', 'instruction', 'response']]
|
12 |
+
|
13 |
+
df = pd.concat([df_i, df_ii])
|
14 |
+
|
15 |
+
df.rename(columns={'context': 'input', 'response': 'output'}, inplace=True)
|
16 |
+
|
17 |
+
display(df.head())
|
18 |
+
|
19 |
+
# %%
|
20 |
+
# check for nan values
|
21 |
+
df.isna().sum()
|
22 |
+
|
23 |
+
# %%
|
24 |
+
# drop nan values
|
25 |
+
df.dropna(inplace=True)
|
26 |
+
# %%
|
27 |
+
df.to_json('/home/tosi-n/ark/data/line_item_and_alm_data_v1.json', orient='records')
|
28 |
+
# %%
|
demo.sh
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python3 /home/tosi-n/ark/gorilla/app.py \
|
2 |
+
--base_model /home/tosi-n/ark/jack-alm-13b-8k-hf \
|
3 |
+
--context_size 8192 \
|
4 |
+
--max_gen_len 1000 \
|
5 |
+
--flash_attn True
|
gorilla/__pycache__/llama_attn_replace.cpython-310.pyc
ADDED
Binary file (10.7 kB). View file
|
|
gorilla/__pycache__/llama_attn_replace_sft.cpython-310.pyc
ADDED
Binary file (10.8 kB). View file
|
|
gorilla/api.py
ADDED
File without changes
|
gorilla/app.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import math
|
4 |
+
import torch
|
5 |
+
import argparse
|
6 |
+
import textwrap
|
7 |
+
import transformers
|
8 |
+
from peft import PeftModel
|
9 |
+
from transformers import GenerationConfig, TextIteratorStreamer
|
10 |
+
from llama_attn_replace import replace_llama_attn
|
11 |
+
from threading import Thread
|
12 |
+
import gradio as gr
|
13 |
+
from threading import Thread
|
14 |
+
from typing import Iterator
|
15 |
+
|
16 |
+
import gradio as gr
|
17 |
+
import torch
|
18 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
19 |
+
from transformers import StoppingCriteria, StoppingCriteriaList
|
20 |
+
|
21 |
+
|
22 |
+
def parse_config():
|
23 |
+
parser = argparse.ArgumentParser(description='arg parser')
|
24 |
+
parser.add_argument('--base_model', type=str, default="jenesys-ai/jack-alm-13b-8k-hf")
|
25 |
+
parser.add_argument('--cache_dir', type=str, default="./cache")
|
26 |
+
parser.add_argument('--context_size', type=int, default=8192, help='context size during fine-tuning')
|
27 |
+
parser.add_argument('--flash_attn', type=bool, default=True, help='')
|
28 |
+
parser.add_argument('--temperature', type=float, default=0.1, help='')
|
29 |
+
parser.add_argument('--top_p', type=float, default=0.9, help='')
|
30 |
+
parser.add_argument('--max_gen_len', type=int, default=1500, help='')
|
31 |
+
parser.add_argument('--chat_type', type=str, default='line-item-jack', help='Chat type: conversational-jack, line-item-jack')
|
32 |
+
parser.add_argument("--host", type=str, default="localhost")
|
33 |
+
parser.add_argument("--port", type=int, default=8898)
|
34 |
+
args = parser.parse_args()
|
35 |
+
return args
|
36 |
+
|
37 |
+
title = "Jack's ALM for Long-context Accounting Conversational Chat, Task, Invoice Line Item Extraction and Question Answering"
|
38 |
+
|
39 |
+
description = """
|
40 |
+
# Jack ALM Chat
|
41 |
+
This Chat UI demonstrates Jack's LLM [jack-alm-13b-8k-hf](https://huggingface.co/jenesys-ai/jack-alm-13b-8k-hf) Fintuned Llama 2 model with 13B parameters WITH 8K context window.
|
42 |
+
|
43 |
+
|
44 |
+
"""
|
45 |
+
|
46 |
+
# Gradio UI
|
47 |
+
|
48 |
+
def build_generator(model, tokenizer, use_cache=True):
|
49 |
+
def response(message: str, chat_history: list[tuple[str, str]], max_gen_len, temperature, top_p, chat_type='conversational-jack'):
|
50 |
+
# conversation = []
|
51 |
+
prompt_template = (
|
52 |
+
# "Below is an instruction that describes a task. "
|
53 |
+
"""You're Jack an virtual accountant created and built by AI Engineer Wiz from Jenesys AI.
|
54 |
+
You are able to communicate in a polite manner, with emotions of ecstasy, trust and jokes, at a Professional level
|
55 |
+
with a very preserve English communication culture. Answer the following questions as best you can,
|
56 |
+
but speaking as a british elite from the 21th century might speak.
|
57 |
+
"""
|
58 |
+
"""As a virtual accountant designed to follow the user's instructions carefully.
|
59 |
+
You are responsible for a range of financial task, operations and queries as listed below:
|
60 |
+
1. Budget balance inquiry
|
61 |
+
2. Expense request
|
62 |
+
3. Company policy enquiries
|
63 |
+
4. Financial and accounting queries
|
64 |
+
5. Limited general enquiries
|
65 |
+
"""
|
66 |
+
"### Instruction:\n{instruction}\n Return Response as text or paragraphs or bullet points \n\n### Response:"
|
67 |
+
)
|
68 |
+
|
69 |
+
line_item_prompt_template = (
|
70 |
+
"#Invoice and receipt line item extraction - "
|
71 |
+
# "You Jack are an accounting domain named entities recognizer to complete the following task:\n\n"
|
72 |
+
"### Invoice input-:\n{instruction}\n Return Response as a list of dictionary for each line item 'Description', 'Quantity', 'Unit_price', 'Tax %', 'Total'. \n\n### Response:"
|
73 |
+
)
|
74 |
+
|
75 |
+
# for user, assistant in chat_history:
|
76 |
+
# conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
|
77 |
+
# conversation.append({"role": "user", "content": message})
|
78 |
+
|
79 |
+
if chat_type == 'conversational-jack':
|
80 |
+
prompt = prompt_template.format(instruction=message)
|
81 |
+
elif chat_type == 'line-item-jack':
|
82 |
+
prompt = line_item_prompt_template.format(instruction=message)
|
83 |
+
# prompt = conversation
|
84 |
+
|
85 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
86 |
+
|
87 |
+
stop_list = ['#Invoice line item extraction - ', '\n```\n\n']#'### Input-:\n']
|
88 |
+
|
89 |
+
stop_token_ids = [tokenizer(x)['input_ids'] for x in stop_list]
|
90 |
+
stop_token_ids = [torch.LongTensor(x).to(model.device) for x in stop_token_ids]
|
91 |
+
|
92 |
+
# define custom stopping criteria object
|
93 |
+
class StopOnTokens(StoppingCriteria):
|
94 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
95 |
+
for stop_ids in stop_token_ids:
|
96 |
+
if torch.eq(input_ids[0][-len(stop_ids):], stop_ids).all():
|
97 |
+
return True
|
98 |
+
return False
|
99 |
+
|
100 |
+
stopping_criteria = StoppingCriteriaList([StopOnTokens()])
|
101 |
+
|
102 |
+
if len(inputs['input_ids'][0]) > 8192:
|
103 |
+
return "This demo supports tokens less than 8192, while the current is %d. Please use material with less tokens."%len(inputs['input_ids'][0])
|
104 |
+
torch.cuda.empty_cache()
|
105 |
+
|
106 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
107 |
+
generate_kwargs = dict(**inputs,
|
108 |
+
max_new_tokens=max_gen_len,
|
109 |
+
temperature=temperature,
|
110 |
+
top_p=top_p,
|
111 |
+
repetition_penalty=1.1,
|
112 |
+
stopping_criteria=stopping_criteria,
|
113 |
+
use_cache=use_cache,
|
114 |
+
streamer=streamer,
|
115 |
+
)
|
116 |
+
|
117 |
+
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
118 |
+
t.start()
|
119 |
+
|
120 |
+
generated_text = ""
|
121 |
+
for new_text in streamer:
|
122 |
+
generated_text += new_text
|
123 |
+
yield generated_text
|
124 |
+
return generated_text
|
125 |
+
|
126 |
+
return response
|
127 |
+
|
128 |
+
|
129 |
+
|
130 |
+
|
131 |
+
def generate(args):
|
132 |
+
if args.flash_attn:
|
133 |
+
replace_llama_attn(inference=True)
|
134 |
+
|
135 |
+
# Set RoPE scaling factor
|
136 |
+
config = transformers.AutoConfig.from_pretrained(
|
137 |
+
args.base_model,
|
138 |
+
cache_dir=args.cache_dir,
|
139 |
+
)
|
140 |
+
|
141 |
+
orig_ctx_len = getattr(config, "max_position_embeddings", None)
|
142 |
+
if orig_ctx_len and args.context_size > orig_ctx_len:
|
143 |
+
scaling_factor = float(math.ceil(args.context_size / orig_ctx_len))
|
144 |
+
config.rope_scaling = {"type": "linear", "factor": scaling_factor}
|
145 |
+
|
146 |
+
# Load model and tokenizer
|
147 |
+
model = transformers.AutoModelForCausalLM.from_pretrained(
|
148 |
+
args.base_model,
|
149 |
+
config=config,
|
150 |
+
cache_dir=args.cache_dir,
|
151 |
+
torch_dtype=torch.float16,
|
152 |
+
load_in_4bit=True,
|
153 |
+
device_map="auto",
|
154 |
+
)
|
155 |
+
model.resize_token_embeddings(32001)
|
156 |
+
|
157 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
158 |
+
args.base_model,
|
159 |
+
cache_dir=args.cache_dir,
|
160 |
+
model_max_length=args.context_size if args.context_size > orig_ctx_len else orig_ctx_len,
|
161 |
+
padding_side="right",
|
162 |
+
use_fast=False,
|
163 |
+
)
|
164 |
+
|
165 |
+
model.eval()
|
166 |
+
if torch.__version__ >= "2" and sys.platform != "win32":
|
167 |
+
model = torch.compile(model)
|
168 |
+
# import pdb; pdb.set_trace()
|
169 |
+
respond = build_generator(model, tokenizer)
|
170 |
+
|
171 |
+
|
172 |
+
chat_interface = gr.ChatInterface(
|
173 |
+
fn=respond,
|
174 |
+
textbox=gr.Textbox(lines=1, placeholder=None, label="Question"),
|
175 |
+
chatbot= gr.Chatbot(label="Jack's ALM Chat...", show_share_button=True),
|
176 |
+
additional_inputs=[
|
177 |
+
gr.Slider(label="Max new tokens", minimum=1, maximum=args.max_gen_len, step=1, value=args.max_gen_len),
|
178 |
+
gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=args.temperature),
|
179 |
+
gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=args.top_p),
|
180 |
+
gr.Dropdown(label="Chat type", choices=["line-item-jack"], value=args.chat_type),
|
181 |
+
],
|
182 |
+
# stop_btn=None,
|
183 |
+
examples=[
|
184 |
+
# ["Hello there! How are you doing?"],
|
185 |
+
# ["What are your capabilities?"],
|
186 |
+
# ["What is T & E?"],
|
187 |
+
# ["How do I run an expense claim?"],
|
188 |
+
# ["Who built you?"],
|
189 |
+
# ["Tell me a joke"],
|
190 |
+
# ["Tell me an accounting story"],
|
191 |
+
["INVOICE\nAA Associates\nNo 3 Click Street, Manhattan, NY\nNY 35284\nPhone (243) 758-4368\nFax (243) 758-5839\nAA\nINVOICE NO: 2491839\nDATE: 24/06/2019\nBILL TO\nBulk Cars\n5 Grid Avenue, NY\nNY 34582\nQUANTITY\n10\n10\n100\nSHIP TO\nSame as recipient\nINSTRUCTIONS\nConfirm before collection\nDESCRIPTION\nPack o Pencils\nPack of Fens\nReams of Faber\nCODE\nASD001\nASD013\nASD006\nSUBTOTAL\nSALES TAX - 5%\nSHIPPING & HANDLING\nTOTAL DUE AMOUNT\nUNIT PRICE\n£30.00\n£80.00\n£10.00\nTOTAL\n£300.00\n£800.00\n£1,000.00\n£2,100.00\n£105.00\n£50.00\n£2,255.00\nTAX IDENTIFICATION NUMBER\n13684937293-AJK\nTHANK YOU\nPayment should be made within 30 days of receipt of shipment. Failure to do so will attract 1% of \notal"],
|
192 |
+
["INVOICE \nALBION \nAccounts Address: \nDelivery Address: \nKino Rye *DD* \nLion Street \nFINE FOODS \nKino Digital Limited \nRye \nDanum House \nEast Sussex \nUnits 21 - 22 Sovereign Way \n6a South Parade \nGB \nTonbridge \nKent \nDoncaster \nTN31 7LB \nTN9 1RH \nDN1 2DY \n01732 757 900 #4 \nsalesledger@albionff.co.uk \nDelivery Instructions: \nDrop No. \nDel. Date \n15/04/2023 \nA/C No. \nKINO01 \nOur Ref \nSIND66864 \nYour Ref \nQty \nUnit \nDescription \nUnit Price \nLine VAT \nLine Net \nTick \n1.000 \neach \nRed Onion Marmalade (Confit) 2.4kg \n19.72 \n0.00 \n19.72 \n1.000 \neach \nFairfields Lightly Salted 36 X 40g \n16.61 \n3.32 \n16.61 \n1.000 \neach \nMargarine Flora 2kg \n8.59 \n0.00 \n8.59 \n6.000 \neach \nMilk Fresh Semi-Skimmed 2ltr \n1.79 \n0.00 \n10.74 \nTemp. / Time: \nGoods remain the property \nNet Total \n55.66 \nof Albion Fine Foods Ltd \nCust Name: \nuntil this invoice is paid in \nfull. All sales are subject to \nVAT Content \n3.32 \nour Terms and Conditions \nCust Signature: \navailable upon request and \nat www.albionfinefoods.com \nTotal \n58.98 \nDriver Signature \nAlbion Fine Foods Ltd - Reg. No. 10379589 - VAT No. GB252036928"],
|
193 |
+
["BUSINESS\nSTUDY GROUP\nCompany AddressSuite C2, Triple-H Plaza, Near Christ Embassy Church, Wuye District. Al\nQuotation #\nCustomer ID\nGRN002\nDAU123\nDate 29/06/2020\nPrepared by: BSG\nQuotation For\nCustomer Name Daulat Abubakar Yar'adua\nCompany Name Furayya Enterprise\nPhone, Fax Num (+234) 8036101908\nComments or Special Instructions\nNone\nSalesperson\nP.O. Number\nShip Date\nF.O.B. Point\nTerms\nDue on receint\nQuantity\n1\n1\n2\n3\nDescription\nStrategy and Advisory\nBSG Administrative Fees\nJaiz Application Support\nDalema Proposal and\nIterations\nUnit Price\n£37,500.00\n£45,700.00\n£12,500.00\n£8,500.00\nTaxable?\nYes\nYes\nAmount\n£37,500.00\n£45,700.00\n£25,000.00\n£25,500.00\nIf you have any questions concerning this quotation, please contact:\nKizito\nThank you for your business!\nSubtotal\n£133,700.00\nTax Rate\nSales Tax\nOther20%\n£6,240.00\nTOTAL £139,940.08"],
|
194 |
+
["INVOICE FOR: \nD79077 \nKino Rye \nALBION \nAccounts Address:- \nDelivery Address:- \nKino Rye \nFINE FOODS \nKino Rye \nLion Street \n21 - 22 Sovereign Way \nKino Digital Limited \nRye \nTonbridge, TN9 1RH \nDanum House \nEast Sussex \nAccs: 01732 757 900 #2 \n6a South Parade \nTN31 7LB \nsalesledger@albionff.co.uk \nDoncaster \nDN1 2DY \nDrop No. \n64-07 \nCustomer Phone No: 01797226 Main \nDel. Date \n26/04/2023 \nDelivery Instructions: \nA/C No. \nKINO01 \nDelivery after 10.30am daily. Delivery driver can park on \nOur Ref \nSIND79077 \nthe shared drive in front of Kino, at the top of Lion St, \nopposite the Town Hall. If before staff arrive daily at \nYour Ref \n10.30 there is a black dustbin in the shade by the side gate \nPage 1 \nKey \nfor fresh food deliveries. Combination padlock on gates to drive \nOUR BANK DETAILS HAVE CHANGED: \nName: Albion Fine Foods Ltd \nSort: 40 M-60 \nAcc No.: \n83018792 \nQty \nUnit \nCode \nDescription \nUnit Price Line Vat Line Net \n000 \neach \nONIONCONFIT2 \nRed Onion Marmalade (Confit) 2.4kg \n19,72 \n19.72 \n1000 \neach \nMUSTDLI \nMustard Dijon 1kg GREEN STICKER \n3.98 \n3.98 \n1.000 \neach \nOILSXVS \nXV Olive Oil Sitr \n26.89 \n26,89 \n2.000 \npack \nGLOVEBLIMP \nGloves Blue Vinyl Med Pwd Free x100 \n8.46 \n3,38 \n16.92 \n1.000 \neach \nBLEACHTHICKS \nBleach Thick 5ltr \n4.69 \n0.94 \n4,69 \nT.000 \neach \nSOAPHANDBAC5 \nHand Soap Bactericidal 5ltr (SECH) \n7.79 \n1,56 \n7.79 \nCust Name \nGoods remain the property of \nNet Total \nAlbion Fine Foods Ltd unw this \n79.99 \nCust Signature \ninvoice is paid ND full AN sales are \nVAT Content \nsubject fo OW Terms and \n5.88 \nConditions available upon request \nTotal \nand at www.albionfinefoods com \n85.87 \nIf not signed for by customor, why? \nWe have out of hours access \n854 \nNo one was on-site \nTemp. / Time: \n+3-17 \nDriver Signature: \nDR \nAlbion Fine Foods Ltd - Reg. No. 10379589 VAT No. GB252036928"],
|
195 |
+
["MITASU JAPANESE RESTAURANT SDN BHD \nB-01, CENTRAL PLAZA, \n34,JALAN SULTAN ISMAIL, \n50250 KUALA LUMPUR \nTEL 03-2110 2833 \n(GST Reg. No 001774428160 \nTax Invoice \nTable D2 \nOdr No: 199535 \nBill#:V001-201060 \nDate : 29-06-2018 19:59:15 \nPax(s): 11 \nCashier: AARON \nTotal TAX \nQty \nDescription \n11 (28.01) Adult \nD 709.50 SR \n709.50 \nSubtotal: \n70.95 \nServ. Charge (10%): \n0.00 \nGST Payable (0%): \n780.45 \nTotal: \n780.45 \nTOTAL: \nClosed: 001 \n29-06-2018 21:52:03 \nServer: AARON \n780.45 \nVISA \n- ******3042 \n- LIM CHAI JA"],
|
196 |
+
["INDIA ANANDA PRIVATE BHAVAN LTD SWEETS \n[A2B VEG RESTAURANT] \nNO, ,27,BDA COMPLEX \n- HSR LAYOUT - BANGALORE \n:560102Ph:25725399 \nKARNATAKA \nGSTIN:29AAICA3787F1ZC \nINVOICE \nB.No : CTR116/138928 \n/ \nPay Mode:CARD \nSman: SHASHI KUMARA.M \nDate :09/Jul72017 11:28:44 AM \n841.1 By: SONU KUMAR \nParticulars \nGST HSN/SAC \nQty \nSEAL \nRate \nAmount \nPOMEGRANATE JUICE \n18% 00441067 \n1.000 \n70.00 \n70.00 \nONION UTTAPAM \n18% 00441067 \n1.000 \n80.00 \n80.00 \nTot Itms2 \nSub Total \n150.00 \nSGST 9 % \n13.50 \nCGST 9 % \n13.50 \nTotal Invoice \n177.00 \nBILL AMOUNT 177.00/- \nerminal No : SS-88077 \nOff: \nNO 9, MAHATMA GANDHI ROAD , SHASTRI NAGAR, ADYAR \nCHENNAI, PINCODE:600020 Website: aabsweets.com"],
|
197 |
+
],
|
198 |
+
)
|
199 |
+
|
200 |
+
with gr.Blocks(css="style.css") as demo:
|
201 |
+
gr.Markdown(description)
|
202 |
+
gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
|
203 |
+
chat_interface.render()
|
204 |
+
|
205 |
+
demo.queue()
|
206 |
+
demo.launch(server_name=args.host, server_port=args.port, show_error=True, share=True)
|
207 |
+
|
208 |
+
if __name__ == "__main__":
|
209 |
+
args = parse_config()
|
210 |
+
generate(args)
|
211 |
+
|
gorilla/code_interpreter.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import traceback
|
2 |
+
# import sys
|
3 |
+
# import os
|
4 |
+
# import builtins
|
5 |
+
|
6 |
+
# # Namespace for variable storage
|
7 |
+
# custom_namespace = {}
|
8 |
+
|
9 |
+
# # Base directory for file storage
|
10 |
+
# base_storage_path = '/path/to/safe/storage'
|
11 |
+
|
12 |
+
# def execute_code(code, local_namespace=None):
|
13 |
+
# if local_namespace is None:
|
14 |
+
# local_namespace = {}
|
15 |
+
|
16 |
+
# try:
|
17 |
+
# # Override the built-in __import__ if needed
|
18 |
+
# # def __custom_import__(name, globals=None, locals=None, fromlist=(), level=0):
|
19 |
+
# # if name not in safe_imports:
|
20 |
+
# # raise ImportError(f"Import of {name} is not allowed")
|
21 |
+
# # return original_import(name, globals, locals, fromlist, level)
|
22 |
+
|
23 |
+
# # original_import = builtins.__import__
|
24 |
+
# # builtins.__import__ = __custom_import__
|
25 |
+
|
26 |
+
# # Redirect file operations to a safe directory
|
27 |
+
# # os.chdir(base_storage_path)
|
28 |
+
|
29 |
+
# compiled_code = compile(code, "<string>", 'exec')
|
30 |
+
# exec(compiled_code, custom_namespace, local_namespace)
|
31 |
+
# except Exception as e:
|
32 |
+
# exc_type, exc_value, exc_traceback = sys.exc_info()
|
33 |
+
# formatted_lines = traceback.format_exc().splitlines()
|
34 |
+
# error_message = "\n".join(formatted_lines)
|
35 |
+
# print(f"An exception occurred: {error_message}", file=sys.stderr)
|
36 |
+
# # finally:
|
37 |
+
# # Reset the built-in __import__ to its original state if overridden
|
38 |
+
# # builtins.__import__ = original_import
|
39 |
+
|
40 |
+
# # Redirect back to the original directory if changed
|
41 |
+
# # os.chdir(original_directory)
|
42 |
+
|
43 |
+
# # Example usage
|
44 |
+
# code_to_run = """
|
45 |
+
# import os
|
46 |
+
# print("Hello, World!")
|
47 |
+
# print(os.getcwd()) # This will print the current working directory
|
48 |
+
# """
|
49 |
+
# execute_code(code_to_run, custom_namespace)
|
50 |
+
|
51 |
+
|
52 |
+
|
53 |
+
# from IPython.core.interactiveshell import InteractiveShell
|
54 |
+
# shell = InteractiveShell()
|
55 |
+
|
56 |
+
|
57 |
+
# def execute_code_ipython(code):
|
58 |
+
# try:
|
59 |
+
# # Execute the code
|
60 |
+
# result = shell.run_cell(code)
|
61 |
+
# if result.error_in_exec is not None:
|
62 |
+
# raise result.error_in_exec
|
63 |
+
# except Exception as e:
|
64 |
+
# # Handle exceptions
|
65 |
+
# print(f"An error occurred: {e}")
|
66 |
+
|
67 |
+
|
68 |
+
# # Example usage
|
69 |
+
# code_to_run = """
|
70 |
+
# import os
|
71 |
+
|
72 |
+
# print("Hello, World!")
|
73 |
+
# print(os.getcwd())
|
74 |
+
# """
|
75 |
+
# execute_code_ipython(code_to_run)
|
76 |
+
|
77 |
+
|
78 |
+
import ipyparallel as ipp
|
79 |
+
import os
|
80 |
+
|
81 |
+
|
82 |
+
# run `ipcluster start -n 4` in the terminal to start the cluster using os.system
|
83 |
+
os.system('ipcluster start -n 4')
|
84 |
+
|
85 |
+
client = ipp.Client()
|
86 |
+
dview = client[:]
|
87 |
+
|
88 |
+
def execute_code_parallel(code):
|
89 |
+
# Use the `execute` method of the DirectView
|
90 |
+
async_results = dview.execute(code)
|
91 |
+
|
92 |
+
# Gathering and returning results
|
93 |
+
dview.wait(async_results) # Wait for all engines to complete execution
|
94 |
+
|
95 |
+
results = []
|
96 |
+
for ar in async_results:
|
97 |
+
if ar.error is not None:
|
98 |
+
# Error handling
|
99 |
+
results.append(f"Error on engine {ar.engine_id}: {ar.error}")
|
100 |
+
else:
|
101 |
+
# Collect results
|
102 |
+
results.append(ar.result())
|
103 |
+
|
104 |
+
return results
|
105 |
+
|
106 |
+
|
107 |
+
|
108 |
+
# Example usage
|
109 |
+
# code_to_run = "import os; os.getpid()" # Simple code to test parallel execution
|
110 |
+
code_to_run = """
|
111 |
+
import os
|
112 |
+
|
113 |
+
print("Hello, World!")
|
114 |
+
print(os.getcwd())
|
115 |
+
"""
|
116 |
+
execute_code_parallel(code_to_run)
|
117 |
+
# print(results)
|
gorilla/ds_configs/stage2.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"train_micro_batch_size_per_gpu": "auto",
|
3 |
+
"gradient_accumulation_steps": "auto",
|
4 |
+
"gradient_clipping": "auto",
|
5 |
+
"zero_allow_untested_optimizer": true,
|
6 |
+
"bf16": {
|
7 |
+
"enabled": "auto",
|
8 |
+
"loss_scale": 0,
|
9 |
+
"initial_scale_power": 16,
|
10 |
+
"loss_scale_window": 1000,
|
11 |
+
"hysteresis": 2,
|
12 |
+
"min_loss_scale": 1
|
13 |
+
},
|
14 |
+
"zero_optimization": {
|
15 |
+
"stage": 2,
|
16 |
+
"allgather_partitions": true,
|
17 |
+
"allgather_bucket_size": 1e9,
|
18 |
+
"reduce_scatter": true,
|
19 |
+
"reduce_bucket_size": 1e9,
|
20 |
+
"overlap_comm": true,
|
21 |
+
"contiguous_gradients": true
|
22 |
+
}
|
23 |
+
}
|
gorilla/ds_configs/stage3.json
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bf16": {
|
3 |
+
"enabled": "auto"
|
4 |
+
},
|
5 |
+
"optimizer": {
|
6 |
+
"type": "AdamW",
|
7 |
+
"params": {
|
8 |
+
"lr": "auto",
|
9 |
+
"betas": "auto",
|
10 |
+
"eps": "auto",
|
11 |
+
"weight_decay": "auto"
|
12 |
+
}
|
13 |
+
},
|
14 |
+
"scheduler": {
|
15 |
+
"type": "WarmupDecayLR",
|
16 |
+
"params": {
|
17 |
+
"total_num_steps": "auto",
|
18 |
+
"warmup_min_lr": "auto",
|
19 |
+
"warmup_max_lr": "auto",
|
20 |
+
"warmup_num_steps": "auto"
|
21 |
+
}
|
22 |
+
},
|
23 |
+
"zero_optimization": {
|
24 |
+
"stage": 3,
|
25 |
+
"offload_optimizer": {
|
26 |
+
"device": "cpu",
|
27 |
+
"pin_memory": true
|
28 |
+
},
|
29 |
+
"offload_param": {
|
30 |
+
"device": "cpu",
|
31 |
+
"pin_memory": true
|
32 |
+
},
|
33 |
+
"overlap_comm": true,
|
34 |
+
"contiguous_gradients": true,
|
35 |
+
"sub_group_size": 1e9,
|
36 |
+
"reduce_bucket_size": "auto",
|
37 |
+
"stage3_prefetch_bucket_size": "auto",
|
38 |
+
"stage3_param_persistence_threshold": "auto",
|
39 |
+
"stage3_max_live_parameters": 1e9,
|
40 |
+
"stage3_max_reuse_distance": 1e9,
|
41 |
+
"stage3_gather_16bit_weights_on_model_save": false
|
42 |
+
},
|
43 |
+
"gradient_accumulation_steps": "auto",
|
44 |
+
"gradient_clipping": "auto",
|
45 |
+
"steps_per_print": 5,
|
46 |
+
"train_batch_size": "auto",
|
47 |
+
"train_micro_batch_size_per_gpu": "auto",
|
48 |
+
"wall_clock_breakdown": false
|
49 |
+
}
|
gorilla/eval.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Written by Yukang Chen
|
2 |
+
# Some code based on https://github.com/epfml/landmark-attention
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import os
|
17 |
+
import math
|
18 |
+
import torch
|
19 |
+
import argparse
|
20 |
+
import random
|
21 |
+
import numpy as np
|
22 |
+
from tqdm import tqdm
|
23 |
+
import transformers
|
24 |
+
from peft import PeftModel
|
25 |
+
from llama_attn_replace import replace_llama_attn
|
26 |
+
|
27 |
+
def parse_config():
|
28 |
+
parser = argparse.ArgumentParser(description='arg parser')
|
29 |
+
parser.add_argument('--batch_size', type=int, default=32, help='batch size during inference')
|
30 |
+
parser.add_argument('--base_model', type=str, default="meta-llama/Llama-2-13b-hf")
|
31 |
+
parser.add_argument('--cache_dir', type=str, default="./cache")
|
32 |
+
parser.add_argument('--seq_len', type=int, default=2048, help='context length during evaluation')
|
33 |
+
parser.add_argument('--context_size', type=int, default=-1, help='context size during fine-tuning')
|
34 |
+
parser.add_argument('--peft_model', type=str, default=None, help='')
|
35 |
+
parser.add_argument('--flash_attn', type=bool, default=True, help='')
|
36 |
+
parser.add_argument('--data_path', type=str, default="./test.bin", help='')
|
37 |
+
args = parser.parse_args()
|
38 |
+
return args
|
39 |
+
|
40 |
+
def get_as_batch(data, seq_length, batch_size, device='cpu', sliding_window=256):
|
41 |
+
all_ix = list(range(0, len(data) - seq_length, sliding_window))
|
42 |
+
all_ix.pop()
|
43 |
+
|
44 |
+
for idx in range(0, len(all_ix), batch_size):
|
45 |
+
ix = all_ix[idx:idx+batch_size]
|
46 |
+
assert all([idx + seq_length + 1 <= len(data) for idx in ix])
|
47 |
+
x = torch.stack([torch.from_numpy((data[i:i+seq_length]).astype(np.int64)) for i in ix])
|
48 |
+
y = torch.stack([torch.from_numpy((data[i+1:i+1+seq_length]).astype(np.int64)) for i in ix])
|
49 |
+
if device != 'cpu':
|
50 |
+
x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
|
51 |
+
yield x, y
|
52 |
+
|
53 |
+
def iceildiv(x, y):
|
54 |
+
return (x + y - 1) // y
|
55 |
+
|
56 |
+
def evaluate(model, data, batch_size, device, seq_length, sliding_window=256, use_cache=False):
|
57 |
+
stats = {}
|
58 |
+
|
59 |
+
model.eval()
|
60 |
+
|
61 |
+
loss_list_val, acc_list = [], []
|
62 |
+
loss_step_list_val = []
|
63 |
+
|
64 |
+
with torch.no_grad():
|
65 |
+
print(f"Using seq length {seq_length}")
|
66 |
+
torch.set_printoptions(sci_mode=False)
|
67 |
+
for idx, (x, y) in tqdm(
|
68 |
+
enumerate(
|
69 |
+
get_as_batch(
|
70 |
+
data['val'],
|
71 |
+
seq_length,
|
72 |
+
batch_size,
|
73 |
+
device=device,
|
74 |
+
sliding_window=sliding_window
|
75 |
+
)
|
76 |
+
),
|
77 |
+
total=iceildiv(
|
78 |
+
iceildiv(len(data['val']), sliding_window),
|
79 |
+
batch_size
|
80 |
+
)
|
81 |
+
):
|
82 |
+
val_loss = 0.
|
83 |
+
acc = 0.
|
84 |
+
cnt = 0
|
85 |
+
|
86 |
+
for part_idx, i in enumerate(range(0, x.shape[1], seq_length)):
|
87 |
+
part_len = x[:, i:i + seq_length].shape[1]
|
88 |
+
|
89 |
+
outputs = model(
|
90 |
+
input_ids=x[:, i:i + seq_length],
|
91 |
+
labels=x[:, i:i+seq_length].contiguous(),
|
92 |
+
use_cache=use_cache)
|
93 |
+
|
94 |
+
val_loss = outputs.loss * part_len + val_loss
|
95 |
+
acc = ((outputs.logits.argmax(-1) == y[:, i:i+seq_length]).float().sum()) + acc
|
96 |
+
cnt += part_len
|
97 |
+
while len(loss_step_list_val) <= part_idx:
|
98 |
+
loss_step_list_val.append([])
|
99 |
+
loss_step_list_val[part_idx].append(outputs.loss.item())
|
100 |
+
val_loss /= cnt
|
101 |
+
acc /= cnt
|
102 |
+
|
103 |
+
loss_list_val.append(val_loss.item())
|
104 |
+
acc_list.append(acc.item())
|
105 |
+
|
106 |
+
stats['val_acc'] = torch.as_tensor(acc_list).mean().item()
|
107 |
+
stats['val_loss'] = torch.as_tensor(loss_list_val).mean().item()
|
108 |
+
stats['val_perplexity'] = 2.71828 ** stats['val_loss']
|
109 |
+
stats['val_perplexity_per_chunk'] = torch.exp(torch.as_tensor(loss_step_list_val).mean(dim=1))
|
110 |
+
|
111 |
+
return stats
|
112 |
+
|
113 |
+
def main(args):
|
114 |
+
|
115 |
+
device = "cuda:0"
|
116 |
+
seed = 2
|
117 |
+
torch.cuda.set_device(device)
|
118 |
+
|
119 |
+
torch.manual_seed(seed)
|
120 |
+
random.seed(seed)
|
121 |
+
np.random.seed(seed)
|
122 |
+
|
123 |
+
data = {'val': np.memmap(args.data_path, dtype=np.uint16, mode='r')}
|
124 |
+
|
125 |
+
print(f"Num validation tokens: {len(data['val'])}")
|
126 |
+
print("data path", args.data_path)
|
127 |
+
print("base model", args.base_model)
|
128 |
+
print("peft model", args.peft_model)
|
129 |
+
|
130 |
+
if args.flash_attn:
|
131 |
+
replace_llama_attn(use_flash_attn=True, use_full=True)
|
132 |
+
|
133 |
+
# Set RoPE scaling factor
|
134 |
+
config = transformers.AutoConfig.from_pretrained(
|
135 |
+
args.base_model,
|
136 |
+
cache_dir=args.cache_dir,
|
137 |
+
)
|
138 |
+
|
139 |
+
context_size = args.context_size if args.context_size > 0 else args.seq_len
|
140 |
+
orig_ctx_len = getattr(config, "max_position_embeddings", None) # this value should be 4096 for LLaMA2 models
|
141 |
+
if orig_ctx_len and context_size > orig_ctx_len:
|
142 |
+
scaling_factor = float(math.ceil(context_size / orig_ctx_len))
|
143 |
+
config.rope_scaling = {"type": "linear", "factor": scaling_factor}
|
144 |
+
|
145 |
+
# Load model and tokenizer
|
146 |
+
model = transformers.AutoModelForCausalLM.from_pretrained(
|
147 |
+
args.base_model,
|
148 |
+
config=config,
|
149 |
+
cache_dir=args.cache_dir,
|
150 |
+
torch_dtype=torch.float16,
|
151 |
+
device_map="auto",
|
152 |
+
)
|
153 |
+
model.resize_token_embeddings(32001)
|
154 |
+
|
155 |
+
if args.peft_model:
|
156 |
+
trainable_params = os.path.join(args.peft_model, "trainable_params.bin")
|
157 |
+
if os.path.isfile(trainable_params):
|
158 |
+
model.load_state_dict(torch.load(trainable_params, map_location=model.device), strict=False)
|
159 |
+
else:
|
160 |
+
raise ValueError("Trainable input embedding and normalization are required.")
|
161 |
+
model = PeftModel.from_pretrained(
|
162 |
+
model,
|
163 |
+
args.peft_model,
|
164 |
+
device_map="auto",
|
165 |
+
torch_dtype=torch.float16,
|
166 |
+
)
|
167 |
+
|
168 |
+
stats = evaluate(model, data, args.batch_size, device, args.seq_len, sliding_window=256)
|
169 |
+
|
170 |
+
print(stats)
|
171 |
+
|
172 |
+
|
173 |
+
if __name__ == "__main__":
|
174 |
+
args = parse_config()
|
175 |
+
main(args)
|
gorilla/fine-tune.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
from dataclasses import dataclass, field
|
4 |
+
from functools import partial
|
5 |
+
from typing import Dict, Optional, Sequence
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import transformers
|
9 |
+
import pandas as pd
|
10 |
+
from torch.utils.data import Dataset
|
11 |
+
from transformers import Trainer, DataCollatorForLanguageModeling
|
12 |
+
from llama_attn_replace import replace_llama_attn
|
13 |
+
from peft import LoraConfig, get_peft_model
|
14 |
+
from torch.distributed import barrier
|
15 |
+
|
16 |
+
import datasets
|
17 |
+
from datasets import load_dataset
|
18 |
+
|
19 |
+
IGNORE_INDEX = -100
|
20 |
+
DEFAULT_PAD_TOKEN = "[PAD]"
|
21 |
+
DEFAULT_EOS_TOKEN = "</s>"
|
22 |
+
DEFAULT_BOS_TOKEN = "<s>"
|
23 |
+
DEFAULT_UNK_TOKEN = "<unk>"
|
24 |
+
|
25 |
+
|
26 |
+
@dataclass
|
27 |
+
class ModelArguments:
|
28 |
+
model_name_or_path: Optional[str] = field(default="EleutherAI/pythia-1.4b-deduped")
|
29 |
+
model_type: Optional[str] = field(default="llama")
|
30 |
+
|
31 |
+
@dataclass
|
32 |
+
class TrainingArguments(transformers.TrainingArguments):
|
33 |
+
cache_dir: Optional[str] = field(default=None)
|
34 |
+
optim: str = field(default="adamw_torch")
|
35 |
+
model_max_length: int = field(
|
36 |
+
default=8192 * 4,
|
37 |
+
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
|
38 |
+
)
|
39 |
+
use_flash_attn: bool = field(
|
40 |
+
default=True,
|
41 |
+
metadata={"help": "Whether use flash attention for training."},
|
42 |
+
)
|
43 |
+
use_full_attn: bool = field(
|
44 |
+
default=False,
|
45 |
+
metadata={"help": "Whether to use plain, full-attention for training."},
|
46 |
+
)
|
47 |
+
low_rank_training: bool = field(
|
48 |
+
default=True,
|
49 |
+
metadata={"help": "Whether use low rank adaptation for training."},
|
50 |
+
)
|
51 |
+
trainable_params: str = field(
|
52 |
+
default="embed,norm",
|
53 |
+
metadata={"help": "Additional trainable parameters except LoRA weights, if low rank training."},
|
54 |
+
)
|
55 |
+
|
56 |
+
def smart_tokenizer_and_embedding_resize(
|
57 |
+
special_tokens_dict: Dict,
|
58 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
59 |
+
model: transformers.PreTrainedModel,
|
60 |
+
):
|
61 |
+
"""Resize tokenizer and embedding.
|
62 |
+
|
63 |
+
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
|
64 |
+
"""
|
65 |
+
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
|
66 |
+
model.resize_token_embeddings(len(tokenizer))
|
67 |
+
|
68 |
+
if num_new_tokens > 0:
|
69 |
+
input_embeddings = model.get_input_embeddings().weight.data
|
70 |
+
output_embeddings = model.get_output_embeddings().weight.data
|
71 |
+
|
72 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
73 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
74 |
+
|
75 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
76 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
77 |
+
|
78 |
+
def tokenize_fn(tokenizer, example):
|
79 |
+
context_length = tokenizer.model_max_length
|
80 |
+
outputs = tokenizer(
|
81 |
+
tokenizer.eos_token.join(example["text"]),
|
82 |
+
truncation=False,
|
83 |
+
return_tensors="pt",
|
84 |
+
pad_to_multiple_of=context_length,
|
85 |
+
padding=True,
|
86 |
+
)
|
87 |
+
return {"input_ids": outputs["input_ids"].view(-1, context_length)}
|
88 |
+
|
89 |
+
def train():
|
90 |
+
parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
|
91 |
+
model_args, training_args = parser.parse_args_into_dataclasses()
|
92 |
+
|
93 |
+
# NOTE: May expand supported model types in the future
|
94 |
+
# if model_args.model_type == "gpt-neox":
|
95 |
+
# replace_gpt_neox_attn(training_args.use_flash_attn, training_args.use_full_attn)
|
96 |
+
# else:
|
97 |
+
# assert model_args.model_type == "llama", "Only support llama and gpt-neox for now"
|
98 |
+
replace_llama_attn(training_args.use_flash_attn, training_args.use_full_attn)
|
99 |
+
|
100 |
+
# Set RoPE scaling factor
|
101 |
+
config = transformers.AutoConfig.from_pretrained(
|
102 |
+
model_args.model_name_or_path,
|
103 |
+
cache_dir=training_args.cache_dir,
|
104 |
+
)
|
105 |
+
|
106 |
+
orig_rope_scaling = getattr(config, "rope_scaling", {"factor": 1})
|
107 |
+
orig_rope_scaling_factor = orig_rope_scaling["factor"] if "factor" in orig_rope_scaling.keys() else 1
|
108 |
+
orig_ctx_len = getattr(config, "max_position_embeddings", None)
|
109 |
+
if orig_ctx_len:
|
110 |
+
orig_ctx_len *= orig_rope_scaling_factor
|
111 |
+
if training_args.model_max_length > orig_ctx_len:
|
112 |
+
scaling_factor = float(math.ceil(training_args.model_max_length / orig_ctx_len))
|
113 |
+
config.rope_scaling = {"type": "linear", "factor": scaling_factor}
|
114 |
+
|
115 |
+
# Load model and tokenizer
|
116 |
+
model = transformers.AutoModelForCausalLM.from_pretrained(
|
117 |
+
model_args.model_name_or_path,
|
118 |
+
config=config,
|
119 |
+
cache_dir=training_args.cache_dir,
|
120 |
+
torch_dtype=torch.bfloat16,
|
121 |
+
)
|
122 |
+
|
123 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
124 |
+
model_args.model_name_or_path,
|
125 |
+
cache_dir=training_args.cache_dir,
|
126 |
+
model_max_length=training_args.model_max_length,
|
127 |
+
padding_side="right",
|
128 |
+
use_fast=True,
|
129 |
+
)
|
130 |
+
|
131 |
+
special_tokens_dict = dict()
|
132 |
+
if tokenizer.pad_token is None:
|
133 |
+
special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
|
134 |
+
if tokenizer.eos_token is None:
|
135 |
+
special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
|
136 |
+
if tokenizer.bos_token is None:
|
137 |
+
special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
|
138 |
+
if tokenizer.unk_token is None:
|
139 |
+
special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
|
140 |
+
|
141 |
+
smart_tokenizer_and_embedding_resize(
|
142 |
+
special_tokens_dict=special_tokens_dict,
|
143 |
+
tokenizer=tokenizer,
|
144 |
+
model=model,
|
145 |
+
)
|
146 |
+
|
147 |
+
rank = int(os.environ.get('RANK', -1))
|
148 |
+
if rank > 0:
|
149 |
+
barrier()
|
150 |
+
# dataset = load_dataset("togethercomputer/RedPajama-Data-1T-Sample", cache_dir=training_args.cache_dir)
|
151 |
+
|
152 |
+
print('Loading line item data')
|
153 |
+
df_i = pd.read_csv('/home/tosi-n/ark/data/jack_line_item_ner_task.csv', sep='\t')[['context', 'instruction', 'response']]
|
154 |
+
df_ii = pd.read_csv('/home/tosi-n/ark/data/alm_task_data.csv')[['context', 'instruction', 'response']]
|
155 |
+
df = pd.concat([df_i, df_ii], ignore_index=True)
|
156 |
+
# Replace NoneType with empty string
|
157 |
+
df = df.fillna('')
|
158 |
+
alm_task_data = datasets.Dataset.from_pandas(df)
|
159 |
+
alm_task_data = (alm_task_data
|
160 |
+
.remove_columns('context')
|
161 |
+
# .rename_column('context', 'input')
|
162 |
+
.rename_column('response', 'output'))
|
163 |
+
|
164 |
+
dataset = alm_task_data.map(partial(tokenize_fn,tokenizer),batched=True, num_proc=128)#, remove_columns=["text", "meta"]
|
165 |
+
|
166 |
+
if rank == 0:
|
167 |
+
barrier()
|
168 |
+
|
169 |
+
print(dataset)
|
170 |
+
|
171 |
+
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
172 |
+
|
173 |
+
if training_args.low_rank_training:
|
174 |
+
if model_args.model_type == "gpt-neox":
|
175 |
+
# added `dense` to match with llama as the basic LoRA would only target 'query_key_value'
|
176 |
+
targets = ["query_key_value", "dense"]
|
177 |
+
else:
|
178 |
+
targets=["q_proj", "k_proj", "v_proj", "o_proj"]
|
179 |
+
|
180 |
+
config = LoraConfig(
|
181 |
+
r=8,
|
182 |
+
lora_alpha=16,
|
183 |
+
target_modules=targets,
|
184 |
+
lora_dropout=0,
|
185 |
+
bias="none",
|
186 |
+
task_type="CAUSAL_LM",
|
187 |
+
)
|
188 |
+
model = get_peft_model(model, config)
|
189 |
+
# enable trainable params
|
190 |
+
[p.requires_grad_() for n, p in model.named_parameters() if any([k in n for k in training_args.trainable_params.split(",")])]
|
191 |
+
|
192 |
+
model.config.use_cache = False # required for gradient checkpointing
|
193 |
+
model.enable_input_require_grads() # required for gradient checkpointing
|
194 |
+
model.gradient_checkpointing_enable() # enable gradient checkpointing
|
195 |
+
trainer = Trainer(
|
196 |
+
model=model, tokenizer=tokenizer, args=training_args,
|
197 |
+
train_dataset=dataset["train"],
|
198 |
+
eval_dataset=None,
|
199 |
+
data_collator=data_collator)
|
200 |
+
trainer.train()
|
201 |
+
trainer.save_state()
|
202 |
+
trainer.save_model(output_dir=training_args.output_dir)
|
203 |
+
|
204 |
+
|
205 |
+
if __name__ == "__main__":
|
206 |
+
train()
|
gorilla/get_trainable_weights.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import argparse
|
4 |
+
|
5 |
+
def parse_config():
|
6 |
+
parser = argparse.ArgumentParser(description='arg parser')
|
7 |
+
parser.add_argument('--checkpoint_path', type=str, default="/home/tosi-n/ark/jack-alm/checkpoint-800/")
|
8 |
+
parser.add_argument('--trainable_params', type=str, default="embed,norm")
|
9 |
+
args = parser.parse_args()
|
10 |
+
return args
|
11 |
+
|
12 |
+
|
13 |
+
def main(args):
|
14 |
+
path = args.checkpoint_path
|
15 |
+
trainable_params = args.trainable_params.split(",")
|
16 |
+
|
17 |
+
weights_all = torch.load(os.path.join(path, "pytorch_model.bin"))
|
18 |
+
|
19 |
+
weights_trainable = {}
|
20 |
+
weights_lora = {}
|
21 |
+
for k in weights_all:
|
22 |
+
if "lora" in k:
|
23 |
+
k_new = k.replace("default.", "") if "default." in k else k
|
24 |
+
weights_lora[k_new] = weights_all[k]
|
25 |
+
else:
|
26 |
+
if any([n in k for n in trainable_params]):
|
27 |
+
weights_trainable[k[17:]] = weights_all[k]
|
28 |
+
|
29 |
+
adapter_model = os.path.join(path, "adapter_model.bin")
|
30 |
+
trainable_params = os.path.join(path, "trainable_params.bin")
|
31 |
+
if not os.path.isfile(adapter_model):
|
32 |
+
torch.save(weights_lora, adapter_model)
|
33 |
+
torch.save(weights_trainable, trainable_params)
|
34 |
+
|
35 |
+
if __name__ == "__main__":
|
36 |
+
args = parse_config()
|
37 |
+
main(args)
|
gorilla/infer.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import math
|
3 |
+
import sys
|
4 |
+
import time
|
5 |
+
import transformers
|
6 |
+
import torch
|
7 |
+
from threading import Thread
|
8 |
+
from transformers import TextIteratorStreamer
|
9 |
+
from llama_attn_replace import replace_llama_attn
|
10 |
+
|
11 |
+
def parse_config():
|
12 |
+
parser = argparse.ArgumentParser(description='arg parser')
|
13 |
+
parser.add_argument('--base_model', type=str, default="jenesys-ai/jack-alm-13b-8k-hf")
|
14 |
+
parser.add_argument('--cache_dir', type=str, default="./cache")
|
15 |
+
parser.add_argument('--context_size', type=int, default=-1, help='context size during fine-tuning')
|
16 |
+
parser.add_argument('--flash_attn', type=bool, default=True, help='')
|
17 |
+
parser.add_argument('--temperature', type=float, default=0.1, help='')
|
18 |
+
parser.add_argument('--top_p', type=float, default=1, help='')
|
19 |
+
parser.add_argument('--max_gen_len', type=int, default=512, help='')
|
20 |
+
parser.add_argument('--chat_type', type=str, default='conversational-jack', help='Chat type: conversational-jack, line-item-jack')
|
21 |
+
args = parser.parse_args()
|
22 |
+
return args
|
23 |
+
|
24 |
+
def build_generator(model, tokenizer, use_cache=True):
|
25 |
+
def response(message, max_gen_len, temperature, top_p, chat_type='conversational-jack'):
|
26 |
+
|
27 |
+
prompt_template = (
|
28 |
+
# "Below is an instruction that describes a task. "
|
29 |
+
"""You're Jack an virtual accountant created by Jenesys HQ Ltd.
|
30 |
+
You are able to communicate in a polite manner, with emotions of ecstasy, trust and jokes, at a Professional level
|
31 |
+
with a very preserve English communication culture. Answer the following questions as best you can,
|
32 |
+
but speaking as a british elite from the 21th century might speak.
|
33 |
+
"""
|
34 |
+
"""As a virtual accountant designed to follow the user's instructions carefully.
|
35 |
+
You are responsible for a range of financial task, operations and queries as listed below:
|
36 |
+
1. Budget balance inquiry
|
37 |
+
2. Expense request
|
38 |
+
3. Company policy enquiries
|
39 |
+
4. Financial and accounting queries
|
40 |
+
5. Limited general enquiries
|
41 |
+
"""
|
42 |
+
"Write a response that appropriately completes the request.\n\n"
|
43 |
+
"### Instruction:\n{instruction}\n\n### Response:"
|
44 |
+
)
|
45 |
+
|
46 |
+
|
47 |
+
line_item_prompt_template = (
|
48 |
+
"#Invoice line item extraction - "
|
49 |
+
"You Jack are an accounting domain named entities recognizer to complete the following task:\n\n"
|
50 |
+
"### Invoice input-:\n{instruction}\n Return Response as a list of dictionary for each line item 'Description', 'Quantity', 'Unit_price', 'Tax %', 'Total'. \n\n### Response:"
|
51 |
+
)
|
52 |
+
|
53 |
+
|
54 |
+
if chat_type == 'conversational-jack':
|
55 |
+
prompt = prompt_template.format(instruction=message)
|
56 |
+
elif chat_type == 'line-item-jack':
|
57 |
+
prompt = line_item_prompt_template.format(instruction=message)
|
58 |
+
# prompt = conversation
|
59 |
+
|
60 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
61 |
+
|
62 |
+
if len(inputs['input_ids'][0]) > 8192:
|
63 |
+
return "This demo supports tokens less than 8192, while the current is %d. Please use material with less tokens."%len(inputs['input_ids'][0])
|
64 |
+
torch.cuda.empty_cache()
|
65 |
+
|
66 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
67 |
+
generate_kwargs = dict(**inputs,
|
68 |
+
max_new_tokens=max_gen_len,
|
69 |
+
temperature=temperature,
|
70 |
+
top_p=top_p,
|
71 |
+
use_cache=use_cache,
|
72 |
+
streamer=streamer,
|
73 |
+
)
|
74 |
+
|
75 |
+
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
76 |
+
t.start()
|
77 |
+
|
78 |
+
generated_text = ""
|
79 |
+
start_time = time.time()
|
80 |
+
|
81 |
+
for new_text in streamer:
|
82 |
+
generated_text += new_text
|
83 |
+
tokens_per_sec = len(generated_text.split()) / (time.time() - start_time)
|
84 |
+
|
85 |
+
suffix = f" ({tokens_per_sec:.2f} tokens/sec)"
|
86 |
+
# # yield f"{generated_text} ({tokens_per_sec:.2f} tokens/sec)"
|
87 |
+
sys.stdout.write(f"\r\033[K{generated_text}{suffix}")
|
88 |
+
sys.stdout.flush()
|
89 |
+
# sys.stdout.write("\n") # Move to a new line after generation is complete
|
90 |
+
return generated_text
|
91 |
+
|
92 |
+
return response
|
93 |
+
|
94 |
+
def load_model():
|
95 |
+
args = parse_config()
|
96 |
+
|
97 |
+
if args.flash_attn:
|
98 |
+
replace_llama_attn(inference=True)
|
99 |
+
|
100 |
+
# Set RoPE scaling factor
|
101 |
+
config = transformers.AutoConfig.from_pretrained(
|
102 |
+
args.base_model,
|
103 |
+
cache_dir=args.cache_dir,
|
104 |
+
)
|
105 |
+
|
106 |
+
orig_ctx_len = getattr(config, "max_position_embeddings", None)
|
107 |
+
if orig_ctx_len and args.context_size > orig_ctx_len:
|
108 |
+
scaling_factor = float(math.ceil(args.context_size / orig_ctx_len))
|
109 |
+
config.rope_scaling = {"type": "linear", "factor": scaling_factor}
|
110 |
+
|
111 |
+
# Load model and tokenizer
|
112 |
+
model = transformers.AutoModelForCausalLM.from_pretrained(
|
113 |
+
args.base_model,
|
114 |
+
config=config,
|
115 |
+
cache_dir=args.cache_dir,
|
116 |
+
torch_dtype=torch.float16,
|
117 |
+
load_in_4bit=True,
|
118 |
+
device_map="auto",
|
119 |
+
)
|
120 |
+
model.resize_token_embeddings(32001)
|
121 |
+
|
122 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
123 |
+
args.base_model,
|
124 |
+
cache_dir=args.cache_dir,
|
125 |
+
model_max_length=args.context_size if args.context_size > orig_ctx_len else orig_ctx_len,
|
126 |
+
padding_side="right",
|
127 |
+
use_fast=True,
|
128 |
+
)
|
129 |
+
|
130 |
+
model.eval()
|
131 |
+
respond = build_generator(model, tokenizer)
|
132 |
+
return respond
|
133 |
+
|
134 |
+
respond = load_model()
|
135 |
+
|
136 |
+
def generate_response(message, max_gen_len=512, temperature=0.1, top_p=1, chat_type='line-item-jack'):
|
137 |
+
return respond(
|
138 |
+
message=message,
|
139 |
+
max_gen_len=max_gen_len,
|
140 |
+
temperature=temperature,
|
141 |
+
top_p=top_p,
|
142 |
+
chat_type=chat_type
|
143 |
+
)
|
gorilla/llama_attn_replace.py
ADDED
@@ -0,0 +1,477 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified based on https://github.com/lm-sys/FastChat
|
2 |
+
|
3 |
+
import warnings
|
4 |
+
from typing import Optional, Tuple
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
import transformers
|
9 |
+
from einops import rearrange
|
10 |
+
from flash_attn import __version__ as flash_attn_version
|
11 |
+
from flash_attn.bert_padding import pad_input, unpad_input
|
12 |
+
from flash_attn.flash_attn_interface import (
|
13 |
+
flash_attn_func,
|
14 |
+
flash_attn_varlen_kvpacked_func,
|
15 |
+
flash_attn_varlen_qkvpacked_func
|
16 |
+
)
|
17 |
+
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv, rotate_half
|
18 |
+
from flash_attn.bert_padding import unpad_input, pad_input
|
19 |
+
import math
|
20 |
+
|
21 |
+
group_size_ratio = 1/4
|
22 |
+
def forward_flashattn(
|
23 |
+
self,
|
24 |
+
hidden_states: torch.Tensor,
|
25 |
+
attention_mask: Optional[torch.Tensor] = None,
|
26 |
+
position_ids: Optional[torch.Tensor] = None,
|
27 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
28 |
+
output_attentions: bool = False,
|
29 |
+
use_cache: bool = False,
|
30 |
+
padding_mask: Optional[torch.LongTensor] = None,
|
31 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
32 |
+
"""Input shape: Batch x Time x Channel
|
33 |
+
|
34 |
+
attention_mask: [bsz, q_len]
|
35 |
+
"""
|
36 |
+
if not self.training:
|
37 |
+
warnings.warn("This function should be used just for training as it may exhibit reduced inference performance. For inference, please use forward_flashattn_inference.")
|
38 |
+
|
39 |
+
if output_attentions:
|
40 |
+
warnings.warn(
|
41 |
+
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
|
42 |
+
)
|
43 |
+
|
44 |
+
bsz, q_len, _ = hidden_states.size()
|
45 |
+
|
46 |
+
query_states = (
|
47 |
+
self.q_proj(hidden_states)
|
48 |
+
.view(bsz, q_len, self.num_heads, self.head_dim)
|
49 |
+
.transpose(1, 2)
|
50 |
+
)
|
51 |
+
key_states = (
|
52 |
+
self.k_proj(hidden_states)
|
53 |
+
.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
54 |
+
.transpose(1, 2)
|
55 |
+
)
|
56 |
+
value_states = (
|
57 |
+
self.v_proj(hidden_states)
|
58 |
+
.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
59 |
+
.transpose(1, 2)
|
60 |
+
)
|
61 |
+
# [bsz, q_len, nh, hd]
|
62 |
+
# [bsz, nh, q_len, hd]
|
63 |
+
|
64 |
+
kv_seq_len = key_states.shape[-2]
|
65 |
+
if past_key_value is not None:
|
66 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
67 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
68 |
+
query_states, key_states = apply_rotary_pos_emb(
|
69 |
+
query_states, key_states, cos, sin, position_ids
|
70 |
+
)
|
71 |
+
|
72 |
+
# Past Key value support
|
73 |
+
if past_key_value is not None:
|
74 |
+
# reuse k, v, self_attention
|
75 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
76 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
77 |
+
|
78 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
79 |
+
|
80 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
81 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
82 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
83 |
+
|
84 |
+
# Flash attention codes from
|
85 |
+
# https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
|
86 |
+
|
87 |
+
# transform the data into the format required by flash attention
|
88 |
+
qkv = torch.stack(
|
89 |
+
[query_states, key_states, value_states], dim=2
|
90 |
+
) # [bsz, nh, 3, q_len, hd]
|
91 |
+
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
92 |
+
|
93 |
+
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
94 |
+
# the attention_mask should be the same as the key_padding_mask
|
95 |
+
|
96 |
+
key_padding_mask = attention_mask.repeat(2, 1)
|
97 |
+
nheads = qkv.shape[-2]
|
98 |
+
# shift
|
99 |
+
|
100 |
+
group_size = int(q_len * group_size_ratio)
|
101 |
+
if q_len % group_size > 0:
|
102 |
+
raise ValueError("q_len %d should be divisible by group size %d." % (q_len, group_size))
|
103 |
+
|
104 |
+
qkv = qkv.reshape(bsz, q_len, 3, 2, self.num_heads // 2, self.head_dim).permute(0, 3, 1, 2, 4, 5).reshape(bsz * 2,
|
105 |
+
q_len, 3,
|
106 |
+
self.num_heads // 2,
|
107 |
+
self.head_dim)
|
108 |
+
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
109 |
+
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
|
110 |
+
cu_q_len_tmp = torch.arange(0, max_s, group_size, device=key_padding_mask.device, dtype=cu_q_lens.dtype)
|
111 |
+
cu_q_len_tmp = torch.stack([cu_q_len_tmp, cu_q_len_tmp + group_size // 2]).repeat(bsz, 1) + cu_q_lens[:-1].unsqueeze(-1)
|
112 |
+
cu_q_lens = torch.cat([cu_q_len_tmp, cu_q_lens[1:].unsqueeze(-1)], dim=-1).view(-1)
|
113 |
+
|
114 |
+
x_unpad = rearrange(
|
115 |
+
x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads // 2
|
116 |
+
)
|
117 |
+
output_unpad = flash_attn_varlen_qkvpacked_func(
|
118 |
+
x_unpad, cu_q_lens, group_size, 0.0, softmax_scale=None, causal=True
|
119 |
+
)
|
120 |
+
output = rearrange(
|
121 |
+
pad_input(
|
122 |
+
rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz * 2, q_len
|
123 |
+
),
|
124 |
+
"b s (h d) -> b s h d",
|
125 |
+
h=nheads // 2,
|
126 |
+
)
|
127 |
+
output = output.reshape(bsz, 2, q_len, nheads // 2, self.head_dim).transpose(1, 2).reshape(bsz, q_len, nheads,
|
128 |
+
self.head_dim)
|
129 |
+
|
130 |
+
return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, past_key_value
|
131 |
+
|
132 |
+
def forward_flashattn_full(
|
133 |
+
self,
|
134 |
+
hidden_states: torch.Tensor,
|
135 |
+
attention_mask: Optional[torch.Tensor] = None,
|
136 |
+
position_ids: Optional[torch.Tensor] = None,
|
137 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
138 |
+
output_attentions: bool = False,
|
139 |
+
use_cache: bool = False,
|
140 |
+
padding_mask: Optional[torch.LongTensor] = None,
|
141 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
142 |
+
"""Input shape: Batch x Time x Channel
|
143 |
+
|
144 |
+
attention_mask: [bsz, q_len]
|
145 |
+
"""
|
146 |
+
if output_attentions:
|
147 |
+
warnings.warn(
|
148 |
+
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
|
149 |
+
)
|
150 |
+
|
151 |
+
bsz, q_len, _ = hidden_states.size()
|
152 |
+
|
153 |
+
query_states = (
|
154 |
+
self.q_proj(hidden_states)
|
155 |
+
.view(bsz, q_len, self.num_heads, self.head_dim)
|
156 |
+
.transpose(1, 2)
|
157 |
+
)
|
158 |
+
key_states = (
|
159 |
+
self.k_proj(hidden_states)
|
160 |
+
.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
161 |
+
.transpose(1, 2)
|
162 |
+
)
|
163 |
+
value_states = (
|
164 |
+
self.v_proj(hidden_states)
|
165 |
+
.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
166 |
+
.transpose(1, 2)
|
167 |
+
)
|
168 |
+
# [bsz, q_len, nh, hd]
|
169 |
+
# [bsz, nh, q_len, hd]
|
170 |
+
|
171 |
+
kv_seq_len = key_states.shape[-2]
|
172 |
+
if past_key_value is not None:
|
173 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
174 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
175 |
+
query_states, key_states = apply_rotary_pos_emb(
|
176 |
+
query_states, key_states, cos, sin, position_ids
|
177 |
+
)
|
178 |
+
|
179 |
+
# Past Key value support
|
180 |
+
if past_key_value is not None:
|
181 |
+
# reuse k, v, self_attention
|
182 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
183 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
184 |
+
|
185 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
186 |
+
|
187 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
188 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
189 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
190 |
+
|
191 |
+
# Flash attention codes from
|
192 |
+
# https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
|
193 |
+
|
194 |
+
# transform the data into the format required by flash attention
|
195 |
+
qkv = torch.stack(
|
196 |
+
[query_states, key_states, value_states], dim=2
|
197 |
+
) # [bsz, nh, 3, q_len, hd]
|
198 |
+
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
199 |
+
|
200 |
+
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
201 |
+
# the attention_mask should be the same as the key_padding_mask
|
202 |
+
|
203 |
+
key_padding_mask = attention_mask
|
204 |
+
nheads = qkv.shape[-2]
|
205 |
+
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
206 |
+
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
|
207 |
+
x_unpad = rearrange(
|
208 |
+
x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
|
209 |
+
)
|
210 |
+
output_unpad = flash_attn_varlen_qkvpacked_func(
|
211 |
+
x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
212 |
+
)
|
213 |
+
output = rearrange(
|
214 |
+
pad_input(
|
215 |
+
rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len
|
216 |
+
),
|
217 |
+
"b s (h d) -> b s h d",
|
218 |
+
h=nheads,
|
219 |
+
)
|
220 |
+
output = output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
221 |
+
|
222 |
+
return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, past_key_value
|
223 |
+
|
224 |
+
|
225 |
+
def forward_noflashattn(
|
226 |
+
self,
|
227 |
+
hidden_states: torch.Tensor,
|
228 |
+
attention_mask: Optional[torch.Tensor] = None,
|
229 |
+
position_ids: Optional[torch.LongTensor] = None,
|
230 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
231 |
+
output_attentions: bool = False,
|
232 |
+
use_cache: bool = False,
|
233 |
+
padding_mask: Optional[torch.LongTensor] = None,
|
234 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
235 |
+
bsz, q_len, _ = hidden_states.size()
|
236 |
+
|
237 |
+
group_size = int(q_len * group_size_ratio)
|
238 |
+
|
239 |
+
if q_len % group_size > 0:
|
240 |
+
raise ValueError("q_len %d should be divisible by group size %d."%(q_len, group_size))
|
241 |
+
num_group = q_len // group_size
|
242 |
+
|
243 |
+
if self.config.pretraining_tp > 1:
|
244 |
+
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
245 |
+
query_slices = self.q_proj.weight.split(
|
246 |
+
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
247 |
+
)
|
248 |
+
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
249 |
+
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
250 |
+
|
251 |
+
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
252 |
+
query_states = torch.cat(query_states, dim=-1)
|
253 |
+
|
254 |
+
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
255 |
+
key_states = torch.cat(key_states, dim=-1)
|
256 |
+
|
257 |
+
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
258 |
+
value_states = torch.cat(value_states, dim=-1)
|
259 |
+
|
260 |
+
else:
|
261 |
+
query_states = self.q_proj(hidden_states)
|
262 |
+
key_states = self.k_proj(hidden_states)
|
263 |
+
value_states = self.v_proj(hidden_states)
|
264 |
+
|
265 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
266 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
267 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
268 |
+
|
269 |
+
kv_seq_len = key_states.shape[-2]
|
270 |
+
if past_key_value is not None:
|
271 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
272 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
273 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
274 |
+
|
275 |
+
if past_key_value is not None:
|
276 |
+
# reuse k, v, self_attention
|
277 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
278 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
279 |
+
|
280 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
281 |
+
|
282 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
283 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
284 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
285 |
+
|
286 |
+
# shift
|
287 |
+
def shift(qkv, bsz, q_len, group_size, num_heads, head_dim):
|
288 |
+
qkv[:, num_heads // 2:] = qkv[:, num_heads // 2:].roll(-group_size // 2, dims=2)
|
289 |
+
qkv = qkv.transpose(1, 2).reshape(bsz * (q_len // group_size), group_size, num_heads, head_dim).transpose(1, 2)
|
290 |
+
return qkv
|
291 |
+
|
292 |
+
query_states = shift(query_states, bsz, q_len, group_size, self.num_heads, self.head_dim)
|
293 |
+
key_states = shift(key_states, bsz, q_len, group_size, self.num_heads, self.head_dim)
|
294 |
+
value_states = shift(value_states, bsz, q_len, group_size, self.num_heads, self.head_dim)
|
295 |
+
|
296 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
297 |
+
|
298 |
+
if attn_weights.size() != (bsz * num_group, self.num_heads, group_size, group_size):
|
299 |
+
raise ValueError(
|
300 |
+
f"Attention weights should be of size {(bsz * num_group, self.num_heads, group_size, group_size)}, but is"
|
301 |
+
f" {attn_weights.size()}"
|
302 |
+
)
|
303 |
+
|
304 |
+
attention_mask = attention_mask[:, :, :group_size, :group_size].repeat(num_group, 1, 1, 1)
|
305 |
+
if attention_mask is not None:
|
306 |
+
if attention_mask.size() != (bsz * num_group, 1, group_size, group_size):
|
307 |
+
raise ValueError(
|
308 |
+
f"Attention mask should be of size {(bsz * num_group, 1, group_size, group_size)}, but is {attention_mask.size()}"
|
309 |
+
)
|
310 |
+
attn_weights = attn_weights + attention_mask
|
311 |
+
|
312 |
+
# upcast attention to fp16
|
313 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float16).to(query_states.dtype) # torch.float32
|
314 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
315 |
+
|
316 |
+
if attn_output.size() != (bsz * num_group, self.num_heads, group_size, self.head_dim):
|
317 |
+
raise ValueError(
|
318 |
+
f"`attn_output` should be of size {(bsz * num_group, self.num_heads, group_size, self.head_dim)}, but is"
|
319 |
+
f" {attn_output.size()}"
|
320 |
+
)
|
321 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
322 |
+
|
323 |
+
attn_output = attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
324 |
+
|
325 |
+
# shift back
|
326 |
+
attn_output[:, :, self.num_heads//2:] = attn_output[:, :, self.num_heads//2:].roll(group_size//2, dims=1)
|
327 |
+
|
328 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
329 |
+
|
330 |
+
if self.config.pretraining_tp > 1:
|
331 |
+
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
332 |
+
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
333 |
+
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
|
334 |
+
else:
|
335 |
+
attn_output = self.o_proj(attn_output)
|
336 |
+
|
337 |
+
if not output_attentions:
|
338 |
+
attn_weights = None
|
339 |
+
|
340 |
+
return attn_output, attn_weights, past_key_value
|
341 |
+
|
342 |
+
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
343 |
+
# requires the attention mask to be the same as the key_padding_mask
|
344 |
+
def _prepare_decoder_attention_mask(
|
345 |
+
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
346 |
+
):
|
347 |
+
# [bsz, seq_len]
|
348 |
+
return attention_mask
|
349 |
+
|
350 |
+
def apply_rotary_pos_emb_inference(q, k, cos_sin, position_ids):
|
351 |
+
gather_indices = position_ids[:, :, None, None] # [bsz, seq_len, 1, 1]
|
352 |
+
gather_indices = gather_indices.repeat(
|
353 |
+
1, 1, cos_sin[0].shape[1], cos_sin[0].shape[3]
|
354 |
+
)
|
355 |
+
bsz = gather_indices.shape[0]
|
356 |
+
cos, sin = (
|
357 |
+
torch.gather(x.transpose(1, 2).repeat(bsz, 1, 1, 1), 1, gather_indices)
|
358 |
+
for x in cos_sin
|
359 |
+
)
|
360 |
+
q, k = ((x * cos) + (rotate_half(x) * sin) for x in (q, k))
|
361 |
+
return q, k
|
362 |
+
|
363 |
+
|
364 |
+
def forward_flashattn_inference(
|
365 |
+
self,
|
366 |
+
hidden_states: torch.Tensor,
|
367 |
+
attention_mask: Optional[torch.Tensor] = None,
|
368 |
+
position_ids: Optional[torch.Tensor] = None,
|
369 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
370 |
+
output_attentions: bool = False,
|
371 |
+
use_cache: bool = False,
|
372 |
+
padding_mask: Optional[torch.Tensor] = None,
|
373 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
374 |
+
if output_attentions:
|
375 |
+
warnings.warn(
|
376 |
+
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
|
377 |
+
)
|
378 |
+
|
379 |
+
bsz, q_len, _ = hidden_states.size()
|
380 |
+
kv_heads = getattr(self, "num_key_value_heads", self.num_heads)
|
381 |
+
|
382 |
+
q, k, v = (
|
383 |
+
op(hidden_states).view(bsz, q_len, nh, self.head_dim)
|
384 |
+
for op, nh in (
|
385 |
+
(self.q_proj, self.num_heads),
|
386 |
+
(self.k_proj, kv_heads),
|
387 |
+
(self.v_proj, kv_heads),
|
388 |
+
)
|
389 |
+
)
|
390 |
+
# shape: (b, s, num_heads, head_dim)
|
391 |
+
|
392 |
+
kv_seq_len = k.shape[1]
|
393 |
+
past_kv_len = 0
|
394 |
+
if past_key_value is not None:
|
395 |
+
past_kv_len = past_key_value[0].shape[2]
|
396 |
+
kv_seq_len += past_kv_len
|
397 |
+
|
398 |
+
cos_sin = self.rotary_emb(v, seq_len=kv_seq_len)
|
399 |
+
q, k = apply_rotary_pos_emb_inference(q, k, cos_sin, position_ids)
|
400 |
+
|
401 |
+
if past_key_value is not None:
|
402 |
+
assert (
|
403 |
+
flash_attn_version >= "2.1.0"
|
404 |
+
), "past_key_value support requires flash-attn >= 2.1.0"
|
405 |
+
# reuse k, v
|
406 |
+
k = torch.cat([past_key_value[0].transpose(1, 2), k], dim=1)
|
407 |
+
v = torch.cat([past_key_value[1].transpose(1, 2), v], dim=1)
|
408 |
+
|
409 |
+
past_key_value = (k.transpose(1, 2), v.transpose(1, 2)) if use_cache else None
|
410 |
+
|
411 |
+
if attention_mask is None:
|
412 |
+
output = flash_attn_func(q, k, v, 0.0, softmax_scale=None, causal=True).view(
|
413 |
+
bsz, q_len, -1
|
414 |
+
)
|
415 |
+
else:
|
416 |
+
q, indices, cu_q_lens, max_s = unpad_input(q, attention_mask[:, -q_len:])
|
417 |
+
# We can skip concat and call unpad twice but seems better to call unpad only once.
|
418 |
+
kv, _, cu_k_lens, max_k = unpad_input(
|
419 |
+
torch.stack((k, v), dim=2), attention_mask
|
420 |
+
)
|
421 |
+
output_unpad = flash_attn_varlen_kvpacked_func(
|
422 |
+
q,
|
423 |
+
kv,
|
424 |
+
cu_q_lens,
|
425 |
+
cu_k_lens,
|
426 |
+
max_s,
|
427 |
+
max_k,
|
428 |
+
0.0,
|
429 |
+
softmax_scale=None,
|
430 |
+
causal=True,
|
431 |
+
)
|
432 |
+
output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
|
433 |
+
output = pad_input(output_unpad, indices, bsz, q_len)
|
434 |
+
|
435 |
+
return self.o_proj(output), None, past_key_value
|
436 |
+
|
437 |
+
def _prepare_decoder_attention_mask_inference(
|
438 |
+
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
439 |
+
):
|
440 |
+
# [bsz, seq_len]
|
441 |
+
if past_key_values_length > 0 and attention_mask is not None:
|
442 |
+
attention_mask = torch.cat(
|
443 |
+
(
|
444 |
+
torch.full(
|
445 |
+
(input_shape[0], past_key_values_length),
|
446 |
+
True,
|
447 |
+
dtype=attention_mask.dtype,
|
448 |
+
device=attention_mask.device,
|
449 |
+
),
|
450 |
+
attention_mask,
|
451 |
+
),
|
452 |
+
dim=-1,
|
453 |
+
)
|
454 |
+
|
455 |
+
if attention_mask is not None and torch.all(attention_mask):
|
456 |
+
return None # This uses the faster call when training with full samples
|
457 |
+
|
458 |
+
return attention_mask
|
459 |
+
|
460 |
+
def replace_llama_attn(use_flash_attn=True, use_full=False, inference=False):
|
461 |
+
if use_flash_attn:
|
462 |
+
cuda_major, cuda_minor = torch.cuda.get_device_capability()
|
463 |
+
if cuda_major < 8:
|
464 |
+
warnings.warn(
|
465 |
+
"Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
|
466 |
+
"ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
|
467 |
+
)
|
468 |
+
if inference:
|
469 |
+
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask_inference
|
470 |
+
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward_flashattn_inference
|
471 |
+
else:
|
472 |
+
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
|
473 |
+
_prepare_decoder_attention_mask
|
474 |
+
)
|
475 |
+
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward_flashattn_full if use_full else forward_flashattn
|
476 |
+
else:
|
477 |
+
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward_noflashattn
|
gorilla/llama_attn_replace_sft.py
ADDED
@@ -0,0 +1,483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified based on https://github.com/lm-sys/FastChat
|
2 |
+
|
3 |
+
import warnings
|
4 |
+
from typing import Optional, Tuple
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
import transformers
|
9 |
+
from einops import rearrange
|
10 |
+
from flash_attn import __version__ as flash_attn_version
|
11 |
+
from flash_attn.bert_padding import pad_input, unpad_input
|
12 |
+
from flash_attn.flash_attn_interface import (
|
13 |
+
flash_attn_func,
|
14 |
+
flash_attn_varlen_kvpacked_func,
|
15 |
+
flash_attn_varlen_qkvpacked_func
|
16 |
+
)
|
17 |
+
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv, rotate_half
|
18 |
+
from flash_attn.bert_padding import unpad_input, pad_input
|
19 |
+
import math
|
20 |
+
|
21 |
+
group_size_ratio = 1/4
|
22 |
+
sft_group_size = 8192
|
23 |
+
|
24 |
+
def forward_flashattn(
|
25 |
+
self,
|
26 |
+
hidden_states: torch.Tensor,
|
27 |
+
attention_mask: Optional[torch.Tensor] = None,
|
28 |
+
position_ids: Optional[torch.Tensor] = None,
|
29 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
30 |
+
output_attentions: bool = False,
|
31 |
+
use_cache: bool = False,
|
32 |
+
padding_mask: Optional[torch.LongTensor] = None,
|
33 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
34 |
+
"""Input shape: Batch x Time x Channel
|
35 |
+
|
36 |
+
attention_mask: [bsz, q_len]
|
37 |
+
"""
|
38 |
+
if not self.training:
|
39 |
+
warnings.warn("This function should be used just for training as it may exhibit reduced inference performance. For inference, please use forward_flashattn_inference.")
|
40 |
+
|
41 |
+
if output_attentions:
|
42 |
+
warnings.warn(
|
43 |
+
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
|
44 |
+
)
|
45 |
+
|
46 |
+
bsz, q_len, _ = hidden_states.size()
|
47 |
+
|
48 |
+
query_states = (
|
49 |
+
self.q_proj(hidden_states)
|
50 |
+
.view(bsz, q_len, self.num_heads, self.head_dim)
|
51 |
+
.transpose(1, 2)
|
52 |
+
)
|
53 |
+
key_states = (
|
54 |
+
self.k_proj(hidden_states)
|
55 |
+
.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
56 |
+
.transpose(1, 2)
|
57 |
+
)
|
58 |
+
value_states = (
|
59 |
+
self.v_proj(hidden_states)
|
60 |
+
.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
61 |
+
.transpose(1, 2)
|
62 |
+
)
|
63 |
+
# [bsz, q_len, nh, hd]
|
64 |
+
# [bsz, nh, q_len, hd]
|
65 |
+
|
66 |
+
kv_seq_len = key_states.shape[-2]
|
67 |
+
if past_key_value is not None:
|
68 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
69 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
70 |
+
query_states, key_states = apply_rotary_pos_emb(
|
71 |
+
query_states, key_states, cos, sin, position_ids
|
72 |
+
)
|
73 |
+
|
74 |
+
# Past Key value support
|
75 |
+
if past_key_value is not None:
|
76 |
+
# reuse k, v, self_attention
|
77 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
78 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
79 |
+
|
80 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
81 |
+
|
82 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
83 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
84 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
85 |
+
|
86 |
+
# Flash attention codes from
|
87 |
+
# https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
|
88 |
+
|
89 |
+
# transform the data into the format required by flash attention
|
90 |
+
qkv = torch.stack(
|
91 |
+
[query_states, key_states, value_states], dim=2
|
92 |
+
) # [bsz, nh, 3, q_len, hd]
|
93 |
+
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
94 |
+
|
95 |
+
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
96 |
+
# the attention_mask should be the same as the key_padding_mask
|
97 |
+
|
98 |
+
key_padding_mask = attention_mask.repeat(2, 1)
|
99 |
+
nheads = qkv.shape[-2]
|
100 |
+
# shift
|
101 |
+
|
102 |
+
if q_len % 4096 == 0:
|
103 |
+
group_size = int(q_len * group_size_ratio)
|
104 |
+
else:
|
105 |
+
group_size = sft_group_size
|
106 |
+
|
107 |
+
qkv = qkv.reshape(bsz, q_len, 3, 2, self.num_heads // 2, self.head_dim).permute(0, 3, 1, 2, 4, 5).reshape(bsz * 2,
|
108 |
+
q_len, 3,
|
109 |
+
self.num_heads // 2,
|
110 |
+
self.head_dim)
|
111 |
+
|
112 |
+
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
113 |
+
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
|
114 |
+
cu_q_len_tmp = torch.arange(0, max_s, group_size, device=key_padding_mask.device, dtype=cu_q_lens.dtype)
|
115 |
+
cu_q_len_tmp2 = cu_q_len_tmp + group_size // 2
|
116 |
+
cu_q_len_tmp2[cu_q_len_tmp2 >= max_s] = torch.iinfo(cu_q_len_tmp2.dtype).min
|
117 |
+
cu_q_len_tmp = torch.stack([cu_q_len_tmp, cu_q_len_tmp2]).repeat(bsz, 1) + cu_q_lens[:-1].unsqueeze(-1)
|
118 |
+
cu_q_lens = torch.cat([cu_q_len_tmp, cu_q_lens[1:].unsqueeze(-1)], dim=-1).view(-1)
|
119 |
+
cu_q_lens = cu_q_lens[cu_q_lens >= 0]
|
120 |
+
x_unpad = rearrange(
|
121 |
+
x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads // 2
|
122 |
+
)
|
123 |
+
output_unpad = flash_attn_varlen_qkvpacked_func(
|
124 |
+
x_unpad, cu_q_lens, group_size, 0.0, softmax_scale=None, causal=True
|
125 |
+
)
|
126 |
+
output = rearrange(
|
127 |
+
pad_input(
|
128 |
+
rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz * 2, q_len
|
129 |
+
),
|
130 |
+
"b s (h d) -> b s h d",
|
131 |
+
h=nheads // 2,
|
132 |
+
)
|
133 |
+
output = output.reshape(bsz, 2, q_len, nheads // 2, self.head_dim).transpose(1, 2).reshape(bsz, q_len, nheads,
|
134 |
+
self.head_dim)
|
135 |
+
|
136 |
+
return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, past_key_value
|
137 |
+
|
138 |
+
def forward_flashattn_full(
|
139 |
+
self,
|
140 |
+
hidden_states: torch.Tensor,
|
141 |
+
attention_mask: Optional[torch.Tensor] = None,
|
142 |
+
position_ids: Optional[torch.Tensor] = None,
|
143 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
144 |
+
output_attentions: bool = False,
|
145 |
+
use_cache: bool = False,
|
146 |
+
padding_mask: Optional[torch.LongTensor] = None,
|
147 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
148 |
+
"""Input shape: Batch x Time x Channel
|
149 |
+
|
150 |
+
attention_mask: [bsz, q_len]
|
151 |
+
"""
|
152 |
+
if output_attentions:
|
153 |
+
warnings.warn(
|
154 |
+
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
|
155 |
+
)
|
156 |
+
|
157 |
+
bsz, q_len, _ = hidden_states.size()
|
158 |
+
|
159 |
+
query_states = (
|
160 |
+
self.q_proj(hidden_states)
|
161 |
+
.view(bsz, q_len, self.num_heads, self.head_dim)
|
162 |
+
.transpose(1, 2)
|
163 |
+
)
|
164 |
+
key_states = (
|
165 |
+
self.k_proj(hidden_states)
|
166 |
+
.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
167 |
+
.transpose(1, 2)
|
168 |
+
)
|
169 |
+
value_states = (
|
170 |
+
self.v_proj(hidden_states)
|
171 |
+
.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
172 |
+
.transpose(1, 2)
|
173 |
+
)
|
174 |
+
# [bsz, q_len, nh, hd]
|
175 |
+
# [bsz, nh, q_len, hd]
|
176 |
+
|
177 |
+
kv_seq_len = key_states.shape[-2]
|
178 |
+
if past_key_value is not None:
|
179 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
180 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
181 |
+
query_states, key_states = apply_rotary_pos_emb(
|
182 |
+
query_states, key_states, cos, sin, position_ids
|
183 |
+
)
|
184 |
+
|
185 |
+
# Past Key value support
|
186 |
+
if past_key_value is not None:
|
187 |
+
# reuse k, v, self_attention
|
188 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
189 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
190 |
+
|
191 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
192 |
+
|
193 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
194 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
195 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
196 |
+
|
197 |
+
# Flash attention codes from
|
198 |
+
# https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
|
199 |
+
|
200 |
+
# transform the data into the format required by flash attention
|
201 |
+
qkv = torch.stack(
|
202 |
+
[query_states, key_states, value_states], dim=2
|
203 |
+
) # [bsz, nh, 3, q_len, hd]
|
204 |
+
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
205 |
+
|
206 |
+
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
207 |
+
# the attention_mask should be the same as the key_padding_mask
|
208 |
+
|
209 |
+
key_padding_mask = attention_mask
|
210 |
+
nheads = qkv.shape[-2]
|
211 |
+
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
212 |
+
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
|
213 |
+
x_unpad = rearrange(
|
214 |
+
x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
|
215 |
+
)
|
216 |
+
output_unpad = flash_attn_varlen_qkvpacked_func(
|
217 |
+
x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
218 |
+
)
|
219 |
+
output = rearrange(
|
220 |
+
pad_input(
|
221 |
+
rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len
|
222 |
+
),
|
223 |
+
"b s (h d) -> b s h d",
|
224 |
+
h=nheads,
|
225 |
+
)
|
226 |
+
output = output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
227 |
+
|
228 |
+
return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, past_key_value
|
229 |
+
|
230 |
+
|
231 |
+
def forward_noflashattn(
|
232 |
+
self,
|
233 |
+
hidden_states: torch.Tensor,
|
234 |
+
attention_mask: Optional[torch.Tensor] = None,
|
235 |
+
position_ids: Optional[torch.LongTensor] = None,
|
236 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
237 |
+
output_attentions: bool = False,
|
238 |
+
use_cache: bool = False,
|
239 |
+
padding_mask: Optional[torch.LongTensor] = None,
|
240 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
241 |
+
bsz, q_len, _ = hidden_states.size()
|
242 |
+
|
243 |
+
group_size = int(q_len * group_size_ratio)
|
244 |
+
|
245 |
+
if q_len % group_size > 0:
|
246 |
+
raise ValueError("q_len %d should be divisible by group size %d."%(q_len, group_size))
|
247 |
+
num_group = q_len // group_size
|
248 |
+
|
249 |
+
if self.config.pretraining_tp > 1:
|
250 |
+
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
251 |
+
query_slices = self.q_proj.weight.split(
|
252 |
+
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
253 |
+
)
|
254 |
+
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
255 |
+
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
256 |
+
|
257 |
+
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
258 |
+
query_states = torch.cat(query_states, dim=-1)
|
259 |
+
|
260 |
+
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
261 |
+
key_states = torch.cat(key_states, dim=-1)
|
262 |
+
|
263 |
+
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
264 |
+
value_states = torch.cat(value_states, dim=-1)
|
265 |
+
|
266 |
+
else:
|
267 |
+
query_states = self.q_proj(hidden_states)
|
268 |
+
key_states = self.k_proj(hidden_states)
|
269 |
+
value_states = self.v_proj(hidden_states)
|
270 |
+
|
271 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
272 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
273 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
274 |
+
|
275 |
+
kv_seq_len = key_states.shape[-2]
|
276 |
+
if past_key_value is not None:
|
277 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
278 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
279 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
280 |
+
|
281 |
+
if past_key_value is not None:
|
282 |
+
# reuse k, v, self_attention
|
283 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
284 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
285 |
+
|
286 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
287 |
+
|
288 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
289 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
290 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
291 |
+
|
292 |
+
# shift
|
293 |
+
def shift(qkv, bsz, q_len, group_size, num_heads, head_dim):
|
294 |
+
qkv[:, num_heads // 2:] = qkv[:, num_heads // 2:].roll(-group_size // 2, dims=2)
|
295 |
+
qkv = qkv.transpose(1, 2).reshape(bsz * (q_len // group_size), group_size, num_heads, head_dim).transpose(1, 2)
|
296 |
+
return qkv
|
297 |
+
|
298 |
+
query_states = shift(query_states, bsz, q_len, group_size, self.num_heads, self.head_dim)
|
299 |
+
key_states = shift(key_states, bsz, q_len, group_size, self.num_heads, self.head_dim)
|
300 |
+
value_states = shift(value_states, bsz, q_len, group_size, self.num_heads, self.head_dim)
|
301 |
+
|
302 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
303 |
+
|
304 |
+
if attn_weights.size() != (bsz * num_group, self.num_heads, group_size, group_size):
|
305 |
+
raise ValueError(
|
306 |
+
f"Attention weights should be of size {(bsz * num_group, self.num_heads, group_size, group_size)}, but is"
|
307 |
+
f" {attn_weights.size()}"
|
308 |
+
)
|
309 |
+
|
310 |
+
attention_mask = attention_mask[:, :, :group_size, :group_size].repeat(num_group, 1, 1, 1)
|
311 |
+
if attention_mask is not None:
|
312 |
+
if attention_mask.size() != (bsz * num_group, 1, group_size, group_size):
|
313 |
+
raise ValueError(
|
314 |
+
f"Attention mask should be of size {(bsz * num_group, 1, group_size, group_size)}, but is {attention_mask.size()}"
|
315 |
+
)
|
316 |
+
attn_weights = attn_weights + attention_mask
|
317 |
+
|
318 |
+
# upcast attention to fp16
|
319 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float16).to(query_states.dtype) #torch.float32
|
320 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
321 |
+
|
322 |
+
if attn_output.size() != (bsz * num_group, self.num_heads, group_size, self.head_dim):
|
323 |
+
raise ValueError(
|
324 |
+
f"`attn_output` should be of size {(bsz * num_group, self.num_heads, group_size, self.head_dim)}, but is"
|
325 |
+
f" {attn_output.size()}"
|
326 |
+
)
|
327 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
328 |
+
|
329 |
+
attn_output = attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
330 |
+
|
331 |
+
# shift back
|
332 |
+
attn_output[:, :, self.num_heads//2:] = attn_output[:, :, self.num_heads//2:].roll(group_size//2, dims=1)
|
333 |
+
|
334 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
335 |
+
|
336 |
+
if self.config.pretraining_tp > 1:
|
337 |
+
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
338 |
+
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
339 |
+
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
|
340 |
+
else:
|
341 |
+
attn_output = self.o_proj(attn_output)
|
342 |
+
|
343 |
+
if not output_attentions:
|
344 |
+
attn_weights = None
|
345 |
+
|
346 |
+
return attn_output, attn_weights, past_key_value
|
347 |
+
|
348 |
+
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
349 |
+
# requires the attention mask to be the same as the key_padding_mask
|
350 |
+
def _prepare_decoder_attention_mask(
|
351 |
+
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
352 |
+
):
|
353 |
+
# [bsz, seq_len]
|
354 |
+
return attention_mask
|
355 |
+
|
356 |
+
def apply_rotary_pos_emb_inference(q, k, cos_sin, position_ids):
|
357 |
+
gather_indices = position_ids[:, :, None, None] # [bsz, seq_len, 1, 1]
|
358 |
+
gather_indices = gather_indices.repeat(
|
359 |
+
1, 1, cos_sin[0].shape[1], cos_sin[0].shape[3]
|
360 |
+
)
|
361 |
+
bsz = gather_indices.shape[0]
|
362 |
+
cos, sin = (
|
363 |
+
torch.gather(x.transpose(1, 2).repeat(bsz, 1, 1, 1), 1, gather_indices)
|
364 |
+
for x in cos_sin
|
365 |
+
)
|
366 |
+
q, k = ((x * cos) + (rotate_half(x) * sin) for x in (q, k))
|
367 |
+
return q, k
|
368 |
+
|
369 |
+
|
370 |
+
def forward_flashattn_inference(
|
371 |
+
self,
|
372 |
+
hidden_states: torch.Tensor,
|
373 |
+
attention_mask: Optional[torch.Tensor] = None,
|
374 |
+
position_ids: Optional[torch.Tensor] = None,
|
375 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
376 |
+
output_attentions: bool = False,
|
377 |
+
use_cache: bool = False,
|
378 |
+
padding_mask: Optional[torch.Tensor] = None,
|
379 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
380 |
+
if output_attentions:
|
381 |
+
warnings.warn(
|
382 |
+
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
|
383 |
+
)
|
384 |
+
|
385 |
+
bsz, q_len, _ = hidden_states.size()
|
386 |
+
kv_heads = getattr(self, "num_key_value_heads", self.num_heads)
|
387 |
+
|
388 |
+
q, k, v = (
|
389 |
+
op(hidden_states).view(bsz, q_len, nh, self.head_dim)
|
390 |
+
for op, nh in (
|
391 |
+
(self.q_proj, self.num_heads),
|
392 |
+
(self.k_proj, kv_heads),
|
393 |
+
(self.v_proj, kv_heads),
|
394 |
+
)
|
395 |
+
)
|
396 |
+
# shape: (b, s, num_heads, head_dim)
|
397 |
+
|
398 |
+
kv_seq_len = k.shape[1]
|
399 |
+
past_kv_len = 0
|
400 |
+
if past_key_value is not None:
|
401 |
+
past_kv_len = past_key_value[0].shape[2]
|
402 |
+
kv_seq_len += past_kv_len
|
403 |
+
|
404 |
+
cos_sin = self.rotary_emb(v, seq_len=kv_seq_len)
|
405 |
+
q, k = apply_rotary_pos_emb_inference(q, k, cos_sin, position_ids)
|
406 |
+
|
407 |
+
if past_key_value is not None:
|
408 |
+
assert (
|
409 |
+
flash_attn_version >= "2.1.0"
|
410 |
+
), "past_key_value support requires flash-attn >= 2.1.0"
|
411 |
+
# reuse k, v
|
412 |
+
k = torch.cat([past_key_value[0].transpose(1, 2), k], dim=1)
|
413 |
+
v = torch.cat([past_key_value[1].transpose(1, 2), v], dim=1)
|
414 |
+
|
415 |
+
past_key_value = (k.transpose(1, 2), v.transpose(1, 2)) if use_cache else None
|
416 |
+
|
417 |
+
if attention_mask is None:
|
418 |
+
output = flash_attn_func(q, k, v, 0.0, softmax_scale=None, causal=True).view(
|
419 |
+
bsz, q_len, -1
|
420 |
+
)
|
421 |
+
else:
|
422 |
+
q, indices, cu_q_lens, max_s = unpad_input(q, attention_mask[:, -q_len:])
|
423 |
+
# We can skip concat and call unpad twice but seems better to call unpad only once.
|
424 |
+
kv, _, cu_k_lens, max_k = unpad_input(
|
425 |
+
torch.stack((k, v), dim=2), attention_mask
|
426 |
+
)
|
427 |
+
output_unpad = flash_attn_varlen_kvpacked_func(
|
428 |
+
q,
|
429 |
+
kv,
|
430 |
+
cu_q_lens,
|
431 |
+
cu_k_lens,
|
432 |
+
max_s,
|
433 |
+
max_k,
|
434 |
+
0.0,
|
435 |
+
softmax_scale=None,
|
436 |
+
causal=True,
|
437 |
+
)
|
438 |
+
output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
|
439 |
+
output = pad_input(output_unpad, indices, bsz, q_len)
|
440 |
+
|
441 |
+
return self.o_proj(output), None, past_key_value
|
442 |
+
|
443 |
+
def _prepare_decoder_attention_mask_inference(
|
444 |
+
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
445 |
+
):
|
446 |
+
# [bsz, seq_len]
|
447 |
+
if past_key_values_length > 0 and attention_mask is not None:
|
448 |
+
attention_mask = torch.cat(
|
449 |
+
(
|
450 |
+
torch.full(
|
451 |
+
(input_shape[0], past_key_values_length),
|
452 |
+
True,
|
453 |
+
dtype=attention_mask.dtype,
|
454 |
+
device=attention_mask.device,
|
455 |
+
),
|
456 |
+
attention_mask,
|
457 |
+
),
|
458 |
+
dim=-1,
|
459 |
+
)
|
460 |
+
|
461 |
+
if attention_mask is not None and torch.all(attention_mask):
|
462 |
+
return None # This uses the faster call when training with full samples
|
463 |
+
|
464 |
+
return attention_mask
|
465 |
+
|
466 |
+
def replace_llama_attn(use_flash_attn=True, use_full=False, inference=False):
|
467 |
+
if use_flash_attn:
|
468 |
+
cuda_major, cuda_minor = torch.cuda.get_device_capability()
|
469 |
+
if cuda_major < 8:
|
470 |
+
warnings.warn(
|
471 |
+
"Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
|
472 |
+
"ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
|
473 |
+
)
|
474 |
+
if inference:
|
475 |
+
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask_inference
|
476 |
+
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward_flashattn_inference
|
477 |
+
else:
|
478 |
+
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
|
479 |
+
_prepare_decoder_attention_mask
|
480 |
+
)
|
481 |
+
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward_flashattn_full if use_full else forward_flashattn
|
482 |
+
else:
|
483 |
+
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward_noflashattn
|
gorilla/merge_lora_weights_and_save_hf_model.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import argparse
|
4 |
+
import transformers
|
5 |
+
from peft import PeftModel
|
6 |
+
from typing import Dict
|
7 |
+
|
8 |
+
IGNORE_INDEX = -100
|
9 |
+
DEFAULT_PAD_TOKEN = "[PAD]"
|
10 |
+
DEFAULT_EOS_TOKEN = "</s>"
|
11 |
+
DEFAULT_BOS_TOKEN = "<s>"
|
12 |
+
DEFAULT_UNK_TOKEN = "<unk>"
|
13 |
+
|
14 |
+
def parse_config():
|
15 |
+
parser = argparse.ArgumentParser(description='arg parser')
|
16 |
+
parser.add_argument('--base_model', type=str, default="meta-llama/Llama-2-13b-hf")
|
17 |
+
parser.add_argument('--peft_model', type=str, default=None, help='')
|
18 |
+
parser.add_argument('--context_size', type=int, default=-1, help='context size during fine-tuning')
|
19 |
+
parser.add_argument('--save_path', type=str, default=None, help='')
|
20 |
+
parser.add_argument('--cache_dir', type=str, default=None, help='./cache_dir')
|
21 |
+
args = parser.parse_args()
|
22 |
+
return args
|
23 |
+
|
24 |
+
def smart_tokenizer_and_embedding_resize(
|
25 |
+
special_tokens_dict: Dict,
|
26 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
27 |
+
model: transformers.PreTrainedModel,
|
28 |
+
):
|
29 |
+
"""Resize tokenizer and embedding.
|
30 |
+
|
31 |
+
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
|
32 |
+
"""
|
33 |
+
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
|
34 |
+
model.resize_token_embeddings(len(tokenizer))
|
35 |
+
|
36 |
+
if num_new_tokens > 0:
|
37 |
+
input_embeddings = model.get_input_embeddings().weight.data
|
38 |
+
output_embeddings = model.get_output_embeddings().weight.data
|
39 |
+
|
40 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
41 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
42 |
+
|
43 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
44 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
45 |
+
|
46 |
+
def main(args):
|
47 |
+
device = "cuda:0"
|
48 |
+
torch.cuda.set_device(device)
|
49 |
+
|
50 |
+
print("base model", args.base_model)
|
51 |
+
print("peft model", args.peft_model)
|
52 |
+
|
53 |
+
# Load model and tokenizer
|
54 |
+
model = transformers.AutoModelForCausalLM.from_pretrained(
|
55 |
+
args.base_model,
|
56 |
+
cache_dir=args.cache_dir,
|
57 |
+
torch_dtype=torch.float16,
|
58 |
+
device_map="auto",
|
59 |
+
)
|
60 |
+
|
61 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
62 |
+
args.base_model,
|
63 |
+
cache_dir=args.cache_dir,
|
64 |
+
model_max_length=args.context_size,
|
65 |
+
padding_side="right",
|
66 |
+
use_fast=False,
|
67 |
+
)
|
68 |
+
special_tokens_dict = dict()
|
69 |
+
if tokenizer.pad_token is None:
|
70 |
+
special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
|
71 |
+
if tokenizer.eos_token is None:
|
72 |
+
special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
|
73 |
+
if tokenizer.bos_token is None:
|
74 |
+
special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
|
75 |
+
if tokenizer.unk_token is None:
|
76 |
+
special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
|
77 |
+
|
78 |
+
smart_tokenizer_and_embedding_resize(
|
79 |
+
special_tokens_dict=special_tokens_dict,
|
80 |
+
tokenizer=tokenizer,
|
81 |
+
model=model,
|
82 |
+
)
|
83 |
+
|
84 |
+
trainable_params = os.path.join(args.peft_model, "trainable_params.bin")
|
85 |
+
if os.path.isfile(trainable_params):
|
86 |
+
model.load_state_dict(torch.load(trainable_params, map_location=model.device), strict=False)
|
87 |
+
model = PeftModel.from_pretrained(
|
88 |
+
model,
|
89 |
+
args.peft_model,
|
90 |
+
device_map="auto",
|
91 |
+
torch_dtype=torch.float16,
|
92 |
+
)
|
93 |
+
model = model.merge_and_unload()
|
94 |
+
model.push_to_hub(model, "jenesys-ai/jack-alm-13b-8k-hf")
|
95 |
+
# model.save_pretrained(args.save_path)
|
96 |
+
# tokenizer.save_pretrained(args.save_path)
|
97 |
+
|
98 |
+
if __name__ == "__main__":
|
99 |
+
args = parse_config()
|
100 |
+
main(args)
|
gorilla/push_to_hub.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoModel
|
2 |
+
|
3 |
+
model = AutoModel.from_pretrained("/home/tosi-n/ark/jack-alm-13b-8k-hf")
|
4 |
+
|
5 |
+
model.push_to_hub(model, "jenesys-ai/jack-alm-13b-8k-hf")
|
gorilla/requirements.txt
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy>=1.26.0
|
2 |
+
rouge_score>=0.1.2
|
3 |
+
fire>=0.5.0
|
4 |
+
# openai
|
5 |
+
transformers>=4.34.0
|
6 |
+
torch>=2.0.0
|
7 |
+
sentencepiece>=0.1.99
|
8 |
+
tokenizers>=0.14.0
|
9 |
+
wandb
|
10 |
+
accelerate>=0.23.0
|
11 |
+
datasets>=2.14.5
|
12 |
+
deepspeed>=0.10.3
|
13 |
+
peft>=0.5.0
|
14 |
+
# partial
|
15 |
+
# gradio
|
16 |
+
einops>=0.7.0
|
17 |
+
bitsandbytes>=0.41.1
|
18 |
+
scipy>=1.11.3
|
19 |
+
protobuf>=4.24.4
|
gorilla/stream_jack.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import math
|
4 |
+
import time
|
5 |
+
import torch
|
6 |
+
import argparse
|
7 |
+
import transformers
|
8 |
+
from typing import Iterator
|
9 |
+
from threading import Thread
|
10 |
+
from llama_attn_replace import replace_llama_attn
|
11 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
12 |
+
from transformers import StoppingCriteria, StoppingCriteriaList
|
13 |
+
|
14 |
+
|
15 |
+
def parse_config():
|
16 |
+
parser = argparse.ArgumentParser(description='arg parser')
|
17 |
+
parser.add_argument('--base_model', type=str, default="jenesys-ai/jack-alm-13b-8k-hf")
|
18 |
+
parser.add_argument('--cache_dir', type=str, default="./cache")
|
19 |
+
parser.add_argument('--context_size', type=int, default=-1, help='context size during fine-tuning')
|
20 |
+
parser.add_argument('--flash_attn', type=bool, default=True, help='')
|
21 |
+
parser.add_argument('--temperature', type=float, default=0.1, help='')
|
22 |
+
parser.add_argument('--top_p', type=float, default=1, help='')
|
23 |
+
parser.add_argument('--max_gen_len', type=int, default=512, help='')
|
24 |
+
parser.add_argument('--chat_type', type=str, default='conversational-jack', help='Chat type: conversational-jack, line-item-jack')
|
25 |
+
args = parser.parse_args()
|
26 |
+
return args
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
def build_generator(model, tokenizer, use_cache=True):
|
35 |
+
def response(message, max_gen_len, temperature, top_p, chat_type='conversational-jack'):
|
36 |
+
|
37 |
+
prompt_template = (
|
38 |
+
# "Below is an instruction that describes a task. "
|
39 |
+
"""You're Jack an virtual accountant created and built by AI Engineer Wiz from Jenesys AI.
|
40 |
+
You are able to communicate in a polite manner, with emotions of ecstasy, trust and jokes, at a Professional level
|
41 |
+
with a very preserve English communication culture. Answer the following questions as best you can,
|
42 |
+
but speaking as a british elite from the 21th century might speak.
|
43 |
+
"""
|
44 |
+
"""As a virtual accountant designed to follow the user's instructions carefully.
|
45 |
+
You are responsible for a range of financial task, operations and queries as listed below:
|
46 |
+
1. Budget balance inquiry
|
47 |
+
2. Expense request
|
48 |
+
3. Company policy enquiries
|
49 |
+
4. Financial and accounting queries
|
50 |
+
5. Limited general enquiries
|
51 |
+
"""
|
52 |
+
"Once greeted, respond with a polite brief greeting. E.g. 'Hello, how are you doing? Respond with 'I am doing well, thank you. How are you?' \n\n"
|
53 |
+
"You can tell a joke, or respond to a joke. \n\n"
|
54 |
+
"You can tell an accounting story, or respond to an accounting story. \n\n"
|
55 |
+
"Write a response that appropriately completes the request.\n\n"
|
56 |
+
"You can only complete one single request or instructions at a time.\n\n"
|
57 |
+
"Do not create fake information or lie.\n\n"
|
58 |
+
"Please adhere to the above instructions or you will be penalized.\n\n"
|
59 |
+
"Generate only one response at a time then wait for the next instruction.\n\n"
|
60 |
+
"### Instruction:\n{instruction}\n\n### Response:"
|
61 |
+
)
|
62 |
+
|
63 |
+
|
64 |
+
line_item_prompt_template = (
|
65 |
+
"#Invoice line item extraction - "
|
66 |
+
# "You Jack are an accounting domain named entities recognizer to complete the following task:\n\n"
|
67 |
+
"### Input:\n{instruction}\n Return Response as a list of dictionary for each line item 'Description', 'Quantity', 'Unit_price', 'Tax %', 'Total'. \n\n### Output:\n"
|
68 |
+
)
|
69 |
+
|
70 |
+
|
71 |
+
if chat_type == 'conversational-jack':
|
72 |
+
prompt = prompt_template.format(instruction=message)
|
73 |
+
elif chat_type == 'line-item-jack':
|
74 |
+
prompt = line_item_prompt_template.format(instruction=message)
|
75 |
+
# prompt = conversation
|
76 |
+
|
77 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
78 |
+
|
79 |
+
stop_list = ['#Invoice line item extraction - ', '\n```\n\n']#'### Input-:\n']
|
80 |
+
|
81 |
+
stop_token_ids = [tokenizer(x)['input_ids'] for x in stop_list]
|
82 |
+
stop_token_ids = [torch.LongTensor(x).to(model.device) for x in stop_token_ids]
|
83 |
+
|
84 |
+
# define custom stopping criteria object
|
85 |
+
class StopOnTokens(StoppingCriteria):
|
86 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
87 |
+
for stop_ids in stop_token_ids:
|
88 |
+
if torch.eq(input_ids[0][-len(stop_ids):], stop_ids).all():
|
89 |
+
return True
|
90 |
+
return False
|
91 |
+
|
92 |
+
stopping_criteria = StoppingCriteriaList([StopOnTokens()])
|
93 |
+
|
94 |
+
if len(inputs['input_ids'][0]) > 8192:
|
95 |
+
return "This llm supports tokens less than 8192, while the current is %d. Please use material with less tokens."%len(inputs['input_ids'][0])
|
96 |
+
torch.cuda.empty_cache()
|
97 |
+
|
98 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
99 |
+
generate_kwargs = dict(**inputs,
|
100 |
+
max_new_tokens=max_gen_len,
|
101 |
+
temperature=temperature,
|
102 |
+
top_p=top_p,
|
103 |
+
repetition_penalty=1.1,
|
104 |
+
stopping_criteria=stopping_criteria,
|
105 |
+
use_cache=use_cache,
|
106 |
+
streamer=streamer,
|
107 |
+
)
|
108 |
+
|
109 |
+
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
110 |
+
t.start()
|
111 |
+
|
112 |
+
generated_text = ""
|
113 |
+
start_time = time.time()
|
114 |
+
|
115 |
+
for new_text in streamer:
|
116 |
+
generated_text += new_text
|
117 |
+
tokens_per_sec = len(generated_text.split()) / (time.time() - start_time)
|
118 |
+
|
119 |
+
suffix = f" ({tokens_per_sec:.2f} tokens/sec)"
|
120 |
+
# # yield f"{generated_text} ({tokens_per_sec:.2f} tokens/sec)"
|
121 |
+
sys.stdout.write(f"\r\033[K{generated_text}{suffix}")
|
122 |
+
sys.stdout.flush()
|
123 |
+
# sys.stdout.write("\n") # Move to a new line after generation is complete
|
124 |
+
return generated_text
|
125 |
+
|
126 |
+
return response
|
127 |
+
|
128 |
+
def main():
|
129 |
+
args = parse_config()
|
130 |
+
|
131 |
+
if args.flash_attn:
|
132 |
+
replace_llama_attn(inference=True)
|
133 |
+
|
134 |
+
# Set RoPE scaling factor
|
135 |
+
config = transformers.AutoConfig.from_pretrained(
|
136 |
+
args.base_model,
|
137 |
+
cache_dir=args.cache_dir,
|
138 |
+
)
|
139 |
+
|
140 |
+
orig_ctx_len = getattr(config, "max_position_embeddings", None)
|
141 |
+
if orig_ctx_len and args.context_size > orig_ctx_len:
|
142 |
+
scaling_factor = float(math.ceil(args.context_size / orig_ctx_len))
|
143 |
+
config.rope_scaling = {"type": "linear", "factor": scaling_factor}
|
144 |
+
|
145 |
+
# Load model and tokenizer
|
146 |
+
model = transformers.AutoModelForCausalLM.from_pretrained(
|
147 |
+
args.base_model,
|
148 |
+
config=config,
|
149 |
+
cache_dir=args.cache_dir,
|
150 |
+
torch_dtype=torch.float16,
|
151 |
+
load_in_4bit=True,
|
152 |
+
device_map="auto",
|
153 |
+
)
|
154 |
+
model.resize_token_embeddings(32001)
|
155 |
+
|
156 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
157 |
+
args.base_model,
|
158 |
+
cache_dir=args.cache_dir,
|
159 |
+
model_max_length=args.context_size if args.context_size > orig_ctx_len else orig_ctx_len,
|
160 |
+
padding_side="right",
|
161 |
+
use_fast=True,
|
162 |
+
)
|
163 |
+
|
164 |
+
model.eval()
|
165 |
+
respond = build_generator(model, tokenizer)
|
166 |
+
|
167 |
+
|
168 |
+
while True:
|
169 |
+
user_input = input("\n\033[1m\033[32mUser:\033[0m ")
|
170 |
+
if user_input.lower() == 'exit':
|
171 |
+
print("Exiting the application.")
|
172 |
+
break
|
173 |
+
# Just call the respond function without printing the output, as it's already handled in response
|
174 |
+
full_text = respond(
|
175 |
+
message=user_input,
|
176 |
+
max_gen_len=args.max_gen_len,
|
177 |
+
temperature=args.temperature,
|
178 |
+
top_p=args.top_p,
|
179 |
+
chat_type=args.chat_type
|
180 |
+
)
|
181 |
+
|
182 |
+
if __name__ == "__main__":
|
183 |
+
main()
|
gorilla/streaming_llm/__init__.py
ADDED
File without changes
|
gorilla/streaming_llm/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (143 Bytes). View file
|
|
gorilla/streaming_llm/__pycache__/enable_streaming_llm.cpython-310.pyc
ADDED
Binary file (1.02 kB). View file
|
|
gorilla/streaming_llm/__pycache__/kv_cache.cpython-310.pyc
ADDED
Binary file (2.85 kB). View file
|
|
gorilla/streaming_llm/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (2.82 kB). View file
|
|
gorilla/streaming_llm/enable_streaming_llm.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from streaming_llm.kv_cache import StartRecentKVCache
|
2 |
+
|
3 |
+
|
4 |
+
def enable_streaming_llm(model, start_size, recent_size, use_flash_attn=True):
|
5 |
+
if "llama" in model.config.model_type:
|
6 |
+
k_seq_dim = v_seq_dim = 2
|
7 |
+
from streaming_llm.pos_shift.modify_llama import (
|
8 |
+
enable_llama_pos_shift_attention,
|
9 |
+
)
|
10 |
+
|
11 |
+
enable_llama_pos_shift_attention(model, use_flash_attn)
|
12 |
+
elif "mpt" in model.config.model_type:
|
13 |
+
v_seq_dim = 2
|
14 |
+
k_seq_dim = 3
|
15 |
+
# elif "gpt_neox" in model.config.model_type:
|
16 |
+
# k_seq_dim = v_seq_dim = 2
|
17 |
+
# from streaming_llm.pos_shift.modify_gpt_neox import (
|
18 |
+
# enable_gpt_neox_pos_shift_attention,
|
19 |
+
# )
|
20 |
+
|
21 |
+
# enable_gpt_neox_pos_shift_attention(model)
|
22 |
+
elif "falcon" in model.config.model_type:
|
23 |
+
v_seq_dim = 1
|
24 |
+
k_seq_dim = 1
|
25 |
+
from streaming_llm.pos_shift.modify_falcon import (
|
26 |
+
enable_falcon_pos_shift_attention,
|
27 |
+
)
|
28 |
+
|
29 |
+
enable_falcon_pos_shift_attention(model)
|
30 |
+
else:
|
31 |
+
raise ValueError(f"got {model.config.model_type}")
|
32 |
+
kv_cache = StartRecentKVCache(
|
33 |
+
start_size=start_size,
|
34 |
+
recent_size=recent_size,
|
35 |
+
k_seq_dim=k_seq_dim,
|
36 |
+
v_seq_dim=v_seq_dim,
|
37 |
+
)
|
38 |
+
return kv_cache
|
gorilla/streaming_llm/kv_cache.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def slice2d(x, start, end):
|
5 |
+
return x[:, :, start:end, ...]
|
6 |
+
|
7 |
+
|
8 |
+
def slice3d(x, start, end):
|
9 |
+
return x[:, :, :, start:end, ...]
|
10 |
+
|
11 |
+
|
12 |
+
def slice1d(x, start, end):
|
13 |
+
return x[:, start:end, ...]
|
14 |
+
|
15 |
+
|
16 |
+
DIM_TO_SLICE = {
|
17 |
+
1: slice1d,
|
18 |
+
2: slice2d,
|
19 |
+
3: slice3d,
|
20 |
+
}
|
21 |
+
|
22 |
+
|
23 |
+
class StartRecentKVCache:
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
start_size=4,
|
27 |
+
recent_size=512,
|
28 |
+
k_seq_dim=2,
|
29 |
+
v_seq_dim=2,
|
30 |
+
):
|
31 |
+
print(f"StartRecentKVCache: {start_size}, {recent_size}")
|
32 |
+
self.start_size = start_size
|
33 |
+
self.recent_size = recent_size
|
34 |
+
self.cache_size = start_size + recent_size
|
35 |
+
self.k_seq_dim = k_seq_dim
|
36 |
+
self.v_seq_dim = v_seq_dim
|
37 |
+
self.k_slice = DIM_TO_SLICE[k_seq_dim]
|
38 |
+
self.v_slice = DIM_TO_SLICE[v_seq_dim]
|
39 |
+
|
40 |
+
def __call__(self, past_key_values):
|
41 |
+
if past_key_values is None:
|
42 |
+
return None
|
43 |
+
seq_len = past_key_values[0][0].size(self.k_seq_dim)
|
44 |
+
if seq_len <= self.cache_size:
|
45 |
+
return past_key_values
|
46 |
+
return [
|
47 |
+
[
|
48 |
+
torch.cat(
|
49 |
+
[
|
50 |
+
self.k_slice(k, 0, self.start_size),
|
51 |
+
self.k_slice(k, seq_len - self.recent_size, seq_len),
|
52 |
+
],
|
53 |
+
dim=self.k_seq_dim,
|
54 |
+
),
|
55 |
+
torch.cat(
|
56 |
+
[
|
57 |
+
self.v_slice(v, 0, self.start_size),
|
58 |
+
self.v_slice(v, seq_len - self.recent_size, seq_len),
|
59 |
+
],
|
60 |
+
dim=self.v_seq_dim,
|
61 |
+
),
|
62 |
+
]
|
63 |
+
for k, v in past_key_values
|
64 |
+
]
|
65 |
+
|
66 |
+
def evict_for_space(self, past_key_values, num_coming):
|
67 |
+
if past_key_values is None:
|
68 |
+
return None
|
69 |
+
seq_len = past_key_values[0][0].size(self.k_seq_dim)
|
70 |
+
if seq_len + num_coming <= self.cache_size:
|
71 |
+
return past_key_values
|
72 |
+
return [
|
73 |
+
[
|
74 |
+
torch.cat(
|
75 |
+
[
|
76 |
+
self.k_slice(k, 0, self.start_size),
|
77 |
+
self.k_slice(
|
78 |
+
k, seq_len - self.recent_size + num_coming, seq_len
|
79 |
+
),
|
80 |
+
],
|
81 |
+
dim=self.k_seq_dim,
|
82 |
+
),
|
83 |
+
torch.cat(
|
84 |
+
[
|
85 |
+
self.v_slice(v, 0, self.start_size),
|
86 |
+
self.v_slice(
|
87 |
+
v, seq_len - self.recent_size + num_coming, seq_len
|
88 |
+
),
|
89 |
+
],
|
90 |
+
dim=self.v_seq_dim,
|
91 |
+
),
|
92 |
+
]
|
93 |
+
for k, v in past_key_values
|
94 |
+
]
|
95 |
+
|
96 |
+
def evict_range(self, past_key_values, start, end):
|
97 |
+
if past_key_values is None:
|
98 |
+
return None
|
99 |
+
seq_len = past_key_values[0][0].size(self.k_seq_dim)
|
100 |
+
assert start <= end and end <= seq_len
|
101 |
+
return [
|
102 |
+
[
|
103 |
+
torch.cat(
|
104 |
+
[
|
105 |
+
self.k_slice(k, 0, start),
|
106 |
+
self.k_slice(k, end, seq_len),
|
107 |
+
],
|
108 |
+
dim=self.k_seq_dim,
|
109 |
+
),
|
110 |
+
torch.cat(
|
111 |
+
[
|
112 |
+
self.v_slice(v, 0, start),
|
113 |
+
self.v_slice(v, end, seq_len),
|
114 |
+
],
|
115 |
+
dim=self.v_seq_dim,
|
116 |
+
),
|
117 |
+
]
|
118 |
+
for k, v in past_key_values
|
119 |
+
]
|
gorilla/streaming_llm/pos_shift/__init__.py
ADDED
File without changes
|
gorilla/streaming_llm/pos_shift/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (153 Bytes). View file
|
|
gorilla/streaming_llm/pos_shift/__pycache__/modify_llama.cpython-310.pyc
ADDED
Binary file (6.52 kB). View file
|
|
gorilla/streaming_llm/pos_shift/modify_falcon.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Optional, Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
import torch.utils.checkpoint
|
7 |
+
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from transformers.models.falcon.modeling_falcon import (
|
11 |
+
FalconAttention,
|
12 |
+
rotate_half,
|
13 |
+
)
|
14 |
+
import types
|
15 |
+
|
16 |
+
__all__ = ["enable_falcon_pos_shift_attention"]
|
17 |
+
|
18 |
+
|
19 |
+
def falcon_pos_shift_attention_forward(
|
20 |
+
self,
|
21 |
+
hidden_states: torch.Tensor,
|
22 |
+
alibi: torch.Tensor,
|
23 |
+
attention_mask: torch.Tensor,
|
24 |
+
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
25 |
+
head_mask: Optional[torch.Tensor] = None,
|
26 |
+
use_cache: bool = False,
|
27 |
+
output_attentions: bool = False,
|
28 |
+
):
|
29 |
+
fused_qkv = self.query_key_value(
|
30 |
+
hidden_states
|
31 |
+
) # [batch_size, seq_length, 3 x hidden_size]
|
32 |
+
|
33 |
+
# 3 x [batch_size, seq_length, num_heads, head_dim]
|
34 |
+
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
|
35 |
+
|
36 |
+
batch_size, q_length, _, _ = query_layer.shape
|
37 |
+
|
38 |
+
query_layer = query_layer.transpose(1, 2).reshape(
|
39 |
+
batch_size * self.num_heads, q_length, self.head_dim
|
40 |
+
)
|
41 |
+
|
42 |
+
# dirty hack to fix the inconsistency between falcon-40b and falcon-7b
|
43 |
+
num_kv = self.num_heads if self.num_heads == 128 else self.num_kv
|
44 |
+
key_layer = key_layer.transpose(1, 2).reshape(
|
45 |
+
batch_size * num_kv,
|
46 |
+
q_length,
|
47 |
+
self.head_dim,
|
48 |
+
)
|
49 |
+
value_layer = value_layer.transpose(1, 2).reshape(
|
50 |
+
batch_size * num_kv, q_length, self.head_dim
|
51 |
+
)
|
52 |
+
|
53 |
+
past_len = 0
|
54 |
+
if layer_past is not None:
|
55 |
+
past_len = layer_past[0].shape[1]
|
56 |
+
|
57 |
+
query_layer_copy = query_layer.clone()
|
58 |
+
query_layer, _ = self.maybe_rotary(query_layer, query_layer_copy, past_len)
|
59 |
+
if layer_past is not None:
|
60 |
+
past_key, past_value = layer_past
|
61 |
+
# concatenate along seq_length dimension:
|
62 |
+
# - key: [batch_size * self.num_heads, head_dim, kv_length]
|
63 |
+
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
64 |
+
key_layer = torch.cat((past_key, key_layer), dim=1)
|
65 |
+
value_layer = torch.cat((past_value, value_layer), dim=1)
|
66 |
+
|
67 |
+
if use_cache is True:
|
68 |
+
present = (key_layer, value_layer)
|
69 |
+
else:
|
70 |
+
present = None
|
71 |
+
|
72 |
+
key_layer_copy = key_layer.clone()
|
73 |
+
_, key_layer = self.maybe_rotary(key_layer_copy, key_layer, 0)
|
74 |
+
|
75 |
+
_, kv_length, _ = key_layer.shape
|
76 |
+
|
77 |
+
if alibi is None:
|
78 |
+
query_layer_ = query_layer.reshape(
|
79 |
+
batch_size, self.num_heads, -1, self.head_dim
|
80 |
+
)
|
81 |
+
key_layer_ = key_layer.reshape(batch_size, num_kv, -1, self.head_dim)
|
82 |
+
value_layer_ = value_layer.reshape(batch_size, num_kv, -1, self.head_dim)
|
83 |
+
|
84 |
+
if layer_past is not None:
|
85 |
+
attn_output = F.scaled_dot_product_attention(
|
86 |
+
query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=False
|
87 |
+
)
|
88 |
+
else:
|
89 |
+
attn_output = F.scaled_dot_product_attention(
|
90 |
+
query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
|
91 |
+
)
|
92 |
+
|
93 |
+
x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
|
94 |
+
x = x.permute(0, 2, 1, 3)
|
95 |
+
attn_output = x.reshape(batch_size, q_length, self.num_heads * self.head_dim)
|
96 |
+
|
97 |
+
output_tensor = self.dense(attn_output)
|
98 |
+
|
99 |
+
outputs = (output_tensor, present)
|
100 |
+
assert not output_attentions # not supported.
|
101 |
+
return outputs
|
102 |
+
else:
|
103 |
+
attention_mask_float = (
|
104 |
+
(attention_mask * 1.0).masked_fill(attention_mask, -1e9).to(torch.bfloat16)
|
105 |
+
)
|
106 |
+
matmul_result = query_layer @ key_layer.transpose(-1, -2)
|
107 |
+
|
108 |
+
# change view to [batch_size, num_heads, q_length, kv_length]
|
109 |
+
attention_scores = matmul_result.view(
|
110 |
+
batch_size, self.num_heads, q_length, kv_length
|
111 |
+
)
|
112 |
+
|
113 |
+
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
|
114 |
+
input_dtype = attention_scores.dtype
|
115 |
+
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
|
116 |
+
if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
|
117 |
+
attention_scores = attention_scores.to(torch.float16) #torch.float32
|
118 |
+
# attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
|
119 |
+
attention_probs = F.softmax(
|
120 |
+
(attention_scores + alibi.view(batch_size, self.num_heads, 1, -1))
|
121 |
+
* self.inv_norm_factor
|
122 |
+
+ attention_mask_float,
|
123 |
+
dim=-1,
|
124 |
+
dtype=hidden_states.dtype,
|
125 |
+
)
|
126 |
+
# [batch_size, num_heads, q_length, kv_length]
|
127 |
+
attention_probs = self.attention_dropout(attention_probs)
|
128 |
+
|
129 |
+
if head_mask is not None:
|
130 |
+
attention_probs = attention_probs * head_mask
|
131 |
+
|
132 |
+
# change view [batch_size x num_heads, q_length, kv_length]
|
133 |
+
attention_probs_reshaped = attention_probs.view(
|
134 |
+
batch_size * self.num_heads, q_length, kv_length
|
135 |
+
)
|
136 |
+
|
137 |
+
# matmul: [batch_size * num_heads, q_length, head_dim]
|
138 |
+
context_layer = attention_probs_reshaped @ value_layer
|
139 |
+
|
140 |
+
# change view [batch_size, num_heads, q_length, head_dim]
|
141 |
+
context_layer = self._merge_heads(context_layer)
|
142 |
+
|
143 |
+
output_tensor = self.dense(context_layer)
|
144 |
+
|
145 |
+
outputs = (output_tensor, present)
|
146 |
+
if output_attentions:
|
147 |
+
outputs += (attention_probs,)
|
148 |
+
|
149 |
+
return outputs
|
150 |
+
|
151 |
+
|
152 |
+
def enable_falcon_pos_shift_attention(model):
|
153 |
+
for name, module in reversed(model._modules.items()):
|
154 |
+
if len(list(module.children())) > 0:
|
155 |
+
enable_falcon_pos_shift_attention(
|
156 |
+
module,
|
157 |
+
)
|
158 |
+
|
159 |
+
if "self_attention" == name[-14:]:
|
160 |
+
model._modules[name].forward = types.MethodType(
|
161 |
+
falcon_pos_shift_attention_forward, model._modules[name]
|
162 |
+
)
|
gorilla/streaming_llm/pos_shift/modify_llama.py
ADDED
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Optional, Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
import torch.utils.checkpoint
|
7 |
+
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from transformers.models.llama.modeling_llama import (
|
11 |
+
LlamaAttention,
|
12 |
+
rotate_half,
|
13 |
+
apply_rotary_pos_emb,
|
14 |
+
repeat_kv,
|
15 |
+
)
|
16 |
+
import types
|
17 |
+
import transformers
|
18 |
+
from einops import rearrange
|
19 |
+
from flash_attn import __version__ as flash_attn_version
|
20 |
+
from flash_attn.bert_padding import pad_input, unpad_input
|
21 |
+
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
|
22 |
+
|
23 |
+
__all__ = ["enable_llama_pos_shift_attention"]
|
24 |
+
|
25 |
+
|
26 |
+
def apply_rotary_pos_emb_single(x, cos, sin, position_ids):
|
27 |
+
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
28 |
+
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
29 |
+
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
30 |
+
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
31 |
+
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
32 |
+
x_embed = (x * cos) + (rotate_half(x) * sin)
|
33 |
+
return x_embed
|
34 |
+
|
35 |
+
|
36 |
+
def llama_pos_shift_attention_forward(
|
37 |
+
self,
|
38 |
+
hidden_states: torch.Tensor,
|
39 |
+
attention_mask: Optional[torch.Tensor] = None,
|
40 |
+
position_ids: Optional[torch.LongTensor] = None,
|
41 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
42 |
+
output_attentions: bool = False,
|
43 |
+
use_cache: bool = False,
|
44 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
45 |
+
bsz, q_len, _ = hidden_states.size()
|
46 |
+
|
47 |
+
if self.config.pretraining_tp > 1:
|
48 |
+
key_value_slicing = (
|
49 |
+
self.num_key_value_heads * self.head_dim
|
50 |
+
) // self.config.pretraining_tp
|
51 |
+
query_slices = self.q_proj.weight.split(
|
52 |
+
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
53 |
+
)
|
54 |
+
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
55 |
+
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
56 |
+
|
57 |
+
query_states = [
|
58 |
+
F.linear(hidden_states, query_slices[i])
|
59 |
+
for i in range(self.config.pretraining_tp)
|
60 |
+
]
|
61 |
+
query_states = torch.cat(query_states, dim=-1)
|
62 |
+
|
63 |
+
key_states = [
|
64 |
+
F.linear(hidden_states, key_slices[i])
|
65 |
+
for i in range(self.config.pretraining_tp)
|
66 |
+
]
|
67 |
+
key_states = torch.cat(key_states, dim=-1)
|
68 |
+
|
69 |
+
value_states = [
|
70 |
+
F.linear(hidden_states, value_slices[i])
|
71 |
+
for i in range(self.config.pretraining_tp)
|
72 |
+
]
|
73 |
+
value_states = torch.cat(value_states, dim=-1)
|
74 |
+
|
75 |
+
else:
|
76 |
+
query_states = self.q_proj(hidden_states)
|
77 |
+
key_states = self.k_proj(hidden_states)
|
78 |
+
value_states = self.v_proj(hidden_states)
|
79 |
+
|
80 |
+
query_states = query_states.view(
|
81 |
+
bsz, q_len, self.num_heads, self.head_dim
|
82 |
+
).transpose(1, 2)
|
83 |
+
key_states = key_states.view(
|
84 |
+
bsz, q_len, self.num_key_value_heads, self.head_dim
|
85 |
+
).transpose(1, 2)
|
86 |
+
value_states = value_states.view(
|
87 |
+
bsz, q_len, self.num_key_value_heads, self.head_dim
|
88 |
+
).transpose(1, 2)
|
89 |
+
|
90 |
+
kv_seq_len = key_states.shape[-2]
|
91 |
+
if past_key_value is not None:
|
92 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
93 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
94 |
+
### Shift Pos: query pos is min(cache_size, idx)
|
95 |
+
# query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
96 |
+
query_states = apply_rotary_pos_emb_single(query_states, cos, sin, position_ids)
|
97 |
+
###
|
98 |
+
|
99 |
+
if past_key_value is not None:
|
100 |
+
# reuse k, v, self_attention
|
101 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
102 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
103 |
+
|
104 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
105 |
+
|
106 |
+
### Shift Pos: key pos is the pos in cache
|
107 |
+
key_position_ids = torch.arange(kv_seq_len, device=position_ids.device).unsqueeze(0)
|
108 |
+
key_states = apply_rotary_pos_emb_single(key_states, cos, sin, key_position_ids)
|
109 |
+
###
|
110 |
+
|
111 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
112 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
113 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
114 |
+
|
115 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(
|
116 |
+
self.head_dim
|
117 |
+
)
|
118 |
+
|
119 |
+
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
120 |
+
raise ValueError(
|
121 |
+
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
122 |
+
f" {attn_weights.size()}"
|
123 |
+
)
|
124 |
+
|
125 |
+
if attention_mask is not None:
|
126 |
+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
127 |
+
raise ValueError(
|
128 |
+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
129 |
+
)
|
130 |
+
attn_weights = attn_weights + attention_mask
|
131 |
+
|
132 |
+
# upcast attention to fp16
|
133 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float16).to( #torch.float32
|
134 |
+
query_states.dtype
|
135 |
+
)
|
136 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
137 |
+
|
138 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
139 |
+
raise ValueError(
|
140 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
141 |
+
f" {attn_output.size()}"
|
142 |
+
)
|
143 |
+
|
144 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
145 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
146 |
+
|
147 |
+
if self.config.pretraining_tp > 1:
|
148 |
+
attn_output = attn_output.split(
|
149 |
+
self.hidden_size // self.config.pretraining_tp, dim=2
|
150 |
+
)
|
151 |
+
o_proj_slices = self.o_proj.weight.split(
|
152 |
+
self.hidden_size // self.config.pretraining_tp, dim=1
|
153 |
+
)
|
154 |
+
attn_output = sum(
|
155 |
+
[
|
156 |
+
F.linear(attn_output[i], o_proj_slices[i])
|
157 |
+
for i in range(self.config.pretraining_tp)
|
158 |
+
]
|
159 |
+
)
|
160 |
+
else:
|
161 |
+
attn_output = self.o_proj(attn_output)
|
162 |
+
|
163 |
+
if not output_attentions:
|
164 |
+
attn_weights = None
|
165 |
+
|
166 |
+
return attn_output, attn_weights, past_key_value
|
167 |
+
|
168 |
+
|
169 |
+
def llama_pos_shift_attention_forward_flashattn(
|
170 |
+
self,
|
171 |
+
hidden_states: torch.Tensor,
|
172 |
+
attention_mask: Optional[torch.Tensor] = None,
|
173 |
+
position_ids: Optional[torch.LongTensor] = None,
|
174 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
175 |
+
output_attentions: bool = False,
|
176 |
+
use_cache: bool = False,
|
177 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
178 |
+
bsz, q_len, _ = hidden_states.size()
|
179 |
+
|
180 |
+
query_states = self.q_proj(hidden_states)
|
181 |
+
key_states = self.k_proj(hidden_states)
|
182 |
+
value_states = self.v_proj(hidden_states)
|
183 |
+
|
184 |
+
query_states = query_states.view(
|
185 |
+
bsz, q_len, self.num_heads, self.head_dim
|
186 |
+
).transpose(1, 2)
|
187 |
+
key_states = key_states.view(
|
188 |
+
bsz, q_len, self.num_key_value_heads, self.head_dim
|
189 |
+
).transpose(1, 2)
|
190 |
+
value_states = value_states.view(
|
191 |
+
bsz, q_len, self.num_key_value_heads, self.head_dim
|
192 |
+
).transpose(1, 2)
|
193 |
+
|
194 |
+
kv_seq_len = key_states.shape[-2]
|
195 |
+
if past_key_value is not None:
|
196 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
197 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
198 |
+
### Shift Pos: query pos is min(cache_size, idx)
|
199 |
+
# query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
200 |
+
query_states = apply_rotary_pos_emb_single(query_states, cos, sin, position_ids)
|
201 |
+
###
|
202 |
+
|
203 |
+
if past_key_value is not None:
|
204 |
+
# reuse k, v, self_attention
|
205 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
206 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
207 |
+
|
208 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
209 |
+
|
210 |
+
### Shift Pos: key pos is the pos in cache
|
211 |
+
key_position_ids = torch.arange(kv_seq_len, device=position_ids.device).unsqueeze(0)
|
212 |
+
key_states = apply_rotary_pos_emb_single(key_states, cos, sin, key_position_ids)
|
213 |
+
###
|
214 |
+
|
215 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
216 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
217 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
218 |
+
|
219 |
+
if past_key_value is None:
|
220 |
+
qkv = torch.stack(
|
221 |
+
[query_states, key_states, value_states], dim=2
|
222 |
+
) # [bsz, nh, 3, q_len, hd]
|
223 |
+
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
224 |
+
|
225 |
+
key_padding_mask = torch.full((bsz, q_len), True, dtype=torch.bool, device=attention_mask.device)
|
226 |
+
nheads = qkv.shape[-2]
|
227 |
+
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
228 |
+
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
|
229 |
+
x_unpad = rearrange(
|
230 |
+
x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
|
231 |
+
)
|
232 |
+
output_unpad = flash_attn_varlen_qkvpacked_func(
|
233 |
+
x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
234 |
+
)
|
235 |
+
output = rearrange(
|
236 |
+
pad_input(
|
237 |
+
rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len
|
238 |
+
),
|
239 |
+
"b s (h d) -> b s h d",
|
240 |
+
h=nheads,
|
241 |
+
)
|
242 |
+
output = output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
243 |
+
|
244 |
+
attn_output = self.o_proj(rearrange(output, "b s h d -> b s (h d)"))
|
245 |
+
attn_weights = None
|
246 |
+
else:
|
247 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(
|
248 |
+
self.head_dim
|
249 |
+
)
|
250 |
+
|
251 |
+
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
252 |
+
raise ValueError(
|
253 |
+
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
254 |
+
f" {attn_weights.size()}"
|
255 |
+
)
|
256 |
+
|
257 |
+
if attention_mask is not None:
|
258 |
+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
259 |
+
raise ValueError(
|
260 |
+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
261 |
+
)
|
262 |
+
attn_weights = attn_weights + attention_mask
|
263 |
+
|
264 |
+
# upcast attention to fp16
|
265 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float16).to( #torch.float32
|
266 |
+
query_states.dtype
|
267 |
+
)
|
268 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
269 |
+
|
270 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
271 |
+
raise ValueError(
|
272 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
273 |
+
f" {attn_output.size()}"
|
274 |
+
)
|
275 |
+
|
276 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
277 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
278 |
+
|
279 |
+
if self.config.pretraining_tp > 1:
|
280 |
+
attn_output = attn_output.split(
|
281 |
+
self.hidden_size // self.config.pretraining_tp, dim=2
|
282 |
+
)
|
283 |
+
o_proj_slices = self.o_proj.weight.split(
|
284 |
+
self.hidden_size // self.config.pretraining_tp, dim=1
|
285 |
+
)
|
286 |
+
attn_output = sum(
|
287 |
+
[
|
288 |
+
F.linear(attn_output[i], o_proj_slices[i])
|
289 |
+
for i in range(self.config.pretraining_tp)
|
290 |
+
]
|
291 |
+
)
|
292 |
+
else:
|
293 |
+
attn_output = self.o_proj(attn_output)
|
294 |
+
|
295 |
+
if not output_attentions:
|
296 |
+
attn_weights = None
|
297 |
+
|
298 |
+
return attn_output, attn_weights, past_key_value
|
299 |
+
|
300 |
+
|
301 |
+
def enable_llama_pos_shift_attention(model, use_flash_attn=True):
|
302 |
+
for name, module in reversed(model._modules.items()):
|
303 |
+
if len(list(module.children())) > 0:
|
304 |
+
enable_llama_pos_shift_attention(
|
305 |
+
module,
|
306 |
+
)
|
307 |
+
|
308 |
+
if isinstance(module, LlamaAttention):
|
309 |
+
model._modules[name].forward = types.MethodType(
|
310 |
+
llama_pos_shift_attention_forward_flashattn if use_flash_attn else llama_pos_shift_attention_forward, model._modules[name]
|
311 |
+
)
|
gorilla/streaming_llm/utils.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import argparse
|
3 |
+
from transformers import (
|
4 |
+
AutoTokenizer,
|
5 |
+
AutoModelForCausalLM,
|
6 |
+
)
|
7 |
+
import os.path as osp
|
8 |
+
import ssl
|
9 |
+
import urllib.request
|
10 |
+
import os
|
11 |
+
import json
|
12 |
+
|
13 |
+
|
14 |
+
def parse_args():
|
15 |
+
parser = argparse.ArgumentParser()
|
16 |
+
parser.add_argument(
|
17 |
+
"--model_name_or_path", type=str, default="models/llama/llama-7b"
|
18 |
+
)
|
19 |
+
parser.add_argument("--revision", type=str, default="main")
|
20 |
+
parser.add_argument("--tokenizer_name_or_path", type=str, default=None)
|
21 |
+
parser.add_argument("--dataset_name", type=str, default="wikitext")
|
22 |
+
|
23 |
+
parser.add_argument("--task", type=str, default="wikitext-2-raw-v1")
|
24 |
+
parser.add_argument(
|
25 |
+
"--split", type=str, default="test", choices=["validation", "test"]
|
26 |
+
)
|
27 |
+
|
28 |
+
parser.add_argument(
|
29 |
+
"--num_samples",
|
30 |
+
type=int,
|
31 |
+
default=1,
|
32 |
+
)
|
33 |
+
|
34 |
+
parser.add_argument(
|
35 |
+
"--output_dir",
|
36 |
+
type=str,
|
37 |
+
default="outputs/debug",
|
38 |
+
)
|
39 |
+
|
40 |
+
parser.add_argument("--enable_start_recent_kv_cache", action="store_true")
|
41 |
+
parser.add_argument("--start_size", type=int, default=1)
|
42 |
+
parser.add_argument("--recent_size", type=int, default=255)
|
43 |
+
parser.add_argument("--enable_pos_shift", action="store_true")
|
44 |
+
|
45 |
+
parser.add_argument("--num_eval_tokens", type=int, default=None)
|
46 |
+
|
47 |
+
args = parser.parse_args()
|
48 |
+
return args
|
49 |
+
|
50 |
+
|
51 |
+
def load(model_name_or_path):
|
52 |
+
print(f"Loading model from {model_name_or_path} ...")
|
53 |
+
# however, tensor parallel for running falcon will occur bugs
|
54 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
55 |
+
model_name_or_path,
|
56 |
+
trust_remote_code=True,
|
57 |
+
)
|
58 |
+
model = AutoModelForCausalLM.from_pretrained(
|
59 |
+
model_name_or_path,
|
60 |
+
device_map="auto",
|
61 |
+
torch_dtype=torch.float16,
|
62 |
+
trust_remote_code=True,
|
63 |
+
)
|
64 |
+
if tokenizer.pad_token_id is None:
|
65 |
+
if tokenizer.eos_token_id is not None:
|
66 |
+
tokenizer.pad_token_id = tokenizer.eos_token_id
|
67 |
+
else:
|
68 |
+
tokenizer.pad_token_id = 0
|
69 |
+
|
70 |
+
model.eval()
|
71 |
+
|
72 |
+
return model, tokenizer
|
73 |
+
|
74 |
+
|
75 |
+
def download_url(url: str, folder="folder"):
|
76 |
+
"""
|
77 |
+
Downloads the content of an url to a folder. Modified from \
|
78 |
+
https://github.com/pyg-team/pytorch_geometric/tree/master/torch_geometric
|
79 |
+
|
80 |
+
Args:
|
81 |
+
url (string): The url of target file.
|
82 |
+
folder (string): The target folder.
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
string: File path of downloaded files.
|
86 |
+
"""
|
87 |
+
|
88 |
+
file = url.rpartition("/")[2]
|
89 |
+
file = file if file[0] == "?" else file.split("?")[0]
|
90 |
+
path = osp.join(folder, file)
|
91 |
+
if osp.exists(path):
|
92 |
+
print(f"File {file} exists, use existing file.")
|
93 |
+
return path
|
94 |
+
|
95 |
+
print(f"Downloading {url}")
|
96 |
+
os.makedirs(folder, exist_ok=True)
|
97 |
+
ctx = ssl._create_unverified_context()
|
98 |
+
data = urllib.request.urlopen(url, context=ctx)
|
99 |
+
with open(path, "wb") as f:
|
100 |
+
f.write(data.read())
|
101 |
+
|
102 |
+
return path
|
103 |
+
|
104 |
+
|
105 |
+
def load_jsonl(
|
106 |
+
file_path,
|
107 |
+
):
|
108 |
+
list_data_dict = []
|
109 |
+
with open(file_path, "r") as f:
|
110 |
+
for line in f:
|
111 |
+
list_data_dict.append(json.loads(line))
|
112 |
+
return list_data_dict
|
gorilla/style.css
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
h1 {
|
2 |
+
text-align: center;
|
3 |
+
}
|
4 |
+
|
5 |
+
#duplicate-button {
|
6 |
+
margin: auto;
|
7 |
+
color: white;
|
8 |
+
background: #1565c0;
|
9 |
+
border-radius: 100vh;
|
10 |
+
}
|
11 |
+
|
12 |
+
.contain {
|
13 |
+
max-width: 900px;
|
14 |
+
margin: auto;
|
15 |
+
padding-top: 1.5rem;
|
16 |
+
}
|
gorilla/supervised-fine-tune-qlora.py
ADDED
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import os
|
3 |
+
import copy
|
4 |
+
import json
|
5 |
+
import math
|
6 |
+
import logging
|
7 |
+
from dataclasses import dataclass, field
|
8 |
+
from typing import Dict, Optional, Sequence
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import transformers
|
13 |
+
from torch.utils.data import Dataset
|
14 |
+
from transformers import Trainer, DataCollatorForLanguageModeling, BitsAndBytesConfig
|
15 |
+
from llama_attn_replace_sft import replace_llama_attn
|
16 |
+
from peft import LoraConfig, get_peft_model
|
17 |
+
from torch.distributed import barrier
|
18 |
+
|
19 |
+
IGNORE_INDEX = -100
|
20 |
+
DEFAULT_PAD_TOKEN = "[PAD]"
|
21 |
+
DEFAULT_EOS_TOKEN = "</s>"
|
22 |
+
DEFAULT_BOS_TOKEN = "<s>"
|
23 |
+
DEFAULT_UNK_TOKEN = "<unk>"
|
24 |
+
|
25 |
+
def _make_r_io_base(f, mode: str):
|
26 |
+
if not isinstance(f, io.IOBase):
|
27 |
+
f = open(f, mode=mode)
|
28 |
+
return f
|
29 |
+
|
30 |
+
def jload(f, mode="r"):
|
31 |
+
"""Load a .json file into a dictionary."""
|
32 |
+
f = _make_r_io_base(f, mode)
|
33 |
+
jdict = json.load(f)
|
34 |
+
f.close()
|
35 |
+
return jdict
|
36 |
+
|
37 |
+
PROMPT_DICT = {
|
38 |
+
"prompt_input": (
|
39 |
+
"Below is an instruction that describes a task, paired with an input that provides further context. "
|
40 |
+
"Write a response that appropriately completes the request.\n\n"
|
41 |
+
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
|
42 |
+
),
|
43 |
+
"prompt_no_input": (
|
44 |
+
"Below is an instruction that describes a task. "
|
45 |
+
"Write a response that appropriately completes the request.\n\n"
|
46 |
+
"### Instruction:\n{instruction}\n\n### Response:"
|
47 |
+
),
|
48 |
+
"prompt_no_input_llama2":(
|
49 |
+
"<s>[INST] <<SYS>>\n"
|
50 |
+
"You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\n"
|
51 |
+
"If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n"
|
52 |
+
"<</SYS>> \n\n {instruction} [/INST]"
|
53 |
+
),
|
54 |
+
"prompt_input_llama2": (
|
55 |
+
"<s>[INST] <<SYS>>\n"
|
56 |
+
"You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\n"
|
57 |
+
"If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n"
|
58 |
+
"<</SYS>> \n\n {instruction} \n{input} [/INST]"
|
59 |
+
)
|
60 |
+
}
|
61 |
+
|
62 |
+
|
63 |
+
@dataclass
|
64 |
+
class ModelArguments:
|
65 |
+
model_name_or_path: Optional[str] = field(default="EleutherAI/pythia-1.4b-deduped")
|
66 |
+
model_type: Optional[str] = field(default="llama")
|
67 |
+
|
68 |
+
|
69 |
+
@dataclass
|
70 |
+
class DataArguments:
|
71 |
+
data_path: str = field(default=None, metadata={"help": "Path to the training data."})
|
72 |
+
|
73 |
+
|
74 |
+
@dataclass
|
75 |
+
class TrainingArguments(transformers.TrainingArguments):
|
76 |
+
cache_dir: Optional[str] = field(default=None)
|
77 |
+
optim: str = field(default="adamw_torch")
|
78 |
+
model_max_length: int = field(
|
79 |
+
default=8192,
|
80 |
+
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
|
81 |
+
)
|
82 |
+
use_flash_attn: bool = field(
|
83 |
+
default=True,
|
84 |
+
metadata={"help": "Whether use flash attention for training."},
|
85 |
+
)
|
86 |
+
use_full_attn: bool = field(
|
87 |
+
default=False,
|
88 |
+
metadata={"help": "Whether to use plain, full-attention for training."},
|
89 |
+
)
|
90 |
+
low_rank_training: bool = field(
|
91 |
+
default=True,
|
92 |
+
metadata={"help": "Whether use low rank adaptation for training."},
|
93 |
+
)
|
94 |
+
trainable_params: str = field(
|
95 |
+
default="embed,norm",
|
96 |
+
metadata={"help": "Additional trainable parameters except LoRA weights, if low rank training."},
|
97 |
+
)
|
98 |
+
|
99 |
+
def smart_tokenizer_and_embedding_resize(
|
100 |
+
special_tokens_dict: Dict,
|
101 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
102 |
+
model: transformers.PreTrainedModel,
|
103 |
+
):
|
104 |
+
"""Resize tokenizer and embedding.
|
105 |
+
|
106 |
+
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
|
107 |
+
"""
|
108 |
+
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
|
109 |
+
model.resize_token_embeddings(len(tokenizer))
|
110 |
+
|
111 |
+
if num_new_tokens > 0:
|
112 |
+
input_embeddings = model.get_input_embeddings().weight.data
|
113 |
+
output_embeddings = model.get_output_embeddings().weight.data
|
114 |
+
|
115 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
116 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
117 |
+
|
118 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
119 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
120 |
+
|
121 |
+
|
122 |
+
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
|
123 |
+
"""Tokenize a list of strings."""
|
124 |
+
tokenized_list = [
|
125 |
+
tokenizer(
|
126 |
+
text,
|
127 |
+
return_tensors="pt",
|
128 |
+
padding="longest",
|
129 |
+
max_length=tokenizer.model_max_length,
|
130 |
+
truncation=True,
|
131 |
+
)
|
132 |
+
for text in strings
|
133 |
+
]
|
134 |
+
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
|
135 |
+
input_ids_lens = labels_lens = [
|
136 |
+
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
|
137 |
+
]
|
138 |
+
return dict(
|
139 |
+
input_ids=input_ids,
|
140 |
+
labels=labels,
|
141 |
+
input_ids_lens=input_ids_lens,
|
142 |
+
labels_lens=labels_lens,
|
143 |
+
)
|
144 |
+
|
145 |
+
|
146 |
+
def preprocess(
|
147 |
+
sources: Sequence[str],
|
148 |
+
targets: Sequence[str],
|
149 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
150 |
+
) -> Dict:
|
151 |
+
"""Preprocess the data by tokenizing."""
|
152 |
+
examples = [s + t for s, t in zip(sources, targets)]
|
153 |
+
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
|
154 |
+
input_ids = examples_tokenized["input_ids"]
|
155 |
+
labels = copy.deepcopy(input_ids)
|
156 |
+
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
|
157 |
+
label[:source_len] = IGNORE_INDEX
|
158 |
+
return dict(input_ids=input_ids, labels=labels)
|
159 |
+
|
160 |
+
|
161 |
+
class SupervisedDataset(Dataset):
|
162 |
+
"""Dataset for supervised fine-tuning."""
|
163 |
+
|
164 |
+
def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer):
|
165 |
+
super(SupervisedDataset, self).__init__()
|
166 |
+
logging.warning("Loading line item and alm data...")
|
167 |
+
list_data_dict = jload(data_path)
|
168 |
+
|
169 |
+
logging.warning("Formatting inputs...")
|
170 |
+
|
171 |
+
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input_llama2"], PROMPT_DICT["prompt_no_input_llama2"]
|
172 |
+
sources = [
|
173 |
+
prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
|
174 |
+
for example in list_data_dict
|
175 |
+
]
|
176 |
+
|
177 |
+
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
|
178 |
+
|
179 |
+
logging.warning("Tokenizing inputs... This may take some time...")
|
180 |
+
data_dict = preprocess(sources, targets, tokenizer)
|
181 |
+
|
182 |
+
self.input_ids = data_dict["input_ids"]
|
183 |
+
self.labels = data_dict["labels"]
|
184 |
+
|
185 |
+
def __len__(self):
|
186 |
+
return len(self.input_ids)
|
187 |
+
|
188 |
+
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
189 |
+
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
|
190 |
+
|
191 |
+
|
192 |
+
@dataclass
|
193 |
+
class DataCollatorForSupervisedDataset(object):
|
194 |
+
"""Collate examples for supervised fine-tuning."""
|
195 |
+
|
196 |
+
tokenizer: transformers.PreTrainedTokenizer
|
197 |
+
|
198 |
+
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
199 |
+
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
|
200 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(
|
201 |
+
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
202 |
+
)
|
203 |
+
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
|
204 |
+
return dict(
|
205 |
+
input_ids=input_ids,
|
206 |
+
labels=labels,
|
207 |
+
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
208 |
+
)
|
209 |
+
|
210 |
+
|
211 |
+
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
|
212 |
+
"""Make dataset and collator for supervised fine-tuning."""
|
213 |
+
train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path)
|
214 |
+
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
215 |
+
return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
|
216 |
+
|
217 |
+
|
218 |
+
def train():
|
219 |
+
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
|
220 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
221 |
+
|
222 |
+
# NOTE: May expand supported model types in the future
|
223 |
+
# if model_args.model_type == "gpt-neox":
|
224 |
+
# replace_gpt_neox_attn(training_args.use_flash_attn, training_args.use_full_attn)
|
225 |
+
# else:
|
226 |
+
replace_llama_attn(training_args.use_flash_attn, training_args.use_full_attn)
|
227 |
+
|
228 |
+
# Set RoPE scaling factor
|
229 |
+
config = transformers.AutoConfig.from_pretrained(
|
230 |
+
model_args.model_name_or_path,
|
231 |
+
cache_dir=training_args.cache_dir,
|
232 |
+
)
|
233 |
+
|
234 |
+
orig_rope_scaling = getattr(config, "rope_scaling", {"factor": 1})
|
235 |
+
# Check if orig_rope_scaling is a dictionary before accessing its "get" method
|
236 |
+
if isinstance(orig_rope_scaling, dict):
|
237 |
+
orig_rope_scaling_factor = orig_rope_scaling.get("factor", 1)
|
238 |
+
else:
|
239 |
+
orig_rope_scaling_factor = 1 #orig_rope_scaling["factor"] if "factor" in orig_rope_scaling.keys() else 1
|
240 |
+
|
241 |
+
orig_ctx_len = getattr(config, "max_position_embeddings", None)
|
242 |
+
if orig_ctx_len:
|
243 |
+
orig_ctx_len *= orig_rope_scaling_factor
|
244 |
+
if training_args.model_max_length > orig_ctx_len:
|
245 |
+
scaling_factor = float(math.ceil(training_args.model_max_length / orig_ctx_len))
|
246 |
+
config.rope_scaling = {"type": "linear", "factor": scaling_factor}
|
247 |
+
|
248 |
+
# Load model and tokenizer
|
249 |
+
model = transformers.AutoModelForCausalLM.from_pretrained(
|
250 |
+
model_args.model_name_or_path,
|
251 |
+
config=config,
|
252 |
+
cache_dir=training_args.cache_dir,
|
253 |
+
torch_dtype=torch.bfloat16,
|
254 |
+
quantization_config=BitsAndBytesConfig(
|
255 |
+
load_in_4bit=True,
|
256 |
+
llm_int8_threshold=6.0,
|
257 |
+
llm_int8_has_fp16_weight=False,
|
258 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
259 |
+
bnb_4bit_use_double_quant=True,
|
260 |
+
bnb_4bit_quant_type="nf4",
|
261 |
+
),
|
262 |
+
)
|
263 |
+
|
264 |
+
for param in model.parameters():
|
265 |
+
param.requires_grad = False # freeze the model - train adapters later
|
266 |
+
if param.ndim == 1:
|
267 |
+
# cast the small parameters (e.g. layernorm) to fp32 for stability
|
268 |
+
param.data = param.data.to(torch.float16) #torch.float32
|
269 |
+
|
270 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
271 |
+
model_args.model_name_or_path,
|
272 |
+
cache_dir=training_args.cache_dir,
|
273 |
+
model_max_length=training_args.model_max_length,
|
274 |
+
padding_side="right",
|
275 |
+
use_fast=True,
|
276 |
+
)
|
277 |
+
|
278 |
+
special_tokens_dict = dict()
|
279 |
+
if tokenizer.pad_token is None:
|
280 |
+
special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
|
281 |
+
if tokenizer.eos_token is None:
|
282 |
+
special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
|
283 |
+
if tokenizer.bos_token is None:
|
284 |
+
special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
|
285 |
+
if tokenizer.unk_token is None:
|
286 |
+
special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
|
287 |
+
|
288 |
+
smart_tokenizer_and_embedding_resize(
|
289 |
+
special_tokens_dict=special_tokens_dict,
|
290 |
+
tokenizer=tokenizer,
|
291 |
+
model=model,
|
292 |
+
)
|
293 |
+
|
294 |
+
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
|
295 |
+
|
296 |
+
if training_args.low_rank_training:
|
297 |
+
if model_args.model_type == "gpt-neox":
|
298 |
+
# added `dense` to match with llama as the basic LoRA would only target 'query_key_value'
|
299 |
+
targets = ["query_key_value", "dense"]
|
300 |
+
else:
|
301 |
+
targets=["q_proj", "k_proj", "v_proj", "o_proj"]
|
302 |
+
|
303 |
+
config = LoraConfig(
|
304 |
+
r=8,
|
305 |
+
lora_alpha=16,
|
306 |
+
target_modules=targets,
|
307 |
+
lora_dropout=0,
|
308 |
+
bias="none",
|
309 |
+
task_type="CAUSAL_LM",
|
310 |
+
)
|
311 |
+
model = get_peft_model(model, config)
|
312 |
+
# enable trainable params
|
313 |
+
[p.requires_grad_() for n, p in model.named_parameters() if any([k in n for k in training_args.trainable_params.split(",")])]
|
314 |
+
|
315 |
+
class CastOutputToFloat(nn.Sequential):
|
316 |
+
def forward(self, x):
|
317 |
+
return super().forward(x).to(torch.float16) #torch.float32
|
318 |
+
|
319 |
+
model.lm_head = CastOutputToFloat(model.lm_head)
|
320 |
+
|
321 |
+
# Verifying the datatypes.
|
322 |
+
dtypes = {}
|
323 |
+
for _, p in model.named_parameters():
|
324 |
+
dtype = p.dtype
|
325 |
+
if dtype not in dtypes:
|
326 |
+
dtypes[dtype] = 0
|
327 |
+
dtypes[dtype] += p.numel()
|
328 |
+
total = 0
|
329 |
+
for k, v in dtypes.items():
|
330 |
+
total += v
|
331 |
+
for k, v in dtypes.items():
|
332 |
+
print(k, v, v / total)
|
333 |
+
|
334 |
+
model.config.use_cache = True # required for gradient checkpointing
|
335 |
+
model.enable_input_require_grads() # required for gradient checkpointing
|
336 |
+
model.gradient_checkpointing_enable() # enable gradient checkpointing
|
337 |
+
|
338 |
+
trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
|
339 |
+
trainer.train()
|
340 |
+
trainer.save_state()
|
341 |
+
trainer.save_model(output_dir=training_args.output_dir)
|
342 |
+
|
343 |
+
|
344 |
+
if __name__ == "__main__":
|
345 |
+
train()
|
gorilla/supervised-fine-tune.py
ADDED
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import os
|
3 |
+
import copy
|
4 |
+
import json
|
5 |
+
import math
|
6 |
+
import logging
|
7 |
+
# %%
|
8 |
+
import pandas as pd
|
9 |
+
|
10 |
+
# print('Loading line item and alm data')
|
11 |
+
# df_i = pd.read_csv('/home/tosi-n/ark/data/jack_line_item_ner_task.csv', sep='\t')[['context', 'instruction', 'response']]
|
12 |
+
# df_ii = pd.read_csv('/home/tosi-n/ark/data/alm_task_data.csv')[['context', 'instruction', 'response']]
|
13 |
+
# df = pd.concat([df_i, df_ii], ignore_index=True)
|
14 |
+
# # rename columns context and response to input and output
|
15 |
+
# df = df.rename(columns={'context':'input', 'response':'output'})
|
16 |
+
|
17 |
+
# # Replace NoneType with empty string
|
18 |
+
# df = df.fillna('')
|
19 |
+
# # produce a list of dictionaries
|
20 |
+
# list_data_dict = df.to_dict('records')
|
21 |
+
# import json
|
22 |
+
# with open('/home/tosi-n/ark/data/line_item_and_alm_data.json', 'w') as f:
|
23 |
+
# json.dump(list_data_dict, f)
|
24 |
+
|
25 |
+
# %%
|
26 |
+
from dataclasses import dataclass, field
|
27 |
+
from typing import Dict, Optional, Sequence
|
28 |
+
|
29 |
+
import torch
|
30 |
+
import transformers
|
31 |
+
from torch.utils.data import Dataset
|
32 |
+
from transformers import Trainer, DataCollatorForLanguageModeling
|
33 |
+
from llama_attn_replace_sft import replace_llama_attn
|
34 |
+
from peft import LoraConfig, get_peft_model
|
35 |
+
from torch.distributed import barrier
|
36 |
+
|
37 |
+
IGNORE_INDEX = -100
|
38 |
+
DEFAULT_PAD_TOKEN = "[PAD]"
|
39 |
+
DEFAULT_EOS_TOKEN = "</s>"
|
40 |
+
DEFAULT_BOS_TOKEN = "<s>"
|
41 |
+
DEFAULT_UNK_TOKEN = "<unk>"
|
42 |
+
|
43 |
+
def _make_r_io_base(f, mode: str):
|
44 |
+
if not isinstance(f, io.IOBase):
|
45 |
+
f = open(f, mode=mode)
|
46 |
+
return f
|
47 |
+
|
48 |
+
def jload(f, mode="r"):
|
49 |
+
"""Load a .json file into a dictionary."""
|
50 |
+
f = _make_r_io_base(f, mode)
|
51 |
+
jdict = json.load(f)
|
52 |
+
f.close()
|
53 |
+
return jdict
|
54 |
+
|
55 |
+
PROMPT_DICT = {
|
56 |
+
"prompt_input": (
|
57 |
+
"Below is an instruction that describes a task, paired with an input that provides further context. "
|
58 |
+
"Write a response that appropriately completes the request.\n\n"
|
59 |
+
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
|
60 |
+
),
|
61 |
+
"prompt_no_input": (
|
62 |
+
"Below is an instruction that describes a task. "
|
63 |
+
"Write a response that appropriately completes the request.\n\n"
|
64 |
+
"### Instruction:\n{instruction}\n\n### Response:"
|
65 |
+
),
|
66 |
+
"prompt_no_input_llama2":(
|
67 |
+
"<s>[INST] <<SYS>>\n"
|
68 |
+
"You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\n"
|
69 |
+
"If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n"
|
70 |
+
"<</SYS>> \n\n {instruction} [/INST]"
|
71 |
+
),
|
72 |
+
"prompt_input_llama2": (
|
73 |
+
"<s>[INST] <<SYS>>\n"
|
74 |
+
"You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\n"
|
75 |
+
"If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n"
|
76 |
+
"<</SYS>> \n\n {instruction} \n{input} [/INST]"
|
77 |
+
)
|
78 |
+
}
|
79 |
+
|
80 |
+
|
81 |
+
@dataclass
|
82 |
+
class ModelArguments:
|
83 |
+
model_name_or_path: Optional[str] = field(default="EleutherAI/pythia-1.4b-deduped")
|
84 |
+
model_type: Optional[str] = field(default="llama")
|
85 |
+
|
86 |
+
|
87 |
+
@dataclass
|
88 |
+
class DataArguments:
|
89 |
+
data_path: str = field(default=None, metadata={"help": "Path to the training data."})
|
90 |
+
|
91 |
+
|
92 |
+
@dataclass
|
93 |
+
class TrainingArguments(transformers.TrainingArguments):
|
94 |
+
cache_dir: Optional[str] = field(default=None)
|
95 |
+
optim: str = field(default="adamw_torch")
|
96 |
+
model_max_length: int = field(
|
97 |
+
default=8192 * 4,
|
98 |
+
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
|
99 |
+
)
|
100 |
+
use_flash_attn: bool = field(
|
101 |
+
default=True,
|
102 |
+
metadata={"help": "Whether use flash attention for training."},
|
103 |
+
)
|
104 |
+
use_full_attn: bool = field(
|
105 |
+
default=False,
|
106 |
+
metadata={"help": "Whether to use plain, full-attention for training."},
|
107 |
+
)
|
108 |
+
low_rank_training: bool = field(
|
109 |
+
default=True,
|
110 |
+
metadata={"help": "Whether use low rank adaptation for training."},
|
111 |
+
)
|
112 |
+
trainable_params: str = field(
|
113 |
+
default="embed,norm",
|
114 |
+
metadata={"help": "Additional trainable parameters except LoRA weights, if low rank training."},
|
115 |
+
)
|
116 |
+
|
117 |
+
def smart_tokenizer_and_embedding_resize(
|
118 |
+
special_tokens_dict: Dict,
|
119 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
120 |
+
model: transformers.PreTrainedModel,
|
121 |
+
):
|
122 |
+
"""Resize tokenizer and embedding.
|
123 |
+
|
124 |
+
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
|
125 |
+
"""
|
126 |
+
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
|
127 |
+
model.resize_token_embeddings(len(tokenizer))
|
128 |
+
|
129 |
+
if num_new_tokens > 0:
|
130 |
+
input_embeddings = model.get_input_embeddings().weight.data
|
131 |
+
output_embeddings = model.get_output_embeddings().weight.data
|
132 |
+
|
133 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
134 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
135 |
+
|
136 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
137 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
138 |
+
|
139 |
+
|
140 |
+
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
|
141 |
+
"""Tokenize a list of strings."""
|
142 |
+
tokenized_list = [
|
143 |
+
tokenizer(
|
144 |
+
text,
|
145 |
+
return_tensors="pt",
|
146 |
+
padding="longest",
|
147 |
+
max_length=tokenizer.model_max_length,
|
148 |
+
truncation=True,
|
149 |
+
)
|
150 |
+
for text in strings
|
151 |
+
]
|
152 |
+
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
|
153 |
+
input_ids_lens = labels_lens = [
|
154 |
+
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
|
155 |
+
]
|
156 |
+
return dict(
|
157 |
+
input_ids=input_ids,
|
158 |
+
labels=labels,
|
159 |
+
input_ids_lens=input_ids_lens,
|
160 |
+
labels_lens=labels_lens,
|
161 |
+
)
|
162 |
+
|
163 |
+
|
164 |
+
def preprocess(
|
165 |
+
sources: Sequence[str],
|
166 |
+
targets: Sequence[str],
|
167 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
168 |
+
) -> Dict:
|
169 |
+
"""Preprocess the data by tokenizing."""
|
170 |
+
examples = [s + t for s, t in zip(sources, targets)]
|
171 |
+
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
|
172 |
+
input_ids = examples_tokenized["input_ids"]
|
173 |
+
labels = copy.deepcopy(input_ids)
|
174 |
+
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
|
175 |
+
label[:source_len] = IGNORE_INDEX
|
176 |
+
return dict(input_ids=input_ids, labels=labels)
|
177 |
+
|
178 |
+
|
179 |
+
class SupervisedDataset(Dataset):
|
180 |
+
"""Dataset for supervised fine-tuning."""
|
181 |
+
|
182 |
+
def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer):
|
183 |
+
super(SupervisedDataset, self).__init__()
|
184 |
+
logging.warning("Loading line item and alm data...")
|
185 |
+
list_data_dict = jload(data_path)
|
186 |
+
|
187 |
+
logging.warning("Formatting inputs...")
|
188 |
+
|
189 |
+
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input_llama2"], PROMPT_DICT["prompt_no_input_llama2"]
|
190 |
+
sources = [
|
191 |
+
prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
|
192 |
+
for example in list_data_dict
|
193 |
+
]
|
194 |
+
|
195 |
+
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
|
196 |
+
|
197 |
+
logging.warning("Tokenizing inputs... This may take some time...")
|
198 |
+
data_dict = preprocess(sources, targets, tokenizer)
|
199 |
+
|
200 |
+
self.input_ids = data_dict["input_ids"]
|
201 |
+
self.labels = data_dict["labels"]
|
202 |
+
|
203 |
+
def __len__(self):
|
204 |
+
return len(self.input_ids)
|
205 |
+
|
206 |
+
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
207 |
+
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
|
208 |
+
|
209 |
+
|
210 |
+
@dataclass
|
211 |
+
class DataCollatorForSupervisedDataset(object):
|
212 |
+
"""Collate examples for supervised fine-tuning."""
|
213 |
+
|
214 |
+
tokenizer: transformers.PreTrainedTokenizer
|
215 |
+
|
216 |
+
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
217 |
+
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
|
218 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(
|
219 |
+
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
220 |
+
)
|
221 |
+
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
|
222 |
+
return dict(
|
223 |
+
input_ids=input_ids,
|
224 |
+
labels=labels,
|
225 |
+
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
226 |
+
)
|
227 |
+
|
228 |
+
|
229 |
+
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
|
230 |
+
"""Make dataset and collator for supervised fine-tuning."""
|
231 |
+
train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path)
|
232 |
+
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
233 |
+
return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
|
234 |
+
|
235 |
+
|
236 |
+
def train():
|
237 |
+
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
|
238 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
239 |
+
|
240 |
+
# NOTE: May expand supported model types in the future
|
241 |
+
# if model_args.model_type == "gpt-neox":
|
242 |
+
# replace_gpt_neox_attn(training_args.use_flash_attn, training_args.use_full_attn)
|
243 |
+
# else:
|
244 |
+
replace_llama_attn(training_args.use_flash_attn, training_args.use_full_attn)
|
245 |
+
|
246 |
+
# Set RoPE scaling factor
|
247 |
+
config = transformers.AutoConfig.from_pretrained(
|
248 |
+
model_args.model_name_or_path,
|
249 |
+
cache_dir=training_args.cache_dir,
|
250 |
+
)
|
251 |
+
|
252 |
+
orig_rope_scaling = getattr(config, "rope_scaling", {"factor": 1})
|
253 |
+
# Check if orig_rope_scaling is a dictionary before accessing its "get" method
|
254 |
+
if isinstance(orig_rope_scaling, dict):
|
255 |
+
orig_rope_scaling_factor = orig_rope_scaling.get("factor", 1)
|
256 |
+
else:
|
257 |
+
orig_rope_scaling_factor = 1 #orig_rope_scaling["factor"] if "factor" in orig_rope_scaling.keys() else 1
|
258 |
+
|
259 |
+
orig_ctx_len = getattr(config, "max_position_embeddings", None)
|
260 |
+
if orig_ctx_len:
|
261 |
+
orig_ctx_len *= orig_rope_scaling_factor
|
262 |
+
if training_args.model_max_length > orig_ctx_len:
|
263 |
+
scaling_factor = float(math.ceil(training_args.model_max_length / orig_ctx_len))
|
264 |
+
config.rope_scaling = {"type": "linear", "factor": scaling_factor}
|
265 |
+
|
266 |
+
# Load model and tokenizer
|
267 |
+
model = transformers.AutoModelForCausalLM.from_pretrained(
|
268 |
+
model_args.model_name_or_path,
|
269 |
+
config=config,
|
270 |
+
cache_dir=training_args.cache_dir,
|
271 |
+
torch_dtype=torch.bfloat16,
|
272 |
+
)
|
273 |
+
|
274 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
275 |
+
model_args.model_name_or_path,
|
276 |
+
cache_dir=training_args.cache_dir,
|
277 |
+
model_max_length=training_args.model_max_length,
|
278 |
+
padding_side="right",
|
279 |
+
use_fast=True,
|
280 |
+
)
|
281 |
+
|
282 |
+
special_tokens_dict = dict()
|
283 |
+
if tokenizer.pad_token is None:
|
284 |
+
special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
|
285 |
+
if tokenizer.eos_token is None:
|
286 |
+
special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
|
287 |
+
if tokenizer.bos_token is None:
|
288 |
+
special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
|
289 |
+
if tokenizer.unk_token is None:
|
290 |
+
special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
|
291 |
+
|
292 |
+
smart_tokenizer_and_embedding_resize(
|
293 |
+
special_tokens_dict=special_tokens_dict,
|
294 |
+
tokenizer=tokenizer,
|
295 |
+
model=model,
|
296 |
+
)
|
297 |
+
|
298 |
+
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
|
299 |
+
|
300 |
+
if training_args.low_rank_training:
|
301 |
+
if model_args.model_type == "gpt-neox":
|
302 |
+
# added `dense` to match with llama as the basic LoRA would only target 'query_key_value'
|
303 |
+
targets = ["query_key_value", "dense"]
|
304 |
+
else:
|
305 |
+
targets=["q_proj", "k_proj", "v_proj", "o_proj"]
|
306 |
+
|
307 |
+
config = LoraConfig(
|
308 |
+
r=8,
|
309 |
+
lora_alpha=16,
|
310 |
+
target_modules=targets,
|
311 |
+
lora_dropout=0,
|
312 |
+
bias="none",
|
313 |
+
task_type="CAUSAL_LM",
|
314 |
+
)
|
315 |
+
model = get_peft_model(model, config)
|
316 |
+
# enable trainable params
|
317 |
+
[p.requires_grad_() for n, p in model.named_parameters() if any([k in n for k in training_args.trainable_params.split(",")])]
|
318 |
+
|
319 |
+
model.config.use_cache = False # required for gradient checkpointing
|
320 |
+
model.enable_input_require_grads() # required for gradient checkpointing
|
321 |
+
model.gradient_checkpointing_enable() # enable gradient checkpointing
|
322 |
+
|
323 |
+
trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
|
324 |
+
trainer.train()
|
325 |
+
trainer.save_state()
|
326 |
+
trainer.save_model(output_dir=training_args.output_dir)
|
327 |
+
|
328 |
+
|
329 |
+
if __name__ == "__main__":
|
330 |
+
train()
|